summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorChris McDonough <chrism@agendaless.com>2009-01-25 05:11:21 +0000
committerChris McDonough <chrism@agendaless.com>2009-01-25 05:11:21 +0000
commitc8cab3395432983c2165dce196ad5204e420a900 (patch)
tree3ef0a4b85c612513ac79779db28763f39e11907c
parente5cf7dbec2ccda7d2e4d79815ac441acf2ab1061 (diff)
downloadpyramid-c8cab3395432983c2165dce196ad5204e420a900.tar.gz
pyramid-c8cab3395432983c2165dce196ad5204e420a900.tar.bz2
pyramid-c8cab3395432983c2165dce196ad5204e420a900.zip
Optimize flatten a bit.
-rw-r--r--repoze/bfg/security.py11
-rw-r--r--repoze/bfg/tests/test_security.py31
2 files changed, 37 insertions, 5 deletions
diff --git a/repoze/bfg/security.py b/repoze/bfg/security.py
index 4e61b0ed9..693f253d0 100644
--- a/repoze/bfg/security.py
+++ b/repoze/bfg/security.py
@@ -312,12 +312,15 @@ def flatten(x):
[1, 2, [3, 4], (5, 6)]
>>> flatten([[[1,2,3], (42,None)], [4,5], [6], 7, MyVector(8,9,10)])
[1, 2, 3, 42, None, 4, 5, 6, 7, 8, 9, 10]"""
- if isinstance(x, basestring):
+ if not hasattr(x, '__iter__'):
return [x]
+ return _flatten(x)
+
+def _flatten(iterable):
result = []
- for el in x:
- if hasattr(el, "__iter__") and not isinstance(el, basestring):
- result.extend(flatten(el))
+ for el in iterable:
+ if hasattr(el, "__iter__"):
+ result.extend(_flatten(el))
else:
result.append(el)
return result
diff --git a/repoze/bfg/tests/test_security.py b/repoze/bfg/tests/test_security.py
index e9f35f57f..4209f8d3d 100644
--- a/repoze/bfg/tests/test_security.py
+++ b/repoze/bfg/tests/test_security.py
@@ -603,7 +603,36 @@ class TestACLDenied(unittest.TestCase):
self.assertEqual(str(denied), msg)
self.failUnless('<ACLDenied instance at ' in repr(denied))
self.failUnless("with msg %r>" % msg in repr(denied))
-
+
+class TestFlatten(unittest.TestCase):
+ def _callFUT(self, item):
+ from repoze.bfg.security import flatten
+ return flatten(item)
+
+ def test_str(self):
+ result = self._callFUT('a')
+ self.assertEqual(result, ['a'])
+
+ def test_unicode(self):
+ result = self._callFUT(u'a')
+ self.assertEqual(result, [u'a'])
+
+ def test_flat_sequence(self):
+ result = self._callFUT([1, 2, 3])
+ self.assertEqual(result, [1, 2, 3])
+
+ def test_singly_nested_sequence(self):
+ result = self._callFUT([1, [2, 3]])
+ self.assertEqual(result, [1, 2, 3])
+
+ def test_doubly_nested_sequence(self):
+ result = self._callFUT([1, [2, [3]]])
+ self.assertEqual(result, [1, 2, 3])
+
+ def test_mix_str_unicode_sequence(self):
+ result = self._callFUT([1, [2, [3]], u'a', ('b', set(['c', 'd']))])
+ self.assertEqual(result, [1, 2, 3, u'a', 'b', 'c', 'd'])
+
class DummyContext:
pass