diff options
-rw-r--r-- | fietsboek/updater/__init__.py | 22 |
1 files changed, 18 insertions, 4 deletions
diff --git a/fietsboek/updater/__init__.py b/fietsboek/updater/__init__.py index 3611e9f..b136a6c 100644 --- a/fietsboek/updater/__init__.py +++ b/fietsboek/updater/__init__.py @@ -86,8 +86,20 @@ class Updater: # Ensure that each script has an entry self.backward_dependencies = {script.id: [] for script in self.scripts.values()} for script in self.scripts.values(): + down_alembic = None for prev_id in script.previous: self.backward_dependencies[prev_id].append(script.id) + possible_alembic = self.scripts[prev_id].alembic_version + if down_alembic is None: + down_alembic = possible_alembic + elif down_alembic != possible_alembic: + LOGGER.error( + "Invalid update graph - two different down alembics for script %s", + script.id, + ) + raise ValueError(f"Two alembic downgrades for {script.id}") + down_alembic = possible_alembic + script.down_alembic = down_alembic def exists(self, revision_id): """Checks if the revision with the given ID exists. @@ -291,6 +303,7 @@ class UpdateScript: self.module = importlib.util.module_from_spec(spec) # type: ignore assert self.module exec(source, self.module.__dict__) # pylint: disable=exec-used + self.down_alembic = None def __repr__(self): return f"<{__name__}.{self.__class__.__name__} name={self.name!r} id={self.id!r}>" @@ -355,10 +368,11 @@ class UpdateScript: """ LOGGER.info("[down] Running pre-alembic task for %s", self.id) self.module.Down().pre_alembic(config) - LOGGER.info("[down] Running alembic downgrade for %s to %s", self.id, self.alembic_version) - alembic.command.downgrade(alembic_config, "-1") - LOGGER.info("[down] Running post-alembic task for %s", self.id) - self.module.Down().post_alembic(config) + if self.down_alembic: + LOGGER.info("[down] Running alembic downgrade for %s to %s", self.id, self.down_alembic) + alembic.command.downgrade(alembic_config, self.down_alembic) + LOGGER.info("[down] Running post-alembic task for %s", self.id) + self.module.Down().post_alembic(config) def _filename_to_modname(name): |