From 89d8823d2eee5239d87107483b5b3b5d857e2cb6 Mon Sep 17 00:00:00 2001 From: Daniel Schadt Date: Thu, 13 Nov 2025 22:07:05 +0100 Subject: check DB connectivity before updating --- fietsboek/updater/__init__.py | 19 +++++++++++++++++++ fietsboek/updater/cli.py | 19 +++++++++++++++++++ 2 files changed, 38 insertions(+) diff --git a/fietsboek/updater/__init__.py b/fietsboek/updater/__init__.py index 42e40f4..1803e09 100644 --- a/fietsboek/updater/__init__.py +++ b/fietsboek/updater/__init__.py @@ -331,6 +331,25 @@ class Updater: return UpdateState.OUTDATED return state + def check_connectivity(self) -> str | None: + """Checks whether the data directory and the SQL server accessible. + + Returns ``None`` if there are no problems, or a string describing the + error. + + :return: Whether there is a connection error. + """ + data_dir = Path(self.settings["fietsboek.data_dir"]) + if not data_dir.exists(): + return "data directory does not exist" + + engine = sqlalchemy.create_engine(self.settings["sqlalchemy.url"]) + try: + with engine.connect(): + pass + except sqlalchemy.exc.OperationalError as exc: + return f"could not connect to database\n\n{exc}" + class UpdateScript: """Represents an update script.""" diff --git a/fietsboek/updater/cli.py b/fietsboek/updater/cli.py index 9b7d92e..271e7a1 100644 --- a/fietsboek/updater/cli.py +++ b/fietsboek/updater/cli.py @@ -8,6 +8,7 @@ migrating the configuration. """ import logging.config +import sys import click @@ -32,6 +33,21 @@ def user_confirm(verb): click.confirm("Proceed?", abort=True) +def check_connectivity(updater: Updater): + """Makes sure that the updater can connect to the database. + + Aborts the program if not. + + :param updater: The updater. + """ + error = updater.check_connectivity() + if error is None: + return + click.secho("Error: ", fg="red", nl=False) + click.echo(error) + sys.exit(1) + + @click.group( help=__doc__, context_settings={"help_option_names": ["-h", "--help"]}, @@ -62,6 +78,7 @@ def update(ctx, config, version, force): logging.config.fileConfig(config) updater = Updater(config) updater.load() + check_connectivity(updater) if version and not updater.exists(version): ctx.fail(f"Version {version!r} not found") @@ -107,6 +124,7 @@ def downgrade(ctx, config, version, force): logging.config.fileConfig(config) updater = Updater(config) updater.load() + check_connectivity(updater) if version and not updater.exists(version): ctx.fail(f"Version {version!r} not found") @@ -158,6 +176,7 @@ def status(config): logging.config.fileConfig(config) updater = Updater(config) updater.load() + check_connectivity(updater) current = updater.current_versions() heads = updater.heads() click.secho("Current versions:", fg="yellow") -- cgit v1.2.3