From d444871829a3126b206f381635da53778c5864b5 Mon Sep 17 00:00:00 2001 From: Gabriel Kaszewski Date: Sat, 15 Nov 2025 23:39:51 +0100 Subject: [PATCH] feat: Implement face clustering and media retrieval for persons --- Cargo.lock | 7 ++ libertas_api/Cargo.toml | 1 + libertas_api/src/factory.rs | 10 +- libertas_api/src/handlers/person_handlers.rs | 34 +++++- libertas_api/src/services/person_service.rs | 103 +++++++++++++++++- libertas_core/src/repositories.rs | 6 + libertas_core/src/services.rs | 9 ++ .../repositories/face_embedding_repository.rs | 19 ++++ .../src/repositories/media_repository.rs | 58 ++++++++++ 9 files changed, 239 insertions(+), 8 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 6b5e4d3..318b53c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -751,6 +751,12 @@ version = "2.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2a2330da5de22e8a3cb63252ce2abb30116bf5265e89c0e01bc17015ce30a476" +[[package]] +name = "dbscan" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9654109c8a07c62a71b9d443e72dde6c6e1364a9e8654126be122c58919d8a85" + [[package]] name = "der" version = "0.7.10" @@ -1810,6 +1816,7 @@ dependencies = [ "axum-extra", "bytes", "chrono", + "dbscan", "futures", "headers", "jsonwebtoken", diff --git a/libertas_api/Cargo.toml b/libertas_api/Cargo.toml index 081adfa..72df035 100644 --- a/libertas_api/Cargo.toml +++ b/libertas_api/Cargo.toml @@ -36,3 +36,4 @@ tower = { version = "0.5.2", features = ["util"] } tower-http = { version = "0.6.6", features = ["fs", "trace"] } tracing = "0.1.41" tracing-subscriber = { version = "0.3.20", features = ["env-filter"] } +dbscan = "0.3.1" diff --git a/libertas_api/src/factory.rs b/libertas_api/src/factory.rs index b8d7f75..d7db29d 100644 --- a/libertas_api/src/factory.rs +++ b/libertas_api/src/factory.rs @@ -6,9 +6,9 @@ use libertas_core::{ }; use libertas_infra::factory::{ build_album_repository, build_album_share_repository, build_database_pool, - build_face_region_repository, build_media_metadata_repository, build_media_repository, - build_person_repository, build_person_share_repository, build_tag_repository, - build_user_repository, + build_face_embedding_repository, build_face_region_repository, build_media_metadata_repository, + build_media_repository, build_person_repository, build_person_share_repository, + build_tag_repository, build_user_repository, }; use crate::{ @@ -40,6 +40,8 @@ pub async fn build_app_state(config: AppConfig) -> CoreResult { let face_region_repo = build_face_region_repository(&config.database, db_pool.clone()).await?; let person_share_repo = build_person_share_repository(&config.database, db_pool.clone()).await?; + let face_embedding_repo = + build_face_embedding_repository(&config.database, db_pool.clone()).await?; let hasher = Arc::new(Argon2Hasher::default()); let tokenizer = Arc::new(JwtGenerator::new(config.jwt_secret.clone())); @@ -81,6 +83,8 @@ pub async fn build_app_state(config: AppConfig) -> CoreResult { person_repo.clone(), face_region_repo.clone(), person_share_repo.clone(), + face_embedding_repo.clone(), + media_repo.clone(), authorization_service.clone(), )); diff --git a/libertas_api/src/handlers/person_handlers.rs b/libertas_api/src/handlers/person_handlers.rs index 94b1763..001c9d5 100644 --- a/libertas_api/src/handlers/person_handlers.rs +++ b/libertas_api/src/handlers/person_handlers.rs @@ -10,10 +10,12 @@ use uuid::Uuid; use crate::{ error::ApiError, + extractors::query_options::ApiListMediaOptions, middleware::auth::UserId, schema::{ - AssignFaceRequest, CreatePersonRequest, FaceRegionResponse, MergePersonRequest, - PersonResponse, SetPersonThumbnailRequest, SharePersonRequest, UpdatePersonRequest, + AssignFaceRequest, CreatePersonRequest, FaceRegionResponse, MediaResponse, + MergePersonRequest, PaginatedResponse, PersonResponse, SetPersonThumbnailRequest, + SharePersonRequest, UpdatePersonRequest, map_paginated_response, }, state::AppState, }; @@ -31,6 +33,8 @@ pub fn people_routes() -> Router { ) .route("/{person_id}/merge", post(merge_person)) .route("/{person_id}/thumbnail", put(set_person_thumbnail)) + .route("/cluster", post(cluster_faces)) + .route("/{person_id}/media", get(list_media_for_person)) } pub fn face_routes() -> Router { @@ -176,3 +180,29 @@ async fn set_person_thumbnail( .await?; Ok(StatusCode::OK) } + +async fn cluster_faces( + State(state): State, + UserId(user_id): UserId, +) -> Result { + state + .person_service + .cluster_unassigned_faces(user_id) + .await?; + Ok(StatusCode::OK) +} + +async fn list_media_for_person( + State(state): State, + UserId(user_id): UserId, + Path(person_id): Path, + ApiListMediaOptions(options): ApiListMediaOptions, +) -> Result>, ApiError> { + let core_paginated_result = state + .person_service + .list_media_for_person(person_id, user_id, options) + .await?; + + let api_response = map_paginated_response(core_paginated_result); + Ok(Json(api_response)) +} diff --git a/libertas_api/src/services/person_service.rs b/libertas_api/src/services/person_service.rs index 302cef0..9e964f4 100644 --- a/libertas_api/src/services/person_service.rs +++ b/libertas_api/src/services/person_service.rs @@ -1,11 +1,16 @@ -use std::sync::Arc; +use dbscan::{self, Classification}; +use std::{collections::HashMap, sync::Arc}; use async_trait::async_trait; use libertas_core::{ authz, error::{CoreError, CoreResult}, - models::{FaceRegion, Person, PersonPermission}, - repositories::{FaceRegionRepository, PersonRepository, PersonShareRepository}, + models::{FaceRegion, Media, Person, PersonPermission}, + repositories::{ + FaceEmbeddingRepository, FaceRegionRepository, MediaRepository, PersonRepository, + PersonShareRepository, + }, + schema::{ListMediaOptions, PaginatedResponse}, services::{AuthorizationService, PersonService}, }; use uuid::Uuid; @@ -14,6 +19,8 @@ pub struct PersonServiceImpl { person_repo: Arc, face_repo: Arc, person_share_repo: Arc, + face_embedding_repo: Arc, + media_repo: Arc, auth_service: Arc, } @@ -22,12 +29,16 @@ impl PersonServiceImpl { person_repo: Arc, face_repo: Arc, person_share_repo: Arc, + face_embedding_repo: Arc, + media_repo: Arc, auth_service: Arc, ) -> Self { Self { person_repo, face_repo, person_share_repo, + face_embedding_repo, + media_repo, auth_service, } } @@ -40,6 +51,13 @@ impl PersonServiceImpl { .ok_or(CoreError::NotFound("Person".to_string(), person_id))?; Ok(person) } + + fn bytes_to_f32_vec(bytes: &[u8]) -> Vec { + bytes + .chunks_exact(4) + .map(|chunk| f32::from_le_bytes(chunk.try_into().unwrap_or([0; 4]))) + .collect() + } } #[async_trait] @@ -245,4 +263,83 @@ impl PersonService for PersonServiceImpl { .set_thumbnail_media_id(person_id, face_region.media_id) .await } + + async fn cluster_unassigned_faces(&self, user_id: Uuid) -> CoreResult<()> { + let embedding_data = self + .face_embedding_repo + .list_unassigned_by_user(user_id) + .await?; + + if embedding_data.is_empty() { + return Ok(()); + } + + let embeddings_f32: Vec> = embedding_data + .iter() + .map(|data| Self::bytes_to_f32_vec(&data.embedding)) + .collect(); + + let scan = dbscan::Model::new(0.4, 2); + let clusters = scan.run(&embeddings_f32); + + let mut cluster_map: HashMap> = HashMap::new(); + + tracing::info!( + "DBSCAN found {} clusters", + clusters + .iter() + .filter(|c| match c { + Classification::Core(_) | Classification::Edge(_) => true, + Classification::Noise => false, + }) + .count() + ); + + for (i, classification) in clusters.iter().enumerate() { + match classification { + Classification::Core(cluster_id) | Classification::Edge(cluster_id) => { + cluster_map.entry(*cluster_id).or_default().push(i); + } + Classification::Noise => {} + } + } + + for (_cluster_id, indices) in cluster_map { + let person_name = format!("Person {}", Uuid::new_v4()); + let new_person = self.create_person(&person_name, user_id).await?; + + let face_ids: Vec = indices + .iter() + .map(|&i| embedding_data[i].face_region_id) + .collect(); + + for face_id in face_ids { + self.face_repo + .update_person_id(face_id, new_person.id) + .await?; + } + } + + Ok(()) + } + + async fn list_media_for_person( + &self, + person_id: Uuid, + user_id: Uuid, + options: ListMediaOptions, + ) -> CoreResult> { + self.auth_service + .check_permission(Some(user_id), authz::Permission::ViewPerson(person_id)) + .await?; + + let (data, total_items) = self + .media_repo + .list_by_person_id(person_id, &options) + .await?; + + let pagination = options.pagination.unwrap(); + let response = PaginatedResponse::new(data, pagination.page, pagination.limit, total_items); + Ok(response) + } } diff --git a/libertas_core/src/repositories.rs b/libertas_core/src/repositories.rs index d861422..1f4a55f 100644 --- a/libertas_core/src/repositories.rs +++ b/libertas_core/src/repositories.rs @@ -20,6 +20,11 @@ pub trait MediaRepository: Send + Sync { user_id: Uuid, options: &ListMediaOptions, ) -> CoreResult<(Vec, i64)>; + async fn list_by_person_id( + &self, + person_id: Uuid, + options: &ListMediaOptions, + ) -> CoreResult<(Vec, i64)>; async fn update_thumbnail_path(&self, id: Uuid, thumbnail_path: String) -> CoreResult<()>; async fn delete(&self, id: Uuid) -> CoreResult<()>; } @@ -139,4 +144,5 @@ pub trait FaceEmbeddingRepository: Send + Sync { &self, face_region_id: Uuid, ) -> CoreResult>; + async fn list_unassigned_by_user(&self, user_id: Uuid) -> CoreResult>; } diff --git a/libertas_core/src/services.rs b/libertas_core/src/services.rs index 363fc14..6203f7d 100644 --- a/libertas_core/src/services.rs +++ b/libertas_core/src/services.rs @@ -127,6 +127,15 @@ pub trait PersonService: Send + Sync { face_region_id: Uuid, user_id: Uuid, ) -> CoreResult<()>; + + async fn cluster_unassigned_faces(&self, user_id: Uuid) -> CoreResult<()>; + + async fn list_media_for_person( + &self, + person_id: Uuid, + user_id: Uuid, + options: ListMediaOptions, + ) -> CoreResult>; } #[async_trait] diff --git a/libertas_infra/src/repositories/face_embedding_repository.rs b/libertas_infra/src/repositories/face_embedding_repository.rs index 0ada3a0..75b5ed0 100644 --- a/libertas_infra/src/repositories/face_embedding_repository.rs +++ b/libertas_infra/src/repositories/face_embedding_repository.rs @@ -58,4 +58,23 @@ impl FaceEmbeddingRepository for PostgresFaceEmbeddingRepository { Ok(pg_embedding.map(FaceEmbedding::from)) } + + async fn list_unassigned_by_user(&self, user_id: Uuid) -> CoreResult> { + let pg_embeddings = sqlx::query_as!( + PostgresFaceEmbedding, + r#" + SELECT fe.id, fe.face_region_id, fe.model_id, fe.embedding + FROM face_embeddings fe + JOIN face_regions fr ON fe.face_region_id = fr.id + JOIN media m ON fr.media_id = m.id + WHERE fr.person_id IS NULL AND m.owner_id = $1 + "#, + user_id + ) + .fetch_all(&self.pool) + .await + .map_err(|e| CoreError::Database(e.to_string()))?; + + Ok(pg_embeddings.into_iter().map(FaceEmbedding::from).collect()) + } } diff --git a/libertas_infra/src/repositories/media_repository.rs b/libertas_infra/src/repositories/media_repository.rs index d52b4e1..4de1b7e 100644 --- a/libertas_infra/src/repositories/media_repository.rs +++ b/libertas_infra/src/repositories/media_repository.rs @@ -162,6 +162,64 @@ impl MediaRepository for PostgresMediaRepository { Ok((media_list, total_items_result)) } + async fn list_by_person_id( + &self, + person_id: Uuid, + options: &ListMediaOptions, + ) -> CoreResult<(Vec, i64)> { + let count_base_sql = " + SELECT COUNT(DISTINCT media.id) as total + FROM media + JOIN face_regions fr ON media.id = fr.media_id + "; + let mut count_query = sqlx::QueryBuilder::new(count_base_sql); + count_query.push(" WHERE fr.person_id = "); + count_query.push_bind(person_id); + + let (mut count_query, _metadata_filter_count) = self + .query_builder + .apply_filters_to_query(count_query, options)?; + + let total_items_result = count_query + .build_query_scalar() + .fetch_one(&self.pool) + .await + .map_err(|e| CoreError::Database(e.to_string()))?; + + let data_base_sql = " + SELECT media.id, media.owner_id, media.storage_path, + media.original_filename, media.mime_type, + media.hash, media.created_at, media.thumbnail_path + FROM media + JOIN face_regions fr ON media.id = fr.media_id + "; + let mut data_query = sqlx::QueryBuilder::new(data_base_sql); + data_query.push(" WHERE fr.person_id = "); + data_query.push_bind(person_id); + + let (mut data_query, _metadata_filter_count) = self + .query_builder + .apply_filters_to_query(data_query, options)?; + + data_query.push(" GROUP BY media.id "); + + let data_query = self + .query_builder + .apply_sorting_to_query(data_query, options)?; + let mut data_query = self + .query_builder + .apply_pagination_to_query(data_query, options)?; + + let pg_media = data_query + .build_query_as::() + .fetch_all(&self.pool) + .await + .map_err(|e| CoreError::Database(e.to_string()))?; + + let media_list = pg_media.into_iter().map(|m| m.into()).collect(); + Ok((media_list, total_items_result)) + } + async fn update_thumbnail_path(&self, id: Uuid, thumbnail_path: String) -> CoreResult<()> { sqlx::query!( r#"