diff --git a/libertas_api/migrations/20251115202359_add_thumbnail_media_id_to_albums.sql b/libertas_api/migrations/20251115202359_add_thumbnail_media_id_to_albums.sql new file mode 100644 index 0000000..a4355b0 --- /dev/null +++ b/libertas_api/migrations/20251115202359_add_thumbnail_media_id_to_albums.sql @@ -0,0 +1,2 @@ +ALTER TABLE albums +ADD COLUMN thumbnail_media_id UUID REFERENCES media(id) ON DELETE SET NULL; \ No newline at end of file diff --git a/libertas_api/migrations/20251115203140_add_thumbnail_media_id_to_people.sql b/libertas_api/migrations/20251115203140_add_thumbnail_media_id_to_people.sql new file mode 100644 index 0000000..2ce76d1 --- /dev/null +++ b/libertas_api/migrations/20251115203140_add_thumbnail_media_id_to_people.sql @@ -0,0 +1,2 @@ +ALTER TABLE people +ADD COLUMN thumbnail_media_id UUID REFERENCES media(id) ON DELETE SET NULL; \ No newline at end of file diff --git a/libertas_api/migrations/20251115210437_create_face_embeddings.sql b/libertas_api/migrations/20251115210437_create_face_embeddings.sql new file mode 100644 index 0000000..baae7a8 --- /dev/null +++ b/libertas_api/migrations/20251115210437_create_face_embeddings.sql @@ -0,0 +1,7 @@ +CREATE TABLE face_embeddings ( + id UUID PRIMARY KEY, + face_region_id UUID NOT NULL REFERENCES face_regions(id) ON DELETE CASCADE, + model_id SMALLINT NOT NULL, + embedding BYTEA NOT NULL +); +CREATE UNIQUE INDEX idx_face_embeddings_region_id ON face_embeddings (face_region_id); \ No newline at end of file diff --git a/libertas_api/src/handlers/album_handlers.rs b/libertas_api/src/handlers/album_handlers.rs index 6612b49..87904f1 100644 --- a/libertas_api/src/handlers/album_handlers.rs +++ b/libertas_api/src/handlers/album_handlers.rs @@ -2,12 +2,22 @@ use axum::{ Json, Router, extract::{Path, State}, http::StatusCode, - routing::{get, post}, + routing::{get, post, put}, +}; +use libertas_core::schema::{ + AddMediaToAlbumData, CreateAlbumData, ShareAlbumData, UpdateAlbumData, }; -use libertas_core::schema::{AddMediaToAlbumData, CreateAlbumData, ShareAlbumData, UpdateAlbumData}; use uuid::Uuid; -use crate::{error::ApiError, middleware::auth::UserId, schema::{AddMediaToAlbumRequest, AlbumResponse, CreateAlbumRequest, ShareAlbumRequest, UpdateAlbumRequest}, state::AppState}; +use crate::{ + error::ApiError, + middleware::auth::UserId, + schema::{ + AddMediaToAlbumRequest, AlbumResponse, CreateAlbumRequest, SetThumbnailRequest, + ShareAlbumRequest, UpdateAlbumRequest, + }, + state::AppState, +}; async fn create_album( State(state): State, @@ -110,6 +120,19 @@ async fn delete_album( Ok(StatusCode::NO_CONTENT) } +async fn set_album_thumbnail( + State(state): State, + UserId(user_id): UserId, + Path(album_id): Path, + Json(payload): Json, +) -> Result { + state + .album_service + .set_album_thumbnail(album_id, payload.media_id, user_id) + .await?; + Ok(StatusCode::OK) +} + pub fn album_routes() -> Router { Router::new() .route("/", post(create_album).get(list_user_albums)) @@ -119,6 +142,7 @@ pub fn album_routes() -> Router { .put(update_album) .delete(delete_album), ) + .route("/{id}/thumbnail", put(set_album_thumbnail)) .route("/{id}/media", post(add_media_to_album)) .route("/{id}/share", post(share_album)) } diff --git a/libertas_api/src/handlers/person_handlers.rs b/libertas_api/src/handlers/person_handlers.rs index 28db4b0..94b1763 100644 --- a/libertas_api/src/handlers/person_handlers.rs +++ b/libertas_api/src/handlers/person_handlers.rs @@ -13,7 +13,7 @@ use crate::{ middleware::auth::UserId, schema::{ AssignFaceRequest, CreatePersonRequest, FaceRegionResponse, MergePersonRequest, - PersonResponse, SharePersonRequest, UpdatePersonRequest, + PersonResponse, SetPersonThumbnailRequest, SharePersonRequest, UpdatePersonRequest, }, state::AppState, }; @@ -30,6 +30,7 @@ pub fn people_routes() -> Router { post(share_person).delete(unshare_person), ) .route("/{person_id}/merge", post(merge_person)) + .route("/{person_id}/thumbnail", put(set_person_thumbnail)) } pub fn face_routes() -> Router { @@ -162,3 +163,16 @@ async fn merge_person( .await?; Ok(StatusCode::NO_CONTENT) } + +async fn set_person_thumbnail( + State(state): State, + UserId(user_id): UserId, + Path(person_id): Path, + Json(payload): Json, +) -> Result { + state + .person_service + .set_person_thumbnail(person_id, payload.face_region_id, user_id) + .await?; + Ok(StatusCode::OK) +} diff --git a/libertas_api/src/schema.rs b/libertas_api/src/schema.rs index e28d9d8..609099e 100644 --- a/libertas_api/src/schema.rs +++ b/libertas_api/src/schema.rs @@ -279,3 +279,13 @@ pub struct ServeFileQuery { #[serde(default)] pub strip: bool, } + +#[derive(Deserialize)] +pub struct SetThumbnailRequest { + pub media_id: Uuid, +} + +#[derive(Deserialize)] +pub struct SetPersonThumbnailRequest { + pub face_region_id: Uuid, +} diff --git a/libertas_api/src/services/album_service.rs b/libertas_api/src/services/album_service.rs index 3f5df9c..f723c05 100644 --- a/libertas_api/src/services/album_service.rs +++ b/libertas_api/src/services/album_service.rs @@ -170,4 +170,22 @@ impl AlbumService for AlbumServiceImpl { let media = self.album_repo.list_media_by_album_id(album_id).await?; Ok(PublicAlbumBundle { album, media }) } + + async fn set_album_thumbnail( + &self, + album_id: Uuid, + media_id: Uuid, + user_id: Uuid, + ) -> CoreResult<()> { + self.auth_service + .check_permission(Some(user_id), Permission::EditAlbum(album_id)) + .await?; + self.auth_service + .check_permission(Some(user_id), Permission::ViewMedia(media_id)) + .await?; + + self.album_repo + .set_thumbnail_media_id(album_id, media_id) + .await + } } diff --git a/libertas_api/src/services/authorization_service.rs b/libertas_api/src/services/authorization_service.rs index bb5b482..a56631c 100644 --- a/libertas_api/src/services/authorization_service.rs +++ b/libertas_api/src/services/authorization_service.rs @@ -120,7 +120,6 @@ impl AuthorizationService for AuthorizationServiceImpl { if let Some(ref user) = user { if authz::is_admin(user) { - // [cite: 115] return Ok(()); } } @@ -135,7 +134,6 @@ impl AuthorizationService for AuthorizationServiceImpl { if let Some(id) = user_id { if authz::is_owner(id, &media) { - // [cite: 117] return Ok(()); } @@ -144,7 +142,6 @@ impl AuthorizationService for AuthorizationServiceImpl { .is_media_in_shared_album(media_id, id) .await? { - // [cite: 118-119] return Ok(()); } } diff --git a/libertas_api/src/services/person_service.rs b/libertas_api/src/services/person_service.rs index 173bb4d..302cef0 100644 --- a/libertas_api/src/services/person_service.rs +++ b/libertas_api/src/services/person_service.rs @@ -215,4 +215,34 @@ impl PersonService for PersonServiceImpl { self.person_repo.delete(source_person_id).await } + + async fn set_person_thumbnail( + &self, + person_id: Uuid, + face_region_id: Uuid, + user_id: Uuid, + ) -> CoreResult<()> { + self.auth_service + .check_permission(Some(user_id), authz::Permission::EditPerson(person_id)) + .await?; + + let face_region = + self.face_repo + .find_by_id(face_region_id) + .await? + .ok_or(CoreError::NotFound( + "FaceRegion".to_string(), + face_region_id, + ))?; + + if face_region.person_id != Some(person_id) { + return Err(CoreError::Validation( + "FaceRegion does not belong to the specified person".to_string(), + )); + } + + self.person_repo + .set_thumbnail_media_id(person_id, face_region.media_id) + .await + } } diff --git a/libertas_core/src/ai.rs b/libertas_core/src/ai.rs index c9d4e6f..088744a 100644 --- a/libertas_core/src/ai.rs +++ b/libertas_core/src/ai.rs @@ -15,3 +15,10 @@ pub struct BoundingBox { pub trait FaceDetector: Send + Sync { async fn detect_faces(&self, image_bytes: &[u8]) -> CoreResult>; } + +#[async_trait] +pub trait FaceEmbedder: Send + Sync { + /// Generates a feature vector for a cropped face image. + /// The image bytes should be a pre-cropped face. + async fn generate_embedding(&self, image_bytes: &[u8]) -> CoreResult>; +} diff --git a/libertas_core/src/config.rs b/libertas_core/src/config.rs index 8d33725..ffc4f52 100644 --- a/libertas_core/src/config.rs +++ b/libertas_core/src/config.rs @@ -39,10 +39,20 @@ pub enum FaceDetectorRuntime { RemoteNats { subject: String }, } +#[derive(Deserialize, Clone, Debug)] +#[serde(rename_all = "lowercase")] +pub enum FaceEmbedderRuntime { + Tract, + Onnx, + RemoteNats { subject: String }, +} + #[derive(Deserialize, Clone, Debug)] pub struct AiConfig { pub face_detector_runtime: FaceDetectorRuntime, pub face_detector_model_path: Option, + pub face_embedder_runtime: FaceEmbedderRuntime, + pub face_embedder_model_path: Option, } #[derive(Deserialize, Clone, Debug)] diff --git a/libertas_core/src/models.rs b/libertas_core/src/models.rs index 0cb75d2..8e2208c 100644 --- a/libertas_core/src/models.rs +++ b/libertas_core/src/models.rs @@ -194,3 +194,11 @@ pub struct PersonShare { pub user_id: uuid::Uuid, pub permission: PersonPermission, } + +#[derive(Debug, Clone)] +pub struct FaceEmbedding { + pub id: uuid::Uuid, + pub face_region_id: uuid::Uuid, + pub model_id: i16, + pub embedding: Vec, +} diff --git a/libertas_core/src/plugins.rs b/libertas_core/src/plugins.rs index 025bcf4..0a84462 100644 --- a/libertas_core/src/plugins.rs +++ b/libertas_core/src/plugins.rs @@ -7,8 +7,8 @@ use crate::{ error::CoreResult, models::Media, repositories::{ - AlbumRepository, FaceRegionRepository, MediaMetadataRepository, MediaRepository, - PersonRepository, TagRepository, UserRepository, + AlbumRepository, FaceEmbeddingRepository, FaceRegionRepository, MediaMetadataRepository, + MediaRepository, PersonRepository, TagRepository, UserRepository, }, }; @@ -24,6 +24,7 @@ pub struct PluginContext { pub tag_repo: Arc, pub person_repo: Arc, pub face_region_repo: Arc, + pub face_embedding_repo: Arc, pub media_library_path: String, pub config: Arc, } diff --git a/libertas_core/src/repositories.rs b/libertas_core/src/repositories.rs index 4280e8e..d861422 100644 --- a/libertas_core/src/repositories.rs +++ b/libertas_core/src/repositories.rs @@ -4,8 +4,8 @@ use uuid::Uuid; use crate::{ error::CoreResult, models::{ - Album, AlbumPermission, FaceRegion, Media, MediaMetadata, Person, PersonPermission, Tag, - User, + Album, AlbumPermission, FaceEmbedding, FaceRegion, Media, MediaMetadata, Person, + PersonPermission, Tag, User, }, schema::{ListMediaOptions, MediaImportBundle}, }; @@ -43,6 +43,7 @@ pub trait AlbumRepository: Send + Sync { async fn delete(&self, id: Uuid) -> CoreResult<()>; async fn list_media_by_album_id(&self, album_id: Uuid) -> CoreResult>; async fn is_media_in_public_album(&self, media_id: Uuid) -> CoreResult; + async fn set_thumbnail_media_id(&self, album_id: Uuid, media_id: Uuid) -> CoreResult<()>; } #[async_trait] @@ -90,6 +91,7 @@ pub trait PersonRepository: Send + Sync { async fn list_by_user(&self, user_id: Uuid) -> CoreResult>; async fn update(&self, person: Person) -> CoreResult<()>; async fn delete(&self, id: Uuid) -> CoreResult<()>; + async fn set_thumbnail_media_id(&self, person_id: Uuid, media_id: Uuid) -> CoreResult<()>; } #[async_trait] @@ -129,3 +131,12 @@ pub trait PersonShareRepository: Send + Sync { pub trait MediaImportRepository: Send + Sync { async fn create_media_bundle(&self, bundle: MediaImportBundle) -> CoreResult<()>; } + +#[async_trait] +pub trait FaceEmbeddingRepository: Send + Sync { + async fn create(&self, embedding: &FaceEmbedding) -> CoreResult<()>; + async fn find_by_face_region_id( + &self, + face_region_id: Uuid, + ) -> CoreResult>; +} diff --git a/libertas_core/src/services.rs b/libertas_core/src/services.rs index 91dd3fd..363fc14 100644 --- a/libertas_core/src/services.rs +++ b/libertas_core/src/services.rs @@ -52,6 +52,12 @@ pub trait AlbumService: Send + Sync { ) -> CoreResult; async fn delete_album(&self, album_id: Uuid, user_id: Uuid) -> CoreResult<()>; async fn get_public_album_bundle(&self, album_id: Uuid) -> CoreResult; + async fn set_album_thumbnail( + &self, + album_id: Uuid, + media_id: Uuid, + user_id: Uuid, + ) -> CoreResult<()>; } #[async_trait] @@ -114,6 +120,13 @@ pub trait PersonService: Send + Sync { source_person_id: Uuid, user_id: Uuid, ) -> CoreResult<()>; + + async fn set_person_thumbnail( + &self, + person_id: Uuid, + face_region_id: Uuid, + user_id: Uuid, + ) -> CoreResult<()>; } #[async_trait] diff --git a/libertas_infra/src/ai/mod.rs b/libertas_infra/src/ai/mod.rs index 046bcf2..6dc624c 100644 --- a/libertas_infra/src/ai/mod.rs +++ b/libertas_infra/src/ai/mod.rs @@ -1,2 +1,3 @@ pub mod remote_detector; pub mod tract_detector; +pub mod tract_embedder; diff --git a/libertas_infra/src/ai/tract_embedder.rs b/libertas_infra/src/ai/tract_embedder.rs new file mode 100644 index 0000000..194e537 --- /dev/null +++ b/libertas_infra/src/ai/tract_embedder.rs @@ -0,0 +1,89 @@ +use async_trait::async_trait; +use image::imageops; +use libertas_core::{ + ai::FaceEmbedder, + error::{CoreError, CoreResult}, +}; +use std::sync::Arc; +use tokio::task; + +use tract_onnx::{prelude::*, tract_core::ndarray::Array4}; + +type TractModel = SimplePlan, Graph>>; + +pub struct TractFaceEmbedder { + model: Arc, +} + +impl TractFaceEmbedder { + pub fn new(model_path: &str) -> CoreResult { + let model = tract_onnx::onnx() + .model_for_path(model_path) + .map_err(|e| CoreError::Config(format!("Failed to load embedding model: {}", e)))? + .with_input_fact(0, f32::fact([1, 112, 112, 3]).into()) + .map_err(|e| CoreError::Config(format!("Failed to set input fact: {}", e)))? + .into_optimized() + .map_err(|e| CoreError::Config(format!("Failed to optimize model: {}", e)))? + .into_runnable() + .map_err(|e| CoreError::Config(format!("Failed to make model runnable: {}", e)))?; + + Ok(Self { + model: Arc::new(model), + }) + } +} + +#[async_trait] +impl FaceEmbedder for TractFaceEmbedder { + async fn generate_embedding(&self, image_bytes: &[u8]) -> CoreResult> { + let start_time = std::time::Instant::now(); + + let image_bytes = image_bytes.to_vec(); + let model = self.model.clone(); + + let embedding = task::spawn_blocking(move || { + println!("Running face embedding locally on the CPU..."); + + let img = image::load_from_memory(&image_bytes) + .map_err(|e| CoreError::Unknown(format!("Failed to load cropped face: {}", e)))?; + + let resized = imageops::resize(&img, 112, 112, imageops::FilterType::Triangle); + + let tensor: Tensor = Array4::from_shape_fn((1, 112, 112, 3), |(_, y, x, c)| { + (resized.get_pixel(x as u32, y as u32)[c] as f32 - 127.5) / 128.0 + }) + .into(); + + let result = model + .run(tvec!(tensor.into())) + .map_err(|e| CoreError::Unknown(format!("Failed to run embedding model: {}", e)))?; + + let output_tensor = result[0].to_array_view::().map_err(|e| { + CoreError::Unknown(format!("Failed to convert output tensor: {}", e)) + })?; + + let output_vec: Vec = output_tensor.as_slice().unwrap_or(&[]).to_vec(); + + if output_vec.is_empty() { + return Err(CoreError::Unknown( + "Embedding model returned empty output".to_string(), + )); + } + + let norm = (output_vec.iter().map(|&x| x * x).sum::()).sqrt(); + if norm > 1e-5 { + let normalized_vec: Vec = output_vec.iter().map(|&x| x / norm).collect(); + Ok(normalized_vec) + } else { + Ok(output_vec) + } + }) + .await + .map_err(|e| CoreError::Unknown(format!("Embedding task failed: {}", e)))?; + + let duration = start_time.elapsed(); + println!("Face embedding generated in {} ms", duration.as_millis()); + + embedding + } +} diff --git a/libertas_infra/src/db_models.rs b/libertas_infra/src/db_models.rs index c0d2760..1617cf1 100644 --- a/libertas_infra/src/db_models.rs +++ b/libertas_infra/src/db_models.rs @@ -18,7 +18,6 @@ pub enum PostgresMediaMetadataSource { TrackInfo, } - #[derive(sqlx::FromRow)] pub struct PostgresUser { pub id: Uuid, @@ -116,4 +115,12 @@ pub struct PostgresPersonShared { pub owner_id: Uuid, pub name: String, pub permission: PostgresPersonPermission, -} \ No newline at end of file +} + +#[derive(sqlx::FromRow)] +pub struct PostgresFaceEmbedding { + pub id: Uuid, + pub face_region_id: Uuid, + pub model_id: i16, + pub embedding: Vec, +} diff --git a/libertas_infra/src/factory.rs b/libertas_infra/src/factory.rs index c2e413a..f9851c0 100644 --- a/libertas_infra/src/factory.rs +++ b/libertas_infra/src/factory.rs @@ -177,3 +177,19 @@ pub async fn build_media_import_repository( )), } } + +pub async fn build_face_embedding_repository( + _db_config: &DatabaseConfig, + pool: DatabasePool, +) -> CoreResult> { + match pool { + DatabasePool::Postgres(pg_pool) => Ok(Arc::new( + crate::repositories::face_embedding_repository::PostgresFaceEmbeddingRepository::new( + pg_pool, + ), + )), + DatabasePool::Sqlite(_sqlite_pool) => Err(CoreError::Database( + "Sqlite face embedding repository not implemented".to_string(), + )), + } +} diff --git a/libertas_infra/src/mappers.rs b/libertas_infra/src/mappers.rs index 301ffe7..81d3ead 100644 --- a/libertas_infra/src/mappers.rs +++ b/libertas_infra/src/mappers.rs @@ -1,6 +1,14 @@ -use libertas_core::models::{Album, AlbumPermission, AlbumShare, FaceRegion, Media, MediaMetadata, MediaMetadataSource, Person, PersonPermission, Role, Tag, User}; +use libertas_core::models::{ + Album, AlbumPermission, AlbumShare, FaceEmbedding, FaceRegion, Media, MediaMetadata, + MediaMetadataSource, Person, PersonPermission, Role, Tag, User, +}; -use crate::db_models::{PostgresAlbum, PostgresAlbumPermission, PostgresAlbumShare, PostgresFaceRegion, PostgresMedia, PostgresMediaMetadata, PostgresMediaMetadataSource, PostgresPerson, PostgresPersonPermission, PostgresPersonShared, PostgresRole, PostgresTag, PostgresUser}; +use crate::db_models::{ + PostgresAlbum, PostgresAlbumPermission, PostgresAlbumShare, PostgresFaceEmbedding, + PostgresFaceRegion, PostgresMedia, PostgresMediaMetadata, PostgresMediaMetadataSource, + PostgresPerson, PostgresPersonPermission, PostgresPersonShared, PostgresRole, PostgresTag, + PostgresUser, +}; impl From for Role { fn from(pg_role: PostgresRole) -> Self { @@ -186,4 +194,15 @@ impl From for (Person, PersonPermission) { let permission = PersonPermission::from(pg_shared.permission); (person, permission) } -} \ No newline at end of file +} + +impl From for FaceEmbedding { + fn from(pg_embedding: PostgresFaceEmbedding) -> Self { + Self { + id: pg_embedding.id, + face_region_id: pg_embedding.face_region_id, + model_id: pg_embedding.model_id, + embedding: pg_embedding.embedding, + } + } +} diff --git a/libertas_infra/src/repositories/album_repository.rs b/libertas_infra/src/repositories/album_repository.rs index f42a1e6..69dd0a0 100644 --- a/libertas_infra/src/repositories/album_repository.rs +++ b/libertas_infra/src/repositories/album_repository.rs @@ -166,4 +166,21 @@ impl AlbumRepository for PostgresAlbumRepository { Ok(result.exists) } + + async fn set_thumbnail_media_id(&self, album_id: Uuid, media_id: Uuid) -> CoreResult<()> { + sqlx::query!( + r#" + UPDATE albums + SET thumbnail_media_id = $1 + WHERE id = $2 + "#, + media_id, + album_id + ) + .execute(&self.pool) + .await + .map_err(|e| CoreError::Database(e.to_string()))?; + + Ok(()) + } } diff --git a/libertas_infra/src/repositories/face_embedding_repository.rs b/libertas_infra/src/repositories/face_embedding_repository.rs new file mode 100644 index 0000000..0ada3a0 --- /dev/null +++ b/libertas_infra/src/repositories/face_embedding_repository.rs @@ -0,0 +1,61 @@ +use async_trait::async_trait; +use libertas_core::{ + error::{CoreError, CoreResult}, + models::FaceEmbedding, + repositories::FaceEmbeddingRepository, +}; +use sqlx::PgPool; +use uuid::Uuid; + +use crate::db_models::PostgresFaceEmbedding; + +#[derive(Clone)] +pub struct PostgresFaceEmbeddingRepository { + pool: PgPool, +} + +impl PostgresFaceEmbeddingRepository { + pub fn new(pool: PgPool) -> Self { + Self { pool } + } +} + +#[async_trait] +impl FaceEmbeddingRepository for PostgresFaceEmbeddingRepository { + async fn create(&self, embedding: &FaceEmbedding) -> CoreResult<()> { + sqlx::query!( + r#" + INSERT INTO face_embeddings (id, face_region_id, model_id, embedding) + VALUES ($1, $2, $3, $4) + "#, + embedding.id, + embedding.face_region_id, + embedding.model_id, + embedding.embedding + ) + .execute(&self.pool) + .await + .map_err(|e| CoreError::Database(e.to_string()))?; + Ok(()) + } + + async fn find_by_face_region_id( + &self, + face_region_id: Uuid, + ) -> CoreResult> { + let pg_embedding = sqlx::query_as!( + PostgresFaceEmbedding, + r#" + SELECT id, face_region_id, model_id, embedding + FROM face_embeddings + WHERE face_region_id = $1 + "#, + face_region_id + ) + .fetch_optional(&self.pool) + .await + .map_err(|e| CoreError::Database(e.to_string()))?; + + Ok(pg_embedding.map(FaceEmbedding::from)) + } +} diff --git a/libertas_infra/src/repositories/mod.rs b/libertas_infra/src/repositories/mod.rs index 83df769..5590626 100644 --- a/libertas_infra/src/repositories/mod.rs +++ b/libertas_infra/src/repositories/mod.rs @@ -1,5 +1,6 @@ pub mod album_repository; pub mod album_share_repository; +pub mod face_embedding_repository; pub mod face_region_repository; pub mod media_import_repository; pub mod media_metadata_repository; diff --git a/libertas_infra/src/repositories/person_repository.rs b/libertas_infra/src/repositories/person_repository.rs index 0adebe8..b2a2129 100644 --- a/libertas_infra/src/repositories/person_repository.rs +++ b/libertas_infra/src/repositories/person_repository.rs @@ -1,5 +1,9 @@ use async_trait::async_trait; -use libertas_core::{error::{CoreError, CoreResult}, models::Person, repositories::PersonRepository}; +use libertas_core::{ + error::{CoreError, CoreResult}, + models::Person, + repositories::PersonRepository, +}; use sqlx::PgPool; use uuid::Uuid; @@ -95,4 +99,20 @@ impl PersonRepository for PostgresPersonRepository { .map_err(|e| CoreError::Database(e.to_string()))?; Ok(()) } -} \ No newline at end of file + + async fn set_thumbnail_media_id(&self, person_id: Uuid, media_id: Uuid) -> CoreResult<()> { + sqlx::query!( + r#" + UPDATE people + SET thumbnail_media_id = $1 + WHERE id = $2 + "#, + media_id, + person_id + ) + .execute(&self.pool) + .await + .map_err(|e| CoreError::Database(e.to_string()))?; + Ok(()) + } +} diff --git a/libertas_worker/src/main.rs b/libertas_worker/src/main.rs index 8f471d5..ae42966 100644 --- a/libertas_worker/src/main.rs +++ b/libertas_worker/src/main.rs @@ -3,9 +3,9 @@ use std::{path::PathBuf, sync::Arc}; use futures_util::StreamExt; use libertas_core::plugins::PluginContext; use libertas_infra::factory::{ - build_album_repository, build_database_pool, build_face_region_repository, - build_media_metadata_repository, build_media_repository, build_person_repository, - build_tag_repository, build_user_repository, + build_album_repository, build_database_pool, build_face_embedding_repository, + build_face_region_repository, build_media_metadata_repository, build_media_repository, + build_person_repository, build_tag_repository, build_user_repository, }; use serde::Deserialize; use tokio::fs; @@ -45,6 +45,8 @@ async fn main() -> anyhow::Result<()> { let tag_repo = build_tag_repository(&config.database, db_pool.clone()).await?; let person_repo = build_person_repository(&config.database, db_pool.clone()).await?; let face_region_repo = build_face_region_repository(&config.database, db_pool.clone()).await?; + let face_embedding_repo = + build_face_embedding_repository(&config.database, db_pool.clone()).await?; let context = Arc::new(PluginContext { media_repo, @@ -53,6 +55,7 @@ async fn main() -> anyhow::Result<()> { tag_repo, person_repo, face_region_repo, + face_embedding_repo, metadata_repo, media_library_path: config.media_library_path.clone(), config: Arc::new(config.clone()), diff --git a/libertas_worker/src/plugin_manager.rs b/libertas_worker/src/plugin_manager.rs index 96c6bb8..4e0658b 100644 --- a/libertas_worker/src/plugin_manager.rs +++ b/libertas_worker/src/plugin_manager.rs @@ -1,19 +1,20 @@ use std::sync::Arc; use libertas_core::{ - ai::FaceDetector, - config::{AiConfig, AppConfig, FaceDetectorRuntime}, + ai::{FaceDetector, FaceEmbedder}, + config::{AiConfig, AppConfig, FaceDetectorRuntime, FaceEmbedderRuntime}, error::{CoreError, CoreResult}, models::Media, plugins::{MediaProcessorPlugin, PluginContext}, }; use libertas_infra::ai::{ remote_detector::RemoteNatsFaceDetector, tract_detector::TractFaceDetector, + tract_embedder::TractFaceEmbedder, }; use crate::plugins::{ - exif_reader::ExifReaderPlugin, face_detector::FaceDetectionPlugin, thumbnail::ThumbnailPlugin, - xmp_writer::XmpWriterPlugin, + embedding_generator::EmbeddingGeneratorPlugin, exif_reader::ExifReaderPlugin, + face_detector::FaceDetectionPlugin, thumbnail::ThumbnailPlugin, xmp_writer::XmpWriterPlugin, }; pub struct PluginManager { @@ -25,7 +26,7 @@ impl PluginManager { let mut plugins: Vec> = Vec::new(); if let Some(ai_config) = &config.ai_config { - match build_face_detector(ai_config, nats_client) { + match build_face_detector(ai_config, nats_client.clone()) { Ok(detector) => { plugins.push(Arc::new(FaceDetectionPlugin::new(detector))); println!("FaceDetectionPlugin loaded."); @@ -34,6 +35,16 @@ impl PluginManager { eprintln!("Failed to load FaceDetectionPlugin: {}", e); } } + + match build_face_embedder(ai_config, nats_client.clone()) { + Ok(embedder) => { + plugins.push(Arc::new(EmbeddingGeneratorPlugin::new(embedder))); + println!("EmbeddingGeneratorPlugin loaded."); + } + Err(e) => { + eprintln!("Failed to load EmbeddingGeneratorPlugin: {}", e); + } + } } plugins.push(Arc::new(ExifReaderPlugin)); @@ -86,3 +97,27 @@ fn build_face_detector( ))), } } + +fn build_face_embedder( + config: &AiConfig, + _nats_client: async_nats::Client, +) -> CoreResult> { + match &config.face_embedder_runtime { + FaceEmbedderRuntime::Tract => { + let model_path = + config + .face_embedder_model_path + .as_deref() + .ok_or(CoreError::Config( + "Tract runtime needs 'face_embedder_model_path'".to_string(), + ))?; + Ok(Box::new(TractFaceEmbedder::new(model_path)?)) + } + FaceEmbedderRuntime::Onnx => { + unimplemented!("ONNX face embedder not implemented yet"); + } + FaceEmbedderRuntime::RemoteNats { subject: _ } => { + unimplemented!("RemoteNats face embedder not implemented yet"); + } + } +} diff --git a/libertas_worker/src/plugins/embedding_generator.rs b/libertas_worker/src/plugins/embedding_generator.rs new file mode 100644 index 0000000..aa4f053 --- /dev/null +++ b/libertas_worker/src/plugins/embedding_generator.rs @@ -0,0 +1,110 @@ +use std::{io::Cursor, path::PathBuf}; + +use async_trait::async_trait; +use image::{ImageFormat, ImageReader}; +use libertas_core::{ + ai::FaceEmbedder, + error::{CoreError, CoreResult}, + models::{FaceEmbedding, Media}, + plugins::{MediaProcessorPlugin, PluginContext, PluginData}, +}; +use tokio::fs; + +pub struct EmbeddingGeneratorPlugin { + embedder: Box, + model_id: i16, +} + +impl EmbeddingGeneratorPlugin { + pub fn new(embedder: Box) -> Self { + Self { + embedder, + model_id: 1, // todo: come from config or something + } + } + + fn f32_vec_to_bytes(vec: &[f32]) -> Vec { + vec.iter().flat_map(|&f| f.to_le_bytes()).collect() + } +} + +#[async_trait] +impl MediaProcessorPlugin for EmbeddingGeneratorPlugin { + fn name(&self) -> &'static str { + "embedding_generator" + } + + async fn process(&self, media: &Media, context: &PluginContext) -> CoreResult { + if !media.mime_type.starts_with("image/") { + return Ok(PluginData { + message: "Not an image, skipping.".to_string(), + }); + } + + // 1. Get all face regions for this media + let faces = context.face_region_repo.find_by_media_id(media.id).await?; + if faces.is_empty() { + return Ok(PluginData { + message: "No faces found to embed.".to_string(), + }); + } + + // 2. Load the full original image + let file_path = PathBuf::from(&context.media_library_path).join(&media.storage_path); + let image_bytes = fs::read(file_path).await?; + let img = ImageReader::new(Cursor::new(&image_bytes)) + .with_guessed_format()? + .decode() + .map_err(|e| CoreError::Unknown(format!("Failed to decode image: {}", e)))?; + + let mut new_embeddings = 0; + + for face in faces { + // 3. Check if embedding already exists + if context + .face_embedding_repo + .find_by_face_region_id(face.id) + .await? + .is_some() + { + continue; + } + + // 4. Crop the face from the main image + let cropped_face = img.crop_imm( + face.x_min as u32, + face.y_min as u32, + (face.x_max - face.x_min) as u32, + (face.y_max - face.y_min) as u32, + ); + + // 5. Convert cropped image back to bytes (as JPEG) + let mut buf = Cursor::new(Vec::new()); + cropped_face + .write_to(&mut buf, ImageFormat::Jpeg) + .map_err(|e| { + CoreError::Unknown(format!("Failed to encode cropped image: {}", e)) + })?; + let cropped_bytes = buf.into_inner(); + + // 6. Generate the embedding + let embedding_f32 = self.embedder.generate_embedding(&cropped_bytes).await?; + let embedding_bytes = Self::f32_vec_to_bytes(&embedding_f32); + + // 7. Save to database + let embedding_model = FaceEmbedding { + id: uuid::Uuid::new_v4(), + face_region_id: face.id, + model_id: self.model_id, + embedding: embedding_bytes, + }; + + context.face_embedding_repo.create(&embedding_model).await?; + new_embeddings += 1; + } + + Ok(PluginData { + message: format!("Generated {} new embeddings.", new_embeddings), + }) + } +} diff --git a/libertas_worker/src/plugins/mod.rs b/libertas_worker/src/plugins/mod.rs index 28fc3e4..955a43b 100644 --- a/libertas_worker/src/plugins/mod.rs +++ b/libertas_worker/src/plugins/mod.rs @@ -1,3 +1,4 @@ +pub mod embedding_generator; pub mod exif_reader; pub mod face_detector; pub mod thumbnail;