summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorChris McDonough <chrism@plope.com>2012-08-03 12:55:12 -0400
committerChris McDonough <chrism@plope.com>2012-08-03 12:55:12 -0400
commita00621e45ef29cde34469798144156c80a17a1e9 (patch)
tree37e6122d52b74313403250dd31b3c49283601e02
parentfc3f23c094795e6c889531c9706ec9b1153aac67 (diff)
downloadpyramid-a00621e45ef29cde34469798144156c80a17a1e9.tar.gz
pyramid-a00621e45ef29cde34469798144156c80a17a1e9.tar.bz2
pyramid-a00621e45ef29cde34469798144156c80a17a1e9.zip
first cut at extensible view predicates via config.add_view_predicate; still requires testing of predicates themselves
-rw-r--r--pyramid/config/__init__.py2
-rw-r--r--pyramid/config/predicates.py218
-rw-r--r--pyramid/config/util.py56
-rw-r--r--pyramid/config/views.py116
-rw-r--r--pyramid/interfaces.py3
-rw-r--r--pyramid/testing.py1
-rw-r--r--pyramid/tests/test_config/test_init.py21
-rw-r--r--pyramid/tests/test_config/test_views.py12
-rw-r--r--pyramid/tests/test_url.py10
-rw-r--r--pyramid/tests/test_view.py13
10 files changed, 408 insertions, 44 deletions
diff --git a/pyramid/config/__init__.py b/pyramid/config/__init__.py
index 52d7aca83..5eb860ed5 100644
--- a/pyramid/config/__init__.py
+++ b/pyramid/config/__init__.py
@@ -353,6 +353,8 @@ class Configurator(
for name, renderer in DEFAULT_RENDERERS:
self.add_renderer(name, renderer)
+ self.add_default_view_predicates()
+
if exceptionresponse_view is not None:
exceptionresponse_view = self.maybe_dotted(exceptionresponse_view)
self.add_view(exceptionresponse_view, context=IExceptionResponse)
diff --git a/pyramid/config/predicates.py b/pyramid/config/predicates.py
new file mode 100644
index 000000000..24ec89c6b
--- /dev/null
+++ b/pyramid/config/predicates.py
@@ -0,0 +1,218 @@
+import re
+
+from pyramid.compat import is_nonstr_iter
+
+from pyramid.exceptions import ConfigurationError
+
+from pyramid.traversal import (
+ find_interface,
+ traversal_path,
+ )
+
+from pyramid.urldispatch import _compile_route
+
+from .util import as_sorted_tuple
+
+class XHRPredicate(object):
+ def __init__(self, val):
+ self.val = bool(val)
+
+ def __text__(self):
+ return 'xhr = True'
+
+ def __phash__(self):
+ return 'xhr:%r' % (self.val,)
+
+ def __call__(self, context, request):
+ return request.is_xhr
+
+
+class RequestMethodPredicate(object):
+ def __init__(self, val):
+ self.val = as_sorted_tuple(val)
+
+ def __text__(self):
+ return 'request method = %r' % (self.val,)
+
+ def __phash__(self):
+ L = []
+ for v in self.val:
+ L.append('request_method:%r' % v)
+ return L
+
+ def __call__(self, context, request):
+ return request.method in self.val
+
+class PathInfoPredicate(object):
+ def __init__(self, val):
+ self.orig = val
+ try:
+ val = re.compile(val)
+ except re.error as why:
+ raise ConfigurationError(why.args[0])
+ self.val = val
+
+ def __text__(self):
+ return 'path_info = %s' % (self.orig,)
+
+ def __phash__(self):
+ return 'path_info:%r' % (self.orig,)
+
+ def __call__(self, context, request):
+ return self.val.match(request.upath_info) is not None
+
+class RequestParamPredicate(object):
+ def __init__(self, val):
+ name = val
+ v = None
+ if '=' in name:
+ name, v = name.split('=', 1)
+ if v is None:
+ self.text = 'request_param %s' % (name,)
+ else:
+ self.text = 'request_param %s = %s' % (name, v)
+ self.name = name
+ self.val = v
+
+ def __text__(self):
+ return self.text
+
+ def __phash__(self):
+ return 'request_param:%r=%r' % (self.name, self.val)
+
+ def __call__(self, context, request):
+ if self.val is None:
+ return self.name in request.params
+ return request.params.get(self.name) == self.val
+
+
+class HeaderPredicate(object):
+ def __init__(self, val):
+ name = val
+ v = None
+ if ':' in name:
+ name, v = name.split(':', 1)
+ try:
+ v = re.compile(v)
+ except re.error as why:
+ raise ConfigurationError(why.args[0])
+ if v is None:
+ self.text = 'header %s' % (name,)
+ else:
+ self.text = 'header %s = %s' % (name, v)
+ self.name = name
+ self.val = v
+
+ def __text__(self):
+ return self.text
+
+ def __phash__(self):
+ return 'header:%r=%r' % (self.name, self.val)
+
+ def __call__(self, context, request):
+ if self.val is None:
+ return self.name in request.headers
+ val = request.headers.get(self.name)
+ if val is None:
+ return False
+ return self.val.match(val) is not None
+
+class AcceptPredicate(object):
+ def __init__(self, val):
+ self.val = val
+
+ def __text__(self):
+ return 'accept = %s' % (self.val,)
+
+ def __phash__(self):
+ return 'accept:%r' % (self.val,)
+
+ def __call__(self, context, request):
+ return self.val in request.accept
+
+class ContainmentPredicate(object):
+ def __init__(self, val):
+ self.val = val
+
+ def __text__(self):
+ return 'containment = %s' % (self.val,)
+
+ def __phash__(self):
+ return 'containment:%r' % hash(self.val)
+
+ def __call__(self, context, request):
+ ctx = getattr(request, 'context', context)
+ return find_interface(ctx, self.val) is not None
+
+class RequestTypePredicate(object):
+ def __init__(self, val):
+ self.val = val
+
+ def __text__(self):
+ return 'request_type = %s' % (self.val,)
+
+ def __phash__(self):
+ return 'request_type:%r' % hash(self.val)
+
+ def __call__(self, context, request):
+ return self.val.providedBy(request)
+
+class MatchParamPredicate(object):
+ def __init__(self, val):
+ if not is_nonstr_iter(val):
+ val = (val,)
+ val = sorted(val)
+ self.val = val
+ self.reqs = [
+ (x.strip(), y.strip()) for x, y in [ p.split('=', 1) for p in val ]
+ ]
+
+ def __text__(self):
+ return 'match_param %s' % (self.val,)
+
+ def __phash__(self):
+ L = []
+ for k, v in self.reqs:
+ L.append('match_param:%r=%r' % (k, v))
+ return L
+
+ def __call__(self, context, request):
+ for k, v in self.reqs:
+ if request.matchdict.get(k) != v:
+ return False
+ return True
+
+class CustomPredicate(object):
+ def __init__(self, func):
+ self.func = func
+
+ def __text__(self):
+ return getattr(self.func, '__text__', repr(self.func))
+
+ def __phash__(self):
+ return 'custom:%r' % hash(self.func)
+
+ def __call__(self, context, request):
+ return self.func(context, request)
+
+
+class TraversePredicate(object):
+ def __init__(self, val):
+ _, self.tgenerate = _compile_route(val)
+ self.val = val
+
+ def __text__(self):
+ return 'traverse matchdict pseudo-predicate'
+
+ def __phash__(self):
+ return ''
+
+ def __call__(self, context, request):
+ if 'traverse' in context:
+ return True
+ m = context['match']
+ tvalue = self.tgenerate(m)
+ m['traverse'] = traversal_path(tvalue)
+ return True
+
+
diff --git a/pyramid/config/util.py b/pyramid/config/util.py
index 027060db3..1574d7da6 100644
--- a/pyramid/config/util.py
+++ b/pyramid/config/util.py
@@ -324,14 +324,26 @@ class TopologicalSorter(object):
self.names = []
self.req_before = set()
self.req_after = set()
+ self.name2before = {}
+ self.name2after = {}
self.name2val = {}
self.order = []
self.default_before = default_before
self.default_after = default_after
self.first = first
self.last = last
-
+
+ def remove(self, name):
+ if name in self.names:
+ self.names.remove(name)
+ del self.name2val[name]
+ for u in self.name2after.get(name, []):
+ self.order.remove((u, name))
+ for u in self.name2before.get(name, []):
+ self.order.remove((name, u))
+
def add(self, name, val, after=None, before=None):
+ self.remove(name)
self.names.append(name)
self.name2val[name] = val
if after is None and before is None:
@@ -340,11 +352,13 @@ class TopologicalSorter(object):
if after is not None:
if not is_nonstr_iter(after):
after = (after,)
+ self.name2after[name] = after
self.order += [(u, name) for u in after]
self.req_after.add(name)
if before is not None:
if not is_nonstr_iter(before):
before = (before,)
+ self.name2before[name] = before
self.order += [(name, o) for o in before]
self.req_before.add(name)
@@ -432,3 +446,43 @@ class CyclicDependencyError(Exception):
L.append('%r sorts before %r' % (dependent, dependees))
msg = 'Implicit ordering cycle:' + '; '.join(L)
return msg
+
+class PredicateList(object):
+ def __init__(self):
+ self.sorter = TopologicalSorter()
+
+ def add(self, name, factory, weighs_more_than=None, weighs_less_than=None):
+ self.sorter.add(name, factory, after=weighs_more_than,
+ before=weighs_less_than)
+
+ def make(self, **kw):
+ ordered = self.sorter.sorted()
+ phash = md5()
+ weights = []
+ predicates = []
+ for order, (name, predicate_factory) in enumerate(ordered):
+ vals = kw.pop(name, None)
+ if vals is None:
+ continue
+ if not isinstance(vals, SequenceOfPredicateValues):
+ vals = (vals,)
+ for val in vals:
+ predicate = predicate_factory(val)
+ hashes = predicate.__phash__()
+ if not is_nonstr_iter(hashes):
+ hashes = [hashes]
+ for h in hashes:
+ phash.update(bytes_(h))
+ predicate = predicate_factory(val)
+ weights.append(1 << order)
+ predicates.append(predicate)
+ if kw:
+ raise ConfigurationError('Unknown predicate values: %r' % (kw,))
+ score = 0
+ for bit in weights:
+ score = score | bit
+ order = (MAX_ORDER - score) / (len(predicates) + 1)
+ return order, predicates, phash.hexdigest()
+
+class SequenceOfPredicateValues(tuple):
+ pass
diff --git a/pyramid/config/views.py b/pyramid/config/views.py
index 4354b4691..b59d18400 100644
--- a/pyramid/config/views.py
+++ b/pyramid/config/views.py
@@ -20,6 +20,7 @@ from pyramid.interfaces import (
IException,
IExceptionViewClassifier,
IMultiView,
+ IPredicateList,
IRendererFactory,
IRequest,
IResponse,
@@ -65,12 +66,15 @@ from pyramid.view import (
from pyramid.util import object_description
+from pyramid.config import predicates
+
from pyramid.config.util import (
DEFAULT_PHASH,
MAX_ORDER,
action_method,
as_sorted_tuple,
- make_predicates,
+ PredicateList,
+ SequenceOfPredicateValues,
)
urljoin = urlparse.urljoin
@@ -272,11 +276,11 @@ class ViewDeriver(object):
@wraps_view
def predicated_view(self, view):
- predicates = self.kw.get('predicates', ())
- if not predicates:
+ preds = self.kw.get('predicates', ())
+ if not preds:
return view
def predicate_wrapper(context, request):
- for predicate in predicates:
+ for predicate in preds:
if not predicate(context, request):
view_name = getattr(view, '__name__', view)
raise PredicateMismatch(
@@ -285,9 +289,9 @@ class ViewDeriver(object):
return view(context, request)
def checker(context, request):
return all((predicate(context, request) for predicate in
- predicates))
+ preds))
predicate_wrapper.__predicated__ = checker
- predicate_wrapper.__predicates__ = predicates
+ predicate_wrapper.__predicates__ = preds
return predicate_wrapper
@wraps_view
@@ -634,10 +638,10 @@ class ViewsConfiguratorMixin(object):
def add_view(self, view=None, name="", for_=None, permission=None,
request_type=None, route_name=None, request_method=None,
request_param=None, containment=None, attr=None,
- renderer=None, wrapper=None, xhr=False, accept=None,
+ renderer=None, wrapper=None, xhr=None, accept=None,
header=None, path_info=None, custom_predicates=(),
context=None, decorator=None, mapper=None, http_cache=None,
- match_param=None):
+ match_param=None, **other_predicates):
""" Add a :term:`view configuration` to the current
configuration state. Arguments to ``add_view`` are broken
down below into *predicate* arguments and *non-predicate*
@@ -1003,12 +1007,6 @@ class ViewsConfiguratorMixin(object):
# GET implies HEAD too
request_method = as_sorted_tuple(request_method + ('HEAD',))
- order, predicates, phash = make_predicates(xhr=xhr,
- request_method=request_method, path_info=path_info,
- request_param=request_param, header=header, accept=accept,
- containment=containment, request_type=request_type,
- match_param=match_param, custom=custom_predicates)
-
if context is None:
context = for_
@@ -1024,12 +1022,24 @@ class ViewsConfiguratorMixin(object):
registry = self.registry)
introspectables = []
- discriminator = [
- 'view', context, name, request_type, IView, containment,
- request_param, request_method, route_name, attr,
- xhr, accept, header, path_info, match_param]
- discriminator.extend(sorted([hash(x) for x in custom_predicates]))
- discriminator = tuple(discriminator)
+ pvals = other_predicates
+ pvals.update(
+ dict(
+ xhr=xhr,
+ request_method=request_method,
+ path_info=path_info,
+ request_param=request_param,
+ header=header,
+ accept=accept,
+ containment=containment,
+ request_type=request_type,
+ match_param=match_param,
+ custom=SequenceOfPredicateValues(custom_predicates),
+ )
+ )
+
+ discriminator = ('view', context, name, route_name, attr,
+ str(sorted(pvals.items())))
if inspect.isclass(view) and attr:
view_desc = 'method %r of %s' % (
attr, self.object_description(view))
@@ -1057,9 +1067,13 @@ class ViewsConfiguratorMixin(object):
decorator=decorator,
)
)
+ view_intr.update(**other_predicates)
introspectables.append(view_intr)
+ predlist = self.view_predlist
def register(permission=permission, renderer=renderer):
+ order, preds, phash = predlist.make(**pvals)
+ view_intr.update({'phash':phash})
request_iface = IRequest
if route_name is not None:
request_iface = self.registry.queryUtility(IRouteRequest,
@@ -1087,7 +1101,7 @@ class ViewsConfiguratorMixin(object):
# __no_permission_required__ handled by _secure_view
deriver = ViewDeriver(registry=self.registry,
permission=permission,
- predicates=predicates,
+ predicates=preds,
attr=attr,
renderer=renderer,
wrapper_viewname=wrapper,
@@ -1230,6 +1244,66 @@ class ViewsConfiguratorMixin(object):
introspectables.append(perm_intr)
self.action(discriminator, register, introspectables=introspectables)
+ @property
+ def view_predlist(self):
+ predlist = self.registry.queryUtility(IPredicateList, name='view')
+ if predlist is None:
+ predlist = PredicateList()
+ self.registry.registerUtility(predlist, IPredicateList, name='view')
+ return predlist
+
+ @action_method
+ def add_view_predicate(self, name, factory, weighs_more_than=None,
+ weighs_less_than=None):
+ """ Adds a view predicate factory. The view predicate can later be
+ named as a keyword argument to
+ :meth:`pyramid.config.Configurator.add_view`.
+
+ ``name`` should be the name of the predicate. It must be a valid
+ Python identifier (it will be used as a keyword argument to
+ ``add_view``).
+
+ ``factory`` should be a :term:`predicate factory`.
+ """
+ discriminator = ('view predicate', name)
+ intr = self.introspectable(
+ 'view predicates',
+ discriminator,
+ 'view predicate named %s' % name,
+ 'view predicate')
+ intr['name'] = name
+ intr['factory'] = factory
+ intr['weighs_more_than'] = weighs_more_than
+ intr['weighs_less_than'] = weighs_less_than
+ def register():
+ predlist = self.view_predlist
+ predlist.add(name, factory, weighs_more_than=weighs_more_than,
+ weighs_less_than=weighs_less_than)
+ self.action(discriminator, register, introspectables=(intr,),
+ order=PHASE1_CONFIG) # must be registered before views added
+
+ def add_default_view_predicates(self):
+ self.add_view_predicate(
+ 'xhr', predicates.XHRPredicate)
+ self.add_view_predicate(
+ 'request_method', predicates.RequestMethodPredicate)
+ self.add_view_predicate(
+ 'path_info', predicates.PathInfoPredicate)
+ self.add_view_predicate(
+ 'request_param', predicates.RequestParamPredicate)
+ self.add_view_predicate(
+ 'header', predicates.HeaderPredicate)
+ self.add_view_predicate(
+ 'accept', predicates.AcceptPredicate)
+ self.add_view_predicate(
+ 'containment', predicates.ContainmentPredicate)
+ self.add_view_predicate(
+ 'request_type', predicates.RequestTypePredicate)
+ self.add_view_predicate(
+ 'match_param', predicates.MatchParamPredicate)
+ self.add_view_predicate(
+ 'custom', predicates.CustomPredicate)
+
def derive_view(self, view, attr=None, renderer=None):
"""
Create a :term:`view callable` using the function, instance,
diff --git a/pyramid/interfaces.py b/pyramid/interfaces.py
index 1445ee394..114a01854 100644
--- a/pyramid/interfaces.py
+++ b/pyramid/interfaces.py
@@ -1111,6 +1111,9 @@ class IJSONAdapter(Interface):
into a JSON-serializable primitive.
"""
+class IPredicateList(Interface):
+ """ Interface representing a predicate list """
+
# configuration phases: a lower phase number means the actions associated
# with this phase will be executed earlier than those with later phase
# numbers. The default phase number is 0, FTR.
diff --git a/pyramid/testing.py b/pyramid/testing.py
index 40e90cda6..2628dc817 100644
--- a/pyramid/testing.py
+++ b/pyramid/testing.py
@@ -824,6 +824,7 @@ def setUp(registry=None, request=None, hook_zca=True, autocommit=True,
# ``render_template`` and friends went behind the back of
# any existing renderer factory lookup system.
config.add_renderer(name, renderer)
+ config.add_default_view_predicates()
config.commit()
global have_zca
try:
diff --git a/pyramid/tests/test_config/test_init.py b/pyramid/tests/test_config/test_init.py
index 37c3de275..b23168aaa 100644
--- a/pyramid/tests/test_config/test_init.py
+++ b/pyramid/tests/test_config/test_init.py
@@ -349,7 +349,7 @@ class ConfiguratorTests(unittest.TestCase):
config.setup_registry()
self.assertEqual(reg.has_listeners, True)
- def test_setup_registry_registers_default_exceptionresponse_view(self):
+ def test_setup_registry_registers_default_exceptionresponse_views(self):
from webob.exc import WSGIHTTPException
from pyramid.interfaces import IExceptionResponse
from pyramid.view import default_exceptionresponse_view
@@ -357,6 +357,7 @@ class ConfiguratorTests(unittest.TestCase):
config = self._makeOne(reg)
views = []
config.add_view = lambda *arg, **kw: views.append((arg, kw))
+ config.add_default_view_predicates = lambda *arg: None
config._add_tween = lambda *arg, **kw: False
config.setup_registry()
self.assertEqual(views[0], ((default_exceptionresponse_view,),
@@ -364,6 +365,16 @@ class ConfiguratorTests(unittest.TestCase):
self.assertEqual(views[1], ((default_exceptionresponse_view,),
{'context':WSGIHTTPException}))
+ def test_setup_registry_registers_default_view_predicates(self):
+ reg = DummyRegistry()
+ config = self._makeOne(reg)
+ vp_called = []
+ config.add_view = lambda *arg, **kw: None
+ config.add_default_view_predicates = lambda *arg: vp_called.append(True)
+ config._add_tween = lambda *arg, **kw: False
+ config.setup_registry()
+ self.assertTrue(vp_called)
+
def test_setup_registry_registers_default_webob_iresponse_adapter(self):
from webob import Response
from pyramid.interfaces import IResponse
@@ -1940,10 +1951,11 @@ class DummyEvent:
pass
class DummyRegistry(object):
- def __init__(self, adaptation=None):
+ def __init__(self, adaptation=None, util=None):
self.utilities = []
self.adapters = []
self.adaptation = adaptation
+ self.util = util
def subscribers(self, events, name):
self.events = events
return events
@@ -1953,6 +1965,8 @@ class DummyRegistry(object):
self.adapters.append((arg, kw))
def queryAdapter(self, *arg, **kw):
return self.adaptation
+ def queryUtility(self, *arg, **kw):
+ return self.util
from pyramid.interfaces import IResponse
@implementer(IResponse)
@@ -1983,3 +1997,6 @@ class DummyIntrospectable(object):
def register(self, introspector, action_info):
self.registered.append((introspector, action_info))
+class DummyPredicateList(object):
+ def add(self, name, factory, weighs_more_than=None, weighs_less_than=None):
+ pass
diff --git a/pyramid/tests/test_config/test_views.py b/pyramid/tests/test_config/test_views.py
index ebf1dfb39..ea8883478 100644
--- a/pyramid/tests/test_config/test_views.py
+++ b/pyramid/tests/test_config/test_views.py
@@ -970,8 +970,8 @@ class TestViewsConfigurationMixin(unittest.TestCase):
wrapper = self._getViewCallable(config)
self.assertTrue(IMultiView.providedBy(wrapper))
request = self._makeRequest(config)
- self.assertEqual(wrapper.__discriminator__(foo, request)[5], IFoo)
- self.assertEqual(wrapper.__discriminator__(bar, request)[5], IBar)
+ self.assertTrue('IFoo' in wrapper.__discriminator__(foo, request)[5])
+ self.assertTrue('IBar' in wrapper.__discriminator__(bar, request)[5])
def test_add_view_with_template_renderer(self):
from pyramid.tests import test_config
@@ -1217,8 +1217,8 @@ class TestViewsConfigurationMixin(unittest.TestCase):
def test_add_view_with_header_badregex(self):
view = lambda *arg: 'OK'
config = self._makeOne()
- self.assertRaises(ConfigurationError,
- config.add_view, view=view, header='Host:a\\')
+ config.add_view(view, header='Host:a\\')
+ self.assertRaises(ConfigurationError, config.commit)
def test_add_view_with_header_noval_match(self):
from pyramid.renderers import null_renderer
@@ -1323,8 +1323,8 @@ class TestViewsConfigurationMixin(unittest.TestCase):
def test_add_view_with_path_info_badregex(self):
view = lambda *arg: 'OK'
config = self._makeOne()
- self.assertRaises(ConfigurationError,
- config.add_view, view=view, path_info='\\')
+ config.add_view(view, path_info='\\')
+ self.assertRaises(ConfigurationError, config.commit)
def test_add_view_with_path_info_match(self):
from pyramid.renderers import null_renderer
diff --git a/pyramid/tests/test_url.py b/pyramid/tests/test_url.py
index 50deb63f3..a7a565356 100644
--- a/pyramid/tests/test_url.py
+++ b/pyramid/tests/test_url.py
@@ -2,10 +2,8 @@ import os
import unittest
import warnings
-from pyramid.testing import (
- setUp,
- tearDown,
- )
+from pyramid import testing
+
from pyramid.compat import (
text_,
native_,
@@ -14,10 +12,10 @@ from pyramid.compat import (
class TestURLMethodsMixin(unittest.TestCase):
def setUp(self):
- self.config = setUp()
+ self.config = testing.setUp()
def tearDown(self):
- tearDown()
+ testing.tearDown()
def _makeOne(self, environ=None):
from pyramid.url import URLMethodsMixin
diff --git a/pyramid/tests/test_view.py b/pyramid/tests/test_view.py
index a105adb70..ee4994172 100644
--- a/pyramid/tests/test_view.py
+++ b/pyramid/tests/test_view.py
@@ -3,17 +3,14 @@ import sys
from zope.interface import implementer
-from pyramid.testing import (
- setUp,
- tearDown,
- )
+from pyramid import testing
class BaseTest(object):
def setUp(self):
- self.config = setUp()
+ self.config = testing.setUp()
def tearDown(self):
- tearDown()
+ testing.tearDown()
def _registerView(self, reg, app, name):
from pyramid.interfaces import IRequest
@@ -334,10 +331,10 @@ class TestIsResponse(unittest.TestCase):
class TestViewConfigDecorator(unittest.TestCase):
def setUp(self):
- setUp()
+ testing.setUp()
def tearDown(self):
- tearDown()
+ testing.tearDown()
def _getTargetClass(self):
from pyramid.view import view_config