diff options
| author | Michael Merickel <michael@merickel.org> | 2017-04-30 02:00:48 -0500 |
|---|---|---|
| committer | Michael Merickel <michael@merickel.org> | 2017-04-30 02:00:48 -0500 |
| commit | 3f14d63c009ae7f101b7aeb4525bab2dfe66fa11 (patch) | |
| tree | 1654d1d20d5a8aa13b9d76cfc2b8d844a3f3be49 | |
| parent | 682a9b9df6f42f8261daa077f04b47b65bf00c34 (diff) | |
| download | pyramid-3f14d63c009ae7f101b7aeb4525bab2dfe66fa11.tar.gz pyramid-3f14d63c009ae7f101b7aeb4525bab2dfe66fa11.tar.bz2 pyramid-3f14d63c009ae7f101b7aeb4525bab2dfe66fa11.zip | |
restore the ``ICSRFStoragePolicy.check_csrf_token`` api
| -rw-r--r-- | pyramid/csrf.py | 35 | ||||
| -rw-r--r-- | pyramid/interfaces.py | 10 | ||||
| -rw-r--r-- | pyramid/tests/test_csrf.py | 121 |
3 files changed, 113 insertions, 53 deletions
diff --git a/pyramid/csrf.py b/pyramid/csrf.py index 1910e4ec8..c8f097777 100644 --- a/pyramid/csrf.py +++ b/pyramid/csrf.py @@ -47,6 +47,12 @@ class LegacySessionCSRFStoragePolicy(object): generating a new one if needed.""" return request.session.get_csrf_token() + def check_csrf_token(self, request, supplied_token): + """ Returns ``True`` if the ``supplied_token`` is valid.""" + expected_token = self.get_csrf_token(request) + return not strings_differ( + bytes_(expected_token), bytes_(supplied_token)) + @implementer(ICSRFStoragePolicy) class SessionCSRFStoragePolicy(object): @@ -82,6 +88,12 @@ class SessionCSRFStoragePolicy(object): token = self.new_csrf_token(request) return token + def check_csrf_token(self, request, supplied_token): + """ Returns ``True`` if the ``supplied_token`` is valid.""" + expected_token = self.get_csrf_token(request) + return not strings_differ( + bytes_(expected_token), bytes_(supplied_token)) + @implementer(ICSRFStoragePolicy) class CookieCSRFStoragePolicy(object): @@ -133,6 +145,12 @@ class CookieCSRFStoragePolicy(object): token = self.new_csrf_token(request) return token + def check_csrf_token(self, request, supplied_token): + """ Returns ``True`` if the ``supplied_token`` is valid.""" + expected_token = self.get_csrf_token(request) + return not strings_differ( + bytes_(expected_token), bytes_(supplied_token)) + def get_csrf_token(request): """ Get the currently active CSRF token for the request passed, generating @@ -140,6 +158,7 @@ def get_csrf_token(request): calls the equivalent method in the chosen CSRF protection implementation. .. versionadded :: 1.9 + """ registry = request.registry csrf = registry.getUtility(ICSRFStoragePolicy) @@ -152,6 +171,7 @@ def new_csrf_token(request): chosen CSRF protection implementation. .. versionadded :: 1.9 + """ registry = request.registry csrf = registry.getUtility(ICSRFStoragePolicy) @@ -171,9 +191,8 @@ def check_csrf_token(request, function, the string ``X-CSRF-Token`` will be used to look up the token in ``request.headers``. - If the value supplied by post or by header doesn't match the value supplied - by ``policy.get_csrf_token()`` (where ``policy`` is an implementation of - :class:`pyramid.interfaces.ICSRFStoragePolicy`), and ``raises`` is + If the value supplied by post or by header cannot be verified by the + :class:`pyramid.interfaces.ICSRFStoragePolicy`, and ``raises`` is ``True``, this function will raise an :exc:`pyramid.exceptions.BadCSRFToken` exception. If the values differ and ``raises`` is ``False``, this function will return ``False``. If the @@ -191,7 +210,10 @@ def check_csrf_token(request, a header. .. versionchanged:: 1.9 - Moved from :mod:`pyramid.session` to :mod:`pyramid.csrf` + Moved from :mod:`pyramid.session` to :mod:`pyramid.csrf` and updated + to use the configured :class:`pyramid.interfaces.ICSRFStoragePolicy` to + verify the CSRF token. + """ supplied_token = "" # We first check the headers for a csrf token, as that is significantly @@ -207,8 +229,8 @@ def check_csrf_token(request, if supplied_token == "" and token is not None: supplied_token = request.POST.get(token, "") - expected_token = get_csrf_token(request) - if strings_differ(bytes_(expected_token), bytes_(supplied_token)): + policy = request.registry.getUtility(ICSRFStoragePolicy) + if not policy.check_csrf_token(request, text_(supplied_token)): if raises: raise BadCSRFToken('check_csrf_token(): Invalid token') return False @@ -239,6 +261,7 @@ def check_csrf_origin(request, trusted_origins=None, raises=True): .. versionchanged:: 1.9 Moved from :mod:`pyramid.session` to :mod:`pyramid.csrf` + """ def _fail(reason): if raises: diff --git a/pyramid/interfaces.py b/pyramid/interfaces.py index 853e8fcdd..ab83813c8 100644 --- a/pyramid/interfaces.py +++ b/pyramid/interfaces.py @@ -1010,6 +1010,16 @@ class ICSRFStoragePolicy(Interface): """ + def check_csrf_token(request, token): + """ Determine if the supplied ``token`` is valid. Most implementations + should simply compare the ``token`` to the current value of + ``get_csrf_token`` but it is possible to verify the token using + any mechanism necessary using this method. + + Returns ``True`` if the ``token`` is valid, otherwise ``False``. + + """ + class IIntrospector(Interface): def get(category_name, discriminator, default=None): diff --git a/pyramid/tests/test_csrf.py b/pyramid/tests/test_csrf.py index cd7ba2951..f01780ad8 100644 --- a/pyramid/tests/test_csrf.py +++ b/pyramid/tests/test_csrf.py @@ -1,61 +1,20 @@ import unittest -from zope.interface.interfaces import ComponentLookupError - from pyramid import testing from pyramid.config import Configurator -from pyramid.events import BeforeRender - - -class Test_get_csrf_token(unittest.TestCase): - def setUp(self): - self.config = testing.setUp() - - def _callFUT(self, *args, **kwargs): - from pyramid.csrf import get_csrf_token - return get_csrf_token(*args, **kwargs) - - def test_no_override_csrf_utility_registered(self): - request = testing.DummyRequest() - self._callFUT(request) - - def test_success(self): - self.config.set_csrf_storage_policy(DummyCSRF()) - request = testing.DummyRequest() - - csrf_token = self._callFUT(request) - - self.assertEquals(csrf_token, '02821185e4c94269bdc38e6eeae0a2f8') - - -class Test_new_csrf_token(unittest.TestCase): - def setUp(self): - self.config = testing.setUp() - - def _callFUT(self, *args, **kwargs): - from pyramid.csrf import new_csrf_token - return new_csrf_token(*args, **kwargs) - - def test_no_override_csrf_utility_registered(self): - request = testing.DummyRequest() - self._callFUT(request) - - def test_success(self): - self.config.set_csrf_storage_policy(DummyCSRF()) - request = testing.DummyRequest() - - csrf_token = self._callFUT(request) - - self.assertEquals(csrf_token, 'e5e9e30a08b34ff9842ff7d2b958c14b') class TestLegacySessionCSRFStoragePolicy(unittest.TestCase): class MockSession(object): + def __init__(self, current_token='02821185e4c94269bdc38e6eeae0a2f8'): + self.current_token = current_token + def new_csrf_token(self): - return 'e5e9e30a08b34ff9842ff7d2b958c14b' + self.current_token = 'e5e9e30a08b34ff9842ff7d2b958c14b' + return self.current_token def get_csrf_token(self): - return '02821185e4c94269bdc38e6eeae0a2f8' + return self.current_token def _makeOne(self): from pyramid.csrf import LegacySessionCSRFStoragePolicy @@ -86,6 +45,13 @@ class TestLegacySessionCSRFStoragePolicy(unittest.TestCase): 'e5e9e30a08b34ff9842ff7d2b958c14b' ) + def test_check_csrf_token(self): + request = DummyRequest(session=self.MockSession('foo')) + + policy = self._makeOne() + self.assertTrue(policy.check_csrf_token(request, 'foo')) + self.assertFalse(policy.check_csrf_token(request, 'bar')) + class TestSessionCSRFStoragePolicy(unittest.TestCase): def _makeOne(self, **kw): @@ -121,6 +87,16 @@ class TestSessionCSRFStoragePolicy(unittest.TestCase): self.assertNotEqual(token, 'foo') self.assertEqual(token, policy.get_csrf_token(request)) + def test_check_csrf_token(self): + request = DummyRequest(session={}) + + policy = self._makeOne() + self.assertFalse(policy.check_csrf_token(request, 'foo')) + + request.session = {'_csrft_': 'foo'} + self.assertTrue(policy.check_csrf_token(request, 'foo')) + self.assertFalse(policy.check_csrf_token(request, 'bar')) + class TestCookieCSRFStoragePolicy(unittest.TestCase): def _makeOne(self, **kw): @@ -189,6 +165,57 @@ class TestCookieCSRFStoragePolicy(unittest.TestCase): self.assertNotEqual(token, 'foo') self.assertEqual(token, policy.get_csrf_token(request)) + def test_check_csrf_token(self): + request = DummyRequest() + + policy = self._makeOne() + self.assertFalse(policy.check_csrf_token(request, 'foo')) + + request.cookies = {'csrf_token': 'foo'} + self.assertTrue(policy.check_csrf_token(request, 'foo')) + self.assertFalse(policy.check_csrf_token(request, 'bar')) + +class Test_get_csrf_token(unittest.TestCase): + def setUp(self): + self.config = testing.setUp() + + def _callFUT(self, *args, **kwargs): + from pyramid.csrf import get_csrf_token + return get_csrf_token(*args, **kwargs) + + def test_no_override_csrf_utility_registered(self): + request = testing.DummyRequest() + self._callFUT(request) + + def test_success(self): + self.config.set_csrf_storage_policy(DummyCSRF()) + request = testing.DummyRequest() + + csrf_token = self._callFUT(request) + + self.assertEquals(csrf_token, '02821185e4c94269bdc38e6eeae0a2f8') + + +class Test_new_csrf_token(unittest.TestCase): + def setUp(self): + self.config = testing.setUp() + + def _callFUT(self, *args, **kwargs): + from pyramid.csrf import new_csrf_token + return new_csrf_token(*args, **kwargs) + + def test_no_override_csrf_utility_registered(self): + request = testing.DummyRequest() + self._callFUT(request) + + def test_success(self): + self.config.set_csrf_storage_policy(DummyCSRF()) + request = testing.DummyRequest() + + csrf_token = self._callFUT(request) + + self.assertEquals(csrf_token, 'e5e9e30a08b34ff9842ff7d2b958c14b') + class Test_check_csrf_token(unittest.TestCase): def setUp(self): |
