From a00621e45ef29cde34469798144156c80a17a1e9 Mon Sep 17 00:00:00 2001 From: Chris McDonough Date: Fri, 3 Aug 2012 12:55:12 -0400 Subject: first cut at extensible view predicates via config.add_view_predicate; still requires testing of predicates themselves --- pyramid/config/__init__.py | 2 + pyramid/config/predicates.py | 218 ++++++++++++++++++++++++++++++++ pyramid/config/util.py | 56 +++++++- pyramid/config/views.py | 116 ++++++++++++++--- pyramid/interfaces.py | 3 + pyramid/testing.py | 1 + pyramid/tests/test_config/test_init.py | 21 ++- pyramid/tests/test_config/test_views.py | 12 +- pyramid/tests/test_url.py | 10 +- pyramid/tests/test_view.py | 13 +- 10 files changed, 408 insertions(+), 44 deletions(-) create mode 100644 pyramid/config/predicates.py diff --git a/pyramid/config/__init__.py b/pyramid/config/__init__.py index 52d7aca83..5eb860ed5 100644 --- a/pyramid/config/__init__.py +++ b/pyramid/config/__init__.py @@ -353,6 +353,8 @@ class Configurator( for name, renderer in DEFAULT_RENDERERS: self.add_renderer(name, renderer) + self.add_default_view_predicates() + if exceptionresponse_view is not None: exceptionresponse_view = self.maybe_dotted(exceptionresponse_view) self.add_view(exceptionresponse_view, context=IExceptionResponse) diff --git a/pyramid/config/predicates.py b/pyramid/config/predicates.py new file mode 100644 index 000000000..24ec89c6b --- /dev/null +++ b/pyramid/config/predicates.py @@ -0,0 +1,218 @@ +import re + +from pyramid.compat import is_nonstr_iter + +from pyramid.exceptions import ConfigurationError + +from pyramid.traversal import ( + find_interface, + traversal_path, + ) + +from pyramid.urldispatch import _compile_route + +from .util import as_sorted_tuple + +class XHRPredicate(object): + def __init__(self, val): + self.val = bool(val) + + def __text__(self): + return 'xhr = True' + + def __phash__(self): + return 'xhr:%r' % (self.val,) + + def __call__(self, context, request): + return request.is_xhr + + +class RequestMethodPredicate(object): + def __init__(self, val): + self.val = as_sorted_tuple(val) + + def __text__(self): + return 'request method = %r' % (self.val,) + + def __phash__(self): + L = [] + for v in self.val: + L.append('request_method:%r' % v) + return L + + def __call__(self, context, request): + return request.method in self.val + +class PathInfoPredicate(object): + def __init__(self, val): + self.orig = val + try: + val = re.compile(val) + except re.error as why: + raise ConfigurationError(why.args[0]) + self.val = val + + def __text__(self): + return 'path_info = %s' % (self.orig,) + + def __phash__(self): + return 'path_info:%r' % (self.orig,) + + def __call__(self, context, request): + return self.val.match(request.upath_info) is not None + +class RequestParamPredicate(object): + def __init__(self, val): + name = val + v = None + if '=' in name: + name, v = name.split('=', 1) + if v is None: + self.text = 'request_param %s' % (name,) + else: + self.text = 'request_param %s = %s' % (name, v) + self.name = name + self.val = v + + def __text__(self): + return self.text + + def __phash__(self): + return 'request_param:%r=%r' % (self.name, self.val) + + def __call__(self, context, request): + if self.val is None: + return self.name in request.params + return request.params.get(self.name) == self.val + + +class HeaderPredicate(object): + def __init__(self, val): + name = val + v = None + if ':' in name: + name, v = name.split(':', 1) + try: + v = re.compile(v) + except re.error as why: + raise ConfigurationError(why.args[0]) + if v is None: + self.text = 'header %s' % (name,) + else: + self.text = 'header %s = %s' % (name, v) + self.name = name + self.val = v + + def __text__(self): + return self.text + + def __phash__(self): + return 'header:%r=%r' % (self.name, self.val) + + def __call__(self, context, request): + if self.val is None: + return self.name in request.headers + val = request.headers.get(self.name) + if val is None: + return False + return self.val.match(val) is not None + +class AcceptPredicate(object): + def __init__(self, val): + self.val = val + + def __text__(self): + return 'accept = %s' % (self.val,) + + def __phash__(self): + return 'accept:%r' % (self.val,) + + def __call__(self, context, request): + return self.val in request.accept + +class ContainmentPredicate(object): + def __init__(self, val): + self.val = val + + def __text__(self): + return 'containment = %s' % (self.val,) + + def __phash__(self): + return 'containment:%r' % hash(self.val) + + def __call__(self, context, request): + ctx = getattr(request, 'context', context) + return find_interface(ctx, self.val) is not None + +class RequestTypePredicate(object): + def __init__(self, val): + self.val = val + + def __text__(self): + return 'request_type = %s' % (self.val,) + + def __phash__(self): + return 'request_type:%r' % hash(self.val) + + def __call__(self, context, request): + return self.val.providedBy(request) + +class MatchParamPredicate(object): + def __init__(self, val): + if not is_nonstr_iter(val): + val = (val,) + val = sorted(val) + self.val = val + self.reqs = [ + (x.strip(), y.strip()) for x, y in [ p.split('=', 1) for p in val ] + ] + + def __text__(self): + return 'match_param %s' % (self.val,) + + def __phash__(self): + L = [] + for k, v in self.reqs: + L.append('match_param:%r=%r' % (k, v)) + return L + + def __call__(self, context, request): + for k, v in self.reqs: + if request.matchdict.get(k) != v: + return False + return True + +class CustomPredicate(object): + def __init__(self, func): + self.func = func + + def __text__(self): + return getattr(self.func, '__text__', repr(self.func)) + + def __phash__(self): + return 'custom:%r' % hash(self.func) + + def __call__(self, context, request): + return self.func(context, request) + + +class TraversePredicate(object): + def __init__(self, val): + _, self.tgenerate = _compile_route(val) + self.val = val + + def __text__(self): + return 'traverse matchdict pseudo-predicate' + + def __phash__(self): + return '' + + def __call__(self, context, request): + if 'traverse' in context: + return True + m = context['match'] + tvalue = self.tgenerate(m) + m['traverse'] = traversal_path(tvalue) + return True + + diff --git a/pyramid/config/util.py b/pyramid/config/util.py index 027060db3..1574d7da6 100644 --- a/pyramid/config/util.py +++ b/pyramid/config/util.py @@ -324,14 +324,26 @@ class TopologicalSorter(object): self.names = [] self.req_before = set() self.req_after = set() + self.name2before = {} + self.name2after = {} self.name2val = {} self.order = [] self.default_before = default_before self.default_after = default_after self.first = first self.last = last - + + def remove(self, name): + if name in self.names: + self.names.remove(name) + del self.name2val[name] + for u in self.name2after.get(name, []): + self.order.remove((u, name)) + for u in self.name2before.get(name, []): + self.order.remove((name, u)) + def add(self, name, val, after=None, before=None): + self.remove(name) self.names.append(name) self.name2val[name] = val if after is None and before is None: @@ -340,11 +352,13 @@ class TopologicalSorter(object): if after is not None: if not is_nonstr_iter(after): after = (after,) + self.name2after[name] = after self.order += [(u, name) for u in after] self.req_after.add(name) if before is not None: if not is_nonstr_iter(before): before = (before,) + self.name2before[name] = before self.order += [(name, o) for o in before] self.req_before.add(name) @@ -432,3 +446,43 @@ class CyclicDependencyError(Exception): L.append('%r sorts before %r' % (dependent, dependees)) msg = 'Implicit ordering cycle:' + '; '.join(L) return msg + +class PredicateList(object): + def __init__(self): + self.sorter = TopologicalSorter() + + def add(self, name, factory, weighs_more_than=None, weighs_less_than=None): + self.sorter.add(name, factory, after=weighs_more_than, + before=weighs_less_than) + + def make(self, **kw): + ordered = self.sorter.sorted() + phash = md5() + weights = [] + predicates = [] + for order, (name, predicate_factory) in enumerate(ordered): + vals = kw.pop(name, None) + if vals is None: + continue + if not isinstance(vals, SequenceOfPredicateValues): + vals = (vals,) + for val in vals: + predicate = predicate_factory(val) + hashes = predicate.__phash__() + if not is_nonstr_iter(hashes): + hashes = [hashes] + for h in hashes: + phash.update(bytes_(h)) + predicate = predicate_factory(val) + weights.append(1 << order) + predicates.append(predicate) + if kw: + raise ConfigurationError('Unknown predicate values: %r' % (kw,)) + score = 0 + for bit in weights: + score = score | bit + order = (MAX_ORDER - score) / (len(predicates) + 1) + return order, predicates, phash.hexdigest() + +class SequenceOfPredicateValues(tuple): + pass diff --git a/pyramid/config/views.py b/pyramid/config/views.py index 4354b4691..b59d18400 100644 --- a/pyramid/config/views.py +++ b/pyramid/config/views.py @@ -20,6 +20,7 @@ from pyramid.interfaces import ( IException, IExceptionViewClassifier, IMultiView, + IPredicateList, IRendererFactory, IRequest, IResponse, @@ -65,12 +66,15 @@ from pyramid.view import ( from pyramid.util import object_description +from pyramid.config import predicates + from pyramid.config.util import ( DEFAULT_PHASH, MAX_ORDER, action_method, as_sorted_tuple, - make_predicates, + PredicateList, + SequenceOfPredicateValues, ) urljoin = urlparse.urljoin @@ -272,11 +276,11 @@ class ViewDeriver(object): @wraps_view def predicated_view(self, view): - predicates = self.kw.get('predicates', ()) - if not predicates: + preds = self.kw.get('predicates', ()) + if not preds: return view def predicate_wrapper(context, request): - for predicate in predicates: + for predicate in preds: if not predicate(context, request): view_name = getattr(view, '__name__', view) raise PredicateMismatch( @@ -285,9 +289,9 @@ class ViewDeriver(object): return view(context, request) def checker(context, request): return all((predicate(context, request) for predicate in - predicates)) + preds)) predicate_wrapper.__predicated__ = checker - predicate_wrapper.__predicates__ = predicates + predicate_wrapper.__predicates__ = preds return predicate_wrapper @wraps_view @@ -634,10 +638,10 @@ class ViewsConfiguratorMixin(object): def add_view(self, view=None, name="", for_=None, permission=None, request_type=None, route_name=None, request_method=None, request_param=None, containment=None, attr=None, - renderer=None, wrapper=None, xhr=False, accept=None, + renderer=None, wrapper=None, xhr=None, accept=None, header=None, path_info=None, custom_predicates=(), context=None, decorator=None, mapper=None, http_cache=None, - match_param=None): + match_param=None, **other_predicates): """ Add a :term:`view configuration` to the current configuration state. Arguments to ``add_view`` are broken down below into *predicate* arguments and *non-predicate* @@ -1003,12 +1007,6 @@ class ViewsConfiguratorMixin(object): # GET implies HEAD too request_method = as_sorted_tuple(request_method + ('HEAD',)) - order, predicates, phash = make_predicates(xhr=xhr, - request_method=request_method, path_info=path_info, - request_param=request_param, header=header, accept=accept, - containment=containment, request_type=request_type, - match_param=match_param, custom=custom_predicates) - if context is None: context = for_ @@ -1024,12 +1022,24 @@ class ViewsConfiguratorMixin(object): registry = self.registry) introspectables = [] - discriminator = [ - 'view', context, name, request_type, IView, containment, - request_param, request_method, route_name, attr, - xhr, accept, header, path_info, match_param] - discriminator.extend(sorted([hash(x) for x in custom_predicates])) - discriminator = tuple(discriminator) + pvals = other_predicates + pvals.update( + dict( + xhr=xhr, + request_method=request_method, + path_info=path_info, + request_param=request_param, + header=header, + accept=accept, + containment=containment, + request_type=request_type, + match_param=match_param, + custom=SequenceOfPredicateValues(custom_predicates), + ) + ) + + discriminator = ('view', context, name, route_name, attr, + str(sorted(pvals.items()))) if inspect.isclass(view) and attr: view_desc = 'method %r of %s' % ( attr, self.object_description(view)) @@ -1057,9 +1067,13 @@ class ViewsConfiguratorMixin(object): decorator=decorator, ) ) + view_intr.update(**other_predicates) introspectables.append(view_intr) + predlist = self.view_predlist def register(permission=permission, renderer=renderer): + order, preds, phash = predlist.make(**pvals) + view_intr.update({'phash':phash}) request_iface = IRequest if route_name is not None: request_iface = self.registry.queryUtility(IRouteRequest, @@ -1087,7 +1101,7 @@ class ViewsConfiguratorMixin(object): # __no_permission_required__ handled by _secure_view deriver = ViewDeriver(registry=self.registry, permission=permission, - predicates=predicates, + predicates=preds, attr=attr, renderer=renderer, wrapper_viewname=wrapper, @@ -1230,6 +1244,66 @@ class ViewsConfiguratorMixin(object): introspectables.append(perm_intr) self.action(discriminator, register, introspectables=introspectables) + @property + def view_predlist(self): + predlist = self.registry.queryUtility(IPredicateList, name='view') + if predlist is None: + predlist = PredicateList() + self.registry.registerUtility(predlist, IPredicateList, name='view') + return predlist + + @action_method + def add_view_predicate(self, name, factory, weighs_more_than=None, + weighs_less_than=None): + """ Adds a view predicate factory. The view predicate can later be + named as a keyword argument to + :meth:`pyramid.config.Configurator.add_view`. + + ``name`` should be the name of the predicate. It must be a valid + Python identifier (it will be used as a keyword argument to + ``add_view``). + + ``factory`` should be a :term:`predicate factory`. + """ + discriminator = ('view predicate', name) + intr = self.introspectable( + 'view predicates', + discriminator, + 'view predicate named %s' % name, + 'view predicate') + intr['name'] = name + intr['factory'] = factory + intr['weighs_more_than'] = weighs_more_than + intr['weighs_less_than'] = weighs_less_than + def register(): + predlist = self.view_predlist + predlist.add(name, factory, weighs_more_than=weighs_more_than, + weighs_less_than=weighs_less_than) + self.action(discriminator, register, introspectables=(intr,), + order=PHASE1_CONFIG) # must be registered before views added + + def add_default_view_predicates(self): + self.add_view_predicate( + 'xhr', predicates.XHRPredicate) + self.add_view_predicate( + 'request_method', predicates.RequestMethodPredicate) + self.add_view_predicate( + 'path_info', predicates.PathInfoPredicate) + self.add_view_predicate( + 'request_param', predicates.RequestParamPredicate) + self.add_view_predicate( + 'header', predicates.HeaderPredicate) + self.add_view_predicate( + 'accept', predicates.AcceptPredicate) + self.add_view_predicate( + 'containment', predicates.ContainmentPredicate) + self.add_view_predicate( + 'request_type', predicates.RequestTypePredicate) + self.add_view_predicate( + 'match_param', predicates.MatchParamPredicate) + self.add_view_predicate( + 'custom', predicates.CustomPredicate) + def derive_view(self, view, attr=None, renderer=None): """ Create a :term:`view callable` using the function, instance, diff --git a/pyramid/interfaces.py b/pyramid/interfaces.py index 1445ee394..114a01854 100644 --- a/pyramid/interfaces.py +++ b/pyramid/interfaces.py @@ -1111,6 +1111,9 @@ class IJSONAdapter(Interface): into a JSON-serializable primitive. """ +class IPredicateList(Interface): + """ Interface representing a predicate list """ + # configuration phases: a lower phase number means the actions associated # with this phase will be executed earlier than those with later phase # numbers. The default phase number is 0, FTR. diff --git a/pyramid/testing.py b/pyramid/testing.py index 40e90cda6..2628dc817 100644 --- a/pyramid/testing.py +++ b/pyramid/testing.py @@ -824,6 +824,7 @@ def setUp(registry=None, request=None, hook_zca=True, autocommit=True, # ``render_template`` and friends went behind the back of # any existing renderer factory lookup system. config.add_renderer(name, renderer) + config.add_default_view_predicates() config.commit() global have_zca try: diff --git a/pyramid/tests/test_config/test_init.py b/pyramid/tests/test_config/test_init.py index 37c3de275..b23168aaa 100644 --- a/pyramid/tests/test_config/test_init.py +++ b/pyramid/tests/test_config/test_init.py @@ -349,7 +349,7 @@ class ConfiguratorTests(unittest.TestCase): config.setup_registry() self.assertEqual(reg.has_listeners, True) - def test_setup_registry_registers_default_exceptionresponse_view(self): + def test_setup_registry_registers_default_exceptionresponse_views(self): from webob.exc import WSGIHTTPException from pyramid.interfaces import IExceptionResponse from pyramid.view import default_exceptionresponse_view @@ -357,6 +357,7 @@ class ConfiguratorTests(unittest.TestCase): config = self._makeOne(reg) views = [] config.add_view = lambda *arg, **kw: views.append((arg, kw)) + config.add_default_view_predicates = lambda *arg: None config._add_tween = lambda *arg, **kw: False config.setup_registry() self.assertEqual(views[0], ((default_exceptionresponse_view,), @@ -364,6 +365,16 @@ class ConfiguratorTests(unittest.TestCase): self.assertEqual(views[1], ((default_exceptionresponse_view,), {'context':WSGIHTTPException})) + def test_setup_registry_registers_default_view_predicates(self): + reg = DummyRegistry() + config = self._makeOne(reg) + vp_called = [] + config.add_view = lambda *arg, **kw: None + config.add_default_view_predicates = lambda *arg: vp_called.append(True) + config._add_tween = lambda *arg, **kw: False + config.setup_registry() + self.assertTrue(vp_called) + def test_setup_registry_registers_default_webob_iresponse_adapter(self): from webob import Response from pyramid.interfaces import IResponse @@ -1940,10 +1951,11 @@ class DummyEvent: pass class DummyRegistry(object): - def __init__(self, adaptation=None): + def __init__(self, adaptation=None, util=None): self.utilities = [] self.adapters = [] self.adaptation = adaptation + self.util = util def subscribers(self, events, name): self.events = events return events @@ -1953,6 +1965,8 @@ class DummyRegistry(object): self.adapters.append((arg, kw)) def queryAdapter(self, *arg, **kw): return self.adaptation + def queryUtility(self, *arg, **kw): + return self.util from pyramid.interfaces import IResponse @implementer(IResponse) @@ -1983,3 +1997,6 @@ class DummyIntrospectable(object): def register(self, introspector, action_info): self.registered.append((introspector, action_info)) +class DummyPredicateList(object): + def add(self, name, factory, weighs_more_than=None, weighs_less_than=None): + pass diff --git a/pyramid/tests/test_config/test_views.py b/pyramid/tests/test_config/test_views.py index ebf1dfb39..ea8883478 100644 --- a/pyramid/tests/test_config/test_views.py +++ b/pyramid/tests/test_config/test_views.py @@ -970,8 +970,8 @@ class TestViewsConfigurationMixin(unittest.TestCase): wrapper = self._getViewCallable(config) self.assertTrue(IMultiView.providedBy(wrapper)) request = self._makeRequest(config) - self.assertEqual(wrapper.__discriminator__(foo, request)[5], IFoo) - self.assertEqual(wrapper.__discriminator__(bar, request)[5], IBar) + self.assertTrue('IFoo' in wrapper.__discriminator__(foo, request)[5]) + self.assertTrue('IBar' in wrapper.__discriminator__(bar, request)[5]) def test_add_view_with_template_renderer(self): from pyramid.tests import test_config @@ -1217,8 +1217,8 @@ class TestViewsConfigurationMixin(unittest.TestCase): def test_add_view_with_header_badregex(self): view = lambda *arg: 'OK' config = self._makeOne() - self.assertRaises(ConfigurationError, - config.add_view, view=view, header='Host:a\\') + config.add_view(view, header='Host:a\\') + self.assertRaises(ConfigurationError, config.commit) def test_add_view_with_header_noval_match(self): from pyramid.renderers import null_renderer @@ -1323,8 +1323,8 @@ class TestViewsConfigurationMixin(unittest.TestCase): def test_add_view_with_path_info_badregex(self): view = lambda *arg: 'OK' config = self._makeOne() - self.assertRaises(ConfigurationError, - config.add_view, view=view, path_info='\\') + config.add_view(view, path_info='\\') + self.assertRaises(ConfigurationError, config.commit) def test_add_view_with_path_info_match(self): from pyramid.renderers import null_renderer diff --git a/pyramid/tests/test_url.py b/pyramid/tests/test_url.py index 50deb63f3..a7a565356 100644 --- a/pyramid/tests/test_url.py +++ b/pyramid/tests/test_url.py @@ -2,10 +2,8 @@ import os import unittest import warnings -from pyramid.testing import ( - setUp, - tearDown, - ) +from pyramid import testing + from pyramid.compat import ( text_, native_, @@ -14,10 +12,10 @@ from pyramid.compat import ( class TestURLMethodsMixin(unittest.TestCase): def setUp(self): - self.config = setUp() + self.config = testing.setUp() def tearDown(self): - tearDown() + testing.tearDown() def _makeOne(self, environ=None): from pyramid.url import URLMethodsMixin diff --git a/pyramid/tests/test_view.py b/pyramid/tests/test_view.py index a105adb70..ee4994172 100644 --- a/pyramid/tests/test_view.py +++ b/pyramid/tests/test_view.py @@ -3,17 +3,14 @@ import sys from zope.interface import implementer -from pyramid.testing import ( - setUp, - tearDown, - ) +from pyramid import testing class BaseTest(object): def setUp(self): - self.config = setUp() + self.config = testing.setUp() def tearDown(self): - tearDown() + testing.tearDown() def _registerView(self, reg, app, name): from pyramid.interfaces import IRequest @@ -334,10 +331,10 @@ class TestIsResponse(unittest.TestCase): class TestViewConfigDecorator(unittest.TestCase): def setUp(self): - setUp() + testing.setUp() def tearDown(self): - tearDown() + testing.tearDown() def _getTargetClass(self): from pyramid.view import view_config -- cgit v1.2.3