From 6e1d245e8faf6072e9fe20bf185624cf0c10da9b Mon Sep 17 00:00:00 2001 From: Chris McDonough Date: Mon, 16 Jan 2012 15:47:59 -0500 Subject: have _FileIter return self from __iter__ so .close has a chance to be called --- pyramid/static.py | 17 ++++++++++++----- pyramid/tests/test_static.py | 27 +++++++++++++++++++++++++++ 2 files changed, 39 insertions(+), 5 deletions(-) diff --git a/pyramid/static.py b/pyramid/static.py index b82bc1bc4..8788d016d 100644 --- a/pyramid/static.py +++ b/pyramid/static.py @@ -68,7 +68,7 @@ class _FileResponse(Response): if 'wsgi.file_wrapper' in environ: app_iter = environ['wsgi.file_wrapper'](f, _BLOCK_SIZE) else: - app_iter = _FileIter(open(path, 'rb')) + app_iter = _FileIter(open(path, 'rb'), _BLOCK_SIZE) self.app_iter = app_iter # assignment of content_length must come after assignment of app_iter self.content_length = content_length @@ -76,13 +76,20 @@ class _FileResponse(Response): self.cache_expires = cache_max_age class _FileIter(object): - block_size = _BLOCK_SIZE - - def __init__(self, file): + def __init__(self, file, block_size): self.file = file + self.block_size = block_size def __iter__(self): - return iter(lambda: self.file.read(self.block_size), b'') + return self + + def next(self): + val = self.file.read(self.block_size) + if not val: + raise StopIteration + return val + + __next__ = next # py3 def close(self): self.file.close() diff --git a/pyramid/tests/test_static.py b/pyramid/tests/test_static.py index 4edd2728e..3d6fbe893 100644 --- a/pyramid/tests/test_static.py +++ b/pyramid/tests/test_static.py @@ -1,5 +1,6 @@ import datetime import unittest +import io # 5 years from now (more or less) fiveyrsfuture = datetime.datetime.utcnow() + datetime.timedelta(5*365) @@ -382,6 +383,32 @@ class Test_patch_mimetypes(unittest.TestCase): result = self._callFUT(module) self.assertEqual(result, False) +class Test_FileIter(unittest.TestCase): + def _makeOne(self, file, block_size): + from pyramid.static import _FileIter + return _FileIter(file, block_size) + + def test___iter__(self): + f = io.BytesIO(b'abc') + inst = self._makeOne(f, 1) + self.assertEqual(inst.__iter__(), inst) + + def test_iteration(self): + data = b'abcdef' + f = io.BytesIO(b'abcdef') + inst = self._makeOne(f, 1) + r = b'' + for x in inst: + self.assertEqual(len(x), 1) + r+=x + self.assertEqual(r, data) + + def test_close(self): + f = io.BytesIO(b'abc') + inst = self._makeOne(f, 1) + inst.close() + self.assertTrue(f.closed) + class DummyContext: pass -- cgit v1.2.3