use async_trait::async_trait; use chrono::{DateTime, Utc}; use sqlx::PgPool; use domain::{ errors::DomainError, models::{feed::{FeedEntry, PageParams, Paginated}, thought::Thought, user::User}, ports::FeedRepository, value_objects::{Content, Email, PasswordHash, ThoughtId, UserId, Username}, }; use domain::models::thought::Visibility; pub struct PgFeedRepository { pool: PgPool } impl PgFeedRepository { pub fn new(pool: PgPool) -> Self { Self { pool } } } #[derive(sqlx::FromRow)] struct FeedRow { thought_id: uuid::Uuid, t_user_id: uuid::Uuid, content: String, in_reply_to_id: Option, in_reply_to_url: Option, t_ap_id: Option, visibility: String, content_warning: Option, sensitive: bool, t_local: bool, thought_created_at: DateTime, updated_at: Option>, author_id: uuid::Uuid, username: String, email: String, password_hash: String, display_name: Option, bio: Option, avatar_url: Option, header_url: Option, custom_css: Option, author_local: bool, u_ap_id: Option, inbox_url: Option, public_key: Option, private_key: Option, author_created_at: DateTime, author_updated_at: DateTime, like_count: i64, boost_count: i64, reply_count: i64, } const FEED_SELECT: &str = " SELECT t.id AS thought_id, t.user_id AS t_user_id, t.content, t.in_reply_to_id, t.in_reply_to_url, t.ap_id AS t_ap_id, t.visibility, t.content_warning, t.sensitive, t.local AS t_local, t.created_at AS thought_created_at, t.updated_at, u.id AS author_id, u.username, u.email, u.password_hash, u.display_name, u.bio, u.avatar_url, u.header_url, u.custom_css, u.local AS author_local, u.ap_id AS u_ap_id, u.inbox_url, u.public_key, u.private_key, u.created_at AS author_created_at, u.updated_at AS author_updated_at, (SELECT COUNT(*) FROM likes l WHERE l.thought_id=t.id) AS like_count, (SELECT COUNT(*) FROM boosts b WHERE b.thought_id=t.id) AS boost_count, (SELECT COUNT(*) FROM thoughts r WHERE r.in_reply_to_id=t.id) AS reply_count FROM thoughts t JOIN users u ON u.id=t.user_id"; fn row_to_entry(r: FeedRow) -> FeedEntry { let thought = Thought { id: ThoughtId::from_uuid(r.thought_id), user_id: UserId::from_uuid(r.t_user_id), content: Content::new_remote(r.content), in_reply_to_id: r.in_reply_to_id.map(ThoughtId::from_uuid), in_reply_to_url: r.in_reply_to_url, ap_id: r.t_ap_id, visibility: Visibility::from_str(&r.visibility), content_warning: r.content_warning, sensitive: r.sensitive, local: r.t_local, created_at: r.thought_created_at, updated_at: r.updated_at, }; let author = User { id: UserId::from_uuid(r.author_id), username: Username::from_trusted(r.username), email: Email::from_trusted(r.email), password_hash: PasswordHash(r.password_hash), display_name: r.display_name, bio: r.bio, avatar_url: r.avatar_url, header_url: r.header_url, custom_css: r.custom_css, local: r.author_local, ap_id: r.u_ap_id, inbox_url: r.inbox_url, public_key: r.public_key, private_key: r.private_key, created_at: r.author_created_at, updated_at: r.author_updated_at, }; FeedEntry { thought, author, like_count: r.like_count, boost_count: r.boost_count, reply_count: r.reply_count, liked_by_viewer: false, boosted_by_viewer: false } } #[async_trait] impl FeedRepository for PgFeedRepository { async fn home_feed(&self, following_ids: &[UserId], page: &PageParams, _viewer_id: Option<&UserId>) -> Result, DomainError> { let ids: Vec = following_ids.iter().map(|id| id.as_uuid()).collect(); let total: i64 = sqlx::query_scalar( "SELECT COUNT(*) FROM thoughts t WHERE t.user_id=ANY($1) AND t.visibility='public'" ).bind(&ids).fetch_one(&self.pool).await.map_err(|e| DomainError::Internal(e.to_string()))?; let sql = format!("{FEED_SELECT} WHERE t.user_id=ANY($1) AND t.visibility='public' ORDER BY t.created_at DESC LIMIT $2 OFFSET $3"); let rows = sqlx::query_as::<_, FeedRow>(&sql) .bind(&ids).bind(page.limit()).bind(page.offset()) .fetch_all(&self.pool).await.map_err(|e| DomainError::Internal(e.to_string()))?; Ok(Paginated { items: rows.into_iter().map(row_to_entry).collect(), total, page: page.page, per_page: page.per_page }) } async fn public_feed(&self, page: &PageParams, _viewer_id: Option<&UserId>) -> Result, DomainError> { let total: i64 = sqlx::query_scalar( "SELECT COUNT(*) FROM thoughts t WHERE t.local=true AND t.visibility='public'" ).fetch_one(&self.pool).await.map_err(|e| DomainError::Internal(e.to_string()))?; let sql = format!("{FEED_SELECT} WHERE t.local=true AND t.visibility='public' ORDER BY t.created_at DESC LIMIT $1 OFFSET $2"); let rows = sqlx::query_as::<_, FeedRow>(&sql) .bind(page.limit()).bind(page.offset()) .fetch_all(&self.pool).await.map_err(|e| DomainError::Internal(e.to_string()))?; Ok(Paginated { items: rows.into_iter().map(row_to_entry).collect(), total, page: page.page, per_page: page.per_page }) } async fn search(&self, query: &str, page: &PageParams, _viewer_id: Option<&UserId>) -> Result, DomainError> { let pattern = format!("%{}%", query.replace('%', "\\%").replace('_', "\\_")); let total: i64 = sqlx::query_scalar( "SELECT COUNT(*) FROM thoughts t WHERE t.content ILIKE $1 AND t.visibility='public'" ).bind(&pattern).fetch_one(&self.pool).await.map_err(|e| DomainError::Internal(e.to_string()))?; let sql = format!("{FEED_SELECT} WHERE t.content ILIKE $1 AND t.visibility='public' ORDER BY t.created_at DESC LIMIT $2 OFFSET $3"); let rows = sqlx::query_as::<_, FeedRow>(&sql) .bind(&pattern).bind(page.limit()).bind(page.offset()) .fetch_all(&self.pool).await.map_err(|e| DomainError::Internal(e.to_string()))?; Ok(Paginated { items: rows.into_iter().map(row_to_entry).collect(), total, page: page.page, per_page: page.per_page }) } } #[cfg(test)] mod tests { use super::*; use domain::{models::{thought::{Thought, Visibility}, user::User}, ports::{ThoughtRepository, UserRepository}, value_objects::*}; use crate::{thought::PgThoughtRepository, user::PgUserRepository}; async fn seed(pool: &sqlx::PgPool, username: &str, content: &str) -> (User, Thought) { let urepo = PgUserRepository::new(pool.clone()); let trepo = PgThoughtRepository::new(pool.clone()); let u = User::new_local(UserId::new(), Username::new(username).unwrap(), Email::new(format!("{username}@ex.com")).unwrap(), PasswordHash("h".into())); urepo.save(&u).await.unwrap(); let t = Thought::new_local(ThoughtId::new(), u.id.clone(), Content::new_local(content).unwrap(), None, Visibility::Public, None, false); trepo.save(&t).await.unwrap(); (u, t) } #[sqlx::test(migrations = "./migrations")] async fn public_feed_returns_local_thoughts(pool: sqlx::PgPool) { let (_, _) = seed(&pool, "alice", "hello").await; let repo = PgFeedRepository::new(pool); let result = repo.public_feed(&PageParams { page: 1, per_page: 20 }, None).await.unwrap(); assert_eq!(result.total, 1); assert_eq!(result.items[0].thought.content.as_str(), "hello"); } #[sqlx::test(migrations = "./migrations")] async fn search_returns_matching_thoughts(pool: sqlx::PgPool) { let (_, _) = seed(&pool, "alice", "hello world").await; let (_, _) = seed(&pool, "bob", "goodbye world").await; let repo = PgFeedRepository::new(pool); let result = repo.search("hello", &PageParams { page: 1, per_page: 20 }, None).await.unwrap(); assert_eq!(result.total, 1); assert_eq!(result.items[0].thought.content.as_str(), "hello world"); } }