From a2aaf098ca5ac897a5e0e1b2302acaf0562f8aad Mon Sep 17 00:00:00 2001 From: Daniel Schadt Date: Thu, 30 Mar 2023 20:30:43 +0200 Subject: type hint the updater --- fietsboek/updater/__init__.py | 79 ++++++++++++++++++------------------------- 1 file changed, 32 insertions(+), 47 deletions(-) diff --git a/fietsboek/updater/__init__.py b/fietsboek/updater/__init__.py index 9e2e7f0..fa7f5c6 100644 --- a/fietsboek/updater/__init__.py +++ b/fietsboek/updater/__init__.py @@ -6,7 +6,7 @@ import logging import random import string from pathlib import Path -from typing import List +from typing import Optional import alembic.command import alembic.config @@ -57,20 +57,19 @@ class Updater: them in the right order. """ - def __init__(self, config_path): + def __init__(self, config_path: str): self.config_path = config_path self.settings = pyramid.paster.get_appsettings(config_path) self.alembic_config = alembic.config.Config(config_path) - self.scripts = {} - self.forward_dependencies = {} - self.backward_dependencies = {} + self.scripts: dict[str, "UpdateScript"] = {} + self.forward_dependencies: dict[str, list[str]] = {} + self.backward_dependencies: dict[str, list[str]] = {} @property - def version_file(self): + def version_file(self) -> Path: """Returns the path to the version file. :return: The path to the data's version file. - :rytpe: pathlib.Path """ data_dir = Path(self.settings["fietsboek.data_dir"]) return data_dir / "VERSION" @@ -99,20 +98,17 @@ class Updater: down_alembic = possible_alembic script.down_alembic = down_alembic - def exists(self, revision_id): + def exists(self, revision_id: str) -> bool: """Checks if the revision with the given ID exists. :param revision_id: ID of the revision to check. - :type revision_id: str :return: True if the revision exists. - :rtype: bool """ return revision_id in self.scripts - def current_versions(self): + def current_versions(self) -> list[str]: """Reads the current version of the data. - :rtype: list[str] :return: The versions, or an empty list if no versions are found. """ try: @@ -121,7 +117,7 @@ class Updater: except FileNotFoundError: return [] - def _transitive_versions(self): + def _transitive_versions(self) -> set[str]: versions = set() queue = self.current_versions() while queue: @@ -131,21 +127,25 @@ class Updater: queue.extend(self.scripts[current].previous) return versions - def _reverse_versions(self): + def _reverse_versions(self) -> set[str]: all_versions = set(script.id for script in self.scripts.values()) return all_versions - self._transitive_versions() - def stamp(self, versions): + def stamp(self, versions: list[str]): """Stampts the given version into the version file. This does not run any updates, it simply updates the version information. :param version: The versions to stamp. - :type version: list[str] """ self.version_file.write_text("\n".join(versions), encoding="utf-8") - def _pick_updates(self, wanted, applied, dependencies): + def _pick_updates( + self, + wanted: str, + applied: set[str], + dependencies: dict[str, list[str]], + ) -> set[str]: to_apply = set() queue = [wanted] while queue: @@ -156,9 +156,9 @@ class Updater: queue.extend(dependencies[current]) return to_apply - def _make_schedule(self, wanted, dependencies): + def _make_schedule(self, wanted: set[str], dependencies: dict[str, list[str]]) -> list[str]: wanted = set(wanted) - queue: List[str] = [] + queue: list[str] = [] while wanted: next_updates = { update @@ -169,19 +169,18 @@ class Updater: wanted -= next_updates return queue - def _stamp_versions(self, old, new): + def _stamp_versions(self, old: list[str], new: list[str]): versions = self.current_versions() versions = [version for version in versions if version not in old] versions.extend(new) self.stamp(versions) - def upgrade(self, target): + def upgrade(self, target: str): """Run the tasks to upgrade to the given target. This ensures that all previous migrations are also run. :param target: The target revision. - :type target: str """ # First, we figure out which tasks we have already applied and which # still need applying. This is pretty much a BFS over the current @@ -198,13 +197,12 @@ class Updater: script.upgrade(self.settings, self.alembic_config) self._stamp_versions(script.previous, [script.id]) - def downgrade(self, target): + def downgrade(self, target: str): """Run the tasks to downgrade to the given target. This ensures that all succeeding down-migrations are also run. :param target: The target revision. - :type target: str """ # This is basically the same as upgrade() but with the reverse # dependencies instead. @@ -218,16 +216,14 @@ class Updater: script.downgrade(self.settings, self.alembic_config) self._stamp_versions([script.id], script.previous) - def new_revision(self, revision_id=None): + def new_revision(self, revision_id: Optional[str] = None) -> str: """Creates a new revision with the current versions as dependencies and the current alembic version. :param revision_id: The revision ID to use. By default, a random string will be generated. - :type revision_id: str :return: The filename of the revision file in the ``updater/`` directory. - :rtype: str """ if not revision_id: revision_id = "".join(random.choices(string.ascii_lowercase + string.digits, k=16)) @@ -260,15 +256,14 @@ class Updater: fobj.write(revision) return filename - def heads(self): + def heads(self) -> list[str]: """Returns all "heads", that are the latest revisions. :return: The heads. - :rtype: list[str] """ return [rev_id for (rev_id, deps) in self.backward_dependencies.items() if not deps] - def has_applied(self, revision_id, backward=False): + def has_applied(self, revision_id: str, backward: bool = False) -> bool: """Checks whether the given revision is applied. By default, this checks if a given update is applied, i.e. the current @@ -281,12 +276,9 @@ class Updater: :meth:`exists` to check whether the revision actually exists. :param revision_id: The revision to check. - :type revision_id: str :param backward: Whether to switch the comparison direction. - :type backward: bool :return: ``True`` if the current version at least matches the asked revision ID. - :rtype: bool """ if not backward: return revision_id in self._transitive_versions() @@ -296,45 +288,42 @@ class Updater: class UpdateScript: """Represents an update script.""" - def __init__(self, source, name): + def __init__(self, source: str, name: str): self.name = name spec = importlib.util.spec_from_loader(f"{__name__}.{name}", None) self.module = importlib.util.module_from_spec(spec) # type: ignore assert self.module exec(source, self.module.__dict__) # pylint: disable=exec-used - self.down_alembic = None + self.down_alembic: Optional[str] = None def __repr__(self): return f"<{__name__}.{self.__class__.__name__} name={self.name!r} id={self.id!r}>" @property - def id(self): + def id(self) -> str: """Returns the ID of the update. - :rtype: str :return: The id of the update """ return self.module.update_id @property - def previous(self): + def previous(self) -> list[str]: """Returns all dependencies of the update. - :rtype: list[str] :return: The IDs of all dependencies of the update. """ return getattr(self.module, "previous", []) @property - def alembic_version(self): + def alembic_version(self) -> str: """Returns the alembic revisions of the update. - :rtype: list[str] :return: The needed alembic revisions. """ return self.module.alembic_revision - def upgrade(self, config, alembic_config): + def upgrade(self, config: dict, alembic_config: alembic.config.Config): """Runs the upgrade migrations of this update script. This first runs the pre_alembic task, then the alembic migration, and @@ -344,9 +333,7 @@ class UpdateScript: executed. :param config: The app configuration. - :type config: dict :param alembic_config: The alembic config to use. - :type alembic_config: alembic.config.Config """ LOGGER.info("[up] Running pre-alembic task for %s", self.id) self.module.Up().pre_alembic(config) @@ -355,15 +342,13 @@ class UpdateScript: LOGGER.info("[up] Running post-alembic task for %s", self.id) self.module.Up().post_alembic(config) - def downgrade(self, config, alembic_config): + def downgrade(self, config: dict, alembic_config: alembic.config.Config): """Runs the downgrade migrations of this update script. See also :meth:`upgrade`. :param config: The app configuration. - :type config: dict :param alembic_config: The alembic config to use. - :type alembic_config: alembic.config.Config """ LOGGER.info("[down] Running pre-alembic task for %s", self.id) self.module.Down().pre_alembic(config) -- cgit v1.2.3