From 0688ffe0ae2a7b3c9670803d6e7ee3ce7a49f747 Mon Sep 17 00:00:00 2001 From: Gabriel Kaszewski Date: Thu, 28 May 2026 23:45:46 +0200 Subject: [PATCH] feat(backend): wire FeedRequest/FeedOptions sort+filter through all feed layers --- crates/adapters/postgres/src/feed/mod.rs | 146 +++++++++++------- crates/adapters/postgres/src/feed/tests.rs | 38 +++-- .../use_cases/federation_management/mod.rs | 13 +- crates/application/src/use_cases/feed.rs | 10 +- crates/domain/src/testing/mod.rs | 2 +- crates/presentation/src/handlers/feed.rs | 74 ++++++++- 6 files changed, 197 insertions(+), 86 deletions(-) diff --git a/crates/adapters/postgres/src/feed/mod.rs b/crates/adapters/postgres/src/feed/mod.rs index d0e884f..67c7afc 100644 --- a/crates/adapters/postgres/src/feed/mod.rs +++ b/crates/adapters/postgres/src/feed/mod.rs @@ -9,7 +9,7 @@ use domain::{ thought::{Thought, Visibility}, user::User, }, - ports::{FeedQuery, FeedRepository, FeedScope}, + ports::{FeedFilter, FeedRepository, FeedRequest, FeedScope, FeedSort}, value_objects::{Content, Email, PasswordHash, ThoughtId, UserId, Username}, }; use sqlx::PgPool; @@ -151,28 +151,62 @@ fn row_to_entry(r: FeedRow, viewer: Option) -> Result &'static str { + if matches!(scope, FeedScope::Search { .. }) { + return "ORDER BY similarity(t.content, $1) DESC"; + } + match sort { + FeedSort::Newest => "ORDER BY t.created_at DESC", + FeedSort::Oldest => "ORDER BY t.created_at ASC", + FeedSort::MostLiked => "ORDER BY like_count DESC, t.created_at DESC", + FeedSort::MostBoosted => "ORDER BY boost_count DESC, t.created_at DESC", + FeedSort::MostDiscussed => "ORDER BY reply_count DESC, t.created_at DESC", + } +} + +fn filter_clauses(f: &FeedFilter) -> String { + let mut s = String::new(); + if f.originals_only { + s += " AND t.in_reply_to_id IS NULL"; + } + if f.replies_only { + s += " AND t.in_reply_to_id IS NOT NULL"; + } + if f.local_only { + s += " AND t.local = true"; + } + if f.hide_sensitive { + s += " AND t.sensitive = false"; + } + s +} + #[async_trait] impl FeedRepository for PgFeedRepository { - async fn query(&self, q: &FeedQuery) -> Result, DomainError> { - let viewer = q.viewer_id.as_ref().map(|v| v.as_uuid()); - let page = &q.page; + async fn query(&self, req: &FeedRequest) -> Result, DomainError> { + let viewer = req.query.viewer_id.as_ref().map(|v| v.as_uuid()); + let page = &req.query.page; + let filter = filter_clauses(&req.options.filter); + let order = order_by_clause(&req.options.sort, &req.query.scope); - match &q.scope { + match &req.query.scope { FeedScope::Home { following_ids } => { let ids: Vec = following_ids.iter().map(|id| id.as_uuid()).collect(); let fed_clause = federation_following_clause(viewer); let count_sql = format!( - "SELECT COUNT(*) FROM thoughts t WHERE (t.user_id=ANY($1){}) AND t.visibility != 'direct'", - fed_clause + "SELECT COUNT(*) FROM thoughts t WHERE (t.user_id=ANY($1){}) AND t.visibility != 'direct'{}", + fed_clause, filter ); let total: i64 = sqlx::query_scalar(&count_sql) .bind(&ids) .fetch_one(&self.pool) .await .into_domain()?; - let sel = feed_select(viewer); - let sql = format!("{sel} WHERE (t.user_id=ANY($1){}) AND t.visibility != 'direct' ORDER BY t.created_at DESC LIMIT $2 OFFSET $3", fed_clause); + let sql = format!( + "{sel} WHERE (t.user_id=ANY($1){}) AND t.visibility != 'direct'{} {} LIMIT $2 OFFSET $3", + fed_clause, filter, order + ); let rows = sqlx::query_as::<_, FeedRow>(&sql) .bind(&ids) .bind(page.limit()) @@ -180,7 +214,6 @@ impl FeedRepository for PgFeedRepository { .fetch_all(&self.pool) .await .into_domain()?; - Ok(Paginated { items: rows .into_iter() @@ -193,22 +226,25 @@ impl FeedRepository for PgFeedRepository { } FeedScope::Public => { - let total: i64 = sqlx::query_scalar( - "SELECT COUNT(*) FROM thoughts t WHERE t.local=true AND t.visibility='public'", - ) - .fetch_one(&self.pool) - .await - .into_domain()?; - + let count_sql = format!( + "SELECT COUNT(*) FROM thoughts t WHERE t.local=true AND t.visibility='public'{}", + filter + ); + let total: i64 = sqlx::query_scalar(&count_sql) + .fetch_one(&self.pool) + .await + .into_domain()?; let sel = feed_select(viewer); - let sql = format!("{sel} WHERE t.local=true AND t.visibility='public' ORDER BY t.created_at DESC LIMIT $1 OFFSET $2"); + let sql = format!( + "{sel} WHERE t.local=true AND t.visibility='public'{} {} LIMIT $1 OFFSET $2", + filter, order + ); let rows = sqlx::query_as::<_, FeedRow>(&sql) .bind(page.limit()) .bind(page.offset()) .fetch_all(&self.pool) .await .into_domain()?; - Ok(Paginated { items: rows .into_iter() @@ -221,16 +257,20 @@ impl FeedRepository for PgFeedRepository { } FeedScope::Search { query } => { - let total: i64 = sqlx::query_scalar( - "SELECT COUNT(*) FROM thoughts t WHERE t.content % $1 AND t.visibility='public'", - ) - .bind(query) - .fetch_one(&self.pool) - .await - .into_domain()?; - + let count_sql = format!( + "SELECT COUNT(*) FROM thoughts t WHERE t.content % $1 AND t.visibility='public'{}", + filter + ); + let total: i64 = sqlx::query_scalar(&count_sql) + .bind(query) + .fetch_one(&self.pool) + .await + .into_domain()?; let sel = feed_select(viewer); - let sql = format!("{sel} WHERE t.content % $1 AND t.visibility='public' ORDER BY similarity(t.content, $1) DESC LIMIT $2 OFFSET $3"); + let sql = format!( + "{sel} WHERE t.content % $1 AND t.visibility='public'{} {} LIMIT $2 OFFSET $3", + filter, order + ); let rows = sqlx::query_as::<_, FeedRow>(&sql) .bind(query) .bind(page.limit()) @@ -238,7 +278,6 @@ impl FeedRepository for PgFeedRepository { .fetch_all(&self.pool) .await .into_domain()?; - Ok(Paginated { items: rows .into_iter() @@ -251,24 +290,25 @@ impl FeedRepository for PgFeedRepository { } FeedScope::Tag { tag_name } => { - let total: i64 = sqlx::query_scalar( + let count_sql = format!( "SELECT COUNT(*) FROM thoughts t JOIN thought_tags tt ON tt.thought_id = t.id JOIN tags tg ON tg.id = tt.tag_id - WHERE tg.name = $1 AND t.visibility = 'public'", - ) - .bind(tag_name) - .fetch_one(&self.pool) - .await - .into_domain()?; - + WHERE tg.name = $1 AND t.visibility = 'public'{}", + filter + ); + let total: i64 = sqlx::query_scalar(&count_sql) + .bind(tag_name) + .fetch_one(&self.pool) + .await + .into_domain()?; let sel = feed_select(viewer); let sql = format!( "{sel} JOIN thought_tags tt ON tt.thought_id = t.id JOIN tags tg ON tg.id = tt.tag_id - WHERE tg.name = $1 AND t.visibility = 'public' - ORDER BY t.created_at DESC LIMIT $2 OFFSET $3" + WHERE tg.name = $1 AND t.visibility = 'public'{} {} LIMIT $2 OFFSET $3", + filter, order ); let rows = sqlx::query_as::<_, FeedRow>(&sql) .bind(tag_name) @@ -277,7 +317,6 @@ impl FeedRepository for PgFeedRepository { .fetch_all(&self.pool) .await .into_domain()?; - Ok(Paginated { items: rows .into_iter() @@ -291,20 +330,22 @@ impl FeedRepository for PgFeedRepository { FeedScope::User { user_id } => { let uid = user_id.as_uuid(); - // Use nil UUID for unauthenticated viewers — won't match owner or follower checks. let viewer_uuid = viewer.unwrap_or(uuid::Uuid::nil()); - - let total: i64 = sqlx::query_scalar( - "SELECT COUNT(*) FROM thoughts t WHERE t.user_id = $1 AND ($2::uuid = $1 OR (t.visibility != 'direct' AND (t.visibility IN ('public', 'unlisted') OR (t.visibility = 'followers' AND EXISTS(SELECT 1 FROM follows WHERE follower_id = $2 AND following_id = $1 AND state = 'accepted')))))", - ) - .bind(uid) - .bind(viewer_uuid) - .fetch_one(&self.pool) - .await - .into_domain()?; - + let count_sql = format!( + "SELECT COUNT(*) FROM thoughts t WHERE t.user_id = $1 AND ($2::uuid = $1 OR (t.visibility != 'direct' AND (t.visibility IN ('public', 'unlisted') OR (t.visibility = 'followers' AND EXISTS(SELECT 1 FROM follows WHERE follower_id = $2 AND following_id = $1 AND state = 'accepted'))))){}", + filter + ); + let total: i64 = sqlx::query_scalar(&count_sql) + .bind(uid) + .bind(viewer_uuid) + .fetch_one(&self.pool) + .await + .into_domain()?; let sel = feed_select(viewer); - let sql = format!("{sel} WHERE t.user_id = $1 AND ($4::uuid = $1 OR (t.visibility != 'direct' AND (t.visibility IN ('public', 'unlisted') OR (t.visibility = 'followers' AND EXISTS(SELECT 1 FROM follows WHERE follower_id = $4 AND following_id = $1 AND state = 'accepted'))))) ORDER BY t.created_at DESC LIMIT $2 OFFSET $3"); + let sql = format!( + "{sel} WHERE t.user_id = $1 AND ($4::uuid = $1 OR (t.visibility != 'direct' AND (t.visibility IN ('public', 'unlisted') OR (t.visibility = 'followers' AND EXISTS(SELECT 1 FROM follows WHERE follower_id = $4 AND following_id = $1 AND state = 'accepted'))))){} {} LIMIT $2 OFFSET $3", + filter, order + ); let rows = sqlx::query_as::<_, FeedRow>(&sql) .bind(uid) .bind(page.limit()) @@ -313,7 +354,6 @@ impl FeedRepository for PgFeedRepository { .fetch_all(&self.pool) .await .into_domain()?; - Ok(Paginated { items: rows .into_iter() diff --git a/crates/adapters/postgres/src/feed/tests.rs b/crates/adapters/postgres/src/feed/tests.rs index 51f8b79..a2669de 100644 --- a/crates/adapters/postgres/src/feed/tests.rs +++ b/crates/adapters/postgres/src/feed/tests.rs @@ -6,7 +6,7 @@ use domain::{ thought::{NewThought, Thought, Visibility}, user::User, }, - ports::{FeedQuery, ThoughtRepository, UserWriter}, + ports::{FeedOptions, FeedQuery, FeedRequest, ThoughtRepository, UserWriter}, value_objects::*, }; @@ -38,13 +38,16 @@ async fn public_feed_returns_local_thoughts(pool: sqlx::PgPool) { let (_, _) = seed(&pool, "alice", "hello").await; let repo = PgFeedRepository::new(pool); let result = repo - .query(&FeedQuery::public( - PageParams { - page: 1, - per_page: 20, - }, - None, - )) + .query(&FeedRequest { + query: FeedQuery::public( + PageParams { + page: 1, + per_page: 20, + }, + None, + ), + options: FeedOptions::default(), + }) .await .unwrap(); assert_eq!(result.total, 1); @@ -57,14 +60,17 @@ async fn search_returns_matching_thoughts(pool: sqlx::PgPool) { let (_, _) = seed(&pool, "bob", "goodbye world").await; let repo = PgFeedRepository::new(pool); let result = repo - .query(&FeedQuery::search( - "hello world", - PageParams { - page: 1, - per_page: 20, - }, - None, - )) + .query(&FeedRequest { + query: FeedQuery::search( + "hello world", + PageParams { + page: 1, + per_page: 20, + }, + None, + ), + options: FeedOptions::default(), + }) .await .unwrap(); assert!(result.total >= 1); diff --git a/crates/application/src/use_cases/federation_management/mod.rs b/crates/application/src/use_cases/federation_management/mod.rs index 50c714e..f921223 100644 --- a/crates/application/src/use_cases/federation_management/mod.rs +++ b/crates/application/src/use_cases/federation_management/mod.rs @@ -9,8 +9,8 @@ use domain::{ }, ports::{ EventPublisher, FederationActionPort, FederationFollowPort, FederationFollowRequestPort, - FederationSchedulerPort, FeedQuery, FeedRepository, FollowRepository, - RemoteActorConnectionRepository, UserReader, + FederationSchedulerPort, FeedOptions, FeedQuery, FeedRepository, FeedRequest, + FollowRepository, RemoteActorConnectionRepository, UserReader, }, value_objects::UserId, }; @@ -136,11 +136,10 @@ pub async fn get_remote_actor_posts( None => ap_repo.intern_remote_actor(&actor.url).await?, }; let result = feed - .query(&FeedQuery::user( - author_id, - page.clone(), - viewer_id.cloned(), - )) + .query(&FeedRequest { + query: FeedQuery::user(author_id, page.clone(), viewer_id.cloned()), + options: FeedOptions::default(), + }) .await?; if let Some(outbox_url) = actor.outbox_url { let _ = scheduler diff --git a/crates/application/src/use_cases/feed.rs b/crates/application/src/use_cases/feed.rs index b16e057..b45235d 100644 --- a/crates/application/src/use_cases/feed.rs +++ b/crates/application/src/use_cases/feed.rs @@ -1,7 +1,7 @@ use domain::{ errors::DomainError, models::feed::{FeedEntry, PageParams, Paginated}, - ports::{FeedQuery, FeedRepository, FollowRepository}, + ports::{FeedOptions, FeedQuery, FeedRepository, FeedRequest, FollowRepository}, value_objects::UserId, }; @@ -10,9 +10,13 @@ pub async fn get_home_feed( follows: &dyn FollowRepository, user_id: &UserId, page: PageParams, + opts: FeedOptions, ) -> Result, DomainError> { let mut following_ids = follows.get_accepted_following_ids(user_id).await?; following_ids.push(user_id.clone()); - feed.query(&FeedQuery::home(user_id.clone(), following_ids, page)) - .await + feed.query(&FeedRequest { + query: FeedQuery::home(user_id.clone(), following_ids, page), + options: opts, + }) + .await } diff --git a/crates/domain/src/testing/mod.rs b/crates/domain/src/testing/mod.rs index 9ad9518..30e0a69 100644 --- a/crates/domain/src/testing/mod.rs +++ b/crates/domain/src/testing/mod.rs @@ -882,7 +882,7 @@ impl RemoteActorConnectionRepository for TestStore { impl FeedRepository for TestStore { async fn query( &self, - _q: &crate::ports::FeedQuery, + _req: &crate::ports::FeedRequest, ) -> Result, DomainError> { Ok(Paginated { items: vec![], diff --git a/crates/presentation/src/handlers/feed.rs b/crates/presentation/src/handlers/feed.rs index ed07d53..1efadda 100644 --- a/crates/presentation/src/handlers/feed.rs +++ b/crates/presentation/src/handlers/feed.rs @@ -17,11 +17,53 @@ use axum::{ use domain::{ models::feed::PageParams, ports::{ - FederationActionPort, FeedQuery, FeedRepository, FollowRepository, SearchPort, - TagRepository, UserRepository, + FederationActionPort, FeedFilter, FeedOptions, FeedQuery, FeedRepository, FeedRequest, + FeedSort, FollowRepository, SearchPort, TagRepository, UserRepository, }, }; +#[derive(serde::Deserialize, Default)] +pub struct FeedOptionsQuery { + pub sort: Option, + pub originals_only: Option, + pub replies_only: Option, + pub local_only: Option, + pub hide_sensitive: Option, +} + +impl TryFrom for FeedOptions { + type Error = crate::errors::ApiError; + + fn try_from(q: FeedOptionsQuery) -> Result { + if q.originals_only.unwrap_or(false) && q.replies_only.unwrap_or(false) { + return Err(crate::errors::ApiError::BadRequest( + "originals_only and replies_only are mutually exclusive".to_string(), + )); + } + let sort = match q.sort.as_deref() { + None | Some("newest") => FeedSort::Newest, + Some("oldest") => FeedSort::Oldest, + Some("most_liked") => FeedSort::MostLiked, + Some("most_boosted") => FeedSort::MostBoosted, + Some("most_discussed") => FeedSort::MostDiscussed, + Some(other) => { + return Err(crate::errors::ApiError::BadRequest(format!( + "unknown sort value: {other}" + ))) + } + }; + Ok(FeedOptions { + sort, + filter: FeedFilter { + originals_only: q.originals_only.unwrap_or(false), + replies_only: q.replies_only.unwrap_or(false), + local_only: q.local_only.unwrap_or(false), + hide_sensitive: q.hide_sensitive.unwrap_or(false), + }, + }) + } +} + deps_struct!(FeedDeps { feed: FeedRepository, follows: FollowRepository, @@ -62,12 +104,14 @@ pub async fn home_feed( Deps(d): Deps, AuthUser(uid): AuthUser, Query(q): Query, + Query(opts_q): Query, ) -> Result, ApiError> { let page = PageParams { page: q.page(), per_page: q.per_page(), }; - let result = get_home_feed(&*d.feed, &*d.follows, &uid, page).await?; + let opts = FeedOptions::try_from(opts_q)?; + let result = get_home_feed(&*d.feed, &*d.follows, &uid, page, opts).await?; Ok(Json(serde_json::json!({ "items": result.items.iter().map(to_thought_response).collect::>(), "total": result.total, @@ -85,12 +129,20 @@ pub async fn public_feed( Deps(d): Deps, OptionalAuthUser(viewer): OptionalAuthUser, Query(q): Query, + Query(opts_q): Query, ) -> Result, ApiError> { let page = PageParams { page: q.page(), per_page: q.per_page(), }; - let result = d.feed.query(&FeedQuery::public(page, viewer)).await?; + let opts = FeedOptions::try_from(opts_q)?; + let result = d + .feed + .query(&FeedRequest { + query: FeedQuery::public(page, viewer), + options: opts, + }) + .await?; Ok(Json(serde_json::json!({ "items": result.items.iter().map(to_thought_response).collect::>(), "total": result.total, @@ -222,15 +274,20 @@ pub async fn user_thoughts_handler( Path(username): Path, OptionalAuthUser(viewer): OptionalAuthUser, Query(q): Query, + Query(opts_q): Query, ) -> Result, ApiError> { let user = get_user_by_username(&*d.users, &username).await?; let page = PageParams { page: q.page(), per_page: q.per_page(), }; + let opts = FeedOptions::try_from(opts_q)?; let result = d .feed - .query(&FeedQuery::user(user.id.clone(), page, viewer)) + .query(&FeedRequest { + query: FeedQuery::user(user.id.clone(), page, viewer), + options: opts, + }) .await?; Ok(Json(serde_json::json!({ "total": result.total, @@ -273,14 +330,19 @@ pub async fn tag_thoughts_handler( Path(tag_name): Path, OptionalAuthUser(viewer): OptionalAuthUser, Query(q): Query, + Query(opts_q): Query, ) -> Result, ApiError> { let page = PageParams { page: q.page(), per_page: q.per_page(), }; + let opts = FeedOptions::try_from(opts_q)?; let result = d .feed - .query(&FeedQuery::tag(&tag_name, page, viewer)) + .query(&FeedRequest { + query: FeedQuery::tag(&tag_name, page, viewer), + options: opts, + }) .await?; Ok(Json(serde_json::json!({ "tag": tag_name,