diff options
| author | Daniel Schadt <kingdread@gmx.de> | 2023-05-09 19:09:43 +0200 | 
|---|---|---|
| committer | Daniel Schadt <kingdread@gmx.de> | 2023-05-09 19:09:43 +0200 | 
| commit | 4ad36ba65e33d32bd2ca1bae0cbc299b244369d8 (patch) | |
| tree | 94839c024d33be0c6f4ac742ff3dc333ee896447 | |
| parent | 5b40a857d02b8768c5bc14306d1934ed354b38d6 (diff) | |
| download | fietsboek-4ad36ba65e33d32bd2ca1bae0cbc299b244369d8.tar.gz fietsboek-4ad36ba65e33d32bd2ca1bae0cbc299b244369d8.tar.bz2 fietsboek-4ad36ba65e33d32bd2ca1bae0cbc299b244369d8.zip  | |
properly type hint browse.py
This is in preparation of adding the sorting feature.
| -rw-r--r-- | fietsboek/views/browse.py | 94 | 
1 files changed, 57 insertions, 37 deletions
diff --git a/fietsboek/views/browse.py b/fietsboek/views/browse.py index ede6b21..aed7d0a 100644 --- a/fietsboek/views/browse.py +++ b/fietsboek/views/browse.py @@ -1,14 +1,18 @@  """Views for browsing all tracks."""  import datetime +from collections.abc import Callable, Iterable +from enum import Enum  from io import RawIOBase -from typing import List  from zipfile import ZIP_DEFLATED, ZipFile  from pyramid.httpexceptions import HTTPBadRequest, HTTPForbidden, HTTPNotFound +from pyramid.request import Request  from pyramid.response import Response  from pyramid.view import view_config  from sqlalchemy import func, or_, select -from sqlalchemy.orm import aliased +from sqlalchemy.orm import DeclarativeMeta, aliased +from sqlalchemy.orm.util import AliasedClass +from sqlalchemy.sql import Selectable  from .. import models, util  from ..models.track import TrackType, TrackWithMetadata @@ -25,31 +29,35 @@ class Stream(RawIOBase):          super().__init__()          self.buffer = [] -    def write(self, b): +    # The following definition violates the substitution principle, so mypy +    # would complain. However, I think we're good acting like we take only +    # "bytes". +    def write(self, b: bytes) -> int:  # type: ignore +        b = bytes(b)          self.buffer.append(b)          return len(b) -    def readall(self): +    def readall(self) -> bytes:          buf = self.buffer          self.buffer = []          return b"".join(buf) -def _get_int(request, name): +def _get_int(request: Request, name: str) -> int:      try:          return int(request.params.get(name))      except ValueError as exc:          raise HTTPBadRequest(f"Invalid integer in {name!r}") from exc -def _get_date(request, name): +def _get_date(request: Request, name: str) -> datetime.date:      try:          return datetime.date.fromisoformat(request.params.get(name))      except ValueError as exc:          raise HTTPBadRequest(f"Invalid date in {name!r}") from exc -def _get_enum(enum, value): +def _get_enum(enum: type[Enum], value: str) -> Enum:      try:          return enum[value]      except KeyError as exc: @@ -59,7 +67,9 @@ def _get_enum(enum, value):  class Filter:      """A class representing a filter that the user can apply to the track list.""" -    def compile(self, query, track, track_cache): +    def compile( +        self, query: Selectable, track: AliasedClass, track_cache: type[DeclarativeMeta] +    ) -> Selectable:          """Compile the filter into the SQL query.          Returns the modified query. @@ -75,7 +85,7 @@ class Filter:          # pylint: disable=unused-argument          return query -    def apply(self, track): +    def apply(self, track: TrackWithMetadata) -> bool:          """Check if the given track matches the filter.          :param track: The track to check. @@ -87,40 +97,50 @@ class Filter:  class LambdaFilter(Filter):      """A :class:`Filter` that works by provided lambda functions.""" -    def __init__(self, compiler, matcher): +    def __init__( +        self, +        compiler: Callable[[Selectable, AliasedClass, type[DeclarativeMeta]], Selectable], +        matcher: Callable[[TrackWithMetadata], bool], +    ):          self.compiler = compiler          self.matcher = matcher -    def compile(self, query, track, track_cache): +    def compile( +        self, query: Selectable, track: AliasedClass, track_cache: type[DeclarativeMeta] +    ) -> Selectable:          return self.compiler(query, track, track_cache) -    def apply(self, track): +    def apply(self, track: TrackWithMetadata) -> bool:          return self.matcher(track)  class SearchFilter(Filter):      """A :class:`Filter` that looks for the given search terms.""" -    def __init__(self, search_terms): +    def __init__(self, search_terms: Iterable[str]):          self.search_terms = search_terms -    def compile(self, query, track, track_cache): +    def compile( +        self, query: Selectable, track: AliasedClass, track_cache: type[DeclarativeMeta] +    ) -> Selectable:          for term in self.search_terms:              term = term.lower()              query = query.where(func.lower(track.title).contains(term))          return query -    def apply(self, track): +    def apply(self, track: TrackWithMetadata) -> bool:          return all(term.lower() in track.title.lower() for term in self.search_terms)  class TagFilter(Filter):      """A :class:`Filter` that looks for the given tags.""" -    def __init__(self, tags): +    def __init__(self, tags: Iterable[str]):          self.tags = tags -    def compile(self, query, track, track_cache): +    def compile( +        self, query: Selectable, track: AliasedClass, track_cache: type[DeclarativeMeta] +    ) -> Selectable:          lower_tags = [tag.lower() for tag in self.tags]          for tag in lower_tags:              exists_query = ( @@ -132,7 +152,7 @@ class TagFilter(Filter):              query = query.where(exists_query)          return query -    def apply(self, track): +    def apply(self, track: TrackWithMetadata) -> bool:          lower_track_tags = {tag.lower() for tag in track.text_tags()}          lower_tags = {tag.lower() for tag in self.tags}          return all(tag in lower_track_tags for tag in lower_tags) @@ -141,10 +161,12 @@ class TagFilter(Filter):  class PersonFilter(Filter):      """A :class:`Filter` that looks for the given tagged people, based on their name.""" -    def __init__(self, names): +    def __init__(self, names: Iterable[str]):          self.names = names -    def compile(self, query, track, track_cache): +    def compile( +        self, query: Selectable, track: AliasedClass, track_cache: type[DeclarativeMeta] +    ) -> Selectable:          lower_names = [name.lower() for name in self.names]          for name in lower_names:              tpa = models.track.track_people_assoc @@ -164,7 +186,7 @@ class PersonFilter(Filter):              query = query.where(or_(exists_query, is_owner))          return query -    def apply(self, track): +    def apply(self, track: TrackWithMetadata) -> bool:          lower_names = {person.name.lower() for person in track.tagged_people}          lower_names.add(track.owner.name.lower())          return all(name.lower() in lower_names for name in self.names) @@ -173,10 +195,12 @@ class PersonFilter(Filter):  class UserTaggedFilter(Filter):      """A :class:`Filter` that looks for a specific user to be tagged.""" -    def __init__(self, user): +    def __init__(self, user: models.User):          self.user = user -    def compile(self, query, track, track_cache): +    def compile( +        self, query: Selectable, track: AliasedClass, track_cache: type[DeclarativeMeta] +    ) -> Selectable:          tpa = models.track.track_people_assoc          return query.where(              or_( @@ -190,39 +214,39 @@ class UserTaggedFilter(Filter):              )          ) -    def apply(self, track): +    def apply(self, track: TrackWithMetadata) -> bool:          return track.owner == self.user or self.user in track.tagged_people  class FilterCollection(Filter):      """A class that applies multiple :class:`Filter`.""" -    def __init__(self, filters): +    def __init__(self, filters: Iterable[Filter]):          self._filters = filters      def __bool__(self):          return bool(self._filters) -    def compile(self, query, track, track_cache): +    def compile( +        self, query: Selectable, track: AliasedClass, track_cache: type[DeclarativeMeta] +    ) -> Selectable:          for filty in self._filters:              query = filty.compile(query, track, track_cache)          return query -    def apply(self, track): +    def apply(self, track: TrackWithMetadata) -> bool:          return all(filty.apply(track) for filty in self._filters)      @classmethod -    def parse(cls, request): +    def parse(cls, request: Request) -> "FilterCollection":          """Parse the filters from the given request.          :raises HTTPBadRequest: If the filters are malformed.          :param request: The request. -        :type request: pyramid.request.Request          :return: The parsed filter. -        :rtype: FilterCollection          """          # pylint: disable=singleton-comparison -        filters: List[Filter] = [] +        filters: list[Filter] = []          if request.params.get("search-terms"):              term = request.params.get("search-terms").strip()              filters.append(SearchFilter([term])) @@ -321,13 +345,11 @@ class FilterCollection(Filter):  @view_config(      route_name="browse", renderer="fietsboek:templates/browse.jinja2", request_method="GET"  ) -def browse(request): +def browse(request: Request) -> Response:      """Returns the page that lets a user browse all visible tracks.      :param request: The Pyramid request. -    :type request: pyramid.request.Request      :return: The HTTP response. -    :rtype: pyramid.response.Response      """      filters = FilterCollection.parse(request)      track = aliased(models.Track, models.User.visible_tracks_query(request.identity).subquery()) @@ -348,13 +370,11 @@ def browse(request):  @view_config(route_name="track-archive", request_method="GET") -def archive(request): +def archive(request: Request) -> Response:      """Packs multiple tracks into a single archive.      :param request: The Pyramid request. -    :type request: pyramid.request.Request      :return: The HTTP response. -    :rtype: pyramid.response.Response      """      track_ids = set(map(int, request.params.getall("track_id[]")))      tracks = (  | 
