diff options
| author | Michael Merickel <michael@merickel.org> | 2018-10-14 21:11:41 -0500 |
|---|---|---|
| committer | Michael Merickel <michael@merickel.org> | 2018-10-14 21:11:41 -0500 |
| commit | 3670c2cdb732d378ba6d38e72e7cd875ff726aa9 (patch) | |
| tree | 5213452a778c992d42602efe7d3b3655a349abd5 /tests/test_csrf.py | |
| parent | 2b024920847481592b1a13d4006d2a9fa8881d72 (diff) | |
| download | pyramid-3670c2cdb732d378ba6d38e72e7cd875ff726aa9.tar.gz pyramid-3670c2cdb732d378ba6d38e72e7cd875ff726aa9.tar.bz2 pyramid-3670c2cdb732d378ba6d38e72e7cd875ff726aa9.zip | |
move tests out of the package
Diffstat (limited to 'tests/test_csrf.py')
| -rw-r--r-- | tests/test_csrf.py | 420 |
1 files changed, 420 insertions, 0 deletions
diff --git a/tests/test_csrf.py b/tests/test_csrf.py new file mode 100644 index 000000000..234d4434c --- /dev/null +++ b/tests/test_csrf.py @@ -0,0 +1,420 @@ +import unittest + +from pyramid import testing +from pyramid.config import Configurator + + +class TestLegacySessionCSRFStoragePolicy(unittest.TestCase): + class MockSession(object): + def __init__(self, current_token='02821185e4c94269bdc38e6eeae0a2f8'): + self.current_token = current_token + + def new_csrf_token(self): + self.current_token = 'e5e9e30a08b34ff9842ff7d2b958c14b' + return self.current_token + + def get_csrf_token(self): + return self.current_token + + def _makeOne(self): + from pyramid.csrf import LegacySessionCSRFStoragePolicy + return LegacySessionCSRFStoragePolicy() + + def test_register_session_csrf_policy(self): + from pyramid.csrf import LegacySessionCSRFStoragePolicy + from pyramid.interfaces import ICSRFStoragePolicy + + config = Configurator() + config.set_csrf_storage_policy(self._makeOne()) + config.commit() + + policy = config.registry.queryUtility(ICSRFStoragePolicy) + + self.assertTrue(isinstance(policy, LegacySessionCSRFStoragePolicy)) + + def test_session_csrf_implementation_delegates_to_session(self): + policy = self._makeOne() + request = DummyRequest(session=self.MockSession()) + + self.assertEqual( + policy.get_csrf_token(request), + '02821185e4c94269bdc38e6eeae0a2f8' + ) + self.assertEqual( + policy.new_csrf_token(request), + 'e5e9e30a08b34ff9842ff7d2b958c14b' + ) + + def test_check_csrf_token(self): + request = DummyRequest(session=self.MockSession('foo')) + + policy = self._makeOne() + self.assertTrue(policy.check_csrf_token(request, 'foo')) + self.assertFalse(policy.check_csrf_token(request, 'bar')) + + +class TestSessionCSRFStoragePolicy(unittest.TestCase): + def _makeOne(self, **kw): + from pyramid.csrf import SessionCSRFStoragePolicy + return SessionCSRFStoragePolicy(**kw) + + def test_register_session_csrf_policy(self): + from pyramid.csrf import SessionCSRFStoragePolicy + from pyramid.interfaces import ICSRFStoragePolicy + + config = Configurator() + config.set_csrf_storage_policy(self._makeOne()) + config.commit() + + policy = config.registry.queryUtility(ICSRFStoragePolicy) + + self.assertTrue(isinstance(policy, SessionCSRFStoragePolicy)) + + def test_it_creates_a_new_token(self): + request = DummyRequest(session={}) + + policy = self._makeOne() + policy._token_factory = lambda: 'foo' + self.assertEqual(policy.get_csrf_token(request), 'foo') + + def test_get_csrf_token_returns_the_new_token(self): + request = DummyRequest(session={'_csrft_': 'foo'}) + + policy = self._makeOne() + self.assertEqual(policy.get_csrf_token(request), 'foo') + + token = policy.new_csrf_token(request) + self.assertNotEqual(token, 'foo') + self.assertEqual(token, policy.get_csrf_token(request)) + + def test_check_csrf_token(self): + request = DummyRequest(session={}) + + policy = self._makeOne() + self.assertFalse(policy.check_csrf_token(request, 'foo')) + + request.session = {'_csrft_': 'foo'} + self.assertTrue(policy.check_csrf_token(request, 'foo')) + self.assertFalse(policy.check_csrf_token(request, 'bar')) + + +class TestCookieCSRFStoragePolicy(unittest.TestCase): + def _makeOne(self, **kw): + from pyramid.csrf import CookieCSRFStoragePolicy + return CookieCSRFStoragePolicy(**kw) + + def test_register_cookie_csrf_policy(self): + from pyramid.csrf import CookieCSRFStoragePolicy + from pyramid.interfaces import ICSRFStoragePolicy + + config = Configurator() + config.set_csrf_storage_policy(self._makeOne()) + config.commit() + + policy = config.registry.queryUtility(ICSRFStoragePolicy) + + self.assertTrue(isinstance(policy, CookieCSRFStoragePolicy)) + + def test_get_cookie_csrf_with_no_existing_cookie_sets_cookies(self): + response = MockResponse() + request = DummyRequest() + + policy = self._makeOne() + token = policy.get_csrf_token(request) + request.response_callback(request, response) + self.assertEqual( + response.headerlist, + [('Set-Cookie', 'csrf_token={}; Path=/; SameSite=Lax'.format( + token))] + ) + + def test_get_cookie_csrf_nondefault_samesite(self): + response = MockResponse() + request = DummyRequest() + + policy = self._makeOne(samesite=None) + token = policy.get_csrf_token(request) + request.response_callback(request, response) + self.assertEqual( + response.headerlist, + [('Set-Cookie', 'csrf_token={}; Path=/'.format(token))] + ) + + def test_existing_cookie_csrf_does_not_set_cookie(self): + request = DummyRequest() + request.cookies = {'csrf_token': 'e6f325fee5974f3da4315a8ccf4513d2'} + + policy = self._makeOne() + token = policy.get_csrf_token(request) + + self.assertEqual( + token, + 'e6f325fee5974f3da4315a8ccf4513d2' + ) + self.assertIsNone(request.response_callback) + + def test_new_cookie_csrf_with_existing_cookie_sets_cookies(self): + request = DummyRequest() + request.cookies = {'csrf_token': 'e6f325fee5974f3da4315a8ccf4513d2'} + + policy = self._makeOne() + token = policy.new_csrf_token(request) + + response = MockResponse() + request.response_callback(request, response) + self.assertEqual( + response.headerlist, + [('Set-Cookie', 'csrf_token={}; Path=/; SameSite=Lax'.format(token) + )] + ) + + def test_get_csrf_token_returns_the_new_token(self): + request = DummyRequest() + request.cookies = {'csrf_token': 'foo'} + + policy = self._makeOne() + self.assertEqual(policy.get_csrf_token(request), 'foo') + + token = policy.new_csrf_token(request) + self.assertNotEqual(token, 'foo') + self.assertEqual(token, policy.get_csrf_token(request)) + + def test_check_csrf_token(self): + request = DummyRequest() + + policy = self._makeOne() + self.assertFalse(policy.check_csrf_token(request, 'foo')) + + request.cookies = {'csrf_token': 'foo'} + self.assertTrue(policy.check_csrf_token(request, 'foo')) + self.assertFalse(policy.check_csrf_token(request, 'bar')) + +class Test_get_csrf_token(unittest.TestCase): + def setUp(self): + self.config = testing.setUp() + + def _callFUT(self, *args, **kwargs): + from pyramid.csrf import get_csrf_token + return get_csrf_token(*args, **kwargs) + + def test_no_override_csrf_utility_registered(self): + request = testing.DummyRequest() + self._callFUT(request) + + def test_success(self): + self.config.set_csrf_storage_policy(DummyCSRF()) + request = testing.DummyRequest() + + csrf_token = self._callFUT(request) + + self.assertEquals(csrf_token, '02821185e4c94269bdc38e6eeae0a2f8') + + +class Test_new_csrf_token(unittest.TestCase): + def setUp(self): + self.config = testing.setUp() + + def _callFUT(self, *args, **kwargs): + from pyramid.csrf import new_csrf_token + return new_csrf_token(*args, **kwargs) + + def test_no_override_csrf_utility_registered(self): + request = testing.DummyRequest() + self._callFUT(request) + + def test_success(self): + self.config.set_csrf_storage_policy(DummyCSRF()) + request = testing.DummyRequest() + + csrf_token = self._callFUT(request) + + self.assertEquals(csrf_token, 'e5e9e30a08b34ff9842ff7d2b958c14b') + + +class Test_check_csrf_token(unittest.TestCase): + def setUp(self): + self.config = testing.setUp() + + # set up CSRF + self.config.set_default_csrf_options(require_csrf=False) + + def _callFUT(self, *args, **kwargs): + from ..csrf import check_csrf_token + return check_csrf_token(*args, **kwargs) + + def test_success_token(self): + request = testing.DummyRequest() + request.method = "POST" + request.POST = {'csrf_token': request.session.get_csrf_token()} + self.assertEqual(self._callFUT(request, token='csrf_token'), True) + + def test_success_header(self): + request = testing.DummyRequest() + request.headers['X-CSRF-Token'] = request.session.get_csrf_token() + self.assertEqual(self._callFUT(request, header='X-CSRF-Token'), True) + + def test_success_default_token(self): + request = testing.DummyRequest() + request.method = "POST" + request.POST = {'csrf_token': request.session.get_csrf_token()} + self.assertEqual(self._callFUT(request), True) + + def test_success_default_header(self): + request = testing.DummyRequest() + request.headers['X-CSRF-Token'] = request.session.get_csrf_token() + self.assertEqual(self._callFUT(request), True) + + def test_failure_raises(self): + from pyramid.exceptions import BadCSRFToken + request = testing.DummyRequest() + self.assertRaises(BadCSRFToken, self._callFUT, request, + 'csrf_token') + + def test_failure_no_raises(self): + request = testing.DummyRequest() + result = self._callFUT(request, 'csrf_token', raises=False) + self.assertEqual(result, False) + + +class Test_check_csrf_token_without_defaults_configured(unittest.TestCase): + def setUp(self): + self.config = testing.setUp() + + def _callFUT(self, *args, **kwargs): + from ..csrf import check_csrf_token + return check_csrf_token(*args, **kwargs) + + def test_success_token(self): + request = testing.DummyRequest() + request.method = "POST" + request.POST = {'csrf_token': request.session.get_csrf_token()} + self.assertEqual(self._callFUT(request, token='csrf_token'), True) + + def test_failure_raises(self): + from pyramid.exceptions import BadCSRFToken + request = testing.DummyRequest() + self.assertRaises(BadCSRFToken, self._callFUT, request, + 'csrf_token') + + def test_failure_no_raises(self): + request = testing.DummyRequest() + result = self._callFUT(request, 'csrf_token', raises=False) + self.assertEqual(result, False) + + +class Test_check_csrf_origin(unittest.TestCase): + def _callFUT(self, *args, **kwargs): + from ..csrf import check_csrf_origin + return check_csrf_origin(*args, **kwargs) + + def test_success_with_http(self): + request = testing.DummyRequest() + request.scheme = "http" + self.assertTrue(self._callFUT(request)) + + def test_success_with_https_and_referrer(self): + request = testing.DummyRequest() + request.scheme = "https" + request.host = "example.com" + request.host_port = "443" + request.referrer = "https://example.com/login/" + request.registry.settings = {} + self.assertTrue(self._callFUT(request)) + + def test_success_with_https_and_origin(self): + request = testing.DummyRequest() + request.scheme = "https" + request.host = "example.com" + request.host_port = "443" + request.headers = {"Origin": "https://example.com/"} + request.referrer = "https://not-example.com/" + request.registry.settings = {} + self.assertTrue(self._callFUT(request)) + + def test_success_with_additional_trusted_host(self): + request = testing.DummyRequest() + request.scheme = "https" + request.host = "example.com" + request.host_port = "443" + request.referrer = "https://not-example.com/login/" + request.registry.settings = { + "pyramid.csrf_trusted_origins": ["not-example.com"], + } + self.assertTrue(self._callFUT(request)) + + def test_success_with_nonstandard_port(self): + request = testing.DummyRequest() + request.scheme = "https" + request.host = "example.com:8080" + request.host_port = "8080" + request.referrer = "https://example.com:8080/login/" + request.registry.settings = {} + self.assertTrue(self._callFUT(request)) + + def test_fails_with_wrong_host(self): + from pyramid.exceptions import BadCSRFOrigin + request = testing.DummyRequest() + request.scheme = "https" + request.host = "example.com" + request.host_port = "443" + request.referrer = "https://not-example.com/login/" + request.registry.settings = {} + self.assertRaises(BadCSRFOrigin, self._callFUT, request) + self.assertFalse(self._callFUT(request, raises=False)) + + def test_fails_with_no_origin(self): + from pyramid.exceptions import BadCSRFOrigin + request = testing.DummyRequest() + request.scheme = "https" + request.referrer = None + self.assertRaises(BadCSRFOrigin, self._callFUT, request) + self.assertFalse(self._callFUT(request, raises=False)) + + def test_fails_when_http_to_https(self): + from pyramid.exceptions import BadCSRFOrigin + request = testing.DummyRequest() + request.scheme = "https" + request.host = "example.com" + request.host_port = "443" + request.referrer = "http://example.com/evil/" + request.registry.settings = {} + self.assertRaises(BadCSRFOrigin, self._callFUT, request) + self.assertFalse(self._callFUT(request, raises=False)) + + def test_fails_with_nonstandard_port(self): + from pyramid.exceptions import BadCSRFOrigin + request = testing.DummyRequest() + request.scheme = "https" + request.host = "example.com:8080" + request.host_port = "8080" + request.referrer = "https://example.com/login/" + request.registry.settings = {} + self.assertRaises(BadCSRFOrigin, self._callFUT, request) + self.assertFalse(self._callFUT(request, raises=False)) + + +class DummyRequest(object): + registry = None + session = None + response_callback = None + + def __init__(self, registry=None, session=None): + self.registry = registry + self.session = session + self.cookies = {} + + def add_response_callback(self, callback): + self.response_callback = callback + + +class MockResponse(object): + def __init__(self): + self.headerlist = [] + + +class DummyCSRF(object): + def new_csrf_token(self, request): + return 'e5e9e30a08b34ff9842ff7d2b958c14b' + + def get_csrf_token(self, request): + return '02821185e4c94269bdc38e6eeae0a2f8' |
