diff options
| -rw-r--r-- | pyramid/csrf.py | 8 | ||||
| -rw-r--r-- | pyramid/tests/test_csrf.py | 26 |
2 files changed, 33 insertions, 1 deletions
diff --git a/pyramid/csrf.py b/pyramid/csrf.py index b2788a764..f282eb569 100644 --- a/pyramid/csrf.py +++ b/pyramid/csrf.py @@ -184,7 +184,13 @@ def check_csrf_token(request, if supplied_token == "" and token is not None: supplied_token = request.POST.get(token, "") - policy = request.registry.getUtility(ICSRFStoragePolicy) + policy = request.registry.queryUtility(ICSRFStoragePolicy) + if policy is None: + # There is no policy set, but we are trying to validate a CSRF token + # This means explicit validation has been asked for without configuring + # the CSRF implementation. Fall back to SessionCSRF as that is the + # default + policy = SessionCSRF() if not policy.check_csrf_token(request, supplied_token): if raises: raise BadCSRFToken('check_csrf_token(): Invalid token') diff --git a/pyramid/tests/test_csrf.py b/pyramid/tests/test_csrf.py index 8866f3601..3994a31d4 100644 --- a/pyramid/tests/test_csrf.py +++ b/pyramid/tests/test_csrf.py @@ -313,6 +313,32 @@ class Test_check_csrf_token(unittest.TestCase): self.assertEqual(self._callFUT(request, token='csrf_token'), True) +class Test_check_csrf_token_without_defaults_configured(unittest.TestCase): + def setUp(self): + self.config = testing.setUp() + + def _callFUT(self, *args, **kwargs): + from ..csrf import check_csrf_token + return check_csrf_token(*args, **kwargs) + + def test_success_token(self): + request = testing.DummyRequest() + request.method = "POST" + request.POST = {'csrf_token': request.session.get_csrf_token()} + self.assertEqual(self._callFUT(request, token='csrf_token'), True) + + def test_failure_raises(self): + from pyramid.exceptions import BadCSRFToken + request = testing.DummyRequest() + self.assertRaises(BadCSRFToken, self._callFUT, request, + 'csrf_token') + + def test_failure_no_raises(self): + request = testing.DummyRequest() + result = self._callFUT(request, 'csrf_token', raises=False) + self.assertEqual(result, False) + + class Test_check_csrf_origin(unittest.TestCase): def _callFUT(self, *args, **kwargs): from ..csrf import check_csrf_origin |
