diff options
Diffstat (limited to 'tests')
| -rw-r--r-- | tests/test_config/test_security.py | 7 | ||||
| -rw-r--r-- | tests/test_csrf.py | 42 | ||||
| -rw-r--r-- | tests/test_viewderivers.py | 22 |
3 files changed, 68 insertions, 3 deletions
diff --git a/tests/test_config/test_security.py b/tests/test_config/test_security.py index 0ae199239..9a9ea9f7e 100644 --- a/tests/test_config/test_security.py +++ b/tests/test_config/test_security.py @@ -158,6 +158,7 @@ class ConfiguratorSecurityMethodsTests(unittest.TestCase): list(sorted(result.safe_methods)), ['GET', 'HEAD', 'OPTIONS', 'TRACE'], ) + self.assertTrue(result.check_origin) self.assertFalse(result.allow_no_origin) self.assertTrue(result.callback is None) @@ -174,7 +175,8 @@ class ConfiguratorSecurityMethodsTests(unittest.TestCase): token='DUMMY', header=None, safe_methods=('PUT',), - allow_no_origin=True, + check_origin=False, + allow_no_origin=False, callback=callback, ) result = config.registry.getUtility(IDefaultCSRFOptions) @@ -182,5 +184,6 @@ class ConfiguratorSecurityMethodsTests(unittest.TestCase): self.assertEqual(result.token, 'DUMMY') self.assertEqual(result.header, None) self.assertEqual(list(sorted(result.safe_methods)), ['PUT']) - self.assertTrue(result.allow_no_origin) + self.assertFalse(result.check_origin) + self.assertFalse(result.allow_no_origin) self.assertTrue(result.callback is callback) diff --git a/tests/test_csrf.py b/tests/test_csrf.py index f93a1afde..ae998ec95 100644 --- a/tests/test_csrf.py +++ b/tests/test_csrf.py @@ -387,8 +387,48 @@ class Test_check_csrf_origin(unittest.TestCase): request = testing.DummyRequest() request.scheme = "https" request.referrer = None - self.assertRaises(BadCSRFOrigin, self._callFUT, request) + self.assertRaises( + BadCSRFOrigin, self._callFUT, request, allow_no_origin=False + ) + self.assertFalse( + self._callFUT(request, raises=False, allow_no_origin=False) + ) + + def test_fail_with_null_origin(self): + from pyramid.exceptions import BadCSRFOrigin + + request = testing.DummyRequest() + request.scheme = "https" + request.host = "example.com" + request.host_port = "443" + request.referrer = None + request.headers = {'Origin': 'null'} + request.registry.settings = {} self.assertFalse(self._callFUT(request, raises=False)) + self.assertRaises(BadCSRFOrigin, self._callFUT, request) + + def test_success_with_null_origin_and_setting(self): + request = testing.DummyRequest() + request.scheme = "https" + request.host = "example.com" + request.host_port = "443" + request.referrer = None + request.headers = {'Origin': 'null'} + request.registry.settings = {"pyramid.csrf_trusted_origins": ["null"]} + self.assertTrue(self._callFUT(request, raises=False)) + + def test_success_with_multiple_origins(self): + request = testing.DummyRequest() + request.scheme = "https" + request.host = "example.com" + request.host_port = "443" + request.headers = { + 'Origin': 'https://google.com https://not-example.com' + } + request.registry.settings = { + "pyramid.csrf_trusted_origins": ["not-example.com"] + } + self.assertTrue(self._callFUT(request, raises=False)) def test_fails_when_http_to_https(self): from pyramid.exceptions import BadCSRFOrigin diff --git a/tests/test_viewderivers.py b/tests/test_viewderivers.py index 12a903eaa..e47296b50 100644 --- a/tests/test_viewderivers.py +++ b/tests/test_viewderivers.py @@ -1414,6 +1414,28 @@ class TestDeriveView(unittest.TestCase): result = view(None, request) self.assertTrue(result is response) + def test_csrf_view_disables_origin_check(self): + response = DummyResponse() + + def inner_view(request): + return response + + self.config.set_default_csrf_options( + require_csrf=True, check_origin=False + ) + request = self._makeRequest() + request.scheme = "https" + request.domain = "example.com" + request.host_port = "443" + request.referrer = None + request.method = 'POST' + request.headers = {"Origin": "https://evil-example.com"} + request.session = DummySession({'csrf_token': 'foo'}) + request.POST = {'csrf_token': 'foo'} + view = self.config._derive_view(inner_view, require_csrf=True) + result = view(None, request) + self.assertTrue(result is response) + def test_csrf_view_allow_no_origin(self): response = DummyResponse() |
