aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--fietsboek/views/browse.py94
1 files changed, 57 insertions, 37 deletions
diff --git a/fietsboek/views/browse.py b/fietsboek/views/browse.py
index ede6b21..aed7d0a 100644
--- a/fietsboek/views/browse.py
+++ b/fietsboek/views/browse.py
@@ -1,14 +1,18 @@
"""Views for browsing all tracks."""
import datetime
+from collections.abc import Callable, Iterable
+from enum import Enum
from io import RawIOBase
-from typing import List
from zipfile import ZIP_DEFLATED, ZipFile
from pyramid.httpexceptions import HTTPBadRequest, HTTPForbidden, HTTPNotFound
+from pyramid.request import Request
from pyramid.response import Response
from pyramid.view import view_config
from sqlalchemy import func, or_, select
-from sqlalchemy.orm import aliased
+from sqlalchemy.orm import DeclarativeMeta, aliased
+from sqlalchemy.orm.util import AliasedClass
+from sqlalchemy.sql import Selectable
from .. import models, util
from ..models.track import TrackType, TrackWithMetadata
@@ -25,31 +29,35 @@ class Stream(RawIOBase):
super().__init__()
self.buffer = []
- def write(self, b):
+ # The following definition violates the substitution principle, so mypy
+ # would complain. However, I think we're good acting like we take only
+ # "bytes".
+ def write(self, b: bytes) -> int: # type: ignore
+ b = bytes(b)
self.buffer.append(b)
return len(b)
- def readall(self):
+ def readall(self) -> bytes:
buf = self.buffer
self.buffer = []
return b"".join(buf)
-def _get_int(request, name):
+def _get_int(request: Request, name: str) -> int:
try:
return int(request.params.get(name))
except ValueError as exc:
raise HTTPBadRequest(f"Invalid integer in {name!r}") from exc
-def _get_date(request, name):
+def _get_date(request: Request, name: str) -> datetime.date:
try:
return datetime.date.fromisoformat(request.params.get(name))
except ValueError as exc:
raise HTTPBadRequest(f"Invalid date in {name!r}") from exc
-def _get_enum(enum, value):
+def _get_enum(enum: type[Enum], value: str) -> Enum:
try:
return enum[value]
except KeyError as exc:
@@ -59,7 +67,9 @@ def _get_enum(enum, value):
class Filter:
"""A class representing a filter that the user can apply to the track list."""
- def compile(self, query, track, track_cache):
+ def compile(
+ self, query: Selectable, track: AliasedClass, track_cache: type[DeclarativeMeta]
+ ) -> Selectable:
"""Compile the filter into the SQL query.
Returns the modified query.
@@ -75,7 +85,7 @@ class Filter:
# pylint: disable=unused-argument
return query
- def apply(self, track):
+ def apply(self, track: TrackWithMetadata) -> bool:
"""Check if the given track matches the filter.
:param track: The track to check.
@@ -87,40 +97,50 @@ class Filter:
class LambdaFilter(Filter):
"""A :class:`Filter` that works by provided lambda functions."""
- def __init__(self, compiler, matcher):
+ def __init__(
+ self,
+ compiler: Callable[[Selectable, AliasedClass, type[DeclarativeMeta]], Selectable],
+ matcher: Callable[[TrackWithMetadata], bool],
+ ):
self.compiler = compiler
self.matcher = matcher
- def compile(self, query, track, track_cache):
+ def compile(
+ self, query: Selectable, track: AliasedClass, track_cache: type[DeclarativeMeta]
+ ) -> Selectable:
return self.compiler(query, track, track_cache)
- def apply(self, track):
+ def apply(self, track: TrackWithMetadata) -> bool:
return self.matcher(track)
class SearchFilter(Filter):
"""A :class:`Filter` that looks for the given search terms."""
- def __init__(self, search_terms):
+ def __init__(self, search_terms: Iterable[str]):
self.search_terms = search_terms
- def compile(self, query, track, track_cache):
+ def compile(
+ self, query: Selectable, track: AliasedClass, track_cache: type[DeclarativeMeta]
+ ) -> Selectable:
for term in self.search_terms:
term = term.lower()
query = query.where(func.lower(track.title).contains(term))
return query
- def apply(self, track):
+ def apply(self, track: TrackWithMetadata) -> bool:
return all(term.lower() in track.title.lower() for term in self.search_terms)
class TagFilter(Filter):
"""A :class:`Filter` that looks for the given tags."""
- def __init__(self, tags):
+ def __init__(self, tags: Iterable[str]):
self.tags = tags
- def compile(self, query, track, track_cache):
+ def compile(
+ self, query: Selectable, track: AliasedClass, track_cache: type[DeclarativeMeta]
+ ) -> Selectable:
lower_tags = [tag.lower() for tag in self.tags]
for tag in lower_tags:
exists_query = (
@@ -132,7 +152,7 @@ class TagFilter(Filter):
query = query.where(exists_query)
return query
- def apply(self, track):
+ def apply(self, track: TrackWithMetadata) -> bool:
lower_track_tags = {tag.lower() for tag in track.text_tags()}
lower_tags = {tag.lower() for tag in self.tags}
return all(tag in lower_track_tags for tag in lower_tags)
@@ -141,10 +161,12 @@ class TagFilter(Filter):
class PersonFilter(Filter):
"""A :class:`Filter` that looks for the given tagged people, based on their name."""
- def __init__(self, names):
+ def __init__(self, names: Iterable[str]):
self.names = names
- def compile(self, query, track, track_cache):
+ def compile(
+ self, query: Selectable, track: AliasedClass, track_cache: type[DeclarativeMeta]
+ ) -> Selectable:
lower_names = [name.lower() for name in self.names]
for name in lower_names:
tpa = models.track.track_people_assoc
@@ -164,7 +186,7 @@ class PersonFilter(Filter):
query = query.where(or_(exists_query, is_owner))
return query
- def apply(self, track):
+ def apply(self, track: TrackWithMetadata) -> bool:
lower_names = {person.name.lower() for person in track.tagged_people}
lower_names.add(track.owner.name.lower())
return all(name.lower() in lower_names for name in self.names)
@@ -173,10 +195,12 @@ class PersonFilter(Filter):
class UserTaggedFilter(Filter):
"""A :class:`Filter` that looks for a specific user to be tagged."""
- def __init__(self, user):
+ def __init__(self, user: models.User):
self.user = user
- def compile(self, query, track, track_cache):
+ def compile(
+ self, query: Selectable, track: AliasedClass, track_cache: type[DeclarativeMeta]
+ ) -> Selectable:
tpa = models.track.track_people_assoc
return query.where(
or_(
@@ -190,39 +214,39 @@ class UserTaggedFilter(Filter):
)
)
- def apply(self, track):
+ def apply(self, track: TrackWithMetadata) -> bool:
return track.owner == self.user or self.user in track.tagged_people
class FilterCollection(Filter):
"""A class that applies multiple :class:`Filter`."""
- def __init__(self, filters):
+ def __init__(self, filters: Iterable[Filter]):
self._filters = filters
def __bool__(self):
return bool(self._filters)
- def compile(self, query, track, track_cache):
+ def compile(
+ self, query: Selectable, track: AliasedClass, track_cache: type[DeclarativeMeta]
+ ) -> Selectable:
for filty in self._filters:
query = filty.compile(query, track, track_cache)
return query
- def apply(self, track):
+ def apply(self, track: TrackWithMetadata) -> bool:
return all(filty.apply(track) for filty in self._filters)
@classmethod
- def parse(cls, request):
+ def parse(cls, request: Request) -> "FilterCollection":
"""Parse the filters from the given request.
:raises HTTPBadRequest: If the filters are malformed.
:param request: The request.
- :type request: pyramid.request.Request
:return: The parsed filter.
- :rtype: FilterCollection
"""
# pylint: disable=singleton-comparison
- filters: List[Filter] = []
+ filters: list[Filter] = []
if request.params.get("search-terms"):
term = request.params.get("search-terms").strip()
filters.append(SearchFilter([term]))
@@ -321,13 +345,11 @@ class FilterCollection(Filter):
@view_config(
route_name="browse", renderer="fietsboek:templates/browse.jinja2", request_method="GET"
)
-def browse(request):
+def browse(request: Request) -> Response:
"""Returns the page that lets a user browse all visible tracks.
:param request: The Pyramid request.
- :type request: pyramid.request.Request
:return: The HTTP response.
- :rtype: pyramid.response.Response
"""
filters = FilterCollection.parse(request)
track = aliased(models.Track, models.User.visible_tracks_query(request.identity).subquery())
@@ -348,13 +370,11 @@ def browse(request):
@view_config(route_name="track-archive", request_method="GET")
-def archive(request):
+def archive(request: Request) -> Response:
"""Packs multiple tracks into a single archive.
:param request: The Pyramid request.
- :type request: pyramid.request.Request
:return: The HTTP response.
- :rtype: pyramid.response.Response
"""
track_ids = set(map(int, request.params.getall("track_id[]")))
tracks = (