summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--docs/api/scripting.rst2
-rw-r--r--pyramid/scripting.py36
-rw-r--r--pyramid/tests/test_scripting.py57
3 files changed, 91 insertions, 4 deletions
diff --git a/docs/api/scripting.rst b/docs/api/scripting.rst
index 2029578ba..3e9a814fc 100644
--- a/docs/api/scripting.rst
+++ b/docs/api/scripting.rst
@@ -7,5 +7,7 @@
.. autofunction:: get_root
+ .. autofunction:: get_root2
+
.. autofunction:: make_request
diff --git a/pyramid/scripting.py b/pyramid/scripting.py
index 79523dff1..cbcba95df 100644
--- a/pyramid/scripting.py
+++ b/pyramid/scripting.py
@@ -1,6 +1,9 @@
from pyramid.config import global_registries
from pyramid.request import Request
from pyramid.interfaces import IRequestFactory
+from pyramid.interfaces import IRootFactory
+from pyramid.threadlocal import manager as threadlocal_manager
+from pyramid.traversal import DefaultRootFactory
def get_root(app, request=None):
""" Return a tuple composed of ``(root, closer)`` when provided a
@@ -23,6 +26,38 @@ def get_root(app, request=None):
root = app.root_factory(request)
return root, closer
+def get_root2(request=None, registry=None):
+ """ Return a tuple composed of ``(root, closer)``. The ``root``
+ returned is the application's root object. The ``closer`` returned
+ is a callable (accepting no arguments) that should be called when
+ your scripting application is finished using the root.
+
+ If ``request`` is None, a default one is constructed using
+ :meth:`pyramid.scripting.make_request`. It is used as the request
+ passed to the :app:`Pyramid` application root factory.
+
+ If ``registry`` is not supplied, the last registry loaded from
+ :attr:`pyramid.config.global_registries` will be used. If you have
+ loaded more than one :app:`Pyramid` application in the current
+ process, you may not want to use the last registry loaded, thus
+ you can search the ``global_registries`` and supply the appropriate
+ one based on your own criteria.
+ """
+ if registry is None:
+ registry = getattr(request, 'registry', global_registries.last)
+ if request is None:
+ request = make_request('/', registry)
+ request.registry = registry
+ threadlocals = {'registry':registry, 'request':request}
+ threadlocal_manager.push(threadlocals)
+ def closer(request=request): # keep request alive via this function default
+ threadlocal_manager.pop()
+ q = registry.queryUtility
+ root_factory = registry.queryUtility(IRootFactory,
+ default=DefaultRootFactory)
+ root = root_factory(request)
+ return root, closer
+
def make_request(path, registry=None):
""" Return a :meth:`pyramid.request.Request` object anchored at a
given path. The object returned will be generated from the supplied
@@ -48,4 +83,3 @@ def make_request(path, registry=None):
request = request_factory.blank(path)
request.registry = registry
return request
-
diff --git a/pyramid/tests/test_scripting.py b/pyramid/tests/test_scripting.py
index 315ab222f..9bf57be06 100644
--- a/pyramid/tests/test_scripting.py
+++ b/pyramid/tests/test_scripting.py
@@ -35,6 +35,53 @@ class TestGetRoot(unittest.TestCase):
pushed = app.threadlocal_manager.pushed[0]
self.assertEqual(pushed['request'].environ['path'], '/')
+class TestGetRoot2(unittest.TestCase):
+ def _callFUT(self, request=None, registry=None):
+ from pyramid.scripting import get_root2
+ return get_root2(request, registry)
+
+ def _makeRegistry(self):
+ return DummyRegistry(DummyFactory)
+
+ def setUp(self):
+ from pyramid.threadlocal import manager
+ self.manager = manager
+ self.default = manager.get()
+
+ def tearDown(self):
+ self.assertEqual(self.default, self.manager.get())
+
+ def test_it_norequest(self):
+ registry = self._makeRegistry()
+ root, closer = self._callFUT(registry=registry)
+ pushed = self.manager.get()
+ self.assertEqual(pushed['registry'], registry)
+ self.assertEqual(pushed['request'].registry, registry)
+ self.assertEqual(root.a, (pushed['request'],))
+ closer()
+
+ def test_it_withrequest(self):
+ request = DummyRequest({})
+ registry = request.registry = self._makeRegistry()
+ root, closer = self._callFUT(request)
+ pushed = self.manager.get()
+ self.assertEqual(pushed['request'], request)
+ self.assertEqual(pushed['registry'], registry)
+ self.assertEqual(pushed['request'].registry, registry)
+ self.assertEqual(root.a, (request,))
+ closer()
+
+ def test_it_with_request_and_registry(self):
+ request = DummyRequest({})
+ registry = request.registry = self._makeRegistry()
+ root, closer = self._callFUT(request, registry)
+ pushed = self.manager.get()
+ self.assertEqual(pushed['request'], request)
+ self.assertEqual(pushed['registry'], registry)
+ self.assertEqual(pushed['request'].registry, registry)
+ self.assertEqual(root.a, (request,))
+ closer()
+
class TestMakeRequest(unittest.TestCase):
def _callFUT(self, path='/', registry=None):
from pyramid.scripting import make_request
@@ -65,12 +112,16 @@ class DummyFactory(object):
req = DummyRequest({'path': path})
return req
+ def __init__(self, *a, **kw):
+ self.a = a
+ self.kw = kw
+
class DummyRegistry(object):
- def __init__(self, result=None):
- self.result = result
+ def __init__(self, factory=None):
+ self.factory = factory
def queryUtility(self, iface, default=None):
- return self.result or default
+ return self.factory or default
dummy_registry = DummyRegistry(DummyFactory)