feat: Add thumbnail management for albums and people, implement face embedding functionality
This commit is contained in:
@@ -0,0 +1,2 @@
|
||||
ALTER TABLE albums
|
||||
ADD COLUMN thumbnail_media_id UUID REFERENCES media(id) ON DELETE SET NULL;
|
||||
@@ -0,0 +1,2 @@
|
||||
ALTER TABLE people
|
||||
ADD COLUMN thumbnail_media_id UUID REFERENCES media(id) ON DELETE SET NULL;
|
||||
@@ -0,0 +1,7 @@
|
||||
CREATE TABLE face_embeddings (
|
||||
id UUID PRIMARY KEY,
|
||||
face_region_id UUID NOT NULL REFERENCES face_regions(id) ON DELETE CASCADE,
|
||||
model_id SMALLINT NOT NULL,
|
||||
embedding BYTEA NOT NULL
|
||||
);
|
||||
CREATE UNIQUE INDEX idx_face_embeddings_region_id ON face_embeddings (face_region_id);
|
||||
@@ -2,12 +2,22 @@ use axum::{
|
||||
Json, Router,
|
||||
extract::{Path, State},
|
||||
http::StatusCode,
|
||||
routing::{get, post},
|
||||
routing::{get, post, put},
|
||||
};
|
||||
use libertas_core::schema::{
|
||||
AddMediaToAlbumData, CreateAlbumData, ShareAlbumData, UpdateAlbumData,
|
||||
};
|
||||
use libertas_core::schema::{AddMediaToAlbumData, CreateAlbumData, ShareAlbumData, UpdateAlbumData};
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::{error::ApiError, middleware::auth::UserId, schema::{AddMediaToAlbumRequest, AlbumResponse, CreateAlbumRequest, ShareAlbumRequest, UpdateAlbumRequest}, state::AppState};
|
||||
use crate::{
|
||||
error::ApiError,
|
||||
middleware::auth::UserId,
|
||||
schema::{
|
||||
AddMediaToAlbumRequest, AlbumResponse, CreateAlbumRequest, SetThumbnailRequest,
|
||||
ShareAlbumRequest, UpdateAlbumRequest,
|
||||
},
|
||||
state::AppState,
|
||||
};
|
||||
|
||||
async fn create_album(
|
||||
State(state): State<AppState>,
|
||||
@@ -110,6 +120,19 @@ async fn delete_album(
|
||||
Ok(StatusCode::NO_CONTENT)
|
||||
}
|
||||
|
||||
async fn set_album_thumbnail(
|
||||
State(state): State<AppState>,
|
||||
UserId(user_id): UserId,
|
||||
Path(album_id): Path<Uuid>,
|
||||
Json(payload): Json<SetThumbnailRequest>,
|
||||
) -> Result<StatusCode, ApiError> {
|
||||
state
|
||||
.album_service
|
||||
.set_album_thumbnail(album_id, payload.media_id, user_id)
|
||||
.await?;
|
||||
Ok(StatusCode::OK)
|
||||
}
|
||||
|
||||
pub fn album_routes() -> Router<AppState> {
|
||||
Router::new()
|
||||
.route("/", post(create_album).get(list_user_albums))
|
||||
@@ -119,6 +142,7 @@ pub fn album_routes() -> Router<AppState> {
|
||||
.put(update_album)
|
||||
.delete(delete_album),
|
||||
)
|
||||
.route("/{id}/thumbnail", put(set_album_thumbnail))
|
||||
.route("/{id}/media", post(add_media_to_album))
|
||||
.route("/{id}/share", post(share_album))
|
||||
}
|
||||
|
||||
@@ -13,7 +13,7 @@ use crate::{
|
||||
middleware::auth::UserId,
|
||||
schema::{
|
||||
AssignFaceRequest, CreatePersonRequest, FaceRegionResponse, MergePersonRequest,
|
||||
PersonResponse, SharePersonRequest, UpdatePersonRequest,
|
||||
PersonResponse, SetPersonThumbnailRequest, SharePersonRequest, UpdatePersonRequest,
|
||||
},
|
||||
state::AppState,
|
||||
};
|
||||
@@ -30,6 +30,7 @@ pub fn people_routes() -> Router<AppState> {
|
||||
post(share_person).delete(unshare_person),
|
||||
)
|
||||
.route("/{person_id}/merge", post(merge_person))
|
||||
.route("/{person_id}/thumbnail", put(set_person_thumbnail))
|
||||
}
|
||||
|
||||
pub fn face_routes() -> Router<AppState> {
|
||||
@@ -162,3 +163,16 @@ async fn merge_person(
|
||||
.await?;
|
||||
Ok(StatusCode::NO_CONTENT)
|
||||
}
|
||||
|
||||
async fn set_person_thumbnail(
|
||||
State(state): State<AppState>,
|
||||
UserId(user_id): UserId,
|
||||
Path(person_id): Path<Uuid>,
|
||||
Json(payload): Json<SetPersonThumbnailRequest>,
|
||||
) -> Result<StatusCode, ApiError> {
|
||||
state
|
||||
.person_service
|
||||
.set_person_thumbnail(person_id, payload.face_region_id, user_id)
|
||||
.await?;
|
||||
Ok(StatusCode::OK)
|
||||
}
|
||||
|
||||
@@ -279,3 +279,13 @@ pub struct ServeFileQuery {
|
||||
#[serde(default)]
|
||||
pub strip: bool,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct SetThumbnailRequest {
|
||||
pub media_id: Uuid,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct SetPersonThumbnailRequest {
|
||||
pub face_region_id: Uuid,
|
||||
}
|
||||
|
||||
@@ -170,4 +170,22 @@ impl AlbumService for AlbumServiceImpl {
|
||||
let media = self.album_repo.list_media_by_album_id(album_id).await?;
|
||||
Ok(PublicAlbumBundle { album, media })
|
||||
}
|
||||
|
||||
async fn set_album_thumbnail(
|
||||
&self,
|
||||
album_id: Uuid,
|
||||
media_id: Uuid,
|
||||
user_id: Uuid,
|
||||
) -> CoreResult<()> {
|
||||
self.auth_service
|
||||
.check_permission(Some(user_id), Permission::EditAlbum(album_id))
|
||||
.await?;
|
||||
self.auth_service
|
||||
.check_permission(Some(user_id), Permission::ViewMedia(media_id))
|
||||
.await?;
|
||||
|
||||
self.album_repo
|
||||
.set_thumbnail_media_id(album_id, media_id)
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
@@ -120,7 +120,6 @@ impl AuthorizationService for AuthorizationServiceImpl {
|
||||
|
||||
if let Some(ref user) = user {
|
||||
if authz::is_admin(user) {
|
||||
// [cite: 115]
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
@@ -135,7 +134,6 @@ impl AuthorizationService for AuthorizationServiceImpl {
|
||||
|
||||
if let Some(id) = user_id {
|
||||
if authz::is_owner(id, &media) {
|
||||
// [cite: 117]
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
@@ -144,7 +142,6 @@ impl AuthorizationService for AuthorizationServiceImpl {
|
||||
.is_media_in_shared_album(media_id, id)
|
||||
.await?
|
||||
{
|
||||
// [cite: 118-119]
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -215,4 +215,34 @@ impl PersonService for PersonServiceImpl {
|
||||
|
||||
self.person_repo.delete(source_person_id).await
|
||||
}
|
||||
|
||||
async fn set_person_thumbnail(
|
||||
&self,
|
||||
person_id: Uuid,
|
||||
face_region_id: Uuid,
|
||||
user_id: Uuid,
|
||||
) -> CoreResult<()> {
|
||||
self.auth_service
|
||||
.check_permission(Some(user_id), authz::Permission::EditPerson(person_id))
|
||||
.await?;
|
||||
|
||||
let face_region =
|
||||
self.face_repo
|
||||
.find_by_id(face_region_id)
|
||||
.await?
|
||||
.ok_or(CoreError::NotFound(
|
||||
"FaceRegion".to_string(),
|
||||
face_region_id,
|
||||
))?;
|
||||
|
||||
if face_region.person_id != Some(person_id) {
|
||||
return Err(CoreError::Validation(
|
||||
"FaceRegion does not belong to the specified person".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
self.person_repo
|
||||
.set_thumbnail_media_id(person_id, face_region.media_id)
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
@@ -15,3 +15,10 @@ pub struct BoundingBox {
|
||||
pub trait FaceDetector: Send + Sync {
|
||||
async fn detect_faces(&self, image_bytes: &[u8]) -> CoreResult<Vec<BoundingBox>>;
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
pub trait FaceEmbedder: Send + Sync {
|
||||
/// Generates a feature vector for a cropped face image.
|
||||
/// The image bytes should be a pre-cropped face.
|
||||
async fn generate_embedding(&self, image_bytes: &[u8]) -> CoreResult<Vec<f32>>;
|
||||
}
|
||||
|
||||
@@ -39,10 +39,20 @@ pub enum FaceDetectorRuntime {
|
||||
RemoteNats { subject: String },
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Clone, Debug)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum FaceEmbedderRuntime {
|
||||
Tract,
|
||||
Onnx,
|
||||
RemoteNats { subject: String },
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Clone, Debug)]
|
||||
pub struct AiConfig {
|
||||
pub face_detector_runtime: FaceDetectorRuntime,
|
||||
pub face_detector_model_path: Option<String>,
|
||||
pub face_embedder_runtime: FaceEmbedderRuntime,
|
||||
pub face_embedder_model_path: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Clone, Debug)]
|
||||
|
||||
@@ -194,3 +194,11 @@ pub struct PersonShare {
|
||||
pub user_id: uuid::Uuid,
|
||||
pub permission: PersonPermission,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct FaceEmbedding {
|
||||
pub id: uuid::Uuid,
|
||||
pub face_region_id: uuid::Uuid,
|
||||
pub model_id: i16,
|
||||
pub embedding: Vec<u8>,
|
||||
}
|
||||
|
||||
@@ -7,8 +7,8 @@ use crate::{
|
||||
error::CoreResult,
|
||||
models::Media,
|
||||
repositories::{
|
||||
AlbumRepository, FaceRegionRepository, MediaMetadataRepository, MediaRepository,
|
||||
PersonRepository, TagRepository, UserRepository,
|
||||
AlbumRepository, FaceEmbeddingRepository, FaceRegionRepository, MediaMetadataRepository,
|
||||
MediaRepository, PersonRepository, TagRepository, UserRepository,
|
||||
},
|
||||
};
|
||||
|
||||
@@ -24,6 +24,7 @@ pub struct PluginContext {
|
||||
pub tag_repo: Arc<dyn TagRepository>,
|
||||
pub person_repo: Arc<dyn PersonRepository>,
|
||||
pub face_region_repo: Arc<dyn FaceRegionRepository>,
|
||||
pub face_embedding_repo: Arc<dyn FaceEmbeddingRepository>,
|
||||
pub media_library_path: String,
|
||||
pub config: Arc<AppConfig>,
|
||||
}
|
||||
|
||||
@@ -4,8 +4,8 @@ use uuid::Uuid;
|
||||
use crate::{
|
||||
error::CoreResult,
|
||||
models::{
|
||||
Album, AlbumPermission, FaceRegion, Media, MediaMetadata, Person, PersonPermission, Tag,
|
||||
User,
|
||||
Album, AlbumPermission, FaceEmbedding, FaceRegion, Media, MediaMetadata, Person,
|
||||
PersonPermission, Tag, User,
|
||||
},
|
||||
schema::{ListMediaOptions, MediaImportBundle},
|
||||
};
|
||||
@@ -43,6 +43,7 @@ pub trait AlbumRepository: Send + Sync {
|
||||
async fn delete(&self, id: Uuid) -> CoreResult<()>;
|
||||
async fn list_media_by_album_id(&self, album_id: Uuid) -> CoreResult<Vec<Media>>;
|
||||
async fn is_media_in_public_album(&self, media_id: Uuid) -> CoreResult<bool>;
|
||||
async fn set_thumbnail_media_id(&self, album_id: Uuid, media_id: Uuid) -> CoreResult<()>;
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
@@ -90,6 +91,7 @@ pub trait PersonRepository: Send + Sync {
|
||||
async fn list_by_user(&self, user_id: Uuid) -> CoreResult<Vec<Person>>;
|
||||
async fn update(&self, person: Person) -> CoreResult<()>;
|
||||
async fn delete(&self, id: Uuid) -> CoreResult<()>;
|
||||
async fn set_thumbnail_media_id(&self, person_id: Uuid, media_id: Uuid) -> CoreResult<()>;
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
@@ -129,3 +131,12 @@ pub trait PersonShareRepository: Send + Sync {
|
||||
pub trait MediaImportRepository: Send + Sync {
|
||||
async fn create_media_bundle(&self, bundle: MediaImportBundle) -> CoreResult<()>;
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
pub trait FaceEmbeddingRepository: Send + Sync {
|
||||
async fn create(&self, embedding: &FaceEmbedding) -> CoreResult<()>;
|
||||
async fn find_by_face_region_id(
|
||||
&self,
|
||||
face_region_id: Uuid,
|
||||
) -> CoreResult<Option<FaceEmbedding>>;
|
||||
}
|
||||
|
||||
@@ -52,6 +52,12 @@ pub trait AlbumService: Send + Sync {
|
||||
) -> CoreResult<Album>;
|
||||
async fn delete_album(&self, album_id: Uuid, user_id: Uuid) -> CoreResult<()>;
|
||||
async fn get_public_album_bundle(&self, album_id: Uuid) -> CoreResult<PublicAlbumBundle>;
|
||||
async fn set_album_thumbnail(
|
||||
&self,
|
||||
album_id: Uuid,
|
||||
media_id: Uuid,
|
||||
user_id: Uuid,
|
||||
) -> CoreResult<()>;
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
@@ -114,6 +120,13 @@ pub trait PersonService: Send + Sync {
|
||||
source_person_id: Uuid,
|
||||
user_id: Uuid,
|
||||
) -> CoreResult<()>;
|
||||
|
||||
async fn set_person_thumbnail(
|
||||
&self,
|
||||
person_id: Uuid,
|
||||
face_region_id: Uuid,
|
||||
user_id: Uuid,
|
||||
) -> CoreResult<()>;
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
|
||||
@@ -1,2 +1,3 @@
|
||||
pub mod remote_detector;
|
||||
pub mod tract_detector;
|
||||
pub mod tract_embedder;
|
||||
|
||||
89
libertas_infra/src/ai/tract_embedder.rs
Normal file
89
libertas_infra/src/ai/tract_embedder.rs
Normal file
@@ -0,0 +1,89 @@
|
||||
use async_trait::async_trait;
|
||||
use image::imageops;
|
||||
use libertas_core::{
|
||||
ai::FaceEmbedder,
|
||||
error::{CoreError, CoreResult},
|
||||
};
|
||||
use std::sync::Arc;
|
||||
use tokio::task;
|
||||
|
||||
use tract_onnx::{prelude::*, tract_core::ndarray::Array4};
|
||||
|
||||
type TractModel = SimplePlan<TypedFact, Box<dyn TypedOp>, Graph<TypedFact, Box<dyn TypedOp>>>;
|
||||
|
||||
pub struct TractFaceEmbedder {
|
||||
model: Arc<TractModel>,
|
||||
}
|
||||
|
||||
impl TractFaceEmbedder {
|
||||
pub fn new(model_path: &str) -> CoreResult<Self> {
|
||||
let model = tract_onnx::onnx()
|
||||
.model_for_path(model_path)
|
||||
.map_err(|e| CoreError::Config(format!("Failed to load embedding model: {}", e)))?
|
||||
.with_input_fact(0, f32::fact([1, 112, 112, 3]).into())
|
||||
.map_err(|e| CoreError::Config(format!("Failed to set input fact: {}", e)))?
|
||||
.into_optimized()
|
||||
.map_err(|e| CoreError::Config(format!("Failed to optimize model: {}", e)))?
|
||||
.into_runnable()
|
||||
.map_err(|e| CoreError::Config(format!("Failed to make model runnable: {}", e)))?;
|
||||
|
||||
Ok(Self {
|
||||
model: Arc::new(model),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl FaceEmbedder for TractFaceEmbedder {
|
||||
async fn generate_embedding(&self, image_bytes: &[u8]) -> CoreResult<Vec<f32>> {
|
||||
let start_time = std::time::Instant::now();
|
||||
|
||||
let image_bytes = image_bytes.to_vec();
|
||||
let model = self.model.clone();
|
||||
|
||||
let embedding = task::spawn_blocking(move || {
|
||||
println!("Running face embedding locally on the CPU...");
|
||||
|
||||
let img = image::load_from_memory(&image_bytes)
|
||||
.map_err(|e| CoreError::Unknown(format!("Failed to load cropped face: {}", e)))?;
|
||||
|
||||
let resized = imageops::resize(&img, 112, 112, imageops::FilterType::Triangle);
|
||||
|
||||
let tensor: Tensor = Array4::from_shape_fn((1, 112, 112, 3), |(_, y, x, c)| {
|
||||
(resized.get_pixel(x as u32, y as u32)[c] as f32 - 127.5) / 128.0
|
||||
})
|
||||
.into();
|
||||
|
||||
let result = model
|
||||
.run(tvec!(tensor.into()))
|
||||
.map_err(|e| CoreError::Unknown(format!("Failed to run embedding model: {}", e)))?;
|
||||
|
||||
let output_tensor = result[0].to_array_view::<f32>().map_err(|e| {
|
||||
CoreError::Unknown(format!("Failed to convert output tensor: {}", e))
|
||||
})?;
|
||||
|
||||
let output_vec: Vec<f32> = output_tensor.as_slice().unwrap_or(&[]).to_vec();
|
||||
|
||||
if output_vec.is_empty() {
|
||||
return Err(CoreError::Unknown(
|
||||
"Embedding model returned empty output".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let norm = (output_vec.iter().map(|&x| x * x).sum::<f32>()).sqrt();
|
||||
if norm > 1e-5 {
|
||||
let normalized_vec: Vec<f32> = output_vec.iter().map(|&x| x / norm).collect();
|
||||
Ok(normalized_vec)
|
||||
} else {
|
||||
Ok(output_vec)
|
||||
}
|
||||
})
|
||||
.await
|
||||
.map_err(|e| CoreError::Unknown(format!("Embedding task failed: {}", e)))?;
|
||||
|
||||
let duration = start_time.elapsed();
|
||||
println!("Face embedding generated in {} ms", duration.as_millis());
|
||||
|
||||
embedding
|
||||
}
|
||||
}
|
||||
@@ -18,7 +18,6 @@ pub enum PostgresMediaMetadataSource {
|
||||
TrackInfo,
|
||||
}
|
||||
|
||||
|
||||
#[derive(sqlx::FromRow)]
|
||||
pub struct PostgresUser {
|
||||
pub id: Uuid,
|
||||
@@ -116,4 +115,12 @@ pub struct PostgresPersonShared {
|
||||
pub owner_id: Uuid,
|
||||
pub name: String,
|
||||
pub permission: PostgresPersonPermission,
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(sqlx::FromRow)]
|
||||
pub struct PostgresFaceEmbedding {
|
||||
pub id: Uuid,
|
||||
pub face_region_id: Uuid,
|
||||
pub model_id: i16,
|
||||
pub embedding: Vec<u8>,
|
||||
}
|
||||
|
||||
@@ -177,3 +177,19 @@ pub async fn build_media_import_repository(
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn build_face_embedding_repository(
|
||||
_db_config: &DatabaseConfig,
|
||||
pool: DatabasePool,
|
||||
) -> CoreResult<Arc<dyn libertas_core::repositories::FaceEmbeddingRepository>> {
|
||||
match pool {
|
||||
DatabasePool::Postgres(pg_pool) => Ok(Arc::new(
|
||||
crate::repositories::face_embedding_repository::PostgresFaceEmbeddingRepository::new(
|
||||
pg_pool,
|
||||
),
|
||||
)),
|
||||
DatabasePool::Sqlite(_sqlite_pool) => Err(CoreError::Database(
|
||||
"Sqlite face embedding repository not implemented".to_string(),
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,14 @@
|
||||
use libertas_core::models::{Album, AlbumPermission, AlbumShare, FaceRegion, Media, MediaMetadata, MediaMetadataSource, Person, PersonPermission, Role, Tag, User};
|
||||
use libertas_core::models::{
|
||||
Album, AlbumPermission, AlbumShare, FaceEmbedding, FaceRegion, Media, MediaMetadata,
|
||||
MediaMetadataSource, Person, PersonPermission, Role, Tag, User,
|
||||
};
|
||||
|
||||
use crate::db_models::{PostgresAlbum, PostgresAlbumPermission, PostgresAlbumShare, PostgresFaceRegion, PostgresMedia, PostgresMediaMetadata, PostgresMediaMetadataSource, PostgresPerson, PostgresPersonPermission, PostgresPersonShared, PostgresRole, PostgresTag, PostgresUser};
|
||||
use crate::db_models::{
|
||||
PostgresAlbum, PostgresAlbumPermission, PostgresAlbumShare, PostgresFaceEmbedding,
|
||||
PostgresFaceRegion, PostgresMedia, PostgresMediaMetadata, PostgresMediaMetadataSource,
|
||||
PostgresPerson, PostgresPersonPermission, PostgresPersonShared, PostgresRole, PostgresTag,
|
||||
PostgresUser,
|
||||
};
|
||||
|
||||
impl From<PostgresRole> for Role {
|
||||
fn from(pg_role: PostgresRole) -> Self {
|
||||
@@ -186,4 +194,15 @@ impl From<PostgresPersonShared> for (Person, PersonPermission) {
|
||||
let permission = PersonPermission::from(pg_shared.permission);
|
||||
(person, permission)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<PostgresFaceEmbedding> for FaceEmbedding {
|
||||
fn from(pg_embedding: PostgresFaceEmbedding) -> Self {
|
||||
Self {
|
||||
id: pg_embedding.id,
|
||||
face_region_id: pg_embedding.face_region_id,
|
||||
model_id: pg_embedding.model_id,
|
||||
embedding: pg_embedding.embedding,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -166,4 +166,21 @@ impl AlbumRepository for PostgresAlbumRepository {
|
||||
|
||||
Ok(result.exists)
|
||||
}
|
||||
|
||||
async fn set_thumbnail_media_id(&self, album_id: Uuid, media_id: Uuid) -> CoreResult<()> {
|
||||
sqlx::query!(
|
||||
r#"
|
||||
UPDATE albums
|
||||
SET thumbnail_media_id = $1
|
||||
WHERE id = $2
|
||||
"#,
|
||||
media_id,
|
||||
album_id
|
||||
)
|
||||
.execute(&self.pool)
|
||||
.await
|
||||
.map_err(|e| CoreError::Database(e.to_string()))?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
61
libertas_infra/src/repositories/face_embedding_repository.rs
Normal file
61
libertas_infra/src/repositories/face_embedding_repository.rs
Normal file
@@ -0,0 +1,61 @@
|
||||
use async_trait::async_trait;
|
||||
use libertas_core::{
|
||||
error::{CoreError, CoreResult},
|
||||
models::FaceEmbedding,
|
||||
repositories::FaceEmbeddingRepository,
|
||||
};
|
||||
use sqlx::PgPool;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::db_models::PostgresFaceEmbedding;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct PostgresFaceEmbeddingRepository {
|
||||
pool: PgPool,
|
||||
}
|
||||
|
||||
impl PostgresFaceEmbeddingRepository {
|
||||
pub fn new(pool: PgPool) -> Self {
|
||||
Self { pool }
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl FaceEmbeddingRepository for PostgresFaceEmbeddingRepository {
|
||||
async fn create(&self, embedding: &FaceEmbedding) -> CoreResult<()> {
|
||||
sqlx::query!(
|
||||
r#"
|
||||
INSERT INTO face_embeddings (id, face_region_id, model_id, embedding)
|
||||
VALUES ($1, $2, $3, $4)
|
||||
"#,
|
||||
embedding.id,
|
||||
embedding.face_region_id,
|
||||
embedding.model_id,
|
||||
embedding.embedding
|
||||
)
|
||||
.execute(&self.pool)
|
||||
.await
|
||||
.map_err(|e| CoreError::Database(e.to_string()))?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn find_by_face_region_id(
|
||||
&self,
|
||||
face_region_id: Uuid,
|
||||
) -> CoreResult<Option<FaceEmbedding>> {
|
||||
let pg_embedding = sqlx::query_as!(
|
||||
PostgresFaceEmbedding,
|
||||
r#"
|
||||
SELECT id, face_region_id, model_id, embedding
|
||||
FROM face_embeddings
|
||||
WHERE face_region_id = $1
|
||||
"#,
|
||||
face_region_id
|
||||
)
|
||||
.fetch_optional(&self.pool)
|
||||
.await
|
||||
.map_err(|e| CoreError::Database(e.to_string()))?;
|
||||
|
||||
Ok(pg_embedding.map(FaceEmbedding::from))
|
||||
}
|
||||
}
|
||||
@@ -1,5 +1,6 @@
|
||||
pub mod album_repository;
|
||||
pub mod album_share_repository;
|
||||
pub mod face_embedding_repository;
|
||||
pub mod face_region_repository;
|
||||
pub mod media_import_repository;
|
||||
pub mod media_metadata_repository;
|
||||
|
||||
@@ -1,5 +1,9 @@
|
||||
use async_trait::async_trait;
|
||||
use libertas_core::{error::{CoreError, CoreResult}, models::Person, repositories::PersonRepository};
|
||||
use libertas_core::{
|
||||
error::{CoreError, CoreResult},
|
||||
models::Person,
|
||||
repositories::PersonRepository,
|
||||
};
|
||||
use sqlx::PgPool;
|
||||
use uuid::Uuid;
|
||||
|
||||
@@ -95,4 +99,20 @@ impl PersonRepository for PostgresPersonRepository {
|
||||
.map_err(|e| CoreError::Database(e.to_string()))?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
async fn set_thumbnail_media_id(&self, person_id: Uuid, media_id: Uuid) -> CoreResult<()> {
|
||||
sqlx::query!(
|
||||
r#"
|
||||
UPDATE people
|
||||
SET thumbnail_media_id = $1
|
||||
WHERE id = $2
|
||||
"#,
|
||||
media_id,
|
||||
person_id
|
||||
)
|
||||
.execute(&self.pool)
|
||||
.await
|
||||
.map_err(|e| CoreError::Database(e.to_string()))?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,9 +3,9 @@ use std::{path::PathBuf, sync::Arc};
|
||||
use futures_util::StreamExt;
|
||||
use libertas_core::plugins::PluginContext;
|
||||
use libertas_infra::factory::{
|
||||
build_album_repository, build_database_pool, build_face_region_repository,
|
||||
build_media_metadata_repository, build_media_repository, build_person_repository,
|
||||
build_tag_repository, build_user_repository,
|
||||
build_album_repository, build_database_pool, build_face_embedding_repository,
|
||||
build_face_region_repository, build_media_metadata_repository, build_media_repository,
|
||||
build_person_repository, build_tag_repository, build_user_repository,
|
||||
};
|
||||
use serde::Deserialize;
|
||||
use tokio::fs;
|
||||
@@ -45,6 +45,8 @@ async fn main() -> anyhow::Result<()> {
|
||||
let tag_repo = build_tag_repository(&config.database, db_pool.clone()).await?;
|
||||
let person_repo = build_person_repository(&config.database, db_pool.clone()).await?;
|
||||
let face_region_repo = build_face_region_repository(&config.database, db_pool.clone()).await?;
|
||||
let face_embedding_repo =
|
||||
build_face_embedding_repository(&config.database, db_pool.clone()).await?;
|
||||
|
||||
let context = Arc::new(PluginContext {
|
||||
media_repo,
|
||||
@@ -53,6 +55,7 @@ async fn main() -> anyhow::Result<()> {
|
||||
tag_repo,
|
||||
person_repo,
|
||||
face_region_repo,
|
||||
face_embedding_repo,
|
||||
metadata_repo,
|
||||
media_library_path: config.media_library_path.clone(),
|
||||
config: Arc::new(config.clone()),
|
||||
|
||||
@@ -1,19 +1,20 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use libertas_core::{
|
||||
ai::FaceDetector,
|
||||
config::{AiConfig, AppConfig, FaceDetectorRuntime},
|
||||
ai::{FaceDetector, FaceEmbedder},
|
||||
config::{AiConfig, AppConfig, FaceDetectorRuntime, FaceEmbedderRuntime},
|
||||
error::{CoreError, CoreResult},
|
||||
models::Media,
|
||||
plugins::{MediaProcessorPlugin, PluginContext},
|
||||
};
|
||||
use libertas_infra::ai::{
|
||||
remote_detector::RemoteNatsFaceDetector, tract_detector::TractFaceDetector,
|
||||
tract_embedder::TractFaceEmbedder,
|
||||
};
|
||||
|
||||
use crate::plugins::{
|
||||
exif_reader::ExifReaderPlugin, face_detector::FaceDetectionPlugin, thumbnail::ThumbnailPlugin,
|
||||
xmp_writer::XmpWriterPlugin,
|
||||
embedding_generator::EmbeddingGeneratorPlugin, exif_reader::ExifReaderPlugin,
|
||||
face_detector::FaceDetectionPlugin, thumbnail::ThumbnailPlugin, xmp_writer::XmpWriterPlugin,
|
||||
};
|
||||
|
||||
pub struct PluginManager {
|
||||
@@ -25,7 +26,7 @@ impl PluginManager {
|
||||
let mut plugins: Vec<Arc<dyn MediaProcessorPlugin>> = Vec::new();
|
||||
|
||||
if let Some(ai_config) = &config.ai_config {
|
||||
match build_face_detector(ai_config, nats_client) {
|
||||
match build_face_detector(ai_config, nats_client.clone()) {
|
||||
Ok(detector) => {
|
||||
plugins.push(Arc::new(FaceDetectionPlugin::new(detector)));
|
||||
println!("FaceDetectionPlugin loaded.");
|
||||
@@ -34,6 +35,16 @@ impl PluginManager {
|
||||
eprintln!("Failed to load FaceDetectionPlugin: {}", e);
|
||||
}
|
||||
}
|
||||
|
||||
match build_face_embedder(ai_config, nats_client.clone()) {
|
||||
Ok(embedder) => {
|
||||
plugins.push(Arc::new(EmbeddingGeneratorPlugin::new(embedder)));
|
||||
println!("EmbeddingGeneratorPlugin loaded.");
|
||||
}
|
||||
Err(e) => {
|
||||
eprintln!("Failed to load EmbeddingGeneratorPlugin: {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
plugins.push(Arc::new(ExifReaderPlugin));
|
||||
@@ -86,3 +97,27 @@ fn build_face_detector(
|
||||
))),
|
||||
}
|
||||
}
|
||||
|
||||
fn build_face_embedder(
|
||||
config: &AiConfig,
|
||||
_nats_client: async_nats::Client,
|
||||
) -> CoreResult<Box<dyn FaceEmbedder>> {
|
||||
match &config.face_embedder_runtime {
|
||||
FaceEmbedderRuntime::Tract => {
|
||||
let model_path =
|
||||
config
|
||||
.face_embedder_model_path
|
||||
.as_deref()
|
||||
.ok_or(CoreError::Config(
|
||||
"Tract runtime needs 'face_embedder_model_path'".to_string(),
|
||||
))?;
|
||||
Ok(Box::new(TractFaceEmbedder::new(model_path)?))
|
||||
}
|
||||
FaceEmbedderRuntime::Onnx => {
|
||||
unimplemented!("ONNX face embedder not implemented yet");
|
||||
}
|
||||
FaceEmbedderRuntime::RemoteNats { subject: _ } => {
|
||||
unimplemented!("RemoteNats face embedder not implemented yet");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
110
libertas_worker/src/plugins/embedding_generator.rs
Normal file
110
libertas_worker/src/plugins/embedding_generator.rs
Normal file
@@ -0,0 +1,110 @@
|
||||
use std::{io::Cursor, path::PathBuf};
|
||||
|
||||
use async_trait::async_trait;
|
||||
use image::{ImageFormat, ImageReader};
|
||||
use libertas_core::{
|
||||
ai::FaceEmbedder,
|
||||
error::{CoreError, CoreResult},
|
||||
models::{FaceEmbedding, Media},
|
||||
plugins::{MediaProcessorPlugin, PluginContext, PluginData},
|
||||
};
|
||||
use tokio::fs;
|
||||
|
||||
pub struct EmbeddingGeneratorPlugin {
|
||||
embedder: Box<dyn FaceEmbedder>,
|
||||
model_id: i16,
|
||||
}
|
||||
|
||||
impl EmbeddingGeneratorPlugin {
|
||||
pub fn new(embedder: Box<dyn FaceEmbedder>) -> Self {
|
||||
Self {
|
||||
embedder,
|
||||
model_id: 1, // todo: come from config or something
|
||||
}
|
||||
}
|
||||
|
||||
fn f32_vec_to_bytes(vec: &[f32]) -> Vec<u8> {
|
||||
vec.iter().flat_map(|&f| f.to_le_bytes()).collect()
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl MediaProcessorPlugin for EmbeddingGeneratorPlugin {
|
||||
fn name(&self) -> &'static str {
|
||||
"embedding_generator"
|
||||
}
|
||||
|
||||
async fn process(&self, media: &Media, context: &PluginContext) -> CoreResult<PluginData> {
|
||||
if !media.mime_type.starts_with("image/") {
|
||||
return Ok(PluginData {
|
||||
message: "Not an image, skipping.".to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
// 1. Get all face regions for this media
|
||||
let faces = context.face_region_repo.find_by_media_id(media.id).await?;
|
||||
if faces.is_empty() {
|
||||
return Ok(PluginData {
|
||||
message: "No faces found to embed.".to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
// 2. Load the full original image
|
||||
let file_path = PathBuf::from(&context.media_library_path).join(&media.storage_path);
|
||||
let image_bytes = fs::read(file_path).await?;
|
||||
let img = ImageReader::new(Cursor::new(&image_bytes))
|
||||
.with_guessed_format()?
|
||||
.decode()
|
||||
.map_err(|e| CoreError::Unknown(format!("Failed to decode image: {}", e)))?;
|
||||
|
||||
let mut new_embeddings = 0;
|
||||
|
||||
for face in faces {
|
||||
// 3. Check if embedding already exists
|
||||
if context
|
||||
.face_embedding_repo
|
||||
.find_by_face_region_id(face.id)
|
||||
.await?
|
||||
.is_some()
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
// 4. Crop the face from the main image
|
||||
let cropped_face = img.crop_imm(
|
||||
face.x_min as u32,
|
||||
face.y_min as u32,
|
||||
(face.x_max - face.x_min) as u32,
|
||||
(face.y_max - face.y_min) as u32,
|
||||
);
|
||||
|
||||
// 5. Convert cropped image back to bytes (as JPEG)
|
||||
let mut buf = Cursor::new(Vec::new());
|
||||
cropped_face
|
||||
.write_to(&mut buf, ImageFormat::Jpeg)
|
||||
.map_err(|e| {
|
||||
CoreError::Unknown(format!("Failed to encode cropped image: {}", e))
|
||||
})?;
|
||||
let cropped_bytes = buf.into_inner();
|
||||
|
||||
// 6. Generate the embedding
|
||||
let embedding_f32 = self.embedder.generate_embedding(&cropped_bytes).await?;
|
||||
let embedding_bytes = Self::f32_vec_to_bytes(&embedding_f32);
|
||||
|
||||
// 7. Save to database
|
||||
let embedding_model = FaceEmbedding {
|
||||
id: uuid::Uuid::new_v4(),
|
||||
face_region_id: face.id,
|
||||
model_id: self.model_id,
|
||||
embedding: embedding_bytes,
|
||||
};
|
||||
|
||||
context.face_embedding_repo.create(&embedding_model).await?;
|
||||
new_embeddings += 1;
|
||||
}
|
||||
|
||||
Ok(PluginData {
|
||||
message: format!("Generated {} new embeddings.", new_embeddings),
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,3 +1,4 @@
|
||||
pub mod embedding_generator;
|
||||
pub mod exif_reader;
|
||||
pub mod face_detector;
|
||||
pub mod thumbnail;
|
||||
|
||||
Reference in New Issue
Block a user