aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--fietsboek/templates/browse.jinja22
-rw-r--r--fietsboek/views/browse.py263
2 files changed, 200 insertions, 65 deletions
diff --git a/fietsboek/templates/browse.jinja2 b/fietsboek/templates/browse.jinja2
index 6f3e0eb..f8937bb 100644
--- a/fietsboek/templates/browse.jinja2
+++ b/fietsboek/templates/browse.jinja2
@@ -85,7 +85,7 @@
{% endmacro %}
{{ render_switch("switchOnlyMyTracks", "show-only[]", "mine", _("page.browse.filter.my_tracks.only")) }}
{{ render_switch("switchOnlyFriendsTracks", "show-only[]", "friends", _("page.browse.filter.friends_tracks_only")) }}
- {{ render_switch("switchOnlyMeTagged", "user-tagged", "on", _("page.browse.filter.me_tagged_only")) }}
+ {{ render_switch("switchOnlyMeTagged", "show-only[]", "user-tagged", _("page.browse.filter.me_tagged_only")) }}
{% endif %}
</div>
diff --git a/fietsboek/views/browse.py b/fietsboek/views/browse.py
index 03d9bf8..462905f 100644
--- a/fietsboek/views/browse.py
+++ b/fietsboek/views/browse.py
@@ -7,10 +7,11 @@ from pyramid.view import view_config
from pyramid.httpexceptions import HTTPForbidden, HTTPNotFound, HTTPBadRequest
from pyramid.response import Response
-from sqlalchemy import select
+from sqlalchemy import select, func, or_
from sqlalchemy.orm import aliased
from .. import models, util
+from ..models.track import TrackType
class Stream(RawIOBase):
@@ -48,10 +49,143 @@ def _get_date(request, name):
raise HTTPBadRequest(f'Invalid date in {name!r}') from exc
-class TrackFilters:
- """A filter that applies user-given filters to a track."""
- # pylint: disable=fixme
- # TODO: We should also do some of those in SQL, if possible.
+def _get_enum(enum, value):
+ try:
+ return enum[value]
+ except KeyError as exc:
+ raise HTTPBadRequest(f'Invalid enum value {value!r}') from exc
+
+
+class Filter:
+ """A class representing a filter that the user can apply to the track list."""
+
+ def compile(self, query, track, track_cache):
+ """Compile the filter into the SQL query.
+
+ Returns the modified query.
+
+ This method is optional, as a pure-Python filtering can be done via
+ :meth:`apply`.
+
+ :param query: The original query to be modified, selecting over all tracks.
+ :param track: The track mapper.
+ :param track_cache: The track cache mapper.
+ :return: The modified query.
+ """
+ # pylint: disable=unused-argument
+ return query
+
+ def apply(self, track):
+ """Check if the given track matches the filter.
+
+ :param track: The track to check.
+ :return: ``True`` if the track matches.
+ """
+ raise NotImplementedError
+
+
+class LambdaFilter(Filter):
+ """A :class:`Filter` that works by provided lambda functions."""
+
+ def __init__(self, compiler, matcher):
+ self.compiler = compiler
+ self.matcher = matcher
+
+ def compile(self, query, track, track_cache):
+ return self.compiler(query, track, track_cache)
+
+ def apply(self, track):
+ return self.matcher(track)
+
+
+class SearchFilter(Filter):
+ """A :class:`Filter` that looks for the given search terms."""
+
+ def __init__(self, search_terms):
+ self.search_terms = search_terms
+
+ def compile(self, query, track, track_cache):
+ for term in self.search_terms:
+ term = term.lower()
+ query = query.where(func.lower(track.title).contains(term))
+ return query
+
+ def apply(self, track):
+ return all(term.lower() in track.title.lower() for term in self.search_terms)
+
+
+class TagFilter(Filter):
+ """A :class:`Filter` that looks for the given tags."""
+
+ def __init__(self, tags):
+ self.tags = tags
+
+ def compile(self, query, track, track_cache):
+ lower_tags = [tag.lower() for tag in self.tags]
+ for tag in lower_tags:
+ exists_query = (select(models.Tag)
+ .where(models.Tag.track_id == track.id)
+ .where(func.lower(models.Tag.tag) == tag)
+ .exists())
+ query = query.where(exists_query)
+ return query
+
+ def apply(self, track):
+ lower_track_tags = {tag.lower() for tag in track.text_tags()}
+ lower_tags = {tag.lower() for tag in self.tags}
+ return all(tag in lower_track_tags for tag in lower_tags)
+
+
+class PersonFilter(Filter):
+ """A :class:`Filter` that looks for the given tagged people, based on their name."""
+
+ def __init__(self, names):
+ self.names = names
+
+ def compile(self, query, track, track_cache):
+ lower_names = [name.lower() for name in self.names]
+ for name in lower_names:
+ tpa = models.track.track_people_assoc
+ exists_query = (select(tpa)
+ .join(models.User, tpa.c.user_id == models.User.id)
+ .where(tpa.c.track_id == track.id)
+ .where(func.lower(models.User.name) == name)
+ .exists())
+ is_owner = (select(models.User.id)
+ .where(models.User.id == track.owner_id)
+ .where(func.lower(models.User.name) == name)
+ .exists())
+ query = query.where(or_(exists_query, is_owner))
+ return query
+
+ def apply(self, track):
+ lower_names = {person.name.lower() for person in track.tagged_people}
+ lower_names.add(track.owner.name.lower())
+ return all(name.lower() in lower_names for name in self.names)
+
+
+class UserTaggedFilter(Filter):
+ """A :class:`Filter` that looks for a specific user to be tagged."""
+
+ def __init__(self, user):
+ self.user = user
+
+ def compile(self, query, track, track_cache):
+ tpa = models.track.track_people_assoc
+ return query.where(or_(
+ track.owner == self.user,
+ (select(tpa)
+ .where(tpa.c.track_id == track.id)
+ .where(tpa.c.user_id == self.user.id)
+ .exists()),
+ ))
+
+ def apply(self, track):
+ return track.owner == self.user or self.user in track.tagged_people
+
+
+class FilterCollection(Filter):
+ """A class that applies multiple :class:`Filter`."""
def __init__(self, filters):
self._filters = filters
@@ -59,15 +193,13 @@ class TrackFilters:
def __bool__(self):
return bool(self._filters)
- def apply(self, track):
- """Apply the filters to the track.
+ def compile(self, query, track, track_cache):
+ for filty in self._filters:
+ query = filty.compile(query, track, track_cache)
+ return query
- :param track: The track.
- :type track: fietsboek.models.track.Track
- :return: Whether the track matches the filters.
- :rtype: bool
- """
- return all(f(track) for f in self._filters)
+ def apply(self, track):
+ return all(filty.apply(track) for filty in self._filters)
@classmethod
def parse(cls, request):
@@ -77,87 +209,83 @@ class TrackFilters:
:param request: The request.
:type request: pyramid.request.Request
:return: The parsed filter.
- :rtype: TrackFilters
+ :rtype: FilterCollection
"""
+ # pylint: disable=singleton-comparison
filters = []
if request.params.get('search-terms'):
term = request.params.get('search-terms').strip()
- filters.append(lambda track: term.lower() in track.title.lower())
+ filters.append(SearchFilter([term]))
if request.params.get('tags'):
tags = [tag.strip() for tag in request.params.get('tags').split('&&')]
tags = list(filter(bool, tags))
-
- def has_tags(track):
- lower_tags = {tag.lower() for tag in track.text_tags()}
- return all(tag.lower() in lower_tags for tag in tags)
-
- filters.append(has_tags)
+ filters.append(TagFilter(tags))
if request.params.get('tagged-person'):
names = [name.strip() for name in request.params.get('tagged-person').split('&&')]
names = list(filter(bool, names))
-
- def has_people(track):
- peoples_names = [person.name for person in track.tagged_people]
- peoples_names.append(track.owner.name)
- peoples_names = set(map(str.lower, peoples_names))
- print(peoples_names)
- return all(name.lower() in peoples_names for name in names)
-
- filters.append(has_people)
+ filters.append(PersonFilter(names))
if request.params.get('min-length'):
# Value is given in km, so convert it to m
min_length = _get_int(request, "min-length") * 1000
- filters.append(lambda track: track.length >= min_length)
+ filters.append(LambdaFilter(
+ lambda query, track, track_cache:
+ query.where(or_(track_cache.length >= min_length,
+ track_cache.length == None)), # noqa: E711
+ lambda track: track.length >= min_length,
+ ))
if request.params.get('max-length'):
max_length = _get_int(request, "max-length") * 1000
- filters.append(lambda track: track.length <= max_length)
+ filters.append(LambdaFilter(
+ lambda query, track, track_cache:
+ query.where(or_(track_cache.length <= max_length,
+ track_cache.length == None)), # noqa: E711
+ lambda track: track.length <= max_length,
+ ))
if request.params.get('min-date'):
min_date = _get_date(request, "min-date")
- filters.append(lambda track: track.date.date() >= min_date)
+ min_date = datetime.datetime.combine(min_date, datetime.time.min)
+ filters.append(LambdaFilter(
+ lambda query, track, track_cache: query.where(track.date_raw >= min_date),
+ lambda track: track.date.replace(tzinfo=None) >= min_date,
+ ))
if request.params.get('max-date'):
max_date = _get_date(request, "max-date")
- filters.append(lambda track: track.date.date() <= max_date)
+ max_date = datetime.datetime.combine(max_date, datetime.time.max)
+ filters.append(LambdaFilter(
+ lambda query, track, track_cache: query.where(track.date_raw <= max_date),
+ lambda track: track.date.replace(tzinfo=None) <= max_date,
+ ))
if "mine" in request.params.getall('show-only[]'):
- filters.append(lambda track: track.owner == request.identity)
+ filters.append(LambdaFilter(
+ lambda query, track, track_cache: query.where(track.owner == request.identity),
+ lambda track: track.owner == request.identity,
+ ))
- if "friends" in request.params.getall('show-only[]'):
- filters.append(lambda track: request.identity and
- track.owner in request.identity.get_friends())
+ if "friends" in request.params.getall('show-only[]') and request.identity:
+ friend_ids = {friend.id for friend in request.identity.get_friends()}
+ filters.append(LambdaFilter(
+ lambda query, track, track_cache: query.where(track.owner_id.in_(friend_ids)),
+ lambda track: track.owner in request.identity.get_friends(),
+ ))
- if request.params.get('user-tagged'):
- filters.append(lambda track: request.identity and
- (track.owner == request.identity or
- request.identity in track.tagged_people))
+ if "user-tagged" in request.params.getall('show-only[]') and request.identity:
+ filters.append(UserTaggedFilter(request.identity))
if 'type[]' in request.params:
- filters.append(lambda track: track.type.name in request.params.getall('type[]'))
-
- return TrackFilters(filters)
+ types = {_get_enum(TrackType, value) for value in request.params.getall('type[]')}
+ filters.append(LambdaFilter(
+ lambda query, track, track_cache: query.where(track.type.in_(types)),
+ lambda track: track.type in types,
+ ))
-
-def visible_tracks(dbsession, user):
- """Returns all visible tracks for the given user.
-
- The user might be ``None``, in which case all public tracks are returned.
-
- :param dbsession: The database session.
- :type dbsession: sqlalchemy.orm.session.Session
- :param user: The user to get the tracks for.
- :type user: fietsboek.models.user.User
- :return: The list of visible tracks.
- :rtype: ~collections.abc.Iterable[fietsboek.models.track.Track]
- """
- temp_track = aliased(models.Track, models.User.visible_tracks_query(user).subquery())
- query = select(temp_track).order_by(temp_track.date_raw.desc())
- tracks = dbsession.execute(query).scalars()
- return tracks
+ return cls(filters)
@view_config(route_name="browse", renderer="fietsboek:templates/browse.jinja2",
@@ -170,8 +298,15 @@ def browse(request):
:return: The HTTP response.
:rtype: pyramid.response.Response
"""
- filters = TrackFilters.parse(request)
- tracks = visible_tracks(request.dbsession, request.identity)
+ filters = FilterCollection.parse(request)
+ track = aliased(models.Track, models.User.visible_tracks_query(request.identity).subquery())
+
+ # Build our query
+ query = select(track).join(models.TrackCache)
+ query = filters.compile(query, track, models.TrackCache)
+ query = query.order_by(track.date_raw.desc())
+
+ tracks = request.dbsession.execute(query).scalars()
tracks = [track for track in tracks if filters.apply(track)]
return {
'tracks': tracks,