From c71663c81f594b992c4d46d2e3136ddd6a9b5a21 Mon Sep 17 00:00:00 2001 From: Daniel Schadt Date: Fri, 12 Aug 2022 23:21:49 +0200 Subject: implement browse filters in SQL This is a continuation of the previous two commits, in which we do the filtering in the SQL query instead of retrieving all objects and then filtering them in Python. The generated SQL can be quite complex, but 1) most of it comes from the logic of determining the visible tracks and 2) it is built piece-by-piece with small Python functions. Therefore, it should be okay. --- fietsboek/templates/browse.jinja2 | 2 +- fietsboek/views/browse.py | 263 ++++++++++++++++++++++++++++---------- 2 files changed, 200 insertions(+), 65 deletions(-) diff --git a/fietsboek/templates/browse.jinja2 b/fietsboek/templates/browse.jinja2 index 6f3e0eb..f8937bb 100644 --- a/fietsboek/templates/browse.jinja2 +++ b/fietsboek/templates/browse.jinja2 @@ -85,7 +85,7 @@ {% endmacro %} {{ render_switch("switchOnlyMyTracks", "show-only[]", "mine", _("page.browse.filter.my_tracks.only")) }} {{ render_switch("switchOnlyFriendsTracks", "show-only[]", "friends", _("page.browse.filter.friends_tracks_only")) }} - {{ render_switch("switchOnlyMeTagged", "user-tagged", "on", _("page.browse.filter.me_tagged_only")) }} + {{ render_switch("switchOnlyMeTagged", "show-only[]", "user-tagged", _("page.browse.filter.me_tagged_only")) }} {% endif %} diff --git a/fietsboek/views/browse.py b/fietsboek/views/browse.py index 03d9bf8..462905f 100644 --- a/fietsboek/views/browse.py +++ b/fietsboek/views/browse.py @@ -7,10 +7,11 @@ from pyramid.view import view_config from pyramid.httpexceptions import HTTPForbidden, HTTPNotFound, HTTPBadRequest from pyramid.response import Response -from sqlalchemy import select +from sqlalchemy import select, func, or_ from sqlalchemy.orm import aliased from .. import models, util +from ..models.track import TrackType class Stream(RawIOBase): @@ -48,10 +49,143 @@ def _get_date(request, name): raise HTTPBadRequest(f'Invalid date in {name!r}') from exc -class TrackFilters: - """A filter that applies user-given filters to a track.""" - # pylint: disable=fixme - # TODO: We should also do some of those in SQL, if possible. +def _get_enum(enum, value): + try: + return enum[value] + except KeyError as exc: + raise HTTPBadRequest(f'Invalid enum value {value!r}') from exc + + +class Filter: + """A class representing a filter that the user can apply to the track list.""" + + def compile(self, query, track, track_cache): + """Compile the filter into the SQL query. + + Returns the modified query. + + This method is optional, as a pure-Python filtering can be done via + :meth:`apply`. + + :param query: The original query to be modified, selecting over all tracks. + :param track: The track mapper. + :param track_cache: The track cache mapper. + :return: The modified query. + """ + # pylint: disable=unused-argument + return query + + def apply(self, track): + """Check if the given track matches the filter. + + :param track: The track to check. + :return: ``True`` if the track matches. + """ + raise NotImplementedError + + +class LambdaFilter(Filter): + """A :class:`Filter` that works by provided lambda functions.""" + + def __init__(self, compiler, matcher): + self.compiler = compiler + self.matcher = matcher + + def compile(self, query, track, track_cache): + return self.compiler(query, track, track_cache) + + def apply(self, track): + return self.matcher(track) + + +class SearchFilter(Filter): + """A :class:`Filter` that looks for the given search terms.""" + + def __init__(self, search_terms): + self.search_terms = search_terms + + def compile(self, query, track, track_cache): + for term in self.search_terms: + term = term.lower() + query = query.where(func.lower(track.title).contains(term)) + return query + + def apply(self, track): + return all(term.lower() in track.title.lower() for term in self.search_terms) + + +class TagFilter(Filter): + """A :class:`Filter` that looks for the given tags.""" + + def __init__(self, tags): + self.tags = tags + + def compile(self, query, track, track_cache): + lower_tags = [tag.lower() for tag in self.tags] + for tag in lower_tags: + exists_query = (select(models.Tag) + .where(models.Tag.track_id == track.id) + .where(func.lower(models.Tag.tag) == tag) + .exists()) + query = query.where(exists_query) + return query + + def apply(self, track): + lower_track_tags = {tag.lower() for tag in track.text_tags()} + lower_tags = {tag.lower() for tag in self.tags} + return all(tag in lower_track_tags for tag in lower_tags) + + +class PersonFilter(Filter): + """A :class:`Filter` that looks for the given tagged people, based on their name.""" + + def __init__(self, names): + self.names = names + + def compile(self, query, track, track_cache): + lower_names = [name.lower() for name in self.names] + for name in lower_names: + tpa = models.track.track_people_assoc + exists_query = (select(tpa) + .join(models.User, tpa.c.user_id == models.User.id) + .where(tpa.c.track_id == track.id) + .where(func.lower(models.User.name) == name) + .exists()) + is_owner = (select(models.User.id) + .where(models.User.id == track.owner_id) + .where(func.lower(models.User.name) == name) + .exists()) + query = query.where(or_(exists_query, is_owner)) + return query + + def apply(self, track): + lower_names = {person.name.lower() for person in track.tagged_people} + lower_names.add(track.owner.name.lower()) + return all(name.lower() in lower_names for name in self.names) + + +class UserTaggedFilter(Filter): + """A :class:`Filter` that looks for a specific user to be tagged.""" + + def __init__(self, user): + self.user = user + + def compile(self, query, track, track_cache): + tpa = models.track.track_people_assoc + return query.where(or_( + track.owner == self.user, + (select(tpa) + .where(tpa.c.track_id == track.id) + .where(tpa.c.user_id == self.user.id) + .exists()), + )) + + def apply(self, track): + return track.owner == self.user or self.user in track.tagged_people + + +class FilterCollection(Filter): + """A class that applies multiple :class:`Filter`.""" def __init__(self, filters): self._filters = filters @@ -59,15 +193,13 @@ class TrackFilters: def __bool__(self): return bool(self._filters) - def apply(self, track): - """Apply the filters to the track. + def compile(self, query, track, track_cache): + for filty in self._filters: + query = filty.compile(query, track, track_cache) + return query - :param track: The track. - :type track: fietsboek.models.track.Track - :return: Whether the track matches the filters. - :rtype: bool - """ - return all(f(track) for f in self._filters) + def apply(self, track): + return all(filty.apply(track) for filty in self._filters) @classmethod def parse(cls, request): @@ -77,87 +209,83 @@ class TrackFilters: :param request: The request. :type request: pyramid.request.Request :return: The parsed filter. - :rtype: TrackFilters + :rtype: FilterCollection """ + # pylint: disable=singleton-comparison filters = [] if request.params.get('search-terms'): term = request.params.get('search-terms').strip() - filters.append(lambda track: term.lower() in track.title.lower()) + filters.append(SearchFilter([term])) if request.params.get('tags'): tags = [tag.strip() for tag in request.params.get('tags').split('&&')] tags = list(filter(bool, tags)) - - def has_tags(track): - lower_tags = {tag.lower() for tag in track.text_tags()} - return all(tag.lower() in lower_tags for tag in tags) - - filters.append(has_tags) + filters.append(TagFilter(tags)) if request.params.get('tagged-person'): names = [name.strip() for name in request.params.get('tagged-person').split('&&')] names = list(filter(bool, names)) - - def has_people(track): - peoples_names = [person.name for person in track.tagged_people] - peoples_names.append(track.owner.name) - peoples_names = set(map(str.lower, peoples_names)) - print(peoples_names) - return all(name.lower() in peoples_names for name in names) - - filters.append(has_people) + filters.append(PersonFilter(names)) if request.params.get('min-length'): # Value is given in km, so convert it to m min_length = _get_int(request, "min-length") * 1000 - filters.append(lambda track: track.length >= min_length) + filters.append(LambdaFilter( + lambda query, track, track_cache: + query.where(or_(track_cache.length >= min_length, + track_cache.length == None)), # noqa: E711 + lambda track: track.length >= min_length, + )) if request.params.get('max-length'): max_length = _get_int(request, "max-length") * 1000 - filters.append(lambda track: track.length <= max_length) + filters.append(LambdaFilter( + lambda query, track, track_cache: + query.where(or_(track_cache.length <= max_length, + track_cache.length == None)), # noqa: E711 + lambda track: track.length <= max_length, + )) if request.params.get('min-date'): min_date = _get_date(request, "min-date") - filters.append(lambda track: track.date.date() >= min_date) + min_date = datetime.datetime.combine(min_date, datetime.time.min) + filters.append(LambdaFilter( + lambda query, track, track_cache: query.where(track.date_raw >= min_date), + lambda track: track.date.replace(tzinfo=None) >= min_date, + )) if request.params.get('max-date'): max_date = _get_date(request, "max-date") - filters.append(lambda track: track.date.date() <= max_date) + max_date = datetime.datetime.combine(max_date, datetime.time.max) + filters.append(LambdaFilter( + lambda query, track, track_cache: query.where(track.date_raw <= max_date), + lambda track: track.date.replace(tzinfo=None) <= max_date, + )) if "mine" in request.params.getall('show-only[]'): - filters.append(lambda track: track.owner == request.identity) + filters.append(LambdaFilter( + lambda query, track, track_cache: query.where(track.owner == request.identity), + lambda track: track.owner == request.identity, + )) - if "friends" in request.params.getall('show-only[]'): - filters.append(lambda track: request.identity and - track.owner in request.identity.get_friends()) + if "friends" in request.params.getall('show-only[]') and request.identity: + friend_ids = {friend.id for friend in request.identity.get_friends()} + filters.append(LambdaFilter( + lambda query, track, track_cache: query.where(track.owner_id.in_(friend_ids)), + lambda track: track.owner in request.identity.get_friends(), + )) - if request.params.get('user-tagged'): - filters.append(lambda track: request.identity and - (track.owner == request.identity or - request.identity in track.tagged_people)) + if "user-tagged" in request.params.getall('show-only[]') and request.identity: + filters.append(UserTaggedFilter(request.identity)) if 'type[]' in request.params: - filters.append(lambda track: track.type.name in request.params.getall('type[]')) - - return TrackFilters(filters) + types = {_get_enum(TrackType, value) for value in request.params.getall('type[]')} + filters.append(LambdaFilter( + lambda query, track, track_cache: query.where(track.type.in_(types)), + lambda track: track.type in types, + )) - -def visible_tracks(dbsession, user): - """Returns all visible tracks for the given user. - - The user might be ``None``, in which case all public tracks are returned. - - :param dbsession: The database session. - :type dbsession: sqlalchemy.orm.session.Session - :param user: The user to get the tracks for. - :type user: fietsboek.models.user.User - :return: The list of visible tracks. - :rtype: ~collections.abc.Iterable[fietsboek.models.track.Track] - """ - temp_track = aliased(models.Track, models.User.visible_tracks_query(user).subquery()) - query = select(temp_track).order_by(temp_track.date_raw.desc()) - tracks = dbsession.execute(query).scalars() - return tracks + return cls(filters) @view_config(route_name="browse", renderer="fietsboek:templates/browse.jinja2", @@ -170,8 +298,15 @@ def browse(request): :return: The HTTP response. :rtype: pyramid.response.Response """ - filters = TrackFilters.parse(request) - tracks = visible_tracks(request.dbsession, request.identity) + filters = FilterCollection.parse(request) + track = aliased(models.Track, models.User.visible_tracks_query(request.identity).subquery()) + + # Build our query + query = select(track).join(models.TrackCache) + query = filters.compile(query, track, models.TrackCache) + query = query.order_by(track.date_raw.desc()) + + tracks = request.dbsession.execute(query).scalars() tracks = [track for track in tracks if filters.apply(track)] return { 'tracks': tracks, -- cgit v1.2.3