diff options
| -rw-r--r-- | CHANGES.txt | 5 | ||||
| -rwxr-xr-x | pyramid/paster_templates/alchemy/+package+/__init__.py_tmpl | 2 | ||||
| -rwxr-xr-x | pyramid/paster_templates/alchemy/+package+/models.py | 25 | ||||
| -rw-r--r-- | pyramid/paster_templates/alchemy/+package+/tests.py_tmpl | 60 |
4 files changed, 77 insertions, 15 deletions
diff --git a/CHANGES.txt b/CHANGES.txt index 8ee0b0bd0..160f28bff 100644 --- a/CHANGES.txt +++ b/CHANGES.txt @@ -23,6 +23,11 @@ Features converted implicitly to strings in the result. Previously passing integers or longs as elements would cause a TypeError. +- ``pyramid_alchemy`` paster template now uses ``query.get`` rather than + ``query.filter_by`` to take better advantage of identity map caching. + +- ``pyramid_alchemy`` paster template now has unit tests. + 1.0 (2011-01-30) ================ diff --git a/pyramid/paster_templates/alchemy/+package+/__init__.py_tmpl b/pyramid/paster_templates/alchemy/+package+/__init__.py_tmpl index 63d43b052..8dfb3bf0a 100755 --- a/pyramid/paster_templates/alchemy/+package+/__init__.py_tmpl +++ b/pyramid/paster_templates/alchemy/+package+/__init__.py_tmpl @@ -11,7 +11,7 @@ def main(global_config, **settings): config = Configurator(settings=settings, root_factory=get_root) config.add_static_view('static', '{{package}}:static') config.add_view('{{package}}.views.view_root', - context='{{package}}.models.MyApp', + context='{{package}}.models.MyRoot', renderer="templates/root.pt") config.add_view('{{package}}.views.view_model', context='{{package}}.models.MyModel', diff --git a/pyramid/paster_templates/alchemy/+package+/models.py b/pyramid/paster_templates/alchemy/+package+/models.py index 82d6bec0c..f1b47f98c 100755 --- a/pyramid/paster_templates/alchemy/+package+/models.py +++ b/pyramid/paster_templates/alchemy/+package+/models.py @@ -6,9 +6,7 @@ from sqlalchemy.orm import sessionmaker from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.exc import IntegrityError -from sqlalchemy.orm.exc import NoResultFound -from sqlalchemy import create_engine from sqlalchemy import Integer from sqlalchemy import Unicode from sqlalchemy import Column @@ -28,7 +26,7 @@ class MyModel(Base): self.name = name self.value = value -class MyApp(object): +class MyRoot(object): __name__ = None __parent__ = None @@ -39,16 +37,14 @@ class MyApp(object): except (ValueError, TypeError): raise KeyError(key) - query = session.query(MyModel).filter_by(id=id) - - try: - item = query.one() - item.__parent__ = self - item.__name__ = key - return item - except NoResultFound: + item = session.query(MyModel).get(id) + if item is None: raise KeyError(key) + item.__parent__ = self + item.__name__ = key + return item + def get(self, key, default=None): try: item = self.__getitem__(key) @@ -61,9 +57,9 @@ class MyApp(object): query = session.query(MyModel) return iter(query) -root = MyApp() +root = MyRoot() -def default_get_root(request): +def root_factory(request): return root def populate(): @@ -81,7 +77,8 @@ def initialize_sql(engine): populate() except IntegrityError: DBSession.rollback() + return DBSession def appmaker(engine): initialize_sql(engine) - return default_get_root + return root_factory diff --git a/pyramid/paster_templates/alchemy/+package+/tests.py_tmpl b/pyramid/paster_templates/alchemy/+package+/tests.py_tmpl new file mode 100644 index 000000000..c073bfc88 --- /dev/null +++ b/pyramid/paster_templates/alchemy/+package+/tests.py_tmpl @@ -0,0 +1,60 @@ +import unittest + +from pyramid import testing + +def _initTestingDB(): + from sqlalchemy import create_engine + from {{package}}.models import initialize_sql + session = initialize_sql(create_engine('sqlite://')) + return session + +class TestMyRoot(unittest.TestCase): + def setUp(self): + self.config = testing.setUp() + self.session = _initTestingDB() + + def tearDown(self): + testing.tearDown() + self.session.remove() + + def _makeOne(self): + from {{package}}.models import MyRoot + return MyRoot() + + def test___getitem__hit(self): + from {{package}}.models import MyModel + root = self._makeOne() + first = root['1'] + self.assertEqual(first.__class__, MyModel) + self.assertEqual(first.__parent__, root) + self.assertEqual(first.__name__, '1') + + def test___getitem__miss(self): + root = self._makeOne() + self.assertRaises(KeyError, root.__getitem__, '100') + + def test___getitem__notint(self): + root = self._makeOne() + self.assertRaises(KeyError, root.__getitem__, 'notint') + + def test_get_hit(self): + from {{package}}.models import MyModel + root = self._makeOne() + first = root.get('1') + self.assertEqual(first.__class__, MyModel) + self.assertEqual(first.__parent__, root) + self.assertEqual(first.__name__, '1') + + def test_get_miss(self): + root = self._makeOne() + self.assertEqual(root.get('100', 'default'), 'default') + self.assertEqual(root.get('100'), None) + + def test___iter__(self): + root = self._makeOne() + iterable = iter(root) + result = list(iterable) + self.assertEqual(len(result), 1) + model = result[0] + self.assertEqual(model.id, 1) + |
