diff options
-rw-r--r-- | fietsboek/templates/browse.jinja2 | 2 | ||||
-rw-r--r-- | fietsboek/views/browse.py | 263 |
2 files changed, 200 insertions, 65 deletions
diff --git a/fietsboek/templates/browse.jinja2 b/fietsboek/templates/browse.jinja2 index 6f3e0eb..f8937bb 100644 --- a/fietsboek/templates/browse.jinja2 +++ b/fietsboek/templates/browse.jinja2 @@ -85,7 +85,7 @@ {% endmacro %} {{ render_switch("switchOnlyMyTracks", "show-only[]", "mine", _("page.browse.filter.my_tracks.only")) }} {{ render_switch("switchOnlyFriendsTracks", "show-only[]", "friends", _("page.browse.filter.friends_tracks_only")) }} - {{ render_switch("switchOnlyMeTagged", "user-tagged", "on", _("page.browse.filter.me_tagged_only")) }} + {{ render_switch("switchOnlyMeTagged", "show-only[]", "user-tagged", _("page.browse.filter.me_tagged_only")) }} {% endif %} </div> diff --git a/fietsboek/views/browse.py b/fietsboek/views/browse.py index 03d9bf8..462905f 100644 --- a/fietsboek/views/browse.py +++ b/fietsboek/views/browse.py @@ -7,10 +7,11 @@ from pyramid.view import view_config from pyramid.httpexceptions import HTTPForbidden, HTTPNotFound, HTTPBadRequest from pyramid.response import Response -from sqlalchemy import select +from sqlalchemy import select, func, or_ from sqlalchemy.orm import aliased from .. import models, util +from ..models.track import TrackType class Stream(RawIOBase): @@ -48,10 +49,143 @@ def _get_date(request, name): raise HTTPBadRequest(f'Invalid date in {name!r}') from exc -class TrackFilters: - """A filter that applies user-given filters to a track.""" - # pylint: disable=fixme - # TODO: We should also do some of those in SQL, if possible. +def _get_enum(enum, value): + try: + return enum[value] + except KeyError as exc: + raise HTTPBadRequest(f'Invalid enum value {value!r}') from exc + + +class Filter: + """A class representing a filter that the user can apply to the track list.""" + + def compile(self, query, track, track_cache): + """Compile the filter into the SQL query. + + Returns the modified query. + + This method is optional, as a pure-Python filtering can be done via + :meth:`apply`. + + :param query: The original query to be modified, selecting over all tracks. + :param track: The track mapper. + :param track_cache: The track cache mapper. + :return: The modified query. + """ + # pylint: disable=unused-argument + return query + + def apply(self, track): + """Check if the given track matches the filter. + + :param track: The track to check. + :return: ``True`` if the track matches. + """ + raise NotImplementedError + + +class LambdaFilter(Filter): + """A :class:`Filter` that works by provided lambda functions.""" + + def __init__(self, compiler, matcher): + self.compiler = compiler + self.matcher = matcher + + def compile(self, query, track, track_cache): + return self.compiler(query, track, track_cache) + + def apply(self, track): + return self.matcher(track) + + +class SearchFilter(Filter): + """A :class:`Filter` that looks for the given search terms.""" + + def __init__(self, search_terms): + self.search_terms = search_terms + + def compile(self, query, track, track_cache): + for term in self.search_terms: + term = term.lower() + query = query.where(func.lower(track.title).contains(term)) + return query + + def apply(self, track): + return all(term.lower() in track.title.lower() for term in self.search_terms) + + +class TagFilter(Filter): + """A :class:`Filter` that looks for the given tags.""" + + def __init__(self, tags): + self.tags = tags + + def compile(self, query, track, track_cache): + lower_tags = [tag.lower() for tag in self.tags] + for tag in lower_tags: + exists_query = (select(models.Tag) + .where(models.Tag.track_id == track.id) + .where(func.lower(models.Tag.tag) == tag) + .exists()) + query = query.where(exists_query) + return query + + def apply(self, track): + lower_track_tags = {tag.lower() for tag in track.text_tags()} + lower_tags = {tag.lower() for tag in self.tags} + return all(tag in lower_track_tags for tag in lower_tags) + + +class PersonFilter(Filter): + """A :class:`Filter` that looks for the given tagged people, based on their name.""" + + def __init__(self, names): + self.names = names + + def compile(self, query, track, track_cache): + lower_names = [name.lower() for name in self.names] + for name in lower_names: + tpa = models.track.track_people_assoc + exists_query = (select(tpa) + .join(models.User, tpa.c.user_id == models.User.id) + .where(tpa.c.track_id == track.id) + .where(func.lower(models.User.name) == name) + .exists()) + is_owner = (select(models.User.id) + .where(models.User.id == track.owner_id) + .where(func.lower(models.User.name) == name) + .exists()) + query = query.where(or_(exists_query, is_owner)) + return query + + def apply(self, track): + lower_names = {person.name.lower() for person in track.tagged_people} + lower_names.add(track.owner.name.lower()) + return all(name.lower() in lower_names for name in self.names) + + +class UserTaggedFilter(Filter): + """A :class:`Filter` that looks for a specific user to be tagged.""" + + def __init__(self, user): + self.user = user + + def compile(self, query, track, track_cache): + tpa = models.track.track_people_assoc + return query.where(or_( + track.owner == self.user, + (select(tpa) + .where(tpa.c.track_id == track.id) + .where(tpa.c.user_id == self.user.id) + .exists()), + )) + + def apply(self, track): + return track.owner == self.user or self.user in track.tagged_people + + +class FilterCollection(Filter): + """A class that applies multiple :class:`Filter`.""" def __init__(self, filters): self._filters = filters @@ -59,15 +193,13 @@ class TrackFilters: def __bool__(self): return bool(self._filters) - def apply(self, track): - """Apply the filters to the track. + def compile(self, query, track, track_cache): + for filty in self._filters: + query = filty.compile(query, track, track_cache) + return query - :param track: The track. - :type track: fietsboek.models.track.Track - :return: Whether the track matches the filters. - :rtype: bool - """ - return all(f(track) for f in self._filters) + def apply(self, track): + return all(filty.apply(track) for filty in self._filters) @classmethod def parse(cls, request): @@ -77,87 +209,83 @@ class TrackFilters: :param request: The request. :type request: pyramid.request.Request :return: The parsed filter. - :rtype: TrackFilters + :rtype: FilterCollection """ + # pylint: disable=singleton-comparison filters = [] if request.params.get('search-terms'): term = request.params.get('search-terms').strip() - filters.append(lambda track: term.lower() in track.title.lower()) + filters.append(SearchFilter([term])) if request.params.get('tags'): tags = [tag.strip() for tag in request.params.get('tags').split('&&')] tags = list(filter(bool, tags)) - - def has_tags(track): - lower_tags = {tag.lower() for tag in track.text_tags()} - return all(tag.lower() in lower_tags for tag in tags) - - filters.append(has_tags) + filters.append(TagFilter(tags)) if request.params.get('tagged-person'): names = [name.strip() for name in request.params.get('tagged-person').split('&&')] names = list(filter(bool, names)) - - def has_people(track): - peoples_names = [person.name for person in track.tagged_people] - peoples_names.append(track.owner.name) - peoples_names = set(map(str.lower, peoples_names)) - print(peoples_names) - return all(name.lower() in peoples_names for name in names) - - filters.append(has_people) + filters.append(PersonFilter(names)) if request.params.get('min-length'): # Value is given in km, so convert it to m min_length = _get_int(request, "min-length") * 1000 - filters.append(lambda track: track.length >= min_length) + filters.append(LambdaFilter( + lambda query, track, track_cache: + query.where(or_(track_cache.length >= min_length, + track_cache.length == None)), # noqa: E711 + lambda track: track.length >= min_length, + )) if request.params.get('max-length'): max_length = _get_int(request, "max-length") * 1000 - filters.append(lambda track: track.length <= max_length) + filters.append(LambdaFilter( + lambda query, track, track_cache: + query.where(or_(track_cache.length <= max_length, + track_cache.length == None)), # noqa: E711 + lambda track: track.length <= max_length, + )) if request.params.get('min-date'): min_date = _get_date(request, "min-date") - filters.append(lambda track: track.date.date() >= min_date) + min_date = datetime.datetime.combine(min_date, datetime.time.min) + filters.append(LambdaFilter( + lambda query, track, track_cache: query.where(track.date_raw >= min_date), + lambda track: track.date.replace(tzinfo=None) >= min_date, + )) if request.params.get('max-date'): max_date = _get_date(request, "max-date") - filters.append(lambda track: track.date.date() <= max_date) + max_date = datetime.datetime.combine(max_date, datetime.time.max) + filters.append(LambdaFilter( + lambda query, track, track_cache: query.where(track.date_raw <= max_date), + lambda track: track.date.replace(tzinfo=None) <= max_date, + )) if "mine" in request.params.getall('show-only[]'): - filters.append(lambda track: track.owner == request.identity) + filters.append(LambdaFilter( + lambda query, track, track_cache: query.where(track.owner == request.identity), + lambda track: track.owner == request.identity, + )) - if "friends" in request.params.getall('show-only[]'): - filters.append(lambda track: request.identity and - track.owner in request.identity.get_friends()) + if "friends" in request.params.getall('show-only[]') and request.identity: + friend_ids = {friend.id for friend in request.identity.get_friends()} + filters.append(LambdaFilter( + lambda query, track, track_cache: query.where(track.owner_id.in_(friend_ids)), + lambda track: track.owner in request.identity.get_friends(), + )) - if request.params.get('user-tagged'): - filters.append(lambda track: request.identity and - (track.owner == request.identity or - request.identity in track.tagged_people)) + if "user-tagged" in request.params.getall('show-only[]') and request.identity: + filters.append(UserTaggedFilter(request.identity)) if 'type[]' in request.params: - filters.append(lambda track: track.type.name in request.params.getall('type[]')) - - return TrackFilters(filters) + types = {_get_enum(TrackType, value) for value in request.params.getall('type[]')} + filters.append(LambdaFilter( + lambda query, track, track_cache: query.where(track.type.in_(types)), + lambda track: track.type in types, + )) - -def visible_tracks(dbsession, user): - """Returns all visible tracks for the given user. - - The user might be ``None``, in which case all public tracks are returned. - - :param dbsession: The database session. - :type dbsession: sqlalchemy.orm.session.Session - :param user: The user to get the tracks for. - :type user: fietsboek.models.user.User - :return: The list of visible tracks. - :rtype: ~collections.abc.Iterable[fietsboek.models.track.Track] - """ - temp_track = aliased(models.Track, models.User.visible_tracks_query(user).subquery()) - query = select(temp_track).order_by(temp_track.date_raw.desc()) - tracks = dbsession.execute(query).scalars() - return tracks + return cls(filters) @view_config(route_name="browse", renderer="fietsboek:templates/browse.jinja2", @@ -170,8 +298,15 @@ def browse(request): :return: The HTTP response. :rtype: pyramid.response.Response """ - filters = TrackFilters.parse(request) - tracks = visible_tracks(request.dbsession, request.identity) + filters = FilterCollection.parse(request) + track = aliased(models.Track, models.User.visible_tracks_query(request.identity).subquery()) + + # Build our query + query = select(track).join(models.TrackCache) + query = filters.compile(query, track, models.TrackCache) + query = query.order_by(track.date_raw.desc()) + + tracks = request.dbsession.execute(query).scalars() tracks = [track for track in tracks if filters.apply(track)] return { 'tracks': tracks, |