feat: Implement face clustering and media retrieval for persons
This commit is contained in:
7
Cargo.lock
generated
7
Cargo.lock
generated
@@ -751,6 +751,12 @@ version = "2.9.0"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "2a2330da5de22e8a3cb63252ce2abb30116bf5265e89c0e01bc17015ce30a476"
|
checksum = "2a2330da5de22e8a3cb63252ce2abb30116bf5265e89c0e01bc17015ce30a476"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "dbscan"
|
||||||
|
version = "0.3.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "9654109c8a07c62a71b9d443e72dde6c6e1364a9e8654126be122c58919d8a85"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "der"
|
name = "der"
|
||||||
version = "0.7.10"
|
version = "0.7.10"
|
||||||
@@ -1810,6 +1816,7 @@ dependencies = [
|
|||||||
"axum-extra",
|
"axum-extra",
|
||||||
"bytes",
|
"bytes",
|
||||||
"chrono",
|
"chrono",
|
||||||
|
"dbscan",
|
||||||
"futures",
|
"futures",
|
||||||
"headers",
|
"headers",
|
||||||
"jsonwebtoken",
|
"jsonwebtoken",
|
||||||
|
|||||||
@@ -36,3 +36,4 @@ tower = { version = "0.5.2", features = ["util"] }
|
|||||||
tower-http = { version = "0.6.6", features = ["fs", "trace"] }
|
tower-http = { version = "0.6.6", features = ["fs", "trace"] }
|
||||||
tracing = "0.1.41"
|
tracing = "0.1.41"
|
||||||
tracing-subscriber = { version = "0.3.20", features = ["env-filter"] }
|
tracing-subscriber = { version = "0.3.20", features = ["env-filter"] }
|
||||||
|
dbscan = "0.3.1"
|
||||||
|
|||||||
@@ -6,9 +6,9 @@ use libertas_core::{
|
|||||||
};
|
};
|
||||||
use libertas_infra::factory::{
|
use libertas_infra::factory::{
|
||||||
build_album_repository, build_album_share_repository, build_database_pool,
|
build_album_repository, build_album_share_repository, build_database_pool,
|
||||||
build_face_region_repository, build_media_metadata_repository, build_media_repository,
|
build_face_embedding_repository, build_face_region_repository, build_media_metadata_repository,
|
||||||
build_person_repository, build_person_share_repository, build_tag_repository,
|
build_media_repository, build_person_repository, build_person_share_repository,
|
||||||
build_user_repository,
|
build_tag_repository, build_user_repository,
|
||||||
};
|
};
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
@@ -40,6 +40,8 @@ pub async fn build_app_state(config: AppConfig) -> CoreResult<AppState> {
|
|||||||
let face_region_repo = build_face_region_repository(&config.database, db_pool.clone()).await?;
|
let face_region_repo = build_face_region_repository(&config.database, db_pool.clone()).await?;
|
||||||
let person_share_repo =
|
let person_share_repo =
|
||||||
build_person_share_repository(&config.database, db_pool.clone()).await?;
|
build_person_share_repository(&config.database, db_pool.clone()).await?;
|
||||||
|
let face_embedding_repo =
|
||||||
|
build_face_embedding_repository(&config.database, db_pool.clone()).await?;
|
||||||
|
|
||||||
let hasher = Arc::new(Argon2Hasher::default());
|
let hasher = Arc::new(Argon2Hasher::default());
|
||||||
let tokenizer = Arc::new(JwtGenerator::new(config.jwt_secret.clone()));
|
let tokenizer = Arc::new(JwtGenerator::new(config.jwt_secret.clone()));
|
||||||
@@ -81,6 +83,8 @@ pub async fn build_app_state(config: AppConfig) -> CoreResult<AppState> {
|
|||||||
person_repo.clone(),
|
person_repo.clone(),
|
||||||
face_region_repo.clone(),
|
face_region_repo.clone(),
|
||||||
person_share_repo.clone(),
|
person_share_repo.clone(),
|
||||||
|
face_embedding_repo.clone(),
|
||||||
|
media_repo.clone(),
|
||||||
authorization_service.clone(),
|
authorization_service.clone(),
|
||||||
));
|
));
|
||||||
|
|
||||||
|
|||||||
@@ -10,10 +10,12 @@ use uuid::Uuid;
|
|||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
error::ApiError,
|
error::ApiError,
|
||||||
|
extractors::query_options::ApiListMediaOptions,
|
||||||
middleware::auth::UserId,
|
middleware::auth::UserId,
|
||||||
schema::{
|
schema::{
|
||||||
AssignFaceRequest, CreatePersonRequest, FaceRegionResponse, MergePersonRequest,
|
AssignFaceRequest, CreatePersonRequest, FaceRegionResponse, MediaResponse,
|
||||||
PersonResponse, SetPersonThumbnailRequest, SharePersonRequest, UpdatePersonRequest,
|
MergePersonRequest, PaginatedResponse, PersonResponse, SetPersonThumbnailRequest,
|
||||||
|
SharePersonRequest, UpdatePersonRequest, map_paginated_response,
|
||||||
},
|
},
|
||||||
state::AppState,
|
state::AppState,
|
||||||
};
|
};
|
||||||
@@ -31,6 +33,8 @@ pub fn people_routes() -> Router<AppState> {
|
|||||||
)
|
)
|
||||||
.route("/{person_id}/merge", post(merge_person))
|
.route("/{person_id}/merge", post(merge_person))
|
||||||
.route("/{person_id}/thumbnail", put(set_person_thumbnail))
|
.route("/{person_id}/thumbnail", put(set_person_thumbnail))
|
||||||
|
.route("/cluster", post(cluster_faces))
|
||||||
|
.route("/{person_id}/media", get(list_media_for_person))
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn face_routes() -> Router<AppState> {
|
pub fn face_routes() -> Router<AppState> {
|
||||||
@@ -176,3 +180,29 @@ async fn set_person_thumbnail(
|
|||||||
.await?;
|
.await?;
|
||||||
Ok(StatusCode::OK)
|
Ok(StatusCode::OK)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn cluster_faces(
|
||||||
|
State(state): State<AppState>,
|
||||||
|
UserId(user_id): UserId,
|
||||||
|
) -> Result<StatusCode, ApiError> {
|
||||||
|
state
|
||||||
|
.person_service
|
||||||
|
.cluster_unassigned_faces(user_id)
|
||||||
|
.await?;
|
||||||
|
Ok(StatusCode::OK)
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn list_media_for_person(
|
||||||
|
State(state): State<AppState>,
|
||||||
|
UserId(user_id): UserId,
|
||||||
|
Path(person_id): Path<Uuid>,
|
||||||
|
ApiListMediaOptions(options): ApiListMediaOptions,
|
||||||
|
) -> Result<Json<PaginatedResponse<MediaResponse>>, ApiError> {
|
||||||
|
let core_paginated_result = state
|
||||||
|
.person_service
|
||||||
|
.list_media_for_person(person_id, user_id, options)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
let api_response = map_paginated_response(core_paginated_result);
|
||||||
|
Ok(Json(api_response))
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,11 +1,16 @@
|
|||||||
use std::sync::Arc;
|
use dbscan::{self, Classification};
|
||||||
|
use std::{collections::HashMap, sync::Arc};
|
||||||
|
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use libertas_core::{
|
use libertas_core::{
|
||||||
authz,
|
authz,
|
||||||
error::{CoreError, CoreResult},
|
error::{CoreError, CoreResult},
|
||||||
models::{FaceRegion, Person, PersonPermission},
|
models::{FaceRegion, Media, Person, PersonPermission},
|
||||||
repositories::{FaceRegionRepository, PersonRepository, PersonShareRepository},
|
repositories::{
|
||||||
|
FaceEmbeddingRepository, FaceRegionRepository, MediaRepository, PersonRepository,
|
||||||
|
PersonShareRepository,
|
||||||
|
},
|
||||||
|
schema::{ListMediaOptions, PaginatedResponse},
|
||||||
services::{AuthorizationService, PersonService},
|
services::{AuthorizationService, PersonService},
|
||||||
};
|
};
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
@@ -14,6 +19,8 @@ pub struct PersonServiceImpl {
|
|||||||
person_repo: Arc<dyn PersonRepository>,
|
person_repo: Arc<dyn PersonRepository>,
|
||||||
face_repo: Arc<dyn FaceRegionRepository>,
|
face_repo: Arc<dyn FaceRegionRepository>,
|
||||||
person_share_repo: Arc<dyn PersonShareRepository>,
|
person_share_repo: Arc<dyn PersonShareRepository>,
|
||||||
|
face_embedding_repo: Arc<dyn FaceEmbeddingRepository>,
|
||||||
|
media_repo: Arc<dyn MediaRepository>,
|
||||||
auth_service: Arc<dyn AuthorizationService>,
|
auth_service: Arc<dyn AuthorizationService>,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -22,12 +29,16 @@ impl PersonServiceImpl {
|
|||||||
person_repo: Arc<dyn PersonRepository>,
|
person_repo: Arc<dyn PersonRepository>,
|
||||||
face_repo: Arc<dyn FaceRegionRepository>,
|
face_repo: Arc<dyn FaceRegionRepository>,
|
||||||
person_share_repo: Arc<dyn PersonShareRepository>,
|
person_share_repo: Arc<dyn PersonShareRepository>,
|
||||||
|
face_embedding_repo: Arc<dyn FaceEmbeddingRepository>,
|
||||||
|
media_repo: Arc<dyn MediaRepository>,
|
||||||
auth_service: Arc<dyn AuthorizationService>,
|
auth_service: Arc<dyn AuthorizationService>,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
Self {
|
Self {
|
||||||
person_repo,
|
person_repo,
|
||||||
face_repo,
|
face_repo,
|
||||||
person_share_repo,
|
person_share_repo,
|
||||||
|
face_embedding_repo,
|
||||||
|
media_repo,
|
||||||
auth_service,
|
auth_service,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -40,6 +51,13 @@ impl PersonServiceImpl {
|
|||||||
.ok_or(CoreError::NotFound("Person".to_string(), person_id))?;
|
.ok_or(CoreError::NotFound("Person".to_string(), person_id))?;
|
||||||
Ok(person)
|
Ok(person)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn bytes_to_f32_vec(bytes: &[u8]) -> Vec<f32> {
|
||||||
|
bytes
|
||||||
|
.chunks_exact(4)
|
||||||
|
.map(|chunk| f32::from_le_bytes(chunk.try_into().unwrap_or([0; 4])))
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
@@ -245,4 +263,83 @@ impl PersonService for PersonServiceImpl {
|
|||||||
.set_thumbnail_media_id(person_id, face_region.media_id)
|
.set_thumbnail_media_id(person_id, face_region.media_id)
|
||||||
.await
|
.await
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn cluster_unassigned_faces(&self, user_id: Uuid) -> CoreResult<()> {
|
||||||
|
let embedding_data = self
|
||||||
|
.face_embedding_repo
|
||||||
|
.list_unassigned_by_user(user_id)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
if embedding_data.is_empty() {
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
|
||||||
|
let embeddings_f32: Vec<Vec<f32>> = embedding_data
|
||||||
|
.iter()
|
||||||
|
.map(|data| Self::bytes_to_f32_vec(&data.embedding))
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
let scan = dbscan::Model::new(0.4, 2);
|
||||||
|
let clusters = scan.run(&embeddings_f32);
|
||||||
|
|
||||||
|
let mut cluster_map: HashMap<usize, Vec<usize>> = HashMap::new();
|
||||||
|
|
||||||
|
tracing::info!(
|
||||||
|
"DBSCAN found {} clusters",
|
||||||
|
clusters
|
||||||
|
.iter()
|
||||||
|
.filter(|c| match c {
|
||||||
|
Classification::Core(_) | Classification::Edge(_) => true,
|
||||||
|
Classification::Noise => false,
|
||||||
|
})
|
||||||
|
.count()
|
||||||
|
);
|
||||||
|
|
||||||
|
for (i, classification) in clusters.iter().enumerate() {
|
||||||
|
match classification {
|
||||||
|
Classification::Core(cluster_id) | Classification::Edge(cluster_id) => {
|
||||||
|
cluster_map.entry(*cluster_id).or_default().push(i);
|
||||||
|
}
|
||||||
|
Classification::Noise => {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (_cluster_id, indices) in cluster_map {
|
||||||
|
let person_name = format!("Person {}", Uuid::new_v4());
|
||||||
|
let new_person = self.create_person(&person_name, user_id).await?;
|
||||||
|
|
||||||
|
let face_ids: Vec<Uuid> = indices
|
||||||
|
.iter()
|
||||||
|
.map(|&i| embedding_data[i].face_region_id)
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
for face_id in face_ids {
|
||||||
|
self.face_repo
|
||||||
|
.update_person_id(face_id, new_person.id)
|
||||||
|
.await?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn list_media_for_person(
|
||||||
|
&self,
|
||||||
|
person_id: Uuid,
|
||||||
|
user_id: Uuid,
|
||||||
|
options: ListMediaOptions,
|
||||||
|
) -> CoreResult<PaginatedResponse<Media>> {
|
||||||
|
self.auth_service
|
||||||
|
.check_permission(Some(user_id), authz::Permission::ViewPerson(person_id))
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
let (data, total_items) = self
|
||||||
|
.media_repo
|
||||||
|
.list_by_person_id(person_id, &options)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
let pagination = options.pagination.unwrap();
|
||||||
|
let response = PaginatedResponse::new(data, pagination.page, pagination.limit, total_items);
|
||||||
|
Ok(response)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -20,6 +20,11 @@ pub trait MediaRepository: Send + Sync {
|
|||||||
user_id: Uuid,
|
user_id: Uuid,
|
||||||
options: &ListMediaOptions,
|
options: &ListMediaOptions,
|
||||||
) -> CoreResult<(Vec<Media>, i64)>;
|
) -> CoreResult<(Vec<Media>, i64)>;
|
||||||
|
async fn list_by_person_id(
|
||||||
|
&self,
|
||||||
|
person_id: Uuid,
|
||||||
|
options: &ListMediaOptions,
|
||||||
|
) -> CoreResult<(Vec<Media>, i64)>;
|
||||||
async fn update_thumbnail_path(&self, id: Uuid, thumbnail_path: String) -> CoreResult<()>;
|
async fn update_thumbnail_path(&self, id: Uuid, thumbnail_path: String) -> CoreResult<()>;
|
||||||
async fn delete(&self, id: Uuid) -> CoreResult<()>;
|
async fn delete(&self, id: Uuid) -> CoreResult<()>;
|
||||||
}
|
}
|
||||||
@@ -139,4 +144,5 @@ pub trait FaceEmbeddingRepository: Send + Sync {
|
|||||||
&self,
|
&self,
|
||||||
face_region_id: Uuid,
|
face_region_id: Uuid,
|
||||||
) -> CoreResult<Option<FaceEmbedding>>;
|
) -> CoreResult<Option<FaceEmbedding>>;
|
||||||
|
async fn list_unassigned_by_user(&self, user_id: Uuid) -> CoreResult<Vec<FaceEmbedding>>;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -127,6 +127,15 @@ pub trait PersonService: Send + Sync {
|
|||||||
face_region_id: Uuid,
|
face_region_id: Uuid,
|
||||||
user_id: Uuid,
|
user_id: Uuid,
|
||||||
) -> CoreResult<()>;
|
) -> CoreResult<()>;
|
||||||
|
|
||||||
|
async fn cluster_unassigned_faces(&self, user_id: Uuid) -> CoreResult<()>;
|
||||||
|
|
||||||
|
async fn list_media_for_person(
|
||||||
|
&self,
|
||||||
|
person_id: Uuid,
|
||||||
|
user_id: Uuid,
|
||||||
|
options: ListMediaOptions,
|
||||||
|
) -> CoreResult<PaginatedResponse<Media>>;
|
||||||
}
|
}
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
|
|||||||
@@ -58,4 +58,23 @@ impl FaceEmbeddingRepository for PostgresFaceEmbeddingRepository {
|
|||||||
|
|
||||||
Ok(pg_embedding.map(FaceEmbedding::from))
|
Ok(pg_embedding.map(FaceEmbedding::from))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn list_unassigned_by_user(&self, user_id: Uuid) -> CoreResult<Vec<FaceEmbedding>> {
|
||||||
|
let pg_embeddings = sqlx::query_as!(
|
||||||
|
PostgresFaceEmbedding,
|
||||||
|
r#"
|
||||||
|
SELECT fe.id, fe.face_region_id, fe.model_id, fe.embedding
|
||||||
|
FROM face_embeddings fe
|
||||||
|
JOIN face_regions fr ON fe.face_region_id = fr.id
|
||||||
|
JOIN media m ON fr.media_id = m.id
|
||||||
|
WHERE fr.person_id IS NULL AND m.owner_id = $1
|
||||||
|
"#,
|
||||||
|
user_id
|
||||||
|
)
|
||||||
|
.fetch_all(&self.pool)
|
||||||
|
.await
|
||||||
|
.map_err(|e| CoreError::Database(e.to_string()))?;
|
||||||
|
|
||||||
|
Ok(pg_embeddings.into_iter().map(FaceEmbedding::from).collect())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -162,6 +162,64 @@ impl MediaRepository for PostgresMediaRepository {
|
|||||||
Ok((media_list, total_items_result))
|
Ok((media_list, total_items_result))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn list_by_person_id(
|
||||||
|
&self,
|
||||||
|
person_id: Uuid,
|
||||||
|
options: &ListMediaOptions,
|
||||||
|
) -> CoreResult<(Vec<Media>, i64)> {
|
||||||
|
let count_base_sql = "
|
||||||
|
SELECT COUNT(DISTINCT media.id) as total
|
||||||
|
FROM media
|
||||||
|
JOIN face_regions fr ON media.id = fr.media_id
|
||||||
|
";
|
||||||
|
let mut count_query = sqlx::QueryBuilder::new(count_base_sql);
|
||||||
|
count_query.push(" WHERE fr.person_id = ");
|
||||||
|
count_query.push_bind(person_id);
|
||||||
|
|
||||||
|
let (mut count_query, _metadata_filter_count) = self
|
||||||
|
.query_builder
|
||||||
|
.apply_filters_to_query(count_query, options)?;
|
||||||
|
|
||||||
|
let total_items_result = count_query
|
||||||
|
.build_query_scalar()
|
||||||
|
.fetch_one(&self.pool)
|
||||||
|
.await
|
||||||
|
.map_err(|e| CoreError::Database(e.to_string()))?;
|
||||||
|
|
||||||
|
let data_base_sql = "
|
||||||
|
SELECT media.id, media.owner_id, media.storage_path,
|
||||||
|
media.original_filename, media.mime_type,
|
||||||
|
media.hash, media.created_at, media.thumbnail_path
|
||||||
|
FROM media
|
||||||
|
JOIN face_regions fr ON media.id = fr.media_id
|
||||||
|
";
|
||||||
|
let mut data_query = sqlx::QueryBuilder::new(data_base_sql);
|
||||||
|
data_query.push(" WHERE fr.person_id = ");
|
||||||
|
data_query.push_bind(person_id);
|
||||||
|
|
||||||
|
let (mut data_query, _metadata_filter_count) = self
|
||||||
|
.query_builder
|
||||||
|
.apply_filters_to_query(data_query, options)?;
|
||||||
|
|
||||||
|
data_query.push(" GROUP BY media.id ");
|
||||||
|
|
||||||
|
let data_query = self
|
||||||
|
.query_builder
|
||||||
|
.apply_sorting_to_query(data_query, options)?;
|
||||||
|
let mut data_query = self
|
||||||
|
.query_builder
|
||||||
|
.apply_pagination_to_query(data_query, options)?;
|
||||||
|
|
||||||
|
let pg_media = data_query
|
||||||
|
.build_query_as::<PostgresMedia>()
|
||||||
|
.fetch_all(&self.pool)
|
||||||
|
.await
|
||||||
|
.map_err(|e| CoreError::Database(e.to_string()))?;
|
||||||
|
|
||||||
|
let media_list = pg_media.into_iter().map(|m| m.into()).collect();
|
||||||
|
Ok((media_list, total_items_result))
|
||||||
|
}
|
||||||
|
|
||||||
async fn update_thumbnail_path(&self, id: Uuid, thumbnail_path: String) -> CoreResult<()> {
|
async fn update_thumbnail_path(&self, id: Uuid, thumbnail_path: String) -> CoreResult<()> {
|
||||||
sqlx::query!(
|
sqlx::query!(
|
||||||
r#"
|
r#"
|
||||||
|
|||||||
Reference in New Issue
Block a user