use axum::{ extract::{FromRef, FromRequestParts}, http::{header, header::AUTHORIZATION, request::Parts}, response::{IntoResponse, Redirect}, }; use domain::{errors::DomainError, value_objects::UserId}; use crate::{errors::ApiError, state::AppState}; pub struct AuthenticatedUser(pub UserId); impl FromRequestParts for AuthenticatedUser where AppState: FromRef, S: Send + Sync, { type Rejection = ApiError; async fn from_request_parts(parts: &mut Parts, state: &S) -> Result { let app_state = AppState::from_ref(state); let token = parts .headers .get(AUTHORIZATION) .and_then(|v| v.to_str().ok()) .and_then(|v| v.strip_prefix("Bearer ")) .ok_or_else(|| { ApiError(DomainError::Unauthorized( "Missing or invalid auth token".into(), )) })?; let user_id = app_state .app_ctx .auth_service .validate_token(token) .await?; Ok(AuthenticatedUser(user_id)) } } pub struct OptionalCookieUser(pub Option); pub struct RequiredCookieUser(pub UserId); fn extract_token_from_cookie(parts: &Parts) -> Option { parts .headers .get(header::COOKIE) .and_then(|v| v.to_str().ok()) .and_then(|cookies| { cookies .split(';') .find_map(|c| c.trim().strip_prefix("token=").map(str::to_string)) }) } impl FromRequestParts for OptionalCookieUser where AppState: FromRef, S: Send + Sync, { type Rejection = std::convert::Infallible; async fn from_request_parts(parts: &mut Parts, state: &S) -> Result { let app_state = AppState::from_ref(state); let Some(token) = extract_token_from_cookie(parts) else { return Ok(OptionalCookieUser(None)); }; let user_id = app_state .app_ctx .auth_service .validate_token(&token) .await .ok(); Ok(OptionalCookieUser(user_id)) } } impl FromRequestParts for RequiredCookieUser where AppState: FromRef, S: Send + Sync, { type Rejection = axum::response::Response; async fn from_request_parts(parts: &mut Parts, state: &S) -> Result { let app_state = AppState::from_ref(state); let token = extract_token_from_cookie(parts) .ok_or_else(|| Redirect::to("/login").into_response())?; let user_id = app_state .app_ctx .auth_service .validate_token(&token) .await .map_err(|_| Redirect::to("/login").into_response())?; Ok(RequiredCookieUser(user_id)) } } #[cfg(test)] mod tests { use super::*; use std::sync::Arc; use axum::{ body::Body, http::{Request, StatusCode}, routing::get, Router, }; use application::{config::AppConfig, context::AppContext}; use domain::{ errors::DomainError, events::DomainEvent, models::{DiaryEntry, DiaryFilter, FeedEntry, Movie, Review, ReviewHistory, UserStats, UserTrends, collections::{PageParams, Paginated}}, ports::{ AuthService, DiaryRepository, EventPublisher, GeneratedToken, MetadataClient, MovieRepository, PasswordHasher, PosterFetcherClient, PosterStorage, ReviewRepository, StatsRepository, UserRepository, }, value_objects::{ Email, ExternalMetadataId, MovieId, MovieTitle, PasswordHash, PosterPath, PosterUrl, ReleaseYear, ReviewId, UserId, }, }; use tower::ServiceExt; // --- Panic stubs (defined once) --- struct Panic; #[async_trait::async_trait] impl MovieRepository for Panic { async fn get_movie_by_external_id(&self, _: &ExternalMetadataId) -> Result, DomainError> { panic!() } async fn get_movie_by_id(&self, _: &MovieId) -> Result, DomainError> { panic!() } async fn get_movies_by_title_and_year(&self, _: &MovieTitle, _: &ReleaseYear) -> Result, DomainError> { panic!() } async fn upsert_movie(&self, _: &Movie) -> Result<(), DomainError> { panic!() } async fn delete_movie(&self, _: &MovieId) -> Result<(), DomainError> { panic!() } } #[async_trait::async_trait] impl ReviewRepository for Panic { async fn save_review(&self, _: &Review) -> Result { panic!() } async fn get_review_by_id(&self, _: &ReviewId) -> Result, DomainError> { panic!() } async fn delete_review(&self, _: &ReviewId) -> Result<(), DomainError> { panic!() } } #[async_trait::async_trait] impl DiaryRepository for Panic { async fn query_diary(&self, _: &DiaryFilter) -> Result, DomainError> { panic!() } async fn query_activity_feed(&self, _: &PageParams) -> Result, DomainError> { panic!() } async fn get_review_history(&self, _: &MovieId) -> Result { panic!() } async fn get_user_history(&self, _: &UserId) -> Result, DomainError> { panic!() } } #[async_trait::async_trait] impl StatsRepository for Panic { async fn get_user_stats(&self, _: &UserId) -> Result { panic!() } async fn get_user_trends(&self, _: &UserId) -> Result { panic!() } } #[async_trait::async_trait] impl MetadataClient for Panic { async fn fetch_movie_metadata(&self, _: &domain::ports::MetadataSearchCriteria) -> Result { panic!() } async fn get_poster_url(&self, _: &ExternalMetadataId) -> Result, DomainError> { panic!() } } #[async_trait::async_trait] impl PosterFetcherClient for Panic { async fn fetch_poster_bytes(&self, _: &PosterUrl) -> Result, DomainError> { panic!() } } #[async_trait::async_trait] impl PosterStorage for Panic { async fn store_poster(&self, _: &MovieId, _: &[u8]) -> Result { panic!() } async fn get_poster(&self, _: &PosterPath) -> Result, DomainError> { panic!() } } #[async_trait::async_trait] impl AuthService for Panic { async fn generate_token(&self, _: &UserId) -> Result { panic!() } async fn validate_token(&self, _: &str) -> Result { panic!() } } #[async_trait::async_trait] impl PasswordHasher for Panic { async fn hash(&self, _: &str) -> Result { panic!() } async fn verify(&self, _: &str, _: &PasswordHash) -> Result { panic!() } } #[async_trait::async_trait] impl UserRepository for Panic { async fn find_by_email(&self, _: &Email) -> Result, DomainError> { panic!() } async fn save(&self, _: &domain::models::User) -> Result<(), DomainError> { panic!() } async fn find_by_id(&self, _: &UserId) -> Result, DomainError> { panic!() } async fn find_by_username(&self, _: &domain::value_objects::Username) -> Result, DomainError> { panic!() } async fn list_with_stats(&self) -> Result, DomainError> { panic!() } } #[async_trait::async_trait] impl EventPublisher for Panic { async fn publish(&self, _: &DomainEvent) -> Result<(), DomainError> { panic!() } } impl crate::ports::HtmlRenderer for Panic { fn render_diary_page(&self, _: &Paginated, _: application::ports::HtmlPageContext) -> Result { panic!() } fn render_login_page(&self, _: application::ports::LoginPageData<'_>) -> Result { panic!() } fn render_register_page(&self, _: application::ports::RegisterPageData<'_>) -> Result { panic!() } fn render_new_review_page(&self, _: application::ports::NewReviewPageData<'_>) -> Result { panic!() } fn render_activity_feed_page(&self, _: application::ports::ActivityFeedPageData) -> Result { panic!() } fn render_users_page(&self, _: application::ports::UsersPageData) -> Result { panic!() } fn render_profile_page(&self, _: application::ports::ProfilePageData) -> Result { panic!() } fn render_following_page(&self, _: application::ports::FollowingPageData) -> Result { panic!() } fn render_followers_page(&self, _: application::ports::FollowersPageData) -> Result { panic!() } } impl crate::ports::RssFeedRenderer for Panic { fn render_feed(&self, _: &[DiaryEntry], _: &str) -> Result { panic!() } } struct RejectingAuth; #[async_trait::async_trait] impl AuthService for RejectingAuth { async fn generate_token(&self, _: &UserId) -> Result { panic!() } async fn validate_token(&self, _: &str) -> Result { Err(DomainError::Unauthorized("bad token".into())) } } // --- Single state factory — only auth_service varies --- fn make_test_state(auth_service: Arc) -> crate::state::AppState { let repo = Arc::new(Panic); crate::state::AppState { app_ctx: AppContext { movie_repository: Arc::clone(&repo) as _, review_repository: Arc::clone(&repo) as _, diary_repository: Arc::clone(&repo) as _, stats_repository: Arc::clone(&repo) as _, metadata_client: Arc::clone(&repo) as _, poster_fetcher: Arc::clone(&repo) as _, poster_storage: Arc::clone(&repo) as _, event_publisher: Arc::clone(&repo) as _, password_hasher: Arc::clone(&repo) as _, user_repository: Arc::clone(&repo) as _, auth_service, config: AppConfig { allow_registration: false, base_url: "http://localhost:3000".to_string(), rate_limit: 20 }, }, html_renderer: Arc::new(Panic), rss_renderer: Arc::new(Panic), ap_service: Arc::new(activitypub::NoopActivityPubService), } } // --- Routers --- async fn protected_handler(user: AuthenticatedUser) -> String { user.0.value().to_string() } async fn optional_cookie_handler(user: OptionalCookieUser) -> String { match user.0 { Some(id) => id.value().to_string(), None => "none".to_string() } } async fn required_cookie_handler(user: RequiredCookieUser) -> String { user.0.value().to_string() } fn router_protected(state: crate::state::AppState) -> Router { Router::new().route("/protected", get(protected_handler)).with_state(state) } fn router_optional(state: crate::state::AppState) -> Router { Router::new().route("/optional", get(optional_cookie_handler)).with_state(state) } fn router_required(state: crate::state::AppState) -> Router { Router::new().route("/required", get(required_cookie_handler)).with_state(state) } // --- Tests --- #[tokio::test] async fn missing_auth_header_returns_401() { let app = router_protected(make_test_state(Arc::new(Panic))); let resp = app.oneshot(Request::builder().uri("/protected").body(Body::empty()).unwrap()).await.unwrap(); assert_eq!(resp.status(), StatusCode::UNAUTHORIZED); } #[tokio::test] async fn optional_cookie_user_returns_none_without_cookie() { let app = router_optional(make_test_state(Arc::new(Panic))); let resp = app.oneshot(Request::builder().uri("/optional").body(Body::empty()).unwrap()).await.unwrap(); assert_eq!(resp.status(), StatusCode::OK); let body = axum::body::to_bytes(resp.into_body(), usize::MAX).await.unwrap(); assert_eq!(&body[..], b"none"); } #[tokio::test] async fn optional_cookie_user_returns_none_with_invalid_token() { let app = router_optional(make_test_state(Arc::new(RejectingAuth))); let resp = app.oneshot(Request::builder().uri("/optional").header("cookie", "token=bad.token.here").body(Body::empty()).unwrap()).await.unwrap(); assert_eq!(resp.status(), StatusCode::OK); let body = axum::body::to_bytes(resp.into_body(), usize::MAX).await.unwrap(); assert_eq!(&body[..], b"none"); } #[tokio::test] async fn required_cookie_user_redirects_without_cookie() { let app = router_required(make_test_state(Arc::new(Panic))); let resp = app.oneshot(Request::builder().uri("/required").body(Body::empty()).unwrap()).await.unwrap(); assert_eq!(resp.status(), StatusCode::SEE_OTHER); assert_eq!(resp.headers().get("location").unwrap(), "/login"); } #[tokio::test] async fn required_cookie_user_redirects_with_invalid_token() { let app = router_required(make_test_state(Arc::new(RejectingAuth))); let resp = app.oneshot(Request::builder().uri("/required").header("cookie", "token=bad.token.here").body(Body::empty()).unwrap()).await.unwrap(); assert_eq!(resp.status(), StatusCode::SEE_OTHER); assert_eq!(resp.headers().get("location").unwrap(), "/login"); } }