feat: enhance application state management with cookie support
- Added cookie key to AppState for managing session cookies. - Updated AppState initialization to derive cookie key from configuration. - Removed session-based authentication option from cargo-generate prompts. - Refactored JWT authentication logic to improve clarity and error handling. - Updated password validation to align with NIST recommendations (minimum length increased). - Removed unused session store implementation and related code. - Improved error handling in user repository for unique constraint violations. - Refactored OIDC service to include state management for authentication flow. - Cleaned up dependencies in Cargo.toml and Cargo.toml.template for clarity and efficiency.
This commit is contained in:
@@ -5,13 +5,11 @@ edition = "2024"
|
||||
default-run = "api"
|
||||
|
||||
[features]
|
||||
default = ["sqlite"]
|
||||
sqlite = ["infra/sqlite", "tower-sessions-sqlx-store/sqlite"]
|
||||
postgres = ["infra/postgres", "tower-sessions-sqlx-store/postgres"]
|
||||
auth-axum-login = ["infra/auth-axum-login"]
|
||||
default = ["sqlite", "auth-jwt"]
|
||||
sqlite = ["infra/sqlite"]
|
||||
postgres = ["infra/postgres"]
|
||||
auth-oidc = ["infra/auth-oidc"]
|
||||
auth-jwt = ["infra/auth-jwt"]
|
||||
auth-full = ["auth-axum-login", "auth-oidc", "auth-jwt"]
|
||||
|
||||
[dependencies]
|
||||
k-core = { git = "https://git.gabrielkaszewski.dev/GKaszewski/k-core", features = [
|
||||
@@ -19,24 +17,16 @@ k-core = { git = "https://git.gabrielkaszewski.dev/GKaszewski/k-core", features
|
||||
"db-sqlx",
|
||||
"sqlite",
|
||||
"http",
|
||||
"auth",
|
||||
"sessions-db",
|
||||
] }
|
||||
domain = { path = "../domain" }
|
||||
infra = { path = "../infra", default-features = false, features = ["sqlite"] }
|
||||
|
||||
#Web framework
|
||||
# Web framework
|
||||
axum = { version = "0.8.8", features = ["macros"] }
|
||||
axum-extra = { version = "0.10", features = ["cookie-private", "cookie-key-expansion"] }
|
||||
tower = "0.5.2"
|
||||
tower-http = { version = "0.6.2", features = ["cors", "trace"] }
|
||||
|
||||
# Authentication
|
||||
# Moved to infra
|
||||
tower-sessions-sqlx-store = { version = "0.15", features = ["sqlite"] }
|
||||
# password-auth removed
|
||||
time = "0.3"
|
||||
async-trait = "0.1.89"
|
||||
|
||||
# Async runtime
|
||||
tokio = { version = "1.48.0", features = ["full"] }
|
||||
|
||||
@@ -44,8 +34,6 @@ tokio = { version = "1.48.0", features = ["full"] }
|
||||
serde = { version = "1.0.228", features = ["derive"] }
|
||||
serde_json = "1.0"
|
||||
|
||||
# Validation via domain newtypes (Email, Password)
|
||||
|
||||
# Error handling
|
||||
thiserror = "2.0.17"
|
||||
anyhow = "1.0"
|
||||
@@ -56,8 +44,6 @@ uuid = { version = "1.19.0", features = ["v4", "serde"] }
|
||||
|
||||
# Logging
|
||||
tracing = "0.1"
|
||||
tracing-subscriber = { version = "0.3.22", features = ["env-filter"] }
|
||||
|
||||
dotenvy = "0.15.7"
|
||||
config = "0.15.19"
|
||||
tower-sessions = "0.14.0"
|
||||
time = "0.3"
|
||||
|
||||
@@ -1,17 +1,15 @@
|
||||
[package]
|
||||
name = "api"
|
||||
name = "{{project_name}}"
|
||||
version = "0.1.0"
|
||||
edition = "2024"
|
||||
default-run = "api"
|
||||
default-run = "{{project_name}}"
|
||||
|
||||
[features]
|
||||
default = ["{{database}}"{% if auth_session %}, "auth-axum-login"{% endif %}{% if auth_oidc %}, "auth-oidc"{% endif %}{% if auth_jwt %}, "auth-jwt"{% endif %}]
|
||||
sqlite = ["infra/sqlite", "tower-sessions-sqlx-store/sqlite"]
|
||||
postgres = ["infra/postgres", "tower-sessions-sqlx-store/postgres"]
|
||||
auth-axum-login = ["infra/auth-axum-login"]
|
||||
default = ["{{database}}"{% if auth_oidc %}, "auth-oidc"{% endif %}{% if auth_jwt %}, "auth-jwt"{% endif %}]
|
||||
sqlite = ["infra/sqlite"]
|
||||
postgres = ["infra/postgres"]
|
||||
auth-oidc = ["infra/auth-oidc"]
|
||||
auth-jwt = ["infra/auth-jwt"]
|
||||
auth-full = ["auth-axum-login", "auth-oidc", "auth-jwt"]
|
||||
|
||||
[dependencies]
|
||||
k-core = { git = "https://git.gabrielkaszewski.dev/GKaszewski/k-core", features = [
|
||||
@@ -19,24 +17,16 @@ k-core = { git = "https://git.gabrielkaszewski.dev/GKaszewski/k-core", features
|
||||
"db-sqlx",
|
||||
"{{database}}",
|
||||
"http",
|
||||
"auth",
|
||||
"sessions-db",
|
||||
] }
|
||||
domain = { path = "../domain" }
|
||||
infra = { path = "../infra", default-features = false, features = ["{{database}}"] }
|
||||
|
||||
#Web framework
|
||||
# Web framework
|
||||
axum = { version = "0.8.8", features = ["macros"] }
|
||||
axum-extra = { version = "0.10", features = ["cookie-private", "cookie-key-expansion"] }
|
||||
tower = "0.5.2"
|
||||
tower-http = { version = "0.6.2", features = ["cors", "trace"] }
|
||||
|
||||
# Authentication
|
||||
# Moved to infra
|
||||
tower-sessions-sqlx-store = { version = "0.15", features = ["{{database}}"] }
|
||||
# password-auth removed
|
||||
time = "0.3"
|
||||
async-trait = "0.1.89"
|
||||
|
||||
# Async runtime
|
||||
tokio = { version = "1.48.0", features = ["full"] }
|
||||
|
||||
@@ -44,8 +34,6 @@ tokio = { version = "1.48.0", features = ["full"] }
|
||||
serde = { version = "1.0.228", features = ["derive"] }
|
||||
serde_json = "1.0"
|
||||
|
||||
# Validation via domain newtypes (Email, Password)
|
||||
|
||||
# Error handling
|
||||
thiserror = "2.0.17"
|
||||
anyhow = "1.0"
|
||||
@@ -56,8 +44,6 @@ uuid = { version = "1.19.0", features = ["v4", "serde"] }
|
||||
|
||||
# Logging
|
||||
tracing = "0.1"
|
||||
tracing-subscriber = { version = "0.3.22", features = ["env-filter"] }
|
||||
|
||||
dotenvy = "0.15.7"
|
||||
config = "0.15.19"
|
||||
tower-sessions = "0.14.0"
|
||||
time = "0.3"
|
||||
|
||||
@@ -1,27 +0,0 @@
|
||||
//! Authentication logic
|
||||
//!
|
||||
//! Proxies to infra implementation if enabled.
|
||||
|
||||
#[cfg(feature = "auth-axum-login")]
|
||||
use std::sync::Arc;
|
||||
|
||||
#[cfg(feature = "auth-axum-login")]
|
||||
use domain::UserRepository;
|
||||
#[cfg(feature = "auth-axum-login")]
|
||||
use infra::session_store::{InfraSessionStore, SessionManagerLayer};
|
||||
|
||||
#[cfg(feature = "auth-axum-login")]
|
||||
use crate::error::ApiError;
|
||||
|
||||
#[cfg(feature = "auth-axum-login")]
|
||||
pub use infra::auth::backend::{AuthManagerLayer, AuthSession, AuthUser, Credentials};
|
||||
|
||||
#[cfg(feature = "auth-axum-login")]
|
||||
pub async fn setup_auth_layer(
|
||||
session_layer: SessionManagerLayer<InfraSessionStore>,
|
||||
user_repo: Arc<dyn UserRepository>,
|
||||
) -> Result<AuthManagerLayer, ApiError> {
|
||||
infra::auth::backend::setup_auth_layer(session_layer, user_repo)
|
||||
.await
|
||||
.map_err(|e| ApiError::Internal(e.to_string()))
|
||||
}
|
||||
@@ -4,52 +4,16 @@
|
||||
|
||||
use std::env;
|
||||
|
||||
use serde::Deserialize;
|
||||
|
||||
/// Authentication mode - determines how the API authenticates requests
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Deserialize)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum AuthMode {
|
||||
/// Session-based authentication using cookies (default for backward compatibility)
|
||||
#[default]
|
||||
Session,
|
||||
/// JWT-based authentication using Bearer tokens
|
||||
Jwt,
|
||||
/// Support both session and JWT authentication (try JWT first, then session)
|
||||
Both,
|
||||
}
|
||||
|
||||
impl AuthMode {
|
||||
/// Parse auth mode from string
|
||||
pub fn from_str(s: &str) -> Self {
|
||||
match s.to_lowercase().as_str() {
|
||||
"jwt" => AuthMode::Jwt,
|
||||
"both" => AuthMode::Both,
|
||||
_ => AuthMode::Session,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
//todo: replace with newtypes
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
/// Application configuration loaded from environment variables
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Config {
|
||||
pub database_url: String,
|
||||
pub session_secret: String,
|
||||
pub cookie_secret: String,
|
||||
pub cors_allowed_origins: Vec<String>,
|
||||
|
||||
#[serde(default = "default_port")]
|
||||
pub port: u16,
|
||||
|
||||
#[serde(default = "default_host")]
|
||||
pub host: String,
|
||||
|
||||
#[serde(default = "default_secure_cookie")]
|
||||
pub secure_cookie: bool,
|
||||
|
||||
#[serde(default = "default_db_max_connections")]
|
||||
pub db_max_connections: u32,
|
||||
|
||||
#[serde(default = "default_db_min_connections")]
|
||||
pub db_min_connections: u32,
|
||||
|
||||
// OIDC configuration
|
||||
@@ -59,57 +23,18 @@ pub struct Config {
|
||||
pub oidc_redirect_url: Option<String>,
|
||||
pub oidc_resource_id: Option<String>,
|
||||
|
||||
// Auth mode configuration
|
||||
#[serde(default)]
|
||||
pub auth_mode: AuthMode,
|
||||
|
||||
// JWT configuration
|
||||
pub jwt_secret: Option<String>,
|
||||
pub jwt_issuer: Option<String>,
|
||||
pub jwt_audience: Option<String>,
|
||||
#[serde(default = "default_jwt_expiry_hours")]
|
||||
pub jwt_expiry_hours: u64,
|
||||
|
||||
/// Whether the application is running in production mode
|
||||
#[serde(default)]
|
||||
pub is_production: bool,
|
||||
}
|
||||
|
||||
fn default_secure_cookie() -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
fn default_db_max_connections() -> u32 {
|
||||
5
|
||||
}
|
||||
|
||||
fn default_db_min_connections() -> u32 {
|
||||
1
|
||||
}
|
||||
|
||||
fn default_port() -> u16 {
|
||||
3000
|
||||
}
|
||||
|
||||
fn default_host() -> String {
|
||||
"127.0.0.1".to_string()
|
||||
}
|
||||
|
||||
fn default_jwt_expiry_hours() -> u64 {
|
||||
24
|
||||
}
|
||||
|
||||
impl Config {
|
||||
pub fn new() -> Result<Self, config::ConfigError> {
|
||||
config::Config::builder()
|
||||
.add_source(config::Environment::default())
|
||||
//.add_source(config::File::with_name(".env").required(false)) // Optional .env file
|
||||
.build()?
|
||||
.try_deserialize()
|
||||
}
|
||||
|
||||
pub fn from_env() -> Self {
|
||||
// Load .env file if it exists, ignore errors if it doesn't
|
||||
let _ = dotenvy::dotenv();
|
||||
|
||||
let host = env::var("HOST").unwrap_or_else(|_| "127.0.0.1".to_string());
|
||||
@@ -121,8 +46,10 @@ impl Config {
|
||||
let database_url =
|
||||
env::var("DATABASE_URL").unwrap_or_else(|_| "sqlite:data.db?mode=rwc".to_string());
|
||||
|
||||
let session_secret = env::var("SESSION_SECRET").unwrap_or_else(|_| {
|
||||
"k-notes-super-secret-key-must-be-at-least-64-bytes-long!!!!".to_string()
|
||||
// Cookie secret for PrivateCookieJar (OIDC state encryption).
|
||||
// Must be at least 64 bytes in production.
|
||||
let cookie_secret = env::var("COOKIE_SECRET").unwrap_or_else(|_| {
|
||||
"k-template-cookie-secret-key-must-be-at-least-64-bytes-long!!".to_string()
|
||||
});
|
||||
|
||||
let cors_origins_str = env::var("CORS_ALLOWED_ORIGINS")
|
||||
@@ -155,12 +82,6 @@ impl Config {
|
||||
let oidc_redirect_url = env::var("OIDC_REDIRECT_URL").ok();
|
||||
let oidc_resource_id = env::var("OIDC_RESOURCE_ID").ok();
|
||||
|
||||
// Auth mode configuration
|
||||
let auth_mode = env::var("AUTH_MODE")
|
||||
.map(|s| AuthMode::from_str(&s))
|
||||
.unwrap_or_default();
|
||||
|
||||
// JWT configuration
|
||||
let jwt_secret = env::var("JWT_SECRET").ok();
|
||||
let jwt_issuer = env::var("JWT_ISSUER").ok();
|
||||
let jwt_audience = env::var("JWT_AUDIENCE").ok();
|
||||
@@ -178,7 +99,7 @@ impl Config {
|
||||
host,
|
||||
port,
|
||||
database_url,
|
||||
session_secret,
|
||||
cookie_secret,
|
||||
cors_allowed_origins,
|
||||
secure_cookie,
|
||||
db_max_connections,
|
||||
@@ -188,7 +109,6 @@ impl Config {
|
||||
oidc_client_secret,
|
||||
oidc_redirect_url,
|
||||
oidc_resource_id,
|
||||
auth_mode,
|
||||
jwt_secret,
|
||||
jwt_issuer,
|
||||
jwt_audience,
|
||||
|
||||
@@ -10,21 +10,19 @@ use uuid::Uuid;
|
||||
|
||||
/// Login request with validated email and password newtypes
|
||||
#[derive(Debug, Deserialize)]
|
||||
#[allow(dead_code)]
|
||||
pub struct LoginRequest {
|
||||
/// Email is validated on deserialization
|
||||
pub email: Email,
|
||||
/// Password is validated on deserialization (min 6 chars)
|
||||
/// Password is validated on deserialization (min 8 chars)
|
||||
pub password: Password,
|
||||
}
|
||||
|
||||
/// Register request with validated email and password newtypes
|
||||
#[derive(Debug, Deserialize)]
|
||||
#[allow(dead_code)]
|
||||
pub struct RegisterRequest {
|
||||
/// Email is validated on deserialization
|
||||
pub email: Email,
|
||||
/// Password is validated on deserialization (min 6 chars)
|
||||
/// Password is validated on deserialization (min 8 chars)
|
||||
pub password: Password,
|
||||
}
|
||||
|
||||
@@ -36,6 +34,14 @@ pub struct UserResponse {
|
||||
pub created_at: DateTime<Utc>,
|
||||
}
|
||||
|
||||
/// JWT token response
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct TokenResponse {
|
||||
pub access_token: String,
|
||||
pub token_type: String,
|
||||
pub expires_in: u64,
|
||||
}
|
||||
|
||||
/// System configuration response
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct ConfigResponse {
|
||||
|
||||
@@ -14,7 +14,6 @@ use domain::DomainError;
|
||||
|
||||
/// API-level errors
|
||||
#[derive(Debug, Error)]
|
||||
#[allow(dead_code)] // Some variants are reserved for future use
|
||||
pub enum ApiError {
|
||||
#[error("{0}")]
|
||||
Domain(#[from] DomainError),
|
||||
@@ -51,11 +50,17 @@ impl IntoResponse for ApiError {
|
||||
|
||||
DomainError::ValidationError(_) => StatusCode::BAD_REQUEST,
|
||||
|
||||
DomainError::Unauthorized(_) => StatusCode::FORBIDDEN,
|
||||
// Unauthenticated = not logged in → 401
|
||||
DomainError::Unauthenticated(_) => StatusCode::UNAUTHORIZED,
|
||||
|
||||
// Forbidden = not allowed to perform action → 403
|
||||
DomainError::Forbidden(_) => StatusCode::FORBIDDEN,
|
||||
|
||||
DomainError::RepositoryError(_) | DomainError::InfrastructureError(_) => {
|
||||
StatusCode::INTERNAL_SERVER_ERROR
|
||||
}
|
||||
|
||||
_ => StatusCode::INTERNAL_SERVER_ERROR,
|
||||
};
|
||||
|
||||
(
|
||||
@@ -76,7 +81,6 @@ impl IntoResponse for ApiError {
|
||||
),
|
||||
|
||||
ApiError::Internal(msg) => {
|
||||
// Log internal errors but don't expose details
|
||||
tracing::error!("Internal error: {}", msg);
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
@@ -108,7 +112,6 @@ impl IntoResponse for ApiError {
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(dead_code)] // Helper constructors for future use
|
||||
impl ApiError {
|
||||
pub fn validation(msg: impl Into<String>) -> Self {
|
||||
Self::Validation(msg.into())
|
||||
@@ -120,5 +123,4 @@ impl ApiError {
|
||||
}
|
||||
|
||||
/// Result type alias for API handlers
|
||||
#[allow(dead_code)]
|
||||
pub type ApiResult<T> = Result<T, ApiError>;
|
||||
|
||||
@@ -1,20 +1,16 @@
|
||||
//! Auth extractors for API handlers
|
||||
//!
|
||||
//! Provides the `CurrentUser` extractor that works with both session and JWT auth.
|
||||
//! Provides the `CurrentUser` extractor that validates JWT Bearer tokens.
|
||||
|
||||
use axum::{extract::FromRequestParts, http::request::Parts};
|
||||
use domain::User;
|
||||
|
||||
use crate::config::AuthMode;
|
||||
use crate::error::ApiError;
|
||||
use crate::state::AppState;
|
||||
|
||||
/// Extracted current user from the request.
|
||||
///
|
||||
/// This extractor supports multiple authentication methods based on the configured `AuthMode`:
|
||||
/// - `Session`: Uses axum-login session cookies
|
||||
/// - `Jwt`: Uses Bearer token in Authorization header
|
||||
/// - `Both`: Tries JWT first, then falls back to session
|
||||
/// Validates a JWT Bearer token from the `Authorization` header.
|
||||
pub struct CurrentUser(pub User);
|
||||
|
||||
impl FromRequestParts<AppState> for CurrentUser {
|
||||
@@ -24,71 +20,47 @@ impl FromRequestParts<AppState> for CurrentUser {
|
||||
parts: &mut Parts,
|
||||
state: &AppState,
|
||||
) -> Result<Self, Self::Rejection> {
|
||||
let auth_mode = state.config.auth_mode;
|
||||
|
||||
// Try JWT first if enabled
|
||||
#[cfg(feature = "auth-jwt")]
|
||||
if matches!(auth_mode, AuthMode::Jwt | AuthMode::Both) {
|
||||
match try_jwt_auth(parts, state).await {
|
||||
Ok(Some(user)) => return Ok(CurrentUser(user)),
|
||||
Ok(None) => {
|
||||
// No JWT token present, continue to session auth if Both mode
|
||||
if auth_mode == AuthMode::Jwt {
|
||||
return Err(ApiError::Unauthorized(
|
||||
"Missing or invalid Authorization header".to_string(),
|
||||
));
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
// JWT was present but invalid
|
||||
tracing::debug!("JWT auth failed: {}", e);
|
||||
if auth_mode == AuthMode::Jwt {
|
||||
return Err(e);
|
||||
}
|
||||
// In Both mode, continue to try session
|
||||
}
|
||||
}
|
||||
{
|
||||
return match try_jwt_auth(parts, state).await {
|
||||
Ok(user) => Ok(CurrentUser(user)),
|
||||
Err(e) => Err(e),
|
||||
};
|
||||
}
|
||||
|
||||
// Try session auth if enabled
|
||||
#[cfg(feature = "auth-axum-login")]
|
||||
if matches!(auth_mode, AuthMode::Session | AuthMode::Both) {
|
||||
if let Some(user) = try_session_auth(parts).await? {
|
||||
return Ok(CurrentUser(user));
|
||||
}
|
||||
#[cfg(not(feature = "auth-jwt"))]
|
||||
{
|
||||
let _ = (parts, state);
|
||||
Err(ApiError::Unauthorized(
|
||||
"No authentication backend configured".to_string(),
|
||||
))
|
||||
}
|
||||
|
||||
Err(ApiError::Unauthorized("Not authenticated".to_string()))
|
||||
}
|
||||
}
|
||||
|
||||
/// Try to authenticate using JWT Bearer token
|
||||
/// Authenticate using JWT Bearer token
|
||||
#[cfg(feature = "auth-jwt")]
|
||||
async fn try_jwt_auth(parts: &mut Parts, state: &AppState) -> Result<Option<User>, ApiError> {
|
||||
async fn try_jwt_auth(parts: &mut Parts, state: &AppState) -> Result<User, ApiError> {
|
||||
use axum::http::header::AUTHORIZATION;
|
||||
|
||||
// Get Authorization header
|
||||
let auth_header = match parts.headers.get(AUTHORIZATION) {
|
||||
Some(header) => header,
|
||||
None => return Ok(None), // No header = no JWT auth attempted
|
||||
};
|
||||
let auth_header = parts
|
||||
.headers
|
||||
.get(AUTHORIZATION)
|
||||
.ok_or_else(|| ApiError::Unauthorized("Missing Authorization header".to_string()))?;
|
||||
|
||||
let auth_str = auth_header
|
||||
.to_str()
|
||||
.map_err(|_| ApiError::Unauthorized("Invalid Authorization header encoding".to_string()))?;
|
||||
|
||||
// Extract Bearer token
|
||||
let token = auth_str.strip_prefix("Bearer ").ok_or_else(|| {
|
||||
ApiError::Unauthorized("Authorization header must use Bearer scheme".to_string())
|
||||
})?;
|
||||
|
||||
// Get JWT validator
|
||||
let validator = state
|
||||
.jwt_validator
|
||||
.as_ref()
|
||||
.ok_or_else(|| ApiError::Internal("JWT validator not configured".to_string()))?;
|
||||
|
||||
// Validate token
|
||||
let claims = validator.validate_token(token).map_err(|e| {
|
||||
tracing::debug!("JWT validation failed: {:?}", e);
|
||||
match e {
|
||||
@@ -102,7 +74,6 @@ async fn try_jwt_auth(parts: &mut Parts, state: &AppState) -> Result<Option<User
|
||||
}
|
||||
})?;
|
||||
|
||||
// Fetch user from database by ID (subject contains user ID)
|
||||
let user_id: uuid::Uuid = claims
|
||||
.sub
|
||||
.parse()
|
||||
@@ -114,20 +85,5 @@ async fn try_jwt_auth(parts: &mut Parts, state: &AppState) -> Result<Option<User
|
||||
.await
|
||||
.map_err(|e| ApiError::Internal(format!("Failed to fetch user: {}", e)))?;
|
||||
|
||||
Ok(Some(user))
|
||||
}
|
||||
|
||||
/// Try to authenticate using session cookie
|
||||
#[cfg(feature = "auth-axum-login")]
|
||||
async fn try_session_auth(parts: &mut Parts) -> Result<Option<User>, ApiError> {
|
||||
use infra::auth::backend::AuthSession;
|
||||
|
||||
// Check if AuthSession extension is present (added by auth middleware)
|
||||
if let Some(auth_session) = parts.extensions.get::<AuthSession>() {
|
||||
if let Some(auth_user) = &auth_session.user {
|
||||
return Ok(Some(auth_user.0.clone()));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(None)
|
||||
Ok(user)
|
||||
}
|
||||
|
||||
107
api/src/main.rs
107
api/src/main.rs
@@ -1,25 +1,19 @@
|
||||
//! API Server Entry Point
|
||||
//!
|
||||
//! Configures and starts the HTTP server with authentication based on AUTH_MODE.
|
||||
//! Configures and starts the HTTP server with JWT-based authentication.
|
||||
|
||||
use std::net::SocketAddr;
|
||||
use std::time::Duration as StdDuration;
|
||||
|
||||
use axum::Router;
|
||||
use domain::UserService;
|
||||
use infra::factory::build_session_store;
|
||||
use infra::factory::build_user_repository;
|
||||
use infra::run_migrations;
|
||||
use infra::session_store::{Expiry, SessionManagerLayer};
|
||||
use k_core::http::server::ServerConfig;
|
||||
use k_core::http::server::apply_standard_middleware;
|
||||
use k_core::http::server::{ServerConfig, apply_standard_middleware};
|
||||
use k_core::logging;
|
||||
use time::Duration;
|
||||
use tokio::net::TcpListener;
|
||||
use tower_sessions::cookie::SameSite;
|
||||
use tracing::info;
|
||||
|
||||
mod auth;
|
||||
mod config;
|
||||
mod dto;
|
||||
mod error;
|
||||
@@ -27,7 +21,7 @@ mod extractors;
|
||||
mod routes;
|
||||
mod state;
|
||||
|
||||
use crate::config::{AuthMode, Config};
|
||||
use crate::config::Config;
|
||||
use crate::state::AppState;
|
||||
|
||||
#[tokio::main]
|
||||
@@ -37,7 +31,6 @@ async fn main() -> anyhow::Result<()> {
|
||||
let config = Config::from_env();
|
||||
|
||||
info!("Starting server on {}:{}", config.host, config.port);
|
||||
info!("Auth mode: {:?}", config.auth_mode);
|
||||
|
||||
// Setup database
|
||||
tracing::info!("Connecting to database: {}", config.database_url);
|
||||
@@ -49,104 +42,40 @@ async fn main() -> anyhow::Result<()> {
|
||||
};
|
||||
|
||||
let db_pool = k_core::db::connect(&db_config).await?;
|
||||
|
||||
run_migrations(&db_pool).await?;
|
||||
|
||||
let user_repo = build_user_repository(&db_pool).await?;
|
||||
let user_service = UserService::new(user_repo.clone());
|
||||
let user_service = UserService::new(user_repo);
|
||||
|
||||
let state = AppState::new(user_service, config.clone()).await?;
|
||||
|
||||
// Build session store (needed for OIDC flow even in JWT mode)
|
||||
let session_store = build_session_store(&db_pool)
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!(e))?;
|
||||
session_store
|
||||
.migrate()
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!(e))?;
|
||||
|
||||
let session_layer = SessionManagerLayer::new(session_store)
|
||||
.with_secure(config.secure_cookie)
|
||||
.with_same_site(SameSite::Lax)
|
||||
.with_expiry(Expiry::OnInactivity(Duration::days(7)));
|
||||
|
||||
let server_config = ServerConfig {
|
||||
cors_origins: config.cors_allowed_origins.clone(),
|
||||
session_secret: Some(config.session_secret.clone()),
|
||||
// session_secret is unused (sessions removed); kept for k-core API compat
|
||||
session_secret: None,
|
||||
};
|
||||
|
||||
// Build the app with appropriate auth layers based on config
|
||||
let app = build_app(state, session_layer, user_repo, &config).await?;
|
||||
let app = Router::new()
|
||||
.nest("/api/v1", routes::api_v1_router())
|
||||
.with_state(state);
|
||||
|
||||
let app = apply_standard_middleware(app, &server_config);
|
||||
|
||||
let addr: SocketAddr = format!("{}:{}", config.host, config.port).parse()?;
|
||||
let listener = TcpListener::bind(addr).await?;
|
||||
|
||||
tracing::info!("🚀 API server running at http://{}", addr);
|
||||
log_auth_info(&config);
|
||||
tracing::info!("🔒 Authentication mode: JWT (Bearer token)");
|
||||
|
||||
#[cfg(feature = "auth-jwt")]
|
||||
tracing::info!(" ✓ JWT auth enabled");
|
||||
|
||||
#[cfg(feature = "auth-oidc")]
|
||||
tracing::info!(" ✓ OIDC integration enabled (stateless cookie state)");
|
||||
|
||||
tracing::info!("📝 API endpoints available at /api/v1/...");
|
||||
|
||||
axum::serve(listener, app).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Build the application router with appropriate auth layers
|
||||
#[allow(unused_variables)] // config/user_repo used conditionally based on features
|
||||
async fn build_app(
|
||||
state: AppState,
|
||||
session_layer: SessionManagerLayer<infra::session_store::InfraSessionStore>,
|
||||
user_repo: std::sync::Arc<dyn domain::UserRepository>,
|
||||
config: &Config,
|
||||
) -> anyhow::Result<Router> {
|
||||
let app = Router::new()
|
||||
.nest("/api/v1", routes::api_v1_router())
|
||||
.with_state(state);
|
||||
|
||||
// When auth-axum-login feature is enabled, always apply the auth layer.
|
||||
// This is needed because:
|
||||
// 1. OIDC callback uses AuthSession for state management
|
||||
// 2. Session-based login/register routes use it
|
||||
// 3. The "JWT mode" just changes what the login endpoint returns, not the underlying session support
|
||||
#[cfg(feature = "auth-axum-login")]
|
||||
{
|
||||
let auth_layer = auth::setup_auth_layer(session_layer, user_repo).await?;
|
||||
return Ok(app.layer(auth_layer));
|
||||
}
|
||||
|
||||
// When auth-axum-login is not compiled in, just use session layer for OIDC flow
|
||||
#[cfg(not(feature = "auth-axum-login"))]
|
||||
{
|
||||
let _ = user_repo; // Suppress unused warning
|
||||
Ok(app.layer(session_layer))
|
||||
}
|
||||
}
|
||||
|
||||
/// Log authentication info based on enabled features and config
|
||||
fn log_auth_info(config: &Config) {
|
||||
match config.auth_mode {
|
||||
AuthMode::Session => {
|
||||
tracing::info!("🔒 Authentication mode: Session (cookie-based)");
|
||||
}
|
||||
AuthMode::Jwt => {
|
||||
tracing::info!("🔒 Authentication mode: JWT (Bearer token)");
|
||||
}
|
||||
AuthMode::Both => {
|
||||
tracing::info!("🔒 Authentication mode: Both (JWT + Session)");
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "auth-axum-login")]
|
||||
tracing::info!(" ✓ Session auth enabled (axum-login)");
|
||||
|
||||
#[cfg(feature = "auth-jwt")]
|
||||
if config.jwt_secret.is_some() {
|
||||
tracing::info!(" ✓ JWT auth enabled");
|
||||
}
|
||||
|
||||
#[cfg(feature = "auth-oidc")]
|
||||
if config.oidc_issuer.is_some() {
|
||||
tracing::info!(" ✓ OIDC integration enabled");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,10 +1,8 @@
|
||||
//! Authentication routes
|
||||
//!
|
||||
//! Provides login, register, logout, and token endpoints.
|
||||
//! Supports both session-based and JWT-based authentication.
|
||||
//! Provides login, register, logout, token, and OIDC endpoints.
|
||||
//! All authentication is JWT-based. OIDC state is stored in an encrypted cookie.
|
||||
|
||||
#[cfg(feature = "auth-oidc")]
|
||||
use axum::response::Response;
|
||||
use axum::{
|
||||
Router,
|
||||
extract::{Json, State},
|
||||
@@ -12,36 +10,13 @@ use axum::{
|
||||
response::IntoResponse,
|
||||
routing::{get, post},
|
||||
};
|
||||
use serde::Serialize;
|
||||
#[cfg(feature = "auth-oidc")]
|
||||
use tower_sessions::Session;
|
||||
|
||||
#[cfg(feature = "auth-axum-login")]
|
||||
use crate::config::AuthMode;
|
||||
use crate::{
|
||||
dto::{LoginRequest, RegisterRequest, UserResponse},
|
||||
dto::{LoginRequest, RegisterRequest, TokenResponse, UserResponse},
|
||||
error::ApiError,
|
||||
extractors::CurrentUser,
|
||||
state::AppState,
|
||||
};
|
||||
#[cfg(feature = "auth-axum-login")]
|
||||
use domain::DomainError;
|
||||
|
||||
/// Token response for JWT authentication
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct TokenResponse {
|
||||
pub access_token: String,
|
||||
pub token_type: String,
|
||||
pub expires_in: u64,
|
||||
}
|
||||
|
||||
/// Login response that can be either a user (session mode) or a token (JWT mode)
|
||||
#[derive(Debug, Serialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum LoginResponse {
|
||||
User(UserResponse),
|
||||
Token(TokenResponse),
|
||||
}
|
||||
|
||||
pub fn router() -> Router<AppState> {
|
||||
let r = Router::new()
|
||||
@@ -50,7 +25,6 @@ pub fn router() -> Router<AppState> {
|
||||
.route("/logout", post(logout))
|
||||
.route("/me", get(me));
|
||||
|
||||
// Add token endpoint for getting JWT from session
|
||||
#[cfg(feature = "auth-jwt")]
|
||||
let r = r.route("/token", post(get_token));
|
||||
|
||||
@@ -62,171 +36,68 @@ pub fn router() -> Router<AppState> {
|
||||
r
|
||||
}
|
||||
|
||||
/// Login endpoint
|
||||
///
|
||||
/// In session mode: Creates a session and returns user info
|
||||
/// In JWT mode: Validates credentials and returns a JWT token
|
||||
/// In both mode: Creates session AND returns JWT token
|
||||
#[cfg(feature = "auth-axum-login")]
|
||||
/// Login with email + password → JWT token
|
||||
async fn login(
|
||||
State(state): State<AppState>,
|
||||
mut auth_session: crate::auth::AuthSession,
|
||||
Json(payload): Json<LoginRequest>,
|
||||
) -> Result<impl IntoResponse, ApiError> {
|
||||
let user = match auth_session
|
||||
.authenticate(crate::auth::Credentials {
|
||||
email: payload.email,
|
||||
password: payload.password,
|
||||
})
|
||||
.await
|
||||
.map_err(|e| ApiError::Internal(e.to_string()))?
|
||||
{
|
||||
Some(user) => user,
|
||||
None => return Err(ApiError::Validation("Invalid credentials".to_string())),
|
||||
};
|
||||
|
||||
let auth_mode = state.config.auth_mode;
|
||||
|
||||
// In session or both mode, create session
|
||||
if matches!(auth_mode, AuthMode::Session | AuthMode::Both) {
|
||||
auth_session
|
||||
.login(&user)
|
||||
.await
|
||||
.map_err(|_| ApiError::Internal("Login failed".to_string()))?;
|
||||
}
|
||||
|
||||
// In JWT or both mode, return token
|
||||
#[cfg(feature = "auth-jwt")]
|
||||
if matches!(auth_mode, AuthMode::Jwt | AuthMode::Both) {
|
||||
let token = create_jwt_for_user(&user.0, &state)?;
|
||||
return Ok((
|
||||
StatusCode::OK,
|
||||
Json(LoginResponse::Token(TokenResponse {
|
||||
access_token: token,
|
||||
token_type: "Bearer".to_string(),
|
||||
expires_in: state.config.jwt_expiry_hours * 3600,
|
||||
})),
|
||||
));
|
||||
}
|
||||
|
||||
// Session mode: return user info
|
||||
Ok((
|
||||
StatusCode::OK,
|
||||
Json(LoginResponse::User(UserResponse {
|
||||
id: user.0.id,
|
||||
email: user.0.email.into_inner(),
|
||||
created_at: user.0.created_at,
|
||||
})),
|
||||
))
|
||||
}
|
||||
|
||||
/// Fallback login when auth-axum-login is not enabled
|
||||
/// Without auth-axum-login, password-based authentication is not available.
|
||||
/// Use OIDC login instead: GET /api/v1/auth/login/oidc
|
||||
#[cfg(not(feature = "auth-axum-login"))]
|
||||
async fn login(
|
||||
State(_state): State<AppState>,
|
||||
Json(_payload): Json<LoginRequest>,
|
||||
) -> Result<(StatusCode, Json<LoginResponse>), ApiError> {
|
||||
Err(ApiError::Internal(
|
||||
"Password-based login not available. auth-axum-login feature is required. Use OIDC login at /api/v1/auth/login/oidc instead.".to_string(),
|
||||
))
|
||||
}
|
||||
|
||||
/// Register endpoint
|
||||
#[cfg(feature = "auth-axum-login")]
|
||||
async fn register(
|
||||
State(state): State<AppState>,
|
||||
mut auth_session: crate::auth::AuthSession,
|
||||
Json(payload): Json<RegisterRequest>,
|
||||
) -> Result<impl IntoResponse, ApiError> {
|
||||
// Email is already validated by the newtype deserialization
|
||||
let email = payload.email;
|
||||
|
||||
if state
|
||||
.user_service
|
||||
.find_by_email(email.as_ref())
|
||||
.await?
|
||||
.is_some()
|
||||
{
|
||||
return Err(ApiError::Domain(DomainError::UserAlreadyExists(
|
||||
email.as_ref().to_string(),
|
||||
)));
|
||||
}
|
||||
|
||||
// Hash password
|
||||
let password_hash = infra::auth::backend::hash_password(payload.password.as_ref());
|
||||
|
||||
// Create user with password
|
||||
let user = state
|
||||
.user_service
|
||||
.create_local(email.as_ref(), &password_hash)
|
||||
.find_by_email(payload.email.as_ref())
|
||||
.await?
|
||||
.ok_or_else(|| ApiError::Unauthorized("Invalid credentials".to_string()))?;
|
||||
|
||||
let hash = user
|
||||
.password_hash
|
||||
.as_deref()
|
||||
.ok_or_else(|| ApiError::Unauthorized("Invalid credentials".to_string()))?;
|
||||
|
||||
if !infra::auth::verify_password(payload.password.as_ref(), hash) {
|
||||
return Err(ApiError::Unauthorized("Invalid credentials".to_string()));
|
||||
}
|
||||
|
||||
let token = create_jwt(&user, &state)?;
|
||||
|
||||
Ok((
|
||||
StatusCode::OK,
|
||||
Json(TokenResponse {
|
||||
access_token: token,
|
||||
token_type: "Bearer".to_string(),
|
||||
expires_in: state.config.jwt_expiry_hours * 3600,
|
||||
}),
|
||||
))
|
||||
}
|
||||
|
||||
/// Register a new local user → JWT token
|
||||
async fn register(
|
||||
State(state): State<AppState>,
|
||||
Json(payload): Json<RegisterRequest>,
|
||||
) -> Result<impl IntoResponse, ApiError> {
|
||||
let password_hash = infra::auth::hash_password(payload.password.as_ref());
|
||||
|
||||
let user = state
|
||||
.user_service
|
||||
.create_local(payload.email.as_ref(), &password_hash)
|
||||
.await?;
|
||||
|
||||
let auth_mode = state.config.auth_mode;
|
||||
|
||||
// In session or both mode, create session
|
||||
if matches!(auth_mode, AuthMode::Session | AuthMode::Both) {
|
||||
let auth_user = crate::auth::AuthUser(user.clone());
|
||||
auth_session
|
||||
.login(&auth_user)
|
||||
.await
|
||||
.map_err(|_| ApiError::Internal("Login failed".to_string()))?;
|
||||
}
|
||||
|
||||
// In JWT or both mode, return token
|
||||
#[cfg(feature = "auth-jwt")]
|
||||
if matches!(auth_mode, AuthMode::Jwt | AuthMode::Both) {
|
||||
let token = create_jwt_for_user(&user, &state)?;
|
||||
return Ok((
|
||||
StatusCode::CREATED,
|
||||
Json(LoginResponse::Token(TokenResponse {
|
||||
access_token: token,
|
||||
token_type: "Bearer".to_string(),
|
||||
expires_in: state.config.jwt_expiry_hours * 3600,
|
||||
})),
|
||||
));
|
||||
}
|
||||
let token = create_jwt(&user, &state)?;
|
||||
|
||||
Ok((
|
||||
StatusCode::CREATED,
|
||||
Json(LoginResponse::User(UserResponse {
|
||||
id: user.id,
|
||||
email: user.email.into_inner(),
|
||||
created_at: user.created_at,
|
||||
})),
|
||||
Json(TokenResponse {
|
||||
access_token: token,
|
||||
token_type: "Bearer".to_string(),
|
||||
expires_in: state.config.jwt_expiry_hours * 3600,
|
||||
}),
|
||||
))
|
||||
}
|
||||
|
||||
/// Fallback register when auth-axum-login is not enabled
|
||||
#[cfg(not(feature = "auth-axum-login"))]
|
||||
async fn register(
|
||||
State(_state): State<AppState>,
|
||||
Json(_payload): Json<RegisterRequest>,
|
||||
) -> Result<(StatusCode, Json<LoginResponse>), ApiError> {
|
||||
Err(ApiError::Internal(
|
||||
"Session-based registration not available. Use JWT token endpoint.".to_string(),
|
||||
))
|
||||
}
|
||||
|
||||
/// Logout endpoint
|
||||
#[cfg(feature = "auth-axum-login")]
|
||||
async fn logout(mut auth_session: crate::auth::AuthSession) -> impl IntoResponse {
|
||||
match auth_session.logout().await {
|
||||
Ok(_) => StatusCode::OK,
|
||||
Err(_) => StatusCode::INTERNAL_SERVER_ERROR,
|
||||
}
|
||||
}
|
||||
|
||||
/// Fallback logout when auth-axum-login is not enabled
|
||||
#[cfg(not(feature = "auth-axum-login"))]
|
||||
/// Logout — JWT is stateless; instruct the client to drop the token
|
||||
async fn logout() -> impl IntoResponse {
|
||||
// JWT tokens can't be "logged out" server-side without a blocklist
|
||||
// Just return OK
|
||||
StatusCode::OK
|
||||
}
|
||||
|
||||
/// Get current user info
|
||||
/// Get current user info from JWT
|
||||
async fn me(CurrentUser(user): CurrentUser) -> Result<impl IntoResponse, ApiError> {
|
||||
Ok(Json(UserResponse {
|
||||
id: user.id,
|
||||
@@ -235,15 +106,13 @@ async fn me(CurrentUser(user): CurrentUser) -> Result<impl IntoResponse, ApiErro
|
||||
}))
|
||||
}
|
||||
|
||||
/// Get a JWT token for the current session user
|
||||
///
|
||||
/// This allows session-authenticated users to obtain a JWT for API access.
|
||||
/// Issue a new JWT for the currently authenticated user (OIDC→JWT exchange or token refresh)
|
||||
#[cfg(feature = "auth-jwt")]
|
||||
async fn get_token(
|
||||
State(state): State<AppState>,
|
||||
CurrentUser(user): CurrentUser,
|
||||
) -> Result<impl IntoResponse, ApiError> {
|
||||
let token = create_jwt_for_user(&user, &state)?;
|
||||
let token = create_jwt(&user, &state)?;
|
||||
|
||||
Ok(Json(TokenResponse {
|
||||
access_token: token,
|
||||
@@ -252,9 +121,9 @@ async fn get_token(
|
||||
}))
|
||||
}
|
||||
|
||||
/// Helper to create JWT for a user
|
||||
/// Helper: create JWT for a user
|
||||
#[cfg(feature = "auth-jwt")]
|
||||
fn create_jwt_for_user(user: &domain::User, state: &AppState) -> Result<String, ApiError> {
|
||||
fn create_jwt(user: &domain::User, state: &AppState) -> Result<String, ApiError> {
|
||||
let validator = state
|
||||
.jwt_validator
|
||||
.as_ref()
|
||||
@@ -265,37 +134,54 @@ fn create_jwt_for_user(user: &domain::User, state: &AppState) -> Result<String,
|
||||
.map_err(|e| ApiError::Internal(format!("Failed to create token: {}", e)))
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "auth-jwt"))]
|
||||
fn create_jwt(_user: &domain::User, _state: &AppState) -> Result<String, ApiError> {
|
||||
Err(ApiError::Internal("JWT feature not enabled".to_string()))
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// OIDC Routes
|
||||
// ============================================================================
|
||||
|
||||
#[cfg(feature = "auth-oidc")]
|
||||
async fn oidc_login(State(state): State<AppState>, session: Session) -> Result<Response, ApiError> {
|
||||
#[derive(serde::Deserialize)]
|
||||
struct CallbackParams {
|
||||
code: String,
|
||||
state: String,
|
||||
}
|
||||
|
||||
/// Start OIDC login: generate authorization URL and store state in encrypted cookie
|
||||
#[cfg(feature = "auth-oidc")]
|
||||
async fn oidc_login(
|
||||
State(state): State<AppState>,
|
||||
jar: axum_extra::extract::PrivateCookieJar,
|
||||
) -> Result<impl IntoResponse, ApiError> {
|
||||
use axum::http::header;
|
||||
use axum::response::Response;
|
||||
use axum_extra::extract::cookie::{Cookie, SameSite};
|
||||
|
||||
let service = state
|
||||
.oidc_service
|
||||
.as_ref()
|
||||
.ok_or(ApiError::Internal("OIDC not configured".into()))?;
|
||||
|
||||
let auth_data = service.get_authorization_url();
|
||||
let (auth_data, oidc_state) = service.get_authorization_url();
|
||||
|
||||
session
|
||||
.insert("oidc_csrf", &auth_data.csrf_token)
|
||||
.await
|
||||
.map_err(|_| ApiError::Internal("Session error".into()))?;
|
||||
session
|
||||
.insert("oidc_nonce", &auth_data.nonce)
|
||||
.await
|
||||
.map_err(|_| ApiError::Internal("Session error".into()))?;
|
||||
session
|
||||
.insert("oidc_pkce", &auth_data.pkce_verifier)
|
||||
.await
|
||||
.map_err(|_| ApiError::Internal("Session error".into()))?;
|
||||
let state_json = serde_json::to_string(&oidc_state)
|
||||
.map_err(|e| ApiError::Internal(format!("Failed to serialize OIDC state: {}", e)))?;
|
||||
|
||||
let response = axum::response::Redirect::to(auth_data.url.as_str()).into_response();
|
||||
let (mut parts, body) = response.into_parts();
|
||||
let cookie = Cookie::build(("oidc_state", state_json))
|
||||
.max_age(time::Duration::minutes(5))
|
||||
.http_only(true)
|
||||
.same_site(SameSite::Lax)
|
||||
.secure(state.config.secure_cookie)
|
||||
.path("/")
|
||||
.build();
|
||||
|
||||
let updated_jar = jar.add(cookie);
|
||||
|
||||
let redirect = axum::response::Redirect::to(auth_data.url.as_str()).into_response();
|
||||
let (mut parts, body) = redirect.into_parts();
|
||||
parts.headers.insert(
|
||||
header::CACHE_CONTROL,
|
||||
"no-cache, no-store, must-revalidate".parse().unwrap(),
|
||||
@@ -305,54 +191,42 @@ async fn oidc_login(State(state): State<AppState>, session: Session) -> Result<R
|
||||
.insert(header::PRAGMA, "no-cache".parse().unwrap());
|
||||
parts.headers.insert(header::EXPIRES, "0".parse().unwrap());
|
||||
|
||||
Ok(Response::from_parts(parts, body))
|
||||
Ok((updated_jar, Response::from_parts(parts, body)))
|
||||
}
|
||||
|
||||
/// Handle OIDC callback: verify state cookie, complete exchange, issue JWT, clear cookie
|
||||
#[cfg(feature = "auth-oidc")]
|
||||
#[derive(serde::Deserialize)]
|
||||
struct CallbackParams {
|
||||
code: String,
|
||||
state: String,
|
||||
}
|
||||
|
||||
#[cfg(all(feature = "auth-oidc", feature = "auth-axum-login"))]
|
||||
async fn oidc_callback(
|
||||
State(state): State<AppState>,
|
||||
session: Session,
|
||||
mut auth_session: crate::auth::AuthSession,
|
||||
jar: axum_extra::extract::PrivateCookieJar,
|
||||
axum::extract::Query(params): axum::extract::Query<CallbackParams>,
|
||||
) -> Result<impl IntoResponse, ApiError> {
|
||||
use infra::auth::oidc::OidcState;
|
||||
|
||||
let service = state
|
||||
.oidc_service
|
||||
.as_ref()
|
||||
.ok_or(ApiError::Internal("OIDC not configured".into()))?;
|
||||
|
||||
let stored_csrf: domain::CsrfToken = session
|
||||
.get("oidc_csrf")
|
||||
.await
|
||||
.map_err(|_| ApiError::Internal("Session error".into()))?
|
||||
.ok_or(ApiError::Validation("Missing CSRF token".into()))?;
|
||||
// Read and decrypt OIDC state from cookie
|
||||
let cookie = jar
|
||||
.get("oidc_state")
|
||||
.ok_or(ApiError::Validation("Missing OIDC state cookie".into()))?;
|
||||
|
||||
if params.state != stored_csrf.as_ref() {
|
||||
let oidc_state: OidcState = serde_json::from_str(cookie.value())
|
||||
.map_err(|_| ApiError::Validation("Invalid OIDC state cookie".into()))?;
|
||||
|
||||
// Verify CSRF token
|
||||
if params.state != oidc_state.csrf_token.as_ref() {
|
||||
return Err(ApiError::Validation("Invalid CSRF token".into()));
|
||||
}
|
||||
|
||||
let stored_pkce: domain::PkceVerifier = session
|
||||
.get("oidc_pkce")
|
||||
.await
|
||||
.map_err(|_| ApiError::Internal("Session error".into()))?
|
||||
.ok_or(ApiError::Validation("Missing PKCE".into()))?;
|
||||
let stored_nonce: domain::OidcNonce = session
|
||||
.get("oidc_nonce")
|
||||
.await
|
||||
.map_err(|_| ApiError::Internal("Session error".into()))?
|
||||
.ok_or(ApiError::Validation("Missing Nonce".into()))?;
|
||||
|
||||
// Complete OIDC exchange
|
||||
let oidc_user = service
|
||||
.resolve_callback(
|
||||
domain::AuthorizationCode::new(params.code),
|
||||
stored_nonce,
|
||||
stored_pkce,
|
||||
oidc_state.nonce,
|
||||
oidc_state.pkce_verifier,
|
||||
)
|
||||
.await
|
||||
.map_err(|e| ApiError::Internal(e.to_string()))?;
|
||||
@@ -363,129 +237,17 @@ async fn oidc_callback(
|
||||
.await
|
||||
.map_err(|e| ApiError::Internal(e.to_string()))?;
|
||||
|
||||
let auth_mode = state.config.auth_mode;
|
||||
// Clear the OIDC state cookie
|
||||
let cleared_jar = jar.remove(axum_extra::extract::cookie::Cookie::from("oidc_state"));
|
||||
|
||||
// In session or both mode, create session
|
||||
if matches!(auth_mode, AuthMode::Session | AuthMode::Both) {
|
||||
auth_session
|
||||
.login(&crate::auth::AuthUser(user.clone()))
|
||||
.await
|
||||
.map_err(|_| ApiError::Internal("Login failed".into()))?;
|
||||
}
|
||||
let token = create_jwt(&user, &state)?;
|
||||
|
||||
// Clean up OIDC state
|
||||
let _: Option<String> = session
|
||||
.remove("oidc_csrf")
|
||||
.await
|
||||
.map_err(|_| ApiError::Internal("Session error".into()))?;
|
||||
let _: Option<String> = session
|
||||
.remove("oidc_pkce")
|
||||
.await
|
||||
.map_err(|_| ApiError::Internal("Session error".into()))?;
|
||||
let _: Option<String> = session
|
||||
.remove("oidc_nonce")
|
||||
.await
|
||||
.map_err(|_| ApiError::Internal("Session error".into()))?;
|
||||
|
||||
// In JWT mode, return token as JSON
|
||||
#[cfg(feature = "auth-jwt")]
|
||||
if matches!(auth_mode, AuthMode::Jwt | AuthMode::Both) {
|
||||
let token = create_jwt_for_user(&user, &state)?;
|
||||
return Ok(Json(TokenResponse {
|
||||
Ok((
|
||||
cleared_jar,
|
||||
Json(TokenResponse {
|
||||
access_token: token,
|
||||
token_type: "Bearer".to_string(),
|
||||
expires_in: state.config.jwt_expiry_hours * 3600,
|
||||
})
|
||||
.into_response());
|
||||
}
|
||||
|
||||
// Session mode: return user info
|
||||
Ok(Json(UserResponse {
|
||||
id: user.id,
|
||||
email: user.email.into_inner(),
|
||||
created_at: user.created_at,
|
||||
})
|
||||
.into_response())
|
||||
}
|
||||
|
||||
/// Fallback OIDC callback when auth-axum-login is not enabled
|
||||
#[cfg(all(feature = "auth-oidc", not(feature = "auth-axum-login")))]
|
||||
async fn oidc_callback(
|
||||
State(state): State<AppState>,
|
||||
session: Session,
|
||||
axum::extract::Query(params): axum::extract::Query<CallbackParams>,
|
||||
) -> Result<impl IntoResponse, ApiError> {
|
||||
let service = state
|
||||
.oidc_service
|
||||
.as_ref()
|
||||
.ok_or(ApiError::Internal("OIDC not configured".into()))?;
|
||||
|
||||
let stored_csrf: domain::CsrfToken = session
|
||||
.get("oidc_csrf")
|
||||
.await
|
||||
.map_err(|_| ApiError::Internal("Session error".into()))?
|
||||
.ok_or(ApiError::Validation("Missing CSRF token".into()))?;
|
||||
|
||||
if params.state != stored_csrf.as_ref() {
|
||||
return Err(ApiError::Validation("Invalid CSRF token".into()));
|
||||
}
|
||||
|
||||
let stored_pkce: domain::PkceVerifier = session
|
||||
.get("oidc_pkce")
|
||||
.await
|
||||
.map_err(|_| ApiError::Internal("Session error".into()))?
|
||||
.ok_or(ApiError::Validation("Missing PKCE".into()))?;
|
||||
let stored_nonce: domain::OidcNonce = session
|
||||
.get("oidc_nonce")
|
||||
.await
|
||||
.map_err(|_| ApiError::Internal("Session error".into()))?
|
||||
.ok_or(ApiError::Validation("Missing Nonce".into()))?;
|
||||
|
||||
let oidc_user = service
|
||||
.resolve_callback(
|
||||
domain::AuthorizationCode::new(params.code),
|
||||
stored_nonce,
|
||||
stored_pkce,
|
||||
)
|
||||
.await
|
||||
.map_err(|e| ApiError::Internal(e.to_string()))?;
|
||||
|
||||
let user = state
|
||||
.user_service
|
||||
.find_or_create(&oidc_user.subject, &oidc_user.email)
|
||||
.await
|
||||
.map_err(|e| ApiError::Internal(e.to_string()))?;
|
||||
|
||||
// Clean up OIDC state
|
||||
let _: Option<String> = session
|
||||
.remove("oidc_csrf")
|
||||
.await
|
||||
.map_err(|_| ApiError::Internal("Session error".into()))?;
|
||||
let _: Option<String> = session
|
||||
.remove("oidc_pkce")
|
||||
.await
|
||||
.map_err(|_| ApiError::Internal("Session error".into()))?;
|
||||
let _: Option<String> = session
|
||||
.remove("oidc_nonce")
|
||||
.await
|
||||
.map_err(|_| ApiError::Internal("Session error".into()))?;
|
||||
|
||||
// Return token as JSON
|
||||
#[cfg(feature = "auth-jwt")]
|
||||
{
|
||||
let token = create_jwt_for_user(&user, &state)?;
|
||||
return Ok(Json(TokenResponse {
|
||||
access_token: token,
|
||||
token_type: "Bearer".to_string(),
|
||||
expires_in: state.config.jwt_expiry_hours * 3600,
|
||||
}));
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "auth-jwt"))]
|
||||
{
|
||||
let _ = user; // Suppress unused warning
|
||||
Err(ApiError::Internal(
|
||||
"No auth backend available for OIDC callback".to_string(),
|
||||
))
|
||||
}
|
||||
}),
|
||||
))
|
||||
}
|
||||
|
||||
@@ -3,18 +3,20 @@
|
||||
//! Holds shared state for the application.
|
||||
|
||||
use axum::extract::FromRef;
|
||||
use axum_extra::extract::cookie::Key;
|
||||
#[cfg(feature = "auth-jwt")]
|
||||
use infra::auth::jwt::{JwtConfig, JwtValidator};
|
||||
#[cfg(feature = "auth-oidc")]
|
||||
use infra::auth::oidc::OidcService;
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::config::{AuthMode, Config};
|
||||
use crate::config::Config;
|
||||
use domain::UserService;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct AppState {
|
||||
pub user_service: Arc<UserService>,
|
||||
pub cookie_key: Key,
|
||||
#[cfg(feature = "auth-oidc")]
|
||||
pub oidc_service: Option<Arc<OidcService>>,
|
||||
#[cfg(feature = "auth-jwt")]
|
||||
@@ -24,6 +26,8 @@ pub struct AppState {
|
||||
|
||||
impl AppState {
|
||||
pub async fn new(user_service: UserService, config: Config) -> anyhow::Result<Self> {
|
||||
let cookie_key = Key::derive_from(config.cookie_secret.as_bytes());
|
||||
|
||||
#[cfg(feature = "auth-oidc")]
|
||||
let oidc_service = if let (Some(issuer), Some(id), secret, Some(redirect), resource_id) = (
|
||||
&config.oidc_issuer,
|
||||
@@ -34,7 +38,6 @@ impl AppState {
|
||||
) {
|
||||
tracing::info!("Initializing OIDC service with issuer: {}", issuer);
|
||||
|
||||
// Construct newtypes from config strings
|
||||
let issuer_url = domain::IssuerUrl::new(issuer)
|
||||
.map_err(|e| anyhow::anyhow!("Invalid OIDC issuer URL: {}", e))?;
|
||||
let client_id = domain::ClientId::new(id)
|
||||
@@ -57,25 +60,15 @@ impl AppState {
|
||||
};
|
||||
|
||||
#[cfg(feature = "auth-jwt")]
|
||||
let jwt_validator = if matches!(config.auth_mode, AuthMode::Jwt | AuthMode::Both) {
|
||||
// Use provided secret or fall back to a development secret
|
||||
let secret = if let Some(ref s) = config.jwt_secret {
|
||||
if s.is_empty() { None } else { Some(s.clone()) }
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let secret = match secret {
|
||||
Some(s) => s,
|
||||
None => {
|
||||
let jwt_validator = {
|
||||
let secret = match &config.jwt_secret {
|
||||
Some(s) if !s.is_empty() => s.clone(),
|
||||
_ => {
|
||||
if config.is_production {
|
||||
anyhow::bail!(
|
||||
"JWT_SECRET is required when AUTH_MODE is 'jwt' or 'both' in production"
|
||||
);
|
||||
anyhow::bail!("JWT_SECRET is required in production");
|
||||
}
|
||||
// Use a development-only default secret
|
||||
tracing::warn!(
|
||||
"⚠️ JWT_SECRET not set - using insecure development secret. DO NOT USE IN PRODUCTION!"
|
||||
"⚠️ JWT_SECRET not set — using insecure development secret. DO NOT USE IN PRODUCTION!"
|
||||
);
|
||||
"k-template-dev-secret-not-for-production-use-only".to_string()
|
||||
}
|
||||
@@ -90,12 +83,11 @@ impl AppState {
|
||||
config.is_production,
|
||||
)?;
|
||||
Some(Arc::new(JwtValidator::new(jwt_config)))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
Ok(Self {
|
||||
user_service: Arc::new(user_service),
|
||||
cookie_key,
|
||||
#[cfg(feature = "auth-oidc")]
|
||||
oidc_service,
|
||||
#[cfg(feature = "auth-jwt")]
|
||||
@@ -116,3 +108,9 @@ impl FromRef<AppState> for Arc<Config> {
|
||||
input.config.clone()
|
||||
}
|
||||
}
|
||||
|
||||
impl FromRef<AppState> for Key {
|
||||
fn from_ref(input: &AppState) -> Self {
|
||||
input.cookie_key.clone()
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user