diff --git a/crates/adapters/postgres/src/thought.rs b/crates/adapters/postgres/src/thought.rs index d28bfd5..a5a82d0 100644 --- a/crates/adapters/postgres/src/thought.rs +++ b/crates/adapters/postgres/src/thought.rs @@ -1,2 +1,236 @@ -pub struct PgThoughtRepository { _pool: sqlx::PgPool } -impl PgThoughtRepository { pub fn new(pool: sqlx::PgPool) -> Self { Self { _pool: pool } } } +use async_trait::async_trait; +use chrono::{DateTime, Utc}; +use sqlx::PgPool; +use domain::{ + errors::DomainError, + models::{ + feed::{FeedEntry, PageParams, Paginated}, + thought::{Thought, Visibility}, + user::User, + }, + ports::ThoughtRepository, + value_objects::{Content, Email, PasswordHash, ThoughtId, UserId, Username}, +}; + +pub struct PgThoughtRepository { pool: PgPool } +impl PgThoughtRepository { pub fn new(pool: PgPool) -> Self { Self { pool } } } + +#[derive(sqlx::FromRow)] +pub(crate) struct ThoughtRow { + pub id: uuid::Uuid, + pub user_id: uuid::Uuid, + pub content: String, + pub in_reply_to_id: Option, + pub in_reply_to_url: Option, + pub ap_id: Option, + pub visibility: String, + pub content_warning: Option, + pub sensitive: bool, + pub local: bool, + pub created_at: DateTime, + pub updated_at: Option>, +} + +impl From for Thought { + fn from(r: ThoughtRow) -> Self { + Thought { + id: ThoughtId::from_uuid(r.id), + user_id: UserId::from_uuid(r.user_id), + content: Content::new_remote(r.content), + in_reply_to_id: r.in_reply_to_id.map(ThoughtId::from_uuid), + in_reply_to_url: r.in_reply_to_url, + ap_id: r.ap_id, + visibility: Visibility::from_str(&r.visibility), + content_warning: r.content_warning, + sensitive: r.sensitive, + local: r.local, + created_at: r.created_at, + updated_at: r.updated_at, + } + } +} + +const THOUGHT_SELECT: &str = + "SELECT id,user_id,content,in_reply_to_id,in_reply_to_url,ap_id,visibility,content_warning,sensitive,local,created_at,updated_at FROM thoughts"; + +#[async_trait] +impl ThoughtRepository for PgThoughtRepository { + async fn save(&self, t: &Thought) -> Result<(), DomainError> { + sqlx::query( + "INSERT INTO thoughts(id,user_id,content,in_reply_to_id,in_reply_to_url,ap_id,visibility,content_warning,sensitive,local,created_at) + VALUES($1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11) + ON CONFLICT(id) DO UPDATE SET content=EXCLUDED.content,updated_at=NOW()" + ) + .bind(t.id.as_uuid()) + .bind(t.user_id.as_uuid()) + .bind(t.content.as_str()) + .bind(t.in_reply_to_id.as_ref().map(|x| x.as_uuid())) + .bind(&t.in_reply_to_url) + .bind(&t.ap_id) + .bind(t.visibility.as_str()) + .bind(&t.content_warning) + .bind(t.sensitive) + .bind(t.local) + .bind(t.created_at) + .execute(&self.pool) + .await + .map_err(|e| DomainError::Internal(e.to_string())) + .map(|_| ()) + } + + async fn find_by_id(&self, id: &ThoughtId) -> Result, DomainError> { + sqlx::query_as::<_, ThoughtRow>(&format!("{THOUGHT_SELECT} WHERE id=$1")) + .bind(id.as_uuid()) + .fetch_optional(&self.pool) + .await + .map_err(|e| DomainError::Internal(e.to_string())) + .map(|o| o.map(Thought::from)) + } + + async fn delete(&self, id: &ThoughtId, user_id: &UserId) -> Result<(), DomainError> { + let r = sqlx::query("DELETE FROM thoughts WHERE id=$1 AND user_id=$2") + .bind(id.as_uuid()) + .bind(user_id.as_uuid()) + .execute(&self.pool) + .await + .map_err(|e| DomainError::Internal(e.to_string()))?; + if r.rows_affected() == 0 { return Err(DomainError::NotFound); } + Ok(()) + } + + async fn update_content(&self, id: &ThoughtId, content: &Content) -> Result<(), DomainError> { + sqlx::query("UPDATE thoughts SET content=$2,updated_at=NOW() WHERE id=$1") + .bind(id.as_uuid()) + .bind(content.as_str()) + .execute(&self.pool) + .await + .map_err(|e| DomainError::Internal(e.to_string())) + .map(|_| ()) + } + + async fn get_thread(&self, id: &ThoughtId) -> Result, DomainError> { + sqlx::query_as::<_, ThoughtRow>( + &format!("{THOUGHT_SELECT} WHERE id=$1 OR in_reply_to_id=$1 ORDER BY created_at ASC") + ) + .bind(id.as_uuid()) + .fetch_all(&self.pool) + .await + .map_err(|e| DomainError::Internal(e.to_string())) + .map(|rows| rows.into_iter().map(Thought::from).collect()) + } + + async fn list_by_user(&self, user_id: &UserId, page: &PageParams) -> Result, DomainError> { + let total: i64 = sqlx::query_scalar("SELECT COUNT(*) FROM thoughts WHERE user_id=$1") + .bind(user_id.as_uuid()) + .fetch_one(&self.pool) + .await + .map_err(|e| DomainError::Internal(e.to_string()))?; + + let rows = sqlx::query_as::<_, ThoughtRow>( + &format!("{THOUGHT_SELECT} WHERE user_id=$1 ORDER BY created_at DESC LIMIT $2 OFFSET $3") + ) + .bind(user_id.as_uuid()) + .bind(page.limit()) + .bind(page.offset()) + .fetch_all(&self.pool) + .await + .map_err(|e| DomainError::Internal(e.to_string()))?; + + let author = sqlx::query_as::<_, crate::user::UserRow>( + "SELECT id,username,email,password_hash,display_name,bio,avatar_url,header_url,custom_css,local,ap_id,inbox_url,public_key,private_key,created_at,updated_at FROM users WHERE id=$1" + ) + .bind(user_id.as_uuid()) + .fetch_optional(&self.pool) + .await + .map_err(|e| DomainError::Internal(e.to_string()))? + .ok_or(DomainError::NotFound)?; + let author = User::from(author); + + let items = rows.into_iter().map(|r| { + let thought = Thought::from(r); + FeedEntry { + thought, + author: author.clone(), + like_count: 0, + boost_count: 0, + reply_count: 0, + liked_by_viewer: false, + boosted_by_viewer: false, + } + }).collect(); + + Ok(Paginated { items, total, page: page.page, per_page: page.per_page }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use domain::{models::{thought::{Thought, Visibility}, user::User}, value_objects::*}; + use crate::user::PgUserRepository; + use domain::ports::UserRepository; + + async fn seed_user(pool: &sqlx::PgPool, username: &str, email: &str) -> User { + let repo = PgUserRepository::new(pool.clone()); + let u = User::new_local( + UserId::new(), + Username::new(username).unwrap(), + Email::new(email).unwrap(), + PasswordHash("h".into()), + ); + repo.save(&u).await.unwrap(); + u + } + + #[sqlx::test(migrations = "./migrations")] + async fn save_and_find_thought(pool: sqlx::PgPool) { + let user = seed_user(&pool, "alice", "alice@ex.com").await; + let repo = PgThoughtRepository::new(pool); + let t = Thought::new_local( + ThoughtId::new(), + user.id.clone(), + Content::new_local("hello world").unwrap(), + None, + Visibility::Public, + None, + false, + ); + repo.save(&t).await.unwrap(); + let found = repo.find_by_id(&t.id).await.unwrap().unwrap(); + assert_eq!(found.content.as_str(), "hello world"); + assert!(found.local); + } + + #[sqlx::test(migrations = "./migrations")] + async fn delete_thought(pool: sqlx::PgPool) { + let user = seed_user(&pool, "bob", "bob@ex.com").await; + let repo = PgThoughtRepository::new(pool); + let t = Thought::new_local(ThoughtId::new(), user.id.clone(), Content::new_local("bye").unwrap(), None, Visibility::Public, None, false); + repo.save(&t).await.unwrap(); + repo.delete(&t.id, &user.id).await.unwrap(); + assert!(repo.find_by_id(&t.id).await.unwrap().is_none()); + } + + #[sqlx::test(migrations = "./migrations")] + async fn delete_wrong_owner_returns_not_found(pool: sqlx::PgPool) { + let alice = seed_user(&pool, "alice", "alice@ex.com").await; + let bob = seed_user(&pool, "bob", "bob@ex.com").await; + let repo = PgThoughtRepository::new(pool); + let t = Thought::new_local(ThoughtId::new(), alice.id.clone(), Content::new_local("secret").unwrap(), None, Visibility::Public, None, false); + repo.save(&t).await.unwrap(); + let err = repo.delete(&t.id, &bob.id).await.unwrap_err(); + assert!(matches!(err, DomainError::NotFound)); + } + + #[sqlx::test(migrations = "./migrations")] + async fn get_thread_returns_root_and_replies(pool: sqlx::PgPool) { + let user = seed_user(&pool, "charlie", "charlie@ex.com").await; + let repo = PgThoughtRepository::new(pool); + let root = Thought::new_local(ThoughtId::new(), user.id.clone(), Content::new_local("root").unwrap(), None, Visibility::Public, None, false); + let reply = Thought::new_local(ThoughtId::new(), user.id.clone(), Content::new_local("reply").unwrap(), Some(root.id.clone()), Visibility::Public, None, false); + repo.save(&root).await.unwrap(); + repo.save(&reply).await.unwrap(); + let thread = repo.get_thread(&root.id).await.unwrap(); + assert_eq!(thread.len(), 2); + } +}