aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--fietsboek/updater/__init__.py19
-rw-r--r--fietsboek/updater/cli.py19
2 files changed, 38 insertions, 0 deletions
diff --git a/fietsboek/updater/__init__.py b/fietsboek/updater/__init__.py
index 42e40f4..1803e09 100644
--- a/fietsboek/updater/__init__.py
+++ b/fietsboek/updater/__init__.py
@@ -331,6 +331,25 @@ class Updater:
return UpdateState.OUTDATED
return state
+ def check_connectivity(self) -> str | None:
+ """Checks whether the data directory and the SQL server accessible.
+
+ Returns ``None`` if there are no problems, or a string describing the
+ error.
+
+ :return: Whether there is a connection error.
+ """
+ data_dir = Path(self.settings["fietsboek.data_dir"])
+ if not data_dir.exists():
+ return "data directory does not exist"
+
+ engine = sqlalchemy.create_engine(self.settings["sqlalchemy.url"])
+ try:
+ with engine.connect():
+ pass
+ except sqlalchemy.exc.OperationalError as exc:
+ return f"could not connect to database\n\n{exc}"
+
class UpdateScript:
"""Represents an update script."""
diff --git a/fietsboek/updater/cli.py b/fietsboek/updater/cli.py
index 9b7d92e..271e7a1 100644
--- a/fietsboek/updater/cli.py
+++ b/fietsboek/updater/cli.py
@@ -8,6 +8,7 @@ migrating the configuration.
"""
import logging.config
+import sys
import click
@@ -32,6 +33,21 @@ def user_confirm(verb):
click.confirm("Proceed?", abort=True)
+def check_connectivity(updater: Updater):
+ """Makes sure that the updater can connect to the database.
+
+ Aborts the program if not.
+
+ :param updater: The updater.
+ """
+ error = updater.check_connectivity()
+ if error is None:
+ return
+ click.secho("Error: ", fg="red", nl=False)
+ click.echo(error)
+ sys.exit(1)
+
+
@click.group(
help=__doc__,
context_settings={"help_option_names": ["-h", "--help"]},
@@ -62,6 +78,7 @@ def update(ctx, config, version, force):
logging.config.fileConfig(config)
updater = Updater(config)
updater.load()
+ check_connectivity(updater)
if version and not updater.exists(version):
ctx.fail(f"Version {version!r} not found")
@@ -107,6 +124,7 @@ def downgrade(ctx, config, version, force):
logging.config.fileConfig(config)
updater = Updater(config)
updater.load()
+ check_connectivity(updater)
if version and not updater.exists(version):
ctx.fail(f"Version {version!r} not found")
@@ -158,6 +176,7 @@ def status(config):
logging.config.fileConfig(config)
updater = Updater(config)
updater.load()
+ check_connectivity(updater)
current = updater.current_versions()
heads = updater.heads()
click.secho("Current versions:", fg="yellow")