From bfe654f0af480473531bdff77cefb676568aac8f Mon Sep 17 00:00:00 2001 From: Luke Cyca Date: Thu, 30 May 2013 10:25:58 -0700 Subject: Support CSRF via X-CSRFToken Header --- pyramid/session.py | 26 +++++++++++++++++--------- pyramid/tests/test_session.py | 21 +++++++++++++++------ 2 files changed, 32 insertions(+), 15 deletions(-) diff --git a/pyramid/session.py b/pyramid/session.py index 7db8c8e0e..0433488d8 100644 --- a/pyramid/session.py +++ b/pyramid/session.py @@ -81,15 +81,22 @@ def signed_deserialize(serialized, secret, hmac=hmac): return pickle.loads(pickled) -def check_csrf_token(request, token='csrf_token', raises=True): +def check_csrf_token(request, + token='csrf_token', + header='X-CSRFToken', + raises=True): """ Check the CSRF token in the request's session against the value in - ``request.params.get(token)``. If a ``token`` keyword is not supplied - to this function, the string ``csrf_token`` will be used to look up - the token within ``request.params``. If the value in - ``request.params.get(token)`` doesn't match the value supplied by - ``request.session.get_csrf_token()``, and ``raises`` is ``True``, this - function will raise an :exc:`pyramid.httpexceptions.HTTPBadRequest` - exception. If the check does succeed and ``raises`` is ``False``, this + ``request.params.get(token)`` or ``request.headers.get(header)``. + If a ``token`` keyword is not supplied to this function, the string + ``csrf_token`` will be used to look up the token in ``request.params``. + If a ``header`` keyword is not supplied to this function, the string + ``X-CSRFToken`` will be used to look up the token in ``request.headers``. + + If the value supplied by param or by header doesn't match the value + supplied by ``request.session.get_csrf_token()``, and ``raises`` is + ``True``, this function will raise an + :exc:`pyramid.httpexceptions.HTTPBadRequest` exception. + If the check does succeed and ``raises`` is ``False``, this function will return ``False``. If the CSRF check is successful, this function will return ``True`` unconditionally. @@ -98,7 +105,8 @@ def check_csrf_token(request, token='csrf_token', raises=True): .. versionadded:: 1.4a2 """ - if request.params.get(token) != request.session.get_csrf_token(): + supplied_token = request.params.get(token, request.headers.get(header)) + if supplied_token != request.session.get_csrf_token(): if raises: raise HTTPBadRequest('incorrect CSRF token') return False diff --git a/pyramid/tests/test_session.py b/pyramid/tests/test_session.py index b3e0e20c4..d3bafb26e 100644 --- a/pyramid/tests/test_session.py +++ b/pyramid/tests/test_session.py @@ -356,20 +356,29 @@ class Test_signed_deserialize(unittest.TestCase): self.assertRaises(ValueError, self._callFUT, serialized, 'secret') class Test_check_csrf_token(unittest.TestCase): - def _callFUT(self, request, token, raises=True): + def _callFUT(self, *args, **kwargs): from ..session import check_csrf_token - return check_csrf_token(request, token, raises=raises) + return check_csrf_token(*args, **kwargs) - def test_success(self): + def test_success_token(self): request = testing.DummyRequest() request.params['csrf_token'] = request.session.get_csrf_token() - self.assertEqual(self._callFUT(request, 'csrf_token'), True) + self.assertEqual(self._callFUT(request, token='csrf_token'), True) + + def test_success_header(self): + request = testing.DummyRequest() + request.headers['X-CSRFToken'] = request.session.get_csrf_token() + self.assertEqual(self._callFUT(request, header='X-CSRFToken'), True) def test_success_default_token(self): - from ..session import check_csrf_token request = testing.DummyRequest() request.params['csrf_token'] = request.session.get_csrf_token() - self.assertEqual(check_csrf_token(request), True) + self.assertEqual(self._callFUT(request), True) + + def test_success_default_header(self): + request = testing.DummyRequest() + request.headers['X-CSRFToken'] = request.session.get_csrf_token() + self.assertEqual(self._callFUT(request), True) def test_failure_raises(self): from pyramid.httpexceptions import HTTPBadRequest -- cgit v1.2.3