diff --git a/Cargo.lock b/Cargo.lock index c40b6a2..4a564a1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1360,6 +1360,7 @@ dependencies = [ "domain", "futures-core", "futures-util", + "jsonwebtoken", "k-core", "openidconnect", "password-auth", @@ -1427,6 +1428,21 @@ dependencies = [ "serde", ] +[[package]] +name = "jsonwebtoken" +version = "9.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a87cc7a48537badeae96744432de36f4be2b4a34a05a5ef32e9dd8a1c169dde" +dependencies = [ + "base64 0.22.1", + "js-sys", + "pem", + "ring", + "serde", + "serde_json", + "simple_asn1", +] + [[package]] name = "k-core" version = "0.1.10" @@ -1605,6 +1621,16 @@ dependencies = [ "rand 0.8.5", ] +[[package]] +name = "num-bigint" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a5e44f723f1133c9deac646763579fdb3ac745e418f2a7af9cd0c431da1f20b9" +dependencies = [ + "num-integer", + "num-traits", +] + [[package]] name = "num-bigint-dig" version = "0.8.6" @@ -1821,6 +1847,16 @@ version = "0.2.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "df94ce210e5bc13cb6651479fa48d14f601d9858cfe0467f43ae157023b938d3" +[[package]] +name = "pem" +version = "3.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d30c53c26bc5b31a98cd02d20f25a7c8567146caf63ed593a9d87b2775291be" +dependencies = [ + "base64 0.22.1", + "serde_core", +] + [[package]] name = "pem-rfc7468" version = "0.7.0" @@ -2731,6 +2767,18 @@ dependencies = [ "rand_core 0.6.4", ] +[[package]] +name = "simple_asn1" +version = "0.6.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "297f631f50729c8c99b84667867963997ec0b50f32b2a7dbcab828ef0541e8bb" +dependencies = [ + "num-bigint", + "num-traits", + "thiserror 2.0.17", + "time", +] + [[package]] name = "slab" version = "0.4.11" diff --git a/README.md b/README.md index 3ecf1d5..437036b 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ A production-ready, modular Rust template for K-Suite applications, following Hexagonal Architecture principles. -## 🌟 Features +## Features - **Hexagonal Architecture**: Clear separation of concerns between Domain, Infrastructure, and API layers. - **Modular & Swappable**: Vendor implementations (databases, message brokers) are behind feature flags and trait objects. @@ -10,7 +10,7 @@ A production-ready, modular Rust template for K-Suite applications, following He - **Cargo Generate Ready**: Pre-configured for `cargo-generate` to easily scaffold new services. - **Testable**: Domain logic is pure and easily testable; Infrastructure is tested with integration tests. -## 🏗️ Project Structure +## Project Structure The workspace consists of three main crates: @@ -26,7 +26,7 @@ The workspace consists of three main crates: - Wires everything together using dependency injection. - Handles HTTP/REST/gRPC interfaces. -## 🚀 Getting Started +## Getting Started ### Prerequisites @@ -57,7 +57,7 @@ cargo test cargo test -p template-infra --no-default-features --features postgres ``` -## ⚙️ Configuration & Feature Flags +## Configuration & Feature Flags This template uses Cargo features to control compilation of infrastructure adapters. @@ -86,7 +86,7 @@ default = ["postgres"] # ... ``` -## 📐 Architecture Guide +## Architecture Guide ### Adding a New Feature diff --git a/api/Cargo.toml b/api/Cargo.toml index f0a7290..7482743 100644 --- a/api/Cargo.toml +++ b/api/Cargo.toml @@ -5,11 +5,13 @@ edition = "2024" default-run = "api" [features] -default = ["sqlite", "auth-axum-login", "auth-oidc"] +default = ["sqlite", "auth-axum-login", "auth-oidc", "auth-jwt"] sqlite = ["infra/sqlite", "tower-sessions-sqlx-store/sqlite"] postgres = ["infra/postgres", "tower-sessions-sqlx-store/postgres"] auth-axum-login = ["infra/auth-axum-login"] auth-oidc = ["infra/auth-oidc"] +auth-jwt = ["infra/auth-jwt"] +auth-full = ["auth-axum-login", "auth-oidc", "auth-jwt"] [dependencies] k-core = { git = "https://git.gabrielkaszewski.dev/GKaszewski/k-core", features = [ diff --git a/api/src/auth.rs b/api/src/auth.rs index e0d7245..b0ba3b6 100644 --- a/api/src/auth.rs +++ b/api/src/auth.rs @@ -2,11 +2,15 @@ //! //! Proxies to infra implementation if enabled. +#[cfg(feature = "auth-axum-login")] use std::sync::Arc; +#[cfg(feature = "auth-axum-login")] use domain::UserRepository; +#[cfg(feature = "auth-axum-login")] use infra::session_store::{InfraSessionStore, SessionManagerLayer}; +#[cfg(feature = "auth-axum-login")] use crate::error::ApiError; #[cfg(feature = "auth-axum-login")] diff --git a/api/src/config.rs b/api/src/config.rs index defa550..524ab9e 100644 --- a/api/src/config.rs +++ b/api/src/config.rs @@ -6,6 +6,30 @@ use std::env; use serde::Deserialize; +/// Authentication mode - determines how the API authenticates requests +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Deserialize)] +#[serde(rename_all = "lowercase")] +pub enum AuthMode { + /// Session-based authentication using cookies (default for backward compatibility) + #[default] + Session, + /// JWT-based authentication using Bearer tokens + Jwt, + /// Support both session and JWT authentication (try JWT first, then session) + Both, +} + +impl AuthMode { + /// Parse auth mode from string + pub fn from_str(s: &str) -> Self { + match s.to_lowercase().as_str() { + "jwt" => AuthMode::Jwt, + "both" => AuthMode::Both, + _ => AuthMode::Session, + } + } +} + //todo: replace with newtypes #[derive(Debug, Clone, Deserialize)] pub struct Config { @@ -28,10 +52,27 @@ pub struct Config { #[serde(default = "default_db_min_connections")] pub db_min_connections: u32, + // OIDC configuration pub oidc_issuer: Option, pub oidc_client_id: Option, pub oidc_client_secret: Option, pub oidc_redirect_url: Option, + pub oidc_resource_id: Option, + + // Auth mode configuration + #[serde(default)] + pub auth_mode: AuthMode, + + // JWT configuration + pub jwt_secret: Option, + pub jwt_issuer: Option, + pub jwt_audience: Option, + #[serde(default = "default_jwt_expiry_hours")] + pub jwt_expiry_hours: u64, + + /// Whether the application is running in production mode + #[serde(default)] + pub is_production: bool, } fn default_secure_cookie() -> bool { @@ -54,6 +95,10 @@ fn default_host() -> String { "127.0.0.1".to_string() } +fn default_jwt_expiry_hours() -> u64 { + 24 +} + impl Config { pub fn new() -> Result { config::Config::builder() @@ -108,6 +153,26 @@ impl Config { let oidc_client_id = env::var("OIDC_CLIENT_ID").ok(); let oidc_client_secret = env::var("OIDC_CLIENT_SECRET").ok(); let oidc_redirect_url = env::var("OIDC_REDIRECT_URL").ok(); + let oidc_resource_id = env::var("OIDC_RESOURCE_ID").ok(); + + // Auth mode configuration + let auth_mode = env::var("AUTH_MODE") + .map(|s| AuthMode::from_str(&s)) + .unwrap_or_default(); + + // JWT configuration + let jwt_secret = env::var("JWT_SECRET").ok(); + let jwt_issuer = env::var("JWT_ISSUER").ok(); + let jwt_audience = env::var("JWT_AUDIENCE").ok(); + let jwt_expiry_hours = env::var("JWT_EXPIRY_HOURS") + .ok() + .and_then(|s| s.parse().ok()) + .unwrap_or(24); + + let is_production = env::var("PRODUCTION") + .or_else(|_| env::var("RUST_ENV")) + .map(|v| v.to_lowercase() == "production" || v == "1" || v == "true") + .unwrap_or(false); Self { host, @@ -122,6 +187,13 @@ impl Config { oidc_client_id, oidc_client_secret, oidc_redirect_url, + oidc_resource_id, + auth_mode, + jwt_secret, + jwt_issuer, + jwt_audience, + jwt_expiry_hours, + is_production, } } } diff --git a/api/src/dto.rs b/api/src/dto.rs index deaa12f..17cbd1a 100644 --- a/api/src/dto.rs +++ b/api/src/dto.rs @@ -40,3 +40,12 @@ pub struct UserResponse { pub struct ConfigResponse { pub allow_registration: bool, } + +#[cfg(feature = "auth-jwt")] +#[derive(Debug, Serialize, Deserialize)] +// also newtypes +pub struct Claims { + pub sub: String, + pub email: String, + pub exp: usize, +} diff --git a/api/src/error.rs b/api/src/error.rs index d3f0aff..7e39348 100644 --- a/api/src/error.rs +++ b/api/src/error.rs @@ -14,6 +14,7 @@ use domain::DomainError; /// API-level errors #[derive(Debug, Error)] +#[allow(dead_code)] // Some variants are reserved for future use pub enum ApiError { #[error("{0}")] Domain(#[from] DomainError), @@ -107,6 +108,7 @@ impl IntoResponse for ApiError { } } +#[allow(dead_code)] // Helper constructors for future use impl ApiError { pub fn validation(msg: impl Into) -> Self { Self::Validation(msg.into()) @@ -118,4 +120,5 @@ impl ApiError { } /// Result type alias for API handlers +#[allow(dead_code)] pub type ApiResult = Result; diff --git a/api/src/extractors.rs b/api/src/extractors.rs new file mode 100644 index 0000000..bf7193b --- /dev/null +++ b/api/src/extractors.rs @@ -0,0 +1,150 @@ +//! Auth extractors for API handlers +//! +//! Provides the `CurrentUser` extractor that works with both session and JWT auth. + +use axum::{extract::FromRequestParts, http::request::Parts}; +use domain::User; + +use crate::config::AuthMode; +use crate::error::ApiError; +use crate::state::AppState; + +/// Extracted current user from the request. +/// +/// This extractor supports multiple authentication methods based on the configured `AuthMode`: +/// - `Session`: Uses axum-login session cookies +/// - `Jwt`: Uses Bearer token in Authorization header +/// - `Both`: Tries JWT first, then falls back to session +pub struct CurrentUser(pub User); + +impl FromRequestParts for CurrentUser { + type Rejection = ApiError; + + async fn from_request_parts( + parts: &mut Parts, + state: &AppState, + ) -> Result { + let auth_mode = state.config.auth_mode; + + // Try JWT first if enabled + #[cfg(feature = "auth-jwt")] + if matches!(auth_mode, AuthMode::Jwt | AuthMode::Both) { + match try_jwt_auth(parts, state).await { + Ok(Some(user)) => return Ok(CurrentUser(user)), + Ok(None) => { + // No JWT token present, continue to session auth if Both mode + if auth_mode == AuthMode::Jwt { + return Err(ApiError::Unauthorized( + "Missing or invalid Authorization header".to_string(), + )); + } + } + Err(e) => { + // JWT was present but invalid + tracing::debug!("JWT auth failed: {}", e); + if auth_mode == AuthMode::Jwt { + return Err(e); + } + // In Both mode, continue to try session + } + } + } + + // Try session auth if enabled + #[cfg(feature = "auth-axum-login")] + if matches!(auth_mode, AuthMode::Session | AuthMode::Both) { + if let Some(user) = try_session_auth(parts).await? { + return Ok(CurrentUser(user)); + } + } + + Err(ApiError::Unauthorized("Not authenticated".to_string())) + } +} + +/// Try to authenticate using JWT Bearer token +#[cfg(feature = "auth-jwt")] +async fn try_jwt_auth(parts: &mut Parts, state: &AppState) -> Result, ApiError> { + use axum::http::header::AUTHORIZATION; + + // Get Authorization header + let auth_header = match parts.headers.get(AUTHORIZATION) { + Some(header) => header, + None => return Ok(None), // No header = no JWT auth attempted + }; + + let auth_str = auth_header + .to_str() + .map_err(|_| ApiError::Unauthorized("Invalid Authorization header encoding".to_string()))?; + + // Extract Bearer token + let token = auth_str.strip_prefix("Bearer ").ok_or_else(|| { + ApiError::Unauthorized("Authorization header must use Bearer scheme".to_string()) + })?; + + // Get JWT validator + let validator = state + .jwt_validator + .as_ref() + .ok_or_else(|| ApiError::Internal("JWT validator not configured".to_string()))?; + + // Validate token + let claims = validator.validate_token(token).map_err(|e| { + tracing::debug!("JWT validation failed: {:?}", e); + match e { + infra::auth::jwt::JwtError::Expired => { + ApiError::Unauthorized("Token expired".to_string()) + } + infra::auth::jwt::JwtError::InvalidFormat => { + ApiError::Unauthorized("Invalid token format".to_string()) + } + _ => ApiError::Unauthorized("Token validation failed".to_string()), + } + })?; + + // Fetch user from database by ID (subject contains user ID) + let user_id: uuid::Uuid = claims + .sub + .parse() + .map_err(|_| ApiError::Unauthorized("Invalid user ID in token".to_string()))?; + + let user = state + .user_service + .find_by_id(user_id) + .await + .map_err(|e| ApiError::Internal(format!("Failed to fetch user: {}", e)))?; + + Ok(Some(user)) +} + +/// Try to authenticate using session cookie +#[cfg(feature = "auth-axum-login")] +async fn try_session_auth(parts: &mut Parts) -> Result, ApiError> { + use infra::auth::backend::AuthSession; + + // Check if AuthSession extension is present (added by auth middleware) + if let Some(auth_session) = parts.extensions.get::() { + if let Some(auth_user) = &auth_session.user { + return Ok(Some(auth_user.0.clone())); + } + } + + Ok(None) +} + +/// Fallback for when auth-axum-login is not enabled +#[cfg(not(feature = "auth-axum-login"))] +async fn try_session_auth(_parts: &mut Parts) -> Result, ApiError> { + Ok(None) +} + +/// Fallback for when auth-jwt is not enabled but auth mode requires it +#[cfg(not(feature = "auth-jwt"))] +async fn try_jwt_auth(_parts: &mut Parts, state: &AppState) -> Result, ApiError> { + if matches!(state.config.auth_mode, AuthMode::Jwt) { + return Err(ApiError::Internal( + "JWT auth mode configured but auth-jwt feature not enabled".to_string(), + )); + } + Ok(None) +} diff --git a/api/src/main.rs b/api/src/main.rs index 287e59a..b498ed2 100644 --- a/api/src/main.rs +++ b/api/src/main.rs @@ -1,3 +1,7 @@ +//! API Server Entry Point +//! +//! Configures and starts the HTTP server with authentication based on AUTH_MODE. + use std::net::SocketAddr; use std::time::Duration as StdDuration; @@ -12,17 +16,18 @@ use k_core::http::server::apply_standard_middleware; use k_core::logging; use time::Duration; use tokio::net::TcpListener; +use tower_sessions::cookie::SameSite; use tracing::info; mod auth; mod config; mod dto; mod error; +mod extractors; mod routes; mod state; -use crate::auth::setup_auth_layer; -use crate::config::Config; +use crate::config::{AuthMode, Config}; use crate::state::AppState; #[tokio::main] @@ -32,6 +37,7 @@ async fn main() -> anyhow::Result<()> { let config = Config::from_env(); info!("Starting server on {}:{}", config.host, config.port); + info!("Auth mode: {:?}", config.auth_mode); // Setup database tracing::info!("Connecting to database: {}", config.database_url); @@ -51,6 +57,7 @@ async fn main() -> anyhow::Result<()> { let state = AppState::new(user_service, config.clone()).await?; + // Build session store (needed for OIDC flow even in JWT mode) let session_store = build_session_store(&db_pool) .await .map_err(|e| anyhow::anyhow!(e))?; @@ -61,30 +68,85 @@ async fn main() -> anyhow::Result<()> { let session_layer = SessionManagerLayer::new(session_store) .with_secure(config.secure_cookie) + .with_same_site(SameSite::Lax) .with_expiry(Expiry::OnInactivity(Duration::days(7))); - let auth_layer = setup_auth_layer(session_layer, user_repo).await?; - let server_config = ServerConfig { cors_origins: config.cors_allowed_origins.clone(), session_secret: Some(config.session_secret.clone()), }; - let app = Router::new() - .nest("/api/v1", routes::api_v1_router()) - .layer(auth_layer) - .with_state(state); - + // Build the app with appropriate auth layers based on config + let app = build_app(state, session_layer, user_repo, &config).await?; let app = apply_standard_middleware(app, &server_config); let addr: SocketAddr = format!("{}:{}", config.host, config.port).parse()?; let listener = TcpListener::bind(addr).await?; tracing::info!("🚀 API server running at http://{}", addr); - tracing::info!("🔒 Authentication enabled (axum-login)"); + log_auth_info(&config); tracing::info!("📝 API endpoints available at /api/v1/..."); axum::serve(listener, app).await?; Ok(()) } + +/// Build the application router with appropriate auth layers +#[allow(unused_variables)] // config/user_repo used conditionally based on features +async fn build_app( + state: AppState, + session_layer: SessionManagerLayer, + user_repo: std::sync::Arc, + config: &Config, +) -> anyhow::Result { + let app = Router::new() + .nest("/api/v1", routes::api_v1_router()) + .with_state(state); + + // When auth-axum-login feature is enabled, always apply the auth layer. + // This is needed because: + // 1. OIDC callback uses AuthSession for state management + // 2. Session-based login/register routes use it + // 3. The "JWT mode" just changes what the login endpoint returns, not the underlying session support + #[cfg(feature = "auth-axum-login")] + { + let auth_layer = auth::setup_auth_layer(session_layer, user_repo).await?; + return Ok(app.layer(auth_layer)); + } + + // When auth-axum-login is not compiled in, just use session layer for OIDC flow + #[cfg(not(feature = "auth-axum-login"))] + { + let _ = user_repo; // Suppress unused warning + Ok(app.layer(session_layer)) + } +} + +/// Log authentication info based on enabled features and config +fn log_auth_info(config: &Config) { + match config.auth_mode { + AuthMode::Session => { + tracing::info!("🔒 Authentication mode: Session (cookie-based)"); + } + AuthMode::Jwt => { + tracing::info!("🔒 Authentication mode: JWT (Bearer token)"); + } + AuthMode::Both => { + tracing::info!("🔒 Authentication mode: Both (JWT + Session)"); + } + } + + #[cfg(feature = "auth-axum-login")] + tracing::info!(" ✓ Session auth enabled (axum-login)"); + + #[cfg(feature = "auth-jwt")] + if config.jwt_secret.is_some() { + tracing::info!(" ✓ JWT auth enabled"); + } + + #[cfg(feature = "auth-oidc")] + if config.oidc_issuer.is_some() { + tracing::info!(" ✓ OIDC integration enabled"); + } +} diff --git a/api/src/routes/auth.rs b/api/src/routes/auth.rs index 8124c72..af190f2 100644 --- a/api/src/routes/auth.rs +++ b/api/src/routes/auth.rs @@ -1,35 +1,75 @@ -use axum::http::StatusCode; +//! Authentication routes +//! +//! Provides login, register, logout, and token endpoints. +//! Supports both session-based and JWT-based authentication. + +#[cfg(feature = "auth-oidc")] +use axum::response::Response; use axum::{ Router, extract::{Json, State}, + http::StatusCode, response::IntoResponse, - routing::post, + routing::{get, post}, }; +use serde::Serialize; +#[cfg(feature = "auth-oidc")] +use tower_sessions::Session; +#[cfg(feature = "auth-axum-login")] +use crate::config::AuthMode; use crate::{ dto::{LoginRequest, RegisterRequest, UserResponse}, error::ApiError, + extractors::CurrentUser, state::AppState, }; +#[cfg(feature = "auth-axum-login")] use domain::{DomainError, Email}; -use tower_sessions::Session; + +/// Token response for JWT authentication +#[derive(Debug, Serialize)] +pub struct TokenResponse { + pub access_token: String, + pub token_type: String, + pub expires_in: u64, +} + +/// Login response that can be either a user (session mode) or a token (JWT mode) +#[derive(Debug, Serialize)] +#[serde(untagged)] +pub enum LoginResponse { + User(UserResponse), + Token(TokenResponse), +} pub fn router() -> Router { let r = Router::new() .route("/login", post(login)) .route("/register", post(register)) .route("/logout", post(logout)) - .route("/me", post(me)); + .route("/me", get(me)); + + // Add token endpoint for getting JWT from session + #[cfg(feature = "auth-jwt")] + let r = r.route("/token", post(get_token)); #[cfg(feature = "auth-oidc")] let r = r - .route("/login/oidc", axum::routing::get(oidc_login)) - .route("/auth/callback", axum::routing::get(oidc_callback)); + .route("/login/oidc", get(oidc_login)) + .route("/callback", get(oidc_callback)); r } +/// Login endpoint +/// +/// In session mode: Creates a session and returns user info +/// In JWT mode: Validates credentials and returns a JWT token +/// In both mode: Creates session AND returns JWT token +#[cfg(feature = "auth-axum-login")] async fn login( + State(state): State, mut auth_session: crate::auth::AuthSession, Json(payload): Json, ) -> Result { @@ -45,21 +85,56 @@ async fn login( None => return Err(ApiError::Validation("Invalid credentials".to_string())), }; - auth_session - .login(&user) - .await - .map_err(|_| ApiError::Internal("Login failed".to_string()))?; + let auth_mode = state.config.auth_mode; + // In session or both mode, create session + if matches!(auth_mode, AuthMode::Session | AuthMode::Both) { + auth_session + .login(&user) + .await + .map_err(|_| ApiError::Internal("Login failed".to_string()))?; + } + + // In JWT or both mode, return token + #[cfg(feature = "auth-jwt")] + if matches!(auth_mode, AuthMode::Jwt | AuthMode::Both) { + let token = create_jwt_for_user(&user.0, &state)?; + return Ok(( + StatusCode::OK, + Json(LoginResponse::Token(TokenResponse { + access_token: token, + token_type: "Bearer".to_string(), + expires_in: state.config.jwt_expiry_hours * 3600, + })), + )); + } + + // Session mode: return user info Ok(( StatusCode::OK, - Json(UserResponse { + Json(LoginResponse::User(UserResponse { id: user.0.id, email: user.0.email.into_inner(), created_at: user.0.created_at, - }), + })), )) } +/// Fallback login when auth-axum-login is not enabled +/// Without auth-axum-login, password-based authentication is not available. +/// Use OIDC login instead: GET /api/v1/auth/login/oidc +#[cfg(not(feature = "auth-axum-login"))] +async fn login( + State(_state): State, + Json(_payload): Json, +) -> Result<(StatusCode, Json), ApiError> { + Err(ApiError::Internal( + "Password-based login not available. auth-axum-login feature is required. Use OIDC login at /api/v1/auth/login/oidc instead.".to_string(), + )) +} + +/// Register endpoint +#[cfg(feature = "auth-axum-login")] async fn register( State(state): State, mut auth_session: crate::auth::AuthSession, @@ -76,9 +151,6 @@ async fn register( ))); } - // Note: In a real app, you would hash the password here. - // This template uses a simplified User::new which doesn't take password. - // You should extend User to handle passwords or use an OIDC flow. let email = Email::try_from(payload.email).map_err(|e| ApiError::Validation(e.to_string()))?; // Using email as subject for local auth for now @@ -87,24 +159,54 @@ async fn register( .find_or_create(&email.as_ref().to_string(), email.as_ref()) .await?; - // Log the user in - let auth_user = crate::auth::AuthUser(user.clone()); + let auth_mode = state.config.auth_mode; - auth_session - .login(&auth_user) - .await - .map_err(|_| ApiError::Internal("Login failed".to_string()))?; + // In session or both mode, create session + if matches!(auth_mode, AuthMode::Session | AuthMode::Both) { + let auth_user = crate::auth::AuthUser(user.clone()); + auth_session + .login(&auth_user) + .await + .map_err(|_| ApiError::Internal("Login failed".to_string()))?; + } + + // In JWT or both mode, return token + #[cfg(feature = "auth-jwt")] + if matches!(auth_mode, AuthMode::Jwt | AuthMode::Both) { + let token = create_jwt_for_user(&user, &state)?; + return Ok(( + StatusCode::CREATED, + Json(LoginResponse::Token(TokenResponse { + access_token: token, + token_type: "Bearer".to_string(), + expires_in: state.config.jwt_expiry_hours * 3600, + })), + )); + } Ok(( StatusCode::CREATED, - Json(UserResponse { + Json(LoginResponse::User(UserResponse { id: user.id, email: user.email.into_inner(), created_at: user.created_at, - }), + })), )) } +/// Fallback register when auth-axum-login is not enabled +#[cfg(not(feature = "auth-axum-login"))] +async fn register( + State(_state): State, + Json(_payload): Json, +) -> Result<(StatusCode, Json), ApiError> { + Err(ApiError::Internal( + "Session-based registration not available. Use JWT token endpoint.".to_string(), + )) +} + +/// Logout endpoint +#[cfg(feature = "auth-axum-login")] async fn logout(mut auth_session: crate::auth::AuthSession) -> impl IntoResponse { match auth_session.logout().await { Ok(_) => StatusCode::OK, @@ -112,23 +214,61 @@ async fn logout(mut auth_session: crate::auth::AuthSession) -> impl IntoResponse } } -async fn me(auth_session: crate::auth::AuthSession) -> Result { - let user = auth_session - .user - .ok_or(ApiError::Unauthorized("Not logged in".to_string()))?; +/// Fallback logout when auth-axum-login is not enabled +#[cfg(not(feature = "auth-axum-login"))] +async fn logout() -> impl IntoResponse { + // JWT tokens can't be "logged out" server-side without a blocklist + // Just return OK + StatusCode::OK +} +/// Get current user info +async fn me(CurrentUser(user): CurrentUser) -> Result { Ok(Json(UserResponse { - id: user.0.id, - email: user.0.email.into_inner(), - created_at: user.0.created_at, + id: user.id, + email: user.email.into_inner(), + created_at: user.created_at, })) } -#[cfg(feature = "auth-oidc")] -async fn oidc_login( +/// Get a JWT token for the current session user +/// +/// This allows session-authenticated users to obtain a JWT for API access. +#[cfg(feature = "auth-jwt")] +async fn get_token( State(state): State, - session: Session, + CurrentUser(user): CurrentUser, ) -> Result { + let token = create_jwt_for_user(&user, &state)?; + + Ok(Json(TokenResponse { + access_token: token, + token_type: "Bearer".to_string(), + expires_in: state.config.jwt_expiry_hours * 3600, + })) +} + +/// Helper to create JWT for a user +#[cfg(feature = "auth-jwt")] +fn create_jwt_for_user(user: &domain::User, state: &AppState) -> Result { + let validator = state + .jwt_validator + .as_ref() + .ok_or_else(|| ApiError::Internal("JWT not configured".to_string()))?; + + validator + .create_token(user) + .map_err(|e| ApiError::Internal(format!("Failed to create token: {}", e))) +} + +// ============================================================================ +// OIDC Routes +// ============================================================================ + +#[cfg(feature = "auth-oidc")] +async fn oidc_login(State(state): State, session: Session) -> Result { + use axum::http::header; + let service = state .oidc_service .as_ref() @@ -149,7 +289,19 @@ async fn oidc_login( .await .map_err(|_| ApiError::Internal("Session error".into()))?; - Ok(axum::response::Redirect::to(&url)) + let response = axum::response::Redirect::to(&url).into_response(); + let (mut parts, body) = response.into_parts(); + + parts.headers.insert( + header::CACHE_CONTROL, + "no-cache, no-store, must-revalidate".parse().unwrap(), + ); + parts + .headers + .insert(header::PRAGMA, "no-cache".parse().unwrap()); + parts.headers.insert(header::EXPIRES, "0".parse().unwrap()); + + Ok(Response::from_parts(parts, body)) } #[cfg(feature = "auth-oidc")] @@ -159,7 +311,7 @@ struct CallbackParams { state: String, } -#[cfg(feature = "auth-oidc")] +#[cfg(all(feature = "auth-oidc", feature = "auth-axum-login"))] async fn oidc_callback( State(state): State, session: Session, @@ -181,7 +333,6 @@ async fn oidc_callback( return Err(ApiError::Validation("Invalid CSRF token".into())); } - // 2. Retrieve secrets let stored_pkce: String = session .get("oidc_pkce") .await @@ -204,11 +355,17 @@ async fn oidc_callback( .await .map_err(|e| ApiError::Internal(e.to_string()))?; - auth_session - .login(&crate::auth::AuthUser(user)) - .await - .map_err(|_| ApiError::Internal("Login failed".into()))?; + let auth_mode = state.config.auth_mode; + // In session or both mode, create session + if matches!(auth_mode, AuthMode::Session | AuthMode::Both) { + auth_session + .login(&crate::auth::AuthUser(user.clone())) + .await + .map_err(|_| ApiError::Internal("Login failed".into()))?; + } + + // Clean up OIDC state let _: Option = session .remove("oidc_csrf") .await @@ -222,5 +379,101 @@ async fn oidc_callback( .await .map_err(|_| ApiError::Internal("Session error".into()))?; - Ok(axum::response::Redirect::to("/")) + // In JWT mode, return token as JSON + #[cfg(feature = "auth-jwt")] + if matches!(auth_mode, AuthMode::Jwt | AuthMode::Both) { + let token = create_jwt_for_user(&user, &state)?; + return Ok(Json(TokenResponse { + access_token: token, + token_type: "Bearer".to_string(), + expires_in: state.config.jwt_expiry_hours * 3600, + }) + .into_response()); + } + + // Session mode: return user info + Ok(Json(UserResponse { + id: user.id, + email: user.email.into_inner(), + created_at: user.created_at, + }) + .into_response()) +} + +/// Fallback OIDC callback when auth-axum-login is not enabled +#[cfg(all(feature = "auth-oidc", not(feature = "auth-axum-login")))] +async fn oidc_callback( + State(state): State, + session: Session, + axum::extract::Query(params): axum::extract::Query, +) -> Result { + let service = state + .oidc_service + .as_ref() + .ok_or(ApiError::Internal("OIDC not configured".into()))?; + + let stored_csrf: String = session + .get("oidc_csrf") + .await + .map_err(|_| ApiError::Internal("Session error".into()))? + .ok_or(ApiError::Validation("Missing CSRF token".into()))?; + + if params.state != stored_csrf { + return Err(ApiError::Validation("Invalid CSRF token".into())); + } + + let stored_pkce: String = session + .get("oidc_pkce") + .await + .map_err(|_| ApiError::Internal("Session error".into()))? + .ok_or(ApiError::Validation("Missing PKCE".into()))?; + let stored_nonce: String = session + .get("oidc_nonce") + .await + .map_err(|_| ApiError::Internal("Session error".into()))? + .ok_or(ApiError::Validation("Missing Nonce".into()))?; + + let oidc_user = service + .resolve_callback(params.code, stored_nonce, stored_pkce) + .await + .map_err(|e| ApiError::Internal(e.to_string()))?; + + let user = state + .user_service + .find_or_create(&oidc_user.subject, &oidc_user.email) + .await + .map_err(|e| ApiError::Internal(e.to_string()))?; + + // Clean up OIDC state + let _: Option = session + .remove("oidc_csrf") + .await + .map_err(|_| ApiError::Internal("Session error".into()))?; + let _: Option = session + .remove("oidc_pkce") + .await + .map_err(|_| ApiError::Internal("Session error".into()))?; + let _: Option = session + .remove("oidc_nonce") + .await + .map_err(|_| ApiError::Internal("Session error".into()))?; + + // Return token as JSON + #[cfg(feature = "auth-jwt")] + { + let token = create_jwt_for_user(&user, &state)?; + return Ok(Json(TokenResponse { + access_token: token, + token_type: "Bearer".to_string(), + expires_in: state.config.jwt_expiry_hours * 3600, + })); + } + + #[cfg(not(feature = "auth-jwt"))] + { + let _ = user; // Suppress unused warning + Err(ApiError::Internal( + "No auth backend available for OIDC callback".to_string(), + )) + } } diff --git a/api/src/state.rs b/api/src/state.rs index a6329af..1a0c4ed 100644 --- a/api/src/state.rs +++ b/api/src/state.rs @@ -3,11 +3,13 @@ //! Holds shared state for the application. use axum::extract::FromRef; +#[cfg(feature = "auth-jwt")] +use infra::auth::jwt::{JwtConfig, JwtValidator}; #[cfg(feature = "auth-oidc")] use infra::auth::oidc::OidcService; use std::sync::Arc; -use crate::config::Config; +use crate::config::{AuthMode, Config}; use domain::UserService; #[derive(Clone)] @@ -15,32 +17,80 @@ pub struct AppState { pub user_service: Arc, #[cfg(feature = "auth-oidc")] pub oidc_service: Option>, - + #[cfg(feature = "auth-jwt")] + pub jwt_validator: Option>, pub config: Arc, } impl AppState { pub async fn new(user_service: UserService, config: Config) -> anyhow::Result { #[cfg(feature = "auth-oidc")] - let oidc_service = if let (Some(issuer), Some(id), Some(secret), Some(redirect)) = ( + let oidc_service = if let (Some(issuer), Some(id), secret, Some(redirect), resource_id) = ( &config.oidc_issuer, &config.oidc_client_id, &config.oidc_client_secret, &config.oidc_redirect_url, + &config.oidc_resource_id, ) { tracing::info!("Initializing OIDC service with issuer: {}", issuer); Some(Arc::new( - OidcService::new(issuer.clone(), id.clone(), secret.clone(), redirect.clone()) - .await?, + OidcService::new( + issuer.clone(), + id.clone(), + secret.clone().unwrap_or_default(), + redirect.clone(), + resource_id.clone(), + ) + .await?, )) } else { None }; + #[cfg(feature = "auth-jwt")] + let jwt_validator = if matches!(config.auth_mode, AuthMode::Jwt | AuthMode::Both) { + // Use provided secret or fall back to a development secret + let secret = if let Some(ref s) = config.jwt_secret { + if s.is_empty() { None } else { Some(s.clone()) } + } else { + None + }; + + let secret = match secret { + Some(s) => s, + None => { + if config.is_production { + anyhow::bail!( + "JWT_SECRET is required when AUTH_MODE is 'jwt' or 'both' in production" + ); + } + // Use a development-only default secret + tracing::warn!( + "⚠️ JWT_SECRET not set - using insecure development secret. DO NOT USE IN PRODUCTION!" + ); + "k-template-dev-secret-not-for-production-use-only".to_string() + } + }; + + tracing::info!("Initializing JWT validator"); + let jwt_config = JwtConfig::new( + secret, + config.jwt_issuer.clone(), + config.jwt_audience.clone(), + Some(config.jwt_expiry_hours), + config.is_production, + )?; + Some(Arc::new(JwtValidator::new(jwt_config))) + } else { + None + }; + Ok(Self { user_service: Arc::new(user_service), #[cfg(feature = "auth-oidc")] oidc_service, + #[cfg(feature = "auth-jwt")] + jwt_validator, config: Arc::new(config), }) } diff --git a/infra/Cargo.toml b/infra/Cargo.toml index ddb9095..df0b625 100644 --- a/infra/Cargo.toml +++ b/infra/Cargo.toml @@ -20,6 +20,7 @@ postgres = [ broker-nats = ["dep:futures-util", "k-core/broker-nats"] auth-axum-login = ["dep:axum-login", "dep:password-auth"] auth-oidc = ["dep:openidconnect", "dep:url"] +auth-jwt = ["dep:jsonwebtoken"] [dependencies] k-core = { git = "https://git.gabrielkaszewski.dev/GKaszewski/k-core", features = [ @@ -50,4 +51,5 @@ axum-login = { version = "0.18", optional = true } password-auth = { version = "1.0", optional = true } openidconnect = { version = "4.0.1", optional = true } url = { version = "2.5.8", optional = true } +jsonwebtoken = { version = "9.3", optional = true } # reqwest = { version = "0.13.1", features = ["blocking", "json"], optional = true } diff --git a/infra/src/auth/jwt.rs b/infra/src/auth/jwt.rs new file mode 100644 index 0000000..6fb54b5 --- /dev/null +++ b/infra/src/auth/jwt.rs @@ -0,0 +1,278 @@ +//! JWT Authentication Infrastructure +//! +//! Provides JWT token creation and validation using HS256 (secret-based). +//! For OIDC/JWKS validation, see the `oidc` module. + +use domain::User; +use jsonwebtoken::{Algorithm, DecodingKey, EncodingKey, Header, Validation, decode, encode}; +use serde::{Deserialize, Serialize}; +use std::time::{SystemTime, UNIX_EPOCH}; + +/// Minimum secret length for production (256 bits = 32 bytes) +const MIN_SECRET_LENGTH: usize = 32; + +/// JWT configuration +#[derive(Debug, Clone)] +pub struct JwtConfig { + /// Secret key for HS256 signing/verification + pub secret: String, + /// Expected issuer (for validation) + pub issuer: Option, + /// Expected audience (for validation) + pub audience: Option, + /// Token expiry in hours (default: 24) + pub expiry_hours: u64, +} + +impl JwtConfig { + /// Create a new JWT config with validation + /// + /// In production mode, this will reject weak secrets. + pub fn new( + secret: String, + issuer: Option, + audience: Option, + expiry_hours: Option, + is_production: bool, + ) -> Result { + // Validate secret strength in production + if is_production && secret.len() < MIN_SECRET_LENGTH { + return Err(JwtError::WeakSecret { + min_length: MIN_SECRET_LENGTH, + actual_length: secret.len(), + }); + } + + Ok(Self { + secret, + issuer, + audience, + expiry_hours: expiry_hours.unwrap_or(24), + }) + } + + /// Create config without validation (for testing) + pub fn new_unchecked(secret: String) -> Self { + Self { + secret, + issuer: None, + audience: None, + expiry_hours: 24, + } + } +} + +/// JWT claims structure +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct JwtClaims { + /// Subject - the user's unique identifier (user ID as string) + pub sub: String, + /// User's email address + pub email: String, + /// Expiry timestamp (seconds since UNIX epoch) + pub exp: usize, + /// Issued at timestamp (seconds since UNIX epoch) + pub iat: usize, + /// Issuer + #[serde(skip_serializing_if = "Option::is_none")] + pub iss: Option, + /// Audience + #[serde(skip_serializing_if = "Option::is_none")] + pub aud: Option, +} + +/// JWT-related errors +#[derive(Debug, thiserror::Error)] +pub enum JwtError { + #[error("JWT secret is too weak: minimum {min_length} bytes required, got {actual_length}")] + WeakSecret { + min_length: usize, + actual_length: usize, + }, + + #[error("Token creation failed: {0}")] + CreationFailed(#[from] jsonwebtoken::errors::Error), + + #[error("Token validation failed: {0}")] + ValidationFailed(String), + + #[error("Token expired")] + Expired, + + #[error("Invalid token format")] + InvalidFormat, + + #[error("Missing configuration")] + MissingConfig, +} + +/// JWT token validator and generator +#[derive(Clone)] +pub struct JwtValidator { + config: JwtConfig, + encoding_key: EncodingKey, + decoding_key: DecodingKey, + validation: Validation, +} + +impl JwtValidator { + /// Create a new JWT validator with the given configuration + pub fn new(config: JwtConfig) -> Self { + let encoding_key = EncodingKey::from_secret(config.secret.as_bytes()); + let decoding_key = DecodingKey::from_secret(config.secret.as_bytes()); + + let mut validation = Validation::new(Algorithm::HS256); + + // Configure issuer validation if set + if let Some(ref issuer) = config.issuer { + validation.set_issuer(&[issuer]); + } + + // Configure audience validation if set + if let Some(ref audience) = config.audience { + validation.set_audience(&[audience]); + } + + Self { + config, + encoding_key, + decoding_key, + validation, + } + } + + /// Create a JWT token for the given user + pub fn create_token(&self, user: &User) -> Result { + let now = SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("Time went backwards") + .as_secs() as usize; + + let expiry = now + (self.config.expiry_hours as usize * 3600); + + let claims = JwtClaims { + sub: user.id.to_string(), + email: user.email.as_ref().to_string(), + exp: expiry, + iat: now, + iss: self.config.issuer.clone(), + aud: self.config.audience.clone(), + }; + + let header = Header::new(Algorithm::HS256); + encode(&header, &claims, &self.encoding_key).map_err(JwtError::CreationFailed) + } + + /// Validate a JWT token and return the claims + pub fn validate_token(&self, token: &str) -> Result { + let token_data = decode::(token, &self.decoding_key, &self.validation).map_err( + |e| match e.kind() { + jsonwebtoken::errors::ErrorKind::ExpiredSignature => JwtError::Expired, + jsonwebtoken::errors::ErrorKind::InvalidToken => JwtError::InvalidFormat, + _ => JwtError::ValidationFailed(e.to_string()), + }, + )?; + + Ok(token_data.claims) + } + + /// Get the user ID (subject) from a token without full validation + /// Useful for logging/debugging, but should not be trusted for auth + pub fn decode_unverified(&self, token: &str) -> Result { + let mut validation = Validation::new(Algorithm::HS256); + validation.insecure_disable_signature_validation(); + validation.validate_exp = false; + + let token_data = decode::(token, &self.decoding_key, &validation) + .map_err(|_| JwtError::InvalidFormat)?; + + Ok(token_data.claims) + } +} + +impl std::fmt::Debug for JwtValidator { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("JwtValidator") + .field("issuer", &self.config.issuer) + .field("audience", &self.config.audience) + .field("expiry_hours", &self.config.expiry_hours) + .finish_non_exhaustive() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use domain::Email; + + fn create_test_user() -> User { + let email = Email::try_from("test@example.com").unwrap(); + User::new("test-subject", email) + } + + #[test] + fn test_create_and_validate_token() { + let config = JwtConfig::new_unchecked("test-secret-key-that-is-long-enough".to_string()); + let validator = JwtValidator::new(config); + let user = create_test_user(); + + let token = validator.create_token(&user).expect("Should create token"); + let claims = validator + .validate_token(&token) + .expect("Should validate token"); + + assert_eq!(claims.sub, user.id.to_string()); + assert_eq!(claims.email, "test@example.com"); + } + + #[test] + fn test_weak_secret_rejected_in_production() { + let result = JwtConfig::new( + "short".to_string(), // Too short + None, + None, + None, + true, // Production mode + ); + + assert!(matches!(result, Err(JwtError::WeakSecret { .. }))); + } + + #[test] + fn test_weak_secret_allowed_in_development() { + let result = JwtConfig::new( + "short".to_string(), // Too short but OK in dev + None, + None, + None, + false, // Development mode + ); + + assert!(result.is_ok()); + } + + #[test] + fn test_invalid_token_rejected() { + let config = JwtConfig::new_unchecked("test-secret-key-that-is-long-enough".to_string()); + let validator = JwtValidator::new(config); + + let result = validator.validate_token("invalid.token.here"); + assert!(result.is_err()); + } + + #[test] + fn test_wrong_secret_rejected() { + let config1 = JwtConfig::new_unchecked("secret-one-that-is-long-enough".to_string()); + let config2 = JwtConfig::new_unchecked("secret-two-that-is-long-enough".to_string()); + + let validator1 = JwtValidator::new(config1); + let validator2 = JwtValidator::new(config2); + + let user = create_test_user(); + let token = validator1.create_token(&user).unwrap(); + + // Token from validator1 should fail on validator2 + let result = validator2.validate_token(&token); + assert!(result.is_err()); + } +} diff --git a/infra/src/auth/mod.rs b/infra/src/auth/mod.rs index f00a3b0..9dd30f5 100644 --- a/infra/src/auth/mod.rs +++ b/infra/src/auth/mod.rs @@ -118,3 +118,6 @@ pub mod backend { #[cfg(feature = "auth-oidc")] pub mod oidc; + +#[cfg(feature = "auth-jwt")] +pub mod jwt; diff --git a/infra/src/auth/oidc.rs b/infra/src/auth/oidc.rs index c6ac845..697f4fd 100644 --- a/infra/src/auth/oidc.rs +++ b/infra/src/auth/oidc.rs @@ -3,7 +3,7 @@ use openidconnect::{ AccessTokenHash, AuthorizationCode, Client, ClientId, ClientSecret, CsrfToken, EmptyAdditionalClaims, EndpointMaybeSet, EndpointNotSet, EndpointSet, IssuerUrl, Nonce, OAuth2TokenResponse, PkceCodeChallenge, PkceCodeVerifier, RedirectUrl, Scope, - StandardErrorResponse, TokenResponse, + StandardErrorResponse, TokenResponse, UserInfoClaims, core::{ CoreAuthDisplay, CoreAuthPrompt, CoreAuthenticationFlow, CoreClient, CoreErrorResponseType, CoreGenderClaim, CoreJsonWebKey, CoreJweContentEncryptionAlgorithm, CoreProviderMetadata, @@ -36,6 +36,7 @@ pub type OidcClient = Client< #[derive(Clone)] pub struct OidcService { client: OidcClient, + resource_id: Option, } #[derive(Debug)] @@ -51,7 +52,31 @@ impl OidcService { client_id: String, client_secret: String, redirect_url: String, + resource_id: Option, ) -> anyhow::Result { + let client_id = client_id.trim().to_string(); + let redirect_url = redirect_url.trim().to_string(); + let issuer = issuer.trim().to_string(); + + // 2. Handle Empty Secret (For PKCE/Public Clients) + let client_secret_clean = client_secret.trim(); + let client_secret_opt = if client_secret_clean.is_empty() { + None + } else { + Some(ClientSecret::new(client_secret_clean.to_string())) + }; + + tracing::debug!("🔵 OIDC Setup: Client ID = '{}'", client_id); + tracing::debug!("🔵 OIDC Setup: Redirect = '{}'", redirect_url); + tracing::debug!( + "🔵 OIDC Setup: Secret = {:?}", + if client_secret_opt.is_some() { + "SET" + } else { + "NONE" + } + ); + let http_client = reqwest::ClientBuilder::new() .redirect(reqwest::redirect::Policy::none()) .build()?; @@ -62,11 +87,14 @@ impl OidcService { let client = CoreClient::from_provider_metadata( provider_metadata, ClientId::new(client_id), - Some(ClientSecret::new(client_secret)), + client_secret_opt, ) .set_redirect_uri(RedirectUrl::new(redirect_url)?); - Ok(Self { client }) + Ok(Self { + client, + resource_id, + }) } // todo: replace this tuple with newtype @@ -118,7 +146,15 @@ impl OidcService { .id_token() .ok_or_else(|| anyhow!("Server did not return an ID token"))?; - let id_token_verifier = self.client.id_token_verifier(); + let mut id_token_verifier = self.client.id_token_verifier().clone(); + + let trusted_resource_id = self.resource_id.clone(); + + if let Some(resource_id) = trusted_resource_id { + id_token_verifier = id_token_verifier + .set_other_audience_verifier_fn(move |aud| aud.as_str() == resource_id); + } + let claims = id_token.claims(&id_token_verifier, &nonce)?; if let Some(expected_access_token_hash) = claims.access_token_hash() { @@ -133,13 +169,28 @@ impl OidcService { } } + let email = if let Some(email) = claims.email() { + Some(email.as_str().to_string()) + } else { + // Fallback: Call UserInfo Endpoint using the Access Token + tracing::debug!("🔵 Email missing in ID Token, fetching UserInfo..."); + + let user_info: UserInfoClaims = self + .client + .user_info(token_response.access_token().clone(), None)? + .request_async(&http_client) + .await?; + + user_info.email().map(|e| e.as_str().to_string()) + }; + + // If email is still missing, we must error out because your app requires valid emails + let email = + email.ok_or_else(|| anyhow!("User has no verified email address in ZITADEL"))?; + Ok(OidcUser { subject: claims.subject().to_string(), - email: claims - .email() - .map(|email| email.as_str()) - .unwrap_or("") - .to_string(), + email, }) } }