diff options
| -rw-r--r-- | pyramid/scaffolds/alchemy/+package+/__init__.py | 11 | ||||
| -rw-r--r-- | pyramid/scaffolds/alchemy/+package+/models.py | 38 | ||||
| -rw-r--r-- | pyramid/scaffolds/alchemy/+package+/scripts/initializedb.py | 17 | ||||
| -rw-r--r-- | pyramid/scaffolds/alchemy/+package+/tests.py_tmpl | 74 | ||||
| -rw-r--r-- | pyramid/scaffolds/alchemy/+package+/views.py_tmpl | 7 |
5 files changed, 88 insertions, 59 deletions
diff --git a/pyramid/scaffolds/alchemy/+package+/__init__.py b/pyramid/scaffolds/alchemy/+package+/__init__.py index 867049e4f..116839351 100644 --- a/pyramid/scaffolds/alchemy/+package+/__init__.py +++ b/pyramid/scaffolds/alchemy/+package+/__init__.py @@ -1,20 +1,11 @@ from pyramid.config import Configurator -from sqlalchemy import engine_from_config - -from .models import ( - DBSession, - Base, - ) - def main(global_config, **settings): """ This function returns a Pyramid WSGI application. """ - engine = engine_from_config(settings, 'sqlalchemy.') - DBSession.configure(bind=engine) - Base.metadata.bind = engine config = Configurator(settings=settings) config.include('pyramid_chameleon') + config.include('.models') config.add_static_view('static', 'static', cache_max_age=3600) config.add_route('home', '/') config.scan() diff --git a/pyramid/scaffolds/alchemy/+package+/models.py b/pyramid/scaffolds/alchemy/+package+/models.py index a0d3e7b71..ccf1f2379 100644 --- a/pyramid/scaffolds/alchemy/+package+/models.py +++ b/pyramid/scaffolds/alchemy/+package+/models.py @@ -5,17 +5,42 @@ from sqlalchemy import ( Text, ) +from sqlalchemy import engine_from_config from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.orm import sessionmaker +import zope.sqlalchemy -from sqlalchemy.orm import ( - scoped_session, - sessionmaker, + +Base = declarative_base() + + +def includeme(config): + settings = config.get_settings() + dbmaker = get_dbmaker(get_engine(settings)) + + config.add_request_method( + lambda r: get_session(r.tm, dbmaker), + 'dbsession', + reify=True ) -from zope.sqlalchemy import ZopeTransactionExtension + config.include('pyramid_tm') -DBSession = scoped_session(sessionmaker(extension=ZopeTransactionExtension())) -Base = declarative_base() + +def get_session(transaction_manager, dbmaker): + dbsession = dbmaker() + zope.sqlalchemy.register(dbsession, transaction_manager=transaction_manager) + return dbsession + + +def get_engine(settings, prefix='sqlalchemy.'): + return engine_from_config(settings, prefix) + + +def get_dbmaker(engine): + dbmaker = sessionmaker() + dbmaker.configure(bind=engine) + return dbmaker class MyModel(Base): @@ -24,4 +49,5 @@ class MyModel(Base): name = Column(Text) value = Column(Integer) + Index('my_index', MyModel.name, unique=True, mysql_length=255) diff --git a/pyramid/scaffolds/alchemy/+package+/scripts/initializedb.py b/pyramid/scaffolds/alchemy/+package+/scripts/initializedb.py index 7dfdece15..43e25bff8 100644 --- a/pyramid/scaffolds/alchemy/+package+/scripts/initializedb.py +++ b/pyramid/scaffolds/alchemy/+package+/scripts/initializedb.py @@ -2,8 +2,6 @@ import os import sys import transaction -from sqlalchemy import engine_from_config - from pyramid.paster import ( get_appsettings, setup_logging, @@ -12,9 +10,11 @@ from pyramid.paster import ( from pyramid.scripts.common import parse_vars from ..models import ( - DBSession, MyModel, Base, + get_session, + get_engine, + get_dbmaker, ) @@ -32,9 +32,14 @@ def main(argv=sys.argv): options = parse_vars(argv[2:]) setup_logging(config_uri) settings = get_appsettings(config_uri, options=options) - engine = engine_from_config(settings, 'sqlalchemy.') - DBSession.configure(bind=engine) + + engine = get_engine(settings) + dbmaker = get_dbmaker(engine) + + dbsession = get_session(transaction.manager, dbmaker) + Base.metadata.create_all(engine) + with transaction.manager: model = MyModel(name='one', value=1) - DBSession.add(model) + dbsession.add(model) diff --git a/pyramid/scaffolds/alchemy/+package+/tests.py_tmpl b/pyramid/scaffolds/alchemy/+package+/tests.py_tmpl index e6425eb91..4ce706077 100644 --- a/pyramid/scaffolds/alchemy/+package+/tests.py_tmpl +++ b/pyramid/scaffolds/alchemy/+package+/tests.py_tmpl @@ -3,53 +3,63 @@ import transaction from pyramid import testing -from .models import DBSession +def dummy_request(dbsession): + return testing.DummyRequest(dbsession=dbsession) -class TestMyViewSuccessCondition(unittest.TestCase): + +class BaseTest(unittest.TestCase): def setUp(self): - self.config = testing.setUp() - from sqlalchemy import create_engine - engine = create_engine('sqlite://') + self.config = testing.setUp(settings={ + 'sqlalchemy.url': 'sqlite:///:memory:' + }) + self.config.include('.models') + settings = self.config.get_settings() + from .models import ( - Base, - MyModel, + get_session, + get_engine, + get_dbmaker, ) - DBSession.configure(bind=engine) - Base.metadata.create_all(engine) - with transaction.manager: - model = MyModel(name='one', value=55) - DBSession.add(model) + + self.engine = get_engine(settings) + dbmaker = get_dbmaker(self.engine) + + self.session = get_session(transaction.manager, dbmaker) + + def init_database(self): + from .models import Base + Base.metadata.create_all(self.engine) def tearDown(self): - DBSession.remove() + from .models import Base + testing.tearDown() + transaction.abort() + Base.metadata.create_all(self.engine) + + +class TestMyViewSuccessCondition(BaseTest): + + def setUp(self): + super(TestMyViewSuccessCondition, self).setUp() + self.init_database() + + from .models import MyModel + + model = MyModel(name='one', value=55) + self.session.add(model) def test_passing_view(self): from .views import my_view - request = testing.DummyRequest() - info = my_view(request) + info = my_view(dummy_request(self.session)) self.assertEqual(info['one'].name, 'one') self.assertEqual(info['project'], '{{project}}') -class TestMyViewFailureCondition(unittest.TestCase): - def setUp(self): - self.config = testing.setUp() - from sqlalchemy import create_engine - engine = create_engine('sqlite://') - from .models import ( - Base, - MyModel, - ) - DBSession.configure(bind=engine) - - def tearDown(self): - DBSession.remove() - testing.tearDown() +class TestMyViewFailureCondition(BaseTest): def test_failing_view(self): from .views import my_view - request = testing.DummyRequest() - info = my_view(request) - self.assertEqual(info.status_int, 500)
\ No newline at end of file + info = my_view(dummy_request(self.session)) + self.assertEqual(info.status_int, 500) diff --git a/pyramid/scaffolds/alchemy/+package+/views.py_tmpl b/pyramid/scaffolds/alchemy/+package+/views.py_tmpl index 292bce579..b559b31ce 100644 --- a/pyramid/scaffolds/alchemy/+package+/views.py_tmpl +++ b/pyramid/scaffolds/alchemy/+package+/views.py_tmpl @@ -3,16 +3,13 @@ from pyramid.view import view_config from sqlalchemy.exc import DBAPIError -from .models import ( - DBSession, - MyModel, - ) +from .models import MyModel @view_config(route_name='home', renderer='templates/mytemplate.pt') def my_view(request): try: - one = DBSession.query(MyModel).filter(MyModel.name == 'one').first() + one = request.dbsession.query(MyModel).filter(MyModel.name == 'one').first() except DBAPIError: return Response(conn_err_msg, content_type='text/plain', status_int=500) return {'one': one, 'project': '{{project}}'} |
