feat(postgres-search): PgSearchRepository using pg_trgm

This commit is contained in:
2026-05-14 09:26:36 +02:00
parent a3534317de
commit ebf0aaab58
2 changed files with 285 additions and 0 deletions

View File

@@ -2,3 +2,15 @@
name = "postgres-search" name = "postgres-search"
version = "0.1.0" version = "0.1.0"
edition = "2021" edition = "2021"
[dependencies]
domain = { workspace = true }
sqlx = { workspace = true }
uuid = { workspace = true }
chrono = { workspace = true }
async-trait = { workspace = true }
[dev-dependencies]
tokio = { workspace = true, features = ["full"] }
sqlx = { workspace = true, features = ["migrate"] }
postgres = { workspace = true }

View File

@@ -0,0 +1,273 @@
use async_trait::async_trait;
use chrono::{DateTime, Utc};
use sqlx::PgPool;
use domain::{
errors::DomainError,
models::{
feed::{FeedEntry, PageParams, Paginated},
thought::Thought,
user::User,
},
ports::SearchPort,
value_objects::{Content, Email, PasswordHash, ThoughtId, UserId, Username},
};
use domain::models::thought::Visibility;
pub struct PgSearchRepository { pool: PgPool }
impl PgSearchRepository { pub fn new(pool: PgPool) -> Self { Self { pool } } }
#[derive(sqlx::FromRow)]
struct FeedRow {
thought_id: uuid::Uuid,
t_user_id: uuid::Uuid,
content: String,
in_reply_to_id: Option<uuid::Uuid>,
in_reply_to_url: Option<String>,
t_ap_id: Option<String>,
visibility: String,
content_warning: Option<String>,
sensitive: bool,
t_local: bool,
thought_created_at: DateTime<Utc>,
updated_at: Option<DateTime<Utc>>,
author_id: uuid::Uuid,
username: String,
email: String,
password_hash: String,
display_name: Option<String>,
bio: Option<String>,
avatar_url: Option<String>,
header_url: Option<String>,
custom_css: Option<String>,
author_local: bool,
u_ap_id: Option<String>,
inbox_url: Option<String>,
public_key: Option<String>,
private_key: Option<String>,
author_created_at: DateTime<Utc>,
author_updated_at: DateTime<Utc>,
like_count: i64,
boost_count: i64,
reply_count: i64,
}
const FEED_SELECT: &str = "
SELECT
t.id AS thought_id, t.user_id AS t_user_id, t.content,
t.in_reply_to_id, t.in_reply_to_url, t.ap_id AS t_ap_id,
t.visibility, t.content_warning, t.sensitive, t.local AS t_local,
t.created_at AS thought_created_at, t.updated_at,
u.id AS author_id, u.username, u.email, u.password_hash,
u.display_name, u.bio, u.avatar_url, u.header_url, u.custom_css,
u.local AS author_local, u.ap_id AS u_ap_id, u.inbox_url,
u.public_key, u.private_key,
u.created_at AS author_created_at, u.updated_at AS author_updated_at,
(SELECT COUNT(*) FROM likes l WHERE l.thought_id=t.id) AS like_count,
(SELECT COUNT(*) FROM boosts b WHERE b.thought_id=t.id) AS boost_count,
(SELECT COUNT(*) FROM thoughts r WHERE r.in_reply_to_id=t.id) AS reply_count
FROM thoughts t JOIN users u ON u.id=t.user_id";
fn row_to_entry(r: FeedRow) -> FeedEntry {
let thought = Thought {
id: ThoughtId::from_uuid(r.thought_id),
user_id: UserId::from_uuid(r.t_user_id),
content: Content::new_remote(r.content),
in_reply_to_id: r.in_reply_to_id.map(ThoughtId::from_uuid),
in_reply_to_url: r.in_reply_to_url,
ap_id: r.t_ap_id,
visibility: Visibility::from_str(&r.visibility),
content_warning: r.content_warning,
sensitive: r.sensitive,
local: r.t_local,
created_at: r.thought_created_at,
updated_at: r.updated_at,
};
let author = User {
id: UserId::from_uuid(r.author_id),
username: Username::from_trusted(r.username),
email: Email::from_trusted(r.email),
password_hash: PasswordHash(r.password_hash),
display_name: r.display_name, bio: r.bio,
avatar_url: r.avatar_url, header_url: r.header_url, custom_css: r.custom_css,
local: r.author_local, ap_id: r.u_ap_id, inbox_url: r.inbox_url,
public_key: r.public_key, private_key: r.private_key,
created_at: r.author_created_at, updated_at: r.author_updated_at,
};
FeedEntry { thought, author, like_count: r.like_count, boost_count: r.boost_count, reply_count: r.reply_count, liked_by_viewer: false, boosted_by_viewer: false }
}
#[derive(sqlx::FromRow)]
struct UserRow {
id: uuid::Uuid,
username: String,
email: String,
password_hash: String,
display_name: Option<String>,
bio: Option<String>,
avatar_url: Option<String>,
header_url: Option<String>,
custom_css: Option<String>,
local: bool,
ap_id: Option<String>,
inbox_url: Option<String>,
public_key: Option<String>,
private_key: Option<String>,
created_at: DateTime<Utc>,
updated_at: DateTime<Utc>,
}
impl From<UserRow> for User {
fn from(r: UserRow) -> Self {
User {
id: UserId::from_uuid(r.id),
username: Username::from_trusted(r.username),
email: Email::from_trusted(r.email),
password_hash: PasswordHash(r.password_hash),
display_name: r.display_name, bio: r.bio,
avatar_url: r.avatar_url, header_url: r.header_url, custom_css: r.custom_css,
local: r.local, ap_id: r.ap_id, inbox_url: r.inbox_url,
public_key: r.public_key, private_key: r.private_key,
created_at: r.created_at, updated_at: r.updated_at,
}
}
}
const USER_SELECT: &str =
"SELECT id,username,email,password_hash,display_name,bio,avatar_url,header_url,\
custom_css,local,ap_id,inbox_url,public_key,private_key,created_at,updated_at FROM users";
#[async_trait]
impl SearchPort for PgSearchRepository {
async fn search_thoughts(
&self,
query: &str,
page: &PageParams,
_viewer_id: Option<&UserId>,
) -> Result<Paginated<FeedEntry>, DomainError> {
let total: i64 = sqlx::query_scalar(
"SELECT COUNT(*) FROM thoughts t
WHERE t.content % $1 AND t.visibility='public'"
)
.bind(query)
.fetch_one(&self.pool)
.await
.map_err(|e| DomainError::Internal(e.to_string()))?;
let sql = format!(
"{FEED_SELECT}
WHERE t.content % $1 AND t.visibility='public'
ORDER BY similarity(t.content, $1) DESC
LIMIT $2 OFFSET $3"
);
let rows = sqlx::query_as::<_, FeedRow>(&sql)
.bind(query)
.bind(page.limit())
.bind(page.offset())
.fetch_all(&self.pool)
.await
.map_err(|e| DomainError::Internal(e.to_string()))?;
Ok(Paginated {
items: rows.into_iter().map(row_to_entry).collect(),
total,
page: page.page,
per_page: page.per_page,
})
}
async fn search_users(
&self,
query: &str,
page: &PageParams,
) -> Result<Paginated<User>, DomainError> {
let total: i64 = sqlx::query_scalar(
"SELECT COUNT(*) FROM users u
WHERE u.local=true AND (u.username % $1 OR u.display_name % $1)"
)
.bind(query)
.fetch_one(&self.pool)
.await
.map_err(|e| DomainError::Internal(e.to_string()))?;
let sql = format!(
"{USER_SELECT}
WHERE local=true AND (username % $1 OR display_name % $1)
ORDER BY similarity(username || ' ' || COALESCE(display_name,''), $1) DESC
LIMIT $2 OFFSET $3"
);
let rows = sqlx::query_as::<_, UserRow>(&sql)
.bind(query)
.bind(page.limit())
.bind(page.offset())
.fetch_all(&self.pool)
.await
.map_err(|e| DomainError::Internal(e.to_string()))?;
Ok(Paginated {
items: rows.into_iter().map(User::from).collect(),
total,
page: page.page,
per_page: page.per_page,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use domain::{
models::{thought::{Thought, Visibility}, user::User},
ports::{SearchPort, ThoughtRepository, UserRepository},
value_objects::*,
};
async fn seed_thought(pool: &sqlx::PgPool, username: &str, content: &str) -> (User, Thought) {
use postgres::{thought::PgThoughtRepository, user::PgUserRepository};
let urepo = PgUserRepository::new(pool.clone());
let trepo = PgThoughtRepository::new(pool.clone());
let u = User::new_local(
UserId::new(),
Username::new(username).unwrap(),
Email::new(format!("{username}@ex.com")).unwrap(),
PasswordHash("h".into()),
);
urepo.save(&u).await.unwrap();
let t = Thought::new_local(
ThoughtId::new(), u.id.clone(),
Content::new_local(content).unwrap(),
None, Visibility::Public, None, false,
);
trepo.save(&t).await.unwrap();
(u, t)
}
#[sqlx::test(migrations = "../postgres/migrations")]
async fn search_thoughts_finds_by_keyword(pool: sqlx::PgPool) {
seed_thought(&pool, "alice", "hello world").await;
seed_thought(&pool, "bob", "goodbye universe").await;
let repo = PgSearchRepository::new(pool);
let result = repo.search_thoughts("hello world", &PageParams { page: 1, per_page: 20 }, None).await.unwrap();
assert_eq!(result.total, 1);
assert_eq!(result.items[0].thought.content.as_str(), "hello world");
}
#[sqlx::test(migrations = "../postgres/migrations")]
async fn search_users_finds_by_username(pool: sqlx::PgPool) {
use postgres::user::PgUserRepository;
let urepo = PgUserRepository::new(pool.clone());
let alice = User::new_local(UserId::new(), Username::new("alice_search").unwrap(), Email::new("alice@ex.com").unwrap(), PasswordHash("h".into()));
urepo.save(&alice).await.unwrap();
let repo = PgSearchRepository::new(pool);
let result = repo.search_users("alice", &PageParams { page: 1, per_page: 20 }).await.unwrap();
assert!(!result.items.is_empty());
assert!(result.items.iter().any(|u| u.username.as_str() == "alice_search"));
}
#[sqlx::test(migrations = "../postgres/migrations")]
async fn search_thoughts_returns_empty_for_no_match(pool: sqlx::PgPool) {
seed_thought(&pool, "alice", "hello world").await;
let repo = PgSearchRepository::new(pool);
let result = repo.search_thoughts("zzzzzzzzz", &PageParams { page: 1, per_page: 20 }, None).await.unwrap();
assert_eq!(result.total, 0);
}
}