diff options
| -rw-r--r-- | pyramid/scaffolds/__init__.py | 4 | ||||
| -rw-r--r-- | pyramid/scripts/pcreate.py | 96 | ||||
| -rw-r--r-- | pyramid/tests/test_scaffolds/test_init.py | 5 | ||||
| -rw-r--r-- | pyramid/tests/test_scripts/test_pcreate.py | 46 |
4 files changed, 120 insertions, 31 deletions
diff --git a/pyramid/scaffolds/__init__.py b/pyramid/scaffolds/__init__.py index c993ce5f9..4e811a42b 100644 --- a/pyramid/scaffolds/__init__.py +++ b/pyramid/scaffolds/__init__.py @@ -18,10 +18,6 @@ class PyramidTemplate(Template): misnamings (such as naming a package "site" or naming a package logger "root". """ - if vars['package'] == 'site': - raise ValueError('Sorry, you may not name your package "site". ' - 'The package name "site" has a special meaning in ' - 'Python. Please name it anything except "site".') vars['random_string'] = native_(binascii.hexlify(os.urandom(20))) package_logger = vars['package'] if package_logger == 'root': diff --git a/pyramid/scripts/pcreate.py b/pyramid/scripts/pcreate.py index f6376f575..1e8074fc5 100644 --- a/pyramid/scripts/pcreate.py +++ b/pyramid/scripts/pcreate.py @@ -8,12 +8,17 @@ import os.path import pkg_resources import re import sys +from pyramid.compat import input_ _bad_chars_re = re.compile('[^a-zA-Z0-9_]') def main(argv=sys.argv, quiet=False): command = PCreateCommand(argv, quiet) - return command.run() + try: + return command.run() + except KeyboardInterrupt: # pragma: no cover + return 1 + class PCreateCommand(object): verbosity = 1 # required @@ -52,6 +57,13 @@ class PCreateCommand(object): dest='interactive', action='store_true', help='When a file would be overwritten, interrogate') + parser.add_option('--ignore-conflicting-name', + dest='force_bad_name', + action='store_true', + default=False, + help='Do create a project even if the chosen name ' + 'is the name of an already existing / importable ' + 'package.') pyramid_dist = pkg_resources.get_distribution("pyramid") @@ -69,25 +81,19 @@ class PCreateCommand(object): self.out('') self.show_scaffolds() return 2 - if not self.options.scaffold_name: - self.out('You must provide at least one scaffold name: -s <scaffold name>') - self.out('') - self.show_scaffolds() - return 2 - if not self.args: - self.out('You must provide a project name') - return 2 - available = [x.name for x in self.scaffolds] - diff = set(self.options.scaffold_name).difference(available) - if diff: - self.out('Unavailable scaffolds: %s' % list(diff)) + + if not self.validate_input(): return 2 + return self.render_scaffolds() - def render_scaffolds(self): - options = self.options - args = self.args - output_dir = os.path.abspath(os.path.normpath(args[0])) + @property + def output_path(self): + return os.path.abspath(os.path.normpath(self.args[0])) + + @property + def project_vars(self): + output_dir = self.output_path project_name = os.path.basename(os.path.split(output_dir)[1]) pkg_name = _bad_chars_re.sub( '', project_name.lower().replace('-', '_')) @@ -111,17 +117,22 @@ class PCreateCommand(object): else: pyramid_docs_branch = 'latest' - vars = { + return { 'project': project_name, 'package': pkg_name, 'egg': egg_name, 'pyramid_version': pyramid_version, 'pyramid_docs_branch': pyramid_docs_branch, - } - for scaffold_name in options.scaffold_name: + } + + + def render_scaffolds(self): + props = self.project_vars + output_dir = self.output_path + for scaffold_name in self.options.scaffold_name: for scaffold in self.scaffolds: if scaffold.name == scaffold_name: - scaffold.run(self, output_dir, vars) + scaffold.run(self, output_dir, props) return 0 def show_scaffolds(self): @@ -154,5 +165,48 @@ class PCreateCommand(object): if not self.quiet: print(msg) + def validate_input(self): + if not self.options.scaffold_name: + self.out('You must provide at least one scaffold name: -s <scaffold name>') + self.out('') + self.show_scaffolds() + return False + if not self.args: + self.out('You must provide a project name') + return False + available = [x.name for x in self.scaffolds] + diff = set(self.options.scaffold_name).difference(available) + if diff: + self.out('Unavailable scaffolds: %s' % ", ".join(sorted(diff))) + return False + + pkg_name = self.project_vars['package'] + + if pkg_name == 'site' and not self.options.force_bad_name: + self.out('The package name "site" has a special meaning in ' + 'Python. Are you sure you want to use it as your ' + 'project\'s name?') + return self.confirm_bad_name('Really use "{0}"?: '.format(pkg_name)) + + # check if pkg_name can be imported (i.e. already exists in current + # $PYTHON_PATH, if so - let the user confirm + pkg_exists = True + try: + __import__(pkg_name, globals(), locals(), [], 0) # use absolute imports + except ImportError as error: + pkg_exists = False + if not pkg_exists: + return True + + if self.options.force_bad_name: + return True + self.out('A package named "{0}" already exists, are you sure you want ' + 'to use it as your project\'s name?'.format(pkg_name)) + return self.confirm_bad_name('Really use "{0}"?: '.format(pkg_name)) + + def confirm_bad_name(self, prompt): # pragma: no cover + answer = input_('{0} [y|N]: '.format(prompt)) + return answer.strip().lower() == 'y' + if __name__ == '__main__': # pragma: no cover sys.exit(main() or 0) diff --git a/pyramid/tests/test_scaffolds/test_init.py b/pyramid/tests/test_scaffolds/test_init.py index 4988e66ff..f4d1b287a 100644 --- a/pyramid/tests/test_scaffolds/test_init.py +++ b/pyramid/tests/test_scaffolds/test_init.py @@ -12,11 +12,6 @@ class TestPyramidTemplate(unittest.TestCase): self.assertTrue(vars['random_string']) self.assertEqual(vars['package_logger'], 'one') - def test_pre_site(self): - inst = self._makeOne() - vars = {'package':'site'} - self.assertRaises(ValueError, inst.pre, 'command', 'output dir', vars) - def test_pre_root(self): inst = self._makeOne() vars = {'package':'root'} diff --git a/pyramid/tests/test_scripts/test_pcreate.py b/pyramid/tests/test_scripts/test_pcreate.py index 63e5e6368..eaa7c1464 100644 --- a/pyramid/tests/test_scripts/test_pcreate.py +++ b/pyramid/tests/test_scripts/test_pcreate.py @@ -1,5 +1,6 @@ import unittest + class TestPCreateCommand(unittest.TestCase): def setUp(self): from pyramid.compat import NativeIO @@ -15,7 +16,8 @@ class TestPCreateCommand(unittest.TestCase): def _makeOne(self, *args, **kw): effargs = ['pcreate'] effargs.extend(args) - cmd = self._getTargetClass()(effargs, **kw) + tgt_class = kw.pop('target_class', self._getTargetClass()) + cmd = tgt_class(effargs, **kw) cmd.out = self.out return cmd @@ -220,6 +222,48 @@ class TestPCreateCommand(unittest.TestCase): 'pyramid_version': '0.10.1dev', 'pyramid_docs_branch': 'master'}) + def test_confirm_override_conflicting_name(self): + from pyramid.scripts.pcreate import PCreateCommand + class YahInputPCreateCommand(PCreateCommand): + def confirm_bad_name(self, pkg_name): + return True + cmd = self._makeOne('-s', 'dummy', 'Unittest', target_class=YahInputPCreateCommand) + scaffold = DummyScaffold('dummy') + cmd.scaffolds = [scaffold] + cmd.pyramid_dist = DummyDist("0.10.1dev") + result = cmd.run() + self.assertEqual(result, 0) + self.assertEqual( + scaffold.vars, + {'project': 'Unittest', 'egg': 'Unittest', 'package': 'unittest', + 'pyramid_version': '0.10.1dev', + 'pyramid_docs_branch': 'master'}) + + def test_force_override_conflicting_name(self): + cmd = self._makeOne('-s', 'dummy', 'Unittest', '--ignore-conflicting-name') + scaffold = DummyScaffold('dummy') + cmd.scaffolds = [scaffold] + cmd.pyramid_dist = DummyDist("0.10.1dev") + result = cmd.run() + self.assertEqual(result, 0) + self.assertEqual( + scaffold.vars, + {'project': 'Unittest', 'egg': 'Unittest', 'package': 'unittest', + 'pyramid_version': '0.10.1dev', + 'pyramid_docs_branch': 'master'}) + + def test_force_override_site_name(self): + from pyramid.scripts.pcreate import PCreateCommand + class NayInputPCreateCommand(PCreateCommand): + def confirm_bad_name(self, pkg_name): + return False + cmd = self._makeOne('-s', 'dummy', 'Site', target_class=NayInputPCreateCommand) + scaffold = DummyScaffold('dummy') + cmd.scaffolds = [scaffold] + cmd.pyramid_dist = DummyDist("0.10.1dev") + result = cmd.run() + self.assertEqual(result, 2) + class Test_main(unittest.TestCase): def _callFUT(self, argv): |
