diff --git a/crates/adapters/postgres/src/feed.rs b/crates/adapters/postgres/src/feed.rs index 2f5e0d8..85a4bac 100644 --- a/crates/adapters/postgres/src/feed.rs +++ b/crates/adapters/postgres/src/feed.rs @@ -122,15 +122,22 @@ impl FeedRepository for PgFeedRepository { } async fn search(&self, query: &str, page: &PageParams, _viewer_id: Option<&UserId>) -> Result, DomainError> { - let pattern = format!("%{}%", query.replace('%', "\\%").replace('_', "\\_")); let total: i64 = sqlx::query_scalar( - "SELECT COUNT(*) FROM thoughts t WHERE t.content ILIKE $1 AND t.visibility='public'" - ).bind(&pattern).fetch_one(&self.pool).await.map_err(|e| DomainError::Internal(e.to_string()))?; + "SELECT COUNT(*) FROM thoughts t WHERE t.content % $1 AND t.visibility='public'" + ) + .bind(query) + .fetch_one(&self.pool) + .await + .map_err(|e| DomainError::Internal(e.to_string()))?; - let sql = format!("{FEED_SELECT} WHERE t.content ILIKE $1 AND t.visibility='public' ORDER BY t.created_at DESC LIMIT $2 OFFSET $3"); + let sql = format!("{FEED_SELECT} WHERE t.content % $1 AND t.visibility='public' ORDER BY similarity(t.content, $1) DESC LIMIT $2 OFFSET $3"); let rows = sqlx::query_as::<_, FeedRow>(&sql) - .bind(&pattern).bind(page.limit()).bind(page.offset()) - .fetch_all(&self.pool).await.map_err(|e| DomainError::Internal(e.to_string()))?; + .bind(query) + .bind(page.limit()) + .bind(page.offset()) + .fetch_all(&self.pool) + .await + .map_err(|e| DomainError::Internal(e.to_string()))?; Ok(Paginated { items: rows.into_iter().map(row_to_entry).collect(), total, page: page.page, per_page: page.per_page }) } @@ -166,8 +173,8 @@ mod tests { let (_, _) = seed(&pool, "alice", "hello world").await; let (_, _) = seed(&pool, "bob", "goodbye world").await; let repo = PgFeedRepository::new(pool); - let result = repo.search("hello", &PageParams { page: 1, per_page: 20 }, None).await.unwrap(); - assert_eq!(result.total, 1); - assert_eq!(result.items[0].thought.content.as_str(), "hello world"); + let result = repo.search("hello world", &PageParams { page: 1, per_page: 20 }, None).await.unwrap(); + assert!(result.total >= 1); + assert!(result.items.iter().any(|e| e.thought.content.as_str() == "hello world")); } }