From ecba9267cf487a5e919588c5b9b22ecb6739c375 Mon Sep 17 00:00:00 2001 From: Gabriel Kaszewski Date: Thu, 14 May 2026 16:03:55 +0200 Subject: [PATCH] =?UTF-8?q?fix:=20compute=20liked=5Fby=5Fviewer/boosted=5F?= =?UTF-8?q?by=5Fviewer=20from=20DB=20=E2=80=94=20viewer=5Fid=20was=20ignor?= =?UTF-8?q?ed=20in=20all=20feed=20queries?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- crates/adapters/postgres/src/feed.rs | 44 +++++++++++++++++------- crates/application/src/use_cases/feed.rs | 4 +-- crates/presentation/src/handlers/feed.rs | 3 +- 3 files changed, 36 insertions(+), 15 deletions(-) diff --git a/crates/adapters/postgres/src/feed.rs b/crates/adapters/postgres/src/feed.rs index bd9b553..664fca8 100644 --- a/crates/adapters/postgres/src/feed.rs +++ b/crates/adapters/postgres/src/feed.rs @@ -45,9 +45,19 @@ struct FeedRow { like_count: i64, boost_count: i64, reply_count: i64, + liked_by_viewer: bool, + boosted_by_viewer: bool, } -const FEED_SELECT: &str = " +fn feed_select(viewer: Option) -> String { + let viewer_checks = match viewer { + Some(uid) => format!( + "EXISTS(SELECT 1 FROM likes WHERE user_id='{uid}' AND thought_id=t.id) AS liked_by_viewer, + EXISTS(SELECT 1 FROM boosts WHERE user_id='{uid}' AND thought_id=t.id) AS boosted_by_viewer" + ), + None => "false AS liked_by_viewer, false AS boosted_by_viewer".to_string(), + }; + format!(" SELECT t.id AS thought_id, t.user_id AS t_user_id, t.content, t.in_reply_to_id, t.in_reply_to_url, t.ap_id AS t_ap_id, @@ -60,8 +70,10 @@ const FEED_SELECT: &str = " u.created_at AS author_created_at, u.updated_at AS author_updated_at, (SELECT COUNT(*) FROM likes l WHERE l.thought_id=t.id) AS like_count, (SELECT COUNT(*) FROM boosts b WHERE b.thought_id=t.id) AS boost_count, - (SELECT COUNT(*) FROM thoughts r WHERE r.in_reply_to_id=t.id) AS reply_count - FROM thoughts t JOIN users u ON u.id=t.user_id"; + (SELECT COUNT(*) FROM thoughts r WHERE r.in_reply_to_id=t.id) AS reply_count, + {viewer_checks} + FROM thoughts t JOIN users u ON u.id=t.user_id") +} fn row_to_entry(r: FeedRow) -> FeedEntry { let thought = Thought { @@ -89,18 +101,20 @@ fn row_to_entry(r: FeedRow) -> FeedEntry { public_key: r.public_key, private_key: r.private_key, created_at: r.author_created_at, updated_at: r.author_updated_at, }; - FeedEntry { thought, author, like_count: r.like_count, boost_count: r.boost_count, reply_count: r.reply_count, liked_by_viewer: false, boosted_by_viewer: false } + FeedEntry { thought, author, like_count: r.like_count, boost_count: r.boost_count, reply_count: r.reply_count, liked_by_viewer: r.liked_by_viewer, boosted_by_viewer: r.boosted_by_viewer } } #[async_trait] impl FeedRepository for PgFeedRepository { - async fn home_feed(&self, following_ids: &[UserId], page: &PageParams, _viewer_id: Option<&UserId>) -> Result, DomainError> { + async fn home_feed(&self, following_ids: &[UserId], page: &PageParams, viewer_id: Option<&UserId>) -> Result, DomainError> { let ids: Vec = following_ids.iter().map(|id| id.as_uuid()).collect(); + let viewer = viewer_id.map(|v| v.as_uuid()); let total: i64 = sqlx::query_scalar( "SELECT COUNT(*) FROM thoughts t WHERE t.user_id=ANY($1) AND t.visibility='public'" ).bind(&ids).fetch_one(&self.pool).await.map_err(|e| DomainError::Internal(e.to_string()))?; - let sql = format!("{FEED_SELECT} WHERE t.user_id=ANY($1) AND t.visibility='public' ORDER BY t.created_at DESC LIMIT $2 OFFSET $3"); + let sel = feed_select(viewer); + let sql = format!("{sel} WHERE t.user_id=ANY($1) AND t.visibility='public' ORDER BY t.created_at DESC LIMIT $2 OFFSET $3"); let rows = sqlx::query_as::<_, FeedRow>(&sql) .bind(&ids).bind(page.limit()).bind(page.offset()) .fetch_all(&self.pool).await.map_err(|e| DomainError::Internal(e.to_string()))?; @@ -108,12 +122,14 @@ impl FeedRepository for PgFeedRepository { Ok(Paginated { items: rows.into_iter().map(row_to_entry).collect(), total, page: page.page, per_page: page.per_page }) } - async fn public_feed(&self, page: &PageParams, _viewer_id: Option<&UserId>) -> Result, DomainError> { + async fn public_feed(&self, page: &PageParams, viewer_id: Option<&UserId>) -> Result, DomainError> { + let viewer = viewer_id.map(|v| v.as_uuid()); let total: i64 = sqlx::query_scalar( "SELECT COUNT(*) FROM thoughts t WHERE t.local=true AND t.visibility='public'" ).fetch_one(&self.pool).await.map_err(|e| DomainError::Internal(e.to_string()))?; - let sql = format!("{FEED_SELECT} WHERE t.local=true AND t.visibility='public' ORDER BY t.created_at DESC LIMIT $1 OFFSET $2"); + let sel = feed_select(viewer); + let sql = format!("{sel} WHERE t.local=true AND t.visibility='public' ORDER BY t.created_at DESC LIMIT $1 OFFSET $2"); let rows = sqlx::query_as::<_, FeedRow>(&sql) .bind(page.limit()).bind(page.offset()) .fetch_all(&self.pool).await.map_err(|e| DomainError::Internal(e.to_string()))?; @@ -121,7 +137,8 @@ impl FeedRepository for PgFeedRepository { Ok(Paginated { items: rows.into_iter().map(row_to_entry).collect(), total, page: page.page, per_page: page.per_page }) } - async fn search(&self, query: &str, page: &PageParams, _viewer_id: Option<&UserId>) -> Result, DomainError> { + async fn search(&self, query: &str, page: &PageParams, viewer_id: Option<&UserId>) -> Result, DomainError> { + let viewer = viewer_id.map(|v| v.as_uuid()); let total: i64 = sqlx::query_scalar( "SELECT COUNT(*) FROM thoughts t WHERE t.content % $1 AND t.visibility='public'" ) @@ -130,7 +147,8 @@ impl FeedRepository for PgFeedRepository { .await .map_err(|e| DomainError::Internal(e.to_string()))?; - let sql = format!("{FEED_SELECT} WHERE t.content % $1 AND t.visibility='public' ORDER BY similarity(t.content, $1) DESC LIMIT $2 OFFSET $3"); + let sel = feed_select(viewer); + let sql = format!("{sel} WHERE t.content % $1 AND t.visibility='public' ORDER BY similarity(t.content, $1) DESC LIMIT $2 OFFSET $3"); let rows = sqlx::query_as::<_, FeedRow>(&sql) .bind(query) .bind(page.limit()) @@ -142,7 +160,8 @@ impl FeedRepository for PgFeedRepository { Ok(Paginated { items: rows.into_iter().map(row_to_entry).collect(), total, page: page.page, per_page: page.per_page }) } - async fn tag_feed(&self, tag_name: &str, page: &PageParams, _viewer_id: Option<&UserId>) -> Result, DomainError> { + async fn tag_feed(&self, tag_name: &str, page: &PageParams, viewer_id: Option<&UserId>) -> Result, DomainError> { + let viewer = viewer_id.map(|v| v.as_uuid()); let total: i64 = sqlx::query_scalar( "SELECT COUNT(*) FROM thoughts t JOIN thought_tags tt ON tt.thought_id = t.id @@ -154,8 +173,9 @@ impl FeedRepository for PgFeedRepository { .await .map_err(|e| DomainError::Internal(e.to_string()))?; + let sel = feed_select(viewer); let sql = format!( - "{FEED_SELECT} + "{sel} JOIN thought_tags tt ON tt.thought_id = t.id JOIN tags tg ON tg.id = tt.tag_id WHERE tg.name = $1 AND t.visibility = 'public' diff --git a/crates/application/src/use_cases/feed.rs b/crates/application/src/use_cases/feed.rs index ef5d38a..24c1c46 100644 --- a/crates/application/src/use_cases/feed.rs +++ b/crates/application/src/use_cases/feed.rs @@ -29,8 +29,8 @@ pub async fn get_following(follows: &dyn FollowRepository, user_id: &UserId, pag follows.list_following(user_id, &page).await } -pub async fn get_by_tag(feed: &dyn FeedRepository, tag_name: &str, page: PageParams) -> Result, DomainError> { - feed.tag_feed(tag_name, &page, None).await +pub async fn get_by_tag(feed: &dyn FeedRepository, tag_name: &str, page: PageParams, viewer_id: Option<&UserId>) -> Result, DomainError> { + feed.tag_feed(tag_name, &page, viewer_id).await } pub async fn search(feed: &dyn FeedRepository, query: &str, page: PageParams, viewer_id: Option<&UserId>) -> Result, DomainError> { diff --git a/crates/presentation/src/handlers/feed.rs b/crates/presentation/src/handlers/feed.rs index 117141a..e942f57 100644 --- a/crates/presentation/src/handlers/feed.rs +++ b/crates/presentation/src/handlers/feed.rs @@ -158,10 +158,11 @@ pub async fn get_popular_tags( pub async fn tag_thoughts_handler( State(s): State, Path(tag_name): Path, + OptionalAuthUser(viewer): OptionalAuthUser, Query(q): Query, ) -> Result, ApiError> { let page = PageParams { page: q.page(), per_page: q.per_page() }; - let result = get_by_tag(&*s.feed, &tag_name, page).await?; + let result = get_by_tag(&*s.feed, &tag_name, page, viewer.as_ref()).await?; Ok(Json(serde_json::json!({ "tag": tag_name, "total": result.total,