summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMichael Merickel <michael@merickel.org>2016-03-22 01:02:05 -0500
committerMichael Merickel <michael@merickel.org>2016-04-10 22:12:38 -0500
commit9e9fa9ac40bdd79fbce69f94a13d705e40f3d458 (patch)
treeca50ee15c4790e7156771ee49c19441c123cd728
parentfa43952e617ad68c52447da28fc7f5be23ff4b10 (diff)
downloadpyramid-9e9fa9ac40bdd79fbce69f94a13d705e40f3d458.tar.gz
pyramid-9e9fa9ac40bdd79fbce69f94a13d705e40f3d458.tar.bz2
pyramid-9e9fa9ac40bdd79fbce69f94a13d705e40f3d458.zip
add a csrf_view to the view pipeline supporting a require_csrf option
-rw-r--r--docs/narr/hooks.rst28
-rw-r--r--pyramid/config/views.py34
-rw-r--r--pyramid/tests/test_config/test_views.py41
-rw-r--r--pyramid/tests/test_viewderivers.py31
-rw-r--r--pyramid/viewderivers.py15
5 files changed, 139 insertions, 10 deletions
diff --git a/docs/narr/hooks.rst b/docs/narr/hooks.rst
index 2c3782387..e7db97565 100644
--- a/docs/narr/hooks.rst
+++ b/docs/narr/hooks.rst
@@ -1580,6 +1580,11 @@ There are several built-in view derivers that :app:`Pyramid` will automatically
apply to any view. Below they are defined in order from furthest to closest to
the user-defined :term:`view callable`:
+``csrf_view``
+
+ Used to check the CSRF token provided in the request. This element is a
+ no-op if ``require_csrf`` is not defined.
+
``secured_view``
Enforce the ``permission`` defined on the view. This element is a no-op if no
@@ -1656,27 +1661,32 @@ View derivers are unique in that they have access to most of the options
passed to :meth:`pyramid.config.Configurator.add_view` in order to decide what
to do, and they have a chance to affect every view in the application.
-Let's look at one more example which will protect views by requiring a CSRF
-token unless ``disable_csrf=True`` is passed to the view:
+Let's override the default CSRF checker to default to on instead of off and
+only check ``POST`` requests:
.. code-block:: python
:linenos:
from pyramid.response import Response
from pyramid.session import check_csrf_token
+ from pyramid.viewderivers import INGRESS
- def require_csrf_view(view, info):
+ def csrf_view(view, info):
+ val = info.options.get('require_csrf', True)
wrapper_view = view
- if not info.options.get('disable_csrf', False):
- def wrapper_view(context, request):
+ if val:
+ if val is True:
+ val = 'csrf_token'
+ def csrf_view(context, request):
if request.method == 'POST':
- check_csrf_token(request)
+ check_csrf_token(request, val, raises=True)
return view(context, request)
+ wrapper_view = csrf_view
return wrapper_view
- require_csrf_view.options = ('disable_csrf',)
+ csrf_view.options = ('require_csrf',)
- config.add_view_deriver(require_csrf_view)
+ config.add_view_deriver(csrf_view, 'csrf_view', over='secured_view', under=INGRESS)
def protected_view(request):
return Response('protected')
@@ -1685,7 +1695,7 @@ token unless ``disable_csrf=True`` is passed to the view:
return Response('unprotected')
config.add_view(protected_view, name='safe')
- config.add_view(unprotected_view, name='unsafe', disable_csrf=True)
+ config.add_view(unprotected_view, name='unsafe', require_csrf=False)
Navigating to ``/safe`` with a POST request will then fail when the call to
:func:`pyramid.session.check_csrf_token` raises a
diff --git a/pyramid/config/views.py b/pyramid/config/views.py
index 3f6a9080d..58fdbfd06 100644
--- a/pyramid/config/views.py
+++ b/pyramid/config/views.py
@@ -213,6 +213,7 @@ class ViewsConfiguratorMixin(object):
http_cache=None,
match_param=None,
check_csrf=None,
+ require_csrf=None,
**view_options):
""" Add a :term:`view configuration` to the current
configuration state. Arguments to ``add_view`` are broken
@@ -366,6 +367,32 @@ class ViewsConfiguratorMixin(object):
before returning the response from the view. This effectively
disables any HTTP caching done by ``http_cache`` for that response.
+ require_csrf
+
+ .. versionadded:: 1.7
+
+ If specified, this value should be one of ``None``, ``True``,
+ ``False``, or a string representing the 'check name'. If the value
+ is ``True`` or a string, CSRF checking will be performed. If the
+ value is ``False`` or ``None``, CSRF checking will not be performed.
+
+ If the value provided is a string, that string will be used as the
+ 'check name'. If the value provided is ``True``, ``csrf_token`` will
+ be used as the check name.
+
+ If CSRF checking is performed, the checked value will be the value
+ of ``request.params[check_name]``. This value will be compared
+ against the value of ``request.session.get_csrf_token()``, and the
+ check will pass if these two values are the same. If the check
+ passes, the associated view will be permitted to execute. If the
+ check fails, the associated view will not be permitted to execute
+ and a :class:`pyramid.exceptions.BadCSRFToken` exception will
+ be raised. This exception may be caught and handled by an
+ :term:`exception view`.
+
+ Note that using this feature requires a :term:`session factory` to
+ have been configured.
+
wrapper
The :term:`view name` of a different :term:`view
@@ -805,6 +832,8 @@ class ViewsConfiguratorMixin(object):
path_info=path_info,
match_param=match_param,
check_csrf=check_csrf,
+ http_cache=http_cache,
+ require_csrf=require_csrf,
callable=view,
mapper=mapper,
decorator=decorator,
@@ -860,6 +889,7 @@ class ViewsConfiguratorMixin(object):
decorator=decorator,
mapper=mapper,
http_cache=http_cache,
+ require_csrf=require_csrf,
extra_options=ovals,
)
derived_view.__discriminator__ = lambda *arg: discriminator
@@ -1183,6 +1213,7 @@ class ViewsConfiguratorMixin(object):
def add_default_view_derivers(self):
d = pyramid.viewderivers
derivers = [
+ ('csrf_view', d.csrf_view),
('secured_view', d.secured_view),
('owrapped_view', d.owrapped_view),
('http_cached_view', d.http_cached_view),
@@ -1284,7 +1315,7 @@ class ViewsConfiguratorMixin(object):
viewname=None, accept=None, order=MAX_ORDER,
phash=DEFAULT_PHASH, decorator=None,
mapper=None, http_cache=None, context=None,
- extra_options=None):
+ require_csrf=None, extra_options=None):
view = self.maybe_dotted(view)
mapper = self.maybe_dotted(mapper)
if isinstance(renderer, string_types):
@@ -1311,6 +1342,7 @@ class ViewsConfiguratorMixin(object):
mapper=mapper,
decorator=decorator,
http_cache=http_cache,
+ require_csrf=require_csrf,
)
if extra_options:
options.update(extra_options)
diff --git a/pyramid/tests/test_config/test_views.py b/pyramid/tests/test_config/test_views.py
index b2513c42c..55ead55c2 100644
--- a/pyramid/tests/test_config/test_views.py
+++ b/pyramid/tests/test_config/test_views.py
@@ -1570,6 +1570,43 @@ class TestViewsConfigurationMixin(unittest.TestCase):
config.add_view(view=view2)
self.assertRaises(ConfigurationConflictError, config.commit)
+ def test_add_view_with_csrf_header(self):
+ from pyramid.renderers import null_renderer
+ def view(request):
+ return 'OK'
+ config = self._makeOne(autocommit=True)
+ config.add_view(view, require_csrf=True, renderer=null_renderer)
+ view = self._getViewCallable(config)
+ request = self._makeRequest(config)
+ request.headers = {'X-CSRF-Token': 'foo'}
+ request.session = DummySession({'csrf_token': 'foo'})
+ self.assertEqual(view(None, request), 'OK')
+
+ def test_add_view_with_csrf_param(self):
+ from pyramid.renderers import null_renderer
+ def view(request):
+ return 'OK'
+ config = self._makeOne(autocommit=True)
+ config.add_view(view, require_csrf='st', renderer=null_renderer)
+ view = self._getViewCallable(config)
+ request = self._makeRequest(config)
+ request.params = {'st': 'foo'}
+ request.headers = {}
+ request.session = DummySession({'csrf_token': 'foo'})
+ self.assertEqual(view(None, request), 'OK')
+
+ def test_add_view_with_missing_csrf_header(self):
+ from pyramid.exceptions import BadCSRFToken
+ from pyramid.renderers import null_renderer
+ def view(request): return 'OK'
+ config = self._makeOne(autocommit=True)
+ config.add_view(view, require_csrf=True, renderer=null_renderer)
+ view = self._getViewCallable(config)
+ request = self._makeRequest(config)
+ request.headers = {}
+ request.session = DummySession({'csrf_token': 'foo'})
+ self.assertRaises(BadCSRFToken, lambda: view(None, request))
+
def test_add_view_with_permission(self):
from pyramid.renderers import null_renderer
view1 = lambda *arg: 'OK'
@@ -3233,3 +3270,7 @@ class DummyIntrospector(object):
return self.getval
def relate(self, a, b):
self.related.append((a, b))
+
+class DummySession(dict):
+ def get_csrf_token(self):
+ return self['csrf_token']
diff --git a/pyramid/tests/test_viewderivers.py b/pyramid/tests/test_viewderivers.py
index 1823beb4d..0dd70b74a 100644
--- a/pyramid/tests/test_viewderivers.py
+++ b/pyramid/tests/test_viewderivers.py
@@ -1090,6 +1090,28 @@ class TestDeriveView(unittest.TestCase):
self.assertRaises(ConfigurationError, self.config._derive_view,
view, http_cache=(None,))
+ def test_csrf_view_requires_header(self):
+ response = DummyResponse()
+ def inner_view(request):
+ return response
+ request = self._makeRequest()
+ request.session = DummySession({'csrf_token': 'foo'})
+ request.headers = {'X-CSRF-Token': 'foo'}
+ view = self.config._derive_view(inner_view, require_csrf=True)
+ result = view(None, request)
+ self.assertTrue(result is response)
+
+ def test_csrf_view_requires_param(self):
+ response = DummyResponse()
+ def inner_view(request):
+ return response
+ request = self._makeRequest()
+ request.session = DummySession({'csrf_token': 'foo'})
+ request.params['DUMMY'] = 'foo'
+ view = self.config._derive_view(inner_view, require_csrf='DUMMY')
+ result = view(None, request)
+ self.assertTrue(result is response)
+
class TestDerivationOrder(unittest.TestCase):
def setUp(self):
@@ -1110,6 +1132,7 @@ class TestDerivationOrder(unittest.TestCase):
derivers_sorted = derivers.sorted()
dlist = [d for (d, _) in derivers_sorted]
self.assertEqual([
+ 'csrf_view',
'secured_view',
'owrapped_view',
'http_cached_view',
@@ -1132,6 +1155,7 @@ class TestDerivationOrder(unittest.TestCase):
derivers_sorted = derivers.sorted()
dlist = [d for (d, _) in derivers_sorted]
self.assertEqual([
+ 'csrf_view',
'secured_view',
'owrapped_view',
'http_cached_view',
@@ -1152,6 +1176,7 @@ class TestDerivationOrder(unittest.TestCase):
derivers_sorted = derivers.sorted()
dlist = [d for (d, _) in derivers_sorted]
self.assertEqual([
+ 'csrf_view',
'secured_view',
'owrapped_view',
'http_cached_view',
@@ -1173,6 +1198,7 @@ class TestDerivationOrder(unittest.TestCase):
derivers_sorted = derivers.sorted()
dlist = [d for (d, _) in derivers_sorted]
self.assertEqual([
+ 'csrf_view',
'secured_view',
'owrapped_view',
'http_cached_view',
@@ -1408,6 +1434,7 @@ class DummyRequest:
self.environ = environ
self.params = {}
self.cookies = {}
+ self.headers = {}
self.response = DummyResponse()
class DummyLogger:
@@ -1428,6 +1455,10 @@ class DummySecurityPolicy:
def permits(self, context, principals, permission):
return self.permitted
+class DummySession(dict):
+ def get_csrf_token(self):
+ return self['csrf_token']
+
def parse_httpdate(s):
import datetime
# cannot use %Z, must use literal GMT; Jython honors timezone
diff --git a/pyramid/viewderivers.py b/pyramid/viewderivers.py
index 8061e5d4a..7560fa67f 100644
--- a/pyramid/viewderivers.py
+++ b/pyramid/viewderivers.py
@@ -6,6 +6,7 @@ from zope.interface import (
)
from pyramid.security import NO_PERMISSION_REQUIRED
+from pyramid.session import check_csrf_token
from pyramid.response import Response
from pyramid.interfaces import (
@@ -455,5 +456,19 @@ def decorated_view(view, info):
decorated_view.options = ('decorator',)
+def csrf_view(view, info):
+ val = info.options.get('require_csrf')
+ wrapped_view = view
+ if val:
+ if val is True:
+ val = 'csrf_token'
+ def csrf_view(context, request):
+ check_csrf_token(request, val, raises=True)
+ return view(context, request)
+ wrapped_view = csrf_view
+ return wrapped_view
+
+csrf_view.options = ('require_csrf',)
+
VIEW = 'VIEW'
INGRESS = 'INGRESS'