summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLuke Cyca <me@lukecyca.com>2013-05-30 10:25:58 -0700
committerLuke Cyca <me@lukecyca.com>2013-05-30 10:25:58 -0700
commitbfe654f0af480473531bdff77cefb676568aac8f (patch)
tree226e2b59b13e377c61d556bb4f04a1334cc1f077
parent30715a7db397d90e786b85715303bfaf34109b31 (diff)
downloadpyramid-bfe654f0af480473531bdff77cefb676568aac8f.tar.gz
pyramid-bfe654f0af480473531bdff77cefb676568aac8f.tar.bz2
pyramid-bfe654f0af480473531bdff77cefb676568aac8f.zip
Support CSRF via X-CSRFToken Header
-rw-r--r--pyramid/session.py26
-rw-r--r--pyramid/tests/test_session.py21
2 files changed, 32 insertions, 15 deletions
diff --git a/pyramid/session.py b/pyramid/session.py
index 7db8c8e0e..0433488d8 100644
--- a/pyramid/session.py
+++ b/pyramid/session.py
@@ -81,15 +81,22 @@ def signed_deserialize(serialized, secret, hmac=hmac):
return pickle.loads(pickled)
-def check_csrf_token(request, token='csrf_token', raises=True):
+def check_csrf_token(request,
+ token='csrf_token',
+ header='X-CSRFToken',
+ raises=True):
""" Check the CSRF token in the request's session against the value in
- ``request.params.get(token)``. If a ``token`` keyword is not supplied
- to this function, the string ``csrf_token`` will be used to look up
- the token within ``request.params``. If the value in
- ``request.params.get(token)`` doesn't match the value supplied by
- ``request.session.get_csrf_token()``, and ``raises`` is ``True``, this
- function will raise an :exc:`pyramid.httpexceptions.HTTPBadRequest`
- exception. If the check does succeed and ``raises`` is ``False``, this
+ ``request.params.get(token)`` or ``request.headers.get(header)``.
+ If a ``token`` keyword is not supplied to this function, the string
+ ``csrf_token`` will be used to look up the token in ``request.params``.
+ If a ``header`` keyword is not supplied to this function, the string
+ ``X-CSRFToken`` will be used to look up the token in ``request.headers``.
+
+ If the value supplied by param or by header doesn't match the value
+ supplied by ``request.session.get_csrf_token()``, and ``raises`` is
+ ``True``, this function will raise an
+ :exc:`pyramid.httpexceptions.HTTPBadRequest` exception.
+ If the check does succeed and ``raises`` is ``False``, this
function will return ``False``. If the CSRF check is successful, this
function will return ``True`` unconditionally.
@@ -98,7 +105,8 @@ def check_csrf_token(request, token='csrf_token', raises=True):
.. versionadded:: 1.4a2
"""
- if request.params.get(token) != request.session.get_csrf_token():
+ supplied_token = request.params.get(token, request.headers.get(header))
+ if supplied_token != request.session.get_csrf_token():
if raises:
raise HTTPBadRequest('incorrect CSRF token')
return False
diff --git a/pyramid/tests/test_session.py b/pyramid/tests/test_session.py
index b3e0e20c4..d3bafb26e 100644
--- a/pyramid/tests/test_session.py
+++ b/pyramid/tests/test_session.py
@@ -356,20 +356,29 @@ class Test_signed_deserialize(unittest.TestCase):
self.assertRaises(ValueError, self._callFUT, serialized, 'secret')
class Test_check_csrf_token(unittest.TestCase):
- def _callFUT(self, request, token, raises=True):
+ def _callFUT(self, *args, **kwargs):
from ..session import check_csrf_token
- return check_csrf_token(request, token, raises=raises)
+ return check_csrf_token(*args, **kwargs)
- def test_success(self):
+ def test_success_token(self):
request = testing.DummyRequest()
request.params['csrf_token'] = request.session.get_csrf_token()
- self.assertEqual(self._callFUT(request, 'csrf_token'), True)
+ self.assertEqual(self._callFUT(request, token='csrf_token'), True)
+
+ def test_success_header(self):
+ request = testing.DummyRequest()
+ request.headers['X-CSRFToken'] = request.session.get_csrf_token()
+ self.assertEqual(self._callFUT(request, header='X-CSRFToken'), True)
def test_success_default_token(self):
- from ..session import check_csrf_token
request = testing.DummyRequest()
request.params['csrf_token'] = request.session.get_csrf_token()
- self.assertEqual(check_csrf_token(request), True)
+ self.assertEqual(self._callFUT(request), True)
+
+ def test_success_default_header(self):
+ request = testing.DummyRequest()
+ request.headers['X-CSRFToken'] = request.session.get_csrf_token()
+ self.assertEqual(self._callFUT(request), True)
def test_failure_raises(self):
from pyramid.httpexceptions import HTTPBadRequest