summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMichael Merickel <mmerickel@users.noreply.github.com>2016-10-17 22:22:10 -0500
committerGitHub <noreply@github.com>2016-10-17 22:22:10 -0500
commite73ae375581539ed42aa97d7cd6e96e6fbd64c79 (patch)
tree10758a9c6980205c752e94e040fdb9433620859b
parent325fc180ccf93716cdd1c959257a9864fcbee359 (diff)
parent17fa5e3ce891064231707bf30413b38b89bd6d7f (diff)
downloadpyramid-e73ae375581539ed42aa97d7cd6e96e6fbd64c79.tar.gz
pyramid-e73ae375581539ed42aa97d7cd6e96e6fbd64c79.tar.bz2
pyramid-e73ae375581539ed42aa97d7cd6e96e6fbd64c79.zip
Merge pull request #2778 from mmerickel/per-request-csrf
add a callback hook to set_default_csrf_options for disabling checks …
-rw-r--r--pyramid/config/security.py19
-rw-r--r--pyramid/interfaces.py1
-rw-r--r--pyramid/tests/test_config/test_security.py6
-rw-r--r--pyramid/tests/test_viewderivers.py28
-rw-r--r--pyramid/viewderivers.py7
5 files changed, 57 insertions, 4 deletions
diff --git a/pyramid/config/security.py b/pyramid/config/security.py
index e387eade9..02732c042 100644
--- a/pyramid/config/security.py
+++ b/pyramid/config/security.py
@@ -169,6 +169,7 @@ class SecurityConfiguratorMixin(object):
token='csrf_token',
header='X-CSRF-Token',
safe_methods=('GET', 'HEAD', 'OPTIONS', 'TRACE'),
+ callback=None,
):
"""
Set the default CSRF options used by subsequent view registrations.
@@ -192,8 +193,20 @@ class SecurityConfiguratorMixin(object):
never be automatically checked for CSRF tokens.
Default: ``('GET', 'HEAD', 'OPTIONS', TRACE')``.
+ If ``callback`` is set, it must be a callable accepting ``(request)``
+ and returning ``True`` if the request should be checked for a valid
+ CSRF token. This callback allows an application to support
+ alternate authentication methods that do not rely on cookies which
+ are not subject to CSRF attacks. For example, if a request is
+ authenticated using the ``Authorization`` header instead of a cookie,
+ this may return ``False`` for that request so that clients do not
+ need to send the ``X-CSRF-Token` header. The callback is only tested
+ for non-safe methods as defined by ``safe_methods``.
+
"""
- options = DefaultCSRFOptions(require_csrf, token, header, safe_methods)
+ options = DefaultCSRFOptions(
+ require_csrf, token, header, safe_methods, callback,
+ )
def register():
self.registry.registerUtility(options, IDefaultCSRFOptions)
intr = self.introspectable('default csrf view options',
@@ -204,13 +217,15 @@ class SecurityConfiguratorMixin(object):
intr['token'] = token
intr['header'] = header
intr['safe_methods'] = as_sorted_tuple(safe_methods)
+ intr['callback'] = callback
self.action(IDefaultCSRFOptions, register, order=PHASE1_CONFIG,
introspectables=(intr,))
@implementer(IDefaultCSRFOptions)
class DefaultCSRFOptions(object):
- def __init__(self, require_csrf, token, header, safe_methods):
+ def __init__(self, require_csrf, token, header, safe_methods, callback):
self.require_csrf = require_csrf
self.token = token
self.header = header
self.safe_methods = frozenset(safe_methods)
+ self.callback = callback
diff --git a/pyramid/interfaces.py b/pyramid/interfaces.py
index 114f802aa..c1ddea63f 100644
--- a/pyramid/interfaces.py
+++ b/pyramid/interfaces.py
@@ -925,6 +925,7 @@ class IDefaultCSRFOptions(Interface):
token = Attribute('The key to be matched in the body of the request.')
header = Attribute('The header to be matched with the CSRF token.')
safe_methods = Attribute('A set of safe methods that skip CSRF checks.')
+ callback = Attribute('A callback to disable CSRF checks per-request.')
class ISessionFactory(Interface):
""" An interface representing a factory which accepts a request object and
diff --git a/pyramid/tests/test_config/test_security.py b/pyramid/tests/test_config/test_security.py
index e461bfd4a..5db8e21fc 100644
--- a/pyramid/tests/test_config/test_security.py
+++ b/pyramid/tests/test_config/test_security.py
@@ -108,14 +108,18 @@ class ConfiguratorSecurityMethodsTests(unittest.TestCase):
self.assertEqual(result.header, 'X-CSRF-Token')
self.assertEqual(list(sorted(result.safe_methods)),
['GET', 'HEAD', 'OPTIONS', 'TRACE'])
+ self.assertTrue(result.callback is None)
def test_changing_set_default_csrf_options(self):
from pyramid.interfaces import IDefaultCSRFOptions
config = self._makeOne(autocommit=True)
+ def callback(request): return True
config.set_default_csrf_options(
- require_csrf=False, token='DUMMY', header=None, safe_methods=('PUT',))
+ require_csrf=False, token='DUMMY', header=None,
+ safe_methods=('PUT',), callback=callback)
result = config.registry.getUtility(IDefaultCSRFOptions)
self.assertEqual(result.require_csrf, False)
self.assertEqual(result.token, 'DUMMY')
self.assertEqual(result.header, None)
self.assertEqual(list(sorted(result.safe_methods)), ['PUT'])
+ self.assertTrue(result.callback is callback)
diff --git a/pyramid/tests/test_viewderivers.py b/pyramid/tests/test_viewderivers.py
index 676c6f66a..51d0bd367 100644
--- a/pyramid/tests/test_viewderivers.py
+++ b/pyramid/tests/test_viewderivers.py
@@ -1291,6 +1291,34 @@ class TestDeriveView(unittest.TestCase):
view = self.config._derive_view(inner_view)
self.assertRaises(BadCSRFToken, lambda: view(None, request))
+ def test_csrf_view_enabled_via_callback(self):
+ def callback(request):
+ return True
+ from pyramid.exceptions import BadCSRFToken
+ def inner_view(request): pass
+ request = self._makeRequest()
+ request.scheme = "http"
+ request.method = 'POST'
+ request.session = DummySession({'csrf_token': 'foo'})
+ self.config.set_default_csrf_options(require_csrf=True, callback=callback)
+ view = self.config._derive_view(inner_view)
+ self.assertRaises(BadCSRFToken, lambda: view(None, request))
+
+ def test_csrf_view_disabled_via_callback(self):
+ def callback(request):
+ return False
+ response = DummyResponse()
+ def inner_view(request):
+ return response
+ request = self._makeRequest()
+ request.scheme = "http"
+ request.method = 'POST'
+ request.session = DummySession({'csrf_token': 'foo'})
+ self.config.set_default_csrf_options(require_csrf=True, callback=callback)
+ view = self.config._derive_view(inner_view)
+ result = view(None, request)
+ self.assertTrue(result is response)
+
def test_csrf_view_uses_custom_csrf_token(self):
response = DummyResponse()
def inner_view(request):
diff --git a/pyramid/viewderivers.py b/pyramid/viewderivers.py
index 513ddf022..4eb0ce704 100644
--- a/pyramid/viewderivers.py
+++ b/pyramid/viewderivers.py
@@ -481,11 +481,13 @@ def csrf_view(view, info):
token = 'csrf_token'
header = 'X-CSRF-Token'
safe_methods = frozenset(["GET", "HEAD", "OPTIONS", "TRACE"])
+ callback = None
else:
default_val = defaults.require_csrf
token = defaults.token
header = defaults.header
safe_methods = defaults.safe_methods
+ callback = defaults.callback
enabled = (
explicit_val is True or
@@ -501,7 +503,10 @@ def csrf_view(view, info):
wrapped_view = view
if enabled:
def csrf_view(context, request):
- if request.method not in safe_methods:
+ if (
+ request.method not in safe_methods and
+ (callback is None or callback(request))
+ ):
check_csrf_origin(request, raises=True)
check_csrf_token(request, token, header, raises=True)
return view(context, request)