diff options
| author | Luke Cyca <me@lukecyca.com> | 2013-05-30 10:25:58 -0700 |
|---|---|---|
| committer | Luke Cyca <me@lukecyca.com> | 2013-05-30 10:25:58 -0700 |
| commit | bfe654f0af480473531bdff77cefb676568aac8f (patch) | |
| tree | 226e2b59b13e377c61d556bb4f04a1334cc1f077 | |
| parent | 30715a7db397d90e786b85715303bfaf34109b31 (diff) | |
| download | pyramid-bfe654f0af480473531bdff77cefb676568aac8f.tar.gz pyramid-bfe654f0af480473531bdff77cefb676568aac8f.tar.bz2 pyramid-bfe654f0af480473531bdff77cefb676568aac8f.zip | |
Support CSRF via X-CSRFToken Header
| -rw-r--r-- | pyramid/session.py | 26 | ||||
| -rw-r--r-- | pyramid/tests/test_session.py | 21 |
2 files changed, 32 insertions, 15 deletions
diff --git a/pyramid/session.py b/pyramid/session.py index 7db8c8e0e..0433488d8 100644 --- a/pyramid/session.py +++ b/pyramid/session.py @@ -81,15 +81,22 @@ def signed_deserialize(serialized, secret, hmac=hmac): return pickle.loads(pickled) -def check_csrf_token(request, token='csrf_token', raises=True): +def check_csrf_token(request, + token='csrf_token', + header='X-CSRFToken', + raises=True): """ Check the CSRF token in the request's session against the value in - ``request.params.get(token)``. If a ``token`` keyword is not supplied - to this function, the string ``csrf_token`` will be used to look up - the token within ``request.params``. If the value in - ``request.params.get(token)`` doesn't match the value supplied by - ``request.session.get_csrf_token()``, and ``raises`` is ``True``, this - function will raise an :exc:`pyramid.httpexceptions.HTTPBadRequest` - exception. If the check does succeed and ``raises`` is ``False``, this + ``request.params.get(token)`` or ``request.headers.get(header)``. + If a ``token`` keyword is not supplied to this function, the string + ``csrf_token`` will be used to look up the token in ``request.params``. + If a ``header`` keyword is not supplied to this function, the string + ``X-CSRFToken`` will be used to look up the token in ``request.headers``. + + If the value supplied by param or by header doesn't match the value + supplied by ``request.session.get_csrf_token()``, and ``raises`` is + ``True``, this function will raise an + :exc:`pyramid.httpexceptions.HTTPBadRequest` exception. + If the check does succeed and ``raises`` is ``False``, this function will return ``False``. If the CSRF check is successful, this function will return ``True`` unconditionally. @@ -98,7 +105,8 @@ def check_csrf_token(request, token='csrf_token', raises=True): .. versionadded:: 1.4a2 """ - if request.params.get(token) != request.session.get_csrf_token(): + supplied_token = request.params.get(token, request.headers.get(header)) + if supplied_token != request.session.get_csrf_token(): if raises: raise HTTPBadRequest('incorrect CSRF token') return False diff --git a/pyramid/tests/test_session.py b/pyramid/tests/test_session.py index b3e0e20c4..d3bafb26e 100644 --- a/pyramid/tests/test_session.py +++ b/pyramid/tests/test_session.py @@ -356,20 +356,29 @@ class Test_signed_deserialize(unittest.TestCase): self.assertRaises(ValueError, self._callFUT, serialized, 'secret') class Test_check_csrf_token(unittest.TestCase): - def _callFUT(self, request, token, raises=True): + def _callFUT(self, *args, **kwargs): from ..session import check_csrf_token - return check_csrf_token(request, token, raises=raises) + return check_csrf_token(*args, **kwargs) - def test_success(self): + def test_success_token(self): request = testing.DummyRequest() request.params['csrf_token'] = request.session.get_csrf_token() - self.assertEqual(self._callFUT(request, 'csrf_token'), True) + self.assertEqual(self._callFUT(request, token='csrf_token'), True) + + def test_success_header(self): + request = testing.DummyRequest() + request.headers['X-CSRFToken'] = request.session.get_csrf_token() + self.assertEqual(self._callFUT(request, header='X-CSRFToken'), True) def test_success_default_token(self): - from ..session import check_csrf_token request = testing.DummyRequest() request.params['csrf_token'] = request.session.get_csrf_token() - self.assertEqual(check_csrf_token(request), True) + self.assertEqual(self._callFUT(request), True) + + def test_success_default_header(self): + request = testing.DummyRequest() + request.headers['X-CSRFToken'] = request.session.get_csrf_token() + self.assertEqual(self._callFUT(request), True) def test_failure_raises(self): from pyramid.httpexceptions import HTTPBadRequest |
