summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--CHANGES.txt5
-rwxr-xr-xpyramid/paster_templates/alchemy/+package+/__init__.py_tmpl2
-rwxr-xr-xpyramid/paster_templates/alchemy/+package+/models.py25
-rw-r--r--pyramid/paster_templates/alchemy/+package+/tests.py_tmpl60
4 files changed, 77 insertions, 15 deletions
diff --git a/CHANGES.txt b/CHANGES.txt
index 8ee0b0bd0..160f28bff 100644
--- a/CHANGES.txt
+++ b/CHANGES.txt
@@ -23,6 +23,11 @@ Features
converted implicitly to strings in the result. Previously passing integers
or longs as elements would cause a TypeError.
+- ``pyramid_alchemy`` paster template now uses ``query.get`` rather than
+ ``query.filter_by`` to take better advantage of identity map caching.
+
+- ``pyramid_alchemy`` paster template now has unit tests.
+
1.0 (2011-01-30)
================
diff --git a/pyramid/paster_templates/alchemy/+package+/__init__.py_tmpl b/pyramid/paster_templates/alchemy/+package+/__init__.py_tmpl
index 63d43b052..8dfb3bf0a 100755
--- a/pyramid/paster_templates/alchemy/+package+/__init__.py_tmpl
+++ b/pyramid/paster_templates/alchemy/+package+/__init__.py_tmpl
@@ -11,7 +11,7 @@ def main(global_config, **settings):
config = Configurator(settings=settings, root_factory=get_root)
config.add_static_view('static', '{{package}}:static')
config.add_view('{{package}}.views.view_root',
- context='{{package}}.models.MyApp',
+ context='{{package}}.models.MyRoot',
renderer="templates/root.pt")
config.add_view('{{package}}.views.view_model',
context='{{package}}.models.MyModel',
diff --git a/pyramid/paster_templates/alchemy/+package+/models.py b/pyramid/paster_templates/alchemy/+package+/models.py
index 82d6bec0c..f1b47f98c 100755
--- a/pyramid/paster_templates/alchemy/+package+/models.py
+++ b/pyramid/paster_templates/alchemy/+package+/models.py
@@ -6,9 +6,7 @@ from sqlalchemy.orm import sessionmaker
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.exc import IntegrityError
-from sqlalchemy.orm.exc import NoResultFound
-from sqlalchemy import create_engine
from sqlalchemy import Integer
from sqlalchemy import Unicode
from sqlalchemy import Column
@@ -28,7 +26,7 @@ class MyModel(Base):
self.name = name
self.value = value
-class MyApp(object):
+class MyRoot(object):
__name__ = None
__parent__ = None
@@ -39,16 +37,14 @@ class MyApp(object):
except (ValueError, TypeError):
raise KeyError(key)
- query = session.query(MyModel).filter_by(id=id)
-
- try:
- item = query.one()
- item.__parent__ = self
- item.__name__ = key
- return item
- except NoResultFound:
+ item = session.query(MyModel).get(id)
+ if item is None:
raise KeyError(key)
+ item.__parent__ = self
+ item.__name__ = key
+ return item
+
def get(self, key, default=None):
try:
item = self.__getitem__(key)
@@ -61,9 +57,9 @@ class MyApp(object):
query = session.query(MyModel)
return iter(query)
-root = MyApp()
+root = MyRoot()
-def default_get_root(request):
+def root_factory(request):
return root
def populate():
@@ -81,7 +77,8 @@ def initialize_sql(engine):
populate()
except IntegrityError:
DBSession.rollback()
+ return DBSession
def appmaker(engine):
initialize_sql(engine)
- return default_get_root
+ return root_factory
diff --git a/pyramid/paster_templates/alchemy/+package+/tests.py_tmpl b/pyramid/paster_templates/alchemy/+package+/tests.py_tmpl
new file mode 100644
index 000000000..c073bfc88
--- /dev/null
+++ b/pyramid/paster_templates/alchemy/+package+/tests.py_tmpl
@@ -0,0 +1,60 @@
+import unittest
+
+from pyramid import testing
+
+def _initTestingDB():
+ from sqlalchemy import create_engine
+ from {{package}}.models import initialize_sql
+ session = initialize_sql(create_engine('sqlite://'))
+ return session
+
+class TestMyRoot(unittest.TestCase):
+ def setUp(self):
+ self.config = testing.setUp()
+ self.session = _initTestingDB()
+
+ def tearDown(self):
+ testing.tearDown()
+ self.session.remove()
+
+ def _makeOne(self):
+ from {{package}}.models import MyRoot
+ return MyRoot()
+
+ def test___getitem__hit(self):
+ from {{package}}.models import MyModel
+ root = self._makeOne()
+ first = root['1']
+ self.assertEqual(first.__class__, MyModel)
+ self.assertEqual(first.__parent__, root)
+ self.assertEqual(first.__name__, '1')
+
+ def test___getitem__miss(self):
+ root = self._makeOne()
+ self.assertRaises(KeyError, root.__getitem__, '100')
+
+ def test___getitem__notint(self):
+ root = self._makeOne()
+ self.assertRaises(KeyError, root.__getitem__, 'notint')
+
+ def test_get_hit(self):
+ from {{package}}.models import MyModel
+ root = self._makeOne()
+ first = root.get('1')
+ self.assertEqual(first.__class__, MyModel)
+ self.assertEqual(first.__parent__, root)
+ self.assertEqual(first.__name__, '1')
+
+ def test_get_miss(self):
+ root = self._makeOne()
+ self.assertEqual(root.get('100', 'default'), 'default')
+ self.assertEqual(root.get('100'), None)
+
+ def test___iter__(self):
+ root = self._makeOne()
+ iterable = iter(root)
+ result = list(iterable)
+ self.assertEqual(len(result), 1)
+ model = result[0]
+ self.assertEqual(model.id, 1)
+