use crate::db_error::IntoDbResult; use async_trait::async_trait; use chrono::{DateTime, Utc}; use domain::{ errors::DomainError, models::{ feed::{FeedEntry, Paginated}, thought::{Thought, Visibility}, user::User, }, ports::{FeedFilter, FeedRepository, FeedRequest, FeedScope, FeedSort}, value_objects::{Content, Email, PasswordHash, ThoughtId, UserId, Username}, }; use sqlx::PgPool; pub struct PgFeedRepository { pool: PgPool, } impl PgFeedRepository { pub fn new(pool: PgPool) -> Self { Self { pool } } } #[derive(sqlx::FromRow)] struct FeedRow { thought_id: uuid::Uuid, t_user_id: uuid::Uuid, content: String, in_reply_to_id: Option, visibility: String, content_warning: Option, sensitive: bool, t_local: bool, thought_created_at: DateTime, updated_at: Option>, note_extensions: Option, author_id: uuid::Uuid, username: String, email: String, password_hash: String, display_name: Option, bio: Option, avatar_url: Option, header_url: Option, custom_css: Option, author_local: bool, author_created_at: DateTime, author_updated_at: DateTime, like_count: i64, boost_count: i64, reply_count: i64, liked_by_viewer: bool, boosted_by_viewer: bool, } fn federation_following_clause(follower: Option) -> String { match follower { Some(fid) => format!( " OR t.user_id IN ( SELECT u2.id FROM users u2 JOIN federation_following ff ON u2.ap_id = ff.remote_actor_url WHERE ff.local_user_id = '{fid}' )" ), None => String::new(), } } fn feed_select(viewer: Option) -> String { let viewer_checks = match viewer { Some(uid) => format!( "EXISTS(SELECT 1 FROM likes WHERE user_id='{uid}' AND thought_id=t.id) AS liked_by_viewer, EXISTS(SELECT 1 FROM boosts WHERE user_id='{uid}' AND thought_id=t.id) AS boosted_by_viewer" ), None => "false AS liked_by_viewer, false AS boosted_by_viewer".to_string(), }; format!( " SELECT t.id AS thought_id, t.user_id AS t_user_id, t.content, t.in_reply_to_id, t.visibility, t.content_warning, t.sensitive, t.local AS t_local, t.created_at AS thought_created_at, t.updated_at, t.note_extensions, u.id AS author_id, CASE WHEN NOT u.local AND ra.handle IS NOT NULL AND ra.handle != '' THEN '@' || ra.handle || CASE WHEN ra.handle NOT LIKE '%@%' THEN '@' || SPLIT_PART(ra.url, '/', 3) ELSE '' END ELSE u.username END AS username, u.email, u.password_hash, COALESCE(ra.display_name, u.display_name) AS display_name, u.bio, COALESCE(ra.avatar_url, u.avatar_url) AS avatar_url, u.header_url, u.custom_css, u.local AS author_local, u.created_at AS author_created_at, u.updated_at AS author_updated_at, (SELECT COUNT(*) FROM likes l WHERE l.thought_id=t.id) AS like_count, (SELECT COUNT(*) FROM boosts b WHERE b.thought_id=t.id) AS boost_count, (SELECT COUNT(*) FROM thoughts r WHERE r.in_reply_to_id=t.id) AS reply_count, {viewer_checks} FROM thoughts t JOIN users u ON u.id=t.user_id LEFT JOIN remote_actors ra ON u.ap_id = ra.url" ) } fn row_to_entry(r: FeedRow, viewer: Option) -> Result { let thought = Thought { id: ThoughtId::from_uuid(r.thought_id), user_id: UserId::from_uuid(r.t_user_id), content: Content::new_remote(r.content), in_reply_to_id: r.in_reply_to_id.map(ThoughtId::from_uuid), visibility: Visibility::from_db_str(&r.visibility)?, content_warning: r.content_warning, sensitive: r.sensitive, local: r.t_local, created_at: r.thought_created_at, updated_at: r.updated_at, note_extensions: r.note_extensions, }; let author = User { id: UserId::from_uuid(r.author_id), username: Username::from_trusted(r.username), email: Email::from_trusted(r.email), password_hash: PasswordHash(r.password_hash), display_name: r.display_name, bio: r.bio, avatar_url: r.avatar_url, header_url: r.header_url, custom_css: r.custom_css, local: r.author_local, created_at: r.author_created_at, updated_at: r.author_updated_at, }; Ok(FeedEntry { thought, author, stats: domain::models::feed::EngagementStats { like_count: r.like_count, boost_count: r.boost_count, reply_count: r.reply_count, }, viewer: viewer.map(|_| domain::models::feed::ViewerContext { liked: r.liked_by_viewer, boosted: r.boosted_by_viewer, }), }) } fn order_by_clause(sort: &FeedSort, scope: &FeedScope) -> &'static str { if matches!(scope, FeedScope::Search { .. }) { return "ORDER BY similarity(t.content, $1) DESC"; } match sort { FeedSort::Newest => "ORDER BY t.created_at DESC", FeedSort::Oldest => "ORDER BY t.created_at ASC", FeedSort::MostLiked => "ORDER BY like_count DESC, t.created_at DESC", FeedSort::MostBoosted => "ORDER BY boost_count DESC, t.created_at DESC", FeedSort::MostDiscussed => "ORDER BY reply_count DESC, t.created_at DESC", } } fn filter_clauses(f: &FeedFilter) -> String { let mut s = String::new(); if f.originals_only { s += " AND t.in_reply_to_id IS NULL"; } if f.replies_only { s += " AND t.in_reply_to_id IS NOT NULL"; } if f.local_only { s += " AND t.local = true"; } if f.hide_sensitive { s += " AND t.sensitive = false"; } s } #[async_trait] impl FeedRepository for PgFeedRepository { async fn query(&self, req: &FeedRequest) -> Result, DomainError> { let viewer = req.query.viewer_id.as_ref().map(|v| v.as_uuid()); let page = &req.query.page; let filter = filter_clauses(&req.options.filter); let order = order_by_clause(&req.options.sort, &req.query.scope); match &req.query.scope { FeedScope::Home { following_ids } => { let ids: Vec = following_ids.iter().map(|id| id.as_uuid()).collect(); let fed_clause = federation_following_clause(viewer); let count_sql = format!( "SELECT COUNT(*) FROM thoughts t WHERE (t.user_id=ANY($1){}) AND t.visibility != 'direct'{}", fed_clause, filter ); let total: i64 = sqlx::query_scalar(&count_sql) .bind(&ids) .fetch_one(&self.pool) .await .into_domain()?; let sel = feed_select(viewer); let sql = format!( "{sel} WHERE (t.user_id=ANY($1){}) AND t.visibility != 'direct'{} {} LIMIT $2 OFFSET $3", fed_clause, filter, order ); let rows = sqlx::query_as::<_, FeedRow>(&sql) .bind(&ids) .bind(page.limit()) .bind(page.offset()) .fetch_all(&self.pool) .await .into_domain()?; Ok(Paginated { items: rows .into_iter() .map(|r| row_to_entry(r, viewer)) .collect::, _>>()?, total, page: page.page, per_page: page.per_page, }) } FeedScope::Public => { let count_sql = format!( "SELECT COUNT(*) FROM thoughts t WHERE t.local=true AND t.visibility='public'{}", filter ); let total: i64 = sqlx::query_scalar(&count_sql) .fetch_one(&self.pool) .await .into_domain()?; let sel = feed_select(viewer); let sql = format!( "{sel} WHERE t.local=true AND t.visibility='public'{} {} LIMIT $1 OFFSET $2", filter, order ); let rows = sqlx::query_as::<_, FeedRow>(&sql) .bind(page.limit()) .bind(page.offset()) .fetch_all(&self.pool) .await .into_domain()?; Ok(Paginated { items: rows .into_iter() .map(|r| row_to_entry(r, viewer)) .collect::, _>>()?, total, page: page.page, per_page: page.per_page, }) } FeedScope::Search { query } => { let count_sql = format!( "SELECT COUNT(*) FROM thoughts t WHERE t.content % $1 AND t.visibility='public'{}", filter ); let total: i64 = sqlx::query_scalar(&count_sql) .bind(query) .fetch_one(&self.pool) .await .into_domain()?; let sel = feed_select(viewer); let sql = format!( "{sel} WHERE t.content % $1 AND t.visibility='public'{} {} LIMIT $2 OFFSET $3", filter, order ); let rows = sqlx::query_as::<_, FeedRow>(&sql) .bind(query) .bind(page.limit()) .bind(page.offset()) .fetch_all(&self.pool) .await .into_domain()?; Ok(Paginated { items: rows .into_iter() .map(|r| row_to_entry(r, viewer)) .collect::, _>>()?, total, page: page.page, per_page: page.per_page, }) } FeedScope::Tag { tag_name } => { let count_sql = format!( "SELECT COUNT(*) FROM thoughts t JOIN thought_tags tt ON tt.thought_id = t.id JOIN tags tg ON tg.id = tt.tag_id WHERE tg.name = $1 AND t.visibility = 'public'{}", filter ); let total: i64 = sqlx::query_scalar(&count_sql) .bind(tag_name) .fetch_one(&self.pool) .await .into_domain()?; let sel = feed_select(viewer); let sql = format!( "{sel} JOIN thought_tags tt ON tt.thought_id = t.id JOIN tags tg ON tg.id = tt.tag_id WHERE tg.name = $1 AND t.visibility = 'public'{} {} LIMIT $2 OFFSET $3", filter, order ); let rows = sqlx::query_as::<_, FeedRow>(&sql) .bind(tag_name) .bind(page.limit()) .bind(page.offset()) .fetch_all(&self.pool) .await .into_domain()?; Ok(Paginated { items: rows .into_iter() .map(|r| row_to_entry(r, viewer)) .collect::, _>>()?, total, page: page.page, per_page: page.per_page, }) } FeedScope::User { user_id } => { let uid = user_id.as_uuid(); let viewer_uuid = viewer.unwrap_or(uuid::Uuid::nil()); let count_sql = format!( "SELECT COUNT(*) FROM thoughts t WHERE t.user_id = $1 AND ($2::uuid = $1 OR (t.visibility != 'direct' AND (t.visibility IN ('public', 'unlisted') OR (t.visibility = 'followers' AND EXISTS(SELECT 1 FROM follows WHERE follower_id = $2 AND following_id = $1 AND state = 'accepted'))))){}", filter ); let total: i64 = sqlx::query_scalar(&count_sql) .bind(uid) .bind(viewer_uuid) .fetch_one(&self.pool) .await .into_domain()?; let sel = feed_select(viewer); let sql = format!( "{sel} WHERE t.user_id = $1 AND ($4::uuid = $1 OR (t.visibility != 'direct' AND (t.visibility IN ('public', 'unlisted') OR (t.visibility = 'followers' AND EXISTS(SELECT 1 FROM follows WHERE follower_id = $4 AND following_id = $1 AND state = 'accepted'))))){} {} LIMIT $2 OFFSET $3", filter, order ); let rows = sqlx::query_as::<_, FeedRow>(&sql) .bind(uid) .bind(page.limit()) .bind(page.offset()) .bind(viewer_uuid) .fetch_all(&self.pool) .await .into_domain()?; Ok(Paginated { items: rows .into_iter() .map(|r| row_to_entry(r, viewer)) .collect::, _>>()?, total, page: page.page, per_page: page.per_page, }) } } } } #[cfg(test)] mod tests;