feat: Implement face clustering and media retrieval for persons
This commit is contained in:
7
Cargo.lock
generated
7
Cargo.lock
generated
@@ -751,6 +751,12 @@ version = "2.9.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2a2330da5de22e8a3cb63252ce2abb30116bf5265e89c0e01bc17015ce30a476"
|
||||
|
||||
[[package]]
|
||||
name = "dbscan"
|
||||
version = "0.3.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9654109c8a07c62a71b9d443e72dde6c6e1364a9e8654126be122c58919d8a85"
|
||||
|
||||
[[package]]
|
||||
name = "der"
|
||||
version = "0.7.10"
|
||||
@@ -1810,6 +1816,7 @@ dependencies = [
|
||||
"axum-extra",
|
||||
"bytes",
|
||||
"chrono",
|
||||
"dbscan",
|
||||
"futures",
|
||||
"headers",
|
||||
"jsonwebtoken",
|
||||
|
||||
@@ -36,3 +36,4 @@ tower = { version = "0.5.2", features = ["util"] }
|
||||
tower-http = { version = "0.6.6", features = ["fs", "trace"] }
|
||||
tracing = "0.1.41"
|
||||
tracing-subscriber = { version = "0.3.20", features = ["env-filter"] }
|
||||
dbscan = "0.3.1"
|
||||
|
||||
@@ -6,9 +6,9 @@ use libertas_core::{
|
||||
};
|
||||
use libertas_infra::factory::{
|
||||
build_album_repository, build_album_share_repository, build_database_pool,
|
||||
build_face_region_repository, build_media_metadata_repository, build_media_repository,
|
||||
build_person_repository, build_person_share_repository, build_tag_repository,
|
||||
build_user_repository,
|
||||
build_face_embedding_repository, build_face_region_repository, build_media_metadata_repository,
|
||||
build_media_repository, build_person_repository, build_person_share_repository,
|
||||
build_tag_repository, build_user_repository,
|
||||
};
|
||||
|
||||
use crate::{
|
||||
@@ -40,6 +40,8 @@ pub async fn build_app_state(config: AppConfig) -> CoreResult<AppState> {
|
||||
let face_region_repo = build_face_region_repository(&config.database, db_pool.clone()).await?;
|
||||
let person_share_repo =
|
||||
build_person_share_repository(&config.database, db_pool.clone()).await?;
|
||||
let face_embedding_repo =
|
||||
build_face_embedding_repository(&config.database, db_pool.clone()).await?;
|
||||
|
||||
let hasher = Arc::new(Argon2Hasher::default());
|
||||
let tokenizer = Arc::new(JwtGenerator::new(config.jwt_secret.clone()));
|
||||
@@ -81,6 +83,8 @@ pub async fn build_app_state(config: AppConfig) -> CoreResult<AppState> {
|
||||
person_repo.clone(),
|
||||
face_region_repo.clone(),
|
||||
person_share_repo.clone(),
|
||||
face_embedding_repo.clone(),
|
||||
media_repo.clone(),
|
||||
authorization_service.clone(),
|
||||
));
|
||||
|
||||
|
||||
@@ -10,10 +10,12 @@ use uuid::Uuid;
|
||||
|
||||
use crate::{
|
||||
error::ApiError,
|
||||
extractors::query_options::ApiListMediaOptions,
|
||||
middleware::auth::UserId,
|
||||
schema::{
|
||||
AssignFaceRequest, CreatePersonRequest, FaceRegionResponse, MergePersonRequest,
|
||||
PersonResponse, SetPersonThumbnailRequest, SharePersonRequest, UpdatePersonRequest,
|
||||
AssignFaceRequest, CreatePersonRequest, FaceRegionResponse, MediaResponse,
|
||||
MergePersonRequest, PaginatedResponse, PersonResponse, SetPersonThumbnailRequest,
|
||||
SharePersonRequest, UpdatePersonRequest, map_paginated_response,
|
||||
},
|
||||
state::AppState,
|
||||
};
|
||||
@@ -31,6 +33,8 @@ pub fn people_routes() -> Router<AppState> {
|
||||
)
|
||||
.route("/{person_id}/merge", post(merge_person))
|
||||
.route("/{person_id}/thumbnail", put(set_person_thumbnail))
|
||||
.route("/cluster", post(cluster_faces))
|
||||
.route("/{person_id}/media", get(list_media_for_person))
|
||||
}
|
||||
|
||||
pub fn face_routes() -> Router<AppState> {
|
||||
@@ -176,3 +180,29 @@ async fn set_person_thumbnail(
|
||||
.await?;
|
||||
Ok(StatusCode::OK)
|
||||
}
|
||||
|
||||
async fn cluster_faces(
|
||||
State(state): State<AppState>,
|
||||
UserId(user_id): UserId,
|
||||
) -> Result<StatusCode, ApiError> {
|
||||
state
|
||||
.person_service
|
||||
.cluster_unassigned_faces(user_id)
|
||||
.await?;
|
||||
Ok(StatusCode::OK)
|
||||
}
|
||||
|
||||
async fn list_media_for_person(
|
||||
State(state): State<AppState>,
|
||||
UserId(user_id): UserId,
|
||||
Path(person_id): Path<Uuid>,
|
||||
ApiListMediaOptions(options): ApiListMediaOptions,
|
||||
) -> Result<Json<PaginatedResponse<MediaResponse>>, ApiError> {
|
||||
let core_paginated_result = state
|
||||
.person_service
|
||||
.list_media_for_person(person_id, user_id, options)
|
||||
.await?;
|
||||
|
||||
let api_response = map_paginated_response(core_paginated_result);
|
||||
Ok(Json(api_response))
|
||||
}
|
||||
|
||||
@@ -1,11 +1,16 @@
|
||||
use std::sync::Arc;
|
||||
use dbscan::{self, Classification};
|
||||
use std::{collections::HashMap, sync::Arc};
|
||||
|
||||
use async_trait::async_trait;
|
||||
use libertas_core::{
|
||||
authz,
|
||||
error::{CoreError, CoreResult},
|
||||
models::{FaceRegion, Person, PersonPermission},
|
||||
repositories::{FaceRegionRepository, PersonRepository, PersonShareRepository},
|
||||
models::{FaceRegion, Media, Person, PersonPermission},
|
||||
repositories::{
|
||||
FaceEmbeddingRepository, FaceRegionRepository, MediaRepository, PersonRepository,
|
||||
PersonShareRepository,
|
||||
},
|
||||
schema::{ListMediaOptions, PaginatedResponse},
|
||||
services::{AuthorizationService, PersonService},
|
||||
};
|
||||
use uuid::Uuid;
|
||||
@@ -14,6 +19,8 @@ pub struct PersonServiceImpl {
|
||||
person_repo: Arc<dyn PersonRepository>,
|
||||
face_repo: Arc<dyn FaceRegionRepository>,
|
||||
person_share_repo: Arc<dyn PersonShareRepository>,
|
||||
face_embedding_repo: Arc<dyn FaceEmbeddingRepository>,
|
||||
media_repo: Arc<dyn MediaRepository>,
|
||||
auth_service: Arc<dyn AuthorizationService>,
|
||||
}
|
||||
|
||||
@@ -22,12 +29,16 @@ impl PersonServiceImpl {
|
||||
person_repo: Arc<dyn PersonRepository>,
|
||||
face_repo: Arc<dyn FaceRegionRepository>,
|
||||
person_share_repo: Arc<dyn PersonShareRepository>,
|
||||
face_embedding_repo: Arc<dyn FaceEmbeddingRepository>,
|
||||
media_repo: Arc<dyn MediaRepository>,
|
||||
auth_service: Arc<dyn AuthorizationService>,
|
||||
) -> Self {
|
||||
Self {
|
||||
person_repo,
|
||||
face_repo,
|
||||
person_share_repo,
|
||||
face_embedding_repo,
|
||||
media_repo,
|
||||
auth_service,
|
||||
}
|
||||
}
|
||||
@@ -40,6 +51,13 @@ impl PersonServiceImpl {
|
||||
.ok_or(CoreError::NotFound("Person".to_string(), person_id))?;
|
||||
Ok(person)
|
||||
}
|
||||
|
||||
fn bytes_to_f32_vec(bytes: &[u8]) -> Vec<f32> {
|
||||
bytes
|
||||
.chunks_exact(4)
|
||||
.map(|chunk| f32::from_le_bytes(chunk.try_into().unwrap_or([0; 4])))
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
@@ -245,4 +263,83 @@ impl PersonService for PersonServiceImpl {
|
||||
.set_thumbnail_media_id(person_id, face_region.media_id)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn cluster_unassigned_faces(&self, user_id: Uuid) -> CoreResult<()> {
|
||||
let embedding_data = self
|
||||
.face_embedding_repo
|
||||
.list_unassigned_by_user(user_id)
|
||||
.await?;
|
||||
|
||||
if embedding_data.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let embeddings_f32: Vec<Vec<f32>> = embedding_data
|
||||
.iter()
|
||||
.map(|data| Self::bytes_to_f32_vec(&data.embedding))
|
||||
.collect();
|
||||
|
||||
let scan = dbscan::Model::new(0.4, 2);
|
||||
let clusters = scan.run(&embeddings_f32);
|
||||
|
||||
let mut cluster_map: HashMap<usize, Vec<usize>> = HashMap::new();
|
||||
|
||||
tracing::info!(
|
||||
"DBSCAN found {} clusters",
|
||||
clusters
|
||||
.iter()
|
||||
.filter(|c| match c {
|
||||
Classification::Core(_) | Classification::Edge(_) => true,
|
||||
Classification::Noise => false,
|
||||
})
|
||||
.count()
|
||||
);
|
||||
|
||||
for (i, classification) in clusters.iter().enumerate() {
|
||||
match classification {
|
||||
Classification::Core(cluster_id) | Classification::Edge(cluster_id) => {
|
||||
cluster_map.entry(*cluster_id).or_default().push(i);
|
||||
}
|
||||
Classification::Noise => {}
|
||||
}
|
||||
}
|
||||
|
||||
for (_cluster_id, indices) in cluster_map {
|
||||
let person_name = format!("Person {}", Uuid::new_v4());
|
||||
let new_person = self.create_person(&person_name, user_id).await?;
|
||||
|
||||
let face_ids: Vec<Uuid> = indices
|
||||
.iter()
|
||||
.map(|&i| embedding_data[i].face_region_id)
|
||||
.collect();
|
||||
|
||||
for face_id in face_ids {
|
||||
self.face_repo
|
||||
.update_person_id(face_id, new_person.id)
|
||||
.await?;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn list_media_for_person(
|
||||
&self,
|
||||
person_id: Uuid,
|
||||
user_id: Uuid,
|
||||
options: ListMediaOptions,
|
||||
) -> CoreResult<PaginatedResponse<Media>> {
|
||||
self.auth_service
|
||||
.check_permission(Some(user_id), authz::Permission::ViewPerson(person_id))
|
||||
.await?;
|
||||
|
||||
let (data, total_items) = self
|
||||
.media_repo
|
||||
.list_by_person_id(person_id, &options)
|
||||
.await?;
|
||||
|
||||
let pagination = options.pagination.unwrap();
|
||||
let response = PaginatedResponse::new(data, pagination.page, pagination.limit, total_items);
|
||||
Ok(response)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -20,6 +20,11 @@ pub trait MediaRepository: Send + Sync {
|
||||
user_id: Uuid,
|
||||
options: &ListMediaOptions,
|
||||
) -> CoreResult<(Vec<Media>, i64)>;
|
||||
async fn list_by_person_id(
|
||||
&self,
|
||||
person_id: Uuid,
|
||||
options: &ListMediaOptions,
|
||||
) -> CoreResult<(Vec<Media>, i64)>;
|
||||
async fn update_thumbnail_path(&self, id: Uuid, thumbnail_path: String) -> CoreResult<()>;
|
||||
async fn delete(&self, id: Uuid) -> CoreResult<()>;
|
||||
}
|
||||
@@ -139,4 +144,5 @@ pub trait FaceEmbeddingRepository: Send + Sync {
|
||||
&self,
|
||||
face_region_id: Uuid,
|
||||
) -> CoreResult<Option<FaceEmbedding>>;
|
||||
async fn list_unassigned_by_user(&self, user_id: Uuid) -> CoreResult<Vec<FaceEmbedding>>;
|
||||
}
|
||||
|
||||
@@ -127,6 +127,15 @@ pub trait PersonService: Send + Sync {
|
||||
face_region_id: Uuid,
|
||||
user_id: Uuid,
|
||||
) -> CoreResult<()>;
|
||||
|
||||
async fn cluster_unassigned_faces(&self, user_id: Uuid) -> CoreResult<()>;
|
||||
|
||||
async fn list_media_for_person(
|
||||
&self,
|
||||
person_id: Uuid,
|
||||
user_id: Uuid,
|
||||
options: ListMediaOptions,
|
||||
) -> CoreResult<PaginatedResponse<Media>>;
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
|
||||
@@ -58,4 +58,23 @@ impl FaceEmbeddingRepository for PostgresFaceEmbeddingRepository {
|
||||
|
||||
Ok(pg_embedding.map(FaceEmbedding::from))
|
||||
}
|
||||
|
||||
async fn list_unassigned_by_user(&self, user_id: Uuid) -> CoreResult<Vec<FaceEmbedding>> {
|
||||
let pg_embeddings = sqlx::query_as!(
|
||||
PostgresFaceEmbedding,
|
||||
r#"
|
||||
SELECT fe.id, fe.face_region_id, fe.model_id, fe.embedding
|
||||
FROM face_embeddings fe
|
||||
JOIN face_regions fr ON fe.face_region_id = fr.id
|
||||
JOIN media m ON fr.media_id = m.id
|
||||
WHERE fr.person_id IS NULL AND m.owner_id = $1
|
||||
"#,
|
||||
user_id
|
||||
)
|
||||
.fetch_all(&self.pool)
|
||||
.await
|
||||
.map_err(|e| CoreError::Database(e.to_string()))?;
|
||||
|
||||
Ok(pg_embeddings.into_iter().map(FaceEmbedding::from).collect())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -162,6 +162,64 @@ impl MediaRepository for PostgresMediaRepository {
|
||||
Ok((media_list, total_items_result))
|
||||
}
|
||||
|
||||
async fn list_by_person_id(
|
||||
&self,
|
||||
person_id: Uuid,
|
||||
options: &ListMediaOptions,
|
||||
) -> CoreResult<(Vec<Media>, i64)> {
|
||||
let count_base_sql = "
|
||||
SELECT COUNT(DISTINCT media.id) as total
|
||||
FROM media
|
||||
JOIN face_regions fr ON media.id = fr.media_id
|
||||
";
|
||||
let mut count_query = sqlx::QueryBuilder::new(count_base_sql);
|
||||
count_query.push(" WHERE fr.person_id = ");
|
||||
count_query.push_bind(person_id);
|
||||
|
||||
let (mut count_query, _metadata_filter_count) = self
|
||||
.query_builder
|
||||
.apply_filters_to_query(count_query, options)?;
|
||||
|
||||
let total_items_result = count_query
|
||||
.build_query_scalar()
|
||||
.fetch_one(&self.pool)
|
||||
.await
|
||||
.map_err(|e| CoreError::Database(e.to_string()))?;
|
||||
|
||||
let data_base_sql = "
|
||||
SELECT media.id, media.owner_id, media.storage_path,
|
||||
media.original_filename, media.mime_type,
|
||||
media.hash, media.created_at, media.thumbnail_path
|
||||
FROM media
|
||||
JOIN face_regions fr ON media.id = fr.media_id
|
||||
";
|
||||
let mut data_query = sqlx::QueryBuilder::new(data_base_sql);
|
||||
data_query.push(" WHERE fr.person_id = ");
|
||||
data_query.push_bind(person_id);
|
||||
|
||||
let (mut data_query, _metadata_filter_count) = self
|
||||
.query_builder
|
||||
.apply_filters_to_query(data_query, options)?;
|
||||
|
||||
data_query.push(" GROUP BY media.id ");
|
||||
|
||||
let data_query = self
|
||||
.query_builder
|
||||
.apply_sorting_to_query(data_query, options)?;
|
||||
let mut data_query = self
|
||||
.query_builder
|
||||
.apply_pagination_to_query(data_query, options)?;
|
||||
|
||||
let pg_media = data_query
|
||||
.build_query_as::<PostgresMedia>()
|
||||
.fetch_all(&self.pool)
|
||||
.await
|
||||
.map_err(|e| CoreError::Database(e.to_string()))?;
|
||||
|
||||
let media_list = pg_media.into_iter().map(|m| m.into()).collect();
|
||||
Ok((media_list, total_items_result))
|
||||
}
|
||||
|
||||
async fn update_thumbnail_path(&self, id: Uuid, thumbnail_path: String) -> CoreResult<()> {
|
||||
sqlx::query!(
|
||||
r#"
|
||||
|
||||
Reference in New Issue
Block a user