summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMichael Merickel <mmerickel@users.noreply.github.com>2016-04-18 09:56:23 -0500
committerMichael Merickel <mmerickel@users.noreply.github.com>2016-04-18 09:56:23 -0500
commit6c16fb020027fac47e4d2e335cd9e264dba8aa3b (patch)
tree1306181202cb8313f16080789f5b9ab1eeb61d53
parent8840437df934a3a29a19be4bfee96cbcf5d537ff (diff)
parent6f524a94157b0caa471222b0d9768a48173c1c7e (diff)
downloadpyramid-6c16fb020027fac47e4d2e335cd9e264dba8aa3b.tar.gz
pyramid-6c16fb020027fac47e4d2e335cd9e264dba8aa3b.tar.bz2
pyramid-6c16fb020027fac47e4d2e335cd9e264dba8aa3b.zip
Merge pull request #2517 from mmerickel/fix/disable-csrf-on-exception-views-by-default
disable csrf checking on all exception views unless explicitly turned on
-rw-r--r--pyramid/tests/test_viewderivers.py59
-rw-r--r--pyramid/viewderivers.py24
2 files changed, 74 insertions, 9 deletions
diff --git a/pyramid/tests/test_viewderivers.py b/pyramid/tests/test_viewderivers.py
index 4767da580..4a7a04197 100644
--- a/pyramid/tests/test_viewderivers.py
+++ b/pyramid/tests/test_viewderivers.py
@@ -1297,6 +1297,64 @@ class TestDeriveView(unittest.TestCase):
result = view(None, request)
self.assertTrue(result is response)
+ def test_csrf_view_skipped_by_default_on_exception_view(self):
+ from pyramid.request import Request
+ def view(request):
+ raise ValueError
+ def excview(request):
+ return 'hello'
+ self.config.add_settings({'pyramid.require_default_csrf': 'yes'})
+ self.config.set_session_factory(
+ lambda request: DummySession({'csrf_token': 'foo'}))
+ self.config.add_view(view, name='foo', require_csrf=False)
+ self.config.add_view(excview, context=ValueError, renderer='string')
+ app = self.config.make_wsgi_app()
+ request = Request.blank('/foo', base_url='http://example.com')
+ request.method = 'POST'
+ response = request.get_response(app)
+ self.assertTrue(b'hello' in response.body)
+
+ def test_csrf_view_failed_on_explicit_exception_view(self):
+ from pyramid.exceptions import BadCSRFToken
+ from pyramid.request import Request
+ def view(request):
+ raise ValueError
+ def excview(request): pass
+ self.config.add_settings({'pyramid.require_default_csrf': 'yes'})
+ self.config.set_session_factory(
+ lambda request: DummySession({'csrf_token': 'foo'}))
+ self.config.add_view(view, name='foo', require_csrf=False)
+ self.config.add_view(excview, context=ValueError, renderer='string',
+ require_csrf=True)
+ app = self.config.make_wsgi_app()
+ request = Request.blank('/foo', base_url='http://example.com')
+ request.method = 'POST'
+ try:
+ request.get_response(app)
+ except BadCSRFToken:
+ pass
+ else: # pragma: no cover
+ raise AssertionError
+
+ def test_csrf_view_passed_on_explicit_exception_view(self):
+ from pyramid.request import Request
+ def view(request):
+ raise ValueError
+ def excview(request):
+ return 'hello'
+ self.config.add_settings({'pyramid.require_default_csrf': 'yes'})
+ self.config.set_session_factory(
+ lambda request: DummySession({'csrf_token': 'foo'}))
+ self.config.add_view(view, name='foo', require_csrf=False)
+ self.config.add_view(excview, context=ValueError, renderer='string',
+ require_csrf=True)
+ app = self.config.make_wsgi_app()
+ request = Request.blank('/foo', base_url='http://example.com')
+ request.method = 'POST'
+ request.headers['X-CSRF-Token'] = 'foo'
+ response = request.get_response(app)
+ self.assertTrue(b'hello' in response.body)
+
class TestDerivationOrder(unittest.TestCase):
def setUp(self):
@@ -1554,7 +1612,6 @@ class TestDeriverIntegration(unittest.TestCase):
from pyramid.interfaces import IRequest
from pyramid.interfaces import IView
from pyramid.interfaces import IViewClassifier
- from pyramid.interfaces import IExceptionViewClassifier
classifier = IViewClassifier
if ctx_iface is None:
ctx_iface = Interface
diff --git a/pyramid/viewderivers.py b/pyramid/viewderivers.py
index d9d9c2904..fbe7cd660 100644
--- a/pyramid/viewderivers.py
+++ b/pyramid/viewderivers.py
@@ -483,21 +483,29 @@ def csrf_view(view, info):
default_val = _parse_csrf_setting(
info.settings.get('pyramid.require_default_csrf'),
'Config setting "pyramid.require_default_csrf"')
- val = _parse_csrf_setting(
+ explicit_val = _parse_csrf_setting(
info.options.get('require_csrf'),
'View option "require_csrf"')
- if (val is True and default_val) or val is None:
- val = default_val
- if val is True:
- val = 'csrf_token'
+ resolved_val = explicit_val
+ if (explicit_val is True and default_val) or explicit_val is None:
+ resolved_val = default_val
+ if resolved_val is True:
+ resolved_val = 'csrf_token'
wrapped_view = view
- if val:
+ if resolved_val:
def csrf_view(context, request):
# Assume that anything not defined as 'safe' by RFC2616 needs
# protection
- if request.method not in SAFE_REQUEST_METHODS:
+ if (
+ request.method not in SAFE_REQUEST_METHODS and
+ (
+ # skip exception views unless value is explicitly defined
+ getattr(request, 'exception', None) is None or
+ explicit_val is not None
+ )
+ ):
check_csrf_origin(request, raises=True)
- check_csrf_token(request, val, raises=True)
+ check_csrf_token(request, resolved_val, raises=True)
return view(context, request)
wrapped_view = csrf_view
return wrapped_view