summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMatthew Wilkes <git@matthewwilkes.name>2016-12-09 12:00:17 +0100
committerMatthew Wilkes <git@matthewwilkes.name>2017-04-12 12:14:12 +0100
commitf6d63a41d37b0647c49e53bb54f009f7da4d5079 (patch)
tree5ace3508e4add4ba1a71b95708719897d7576802
parentfe0d223ad08bcab724d216b3a877b690c5795f73 (diff)
downloadpyramid-f6d63a41d37b0647c49e53bb54f009f7da4d5079.tar.gz
pyramid-f6d63a41d37b0647c49e53bb54f009f7da4d5079.tar.bz2
pyramid-f6d63a41d37b0647c49e53bb54f009f7da4d5079.zip
Fix a bug where people that didn't configure CSRF protection but did configure a session and set explicit checks would see an exception
-rw-r--r--pyramid/csrf.py8
-rw-r--r--pyramid/tests/test_csrf.py26
2 files changed, 33 insertions, 1 deletions
diff --git a/pyramid/csrf.py b/pyramid/csrf.py
index b2788a764..f282eb569 100644
--- a/pyramid/csrf.py
+++ b/pyramid/csrf.py
@@ -184,7 +184,13 @@ def check_csrf_token(request,
if supplied_token == "" and token is not None:
supplied_token = request.POST.get(token, "")
- policy = request.registry.getUtility(ICSRFStoragePolicy)
+ policy = request.registry.queryUtility(ICSRFStoragePolicy)
+ if policy is None:
+ # There is no policy set, but we are trying to validate a CSRF token
+ # This means explicit validation has been asked for without configuring
+ # the CSRF implementation. Fall back to SessionCSRF as that is the
+ # default
+ policy = SessionCSRF()
if not policy.check_csrf_token(request, supplied_token):
if raises:
raise BadCSRFToken('check_csrf_token(): Invalid token')
diff --git a/pyramid/tests/test_csrf.py b/pyramid/tests/test_csrf.py
index 8866f3601..3994a31d4 100644
--- a/pyramid/tests/test_csrf.py
+++ b/pyramid/tests/test_csrf.py
@@ -313,6 +313,32 @@ class Test_check_csrf_token(unittest.TestCase):
self.assertEqual(self._callFUT(request, token='csrf_token'), True)
+class Test_check_csrf_token_without_defaults_configured(unittest.TestCase):
+ def setUp(self):
+ self.config = testing.setUp()
+
+ def _callFUT(self, *args, **kwargs):
+ from ..csrf import check_csrf_token
+ return check_csrf_token(*args, **kwargs)
+
+ def test_success_token(self):
+ request = testing.DummyRequest()
+ request.method = "POST"
+ request.POST = {'csrf_token': request.session.get_csrf_token()}
+ self.assertEqual(self._callFUT(request, token='csrf_token'), True)
+
+ def test_failure_raises(self):
+ from pyramid.exceptions import BadCSRFToken
+ request = testing.DummyRequest()
+ self.assertRaises(BadCSRFToken, self._callFUT, request,
+ 'csrf_token')
+
+ def test_failure_no_raises(self):
+ request = testing.DummyRequest()
+ result = self._callFUT(request, 'csrf_token', raises=False)
+ self.assertEqual(result, False)
+
+
class Test_check_csrf_origin(unittest.TestCase):
def _callFUT(self, *args, **kwargs):
from ..csrf import check_csrf_origin