aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--fietsboek/updater/__init__.py22
1 files changed, 18 insertions, 4 deletions
diff --git a/fietsboek/updater/__init__.py b/fietsboek/updater/__init__.py
index 3611e9f..b136a6c 100644
--- a/fietsboek/updater/__init__.py
+++ b/fietsboek/updater/__init__.py
@@ -86,8 +86,20 @@ class Updater:
# Ensure that each script has an entry
self.backward_dependencies = {script.id: [] for script in self.scripts.values()}
for script in self.scripts.values():
+ down_alembic = None
for prev_id in script.previous:
self.backward_dependencies[prev_id].append(script.id)
+ possible_alembic = self.scripts[prev_id].alembic_version
+ if down_alembic is None:
+ down_alembic = possible_alembic
+ elif down_alembic != possible_alembic:
+ LOGGER.error(
+ "Invalid update graph - two different down alembics for script %s",
+ script.id,
+ )
+ raise ValueError(f"Two alembic downgrades for {script.id}")
+ down_alembic = possible_alembic
+ script.down_alembic = down_alembic
def exists(self, revision_id):
"""Checks if the revision with the given ID exists.
@@ -291,6 +303,7 @@ class UpdateScript:
self.module = importlib.util.module_from_spec(spec) # type: ignore
assert self.module
exec(source, self.module.__dict__) # pylint: disable=exec-used
+ self.down_alembic = None
def __repr__(self):
return f"<{__name__}.{self.__class__.__name__} name={self.name!r} id={self.id!r}>"
@@ -355,10 +368,11 @@ class UpdateScript:
"""
LOGGER.info("[down] Running pre-alembic task for %s", self.id)
self.module.Down().pre_alembic(config)
- LOGGER.info("[down] Running alembic downgrade for %s to %s", self.id, self.alembic_version)
- alembic.command.downgrade(alembic_config, "-1")
- LOGGER.info("[down] Running post-alembic task for %s", self.id)
- self.module.Down().post_alembic(config)
+ if self.down_alembic:
+ LOGGER.info("[down] Running alembic downgrade for %s to %s", self.id, self.down_alembic)
+ alembic.command.downgrade(alembic_config, self.down_alembic)
+ LOGGER.info("[down] Running post-alembic task for %s", self.id)
+ self.module.Down().post_alembic(config)
def _filename_to_modname(name):