feat: add JWT authentication and flexible auth modes with configurable login responses

This commit is contained in:
2026-01-06 05:01:56 +01:00
parent 5296171b85
commit 16dcc4b95e
15 changed files with 1058 additions and 71 deletions

View File

@@ -5,11 +5,13 @@ edition = "2024"
default-run = "api"
[features]
default = ["sqlite", "auth-axum-login", "auth-oidc"]
default = ["sqlite", "auth-axum-login", "auth-oidc", "auth-jwt"]
sqlite = ["infra/sqlite", "tower-sessions-sqlx-store/sqlite"]
postgres = ["infra/postgres", "tower-sessions-sqlx-store/postgres"]
auth-axum-login = ["infra/auth-axum-login"]
auth-oidc = ["infra/auth-oidc"]
auth-jwt = ["infra/auth-jwt"]
auth-full = ["auth-axum-login", "auth-oidc", "auth-jwt"]
[dependencies]
k-core = { git = "https://git.gabrielkaszewski.dev/GKaszewski/k-core", features = [

View File

@@ -2,11 +2,15 @@
//!
//! Proxies to infra implementation if enabled.
#[cfg(feature = "auth-axum-login")]
use std::sync::Arc;
#[cfg(feature = "auth-axum-login")]
use domain::UserRepository;
#[cfg(feature = "auth-axum-login")]
use infra::session_store::{InfraSessionStore, SessionManagerLayer};
#[cfg(feature = "auth-axum-login")]
use crate::error::ApiError;
#[cfg(feature = "auth-axum-login")]

View File

@@ -6,6 +6,30 @@ use std::env;
use serde::Deserialize;
/// Authentication mode - determines how the API authenticates requests
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum AuthMode {
/// Session-based authentication using cookies (default for backward compatibility)
#[default]
Session,
/// JWT-based authentication using Bearer tokens
Jwt,
/// Support both session and JWT authentication (try JWT first, then session)
Both,
}
impl AuthMode {
/// Parse auth mode from string
pub fn from_str(s: &str) -> Self {
match s.to_lowercase().as_str() {
"jwt" => AuthMode::Jwt,
"both" => AuthMode::Both,
_ => AuthMode::Session,
}
}
}
//todo: replace with newtypes
#[derive(Debug, Clone, Deserialize)]
pub struct Config {
@@ -28,10 +52,27 @@ pub struct Config {
#[serde(default = "default_db_min_connections")]
pub db_min_connections: u32,
// OIDC configuration
pub oidc_issuer: Option<String>,
pub oidc_client_id: Option<String>,
pub oidc_client_secret: Option<String>,
pub oidc_redirect_url: Option<String>,
pub oidc_resource_id: Option<String>,
// Auth mode configuration
#[serde(default)]
pub auth_mode: AuthMode,
// JWT configuration
pub jwt_secret: Option<String>,
pub jwt_issuer: Option<String>,
pub jwt_audience: Option<String>,
#[serde(default = "default_jwt_expiry_hours")]
pub jwt_expiry_hours: u64,
/// Whether the application is running in production mode
#[serde(default)]
pub is_production: bool,
}
fn default_secure_cookie() -> bool {
@@ -54,6 +95,10 @@ fn default_host() -> String {
"127.0.0.1".to_string()
}
fn default_jwt_expiry_hours() -> u64 {
24
}
impl Config {
pub fn new() -> Result<Self, config::ConfigError> {
config::Config::builder()
@@ -108,6 +153,26 @@ impl Config {
let oidc_client_id = env::var("OIDC_CLIENT_ID").ok();
let oidc_client_secret = env::var("OIDC_CLIENT_SECRET").ok();
let oidc_redirect_url = env::var("OIDC_REDIRECT_URL").ok();
let oidc_resource_id = env::var("OIDC_RESOURCE_ID").ok();
// Auth mode configuration
let auth_mode = env::var("AUTH_MODE")
.map(|s| AuthMode::from_str(&s))
.unwrap_or_default();
// JWT configuration
let jwt_secret = env::var("JWT_SECRET").ok();
let jwt_issuer = env::var("JWT_ISSUER").ok();
let jwt_audience = env::var("JWT_AUDIENCE").ok();
let jwt_expiry_hours = env::var("JWT_EXPIRY_HOURS")
.ok()
.and_then(|s| s.parse().ok())
.unwrap_or(24);
let is_production = env::var("PRODUCTION")
.or_else(|_| env::var("RUST_ENV"))
.map(|v| v.to_lowercase() == "production" || v == "1" || v == "true")
.unwrap_or(false);
Self {
host,
@@ -122,6 +187,13 @@ impl Config {
oidc_client_id,
oidc_client_secret,
oidc_redirect_url,
oidc_resource_id,
auth_mode,
jwt_secret,
jwt_issuer,
jwt_audience,
jwt_expiry_hours,
is_production,
}
}
}

View File

@@ -40,3 +40,12 @@ pub struct UserResponse {
pub struct ConfigResponse {
pub allow_registration: bool,
}
#[cfg(feature = "auth-jwt")]
#[derive(Debug, Serialize, Deserialize)]
// also newtypes
pub struct Claims {
pub sub: String,
pub email: String,
pub exp: usize,
}

View File

@@ -14,6 +14,7 @@ use domain::DomainError;
/// API-level errors
#[derive(Debug, Error)]
#[allow(dead_code)] // Some variants are reserved for future use
pub enum ApiError {
#[error("{0}")]
Domain(#[from] DomainError),
@@ -107,6 +108,7 @@ impl IntoResponse for ApiError {
}
}
#[allow(dead_code)] // Helper constructors for future use
impl ApiError {
pub fn validation(msg: impl Into<String>) -> Self {
Self::Validation(msg.into())
@@ -118,4 +120,5 @@ impl ApiError {
}
/// Result type alias for API handlers
#[allow(dead_code)]
pub type ApiResult<T> = Result<T, ApiError>;

150
api/src/extractors.rs Normal file
View File

@@ -0,0 +1,150 @@
//! Auth extractors for API handlers
//!
//! Provides the `CurrentUser` extractor that works with both session and JWT auth.
use axum::{extract::FromRequestParts, http::request::Parts};
use domain::User;
use crate::config::AuthMode;
use crate::error::ApiError;
use crate::state::AppState;
/// Extracted current user from the request.
///
/// This extractor supports multiple authentication methods based on the configured `AuthMode`:
/// - `Session`: Uses axum-login session cookies
/// - `Jwt`: Uses Bearer token in Authorization header
/// - `Both`: Tries JWT first, then falls back to session
pub struct CurrentUser(pub User);
impl FromRequestParts<AppState> for CurrentUser {
type Rejection = ApiError;
async fn from_request_parts(
parts: &mut Parts,
state: &AppState,
) -> Result<Self, Self::Rejection> {
let auth_mode = state.config.auth_mode;
// Try JWT first if enabled
#[cfg(feature = "auth-jwt")]
if matches!(auth_mode, AuthMode::Jwt | AuthMode::Both) {
match try_jwt_auth(parts, state).await {
Ok(Some(user)) => return Ok(CurrentUser(user)),
Ok(None) => {
// No JWT token present, continue to session auth if Both mode
if auth_mode == AuthMode::Jwt {
return Err(ApiError::Unauthorized(
"Missing or invalid Authorization header".to_string(),
));
}
}
Err(e) => {
// JWT was present but invalid
tracing::debug!("JWT auth failed: {}", e);
if auth_mode == AuthMode::Jwt {
return Err(e);
}
// In Both mode, continue to try session
}
}
}
// Try session auth if enabled
#[cfg(feature = "auth-axum-login")]
if matches!(auth_mode, AuthMode::Session | AuthMode::Both) {
if let Some(user) = try_session_auth(parts).await? {
return Ok(CurrentUser(user));
}
}
Err(ApiError::Unauthorized("Not authenticated".to_string()))
}
}
/// Try to authenticate using JWT Bearer token
#[cfg(feature = "auth-jwt")]
async fn try_jwt_auth(parts: &mut Parts, state: &AppState) -> Result<Option<User>, ApiError> {
use axum::http::header::AUTHORIZATION;
// Get Authorization header
let auth_header = match parts.headers.get(AUTHORIZATION) {
Some(header) => header,
None => return Ok(None), // No header = no JWT auth attempted
};
let auth_str = auth_header
.to_str()
.map_err(|_| ApiError::Unauthorized("Invalid Authorization header encoding".to_string()))?;
// Extract Bearer token
let token = auth_str.strip_prefix("Bearer ").ok_or_else(|| {
ApiError::Unauthorized("Authorization header must use Bearer scheme".to_string())
})?;
// Get JWT validator
let validator = state
.jwt_validator
.as_ref()
.ok_or_else(|| ApiError::Internal("JWT validator not configured".to_string()))?;
// Validate token
let claims = validator.validate_token(token).map_err(|e| {
tracing::debug!("JWT validation failed: {:?}", e);
match e {
infra::auth::jwt::JwtError::Expired => {
ApiError::Unauthorized("Token expired".to_string())
}
infra::auth::jwt::JwtError::InvalidFormat => {
ApiError::Unauthorized("Invalid token format".to_string())
}
_ => ApiError::Unauthorized("Token validation failed".to_string()),
}
})?;
// Fetch user from database by ID (subject contains user ID)
let user_id: uuid::Uuid = claims
.sub
.parse()
.map_err(|_| ApiError::Unauthorized("Invalid user ID in token".to_string()))?;
let user = state
.user_service
.find_by_id(user_id)
.await
.map_err(|e| ApiError::Internal(format!("Failed to fetch user: {}", e)))?;
Ok(Some(user))
}
/// Try to authenticate using session cookie
#[cfg(feature = "auth-axum-login")]
async fn try_session_auth(parts: &mut Parts) -> Result<Option<User>, ApiError> {
use infra::auth::backend::AuthSession;
// Check if AuthSession extension is present (added by auth middleware)
if let Some(auth_session) = parts.extensions.get::<AuthSession>() {
if let Some(auth_user) = &auth_session.user {
return Ok(Some(auth_user.0.clone()));
}
}
Ok(None)
}
/// Fallback for when auth-axum-login is not enabled
#[cfg(not(feature = "auth-axum-login"))]
async fn try_session_auth(_parts: &mut Parts) -> Result<Option<User>, ApiError> {
Ok(None)
}
/// Fallback for when auth-jwt is not enabled but auth mode requires it
#[cfg(not(feature = "auth-jwt"))]
async fn try_jwt_auth(_parts: &mut Parts, state: &AppState) -> Result<Option<User>, ApiError> {
if matches!(state.config.auth_mode, AuthMode::Jwt) {
return Err(ApiError::Internal(
"JWT auth mode configured but auth-jwt feature not enabled".to_string(),
));
}
Ok(None)
}

View File

@@ -1,3 +1,7 @@
//! API Server Entry Point
//!
//! Configures and starts the HTTP server with authentication based on AUTH_MODE.
use std::net::SocketAddr;
use std::time::Duration as StdDuration;
@@ -12,17 +16,18 @@ use k_core::http::server::apply_standard_middleware;
use k_core::logging;
use time::Duration;
use tokio::net::TcpListener;
use tower_sessions::cookie::SameSite;
use tracing::info;
mod auth;
mod config;
mod dto;
mod error;
mod extractors;
mod routes;
mod state;
use crate::auth::setup_auth_layer;
use crate::config::Config;
use crate::config::{AuthMode, Config};
use crate::state::AppState;
#[tokio::main]
@@ -32,6 +37,7 @@ async fn main() -> anyhow::Result<()> {
let config = Config::from_env();
info!("Starting server on {}:{}", config.host, config.port);
info!("Auth mode: {:?}", config.auth_mode);
// Setup database
tracing::info!("Connecting to database: {}", config.database_url);
@@ -51,6 +57,7 @@ async fn main() -> anyhow::Result<()> {
let state = AppState::new(user_service, config.clone()).await?;
// Build session store (needed for OIDC flow even in JWT mode)
let session_store = build_session_store(&db_pool)
.await
.map_err(|e| anyhow::anyhow!(e))?;
@@ -61,30 +68,85 @@ async fn main() -> anyhow::Result<()> {
let session_layer = SessionManagerLayer::new(session_store)
.with_secure(config.secure_cookie)
.with_same_site(SameSite::Lax)
.with_expiry(Expiry::OnInactivity(Duration::days(7)));
let auth_layer = setup_auth_layer(session_layer, user_repo).await?;
let server_config = ServerConfig {
cors_origins: config.cors_allowed_origins.clone(),
session_secret: Some(config.session_secret.clone()),
};
let app = Router::new()
.nest("/api/v1", routes::api_v1_router())
.layer(auth_layer)
.with_state(state);
// Build the app with appropriate auth layers based on config
let app = build_app(state, session_layer, user_repo, &config).await?;
let app = apply_standard_middleware(app, &server_config);
let addr: SocketAddr = format!("{}:{}", config.host, config.port).parse()?;
let listener = TcpListener::bind(addr).await?;
tracing::info!("🚀 API server running at http://{}", addr);
tracing::info!("🔒 Authentication enabled (axum-login)");
log_auth_info(&config);
tracing::info!("📝 API endpoints available at /api/v1/...");
axum::serve(listener, app).await?;
Ok(())
}
/// Build the application router with appropriate auth layers
#[allow(unused_variables)] // config/user_repo used conditionally based on features
async fn build_app(
state: AppState,
session_layer: SessionManagerLayer<infra::session_store::InfraSessionStore>,
user_repo: std::sync::Arc<dyn domain::UserRepository>,
config: &Config,
) -> anyhow::Result<Router> {
let app = Router::new()
.nest("/api/v1", routes::api_v1_router())
.with_state(state);
// When auth-axum-login feature is enabled, always apply the auth layer.
// This is needed because:
// 1. OIDC callback uses AuthSession for state management
// 2. Session-based login/register routes use it
// 3. The "JWT mode" just changes what the login endpoint returns, not the underlying session support
#[cfg(feature = "auth-axum-login")]
{
let auth_layer = auth::setup_auth_layer(session_layer, user_repo).await?;
return Ok(app.layer(auth_layer));
}
// When auth-axum-login is not compiled in, just use session layer for OIDC flow
#[cfg(not(feature = "auth-axum-login"))]
{
let _ = user_repo; // Suppress unused warning
Ok(app.layer(session_layer))
}
}
/// Log authentication info based on enabled features and config
fn log_auth_info(config: &Config) {
match config.auth_mode {
AuthMode::Session => {
tracing::info!("🔒 Authentication mode: Session (cookie-based)");
}
AuthMode::Jwt => {
tracing::info!("🔒 Authentication mode: JWT (Bearer token)");
}
AuthMode::Both => {
tracing::info!("🔒 Authentication mode: Both (JWT + Session)");
}
}
#[cfg(feature = "auth-axum-login")]
tracing::info!(" ✓ Session auth enabled (axum-login)");
#[cfg(feature = "auth-jwt")]
if config.jwt_secret.is_some() {
tracing::info!(" ✓ JWT auth enabled");
}
#[cfg(feature = "auth-oidc")]
if config.oidc_issuer.is_some() {
tracing::info!(" ✓ OIDC integration enabled");
}
}

View File

@@ -1,35 +1,75 @@
use axum::http::StatusCode;
//! Authentication routes
//!
//! Provides login, register, logout, and token endpoints.
//! Supports both session-based and JWT-based authentication.
#[cfg(feature = "auth-oidc")]
use axum::response::Response;
use axum::{
Router,
extract::{Json, State},
http::StatusCode,
response::IntoResponse,
routing::post,
routing::{get, post},
};
use serde::Serialize;
#[cfg(feature = "auth-oidc")]
use tower_sessions::Session;
#[cfg(feature = "auth-axum-login")]
use crate::config::AuthMode;
use crate::{
dto::{LoginRequest, RegisterRequest, UserResponse},
error::ApiError,
extractors::CurrentUser,
state::AppState,
};
#[cfg(feature = "auth-axum-login")]
use domain::{DomainError, Email};
use tower_sessions::Session;
/// Token response for JWT authentication
#[derive(Debug, Serialize)]
pub struct TokenResponse {
pub access_token: String,
pub token_type: String,
pub expires_in: u64,
}
/// Login response that can be either a user (session mode) or a token (JWT mode)
#[derive(Debug, Serialize)]
#[serde(untagged)]
pub enum LoginResponse {
User(UserResponse),
Token(TokenResponse),
}
pub fn router() -> Router<AppState> {
let r = Router::new()
.route("/login", post(login))
.route("/register", post(register))
.route("/logout", post(logout))
.route("/me", post(me));
.route("/me", get(me));
// Add token endpoint for getting JWT from session
#[cfg(feature = "auth-jwt")]
let r = r.route("/token", post(get_token));
#[cfg(feature = "auth-oidc")]
let r = r
.route("/login/oidc", axum::routing::get(oidc_login))
.route("/auth/callback", axum::routing::get(oidc_callback));
.route("/login/oidc", get(oidc_login))
.route("/callback", get(oidc_callback));
r
}
/// Login endpoint
///
/// In session mode: Creates a session and returns user info
/// In JWT mode: Validates credentials and returns a JWT token
/// In both mode: Creates session AND returns JWT token
#[cfg(feature = "auth-axum-login")]
async fn login(
State(state): State<AppState>,
mut auth_session: crate::auth::AuthSession,
Json(payload): Json<LoginRequest>,
) -> Result<impl IntoResponse, ApiError> {
@@ -45,21 +85,56 @@ async fn login(
None => return Err(ApiError::Validation("Invalid credentials".to_string())),
};
auth_session
.login(&user)
.await
.map_err(|_| ApiError::Internal("Login failed".to_string()))?;
let auth_mode = state.config.auth_mode;
// In session or both mode, create session
if matches!(auth_mode, AuthMode::Session | AuthMode::Both) {
auth_session
.login(&user)
.await
.map_err(|_| ApiError::Internal("Login failed".to_string()))?;
}
// In JWT or both mode, return token
#[cfg(feature = "auth-jwt")]
if matches!(auth_mode, AuthMode::Jwt | AuthMode::Both) {
let token = create_jwt_for_user(&user.0, &state)?;
return Ok((
StatusCode::OK,
Json(LoginResponse::Token(TokenResponse {
access_token: token,
token_type: "Bearer".to_string(),
expires_in: state.config.jwt_expiry_hours * 3600,
})),
));
}
// Session mode: return user info
Ok((
StatusCode::OK,
Json(UserResponse {
Json(LoginResponse::User(UserResponse {
id: user.0.id,
email: user.0.email.into_inner(),
created_at: user.0.created_at,
}),
})),
))
}
/// Fallback login when auth-axum-login is not enabled
/// Without auth-axum-login, password-based authentication is not available.
/// Use OIDC login instead: GET /api/v1/auth/login/oidc
#[cfg(not(feature = "auth-axum-login"))]
async fn login(
State(_state): State<AppState>,
Json(_payload): Json<LoginRequest>,
) -> Result<(StatusCode, Json<LoginResponse>), ApiError> {
Err(ApiError::Internal(
"Password-based login not available. auth-axum-login feature is required. Use OIDC login at /api/v1/auth/login/oidc instead.".to_string(),
))
}
/// Register endpoint
#[cfg(feature = "auth-axum-login")]
async fn register(
State(state): State<AppState>,
mut auth_session: crate::auth::AuthSession,
@@ -76,9 +151,6 @@ async fn register(
)));
}
// Note: In a real app, you would hash the password here.
// This template uses a simplified User::new which doesn't take password.
// You should extend User to handle passwords or use an OIDC flow.
let email = Email::try_from(payload.email).map_err(|e| ApiError::Validation(e.to_string()))?;
// Using email as subject for local auth for now
@@ -87,24 +159,54 @@ async fn register(
.find_or_create(&email.as_ref().to_string(), email.as_ref())
.await?;
// Log the user in
let auth_user = crate::auth::AuthUser(user.clone());
let auth_mode = state.config.auth_mode;
auth_session
.login(&auth_user)
.await
.map_err(|_| ApiError::Internal("Login failed".to_string()))?;
// In session or both mode, create session
if matches!(auth_mode, AuthMode::Session | AuthMode::Both) {
let auth_user = crate::auth::AuthUser(user.clone());
auth_session
.login(&auth_user)
.await
.map_err(|_| ApiError::Internal("Login failed".to_string()))?;
}
// In JWT or both mode, return token
#[cfg(feature = "auth-jwt")]
if matches!(auth_mode, AuthMode::Jwt | AuthMode::Both) {
let token = create_jwt_for_user(&user, &state)?;
return Ok((
StatusCode::CREATED,
Json(LoginResponse::Token(TokenResponse {
access_token: token,
token_type: "Bearer".to_string(),
expires_in: state.config.jwt_expiry_hours * 3600,
})),
));
}
Ok((
StatusCode::CREATED,
Json(UserResponse {
Json(LoginResponse::User(UserResponse {
id: user.id,
email: user.email.into_inner(),
created_at: user.created_at,
}),
})),
))
}
/// Fallback register when auth-axum-login is not enabled
#[cfg(not(feature = "auth-axum-login"))]
async fn register(
State(_state): State<AppState>,
Json(_payload): Json<RegisterRequest>,
) -> Result<(StatusCode, Json<LoginResponse>), ApiError> {
Err(ApiError::Internal(
"Session-based registration not available. Use JWT token endpoint.".to_string(),
))
}
/// Logout endpoint
#[cfg(feature = "auth-axum-login")]
async fn logout(mut auth_session: crate::auth::AuthSession) -> impl IntoResponse {
match auth_session.logout().await {
Ok(_) => StatusCode::OK,
@@ -112,23 +214,61 @@ async fn logout(mut auth_session: crate::auth::AuthSession) -> impl IntoResponse
}
}
async fn me(auth_session: crate::auth::AuthSession) -> Result<impl IntoResponse, ApiError> {
let user = auth_session
.user
.ok_or(ApiError::Unauthorized("Not logged in".to_string()))?;
/// Fallback logout when auth-axum-login is not enabled
#[cfg(not(feature = "auth-axum-login"))]
async fn logout() -> impl IntoResponse {
// JWT tokens can't be "logged out" server-side without a blocklist
// Just return OK
StatusCode::OK
}
/// Get current user info
async fn me(CurrentUser(user): CurrentUser) -> Result<impl IntoResponse, ApiError> {
Ok(Json(UserResponse {
id: user.0.id,
email: user.0.email.into_inner(),
created_at: user.0.created_at,
id: user.id,
email: user.email.into_inner(),
created_at: user.created_at,
}))
}
#[cfg(feature = "auth-oidc")]
async fn oidc_login(
/// Get a JWT token for the current session user
///
/// This allows session-authenticated users to obtain a JWT for API access.
#[cfg(feature = "auth-jwt")]
async fn get_token(
State(state): State<AppState>,
session: Session,
CurrentUser(user): CurrentUser,
) -> Result<impl IntoResponse, ApiError> {
let token = create_jwt_for_user(&user, &state)?;
Ok(Json(TokenResponse {
access_token: token,
token_type: "Bearer".to_string(),
expires_in: state.config.jwt_expiry_hours * 3600,
}))
}
/// Helper to create JWT for a user
#[cfg(feature = "auth-jwt")]
fn create_jwt_for_user(user: &domain::User, state: &AppState) -> Result<String, ApiError> {
let validator = state
.jwt_validator
.as_ref()
.ok_or_else(|| ApiError::Internal("JWT not configured".to_string()))?;
validator
.create_token(user)
.map_err(|e| ApiError::Internal(format!("Failed to create token: {}", e)))
}
// ============================================================================
// OIDC Routes
// ============================================================================
#[cfg(feature = "auth-oidc")]
async fn oidc_login(State(state): State<AppState>, session: Session) -> Result<Response, ApiError> {
use axum::http::header;
let service = state
.oidc_service
.as_ref()
@@ -149,7 +289,19 @@ async fn oidc_login(
.await
.map_err(|_| ApiError::Internal("Session error".into()))?;
Ok(axum::response::Redirect::to(&url))
let response = axum::response::Redirect::to(&url).into_response();
let (mut parts, body) = response.into_parts();
parts.headers.insert(
header::CACHE_CONTROL,
"no-cache, no-store, must-revalidate".parse().unwrap(),
);
parts
.headers
.insert(header::PRAGMA, "no-cache".parse().unwrap());
parts.headers.insert(header::EXPIRES, "0".parse().unwrap());
Ok(Response::from_parts(parts, body))
}
#[cfg(feature = "auth-oidc")]
@@ -159,7 +311,7 @@ struct CallbackParams {
state: String,
}
#[cfg(feature = "auth-oidc")]
#[cfg(all(feature = "auth-oidc", feature = "auth-axum-login"))]
async fn oidc_callback(
State(state): State<AppState>,
session: Session,
@@ -181,7 +333,6 @@ async fn oidc_callback(
return Err(ApiError::Validation("Invalid CSRF token".into()));
}
// 2. Retrieve secrets
let stored_pkce: String = session
.get("oidc_pkce")
.await
@@ -204,11 +355,17 @@ async fn oidc_callback(
.await
.map_err(|e| ApiError::Internal(e.to_string()))?;
auth_session
.login(&crate::auth::AuthUser(user))
.await
.map_err(|_| ApiError::Internal("Login failed".into()))?;
let auth_mode = state.config.auth_mode;
// In session or both mode, create session
if matches!(auth_mode, AuthMode::Session | AuthMode::Both) {
auth_session
.login(&crate::auth::AuthUser(user.clone()))
.await
.map_err(|_| ApiError::Internal("Login failed".into()))?;
}
// Clean up OIDC state
let _: Option<String> = session
.remove("oidc_csrf")
.await
@@ -222,5 +379,101 @@ async fn oidc_callback(
.await
.map_err(|_| ApiError::Internal("Session error".into()))?;
Ok(axum::response::Redirect::to("/"))
// In JWT mode, return token as JSON
#[cfg(feature = "auth-jwt")]
if matches!(auth_mode, AuthMode::Jwt | AuthMode::Both) {
let token = create_jwt_for_user(&user, &state)?;
return Ok(Json(TokenResponse {
access_token: token,
token_type: "Bearer".to_string(),
expires_in: state.config.jwt_expiry_hours * 3600,
})
.into_response());
}
// Session mode: return user info
Ok(Json(UserResponse {
id: user.id,
email: user.email.into_inner(),
created_at: user.created_at,
})
.into_response())
}
/// Fallback OIDC callback when auth-axum-login is not enabled
#[cfg(all(feature = "auth-oidc", not(feature = "auth-axum-login")))]
async fn oidc_callback(
State(state): State<AppState>,
session: Session,
axum::extract::Query(params): axum::extract::Query<CallbackParams>,
) -> Result<impl IntoResponse, ApiError> {
let service = state
.oidc_service
.as_ref()
.ok_or(ApiError::Internal("OIDC not configured".into()))?;
let stored_csrf: String = session
.get("oidc_csrf")
.await
.map_err(|_| ApiError::Internal("Session error".into()))?
.ok_or(ApiError::Validation("Missing CSRF token".into()))?;
if params.state != stored_csrf {
return Err(ApiError::Validation("Invalid CSRF token".into()));
}
let stored_pkce: String = session
.get("oidc_pkce")
.await
.map_err(|_| ApiError::Internal("Session error".into()))?
.ok_or(ApiError::Validation("Missing PKCE".into()))?;
let stored_nonce: String = session
.get("oidc_nonce")
.await
.map_err(|_| ApiError::Internal("Session error".into()))?
.ok_or(ApiError::Validation("Missing Nonce".into()))?;
let oidc_user = service
.resolve_callback(params.code, stored_nonce, stored_pkce)
.await
.map_err(|e| ApiError::Internal(e.to_string()))?;
let user = state
.user_service
.find_or_create(&oidc_user.subject, &oidc_user.email)
.await
.map_err(|e| ApiError::Internal(e.to_string()))?;
// Clean up OIDC state
let _: Option<String> = session
.remove("oidc_csrf")
.await
.map_err(|_| ApiError::Internal("Session error".into()))?;
let _: Option<String> = session
.remove("oidc_pkce")
.await
.map_err(|_| ApiError::Internal("Session error".into()))?;
let _: Option<String> = session
.remove("oidc_nonce")
.await
.map_err(|_| ApiError::Internal("Session error".into()))?;
// Return token as JSON
#[cfg(feature = "auth-jwt")]
{
let token = create_jwt_for_user(&user, &state)?;
return Ok(Json(TokenResponse {
access_token: token,
token_type: "Bearer".to_string(),
expires_in: state.config.jwt_expiry_hours * 3600,
}));
}
#[cfg(not(feature = "auth-jwt"))]
{
let _ = user; // Suppress unused warning
Err(ApiError::Internal(
"No auth backend available for OIDC callback".to_string(),
))
}
}

View File

@@ -3,11 +3,13 @@
//! Holds shared state for the application.
use axum::extract::FromRef;
#[cfg(feature = "auth-jwt")]
use infra::auth::jwt::{JwtConfig, JwtValidator};
#[cfg(feature = "auth-oidc")]
use infra::auth::oidc::OidcService;
use std::sync::Arc;
use crate::config::Config;
use crate::config::{AuthMode, Config};
use domain::UserService;
#[derive(Clone)]
@@ -15,32 +17,80 @@ pub struct AppState {
pub user_service: Arc<UserService>,
#[cfg(feature = "auth-oidc")]
pub oidc_service: Option<Arc<OidcService>>,
#[cfg(feature = "auth-jwt")]
pub jwt_validator: Option<Arc<JwtValidator>>,
pub config: Arc<Config>,
}
impl AppState {
pub async fn new(user_service: UserService, config: Config) -> anyhow::Result<Self> {
#[cfg(feature = "auth-oidc")]
let oidc_service = if let (Some(issuer), Some(id), Some(secret), Some(redirect)) = (
let oidc_service = if let (Some(issuer), Some(id), secret, Some(redirect), resource_id) = (
&config.oidc_issuer,
&config.oidc_client_id,
&config.oidc_client_secret,
&config.oidc_redirect_url,
&config.oidc_resource_id,
) {
tracing::info!("Initializing OIDC service with issuer: {}", issuer);
Some(Arc::new(
OidcService::new(issuer.clone(), id.clone(), secret.clone(), redirect.clone())
.await?,
OidcService::new(
issuer.clone(),
id.clone(),
secret.clone().unwrap_or_default(),
redirect.clone(),
resource_id.clone(),
)
.await?,
))
} else {
None
};
#[cfg(feature = "auth-jwt")]
let jwt_validator = if matches!(config.auth_mode, AuthMode::Jwt | AuthMode::Both) {
// Use provided secret or fall back to a development secret
let secret = if let Some(ref s) = config.jwt_secret {
if s.is_empty() { None } else { Some(s.clone()) }
} else {
None
};
let secret = match secret {
Some(s) => s,
None => {
if config.is_production {
anyhow::bail!(
"JWT_SECRET is required when AUTH_MODE is 'jwt' or 'both' in production"
);
}
// Use a development-only default secret
tracing::warn!(
"⚠️ JWT_SECRET not set - using insecure development secret. DO NOT USE IN PRODUCTION!"
);
"k-template-dev-secret-not-for-production-use-only".to_string()
}
};
tracing::info!("Initializing JWT validator");
let jwt_config = JwtConfig::new(
secret,
config.jwt_issuer.clone(),
config.jwt_audience.clone(),
Some(config.jwt_expiry_hours),
config.is_production,
)?;
Some(Arc::new(JwtValidator::new(jwt_config)))
} else {
None
};
Ok(Self {
user_service: Arc::new(user_service),
#[cfg(feature = "auth-oidc")]
oidc_service,
#[cfg(feature = "auth-jwt")]
jwt_validator,
config: Arc::new(config),
})
}