From 4ad36ba65e33d32bd2ca1bae0cbc299b244369d8 Mon Sep 17 00:00:00 2001 From: Daniel Schadt Date: Tue, 9 May 2023 19:09:43 +0200 Subject: properly type hint browse.py This is in preparation of adding the sorting feature. --- fietsboek/views/browse.py | 94 ++++++++++++++++++++++++++++------------------- 1 file changed, 57 insertions(+), 37 deletions(-) diff --git a/fietsboek/views/browse.py b/fietsboek/views/browse.py index ede6b21..aed7d0a 100644 --- a/fietsboek/views/browse.py +++ b/fietsboek/views/browse.py @@ -1,14 +1,18 @@ """Views for browsing all tracks.""" import datetime +from collections.abc import Callable, Iterable +from enum import Enum from io import RawIOBase -from typing import List from zipfile import ZIP_DEFLATED, ZipFile from pyramid.httpexceptions import HTTPBadRequest, HTTPForbidden, HTTPNotFound +from pyramid.request import Request from pyramid.response import Response from pyramid.view import view_config from sqlalchemy import func, or_, select -from sqlalchemy.orm import aliased +from sqlalchemy.orm import DeclarativeMeta, aliased +from sqlalchemy.orm.util import AliasedClass +from sqlalchemy.sql import Selectable from .. import models, util from ..models.track import TrackType, TrackWithMetadata @@ -25,31 +29,35 @@ class Stream(RawIOBase): super().__init__() self.buffer = [] - def write(self, b): + # The following definition violates the substitution principle, so mypy + # would complain. However, I think we're good acting like we take only + # "bytes". + def write(self, b: bytes) -> int: # type: ignore + b = bytes(b) self.buffer.append(b) return len(b) - def readall(self): + def readall(self) -> bytes: buf = self.buffer self.buffer = [] return b"".join(buf) -def _get_int(request, name): +def _get_int(request: Request, name: str) -> int: try: return int(request.params.get(name)) except ValueError as exc: raise HTTPBadRequest(f"Invalid integer in {name!r}") from exc -def _get_date(request, name): +def _get_date(request: Request, name: str) -> datetime.date: try: return datetime.date.fromisoformat(request.params.get(name)) except ValueError as exc: raise HTTPBadRequest(f"Invalid date in {name!r}") from exc -def _get_enum(enum, value): +def _get_enum(enum: type[Enum], value: str) -> Enum: try: return enum[value] except KeyError as exc: @@ -59,7 +67,9 @@ def _get_enum(enum, value): class Filter: """A class representing a filter that the user can apply to the track list.""" - def compile(self, query, track, track_cache): + def compile( + self, query: Selectable, track: AliasedClass, track_cache: type[DeclarativeMeta] + ) -> Selectable: """Compile the filter into the SQL query. Returns the modified query. @@ -75,7 +85,7 @@ class Filter: # pylint: disable=unused-argument return query - def apply(self, track): + def apply(self, track: TrackWithMetadata) -> bool: """Check if the given track matches the filter. :param track: The track to check. @@ -87,40 +97,50 @@ class Filter: class LambdaFilter(Filter): """A :class:`Filter` that works by provided lambda functions.""" - def __init__(self, compiler, matcher): + def __init__( + self, + compiler: Callable[[Selectable, AliasedClass, type[DeclarativeMeta]], Selectable], + matcher: Callable[[TrackWithMetadata], bool], + ): self.compiler = compiler self.matcher = matcher - def compile(self, query, track, track_cache): + def compile( + self, query: Selectable, track: AliasedClass, track_cache: type[DeclarativeMeta] + ) -> Selectable: return self.compiler(query, track, track_cache) - def apply(self, track): + def apply(self, track: TrackWithMetadata) -> bool: return self.matcher(track) class SearchFilter(Filter): """A :class:`Filter` that looks for the given search terms.""" - def __init__(self, search_terms): + def __init__(self, search_terms: Iterable[str]): self.search_terms = search_terms - def compile(self, query, track, track_cache): + def compile( + self, query: Selectable, track: AliasedClass, track_cache: type[DeclarativeMeta] + ) -> Selectable: for term in self.search_terms: term = term.lower() query = query.where(func.lower(track.title).contains(term)) return query - def apply(self, track): + def apply(self, track: TrackWithMetadata) -> bool: return all(term.lower() in track.title.lower() for term in self.search_terms) class TagFilter(Filter): """A :class:`Filter` that looks for the given tags.""" - def __init__(self, tags): + def __init__(self, tags: Iterable[str]): self.tags = tags - def compile(self, query, track, track_cache): + def compile( + self, query: Selectable, track: AliasedClass, track_cache: type[DeclarativeMeta] + ) -> Selectable: lower_tags = [tag.lower() for tag in self.tags] for tag in lower_tags: exists_query = ( @@ -132,7 +152,7 @@ class TagFilter(Filter): query = query.where(exists_query) return query - def apply(self, track): + def apply(self, track: TrackWithMetadata) -> bool: lower_track_tags = {tag.lower() for tag in track.text_tags()} lower_tags = {tag.lower() for tag in self.tags} return all(tag in lower_track_tags for tag in lower_tags) @@ -141,10 +161,12 @@ class TagFilter(Filter): class PersonFilter(Filter): """A :class:`Filter` that looks for the given tagged people, based on their name.""" - def __init__(self, names): + def __init__(self, names: Iterable[str]): self.names = names - def compile(self, query, track, track_cache): + def compile( + self, query: Selectable, track: AliasedClass, track_cache: type[DeclarativeMeta] + ) -> Selectable: lower_names = [name.lower() for name in self.names] for name in lower_names: tpa = models.track.track_people_assoc @@ -164,7 +186,7 @@ class PersonFilter(Filter): query = query.where(or_(exists_query, is_owner)) return query - def apply(self, track): + def apply(self, track: TrackWithMetadata) -> bool: lower_names = {person.name.lower() for person in track.tagged_people} lower_names.add(track.owner.name.lower()) return all(name.lower() in lower_names for name in self.names) @@ -173,10 +195,12 @@ class PersonFilter(Filter): class UserTaggedFilter(Filter): """A :class:`Filter` that looks for a specific user to be tagged.""" - def __init__(self, user): + def __init__(self, user: models.User): self.user = user - def compile(self, query, track, track_cache): + def compile( + self, query: Selectable, track: AliasedClass, track_cache: type[DeclarativeMeta] + ) -> Selectable: tpa = models.track.track_people_assoc return query.where( or_( @@ -190,39 +214,39 @@ class UserTaggedFilter(Filter): ) ) - def apply(self, track): + def apply(self, track: TrackWithMetadata) -> bool: return track.owner == self.user or self.user in track.tagged_people class FilterCollection(Filter): """A class that applies multiple :class:`Filter`.""" - def __init__(self, filters): + def __init__(self, filters: Iterable[Filter]): self._filters = filters def __bool__(self): return bool(self._filters) - def compile(self, query, track, track_cache): + def compile( + self, query: Selectable, track: AliasedClass, track_cache: type[DeclarativeMeta] + ) -> Selectable: for filty in self._filters: query = filty.compile(query, track, track_cache) return query - def apply(self, track): + def apply(self, track: TrackWithMetadata) -> bool: return all(filty.apply(track) for filty in self._filters) @classmethod - def parse(cls, request): + def parse(cls, request: Request) -> "FilterCollection": """Parse the filters from the given request. :raises HTTPBadRequest: If the filters are malformed. :param request: The request. - :type request: pyramid.request.Request :return: The parsed filter. - :rtype: FilterCollection """ # pylint: disable=singleton-comparison - filters: List[Filter] = [] + filters: list[Filter] = [] if request.params.get("search-terms"): term = request.params.get("search-terms").strip() filters.append(SearchFilter([term])) @@ -321,13 +345,11 @@ class FilterCollection(Filter): @view_config( route_name="browse", renderer="fietsboek:templates/browse.jinja2", request_method="GET" ) -def browse(request): +def browse(request: Request) -> Response: """Returns the page that lets a user browse all visible tracks. :param request: The Pyramid request. - :type request: pyramid.request.Request :return: The HTTP response. - :rtype: pyramid.response.Response """ filters = FilterCollection.parse(request) track = aliased(models.Track, models.User.visible_tracks_query(request.identity).subquery()) @@ -348,13 +370,11 @@ def browse(request): @view_config(route_name="track-archive", request_method="GET") -def archive(request): +def archive(request: Request) -> Response: """Packs multiple tracks into a single archive. :param request: The Pyramid request. - :type request: pyramid.request.Request :return: The HTTP response. - :rtype: pyramid.response.Response """ track_ids = set(map(int, request.params.getall("track_id[]"))) tracks = ( -- cgit v1.2.3