From a5f9e8ae9ea9abbf75b300ed3c3d9f05f2dc917e Mon Sep 17 00:00:00 2001 From: Gabriel Kaszewski Date: Tue, 6 Jan 2026 20:31:57 +0100 Subject: [PATCH] feat: Implement flexible authentication supporting JWT, OIDC, and session modes, alongside new configuration options and refactored auth layer setup. --- Cargo.lock | 2 - notes-api/Cargo.toml | 28 +- notes-api/src/auth.rs | 98 +---- notes-api/src/config.rs | 118 ++++++ notes-api/src/dto.rs | 24 +- notes-api/src/error.rs | 11 + notes-api/src/extractors.rs | 133 ++++++ notes-api/src/main.rs | 128 +++--- notes-api/src/routes/auth.rs | 574 +++++++++++++++++++++----- notes-api/src/routes/import_export.rs | 23 +- notes-api/src/routes/mod.rs | 5 +- notes-api/src/routes/notes.rs | 80 +--- notes-api/src/routes/tags.rs | 43 +- notes-api/src/state.rs | 100 ++++- notes-domain/src/errors.rs | 6 + notes-domain/src/services.rs | 56 +-- notes-infra/src/auth/mod.rs | 6 +- notes-infra/src/session_store.rs | 1 + 18 files changed, 1022 insertions(+), 414 deletions(-) create mode 100644 notes-api/src/extractors.rs diff --git a/Cargo.lock b/Cargo.lock index d28ae67..09ce3ce 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2336,7 +2336,6 @@ dependencies = [ "anyhow", "async-trait", "axum 0.8.8", - "axum-login", "chrono", "dotenvy", "k-core", @@ -2351,7 +2350,6 @@ dependencies = [ "tower 0.5.2", "tower-http", "tower-sessions", - "tower-sessions-sqlx-store", "tracing", "tracing-subscriber", "uuid", diff --git a/notes-api/Cargo.toml b/notes-api/Cargo.toml index cd58238..69cfe7a 100644 --- a/notes-api/Cargo.toml +++ b/notes-api/Cargo.toml @@ -5,16 +5,14 @@ edition = "2024" default-run = "notes-api" [features] -default = ["sqlite", "smart-features"] -sqlite = [ - "notes-infra/sqlite", - "tower-sessions-sqlx-store/sqlite", -] -postgres = [ - "notes-infra/postgres", - "tower-sessions-sqlx-store/postgres", -] +default = ["sqlite", "smart-features", "auth-oidc", "auth-jwt"] +sqlite = ["notes-infra/sqlite"] +postgres = ["notes-infra/postgres"] smart-features = ["notes-infra/smart-features", "notes-infra/broker-nats"] +auth-axum-login = ["notes-infra/auth-axum-login"] +auth-oidc = ["notes-infra/auth-oidc"] +auth-jwt = ["notes-infra/auth-jwt"] +auth-full = ["auth-axum-login", "auth-oidc", "auth-jwt"] [dependencies] notes-domain = { path = "../notes-domain" } @@ -28,9 +26,7 @@ tower = "0.5.2" tower-http = { version = "0.6.2", features = ["cors", "trace"] } # Authentication -axum-login = "0.18" -tower-sessions = "0.14" -tower-sessions-sqlx-store = { version = "0.15", features = ["sqlite"] } + password-auth = "1.0" time = "0.3" async-trait = "0.1.89" @@ -64,5 +60,9 @@ k-core = { git = "https://git.gabrielkaszewski.dev/GKaszewski/k-core", features "db-sqlx", "sqlite", "http", - "auth","sessions-db" -] } \ No newline at end of file + "auth", + "sessions-db", +] } + + +tower-sessions = "0.14.0" diff --git a/notes-api/src/auth.rs b/notes-api/src/auth.rs index 2ad1be0..a7a930b 100644 --- a/notes-api/src/auth.rs +++ b/notes-api/src/auth.rs @@ -1,87 +1,27 @@ -//! Authentication logic using axum-login +//! Authentication logic +//! +//! Proxies to infra implementation if enabled. +#[cfg(feature = "auth-axum-login")] use std::sync::Arc; -use axum_login::{AuthnBackend, UserId}; -use password_auth::verify_password; -use serde::{Deserialize, Serialize}; -use uuid::Uuid; +#[cfg(feature = "auth-axum-login")] +use notes_domain::UserRepository; +#[cfg(feature = "auth-axum-login")] +use notes_infra::session_store::{InfraSessionStore, SessionManagerLayer}; +#[cfg(feature = "auth-axum-login")] use crate::error::ApiError; -use notes_domain::{User, UserRepository}; -/// Wrapper around domain User to implement AuthUser -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct AuthUser(pub User); +#[cfg(feature = "auth-axum-login")] +pub use notes_infra::auth::axum_login::{AuthManagerLayer, AuthSession, AuthUser, Credentials}; -impl axum_login::AuthUser for AuthUser { - type Id = Uuid; - - fn id(&self) -> Self::Id { - self.0.id - } - - fn session_auth_hash(&self) -> &[u8] { - // Use password hash to invalidate sessions if password changes - self.0 - .password_hash - .as_ref() - .map(|s| s.as_bytes()) - .unwrap_or(&[]) - } -} - -#[derive(Clone)] -pub struct AuthBackend { - pub user_repo: Arc, -} - -impl AuthBackend { - pub fn new(user_repo: Arc) -> Self { - Self { user_repo } - } -} - -#[derive(Clone, Debug, Deserialize)] -pub struct Credentials { - pub email: String, - pub password: String, -} - -impl AuthnBackend for AuthBackend { - type User = AuthUser; - type Credentials = Credentials; - type Error = ApiError; - - async fn authenticate( - &self, - creds: Self::Credentials, - ) -> Result, Self::Error> { - let user = self - .user_repo - .find_by_email(&creds.email) - .await - .map_err(|e| ApiError::internal(e.to_string()))?; - - if let Some(user) = user { - if let Some(hash) = &user.password_hash { - // Verify password - if verify_password(&creds.password, hash).is_ok() { - return Ok(Some(AuthUser(user))); - } - } - } - - Ok(None) - } - - async fn get_user(&self, user_id: &UserId) -> Result, Self::Error> { - let user = self - .user_repo - .find_by_id(*user_id) - .await - .map_err(|e| ApiError::internal(e.to_string()))?; - - Ok(user.map(AuthUser)) - } +#[cfg(feature = "auth-axum-login")] +pub async fn setup_auth_layer( + session_layer: SessionManagerLayer, + user_repo: Arc, +) -> Result { + notes_infra::auth::axum_login::setup_auth_layer(session_layer, user_repo) + .await + .map_err(|e| ApiError::Internal(e.to_string())) } diff --git a/notes-api/src/config.rs b/notes-api/src/config.rs index c20f5ee..7d35876 100644 --- a/notes-api/src/config.rs +++ b/notes-api/src/config.rs @@ -1,7 +1,32 @@ #[cfg(feature = "smart-features")] use notes_infra::factory::{EmbeddingProvider, VectorProvider}; +use serde::Deserialize; use std::env; +/// Authentication mode - determines how the API authenticates requests +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Deserialize)] +#[serde(rename_all = "lowercase")] +pub enum AuthMode { + /// Session-based authentication using cookies (default for backward compatibility) + #[default] + Session, + /// JWT-based authentication using Bearer tokens + Jwt, + /// Support both session and JWT authentication (try JWT first, then session) + Both, +} + +impl AuthMode { + /// Parse auth mode from string + pub fn from_str(s: &str) -> Self { + match s.to_lowercase().as_str() { + "jwt" => AuthMode::Jwt, + "both" => AuthMode::Both, + _ => AuthMode::Session, + } + } +} + /// Server configuration #[derive(Debug, Clone)] pub struct Config { @@ -16,6 +41,31 @@ pub struct Config { #[cfg(feature = "smart-features")] pub vector_provider: VectorProvider, pub broker_url: String, + + pub secure_cookie: bool, + + pub db_max_connections: u32, + + pub db_min_connections: u32, + + // OIDC configuration + pub oidc_issuer: Option, + pub oidc_client_id: Option, + pub oidc_client_secret: Option, + pub oidc_redirect_url: Option, + pub oidc_resource_id: Option, + + // Auth mode configuration + pub auth_mode: AuthMode, + + // JWT configuration + pub jwt_secret: Option, + pub jwt_issuer: Option, + pub jwt_audience: Option, + pub jwt_expiry_hours: u64, + + /// Whether the application is running in production mode + pub is_production: bool, } impl Default for Config { @@ -36,6 +86,20 @@ impl Default for Config { collection: "notes".to_string(), }, broker_url: "nats://localhost:4222".to_string(), + secure_cookie: false, + db_max_connections: 5, + db_min_connections: 1, + oidc_issuer: None, + oidc_client_id: None, + oidc_client_secret: None, + oidc_redirect_url: None, + oidc_resource_id: None, + auth_mode: AuthMode::Session, + jwt_secret: None, + jwt_issuer: None, + jwt_audience: None, + jwt_expiry_hours: 24, + is_production: false, } } } @@ -89,6 +153,46 @@ impl Config { let broker_url = env::var("BROKER_URL").unwrap_or_else(|_| "nats://localhost:4222".to_string()); + let secure_cookie = env::var("SECURE_COOKIE") + .ok() + .and_then(|s| s.parse().ok()) + .unwrap_or(false); + + let db_max_connections = env::var("DB_MAX_CONNECTIONS") + .ok() + .and_then(|s| s.parse().ok()) + .unwrap_or(5); + + let db_min_connections = env::var("DB_MIN_CONNECTIONS") + .ok() + .and_then(|s| s.parse().ok()) + .unwrap_or(1); + + let oidc_issuer = env::var("OIDC_ISSUER").ok(); + let oidc_client_id = env::var("OIDC_CLIENT_ID").ok(); + let oidc_client_secret = env::var("OIDC_CLIENT_SECRET").ok(); + let oidc_redirect_url = env::var("OIDC_REDIRECT_URL").ok(); + let oidc_resource_id = env::var("OIDC_RESOURCE_ID").ok(); + + // Auth mode configuration + let auth_mode = env::var("AUTH_MODE") + .map(|s| AuthMode::from_str(&s)) + .unwrap_or_default(); + + // JWT configuration + let jwt_secret = env::var("JWT_SECRET").ok(); + let jwt_issuer = env::var("JWT_ISSUER").ok(); + let jwt_audience = env::var("JWT_AUDIENCE").ok(); + let jwt_expiry_hours = env::var("JWT_EXPIRY_HOURS") + .ok() + .and_then(|s| s.parse().ok()) + .unwrap_or(24); + + let is_production = env::var("PRODUCTION") + .or_else(|_| env::var("RUST_ENV")) + .map(|v| v.to_lowercase() == "production" || v == "1" || v == "true") + .unwrap_or(false); + Self { host, port, @@ -101,6 +205,20 @@ impl Config { #[cfg(feature = "smart-features")] vector_provider, broker_url, + secure_cookie, + db_max_connections, + db_min_connections, + oidc_issuer, + oidc_client_id, + oidc_client_secret, + oidc_redirect_url, + oidc_resource_id, + auth_mode, + jwt_secret, + jwt_issuer, + jwt_audience, + jwt_expiry_hours, + is_production, } } } diff --git a/notes-api/src/dto.rs b/notes-api/src/dto.rs index 46d9fbd..c424759 100644 --- a/notes-api/src/dto.rs +++ b/notes-api/src/dto.rs @@ -5,7 +5,7 @@ use serde::{Deserialize, Serialize}; use uuid::Uuid; use validator::Validate; -use notes_domain::{Note, Tag}; +use notes_domain::{Email, Note, Password, Tag}; /// Request to create a new note #[derive(Debug, Deserialize, Validate)] @@ -118,30 +118,24 @@ pub struct RenameTagRequest { } /// Login request -#[derive(Debug, Deserialize, Validate)] +#[derive(Debug, Deserialize)] pub struct LoginRequest { - #[validate(email(message = "Invalid email format"))] - pub email: String, - - #[validate(length(min = 6, message = "Password must be at least 6 characters"))] - pub password: String, + pub email: Email, + pub password: Password, } /// Register request -#[derive(Debug, Deserialize, Validate)] +#[derive(Debug, Deserialize)] pub struct RegisterRequest { - #[validate(email(message = "Invalid email format"))] - pub email: String, - - #[validate(length(min = 6, message = "Password must be at least 6 characters"))] - pub password: String, + pub email: Email, + pub password: Password, } /// User response DTO #[derive(Debug, Serialize)] pub struct UserResponse { pub id: Uuid, - pub email: String, + pub email: Email, pub created_at: DateTime, } @@ -160,7 +154,7 @@ impl From for NoteVersionResponse { Self { id: version.id, note_id: version.note_id, - title: version.title.unwrap_or_default(), // Convert Option to String + title: version.title.unwrap_or_default(), content: version.content, created_at: version.created_at, } diff --git a/notes-api/src/error.rs b/notes-api/src/error.rs index 6b2fb67..5ea5be7 100644 --- a/notes-api/src/error.rs +++ b/notes-api/src/error.rs @@ -26,6 +26,9 @@ pub enum ApiError { #[error("Forbidden: {0}")] Forbidden(String), + + #[error("Unauthorized: {0}")] + Unauthorized(String), } /// Error response body @@ -96,6 +99,14 @@ impl IntoResponse for ApiError { details: Some(msg.clone()), }, ), + + ApiError::Unauthorized(msg) => ( + StatusCode::UNAUTHORIZED, + ErrorResponse { + error: "Unauthorized".to_string(), + details: Some(msg.clone()), + }, + ), }; (status, Json(error_response)).into_response() diff --git a/notes-api/src/extractors.rs b/notes-api/src/extractors.rs new file mode 100644 index 0000000..3ec5a54 --- /dev/null +++ b/notes-api/src/extractors.rs @@ -0,0 +1,133 @@ +//! Auth extractors for API handlers +//! +//! Provides the `CurrentUser` extractor that works with both session and JWT auth. + +use axum::{extract::FromRequestParts, http::request::Parts}; +use notes_domain::User; + +use crate::config::AuthMode; +use crate::error::ApiError; +use crate::state::AppState; + +/// Extracted current user from the request. +/// +/// This extractor supports multiple authentication methods based on the configured `AuthMode`: +/// - `Session`: Uses axum-login session cookies +/// - `Jwt`: Uses Bearer token in Authorization header +/// - `Both`: Tries JWT first, then falls back to session +pub struct CurrentUser(pub User); + +impl FromRequestParts for CurrentUser { + type Rejection = ApiError; + + async fn from_request_parts( + parts: &mut Parts, + state: &AppState, + ) -> Result { + let auth_mode = state.config.auth_mode; + + // Try JWT first if enabled + #[cfg(feature = "auth-jwt")] + if matches!(auth_mode, AuthMode::Jwt | AuthMode::Both) { + match try_jwt_auth(parts, state).await { + Ok(Some(user)) => return Ok(CurrentUser(user)), + Ok(None) => { + // No JWT token present, continue to session auth if Both mode + if auth_mode == AuthMode::Jwt { + return Err(ApiError::Unauthorized( + "Missing or invalid Authorization header".to_string(), + )); + } + } + Err(e) => { + // JWT was present but invalid + tracing::debug!("JWT auth failed: {}", e); + if auth_mode == AuthMode::Jwt { + return Err(e); + } + // In Both mode, continue to try session + } + } + } + + // Try session auth if enabled + #[cfg(feature = "auth-axum-login")] + if matches!(auth_mode, AuthMode::Session | AuthMode::Both) { + if let Some(user) = try_session_auth(parts).await? { + return Ok(CurrentUser(user)); + } + } + + Err(ApiError::Unauthorized("Not authenticated".to_string())) + } +} + +/// Try to authenticate using JWT Bearer token +#[cfg(feature = "auth-jwt")] +async fn try_jwt_auth(parts: &mut Parts, state: &AppState) -> Result, ApiError> { + use axum::http::header::AUTHORIZATION; + + // Get Authorization header + let auth_header = match parts.headers.get(AUTHORIZATION) { + Some(header) => header, + None => return Ok(None), // No header = no JWT auth attempted + }; + + let auth_str = auth_header + .to_str() + .map_err(|_| ApiError::Unauthorized("Invalid Authorization header encoding".to_string()))?; + + // Extract Bearer token + let token = auth_str.strip_prefix("Bearer ").ok_or_else(|| { + ApiError::Unauthorized("Authorization header must use Bearer scheme".to_string()) + })?; + + // Get JWT validator + let validator = state + .jwt_validator + .as_ref() + .ok_or_else(|| ApiError::Internal("JWT validator not configured".to_string()))?; + + // Validate token + let claims = validator.validate_token(token).map_err(|e| { + tracing::debug!("JWT validation failed: {:?}", e); + match e { + notes_infra::auth::jwt::JwtError::Expired => { + ApiError::Unauthorized("Token expired".to_string()) + } + notes_infra::auth::jwt::JwtError::InvalidFormat => { + ApiError::Unauthorized("Invalid token format".to_string()) + } + _ => ApiError::Unauthorized("Token validation failed".to_string()), + } + })?; + + // Fetch user from database by ID (subject contains user ID) + let user_id: uuid::Uuid = claims + .sub + .parse() + .map_err(|_| ApiError::Unauthorized("Invalid user ID in token".to_string()))?; + + let user = state + .user_service + .find_by_id(user_id) + .await + .map_err(|e| ApiError::Internal(format!("Failed to fetch user: {}", e)))?; + + Ok(Some(user)) +} + +/// Try to authenticate using session cookie +#[cfg(feature = "auth-axum-login")] +async fn try_session_auth(parts: &mut Parts) -> Result, ApiError> { + use notes_infra::auth::axum_login::AuthSession; + + // Check if AuthSession extension is present (added by auth middleware) + if let Some(auth_session) = parts.extensions.get::() { + if let Some(auth_user) = &auth_session.user { + return Ok(Some(auth_user.0.clone())); + } + } + + Ok(None) +} diff --git a/notes-api/src/main.rs b/notes-api/src/main.rs index 99abea2..d9fc0e3 100644 --- a/notes-api/src/main.rs +++ b/notes-api/src/main.rs @@ -2,17 +2,15 @@ //! //! A high-performance, self-hosted note-taking API following hexagonal architecture. -use k_core::{ - db::DatabasePool, - http::server::{ServerConfig, apply_standard_middleware}, -}; +use k_core::http::server::{ServerConfig, apply_standard_middleware}; +use std::net::SocketAddr; use std::{sync::Arc, time::Duration as StdDuration}; use time::Duration; +use tokio::net::TcpListener; +use tower_sessions::cookie::SameSite; +use tower_sessions::{Expiry, SessionManagerLayer}; use axum::Router; -use axum_login::AuthManagerLayerBuilder; - -use tower_sessions::{Expiry, SessionManagerLayer}; use notes_infra::run_migrations; @@ -20,13 +18,15 @@ mod auth; mod config; mod dto; mod error; +mod extractors; mod routes; mod state; -use auth::AuthBackend; use config::Config; use state::AppState; +use crate::config::AuthMode; + #[tokio::main] async fn main() -> anyhow::Result<()> { k_core::logging::init("notes_api"); @@ -53,9 +53,6 @@ async fn main() -> anyhow::Result<()> { build_note_repository, build_session_store, build_tag_repository, build_user_repository, }; - // Create a default user for development - create_dev_user(&db_pool).await.ok(); - // Create repositories via factory let note_repo = build_note_repository(&db_pool) .await @@ -105,20 +102,16 @@ async fn main() -> anyhow::Result<()> { let state = AppState::new( note_repo, tag_repo, - user_repo.clone(), #[cfg(feature = "smart-features")] link_repo, note_service, tag_service, user_service, config.clone(), - ); + ) + .await?; - // Auth backend - let backend = AuthBackend::new(user_repo); // no idea what now with this - - // Session layer - // Use the factory to build the session store, agnostic of the underlying DB + // Build session store (needed for OIDC flow even in JWT mode) let session_store = build_session_store(&db_pool) .await .map_err(|e| anyhow::anyhow!(e))?; @@ -128,28 +121,24 @@ async fn main() -> anyhow::Result<()> { .map_err(|e| anyhow::anyhow!(e))?; let session_layer = SessionManagerLayer::new(session_store) - .with_secure(false) // Set to true in prod + .with_secure(config.secure_cookie) + .with_same_site(SameSite::Lax) .with_expiry(Expiry::OnInactivity(Duration::days(7))); - let auth_layer = AuthManagerLayerBuilder::new(backend, session_layer).build(); - let server_config = ServerConfig { cors_origins: config.cors_allowed_origins.clone(), session_secret: Some(config.session_secret.clone()), }; - let app = Router::new() - .nest("/api/v1", routes::api_v1_router()) - .layer(auth_layer) - .with_state(state); - + // Build the app with appropriate auth layers based on config + let app = build_app(state, session_layer, user_repo, &config).await?; let app = apply_standard_middleware(app, &server_config); - let addr = format!("{}:{}", config.host, config.port); - let listener = tokio::net::TcpListener::bind(&addr).await?; + let addr: SocketAddr = format!("{}:{}", config.host, config.port).parse()?; + let listener = TcpListener::bind(addr).await?; - tracing::info!("🚀 K-Notes API server running at http://{}", addr); - tracing::info!("🔒 Authentication enabled (axum-login)"); + tracing::info!("🚀 API server running at http://{}", addr); + log_auth_info(&config); tracing::info!("📝 API endpoints available at /api/v1/..."); axum::serve(listener, app).await?; @@ -157,32 +146,61 @@ async fn main() -> anyhow::Result<()> { Ok(()) } -async fn create_dev_user(pool: &DatabasePool) -> anyhow::Result<()> { - use notes_domain::{Email, User}; - use notes_infra::factory::build_user_repository; - use password_auth::generate_hash; - use uuid::Uuid; +/// Build the application router with appropriate auth layers +#[allow(unused_variables)] // config/user_repo used conditionally based on features +async fn build_app( + state: AppState, + session_layer: SessionManagerLayer, + user_repo: std::sync::Arc, + config: &Config, +) -> anyhow::Result { + let app = Router::new() + .nest("/api/v1", routes::api_v1_router()) + .with_state(state); - let user_repo = build_user_repository(pool) - .await - .map_err(|e| anyhow::anyhow!(e))?; - - // Check if dev user exists - let dev_user_id = Uuid::parse_str("00000000-0000-0000-0000-000000000001").unwrap(); - if user_repo.find_by_id(dev_user_id).await?.is_none() { - let hash = generate_hash("password"); - let dev_email = Email::try_from("dev@localhost.com") - .map_err(|e| anyhow::anyhow!("Invalid dev email: {}", e))?; - let user = User::with_id( - dev_user_id, - "dev|local", - dev_email, - Some(hash), - chrono::Utc::now(), - ); - user_repo.save(&user).await?; - tracing::info!("Created development user: dev@localhost.com / password"); + // When auth-axum-login feature is enabled, always apply the auth layer. + // This is needed because: + // 1. OIDC callback uses AuthSession for state management + // 2. Session-based login/register routes use it + // 3. The "JWT mode" just changes what the login endpoint returns, not the underlying session support + #[cfg(feature = "auth-axum-login")] + { + let auth_layer = auth::setup_auth_layer(session_layer, user_repo).await?; + return Ok(app.layer(auth_layer)); } - Ok(()) + // When auth-axum-login is not compiled in, just use session layer for OIDC flow + #[cfg(not(feature = "auth-axum-login"))] + { + let _ = user_repo; // Suppress unused warning + Ok(app.layer(session_layer)) + } +} + +/// Log authentication info based on enabled features and config +fn log_auth_info(config: &Config) { + match config.auth_mode { + AuthMode::Session => { + tracing::info!("🔒 Authentication mode: Session (cookie-based)"); + } + AuthMode::Jwt => { + tracing::info!("🔒 Authentication mode: JWT (Bearer token)"); + } + AuthMode::Both => { + tracing::info!("🔒 Authentication mode: Both (JWT + Session)"); + } + } + + #[cfg(feature = "auth-axum-login")] + tracing::info!(" ✓ Session auth enabled (axum-login)"); + + #[cfg(feature = "auth-jwt")] + if config.jwt_secret.is_some() { + tracing::info!(" ✓ JWT auth enabled"); + } + + #[cfg(feature = "auth-oidc")] + if config.oidc_issuer.is_some() { + tracing::info!(" ✓ OIDC integration enabled"); + } } diff --git a/notes-api/src/routes/auth.rs b/notes-api/src/routes/auth.rs index 9418805..4cc5344 100644 --- a/notes-api/src/routes/auth.rs +++ b/notes-api/src/routes/auth.rs @@ -1,117 +1,491 @@ //! Authentication routes +//! +//! Provides login, register, logout, and token endpoints. +//! Supports both session-based and JWT-based authentication. -use axum::{Json, extract::State, http::StatusCode}; -use axum_login::AuthSession; -use validator::Validate; +#[cfg(feature = "auth-oidc")] +use axum::response::Response; +use axum::{ + Router, + extract::{Json, State}, + http::StatusCode, + response::IntoResponse, + routing::{get, post}, +}; +use serde::Serialize; +#[cfg(feature = "auth-oidc")] +use tower_sessions::Session; -use notes_domain::{Email, User}; -use password_auth::generate_hash; +#[cfg(feature = "auth-axum-login")] +use crate::config::AuthMode; +use crate::{ + dto::{LoginRequest, RegisterRequest, UserResponse}, + error::ApiError, + extractors::CurrentUser, + state::AppState, +}; +#[cfg(feature = "auth-axum-login")] +use notes_domain::DomainError; -use crate::auth::{AuthBackend, AuthUser, Credentials}; -use crate::dto::{LoginRequest, RegisterRequest}; -use crate::error::{ApiError, ApiResult}; -use crate::state::AppState; - -/// Register a new user -pub async fn register( - State(state): State, - mut auth_session: AuthSession, - Json(payload): Json, -) -> ApiResult { - payload - .validate() - .map_err(|e| ApiError::validation(e.to_string()))?; - - // Check if registration is allowed - if !state.config.allow_registration { - return Err(ApiError::Forbidden("Registration is disabled".to_string())); - } - - // Check if user exists - if state - .user_repo - .find_by_email(&payload.email) - .await - .map_err(ApiError::from)? - .is_some() - { - return Err(ApiError::Domain( - notes_domain::DomainError::UserAlreadyExists(payload.email.clone()), - )); - } - - // Hash password - let password_hash = generate_hash(&payload.password); - - // Parse email string to Email newtype - let email = Email::try_from(payload.email) - .map_err(|e| ApiError::validation(format!("Invalid email: {}", e)))?; - - // Create user - for local registration, we use email as subject - let user = User::new_local(email, &password_hash); - - state.user_repo.save(&user).await.map_err(ApiError::from)?; - - // Auto login after registration - let user = AuthUser(user); - auth_session - .login(&user) - .await - .map_err(|e| ApiError::internal(e.to_string()))?; - - Ok(StatusCode::CREATED) +/// Token response for JWT authentication +#[derive(Debug, Serialize)] +pub struct TokenResponse { + pub access_token: String, + pub token_type: String, + pub expires_in: u64, } -/// Login user -pub async fn login( - mut auth_session: AuthSession, - Json(payload): Json, -) -> ApiResult { - payload - .validate() - .map_err(|e| ApiError::validation(e.to_string()))?; +/// Login response that can be either a user (session mode) or a token (JWT mode) +#[derive(Debug, Serialize)] +#[serde(untagged)] +pub enum LoginResponse { + User(UserResponse), + Token(TokenResponse), +} - let user = auth_session - .authenticate(Credentials { +pub fn router() -> Router { + let r = Router::new() + .route("/login", post(login)) + .route("/register", post(register)) + .route("/logout", post(logout)) + .route("/me", get(me)); + + // Add token endpoint for getting JWT from session + #[cfg(feature = "auth-jwt")] + let r = r.route("/token", post(get_token)); + + #[cfg(feature = "auth-oidc")] + let r = r + .route("/login/oidc", get(oidc_login)) + .route("/callback", get(oidc_callback)); + + r +} + +/// Login endpoint +/// +/// In session mode: Creates a session and returns user info +/// In JWT mode: Validates credentials and returns a JWT token +/// In both mode: Creates session AND returns JWT token +#[cfg(feature = "auth-axum-login")] +async fn login( + State(state): State, + mut auth_session: crate::auth::AuthSession, + Json(payload): Json, +) -> Result { + let user = match auth_session + .authenticate(crate::auth::Credentials { email: payload.email, password: payload.password, }) .await - .map_err(|e| ApiError::internal(e.to_string()))? - .ok_or_else(|| ApiError::validation("Invalid email or password"))?; // Generic error for security + .map_err(|e| ApiError::Internal(e.to_string()))? + { + Some(user) => user, + None => return Err(ApiError::Validation("Invalid credentials".to_string())), + }; - auth_session - .login(&user) - .await - .map_err(|e| ApiError::internal(e.to_string()))?; + let auth_mode = state.config.auth_mode; - Ok(StatusCode::OK) -} - -/// Logout user -pub async fn logout(mut auth_session: AuthSession) -> ApiResult { - auth_session - .logout() - .await - .map_err(|e| ApiError::internal(e.to_string()))?; - - Ok(StatusCode::OK) -} - -/// Get current user -pub async fn me( - auth_session: AuthSession, -) -> ApiResult> { - let user = + // In session or both mode, create session + if matches!(auth_mode, AuthMode::Session | AuthMode::Both) { auth_session - .user - .ok_or(ApiError::Domain(notes_domain::DomainError::Unauthorized( - "Not logged in".to_string(), - )))?; + .login(&user) + .await + .map_err(|_| ApiError::Internal("Login failed".to_string()))?; + } - Ok(Json(crate::dto::UserResponse { - id: user.0.id, - email: user.0.email_str().to_string(), // Convert Email to String - created_at: user.0.created_at, + // In JWT or both mode, return token + #[cfg(feature = "auth-jwt")] + if matches!(auth_mode, AuthMode::Jwt | AuthMode::Both) { + let token = create_jwt_for_user(&user.0, &state)?; + return Ok(( + StatusCode::OK, + Json(LoginResponse::Token(TokenResponse { + access_token: token, + token_type: "Bearer".to_string(), + expires_in: state.config.jwt_expiry_hours * 3600, + })), + )); + } + + // Session mode: return user info + Ok(( + StatusCode::OK, + Json(LoginResponse::User(UserResponse { + id: user.0.id, + email: user.0.email, + created_at: user.0.created_at, + })), + )) +} + +/// Fallback login when auth-axum-login is not enabled +/// Without auth-axum-login, password-based authentication is not available. +/// Use OIDC login instead: GET /api/v1/auth/login/oidc +#[cfg(not(feature = "auth-axum-login"))] +async fn login( + State(_state): State, + Json(_payload): Json, +) -> Result<(StatusCode, Json), ApiError> { + Err(ApiError::Internal( + "Password-based login not available. auth-axum-login feature is required. Use OIDC login at /api/v1/auth/login/oidc instead.".to_string(), + )) +} + +/// Register endpoint +#[cfg(feature = "auth-axum-login")] +async fn register( + State(state): State, + mut auth_session: crate::auth::AuthSession, + Json(payload): Json, +) -> Result { + // Email is already validated by the newtype deserialization + let email = payload.email; + + if state + .user_service + .find_by_email(email.as_ref()) + .await? + .is_some() + { + return Err(ApiError::Domain(DomainError::UserAlreadyExists( + email.as_ref().to_string(), + ))); + } + + // Hash password + let password_hash = notes_infra::auth::axum_login::hash_password(payload.password.as_ref()); + + // Create user with password + let user = state + .user_service + .create_local(email.as_ref(), &password_hash) + .await?; + + let auth_mode = state.config.auth_mode; + + // In session or both mode, create session + if matches!(auth_mode, AuthMode::Session | AuthMode::Both) { + let auth_user = crate::auth::AuthUser(user.clone()); + auth_session + .login(&auth_user) + .await + .map_err(|_| ApiError::Internal("Login failed".to_string()))?; + } + + // In JWT or both mode, return token + #[cfg(feature = "auth-jwt")] + if matches!(auth_mode, AuthMode::Jwt | AuthMode::Both) { + let token = create_jwt_for_user(&user, &state)?; + return Ok(( + StatusCode::CREATED, + Json(LoginResponse::Token(TokenResponse { + access_token: token, + token_type: "Bearer".to_string(), + expires_in: state.config.jwt_expiry_hours * 3600, + })), + )); + } + + Ok(( + StatusCode::CREATED, + Json(LoginResponse::User(UserResponse { + id: user.id, + email: user.email, + created_at: user.created_at, + })), + )) +} + +/// Fallback register when auth-axum-login is not enabled +#[cfg(not(feature = "auth-axum-login"))] +async fn register( + State(_state): State, + Json(_payload): Json, +) -> Result<(StatusCode, Json), ApiError> { + Err(ApiError::Internal( + "Session-based registration not available. Use JWT token endpoint.".to_string(), + )) +} + +/// Logout endpoint +#[cfg(feature = "auth-axum-login")] +async fn logout(mut auth_session: crate::auth::AuthSession) -> impl IntoResponse { + match auth_session.logout().await { + Ok(_) => StatusCode::OK, + Err(_) => StatusCode::INTERNAL_SERVER_ERROR, + } +} + +/// Fallback logout when auth-axum-login is not enabled +#[cfg(not(feature = "auth-axum-login"))] +async fn logout() -> impl IntoResponse { + // JWT tokens can't be "logged out" server-side without a blocklist + // Just return OK + StatusCode::OK +} + +/// Get current user info +async fn me(CurrentUser(user): CurrentUser) -> Result { + Ok(Json(UserResponse { + id: user.id, + email: user.email, + created_at: user.created_at, })) } + +/// Get a JWT token for the current session user +/// +/// This allows session-authenticated users to obtain a JWT for API access. +#[cfg(feature = "auth-jwt")] +async fn get_token( + State(state): State, + CurrentUser(user): CurrentUser, +) -> Result { + let token = create_jwt_for_user(&user, &state)?; + + Ok(Json(TokenResponse { + access_token: token, + token_type: "Bearer".to_string(), + expires_in: state.config.jwt_expiry_hours * 3600, + })) +} + +/// Helper to create JWT for a user +#[cfg(feature = "auth-jwt")] +fn create_jwt_for_user(user: ¬es_domain::User, state: &AppState) -> Result { + let validator = state + .jwt_validator + .as_ref() + .ok_or_else(|| ApiError::Internal("JWT not configured".to_string()))?; + + validator + .create_token(user) + .map_err(|e| ApiError::Internal(format!("Failed to create token: {}", e))) +} + +// ============================================================================ +// OIDC Routes +// ============================================================================ + +#[cfg(feature = "auth-oidc")] +async fn oidc_login(State(state): State, session: Session) -> Result { + use axum::http::header; + + let service = state + .oidc_service + .as_ref() + .ok_or(ApiError::Internal("OIDC not configured".into()))?; + + let auth_data = service.get_authorization_url(); + + session + .insert("oidc_csrf", &auth_data.csrf_token) + .await + .map_err(|_| ApiError::Internal("Session error".into()))?; + session + .insert("oidc_nonce", &auth_data.nonce) + .await + .map_err(|_| ApiError::Internal("Session error".into()))?; + session + .insert("oidc_pkce", &auth_data.pkce_verifier) + .await + .map_err(|_| ApiError::Internal("Session error".into()))?; + + let response = axum::response::Redirect::to(auth_data.url.as_str()).into_response(); + let (mut parts, body) = response.into_parts(); + + parts.headers.insert( + header::CACHE_CONTROL, + "no-cache, no-store, must-revalidate".parse().unwrap(), + ); + parts + .headers + .insert(header::PRAGMA, "no-cache".parse().unwrap()); + parts.headers.insert(header::EXPIRES, "0".parse().unwrap()); + + Ok(Response::from_parts(parts, body)) +} + +#[cfg(feature = "auth-oidc")] +#[derive(serde::Deserialize)] +struct CallbackParams { + code: String, + state: String, +} + +#[cfg(all(feature = "auth-oidc", feature = "auth-axum-login"))] +async fn oidc_callback( + State(state): State, + session: Session, + mut auth_session: crate::auth::AuthSession, + axum::extract::Query(params): axum::extract::Query, +) -> Result { + let service = state + .oidc_service + .as_ref() + .ok_or(ApiError::Internal("OIDC not configured".into()))?; + + let stored_csrf: notes_domain::CsrfToken = session + .get("oidc_csrf") + .await + .map_err(|_| ApiError::Internal("Session error".into()))? + .ok_or(ApiError::Validation("Missing CSRF token".into()))?; + + if params.state != stored_csrf.as_ref() { + return Err(ApiError::Validation("Invalid CSRF token".into())); + } + + let stored_pkce: notes_domain::PkceVerifier = session + .get("oidc_pkce") + .await + .map_err(|_| ApiError::Internal("Session error".into()))? + .ok_or(ApiError::Validation("Missing PKCE".into()))?; + let stored_nonce: notes_domain::OidcNonce = session + .get("oidc_nonce") + .await + .map_err(|_| ApiError::Internal("Session error".into()))? + .ok_or(ApiError::Validation("Missing Nonce".into()))?; + + let oidc_user = service + .resolve_callback( + notes_domain::AuthorizationCode::new(params.code), + stored_nonce, + stored_pkce, + ) + .await + .map_err(|e| ApiError::Internal(e.to_string()))?; + + let user = state + .user_service + .find_or_create(&oidc_user.subject, &oidc_user.email) + .await + .map_err(|e| ApiError::Internal(e.to_string()))?; + + let auth_mode = state.config.auth_mode; + + // In session or both mode, create session + if matches!(auth_mode, AuthMode::Session | AuthMode::Both) { + auth_session + .login(&crate::auth::AuthUser(user.clone())) + .await + .map_err(|_| ApiError::Internal("Login failed".into()))?; + } + + // Clean up OIDC state + let _: Option = session + .remove("oidc_csrf") + .await + .map_err(|_| ApiError::Internal("Session error".into()))?; + let _: Option = session + .remove("oidc_pkce") + .await + .map_err(|_| ApiError::Internal("Session error".into()))?; + let _: Option = session + .remove("oidc_nonce") + .await + .map_err(|_| ApiError::Internal("Session error".into()))?; + + // In JWT mode, return token as JSON + #[cfg(feature = "auth-jwt")] + if matches!(auth_mode, AuthMode::Jwt | AuthMode::Both) { + let token = create_jwt_for_user(&user, &state)?; + return Ok(Json(TokenResponse { + access_token: token, + token_type: "Bearer".to_string(), + expires_in: state.config.jwt_expiry_hours * 3600, + }) + .into_response()); + } + + // Session mode: return user info + Ok(Json(UserResponse { + id: user.id, + email: user.email, + created_at: user.created_at, + }) + .into_response()) +} + +/// Fallback OIDC callback when auth-axum-login is not enabled +#[cfg(all(feature = "auth-oidc", not(feature = "auth-axum-login")))] +async fn oidc_callback( + State(state): State, + session: Session, + axum::extract::Query(params): axum::extract::Query, +) -> Result { + let service = state + .oidc_service + .as_ref() + .ok_or(ApiError::Internal("OIDC not configured".into()))?; + + let stored_csrf: notes_domain::CsrfToken = session + .get("oidc_csrf") + .await + .map_err(|_| ApiError::Internal("Session error".into()))? + .ok_or(ApiError::Validation("Missing CSRF token".into()))?; + + if params.state != stored_csrf.as_ref() { + return Err(ApiError::Validation("Invalid CSRF token".into())); + } + + let stored_pkce: notes_domain::PkceVerifier = session + .get("oidc_pkce") + .await + .map_err(|_| ApiError::Internal("Session error".into()))? + .ok_or(ApiError::Validation("Missing PKCE".into()))?; + let stored_nonce: notes_domain::OidcNonce = session + .get("oidc_nonce") + .await + .map_err(|_| ApiError::Internal("Session error".into()))? + .ok_or(ApiError::Validation("Missing Nonce".into()))?; + + let oidc_user = service + .resolve_callback( + notes_domain::AuthorizationCode::new(params.code), + stored_nonce, + stored_pkce, + ) + .await + .map_err(|e| ApiError::Internal(e.to_string()))?; + + let user = state + .user_service + .find_or_create(&oidc_user.subject, &oidc_user.email) + .await + .map_err(|e| ApiError::Internal(e.to_string()))?; + + // Clean up OIDC state + let _: Option = session + .remove("oidc_csrf") + .await + .map_err(|_| ApiError::Internal("Session error".into()))?; + let _: Option = session + .remove("oidc_pkce") + .await + .map_err(|_| ApiError::Internal("Session error".into()))?; + let _: Option = session + .remove("oidc_nonce") + .await + .map_err(|_| ApiError::Internal("Session error".into()))?; + + // Return token as JSON + #[cfg(feature = "auth-jwt")] + { + let token = create_jwt_for_user(&user, &state)?; + return Ok(Json(TokenResponse { + access_token: token, + token_type: "Bearer".to_string(), + expires_in: state.config.jwt_expiry_hours * 3600, + })); + } + + #[cfg(not(feature = "auth-jwt"))] + { + let _ = user; // Suppress unused warning + Err(ApiError::Internal( + "No auth backend available for OIDC callback".to_string(), + )) + } +} diff --git a/notes-api/src/routes/import_export.rs b/notes-api/src/routes/import_export.rs index b5274ab..a8199e7 100644 --- a/notes-api/src/routes/import_export.rs +++ b/notes-api/src/routes/import_export.rs @@ -1,9 +1,8 @@ use axum::{Json, extract::State, http::StatusCode}; -use axum_login::{AuthSession, AuthUser}; use serde::{Deserialize, Serialize}; -use crate::auth::AuthBackend; -use crate::error::{ApiError, ApiResult}; +use crate::error::ApiResult; +use crate::extractors::CurrentUser; use crate::state::AppState; use notes_domain::{Note, NoteFilter, Tag}; @@ -17,14 +16,9 @@ pub struct BackupData { /// GET /api/v1/export pub async fn export_data( State(state): State, - auth: AuthSession, + CurrentUser(user): CurrentUser, ) -> ApiResult> { - let user = auth - .user - .ok_or(ApiError::Domain(notes_domain::DomainError::Unauthorized( - "Login required".to_string(), - )))?; - let user_id = user.id(); + let user_id = user.id; let notes = state .note_repo @@ -39,15 +33,10 @@ pub async fn export_data( /// POST /api/v1/import pub async fn import_data( State(state): State, - auth: AuthSession, + CurrentUser(user): CurrentUser, Json(payload): Json, ) -> ApiResult { - let user = auth - .user - .ok_or(ApiError::Domain(notes_domain::DomainError::Unauthorized( - "Login required".to_string(), - )))?; - let user_id = user.id(); + let user_id = user.id; // 1. Import standalone tags (to ensure even unused tags are restored) for tag in payload.tags { diff --git a/notes-api/src/routes/mod.rs b/notes-api/src/routes/mod.rs index 01f11b3..77b7c2b 100644 --- a/notes-api/src/routes/mod.rs +++ b/notes-api/src/routes/mod.rs @@ -17,10 +17,7 @@ use crate::state::AppState; pub fn api_v1_router() -> Router { let router = Router::new() // Auth routes - .route("/auth/register", post(auth::register)) - .route("/auth/login", post(auth::login)) - .route("/auth/logout", post(auth::logout)) - .route("/auth/me", get(auth::me)) + .nest("/auth", auth::router()) // Note routes .route("/notes", get(notes::list_notes).post(notes::create_note)) .route( diff --git a/notes-api/src/routes/notes.rs b/notes-api/src/routes/notes.rs index f6dc189..2889529 100644 --- a/notes-api/src/routes/notes.rs +++ b/notes-api/src/routes/notes.rs @@ -5,34 +5,29 @@ use axum::{ extract::{Path, Query, State}, http::StatusCode, }; -use axum_login::AuthSession; use uuid::Uuid; use validator::Validate; -use axum_login::AuthUser; use notes_domain::{ CreateNoteRequest as DomainCreateNote, NoteTitle, TagName, UpdateNoteRequest as DomainUpdateNote, }; -use crate::auth::AuthBackend; -use crate::dto::{CreateNoteRequest, ListNotesQuery, NoteResponse, SearchQuery, UpdateNoteRequest}; use crate::error::{ApiError, ApiResult}; use crate::state::AppState; +use crate::{ + dto::{CreateNoteRequest, ListNotesQuery, NoteResponse, SearchQuery, UpdateNoteRequest}, + extractors::CurrentUser, +}; /// List notes with optional filtering /// GET /api/v1/notes pub async fn list_notes( State(state): State, - auth: AuthSession, + CurrentUser(user): CurrentUser, Query(query): Query, ) -> ApiResult>> { - let user = auth - .user - .ok_or(ApiError::Domain(notes_domain::DomainError::Unauthorized( - "Login required".to_string(), - )))?; - let user_id = user.id(); + let user_id = user.id; // Build the filter, looking up tag_id by name if needed let mut filter = notes_domain::NoteFilter::new(); @@ -59,15 +54,10 @@ pub async fn list_notes( /// POST /api/v1/notes pub async fn create_note( State(state): State, - auth: AuthSession, + CurrentUser(user): CurrentUser, Json(payload): Json, ) -> ApiResult<(StatusCode, Json)> { - let user = auth - .user - .ok_or(ApiError::Domain(notes_domain::DomainError::Unauthorized( - "Login required".to_string(), - )))?; - let user_id = user.id(); + let user_id = user.id; // Validate input payload @@ -113,15 +103,10 @@ pub async fn create_note( /// GET /api/v1/notes/:id pub async fn get_note( State(state): State, - auth: AuthSession, + CurrentUser(user): CurrentUser, Path(id): Path, ) -> ApiResult> { - let user = auth - .user - .ok_or(ApiError::Domain(notes_domain::DomainError::Unauthorized( - "Login required".to_string(), - )))?; - let user_id = user.id(); + let user_id = user.id; let note = state.note_service.get_note(id, user_id).await?; @@ -132,16 +117,11 @@ pub async fn get_note( /// PATCH /api/v1/notes/:id pub async fn update_note( State(state): State, - auth: AuthSession, + CurrentUser(user): CurrentUser, Path(id): Path, Json(payload): Json, ) -> ApiResult> { - let user = auth - .user - .ok_or(ApiError::Domain(notes_domain::DomainError::Unauthorized( - "Login required".to_string(), - )))?; - let user_id = user.id(); + let user_id = user.id; // Validate input payload @@ -195,15 +175,10 @@ pub async fn update_note( /// DELETE /api/v1/notes/:id pub async fn delete_note( State(state): State, - auth: AuthSession, + CurrentUser(user): CurrentUser, Path(id): Path, ) -> ApiResult { - let user = auth - .user - .ok_or(ApiError::Domain(notes_domain::DomainError::Unauthorized( - "Login required".to_string(), - )))?; - let user_id = user.id(); + let user_id = user.id; state.note_service.delete_note(id, user_id).await?; @@ -214,15 +189,10 @@ pub async fn delete_note( /// GET /api/v1/notes/search pub async fn search_notes( State(state): State, - auth: AuthSession, + CurrentUser(user): CurrentUser, Query(query): Query, ) -> ApiResult>> { - let user = auth - .user - .ok_or(ApiError::Domain(notes_domain::DomainError::Unauthorized( - "Login required".to_string(), - )))?; - let user_id = user.id(); + let user_id = user.id; let notes = state.note_service.search_notes(user_id, &query.q).await?; let response: Vec = notes.into_iter().map(NoteResponse::from).collect(); @@ -234,15 +204,10 @@ pub async fn search_notes( /// GET /api/v1/notes/:id/versions pub async fn list_note_versions( State(state): State, - auth: AuthSession, + CurrentUser(user): CurrentUser, Path(id): Path, ) -> ApiResult>> { - let user = auth - .user - .ok_or(ApiError::Domain(notes_domain::DomainError::Unauthorized( - "Login required".to_string(), - )))?; - let user_id = user.id(); + let user_id = user.id; let versions = state.note_service.list_note_versions(id, user_id).await?; let response: Vec = versions @@ -260,15 +225,10 @@ pub async fn list_note_versions( #[cfg(feature = "smart-features")] pub async fn get_related_notes( State(state): State, - auth: AuthSession, + CurrentUser(user): CurrentUser, Path(id): Path, ) -> ApiResult>> { - let user = auth - .user - .ok_or(ApiError::Domain(notes_domain::DomainError::Unauthorized( - "Login required".to_string(), - )))?; - let user_id = user.id(); + let user_id = user.id; // Verify access to the source note state.note_service.get_note(id, user_id).await?; diff --git a/notes-api/src/routes/tags.rs b/notes-api/src/routes/tags.rs index 911d30a..09580a8 100644 --- a/notes-api/src/routes/tags.rs +++ b/notes-api/src/routes/tags.rs @@ -5,29 +5,25 @@ use axum::{ extract::{Path, State}, http::StatusCode, }; -use axum_login::{AuthSession, AuthUser}; use uuid::Uuid; use validator::Validate; use notes_domain::TagName; -use crate::auth::AuthBackend; -use crate::dto::{CreateTagRequest, RenameTagRequest, TagResponse}; use crate::error::{ApiError, ApiResult}; use crate::state::AppState; +use crate::{ + dto::{CreateTagRequest, RenameTagRequest, TagResponse}, + extractors::CurrentUser, +}; /// List all tags for the user /// GET /api/v1/tags pub async fn list_tags( State(state): State, - auth: AuthSession, + CurrentUser(user): CurrentUser, ) -> ApiResult>> { - let user = auth - .user - .ok_or(ApiError::Domain(notes_domain::DomainError::Unauthorized( - "Login required".to_string(), - )))?; - let user_id = user.id(); + let user_id = user.id; let tags = state.tag_service.list_tags(user_id).await?; let response: Vec = tags.into_iter().map(TagResponse::from).collect(); @@ -39,15 +35,10 @@ pub async fn list_tags( /// POST /api/v1/tags pub async fn create_tag( State(state): State, - auth: AuthSession, + CurrentUser(user): CurrentUser, Json(payload): Json, ) -> ApiResult<(StatusCode, Json)> { - let user = auth - .user - .ok_or(ApiError::Domain(notes_domain::DomainError::Unauthorized( - "Login required".to_string(), - )))?; - let user_id = user.id(); + let user_id = user.id; payload .validate() @@ -66,16 +57,11 @@ pub async fn create_tag( /// PATCH /api/v1/tags/:id pub async fn rename_tag( State(state): State, - auth: AuthSession, + CurrentUser(user): CurrentUser, Path(id): Path, Json(payload): Json, ) -> ApiResult> { - let user = auth - .user - .ok_or(ApiError::Domain(notes_domain::DomainError::Unauthorized( - "Login required".to_string(), - )))?; - let user_id = user.id(); + let user_id = user.id; payload .validate() @@ -94,15 +80,10 @@ pub async fn rename_tag( /// DELETE /api/v1/tags/:id pub async fn delete_tag( State(state): State, - auth: AuthSession, + CurrentUser(user): CurrentUser, Path(id): Path, ) -> ApiResult { - let user = auth - .user - .ok_or(ApiError::Domain(notes_domain::DomainError::Unauthorized( - "Login required".to_string(), - )))?; - let user_id = user.id(); + let user_id = user.id; state.tag_service.delete_tag(id, user_id).await?; diff --git a/notes-api/src/state.rs b/notes-api/src/state.rs index 1f6d97e..61f7fb9 100644 --- a/notes-api/src/state.rs +++ b/notes-api/src/state.rs @@ -1,45 +1,123 @@ use std::sync::Arc; -use crate::config::Config; -use notes_domain::{ - NoteRepository, NoteService, TagRepository, TagService, UserRepository, UserService, -}; +use crate::config::{AuthMode, Config}; +use notes_domain::{NoteRepository, NoteService, TagRepository, TagService, UserService}; + +#[cfg(feature = "auth-jwt")] +use notes_infra::auth::jwt::{JwtConfig, JwtValidator}; +#[cfg(feature = "auth-oidc")] +use notes_infra::auth::oidc::OidcService; /// Application state holding all dependencies #[derive(Clone)] pub struct AppState { pub note_repo: Arc, pub tag_repo: Arc, - pub user_repo: Arc, #[cfg(feature = "smart-features")] pub link_repo: Arc, pub note_service: Arc, pub tag_service: Arc, pub user_service: Arc, pub config: Config, + #[cfg(feature = "auth-oidc")] + pub oidc_service: Option>, + #[cfg(feature = "auth-jwt")] + pub jwt_validator: Option>, } impl AppState { - pub fn new( + pub async fn new( note_repo: Arc, tag_repo: Arc, - user_repo: Arc, #[cfg(feature = "smart-features")] link_repo: Arc, note_service: Arc, tag_service: Arc, user_service: Arc, config: Config, - ) -> Self { - Self { + ) -> anyhow::Result { + #[cfg(feature = "auth-oidc")] + let oidc_service = if let (Some(issuer), Some(id), secret, Some(redirect), resource_id) = ( + &config.oidc_issuer, + &config.oidc_client_id, + &config.oidc_client_secret, + &config.oidc_redirect_url, + &config.oidc_resource_id, + ) { + tracing::info!("Initializing OIDC service with issuer: {}", issuer); + + // Construct newtypes from config strings + let issuer_url = notes_domain::IssuerUrl::new(issuer) + .map_err(|e| anyhow::anyhow!("Invalid OIDC issuer URL: {}", e))?; + let client_id = notes_domain::ClientId::new(id) + .map_err(|e| anyhow::anyhow!("Invalid OIDC client ID: {}", e))?; + let client_secret = secret.as_ref().map(|s| notes_domain::ClientSecret::new(s)); + let redirect_url = notes_domain::RedirectUrl::new(redirect) + .map_err(|e| anyhow::anyhow!("Invalid OIDC redirect URL: {}", e))?; + let resource = resource_id + .as_ref() + .map(|r| notes_domain::ResourceId::new(r)) + .transpose() + .map_err(|e| anyhow::anyhow!("Invalid OIDC resource ID: {}", e))?; + + Some(Arc::new( + OidcService::new(issuer_url, client_id, client_secret, redirect_url, resource) + .await?, + )) + } else { + None + }; + + #[cfg(feature = "auth-jwt")] + let jwt_validator = if matches!(config.auth_mode, AuthMode::Jwt | AuthMode::Both) { + // Use provided secret or fall back to a development secret + let secret = if let Some(ref s) = config.jwt_secret { + if s.is_empty() { None } else { Some(s.clone()) } + } else { + None + }; + + let secret = match secret { + Some(s) => s, + None => { + if config.is_production { + anyhow::bail!( + "JWT_SECRET is required when AUTH_MODE is 'jwt' or 'both' in production" + ); + } + // Use a development-only default secret + tracing::warn!( + "⚠️ JWT_SECRET not set - using insecure development secret. DO NOT USE IN PRODUCTION!" + ); + "k-template-dev-secret-not-for-production-use-only".to_string() + } + }; + + tracing::info!("Initializing JWT validator"); + let jwt_config = JwtConfig::new( + secret, + config.jwt_issuer.clone(), + config.jwt_audience.clone(), + Some(config.jwt_expiry_hours), + config.is_production, + )?; + Some(Arc::new(JwtValidator::new(jwt_config))) + } else { + None + }; + + Ok(Self { note_repo, tag_repo, - user_repo, #[cfg(feature = "smart-features")] link_repo, note_service, tag_service, user_service, config, - } + #[cfg(feature = "auth-oidc")] + oidc_service, + #[cfg(feature = "auth-jwt")] + jwt_validator, + }) } } diff --git a/notes-domain/src/errors.rs b/notes-domain/src/errors.rs index 2f4cec1..8c5de6e 100644 --- a/notes-domain/src/errors.rs +++ b/notes-domain/src/errors.rs @@ -91,6 +91,12 @@ impl DomainError { } } +impl From for DomainError { + fn from(error: crate::value_objects::ValidationError) -> Self { + DomainError::ValidationError(error.to_string()) + } +} + /// Result type alias for domain operations pub type DomainResult = Result; diff --git a/notes-domain/src/services.rs b/notes-domain/src/services.rs index 348ff20..de139aa 100644 --- a/notes-domain/src/services.rs +++ b/notes-domain/src/services.rs @@ -375,36 +375,46 @@ impl UserService { Self { user_repo } } - /// Find or create a user by OIDC subject - /// This is the main entry point for OIDC authentication - pub async fn find_or_create_by_subject( - &self, - subject: &str, - email: Email, - ) -> DomainResult { + pub async fn find_or_create(&self, subject: &str, email: &str) -> DomainResult { + // 1. Try to find by subject (OIDC id) if let Some(user) = self.user_repo.find_by_subject(subject).await? { - Ok(user) - } else { - let user = User::new(subject, email); - self.user_repo.save(&user).await?; - Ok(user) + return Ok(user); } + + // 2. Try to find by email + if let Some(mut user) = self.user_repo.find_by_email(email).await? { + // Link subject if missing (account linking logic) + if user.subject != subject { + user.subject = subject.to_string(); + self.user_repo.save(&user).await?; + } + return Ok(user); + } + + // 3. Create new user + let email = Email::try_from(email)?; + let user = User::new(subject, email); + self.user_repo.save(&user).await?; + + Ok(user) } - /// Get a user by ID - pub async fn get_user(&self, id: Uuid) -> DomainResult { + pub async fn find_by_id(&self, id: Uuid) -> DomainResult { self.user_repo .find_by_id(id) .await? .ok_or(DomainError::UserNotFound(id)) } - /// Delete a user and all associated data - pub async fn delete_user(&self, id: Uuid) -> DomainResult<()> { - // Note: In practice, we'd also need to delete notes and tags - // This would be handled by cascade delete in the database - // or by coordinating with other services - self.user_repo.delete(id).await + pub async fn find_by_email(&self, email: &str) -> DomainResult> { + self.user_repo.find_by_email(email).await + } + + pub async fn create_local(&self, email: &str, password_hash: &str) -> DomainResult { + let email = Email::try_from(email)?; + let user = User::new_local(email, password_hash); + self.user_repo.save(&user).await?; + Ok(user) } } @@ -889,7 +899,7 @@ mod tests { let email = Email::try_from("test@example.com").unwrap(); let user = service - .find_or_create_by_subject("oidc|123", email) + .find_or_create("oidc|123", email.as_ref()) .await .unwrap(); @@ -903,13 +913,13 @@ mod tests { let email1 = Email::try_from("test@example.com").unwrap(); let user1 = service - .find_or_create_by_subject("oidc|123", email1) + .find_or_create("oidc|123", email1.as_ref()) .await .unwrap(); let email2 = Email::try_from("test@example.com").unwrap(); let user2 = service - .find_or_create_by_subject("oidc|123", email2) + .find_or_create("oidc|123", email2.as_ref()) .await .unwrap(); diff --git a/notes-infra/src/auth/mod.rs b/notes-infra/src/auth/mod.rs index ac01f0b..3760d03 100644 --- a/notes-infra/src/auth/mod.rs +++ b/notes-infra/src/auth/mod.rs @@ -1,6 +1,6 @@ #[cfg(feature = "auth-axum-login")] -mod axum_login; +pub mod axum_login; #[cfg(feature = "auth-jwt")] -mod jwt; +pub mod jwt; #[cfg(feature = "auth-oidc")] -mod oidc; +pub mod oidc; diff --git a/notes-infra/src/session_store.rs b/notes-infra/src/session_store.rs index e9f5bee..edb657f 100644 --- a/notes-infra/src/session_store.rs +++ b/notes-infra/src/session_store.rs @@ -1 +1,2 @@ pub use k_core::session::store::InfraSessionStore; +pub use tower_sessions::{Expiry, SessionManagerLayer};