aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--fietsboek/models/user.py26
-rw-r--r--fietsboek/views/default.py8
2 files changed, 25 insertions, 9 deletions
diff --git a/fietsboek/models/user.py b/fietsboek/models/user.py
index 83a36e0..dfddd78 100644
--- a/fietsboek/models/user.py
+++ b/fietsboek/models/user.py
@@ -17,7 +17,7 @@ from sqlalchemy import (
DateTime,
Enum,
)
-from sqlalchemy.orm import relationship
+from sqlalchemy.orm import relationship, with_parent
from sqlalchemy.orm.session import object_session
from sqlalchemy.orm.attributes import flag_dirty
from sqlalchemy import select, union, delete, func
@@ -170,15 +170,31 @@ class User(Base):
@property
def all_tracks(self):
- """Yields all tracks in which the user participated.
+ """Returns a query that selects all the user's tracks.
This includes the user's own tracks, as well as any tracks the user has
been tagged in.
- :rtype: list[fietsboek.models.track.Track]
+ The returned query can be further modified, however, you need to use
+ :func:`sqlalchemy.orm.aliased` to access the correct objects:
+
+ >>> from fietsboek.models.track import Track, TrackType
+ >>> from sqlalchemy import select
+ >>> from sqlalchemy.orm import aliased
+ >>> user = retrieve_user()
+ >>> query = user.all_tracks
+ >>> query = query.filter(query.c.type == TrackType.ORGANIC)
+ >>> query = select(aliased(Track, query))
+
+ :rtype: sqlalchemy.sql.Selectable
"""
- yield from self.tracks
- yield from self.tagged_tracks
+ # Late import to avoid cycles
+ # pylint: disable=import-outside-toplevel
+ from .track import Track
+ own = select(Track).where(with_parent(self, User.tracks))
+ friends = select(Track).where(with_parent(self, User.tagged_tracks))
+ # Create a fresh select so we can apply filter operations
+ return union(own, friends).subquery()
def get_friends(self):
"""Returns all friends of the user.
diff --git a/fietsboek/views/default.py b/fietsboek/views/default.py
index 1d7bfa9..1df4942 100644
--- a/fietsboek/views/default.py
+++ b/fietsboek/views/default.py
@@ -6,6 +6,7 @@ from pyramid.i18n import TranslationString as _
from pyramid.renderers import render_to_response
from sqlalchemy import select
+from sqlalchemy.orm import aliased
from sqlalchemy.exc import NoResultFound
from .. import models, summaries, util, email
@@ -29,12 +30,11 @@ def home(request):
'home_content': content,
}
- all_tracks = request.identity.all_tracks
+ query = request.identity.all_tracks
+ query = select(aliased(models.Track, query)).where(query.c.type == TrackType.ORGANIC)
summary = summaries.Summary()
- for track in all_tracks:
- if track.type != TrackType.ORGANIC:
- continue
+ for track in request.dbsession.execute(query).scalars():
track.ensure_cache()
request.dbsession.add(track.cache)
summary.add(track)