summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--CHANGES.rst12
-rw-r--r--pyramid/scripts/pshell.py110
-rw-r--r--pyramid/tests/test_scripts/dummy.py2
-rw-r--r--pyramid/tests/test_scripts/test_pshell.py29
-rw-r--r--pyramid/tests/test_util.py38
-rw-r--r--pyramid/util.py18
6 files changed, 156 insertions, 53 deletions
diff --git a/CHANGES.rst b/CHANGES.rst
index e09c3723c..47738e29b 100644
--- a/CHANGES.rst
+++ b/CHANGES.rst
@@ -41,6 +41,14 @@ Features
exception/response object for a HTTP 308 redirect.
See https://github.com/Pylons/pyramid/pull/3302
+- Within ``pshell``, allow the user-defined ``setup`` function to be a
+ generator, in which case it may wrap the command's lifecycle.
+ See https://github.com/Pylons/pyramid/pull/3318
+
+- Within ``pshell``, variables defined by the ``[pshell]`` settings are
+ available within the user-defined ``setup`` function.
+ See https://github.com/Pylons/pyramid/pull/3318
+
Bug Fixes
---------
@@ -76,6 +84,10 @@ Backward Incompatibilities
``pyramid.session.UnencryptedCookieSessionFactoryConfig``.
See https://github.com/Pylons/pyramid/pull/3300
+- Variables defined in the ``[pshell]`` section of the settings will no
+ longer override those set by the ``setup`` function.
+ See https://github.com/Pylons/pyramid/pull/3318
+
Documentation Changes
---------------------
diff --git a/pyramid/scripts/pshell.py b/pyramid/scripts/pshell.py
index bb201dbc2..4898eb39f 100644
--- a/pyramid/scripts/pshell.py
+++ b/pyramid/scripts/pshell.py
@@ -1,4 +1,5 @@
from code import interact
+from contextlib import contextmanager
import argparse
import os
import sys
@@ -7,6 +8,7 @@ import pkg_resources
from pyramid.compat import exec_
from pyramid.util import DottedNameResolver
+from pyramid.util import make_contextmanager
from pyramid.paster import bootstrap
from pyramid.settings import aslist
@@ -85,6 +87,7 @@ class PShellCommand(object):
preferred_shells = []
setup = None
pystartup = os.environ.get('PYTHONSTARTUP')
+ resolver = DottedNameResolver(None)
def __init__(self, argv, quiet=False):
self.quiet = quiet
@@ -92,7 +95,6 @@ class PShellCommand(object):
def pshell_file_config(self, loader, defaults):
settings = loader.get_settings('pshell', defaults)
- resolver = DottedNameResolver(None)
self.loaded_objects = {}
self.object_help = {}
self.setup = None
@@ -102,7 +104,7 @@ class PShellCommand(object):
elif k == 'default_shell':
self.preferred_shells = [x.lower() for x in aslist(v)]
else:
- self.loaded_objects[k] = resolver.maybe_resolve(v)
+ self.loaded_objects[k] = self.resolver.maybe_resolve(v)
self.object_help[k] = v
def out(self, msg): # pragma: no cover
@@ -115,18 +117,36 @@ class PShellCommand(object):
if not self.args.config_uri:
self.out('Requires a config file argument')
return 2
+
config_uri = self.args.config_uri
config_vars = parse_vars(self.args.config_vars)
loader = self.get_config_loader(config_uri)
loader.setup_logging(config_vars)
self.pshell_file_config(loader, config_vars)
- env = self.bootstrap(config_uri, options=config_vars)
+ self.env = self.bootstrap(config_uri, options=config_vars)
# remove the closer from the env
- self.closer = env.pop('closer')
+ self.closer = self.env.pop('closer')
+
+ try:
+ if shell is None:
+ try:
+ shell = self.make_shell()
+ except ValueError as e:
+ self.out(str(e))
+ return 1
+
+ with self.setup_env():
+ shell(self.env, self.help)
+
+ finally:
+ self.closer()
+ @contextmanager
+ def setup_env(self):
# setup help text for default environment
+ env = self.env
env_help = dict(env)
env_help['app'] = 'The WSGI application.'
env_help['root'] = 'Root of the default resource tree.'
@@ -135,65 +155,55 @@ class PShellCommand(object):
env_help['root_factory'] = (
'Default root factory used to create `root`.')
+ # load the pshell section of the ini file
+ env.update(self.loaded_objects)
+
+ # eliminate duplicates from env, allowing custom vars to override
+ for k in self.loaded_objects:
+ if k in env_help:
+ del env_help[k]
+
# override use_script with command-line options
if self.args.setup:
self.setup = self.args.setup
if self.setup:
- # store the env before muddling it with the script
- orig_env = env.copy()
-
# call the setup callable
- resolver = DottedNameResolver(None)
- setup = resolver.maybe_resolve(self.setup)
- setup(env)
+ self.setup = self.resolver.maybe_resolve(self.setup)
+ # store the env before muddling it with the script
+ orig_env = env.copy()
+ setup_manager = make_contextmanager(self.setup)
+ with setup_manager(env):
# remove any objects from default help that were overidden
for k, v in env.items():
- if k not in orig_env or env[k] != orig_env[k]:
+ if k not in orig_env or v != orig_env[k]:
if getattr(v, '__doc__', False):
env_help[k] = v.__doc__.replace("\n", " ")
else:
env_help[k] = v
-
- # load the pshell section of the ini file
- env.update(self.loaded_objects)
-
- # eliminate duplicates from env, allowing custom vars to override
- for k in self.loaded_objects:
- if k in env_help:
- del env_help[k]
-
- # generate help text
- help = ''
- if env_help:
- help += 'Environment:'
- for var in sorted(env_help.keys()):
- help += '\n %-12s %s' % (var, env_help[var])
-
- if self.object_help:
- help += '\n\nCustom Variables:'
- for var in sorted(self.object_help.keys()):
- help += '\n %-12s %s' % (var, self.object_help[var])
-
- if shell is None:
- try:
- shell = self.make_shell()
- except ValueError as e:
- self.out(str(e))
- self.closer()
- return 1
-
- if self.pystartup and os.path.isfile(self.pystartup):
- with open(self.pystartup, 'rb') as fp:
- exec_(fp.read().decode('utf-8'), env)
- if '__builtins__' in env:
- del env['__builtins__']
-
- try:
- shell(env, help)
- finally:
- self.closer()
+ del orig_env
+
+ # generate help text
+ help = ''
+ if env_help:
+ help += 'Environment:'
+ for var in sorted(env_help.keys()):
+ help += '\n %-12s %s' % (var, env_help[var])
+
+ if self.object_help:
+ help += '\n\nCustom Variables:'
+ for var in sorted(self.object_help.keys()):
+ help += '\n %-12s %s' % (var, self.object_help[var])
+
+ if self.pystartup and os.path.isfile(self.pystartup):
+ with open(self.pystartup, 'rb') as fp:
+ exec_(fp.read().decode('utf-8'), env)
+ if '__builtins__' in env:
+ del env['__builtins__']
+
+ self.help = help.strip()
+ yield
def show_shells(self):
shells = self.find_all_shells()
diff --git a/pyramid/tests/test_scripts/dummy.py b/pyramid/tests/test_scripts/dummy.py
index 2d2b0549f..f1ef403f8 100644
--- a/pyramid/tests/test_scripts/dummy.py
+++ b/pyramid/tests/test_scripts/dummy.py
@@ -22,11 +22,13 @@ class DummyShell(object):
env = {}
help = ''
called = False
+ dummy_attr = 1
def __call__(self, env, help):
self.env = env
self.help = help
self.called = True
+ self.env['request'].dummy_attr = self.dummy_attr
class DummyInteractor:
def __call__(self, banner, local):
diff --git a/pyramid/tests/test_scripts/test_pshell.py b/pyramid/tests/test_scripts/test_pshell.py
index ca9eb7af2..df664bea9 100644
--- a/pyramid/tests/test_scripts/test_pshell.py
+++ b/pyramid/tests/test_scripts/test_pshell.py
@@ -226,6 +226,33 @@ class TestPShellCommand(unittest.TestCase):
self.assertTrue(self.bootstrap.closer.called)
self.assertTrue(shell.help)
+ def test_command_setup_generator(self):
+ command = self._makeOne()
+ did_resume_after_yield = {}
+ def setup(env):
+ env['a'] = 1
+ env['root'] = 'root override'
+ env['none'] = None
+ request = env['request']
+ yield
+ did_resume_after_yield['result'] = True
+ self.assertEqual(request.dummy_attr, 1)
+ self.loader.settings = {'pshell': {'setup': setup}}
+ shell = dummy.DummyShell()
+ command.run(shell)
+ self.assertEqual(self.bootstrap.a[0], '/foo/bar/myapp.ini#myapp')
+ self.assertEqual(shell.env, {
+ 'app':self.bootstrap.app, 'root':'root override',
+ 'registry':self.bootstrap.registry,
+ 'request':self.bootstrap.request,
+ 'root_factory':self.bootstrap.root_factory,
+ 'a':1,
+ 'none': None,
+ })
+ self.assertTrue(did_resume_after_yield['result'])
+ self.assertTrue(self.bootstrap.closer.called)
+ self.assertTrue(shell.help)
+
def test_command_default_shell_option(self):
command = self._makeOne()
ipshell = dummy.DummyShell()
@@ -259,7 +286,7 @@ class TestPShellCommand(unittest.TestCase):
'registry':self.bootstrap.registry,
'request':self.bootstrap.request,
'root_factory':self.bootstrap.root_factory,
- 'a':1, 'm':model,
+ 'a':1, 'm':'model override',
})
self.assertTrue(self.bootstrap.closer.called)
self.assertTrue(shell.help)
diff --git a/pyramid/tests/test_util.py b/pyramid/tests/test_util.py
index ab9de262e..0f7671d59 100644
--- a/pyramid/tests/test_util.py
+++ b/pyramid/tests/test_util.py
@@ -889,3 +889,41 @@ class Test_is_same_domain(unittest.TestCase):
self.assertTrue(self._callFUT("example.com:8080", "example.com:8080"))
self.assertFalse(self._callFUT("example.com:8080", "example.com"))
self.assertFalse(self._callFUT("example.com", "example.com:8080"))
+
+
+class Test_make_contextmanager(unittest.TestCase):
+ def _callFUT(self, *args, **kw):
+ from pyramid.util import make_contextmanager
+ return make_contextmanager(*args, **kw)
+
+ def test_with_None(self):
+ mgr = self._callFUT(None)
+ with mgr() as ctx:
+ self.assertIsNone(ctx)
+
+ def test_with_generator(self):
+ def mygen(ctx):
+ yield ctx
+ mgr = self._callFUT(mygen)
+ with mgr('a') as ctx:
+ self.assertEqual(ctx, 'a')
+
+ def test_with_multiple_yield_generator(self):
+ def mygen():
+ yield 'a'
+ yield 'b'
+ mgr = self._callFUT(mygen)
+ try:
+ with mgr() as ctx:
+ self.assertEqual(ctx, 'a')
+ except RuntimeError:
+ pass
+ else: # pragma: no cover
+ raise AssertionError('expected raise from multiple yields')
+
+ def test_with_regular_fn(self):
+ def mygen():
+ return 'a'
+ mgr = self._callFUT(mygen)
+ with mgr() as ctx:
+ self.assertEqual(ctx, 'a')
diff --git a/pyramid/util.py b/pyramid/util.py
index 09a3e530f..77a0f306b 100644
--- a/pyramid/util.py
+++ b/pyramid/util.py
@@ -1,4 +1,4 @@
-import contextlib
+from contextlib import contextmanager
import functools
try:
# py2.7.7+ and py3.3+ have native comparison support
@@ -613,7 +613,7 @@ def get_callable_name(name):
)
raise ConfigurationError(msg % name)
-@contextlib.contextmanager
+@contextmanager
def hide_attrs(obj, *attrs):
"""
Temporarily delete object attrs and restore afterward.
@@ -648,3 +648,17 @@ def is_same_domain(host, pattern):
return (pattern[0] == "." and
(host.endswith(pattern) or host == pattern[1:]) or
pattern == host)
+
+
+def make_contextmanager(fn):
+ if inspect.isgeneratorfunction(fn):
+ return contextmanager(fn)
+
+ if fn is None:
+ fn = lambda *a, **kw: None
+
+ @contextmanager
+ @functools.wraps(fn)
+ def wrapper(*a, **kw):
+ yield fn(*a, **kw)
+ return wrapper