diff options
| author | Matthew Wilkes <git@matthewwilkes.name> | 2016-12-09 12:00:17 +0100 |
|---|---|---|
| committer | Matthew Wilkes <git@matthewwilkes.name> | 2017-04-12 12:14:12 +0100 |
| commit | f6d63a41d37b0647c49e53bb54f009f7da4d5079 (patch) | |
| tree | 5ace3508e4add4ba1a71b95708719897d7576802 | |
| parent | fe0d223ad08bcab724d216b3a877b690c5795f73 (diff) | |
| download | pyramid-f6d63a41d37b0647c49e53bb54f009f7da4d5079.tar.gz pyramid-f6d63a41d37b0647c49e53bb54f009f7da4d5079.tar.bz2 pyramid-f6d63a41d37b0647c49e53bb54f009f7da4d5079.zip | |
Fix a bug where people that didn't configure CSRF protection but did configure a session and set explicit checks would see an exception
| -rw-r--r-- | pyramid/csrf.py | 8 | ||||
| -rw-r--r-- | pyramid/tests/test_csrf.py | 26 |
2 files changed, 33 insertions, 1 deletions
diff --git a/pyramid/csrf.py b/pyramid/csrf.py index b2788a764..f282eb569 100644 --- a/pyramid/csrf.py +++ b/pyramid/csrf.py @@ -184,7 +184,13 @@ def check_csrf_token(request, if supplied_token == "" and token is not None: supplied_token = request.POST.get(token, "") - policy = request.registry.getUtility(ICSRFStoragePolicy) + policy = request.registry.queryUtility(ICSRFStoragePolicy) + if policy is None: + # There is no policy set, but we are trying to validate a CSRF token + # This means explicit validation has been asked for without configuring + # the CSRF implementation. Fall back to SessionCSRF as that is the + # default + policy = SessionCSRF() if not policy.check_csrf_token(request, supplied_token): if raises: raise BadCSRFToken('check_csrf_token(): Invalid token') diff --git a/pyramid/tests/test_csrf.py b/pyramid/tests/test_csrf.py index 8866f3601..3994a31d4 100644 --- a/pyramid/tests/test_csrf.py +++ b/pyramid/tests/test_csrf.py @@ -313,6 +313,32 @@ class Test_check_csrf_token(unittest.TestCase): self.assertEqual(self._callFUT(request, token='csrf_token'), True) +class Test_check_csrf_token_without_defaults_configured(unittest.TestCase): + def setUp(self): + self.config = testing.setUp() + + def _callFUT(self, *args, **kwargs): + from ..csrf import check_csrf_token + return check_csrf_token(*args, **kwargs) + + def test_success_token(self): + request = testing.DummyRequest() + request.method = "POST" + request.POST = {'csrf_token': request.session.get_csrf_token()} + self.assertEqual(self._callFUT(request, token='csrf_token'), True) + + def test_failure_raises(self): + from pyramid.exceptions import BadCSRFToken + request = testing.DummyRequest() + self.assertRaises(BadCSRFToken, self._callFUT, request, + 'csrf_token') + + def test_failure_no_raises(self): + request = testing.DummyRequest() + result = self._callFUT(request, 'csrf_token', raises=False) + self.assertEqual(result, False) + + class Test_check_csrf_origin(unittest.TestCase): def _callFUT(self, *args, **kwargs): from ..csrf import check_csrf_origin |
