use axum::{ extract::{FromRef, FromRequestParts}, http::{header, header::AUTHORIZATION, request::Parts}, response::{IntoResponse, Redirect}, }; use domain::{errors::DomainError, value_objects::UserId}; use crate::{errors::ApiError, state::AppState}; pub struct AuthenticatedUser(pub UserId); impl FromRequestParts for AuthenticatedUser where AppState: FromRef, S: Send + Sync, { type Rejection = ApiError; async fn from_request_parts(parts: &mut Parts, state: &S) -> Result { let app_state = AppState::from_ref(state); let token = parts .headers .get(AUTHORIZATION) .and_then(|v| v.to_str().ok()) .and_then(|v| v.strip_prefix("Bearer ")) .ok_or_else(|| { ApiError(DomainError::Unauthorized( "Missing or invalid auth token".into(), )) })?; let user_id = app_state .app_ctx .auth_service .validate_token(token) .await?; Ok(AuthenticatedUser(user_id)) } } pub struct OptionalCookieUser(pub Option); pub struct RequiredCookieUser(pub UserId); fn extract_token_from_cookie(parts: &Parts) -> Option { parts .headers .get(header::COOKIE) .and_then(|v| v.to_str().ok()) .and_then(|cookies| { cookies .split(';') .find_map(|c| c.trim().strip_prefix("token=").map(str::to_string)) }) } impl FromRequestParts for OptionalCookieUser where AppState: FromRef, S: Send + Sync, { type Rejection = std::convert::Infallible; async fn from_request_parts(parts: &mut Parts, state: &S) -> Result { let app_state = AppState::from_ref(state); let Some(token) = extract_token_from_cookie(parts) else { return Ok(OptionalCookieUser(None)); }; let user_id = app_state .app_ctx .auth_service .validate_token(&token) .await .ok(); Ok(OptionalCookieUser(user_id)) } } impl FromRequestParts for RequiredCookieUser where AppState: FromRef, S: Send + Sync, { type Rejection = axum::response::Response; async fn from_request_parts(parts: &mut Parts, state: &S) -> Result { let app_state = AppState::from_ref(state); let token = extract_token_from_cookie(parts) .ok_or_else(|| Redirect::to("/login").into_response())?; let user_id = app_state .app_ctx .auth_service .validate_token(&token) .await .map_err(|_| Redirect::to("/login").into_response())?; Ok(RequiredCookieUser(user_id)) } } #[cfg(test)] mod tests { use super::*; use axum::{ body::Body, http::{Request, StatusCode}, routing::get, Router, }; use tower::ServiceExt; async fn protected_handler(user: AuthenticatedUser) -> String { user.0.value().to_string() } fn test_router(state: crate::state::AppState) -> Router { Router::new() .route("/protected", get(protected_handler)) .with_state(state) } #[tokio::test] async fn missing_auth_header_returns_401() { use std::sync::Arc; use application::context::AppContext; struct PanicRepo; #[async_trait::async_trait] impl domain::ports::MovieRepository for PanicRepo { async fn get_movie_by_external_id(&self, _: &domain::value_objects::ExternalMetadataId) -> Result, domain::errors::DomainError> { panic!() } async fn get_movie_by_id(&self, _: &domain::value_objects::MovieId) -> Result, domain::errors::DomainError> { panic!() } async fn get_movies_by_title_and_year(&self, _: &domain::value_objects::MovieTitle, _: &domain::value_objects::ReleaseYear) -> Result, domain::errors::DomainError> { panic!() } async fn upsert_movie(&self, _: &domain::models::Movie) -> Result<(), domain::errors::DomainError> { panic!() } async fn save_review(&self, _: &domain::models::Review) -> Result { panic!() } async fn query_diary(&self, _: &domain::models::DiaryFilter) -> Result, domain::errors::DomainError> { panic!() } async fn get_review_history(&self, _: &domain::value_objects::MovieId) -> Result { panic!() } async fn get_review_by_id(&self, _: &domain::value_objects::ReviewId) -> Result, domain::errors::DomainError> { panic!() } async fn delete_review(&self, _: &domain::value_objects::ReviewId) -> Result<(), domain::errors::DomainError> { panic!() } async fn delete_movie(&self, _: &domain::value_objects::MovieId) -> Result<(), domain::errors::DomainError> { panic!() } async fn query_activity_feed(&self, _: &domain::models::collections::PageParams) -> Result, domain::errors::DomainError> { panic!() } async fn get_user_stats(&self, _: &domain::value_objects::UserId) -> Result { panic!() } async fn get_user_history(&self, _: &domain::value_objects::UserId) -> Result, domain::errors::DomainError> { panic!() } async fn get_user_trends(&self, _: &domain::value_objects::UserId) -> Result { panic!() } } struct PanicRenderer; impl crate::ports::HtmlRenderer for PanicRenderer { fn render_diary_page(&self, _: &domain::models::collections::Paginated, _: application::ports::HtmlPageContext) -> Result { panic!() } fn render_login_page(&self, _: application::ports::LoginPageData<'_>) -> Result { panic!() } fn render_register_page(&self, _: application::ports::RegisterPageData<'_>) -> Result { panic!() } fn render_new_review_page(&self, _: application::ports::NewReviewPageData<'_>) -> Result { panic!() } fn render_activity_feed_page(&self, _: application::ports::ActivityFeedPageData) -> Result { panic!() } fn render_users_page(&self, _: application::ports::UsersPageData) -> Result { panic!() } fn render_profile_page(&self, _: application::ports::ProfilePageData) -> Result { panic!() } } struct PanicRssRenderer; impl crate::ports::RssFeedRenderer for PanicRssRenderer { fn render_feed(&self, _: &[domain::models::DiaryEntry]) -> Result { panic!() } } struct PanicMeta; struct PanicFetcher; struct PanicStorage; struct PanicEvent; struct PanicHasher; struct PanicAuth; struct PanicUserRepo; #[async_trait::async_trait] impl domain::ports::MetadataClient for PanicMeta { async fn fetch_movie_metadata(&self, _: &domain::ports::MetadataSearchCriteria) -> Result { panic!() } async fn get_poster_url(&self, _: &domain::value_objects::ExternalMetadataId) -> Result, domain::errors::DomainError> { panic!() } } #[async_trait::async_trait] impl domain::ports::PosterFetcherClient for PanicFetcher { async fn fetch_poster_bytes(&self, _: &domain::value_objects::PosterUrl) -> Result, domain::errors::DomainError> { panic!() } } #[async_trait::async_trait] impl domain::ports::PosterStorage for PanicStorage { async fn store_poster(&self, _: &domain::value_objects::MovieId, _: &[u8]) -> Result { panic!() } async fn get_poster(&self, _: &domain::value_objects::PosterPath) -> Result, domain::errors::DomainError> { panic!() } } #[async_trait::async_trait] impl domain::ports::EventPublisher for PanicEvent { async fn publish(&self, _: &domain::events::DomainEvent) -> Result<(), domain::errors::DomainError> { panic!() } } #[async_trait::async_trait] impl domain::ports::PasswordHasher for PanicHasher { async fn hash(&self, _: &str) -> Result { panic!() } async fn verify(&self, _: &str, _: &domain::value_objects::PasswordHash) -> Result { panic!() } } #[async_trait::async_trait] impl domain::ports::AuthService for PanicAuth { async fn generate_token(&self, _: &domain::value_objects::UserId) -> Result { panic!() } async fn validate_token(&self, _: &str) -> Result { panic!() } } #[async_trait::async_trait] impl domain::ports::UserRepository for PanicUserRepo { async fn find_by_email(&self, _: &domain::value_objects::Email) -> Result, domain::errors::DomainError> { panic!() } async fn save(&self, _: &domain::models::User) -> Result<(), domain::errors::DomainError> { panic!() } async fn find_by_id(&self, _: &domain::value_objects::UserId) -> Result, domain::errors::DomainError> { panic!() } async fn list_with_stats(&self) -> Result, domain::errors::DomainError> { panic!() } } let state = crate::state::AppState { app_ctx: AppContext { repository: Arc::new(PanicRepo), metadata_client: Arc::new(PanicMeta), poster_fetcher: Arc::new(PanicFetcher), poster_storage: Arc::new(PanicStorage), event_publisher: Arc::new(PanicEvent), auth_service: Arc::new(PanicAuth), password_hasher: Arc::new(PanicHasher), user_repository: Arc::new(PanicUserRepo), config: application::config::AppConfig { allow_registration: false }, }, html_renderer: Arc::new(PanicRenderer), rss_renderer: Arc::new(PanicRssRenderer), }; let app = test_router(state); let response = app .oneshot( Request::builder() .uri("/protected") .body(Body::empty()) .unwrap(), ) .await .unwrap(); assert_eq!(response.status(), StatusCode::UNAUTHORIZED); } // Reusable helpers for cookie extractor tests async fn optional_cookie_handler(user: OptionalCookieUser) -> String { match user.0 { Some(id) => id.value().to_string(), None => "none".to_string(), } } async fn required_cookie_handler(user: RequiredCookieUser) -> String { user.0.value().to_string() } fn test_router_optional(state: crate::state::AppState) -> Router { Router::new() .route("/optional", get(optional_cookie_handler)) .with_state(state) } fn test_router_required(state: crate::state::AppState) -> Router { Router::new() .route("/required", get(required_cookie_handler)) .with_state(state) } struct RejectingAuth; #[async_trait::async_trait] impl domain::ports::AuthService for RejectingAuth { async fn generate_token(&self, _: &domain::value_objects::UserId) -> Result { panic!() } async fn validate_token(&self, _: &str) -> Result { Err(domain::errors::DomainError::Unauthorized("bad token".into())) } } fn panic_state() -> crate::state::AppState { use std::sync::Arc; use application::context::AppContext; struct PanicRepo2; #[async_trait::async_trait] impl domain::ports::MovieRepository for PanicRepo2 { async fn get_movie_by_external_id(&self, _: &domain::value_objects::ExternalMetadataId) -> Result, domain::errors::DomainError> { panic!() } async fn get_movie_by_id(&self, _: &domain::value_objects::MovieId) -> Result, domain::errors::DomainError> { panic!() } async fn get_movies_by_title_and_year(&self, _: &domain::value_objects::MovieTitle, _: &domain::value_objects::ReleaseYear) -> Result, domain::errors::DomainError> { panic!() } async fn upsert_movie(&self, _: &domain::models::Movie) -> Result<(), domain::errors::DomainError> { panic!() } async fn save_review(&self, _: &domain::models::Review) -> Result { panic!() } async fn query_diary(&self, _: &domain::models::DiaryFilter) -> Result, domain::errors::DomainError> { panic!() } async fn get_review_history(&self, _: &domain::value_objects::MovieId) -> Result { panic!() } async fn get_review_by_id(&self, _: &domain::value_objects::ReviewId) -> Result, domain::errors::DomainError> { panic!() } async fn delete_review(&self, _: &domain::value_objects::ReviewId) -> Result<(), domain::errors::DomainError> { panic!() } async fn delete_movie(&self, _: &domain::value_objects::MovieId) -> Result<(), domain::errors::DomainError> { panic!() } async fn query_activity_feed(&self, _: &domain::models::collections::PageParams) -> Result, domain::errors::DomainError> { panic!() } async fn get_user_stats(&self, _: &domain::value_objects::UserId) -> Result { panic!() } async fn get_user_history(&self, _: &domain::value_objects::UserId) -> Result, domain::errors::DomainError> { panic!() } async fn get_user_trends(&self, _: &domain::value_objects::UserId) -> Result { panic!() } } struct PanicMeta2; struct PanicFetcher2; struct PanicStorage2; struct PanicEvent2; struct PanicHasher2; struct PanicUserRepo2; #[async_trait::async_trait] impl domain::ports::MetadataClient for PanicMeta2 { async fn fetch_movie_metadata(&self, _: &domain::ports::MetadataSearchCriteria) -> Result { panic!() } async fn get_poster_url(&self, _: &domain::value_objects::ExternalMetadataId) -> Result, domain::errors::DomainError> { panic!() } } #[async_trait::async_trait] impl domain::ports::PosterFetcherClient for PanicFetcher2 { async fn fetch_poster_bytes(&self, _: &domain::value_objects::PosterUrl) -> Result, domain::errors::DomainError> { panic!() } } #[async_trait::async_trait] impl domain::ports::PosterStorage for PanicStorage2 { async fn store_poster(&self, _: &domain::value_objects::MovieId, _: &[u8]) -> Result { panic!() } async fn get_poster(&self, _: &domain::value_objects::PosterPath) -> Result, domain::errors::DomainError> { panic!() } } #[async_trait::async_trait] impl domain::ports::EventPublisher for PanicEvent2 { async fn publish(&self, _: &domain::events::DomainEvent) -> Result<(), domain::errors::DomainError> { panic!() } } #[async_trait::async_trait] impl domain::ports::PasswordHasher for PanicHasher2 { async fn hash(&self, _: &str) -> Result { panic!() } async fn verify(&self, _: &str, _: &domain::value_objects::PasswordHash) -> Result { panic!() } } #[async_trait::async_trait] impl domain::ports::AuthService for PanicAuth2 { async fn generate_token(&self, _: &domain::value_objects::UserId) -> Result { panic!() } async fn validate_token(&self, _: &str) -> Result { panic!() } } #[async_trait::async_trait] impl domain::ports::UserRepository for PanicUserRepo2 { async fn find_by_email(&self, _: &domain::value_objects::Email) -> Result, domain::errors::DomainError> { panic!() } async fn save(&self, _: &domain::models::User) -> Result<(), domain::errors::DomainError> { panic!() } async fn find_by_id(&self, _: &domain::value_objects::UserId) -> Result, domain::errors::DomainError> { panic!() } async fn list_with_stats(&self) -> Result, domain::errors::DomainError> { panic!() } } struct PanicRenderer2; impl crate::ports::HtmlRenderer for PanicRenderer2 { fn render_diary_page(&self, _: &domain::models::collections::Paginated, _: application::ports::HtmlPageContext) -> Result { panic!() } fn render_login_page(&self, _: application::ports::LoginPageData<'_>) -> Result { panic!() } fn render_register_page(&self, _: application::ports::RegisterPageData<'_>) -> Result { panic!() } fn render_new_review_page(&self, _: application::ports::NewReviewPageData<'_>) -> Result { panic!() } fn render_activity_feed_page(&self, _: application::ports::ActivityFeedPageData) -> Result { panic!() } fn render_users_page(&self, _: application::ports::UsersPageData) -> Result { panic!() } fn render_profile_page(&self, _: application::ports::ProfilePageData) -> Result { panic!() } } struct PanicRssRenderer2; impl crate::ports::RssFeedRenderer for PanicRssRenderer2 { fn render_feed(&self, _: &[domain::models::DiaryEntry]) -> Result { panic!() } } struct PanicAuth2; crate::state::AppState { app_ctx: AppContext { repository: Arc::new(PanicRepo2), metadata_client: Arc::new(PanicMeta2), poster_fetcher: Arc::new(PanicFetcher2), poster_storage: Arc::new(PanicStorage2), event_publisher: Arc::new(PanicEvent2), auth_service: Arc::new(PanicAuth2), password_hasher: Arc::new(PanicHasher2), user_repository: Arc::new(PanicUserRepo2), config: application::config::AppConfig { allow_registration: false }, }, html_renderer: Arc::new(PanicRenderer2), rss_renderer: Arc::new(PanicRssRenderer2), } } fn rejecting_state() -> crate::state::AppState { use std::sync::Arc; use application::context::AppContext; struct PanicRepo3; #[async_trait::async_trait] impl domain::ports::MovieRepository for PanicRepo3 { async fn get_movie_by_external_id(&self, _: &domain::value_objects::ExternalMetadataId) -> Result, domain::errors::DomainError> { panic!() } async fn get_movie_by_id(&self, _: &domain::value_objects::MovieId) -> Result, domain::errors::DomainError> { panic!() } async fn get_movies_by_title_and_year(&self, _: &domain::value_objects::MovieTitle, _: &domain::value_objects::ReleaseYear) -> Result, domain::errors::DomainError> { panic!() } async fn upsert_movie(&self, _: &domain::models::Movie) -> Result<(), domain::errors::DomainError> { panic!() } async fn save_review(&self, _: &domain::models::Review) -> Result { panic!() } async fn query_diary(&self, _: &domain::models::DiaryFilter) -> Result, domain::errors::DomainError> { panic!() } async fn get_review_history(&self, _: &domain::value_objects::MovieId) -> Result { panic!() } async fn get_review_by_id(&self, _: &domain::value_objects::ReviewId) -> Result, domain::errors::DomainError> { panic!() } async fn delete_review(&self, _: &domain::value_objects::ReviewId) -> Result<(), domain::errors::DomainError> { panic!() } async fn delete_movie(&self, _: &domain::value_objects::MovieId) -> Result<(), domain::errors::DomainError> { panic!() } async fn query_activity_feed(&self, _: &domain::models::collections::PageParams) -> Result, domain::errors::DomainError> { panic!() } async fn get_user_stats(&self, _: &domain::value_objects::UserId) -> Result { panic!() } async fn get_user_history(&self, _: &domain::value_objects::UserId) -> Result, domain::errors::DomainError> { panic!() } async fn get_user_trends(&self, _: &domain::value_objects::UserId) -> Result { panic!() } } struct PanicMeta3; struct PanicFetcher3; struct PanicStorage3; struct PanicEvent3; struct PanicHasher3; struct PanicUserRepo3; #[async_trait::async_trait] impl domain::ports::MetadataClient for PanicMeta3 { async fn fetch_movie_metadata(&self, _: &domain::ports::MetadataSearchCriteria) -> Result { panic!() } async fn get_poster_url(&self, _: &domain::value_objects::ExternalMetadataId) -> Result, domain::errors::DomainError> { panic!() } } #[async_trait::async_trait] impl domain::ports::PosterFetcherClient for PanicFetcher3 { async fn fetch_poster_bytes(&self, _: &domain::value_objects::PosterUrl) -> Result, domain::errors::DomainError> { panic!() } } #[async_trait::async_trait] impl domain::ports::PosterStorage for PanicStorage3 { async fn store_poster(&self, _: &domain::value_objects::MovieId, _: &[u8]) -> Result { panic!() } async fn get_poster(&self, _: &domain::value_objects::PosterPath) -> Result, domain::errors::DomainError> { panic!() } } #[async_trait::async_trait] impl domain::ports::EventPublisher for PanicEvent3 { async fn publish(&self, _: &domain::events::DomainEvent) -> Result<(), domain::errors::DomainError> { panic!() } } #[async_trait::async_trait] impl domain::ports::PasswordHasher for PanicHasher3 { async fn hash(&self, _: &str) -> Result { panic!() } async fn verify(&self, _: &str, _: &domain::value_objects::PasswordHash) -> Result { panic!() } } #[async_trait::async_trait] impl domain::ports::UserRepository for PanicUserRepo3 { async fn find_by_email(&self, _: &domain::value_objects::Email) -> Result, domain::errors::DomainError> { panic!() } async fn save(&self, _: &domain::models::User) -> Result<(), domain::errors::DomainError> { panic!() } async fn find_by_id(&self, _: &domain::value_objects::UserId) -> Result, domain::errors::DomainError> { panic!() } async fn list_with_stats(&self) -> Result, domain::errors::DomainError> { panic!() } } struct PanicRenderer3; impl crate::ports::HtmlRenderer for PanicRenderer3 { fn render_diary_page(&self, _: &domain::models::collections::Paginated, _: application::ports::HtmlPageContext) -> Result { panic!() } fn render_login_page(&self, _: application::ports::LoginPageData<'_>) -> Result { panic!() } fn render_register_page(&self, _: application::ports::RegisterPageData<'_>) -> Result { panic!() } fn render_new_review_page(&self, _: application::ports::NewReviewPageData<'_>) -> Result { panic!() } fn render_activity_feed_page(&self, _: application::ports::ActivityFeedPageData) -> Result { panic!() } fn render_users_page(&self, _: application::ports::UsersPageData) -> Result { panic!() } fn render_profile_page(&self, _: application::ports::ProfilePageData) -> Result { panic!() } } struct PanicRssRenderer3; impl crate::ports::RssFeedRenderer for PanicRssRenderer3 { fn render_feed(&self, _: &[domain::models::DiaryEntry]) -> Result { panic!() } } crate::state::AppState { app_ctx: AppContext { repository: Arc::new(PanicRepo3), metadata_client: Arc::new(PanicMeta3), poster_fetcher: Arc::new(PanicFetcher3), poster_storage: Arc::new(PanicStorage3), event_publisher: Arc::new(PanicEvent3), auth_service: Arc::new(RejectingAuth), password_hasher: Arc::new(PanicHasher3), user_repository: Arc::new(PanicUserRepo3), config: application::config::AppConfig { allow_registration: false }, }, html_renderer: Arc::new(PanicRenderer3), rss_renderer: Arc::new(PanicRssRenderer3), } } #[tokio::test] async fn optional_cookie_user_returns_none_without_cookie() { let app = test_router_optional(panic_state()); let response = app .oneshot(Request::builder().uri("/optional").body(Body::empty()).unwrap()) .await .unwrap(); assert_eq!(response.status(), StatusCode::OK); let body = axum::body::to_bytes(response.into_body(), usize::MAX).await.unwrap(); assert_eq!(&body[..], b"none"); } #[tokio::test] async fn optional_cookie_user_returns_none_with_invalid_token() { let app = test_router_optional(rejecting_state()); let response = app .oneshot( Request::builder() .uri("/optional") .header("cookie", "token=bad.token.here") .body(Body::empty()) .unwrap(), ) .await .unwrap(); assert_eq!(response.status(), StatusCode::OK); let body = axum::body::to_bytes(response.into_body(), usize::MAX).await.unwrap(); assert_eq!(&body[..], b"none"); } #[tokio::test] async fn required_cookie_user_redirects_without_cookie() { let app = test_router_required(panic_state()); let response = app .oneshot(Request::builder().uri("/required").body(Body::empty()).unwrap()) .await .unwrap(); assert_eq!(response.status(), StatusCode::SEE_OTHER); assert_eq!(response.headers().get("location").unwrap(), "/login"); } #[tokio::test] async fn required_cookie_user_redirects_with_invalid_token() { let app = test_router_required(rejecting_state()); let response = app .oneshot( Request::builder() .uri("/required") .header("cookie", "token=bad.token.here") .body(Body::empty()) .unwrap(), ) .await .unwrap(); assert_eq!(response.status(), StatusCode::SEE_OTHER); assert_eq!(response.headers().get("location").unwrap(), "/login"); } }