diff options
-rw-r--r-- | fietsboek/models/user.py | 26 | ||||
-rw-r--r-- | fietsboek/views/default.py | 8 |
2 files changed, 25 insertions, 9 deletions
diff --git a/fietsboek/models/user.py b/fietsboek/models/user.py index 83a36e0..dfddd78 100644 --- a/fietsboek/models/user.py +++ b/fietsboek/models/user.py @@ -17,7 +17,7 @@ from sqlalchemy import ( DateTime, Enum, ) -from sqlalchemy.orm import relationship +from sqlalchemy.orm import relationship, with_parent from sqlalchemy.orm.session import object_session from sqlalchemy.orm.attributes import flag_dirty from sqlalchemy import select, union, delete, func @@ -170,15 +170,31 @@ class User(Base): @property def all_tracks(self): - """Yields all tracks in which the user participated. + """Returns a query that selects all the user's tracks. This includes the user's own tracks, as well as any tracks the user has been tagged in. - :rtype: list[fietsboek.models.track.Track] + The returned query can be further modified, however, you need to use + :func:`sqlalchemy.orm.aliased` to access the correct objects: + + >>> from fietsboek.models.track import Track, TrackType + >>> from sqlalchemy import select + >>> from sqlalchemy.orm import aliased + >>> user = retrieve_user() + >>> query = user.all_tracks + >>> query = query.filter(query.c.type == TrackType.ORGANIC) + >>> query = select(aliased(Track, query)) + + :rtype: sqlalchemy.sql.Selectable """ - yield from self.tracks - yield from self.tagged_tracks + # Late import to avoid cycles + # pylint: disable=import-outside-toplevel + from .track import Track + own = select(Track).where(with_parent(self, User.tracks)) + friends = select(Track).where(with_parent(self, User.tagged_tracks)) + # Create a fresh select so we can apply filter operations + return union(own, friends).subquery() def get_friends(self): """Returns all friends of the user. diff --git a/fietsboek/views/default.py b/fietsboek/views/default.py index 1d7bfa9..1df4942 100644 --- a/fietsboek/views/default.py +++ b/fietsboek/views/default.py @@ -6,6 +6,7 @@ from pyramid.i18n import TranslationString as _ from pyramid.renderers import render_to_response from sqlalchemy import select +from sqlalchemy.orm import aliased from sqlalchemy.exc import NoResultFound from .. import models, summaries, util, email @@ -29,12 +30,11 @@ def home(request): 'home_content': content, } - all_tracks = request.identity.all_tracks + query = request.identity.all_tracks + query = select(aliased(models.Track, query)).where(query.c.type == TrackType.ORGANIC) summary = summaries.Summary() - for track in all_tracks: - if track.type != TrackType.ORGANIC: - continue + for track in request.dbsession.execute(query).scalars(): track.ensure_cache() request.dbsession.add(track.cache) summary.add(track) |