summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMichael Merickel <michael@merickel.org>2017-04-30 02:00:48 -0500
committerMichael Merickel <michael@merickel.org>2017-04-30 02:00:48 -0500
commit3f14d63c009ae7f101b7aeb4525bab2dfe66fa11 (patch)
tree1654d1d20d5a8aa13b9d76cfc2b8d844a3f3be49
parent682a9b9df6f42f8261daa077f04b47b65bf00c34 (diff)
downloadpyramid-3f14d63c009ae7f101b7aeb4525bab2dfe66fa11.tar.gz
pyramid-3f14d63c009ae7f101b7aeb4525bab2dfe66fa11.tar.bz2
pyramid-3f14d63c009ae7f101b7aeb4525bab2dfe66fa11.zip
restore the ``ICSRFStoragePolicy.check_csrf_token`` api
-rw-r--r--pyramid/csrf.py35
-rw-r--r--pyramid/interfaces.py10
-rw-r--r--pyramid/tests/test_csrf.py121
3 files changed, 113 insertions, 53 deletions
diff --git a/pyramid/csrf.py b/pyramid/csrf.py
index 1910e4ec8..c8f097777 100644
--- a/pyramid/csrf.py
+++ b/pyramid/csrf.py
@@ -47,6 +47,12 @@ class LegacySessionCSRFStoragePolicy(object):
generating a new one if needed."""
return request.session.get_csrf_token()
+ def check_csrf_token(self, request, supplied_token):
+ """ Returns ``True`` if the ``supplied_token`` is valid."""
+ expected_token = self.get_csrf_token(request)
+ return not strings_differ(
+ bytes_(expected_token), bytes_(supplied_token))
+
@implementer(ICSRFStoragePolicy)
class SessionCSRFStoragePolicy(object):
@@ -82,6 +88,12 @@ class SessionCSRFStoragePolicy(object):
token = self.new_csrf_token(request)
return token
+ def check_csrf_token(self, request, supplied_token):
+ """ Returns ``True`` if the ``supplied_token`` is valid."""
+ expected_token = self.get_csrf_token(request)
+ return not strings_differ(
+ bytes_(expected_token), bytes_(supplied_token))
+
@implementer(ICSRFStoragePolicy)
class CookieCSRFStoragePolicy(object):
@@ -133,6 +145,12 @@ class CookieCSRFStoragePolicy(object):
token = self.new_csrf_token(request)
return token
+ def check_csrf_token(self, request, supplied_token):
+ """ Returns ``True`` if the ``supplied_token`` is valid."""
+ expected_token = self.get_csrf_token(request)
+ return not strings_differ(
+ bytes_(expected_token), bytes_(supplied_token))
+
def get_csrf_token(request):
""" Get the currently active CSRF token for the request passed, generating
@@ -140,6 +158,7 @@ def get_csrf_token(request):
calls the equivalent method in the chosen CSRF protection implementation.
.. versionadded :: 1.9
+
"""
registry = request.registry
csrf = registry.getUtility(ICSRFStoragePolicy)
@@ -152,6 +171,7 @@ def new_csrf_token(request):
chosen CSRF protection implementation.
.. versionadded :: 1.9
+
"""
registry = request.registry
csrf = registry.getUtility(ICSRFStoragePolicy)
@@ -171,9 +191,8 @@ def check_csrf_token(request,
function, the string ``X-CSRF-Token`` will be used to look up the token in
``request.headers``.
- If the value supplied by post or by header doesn't match the value supplied
- by ``policy.get_csrf_token()`` (where ``policy`` is an implementation of
- :class:`pyramid.interfaces.ICSRFStoragePolicy`), and ``raises`` is
+ If the value supplied by post or by header cannot be verified by the
+ :class:`pyramid.interfaces.ICSRFStoragePolicy`, and ``raises`` is
``True``, this function will raise an
:exc:`pyramid.exceptions.BadCSRFToken` exception. If the values differ
and ``raises`` is ``False``, this function will return ``False``. If the
@@ -191,7 +210,10 @@ def check_csrf_token(request,
a header.
.. versionchanged:: 1.9
- Moved from :mod:`pyramid.session` to :mod:`pyramid.csrf`
+ Moved from :mod:`pyramid.session` to :mod:`pyramid.csrf` and updated
+ to use the configured :class:`pyramid.interfaces.ICSRFStoragePolicy` to
+ verify the CSRF token.
+
"""
supplied_token = ""
# We first check the headers for a csrf token, as that is significantly
@@ -207,8 +229,8 @@ def check_csrf_token(request,
if supplied_token == "" and token is not None:
supplied_token = request.POST.get(token, "")
- expected_token = get_csrf_token(request)
- if strings_differ(bytes_(expected_token), bytes_(supplied_token)):
+ policy = request.registry.getUtility(ICSRFStoragePolicy)
+ if not policy.check_csrf_token(request, text_(supplied_token)):
if raises:
raise BadCSRFToken('check_csrf_token(): Invalid token')
return False
@@ -239,6 +261,7 @@ def check_csrf_origin(request, trusted_origins=None, raises=True):
.. versionchanged:: 1.9
Moved from :mod:`pyramid.session` to :mod:`pyramid.csrf`
+
"""
def _fail(reason):
if raises:
diff --git a/pyramid/interfaces.py b/pyramid/interfaces.py
index 853e8fcdd..ab83813c8 100644
--- a/pyramid/interfaces.py
+++ b/pyramid/interfaces.py
@@ -1010,6 +1010,16 @@ class ICSRFStoragePolicy(Interface):
"""
+ def check_csrf_token(request, token):
+ """ Determine if the supplied ``token`` is valid. Most implementations
+ should simply compare the ``token`` to the current value of
+ ``get_csrf_token`` but it is possible to verify the token using
+ any mechanism necessary using this method.
+
+ Returns ``True`` if the ``token`` is valid, otherwise ``False``.
+
+ """
+
class IIntrospector(Interface):
def get(category_name, discriminator, default=None):
diff --git a/pyramid/tests/test_csrf.py b/pyramid/tests/test_csrf.py
index cd7ba2951..f01780ad8 100644
--- a/pyramid/tests/test_csrf.py
+++ b/pyramid/tests/test_csrf.py
@@ -1,61 +1,20 @@
import unittest
-from zope.interface.interfaces import ComponentLookupError
-
from pyramid import testing
from pyramid.config import Configurator
-from pyramid.events import BeforeRender
-
-
-class Test_get_csrf_token(unittest.TestCase):
- def setUp(self):
- self.config = testing.setUp()
-
- def _callFUT(self, *args, **kwargs):
- from pyramid.csrf import get_csrf_token
- return get_csrf_token(*args, **kwargs)
-
- def test_no_override_csrf_utility_registered(self):
- request = testing.DummyRequest()
- self._callFUT(request)
-
- def test_success(self):
- self.config.set_csrf_storage_policy(DummyCSRF())
- request = testing.DummyRequest()
-
- csrf_token = self._callFUT(request)
-
- self.assertEquals(csrf_token, '02821185e4c94269bdc38e6eeae0a2f8')
-
-
-class Test_new_csrf_token(unittest.TestCase):
- def setUp(self):
- self.config = testing.setUp()
-
- def _callFUT(self, *args, **kwargs):
- from pyramid.csrf import new_csrf_token
- return new_csrf_token(*args, **kwargs)
-
- def test_no_override_csrf_utility_registered(self):
- request = testing.DummyRequest()
- self._callFUT(request)
-
- def test_success(self):
- self.config.set_csrf_storage_policy(DummyCSRF())
- request = testing.DummyRequest()
-
- csrf_token = self._callFUT(request)
-
- self.assertEquals(csrf_token, 'e5e9e30a08b34ff9842ff7d2b958c14b')
class TestLegacySessionCSRFStoragePolicy(unittest.TestCase):
class MockSession(object):
+ def __init__(self, current_token='02821185e4c94269bdc38e6eeae0a2f8'):
+ self.current_token = current_token
+
def new_csrf_token(self):
- return 'e5e9e30a08b34ff9842ff7d2b958c14b'
+ self.current_token = 'e5e9e30a08b34ff9842ff7d2b958c14b'
+ return self.current_token
def get_csrf_token(self):
- return '02821185e4c94269bdc38e6eeae0a2f8'
+ return self.current_token
def _makeOne(self):
from pyramid.csrf import LegacySessionCSRFStoragePolicy
@@ -86,6 +45,13 @@ class TestLegacySessionCSRFStoragePolicy(unittest.TestCase):
'e5e9e30a08b34ff9842ff7d2b958c14b'
)
+ def test_check_csrf_token(self):
+ request = DummyRequest(session=self.MockSession('foo'))
+
+ policy = self._makeOne()
+ self.assertTrue(policy.check_csrf_token(request, 'foo'))
+ self.assertFalse(policy.check_csrf_token(request, 'bar'))
+
class TestSessionCSRFStoragePolicy(unittest.TestCase):
def _makeOne(self, **kw):
@@ -121,6 +87,16 @@ class TestSessionCSRFStoragePolicy(unittest.TestCase):
self.assertNotEqual(token, 'foo')
self.assertEqual(token, policy.get_csrf_token(request))
+ def test_check_csrf_token(self):
+ request = DummyRequest(session={})
+
+ policy = self._makeOne()
+ self.assertFalse(policy.check_csrf_token(request, 'foo'))
+
+ request.session = {'_csrft_': 'foo'}
+ self.assertTrue(policy.check_csrf_token(request, 'foo'))
+ self.assertFalse(policy.check_csrf_token(request, 'bar'))
+
class TestCookieCSRFStoragePolicy(unittest.TestCase):
def _makeOne(self, **kw):
@@ -189,6 +165,57 @@ class TestCookieCSRFStoragePolicy(unittest.TestCase):
self.assertNotEqual(token, 'foo')
self.assertEqual(token, policy.get_csrf_token(request))
+ def test_check_csrf_token(self):
+ request = DummyRequest()
+
+ policy = self._makeOne()
+ self.assertFalse(policy.check_csrf_token(request, 'foo'))
+
+ request.cookies = {'csrf_token': 'foo'}
+ self.assertTrue(policy.check_csrf_token(request, 'foo'))
+ self.assertFalse(policy.check_csrf_token(request, 'bar'))
+
+class Test_get_csrf_token(unittest.TestCase):
+ def setUp(self):
+ self.config = testing.setUp()
+
+ def _callFUT(self, *args, **kwargs):
+ from pyramid.csrf import get_csrf_token
+ return get_csrf_token(*args, **kwargs)
+
+ def test_no_override_csrf_utility_registered(self):
+ request = testing.DummyRequest()
+ self._callFUT(request)
+
+ def test_success(self):
+ self.config.set_csrf_storage_policy(DummyCSRF())
+ request = testing.DummyRequest()
+
+ csrf_token = self._callFUT(request)
+
+ self.assertEquals(csrf_token, '02821185e4c94269bdc38e6eeae0a2f8')
+
+
+class Test_new_csrf_token(unittest.TestCase):
+ def setUp(self):
+ self.config = testing.setUp()
+
+ def _callFUT(self, *args, **kwargs):
+ from pyramid.csrf import new_csrf_token
+ return new_csrf_token(*args, **kwargs)
+
+ def test_no_override_csrf_utility_registered(self):
+ request = testing.DummyRequest()
+ self._callFUT(request)
+
+ def test_success(self):
+ self.config.set_csrf_storage_policy(DummyCSRF())
+ request = testing.DummyRequest()
+
+ csrf_token = self._callFUT(request)
+
+ self.assertEquals(csrf_token, 'e5e9e30a08b34ff9842ff7d2b958c14b')
+
class Test_check_csrf_token(unittest.TestCase):
def setUp(self):