diff --git a/crates/adapters/postgres-search/Cargo.toml b/crates/adapters/postgres-search/Cargo.toml index ec88563..071d1c2 100644 --- a/crates/adapters/postgres-search/Cargo.toml +++ b/crates/adapters/postgres-search/Cargo.toml @@ -2,3 +2,15 @@ name = "postgres-search" version = "0.1.0" edition = "2021" + +[dependencies] +domain = { workspace = true } +sqlx = { workspace = true } +uuid = { workspace = true } +chrono = { workspace = true } +async-trait = { workspace = true } + +[dev-dependencies] +tokio = { workspace = true, features = ["full"] } +sqlx = { workspace = true, features = ["migrate"] } +postgres = { workspace = true } diff --git a/crates/adapters/postgres-search/src/lib.rs b/crates/adapters/postgres-search/src/lib.rs index e69de29..0a2f898 100644 --- a/crates/adapters/postgres-search/src/lib.rs +++ b/crates/adapters/postgres-search/src/lib.rs @@ -0,0 +1,273 @@ +use async_trait::async_trait; +use chrono::{DateTime, Utc}; +use sqlx::PgPool; +use domain::{ + errors::DomainError, + models::{ + feed::{FeedEntry, PageParams, Paginated}, + thought::Thought, + user::User, + }, + ports::SearchPort, + value_objects::{Content, Email, PasswordHash, ThoughtId, UserId, Username}, +}; +use domain::models::thought::Visibility; + +pub struct PgSearchRepository { pool: PgPool } +impl PgSearchRepository { pub fn new(pool: PgPool) -> Self { Self { pool } } } + +#[derive(sqlx::FromRow)] +struct FeedRow { + thought_id: uuid::Uuid, + t_user_id: uuid::Uuid, + content: String, + in_reply_to_id: Option, + in_reply_to_url: Option, + t_ap_id: Option, + visibility: String, + content_warning: Option, + sensitive: bool, + t_local: bool, + thought_created_at: DateTime, + updated_at: Option>, + author_id: uuid::Uuid, + username: String, + email: String, + password_hash: String, + display_name: Option, + bio: Option, + avatar_url: Option, + header_url: Option, + custom_css: Option, + author_local: bool, + u_ap_id: Option, + inbox_url: Option, + public_key: Option, + private_key: Option, + author_created_at: DateTime, + author_updated_at: DateTime, + like_count: i64, + boost_count: i64, + reply_count: i64, +} + +const FEED_SELECT: &str = " + SELECT + t.id AS thought_id, t.user_id AS t_user_id, t.content, + t.in_reply_to_id, t.in_reply_to_url, t.ap_id AS t_ap_id, + t.visibility, t.content_warning, t.sensitive, t.local AS t_local, + t.created_at AS thought_created_at, t.updated_at, + u.id AS author_id, u.username, u.email, u.password_hash, + u.display_name, u.bio, u.avatar_url, u.header_url, u.custom_css, + u.local AS author_local, u.ap_id AS u_ap_id, u.inbox_url, + u.public_key, u.private_key, + u.created_at AS author_created_at, u.updated_at AS author_updated_at, + (SELECT COUNT(*) FROM likes l WHERE l.thought_id=t.id) AS like_count, + (SELECT COUNT(*) FROM boosts b WHERE b.thought_id=t.id) AS boost_count, + (SELECT COUNT(*) FROM thoughts r WHERE r.in_reply_to_id=t.id) AS reply_count + FROM thoughts t JOIN users u ON u.id=t.user_id"; + +fn row_to_entry(r: FeedRow) -> FeedEntry { + let thought = Thought { + id: ThoughtId::from_uuid(r.thought_id), + user_id: UserId::from_uuid(r.t_user_id), + content: Content::new_remote(r.content), + in_reply_to_id: r.in_reply_to_id.map(ThoughtId::from_uuid), + in_reply_to_url: r.in_reply_to_url, + ap_id: r.t_ap_id, + visibility: Visibility::from_str(&r.visibility), + content_warning: r.content_warning, + sensitive: r.sensitive, + local: r.t_local, + created_at: r.thought_created_at, + updated_at: r.updated_at, + }; + let author = User { + id: UserId::from_uuid(r.author_id), + username: Username::from_trusted(r.username), + email: Email::from_trusted(r.email), + password_hash: PasswordHash(r.password_hash), + display_name: r.display_name, bio: r.bio, + avatar_url: r.avatar_url, header_url: r.header_url, custom_css: r.custom_css, + local: r.author_local, ap_id: r.u_ap_id, inbox_url: r.inbox_url, + public_key: r.public_key, private_key: r.private_key, + created_at: r.author_created_at, updated_at: r.author_updated_at, + }; + FeedEntry { thought, author, like_count: r.like_count, boost_count: r.boost_count, reply_count: r.reply_count, liked_by_viewer: false, boosted_by_viewer: false } +} + +#[derive(sqlx::FromRow)] +struct UserRow { + id: uuid::Uuid, + username: String, + email: String, + password_hash: String, + display_name: Option, + bio: Option, + avatar_url: Option, + header_url: Option, + custom_css: Option, + local: bool, + ap_id: Option, + inbox_url: Option, + public_key: Option, + private_key: Option, + created_at: DateTime, + updated_at: DateTime, +} + +impl From for User { + fn from(r: UserRow) -> Self { + User { + id: UserId::from_uuid(r.id), + username: Username::from_trusted(r.username), + email: Email::from_trusted(r.email), + password_hash: PasswordHash(r.password_hash), + display_name: r.display_name, bio: r.bio, + avatar_url: r.avatar_url, header_url: r.header_url, custom_css: r.custom_css, + local: r.local, ap_id: r.ap_id, inbox_url: r.inbox_url, + public_key: r.public_key, private_key: r.private_key, + created_at: r.created_at, updated_at: r.updated_at, + } + } +} + +const USER_SELECT: &str = + "SELECT id,username,email,password_hash,display_name,bio,avatar_url,header_url,\ + custom_css,local,ap_id,inbox_url,public_key,private_key,created_at,updated_at FROM users"; + +#[async_trait] +impl SearchPort for PgSearchRepository { + async fn search_thoughts( + &self, + query: &str, + page: &PageParams, + _viewer_id: Option<&UserId>, + ) -> Result, DomainError> { + let total: i64 = sqlx::query_scalar( + "SELECT COUNT(*) FROM thoughts t + WHERE t.content % $1 AND t.visibility='public'" + ) + .bind(query) + .fetch_one(&self.pool) + .await + .map_err(|e| DomainError::Internal(e.to_string()))?; + + let sql = format!( + "{FEED_SELECT} + WHERE t.content % $1 AND t.visibility='public' + ORDER BY similarity(t.content, $1) DESC + LIMIT $2 OFFSET $3" + ); + let rows = sqlx::query_as::<_, FeedRow>(&sql) + .bind(query) + .bind(page.limit()) + .bind(page.offset()) + .fetch_all(&self.pool) + .await + .map_err(|e| DomainError::Internal(e.to_string()))?; + + Ok(Paginated { + items: rows.into_iter().map(row_to_entry).collect(), + total, + page: page.page, + per_page: page.per_page, + }) + } + + async fn search_users( + &self, + query: &str, + page: &PageParams, + ) -> Result, DomainError> { + let total: i64 = sqlx::query_scalar( + "SELECT COUNT(*) FROM users u + WHERE u.local=true AND (u.username % $1 OR u.display_name % $1)" + ) + .bind(query) + .fetch_one(&self.pool) + .await + .map_err(|e| DomainError::Internal(e.to_string()))?; + + let sql = format!( + "{USER_SELECT} + WHERE local=true AND (username % $1 OR display_name % $1) + ORDER BY similarity(username || ' ' || COALESCE(display_name,''), $1) DESC + LIMIT $2 OFFSET $3" + ); + let rows = sqlx::query_as::<_, UserRow>(&sql) + .bind(query) + .bind(page.limit()) + .bind(page.offset()) + .fetch_all(&self.pool) + .await + .map_err(|e| DomainError::Internal(e.to_string()))?; + + Ok(Paginated { + items: rows.into_iter().map(User::from).collect(), + total, + page: page.page, + per_page: page.per_page, + }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use domain::{ + models::{thought::{Thought, Visibility}, user::User}, + ports::{SearchPort, ThoughtRepository, UserRepository}, + value_objects::*, + }; + + async fn seed_thought(pool: &sqlx::PgPool, username: &str, content: &str) -> (User, Thought) { + use postgres::{thought::PgThoughtRepository, user::PgUserRepository}; + let urepo = PgUserRepository::new(pool.clone()); + let trepo = PgThoughtRepository::new(pool.clone()); + let u = User::new_local( + UserId::new(), + Username::new(username).unwrap(), + Email::new(format!("{username}@ex.com")).unwrap(), + PasswordHash("h".into()), + ); + urepo.save(&u).await.unwrap(); + let t = Thought::new_local( + ThoughtId::new(), u.id.clone(), + Content::new_local(content).unwrap(), + None, Visibility::Public, None, false, + ); + trepo.save(&t).await.unwrap(); + (u, t) + } + + #[sqlx::test(migrations = "../postgres/migrations")] + async fn search_thoughts_finds_by_keyword(pool: sqlx::PgPool) { + seed_thought(&pool, "alice", "hello world").await; + seed_thought(&pool, "bob", "goodbye universe").await; + let repo = PgSearchRepository::new(pool); + let result = repo.search_thoughts("hello world", &PageParams { page: 1, per_page: 20 }, None).await.unwrap(); + assert_eq!(result.total, 1); + assert_eq!(result.items[0].thought.content.as_str(), "hello world"); + } + + #[sqlx::test(migrations = "../postgres/migrations")] + async fn search_users_finds_by_username(pool: sqlx::PgPool) { + use postgres::user::PgUserRepository; + let urepo = PgUserRepository::new(pool.clone()); + let alice = User::new_local(UserId::new(), Username::new("alice_search").unwrap(), Email::new("alice@ex.com").unwrap(), PasswordHash("h".into())); + urepo.save(&alice).await.unwrap(); + let repo = PgSearchRepository::new(pool); + let result = repo.search_users("alice", &PageParams { page: 1, per_page: 20 }).await.unwrap(); + assert!(!result.items.is_empty()); + assert!(result.items.iter().any(|u| u.username.as_str() == "alice_search")); + } + + #[sqlx::test(migrations = "../postgres/migrations")] + async fn search_thoughts_returns_empty_for_no_match(pool: sqlx::PgPool) { + seed_thought(&pool, "alice", "hello world").await; + let repo = PgSearchRepository::new(pool); + let result = repo.search_thoughts("zzzzzzzzz", &PageParams { page: 1, per_page: 20 }, None).await.unwrap(); + assert_eq!(result.total, 0); + } +}