diff options
| author | Michael Merickel <mmerickel@users.noreply.github.com> | 2016-04-18 09:56:23 -0500 |
|---|---|---|
| committer | Michael Merickel <mmerickel@users.noreply.github.com> | 2016-04-18 09:56:23 -0500 |
| commit | 6c16fb020027fac47e4d2e335cd9e264dba8aa3b (patch) | |
| tree | 1306181202cb8313f16080789f5b9ab1eeb61d53 | |
| parent | 8840437df934a3a29a19be4bfee96cbcf5d537ff (diff) | |
| parent | 6f524a94157b0caa471222b0d9768a48173c1c7e (diff) | |
| download | pyramid-6c16fb020027fac47e4d2e335cd9e264dba8aa3b.tar.gz pyramid-6c16fb020027fac47e4d2e335cd9e264dba8aa3b.tar.bz2 pyramid-6c16fb020027fac47e4d2e335cd9e264dba8aa3b.zip | |
Merge pull request #2517 from mmerickel/fix/disable-csrf-on-exception-views-by-default
disable csrf checking on all exception views unless explicitly turned on
| -rw-r--r-- | pyramid/tests/test_viewderivers.py | 59 | ||||
| -rw-r--r-- | pyramid/viewderivers.py | 24 |
2 files changed, 74 insertions, 9 deletions
diff --git a/pyramid/tests/test_viewderivers.py b/pyramid/tests/test_viewderivers.py index 4767da580..4a7a04197 100644 --- a/pyramid/tests/test_viewderivers.py +++ b/pyramid/tests/test_viewderivers.py @@ -1297,6 +1297,64 @@ class TestDeriveView(unittest.TestCase): result = view(None, request) self.assertTrue(result is response) + def test_csrf_view_skipped_by_default_on_exception_view(self): + from pyramid.request import Request + def view(request): + raise ValueError + def excview(request): + return 'hello' + self.config.add_settings({'pyramid.require_default_csrf': 'yes'}) + self.config.set_session_factory( + lambda request: DummySession({'csrf_token': 'foo'})) + self.config.add_view(view, name='foo', require_csrf=False) + self.config.add_view(excview, context=ValueError, renderer='string') + app = self.config.make_wsgi_app() + request = Request.blank('/foo', base_url='http://example.com') + request.method = 'POST' + response = request.get_response(app) + self.assertTrue(b'hello' in response.body) + + def test_csrf_view_failed_on_explicit_exception_view(self): + from pyramid.exceptions import BadCSRFToken + from pyramid.request import Request + def view(request): + raise ValueError + def excview(request): pass + self.config.add_settings({'pyramid.require_default_csrf': 'yes'}) + self.config.set_session_factory( + lambda request: DummySession({'csrf_token': 'foo'})) + self.config.add_view(view, name='foo', require_csrf=False) + self.config.add_view(excview, context=ValueError, renderer='string', + require_csrf=True) + app = self.config.make_wsgi_app() + request = Request.blank('/foo', base_url='http://example.com') + request.method = 'POST' + try: + request.get_response(app) + except BadCSRFToken: + pass + else: # pragma: no cover + raise AssertionError + + def test_csrf_view_passed_on_explicit_exception_view(self): + from pyramid.request import Request + def view(request): + raise ValueError + def excview(request): + return 'hello' + self.config.add_settings({'pyramid.require_default_csrf': 'yes'}) + self.config.set_session_factory( + lambda request: DummySession({'csrf_token': 'foo'})) + self.config.add_view(view, name='foo', require_csrf=False) + self.config.add_view(excview, context=ValueError, renderer='string', + require_csrf=True) + app = self.config.make_wsgi_app() + request = Request.blank('/foo', base_url='http://example.com') + request.method = 'POST' + request.headers['X-CSRF-Token'] = 'foo' + response = request.get_response(app) + self.assertTrue(b'hello' in response.body) + class TestDerivationOrder(unittest.TestCase): def setUp(self): @@ -1554,7 +1612,6 @@ class TestDeriverIntegration(unittest.TestCase): from pyramid.interfaces import IRequest from pyramid.interfaces import IView from pyramid.interfaces import IViewClassifier - from pyramid.interfaces import IExceptionViewClassifier classifier = IViewClassifier if ctx_iface is None: ctx_iface = Interface diff --git a/pyramid/viewderivers.py b/pyramid/viewderivers.py index d9d9c2904..fbe7cd660 100644 --- a/pyramid/viewderivers.py +++ b/pyramid/viewderivers.py @@ -483,21 +483,29 @@ def csrf_view(view, info): default_val = _parse_csrf_setting( info.settings.get('pyramid.require_default_csrf'), 'Config setting "pyramid.require_default_csrf"') - val = _parse_csrf_setting( + explicit_val = _parse_csrf_setting( info.options.get('require_csrf'), 'View option "require_csrf"') - if (val is True and default_val) or val is None: - val = default_val - if val is True: - val = 'csrf_token' + resolved_val = explicit_val + if (explicit_val is True and default_val) or explicit_val is None: + resolved_val = default_val + if resolved_val is True: + resolved_val = 'csrf_token' wrapped_view = view - if val: + if resolved_val: def csrf_view(context, request): # Assume that anything not defined as 'safe' by RFC2616 needs # protection - if request.method not in SAFE_REQUEST_METHODS: + if ( + request.method not in SAFE_REQUEST_METHODS and + ( + # skip exception views unless value is explicitly defined + getattr(request, 'exception', None) is None or + explicit_val is not None + ) + ): check_csrf_origin(request, raises=True) - check_csrf_token(request, val, raises=True) + check_csrf_token(request, resolved_val, raises=True) return view(context, request) wrapped_view = csrf_view return wrapped_view |
