feat: add JWT authentication and flexible auth modes with configurable login responses
This commit is contained in:
48
Cargo.lock
generated
48
Cargo.lock
generated
@@ -1360,6 +1360,7 @@ dependencies = [
|
|||||||
"domain",
|
"domain",
|
||||||
"futures-core",
|
"futures-core",
|
||||||
"futures-util",
|
"futures-util",
|
||||||
|
"jsonwebtoken",
|
||||||
"k-core",
|
"k-core",
|
||||||
"openidconnect",
|
"openidconnect",
|
||||||
"password-auth",
|
"password-auth",
|
||||||
@@ -1427,6 +1428,21 @@ dependencies = [
|
|||||||
"serde",
|
"serde",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "jsonwebtoken"
|
||||||
|
version = "9.3.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "5a87cc7a48537badeae96744432de36f4be2b4a34a05a5ef32e9dd8a1c169dde"
|
||||||
|
dependencies = [
|
||||||
|
"base64 0.22.1",
|
||||||
|
"js-sys",
|
||||||
|
"pem",
|
||||||
|
"ring",
|
||||||
|
"serde",
|
||||||
|
"serde_json",
|
||||||
|
"simple_asn1",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "k-core"
|
name = "k-core"
|
||||||
version = "0.1.10"
|
version = "0.1.10"
|
||||||
@@ -1605,6 +1621,16 @@ dependencies = [
|
|||||||
"rand 0.8.5",
|
"rand 0.8.5",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "num-bigint"
|
||||||
|
version = "0.4.6"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "a5e44f723f1133c9deac646763579fdb3ac745e418f2a7af9cd0c431da1f20b9"
|
||||||
|
dependencies = [
|
||||||
|
"num-integer",
|
||||||
|
"num-traits",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "num-bigint-dig"
|
name = "num-bigint-dig"
|
||||||
version = "0.8.6"
|
version = "0.8.6"
|
||||||
@@ -1821,6 +1847,16 @@ version = "0.2.3"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "df94ce210e5bc13cb6651479fa48d14f601d9858cfe0467f43ae157023b938d3"
|
checksum = "df94ce210e5bc13cb6651479fa48d14f601d9858cfe0467f43ae157023b938d3"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "pem"
|
||||||
|
version = "3.0.6"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "1d30c53c26bc5b31a98cd02d20f25a7c8567146caf63ed593a9d87b2775291be"
|
||||||
|
dependencies = [
|
||||||
|
"base64 0.22.1",
|
||||||
|
"serde_core",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "pem-rfc7468"
|
name = "pem-rfc7468"
|
||||||
version = "0.7.0"
|
version = "0.7.0"
|
||||||
@@ -2731,6 +2767,18 @@ dependencies = [
|
|||||||
"rand_core 0.6.4",
|
"rand_core 0.6.4",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "simple_asn1"
|
||||||
|
version = "0.6.3"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "297f631f50729c8c99b84667867963997ec0b50f32b2a7dbcab828ef0541e8bb"
|
||||||
|
dependencies = [
|
||||||
|
"num-bigint",
|
||||||
|
"num-traits",
|
||||||
|
"thiserror 2.0.17",
|
||||||
|
"time",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "slab"
|
name = "slab"
|
||||||
version = "0.4.11"
|
version = "0.4.11"
|
||||||
|
|||||||
10
README.md
10
README.md
@@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
A production-ready, modular Rust template for K-Suite applications, following Hexagonal Architecture principles.
|
A production-ready, modular Rust template for K-Suite applications, following Hexagonal Architecture principles.
|
||||||
|
|
||||||
## 🌟 Features
|
## Features
|
||||||
|
|
||||||
- **Hexagonal Architecture**: Clear separation of concerns between Domain, Infrastructure, and API layers.
|
- **Hexagonal Architecture**: Clear separation of concerns between Domain, Infrastructure, and API layers.
|
||||||
- **Modular & Swappable**: Vendor implementations (databases, message brokers) are behind feature flags and trait objects.
|
- **Modular & Swappable**: Vendor implementations (databases, message brokers) are behind feature flags and trait objects.
|
||||||
@@ -10,7 +10,7 @@ A production-ready, modular Rust template for K-Suite applications, following He
|
|||||||
- **Cargo Generate Ready**: Pre-configured for `cargo-generate` to easily scaffold new services.
|
- **Cargo Generate Ready**: Pre-configured for `cargo-generate` to easily scaffold new services.
|
||||||
- **Testable**: Domain logic is pure and easily testable; Infrastructure is tested with integration tests.
|
- **Testable**: Domain logic is pure and easily testable; Infrastructure is tested with integration tests.
|
||||||
|
|
||||||
## 🏗️ Project Structure
|
## Project Structure
|
||||||
|
|
||||||
The workspace consists of three main crates:
|
The workspace consists of three main crates:
|
||||||
|
|
||||||
@@ -26,7 +26,7 @@ The workspace consists of three main crates:
|
|||||||
- Wires everything together using dependency injection.
|
- Wires everything together using dependency injection.
|
||||||
- Handles HTTP/REST/gRPC interfaces.
|
- Handles HTTP/REST/gRPC interfaces.
|
||||||
|
|
||||||
## 🚀 Getting Started
|
## Getting Started
|
||||||
|
|
||||||
### Prerequisites
|
### Prerequisites
|
||||||
|
|
||||||
@@ -57,7 +57,7 @@ cargo test
|
|||||||
cargo test -p template-infra --no-default-features --features postgres
|
cargo test -p template-infra --no-default-features --features postgres
|
||||||
```
|
```
|
||||||
|
|
||||||
## ⚙️ Configuration & Feature Flags
|
## Configuration & Feature Flags
|
||||||
|
|
||||||
This template uses Cargo features to control compilation of infrastructure adapters.
|
This template uses Cargo features to control compilation of infrastructure adapters.
|
||||||
|
|
||||||
@@ -86,7 +86,7 @@ default = ["postgres"]
|
|||||||
# ...
|
# ...
|
||||||
```
|
```
|
||||||
|
|
||||||
## 📐 Architecture Guide
|
## Architecture Guide
|
||||||
|
|
||||||
### Adding a New Feature
|
### Adding a New Feature
|
||||||
|
|
||||||
|
|||||||
@@ -5,11 +5,13 @@ edition = "2024"
|
|||||||
default-run = "api"
|
default-run = "api"
|
||||||
|
|
||||||
[features]
|
[features]
|
||||||
default = ["sqlite", "auth-axum-login", "auth-oidc"]
|
default = ["sqlite", "auth-axum-login", "auth-oidc", "auth-jwt"]
|
||||||
sqlite = ["infra/sqlite", "tower-sessions-sqlx-store/sqlite"]
|
sqlite = ["infra/sqlite", "tower-sessions-sqlx-store/sqlite"]
|
||||||
postgres = ["infra/postgres", "tower-sessions-sqlx-store/postgres"]
|
postgres = ["infra/postgres", "tower-sessions-sqlx-store/postgres"]
|
||||||
auth-axum-login = ["infra/auth-axum-login"]
|
auth-axum-login = ["infra/auth-axum-login"]
|
||||||
auth-oidc = ["infra/auth-oidc"]
|
auth-oidc = ["infra/auth-oidc"]
|
||||||
|
auth-jwt = ["infra/auth-jwt"]
|
||||||
|
auth-full = ["auth-axum-login", "auth-oidc", "auth-jwt"]
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
k-core = { git = "https://git.gabrielkaszewski.dev/GKaszewski/k-core", features = [
|
k-core = { git = "https://git.gabrielkaszewski.dev/GKaszewski/k-core", features = [
|
||||||
|
|||||||
@@ -2,11 +2,15 @@
|
|||||||
//!
|
//!
|
||||||
//! Proxies to infra implementation if enabled.
|
//! Proxies to infra implementation if enabled.
|
||||||
|
|
||||||
|
#[cfg(feature = "auth-axum-login")]
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
#[cfg(feature = "auth-axum-login")]
|
||||||
use domain::UserRepository;
|
use domain::UserRepository;
|
||||||
|
#[cfg(feature = "auth-axum-login")]
|
||||||
use infra::session_store::{InfraSessionStore, SessionManagerLayer};
|
use infra::session_store::{InfraSessionStore, SessionManagerLayer};
|
||||||
|
|
||||||
|
#[cfg(feature = "auth-axum-login")]
|
||||||
use crate::error::ApiError;
|
use crate::error::ApiError;
|
||||||
|
|
||||||
#[cfg(feature = "auth-axum-login")]
|
#[cfg(feature = "auth-axum-login")]
|
||||||
|
|||||||
@@ -6,6 +6,30 @@ use std::env;
|
|||||||
|
|
||||||
use serde::Deserialize;
|
use serde::Deserialize;
|
||||||
|
|
||||||
|
/// Authentication mode - determines how the API authenticates requests
|
||||||
|
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Deserialize)]
|
||||||
|
#[serde(rename_all = "lowercase")]
|
||||||
|
pub enum AuthMode {
|
||||||
|
/// Session-based authentication using cookies (default for backward compatibility)
|
||||||
|
#[default]
|
||||||
|
Session,
|
||||||
|
/// JWT-based authentication using Bearer tokens
|
||||||
|
Jwt,
|
||||||
|
/// Support both session and JWT authentication (try JWT first, then session)
|
||||||
|
Both,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl AuthMode {
|
||||||
|
/// Parse auth mode from string
|
||||||
|
pub fn from_str(s: &str) -> Self {
|
||||||
|
match s.to_lowercase().as_str() {
|
||||||
|
"jwt" => AuthMode::Jwt,
|
||||||
|
"both" => AuthMode::Both,
|
||||||
|
_ => AuthMode::Session,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
//todo: replace with newtypes
|
//todo: replace with newtypes
|
||||||
#[derive(Debug, Clone, Deserialize)]
|
#[derive(Debug, Clone, Deserialize)]
|
||||||
pub struct Config {
|
pub struct Config {
|
||||||
@@ -28,10 +52,27 @@ pub struct Config {
|
|||||||
#[serde(default = "default_db_min_connections")]
|
#[serde(default = "default_db_min_connections")]
|
||||||
pub db_min_connections: u32,
|
pub db_min_connections: u32,
|
||||||
|
|
||||||
|
// OIDC configuration
|
||||||
pub oidc_issuer: Option<String>,
|
pub oidc_issuer: Option<String>,
|
||||||
pub oidc_client_id: Option<String>,
|
pub oidc_client_id: Option<String>,
|
||||||
pub oidc_client_secret: Option<String>,
|
pub oidc_client_secret: Option<String>,
|
||||||
pub oidc_redirect_url: Option<String>,
|
pub oidc_redirect_url: Option<String>,
|
||||||
|
pub oidc_resource_id: Option<String>,
|
||||||
|
|
||||||
|
// Auth mode configuration
|
||||||
|
#[serde(default)]
|
||||||
|
pub auth_mode: AuthMode,
|
||||||
|
|
||||||
|
// JWT configuration
|
||||||
|
pub jwt_secret: Option<String>,
|
||||||
|
pub jwt_issuer: Option<String>,
|
||||||
|
pub jwt_audience: Option<String>,
|
||||||
|
#[serde(default = "default_jwt_expiry_hours")]
|
||||||
|
pub jwt_expiry_hours: u64,
|
||||||
|
|
||||||
|
/// Whether the application is running in production mode
|
||||||
|
#[serde(default)]
|
||||||
|
pub is_production: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn default_secure_cookie() -> bool {
|
fn default_secure_cookie() -> bool {
|
||||||
@@ -54,6 +95,10 @@ fn default_host() -> String {
|
|||||||
"127.0.0.1".to_string()
|
"127.0.0.1".to_string()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn default_jwt_expiry_hours() -> u64 {
|
||||||
|
24
|
||||||
|
}
|
||||||
|
|
||||||
impl Config {
|
impl Config {
|
||||||
pub fn new() -> Result<Self, config::ConfigError> {
|
pub fn new() -> Result<Self, config::ConfigError> {
|
||||||
config::Config::builder()
|
config::Config::builder()
|
||||||
@@ -108,6 +153,26 @@ impl Config {
|
|||||||
let oidc_client_id = env::var("OIDC_CLIENT_ID").ok();
|
let oidc_client_id = env::var("OIDC_CLIENT_ID").ok();
|
||||||
let oidc_client_secret = env::var("OIDC_CLIENT_SECRET").ok();
|
let oidc_client_secret = env::var("OIDC_CLIENT_SECRET").ok();
|
||||||
let oidc_redirect_url = env::var("OIDC_REDIRECT_URL").ok();
|
let oidc_redirect_url = env::var("OIDC_REDIRECT_URL").ok();
|
||||||
|
let oidc_resource_id = env::var("OIDC_RESOURCE_ID").ok();
|
||||||
|
|
||||||
|
// Auth mode configuration
|
||||||
|
let auth_mode = env::var("AUTH_MODE")
|
||||||
|
.map(|s| AuthMode::from_str(&s))
|
||||||
|
.unwrap_or_default();
|
||||||
|
|
||||||
|
// JWT configuration
|
||||||
|
let jwt_secret = env::var("JWT_SECRET").ok();
|
||||||
|
let jwt_issuer = env::var("JWT_ISSUER").ok();
|
||||||
|
let jwt_audience = env::var("JWT_AUDIENCE").ok();
|
||||||
|
let jwt_expiry_hours = env::var("JWT_EXPIRY_HOURS")
|
||||||
|
.ok()
|
||||||
|
.and_then(|s| s.parse().ok())
|
||||||
|
.unwrap_or(24);
|
||||||
|
|
||||||
|
let is_production = env::var("PRODUCTION")
|
||||||
|
.or_else(|_| env::var("RUST_ENV"))
|
||||||
|
.map(|v| v.to_lowercase() == "production" || v == "1" || v == "true")
|
||||||
|
.unwrap_or(false);
|
||||||
|
|
||||||
Self {
|
Self {
|
||||||
host,
|
host,
|
||||||
@@ -122,6 +187,13 @@ impl Config {
|
|||||||
oidc_client_id,
|
oidc_client_id,
|
||||||
oidc_client_secret,
|
oidc_client_secret,
|
||||||
oidc_redirect_url,
|
oidc_redirect_url,
|
||||||
|
oidc_resource_id,
|
||||||
|
auth_mode,
|
||||||
|
jwt_secret,
|
||||||
|
jwt_issuer,
|
||||||
|
jwt_audience,
|
||||||
|
jwt_expiry_hours,
|
||||||
|
is_production,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -40,3 +40,12 @@ pub struct UserResponse {
|
|||||||
pub struct ConfigResponse {
|
pub struct ConfigResponse {
|
||||||
pub allow_registration: bool,
|
pub allow_registration: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "auth-jwt")]
|
||||||
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
|
// also newtypes
|
||||||
|
pub struct Claims {
|
||||||
|
pub sub: String,
|
||||||
|
pub email: String,
|
||||||
|
pub exp: usize,
|
||||||
|
}
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ use domain::DomainError;
|
|||||||
|
|
||||||
/// API-level errors
|
/// API-level errors
|
||||||
#[derive(Debug, Error)]
|
#[derive(Debug, Error)]
|
||||||
|
#[allow(dead_code)] // Some variants are reserved for future use
|
||||||
pub enum ApiError {
|
pub enum ApiError {
|
||||||
#[error("{0}")]
|
#[error("{0}")]
|
||||||
Domain(#[from] DomainError),
|
Domain(#[from] DomainError),
|
||||||
@@ -107,6 +108,7 @@ impl IntoResponse for ApiError {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[allow(dead_code)] // Helper constructors for future use
|
||||||
impl ApiError {
|
impl ApiError {
|
||||||
pub fn validation(msg: impl Into<String>) -> Self {
|
pub fn validation(msg: impl Into<String>) -> Self {
|
||||||
Self::Validation(msg.into())
|
Self::Validation(msg.into())
|
||||||
@@ -118,4 +120,5 @@ impl ApiError {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Result type alias for API handlers
|
/// Result type alias for API handlers
|
||||||
|
#[allow(dead_code)]
|
||||||
pub type ApiResult<T> = Result<T, ApiError>;
|
pub type ApiResult<T> = Result<T, ApiError>;
|
||||||
|
|||||||
150
api/src/extractors.rs
Normal file
150
api/src/extractors.rs
Normal file
@@ -0,0 +1,150 @@
|
|||||||
|
//! Auth extractors for API handlers
|
||||||
|
//!
|
||||||
|
//! Provides the `CurrentUser` extractor that works with both session and JWT auth.
|
||||||
|
|
||||||
|
use axum::{extract::FromRequestParts, http::request::Parts};
|
||||||
|
use domain::User;
|
||||||
|
|
||||||
|
use crate::config::AuthMode;
|
||||||
|
use crate::error::ApiError;
|
||||||
|
use crate::state::AppState;
|
||||||
|
|
||||||
|
/// Extracted current user from the request.
|
||||||
|
///
|
||||||
|
/// This extractor supports multiple authentication methods based on the configured `AuthMode`:
|
||||||
|
/// - `Session`: Uses axum-login session cookies
|
||||||
|
/// - `Jwt`: Uses Bearer token in Authorization header
|
||||||
|
/// - `Both`: Tries JWT first, then falls back to session
|
||||||
|
pub struct CurrentUser(pub User);
|
||||||
|
|
||||||
|
impl FromRequestParts<AppState> for CurrentUser {
|
||||||
|
type Rejection = ApiError;
|
||||||
|
|
||||||
|
async fn from_request_parts(
|
||||||
|
parts: &mut Parts,
|
||||||
|
state: &AppState,
|
||||||
|
) -> Result<Self, Self::Rejection> {
|
||||||
|
let auth_mode = state.config.auth_mode;
|
||||||
|
|
||||||
|
// Try JWT first if enabled
|
||||||
|
#[cfg(feature = "auth-jwt")]
|
||||||
|
if matches!(auth_mode, AuthMode::Jwt | AuthMode::Both) {
|
||||||
|
match try_jwt_auth(parts, state).await {
|
||||||
|
Ok(Some(user)) => return Ok(CurrentUser(user)),
|
||||||
|
Ok(None) => {
|
||||||
|
// No JWT token present, continue to session auth if Both mode
|
||||||
|
if auth_mode == AuthMode::Jwt {
|
||||||
|
return Err(ApiError::Unauthorized(
|
||||||
|
"Missing or invalid Authorization header".to_string(),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
// JWT was present but invalid
|
||||||
|
tracing::debug!("JWT auth failed: {}", e);
|
||||||
|
if auth_mode == AuthMode::Jwt {
|
||||||
|
return Err(e);
|
||||||
|
}
|
||||||
|
// In Both mode, continue to try session
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try session auth if enabled
|
||||||
|
#[cfg(feature = "auth-axum-login")]
|
||||||
|
if matches!(auth_mode, AuthMode::Session | AuthMode::Both) {
|
||||||
|
if let Some(user) = try_session_auth(parts).await? {
|
||||||
|
return Ok(CurrentUser(user));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Err(ApiError::Unauthorized("Not authenticated".to_string()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Try to authenticate using JWT Bearer token
|
||||||
|
#[cfg(feature = "auth-jwt")]
|
||||||
|
async fn try_jwt_auth(parts: &mut Parts, state: &AppState) -> Result<Option<User>, ApiError> {
|
||||||
|
use axum::http::header::AUTHORIZATION;
|
||||||
|
|
||||||
|
// Get Authorization header
|
||||||
|
let auth_header = match parts.headers.get(AUTHORIZATION) {
|
||||||
|
Some(header) => header,
|
||||||
|
None => return Ok(None), // No header = no JWT auth attempted
|
||||||
|
};
|
||||||
|
|
||||||
|
let auth_str = auth_header
|
||||||
|
.to_str()
|
||||||
|
.map_err(|_| ApiError::Unauthorized("Invalid Authorization header encoding".to_string()))?;
|
||||||
|
|
||||||
|
// Extract Bearer token
|
||||||
|
let token = auth_str.strip_prefix("Bearer ").ok_or_else(|| {
|
||||||
|
ApiError::Unauthorized("Authorization header must use Bearer scheme".to_string())
|
||||||
|
})?;
|
||||||
|
|
||||||
|
// Get JWT validator
|
||||||
|
let validator = state
|
||||||
|
.jwt_validator
|
||||||
|
.as_ref()
|
||||||
|
.ok_or_else(|| ApiError::Internal("JWT validator not configured".to_string()))?;
|
||||||
|
|
||||||
|
// Validate token
|
||||||
|
let claims = validator.validate_token(token).map_err(|e| {
|
||||||
|
tracing::debug!("JWT validation failed: {:?}", e);
|
||||||
|
match e {
|
||||||
|
infra::auth::jwt::JwtError::Expired => {
|
||||||
|
ApiError::Unauthorized("Token expired".to_string())
|
||||||
|
}
|
||||||
|
infra::auth::jwt::JwtError::InvalidFormat => {
|
||||||
|
ApiError::Unauthorized("Invalid token format".to_string())
|
||||||
|
}
|
||||||
|
_ => ApiError::Unauthorized("Token validation failed".to_string()),
|
||||||
|
}
|
||||||
|
})?;
|
||||||
|
|
||||||
|
// Fetch user from database by ID (subject contains user ID)
|
||||||
|
let user_id: uuid::Uuid = claims
|
||||||
|
.sub
|
||||||
|
.parse()
|
||||||
|
.map_err(|_| ApiError::Unauthorized("Invalid user ID in token".to_string()))?;
|
||||||
|
|
||||||
|
let user = state
|
||||||
|
.user_service
|
||||||
|
.find_by_id(user_id)
|
||||||
|
.await
|
||||||
|
.map_err(|e| ApiError::Internal(format!("Failed to fetch user: {}", e)))?;
|
||||||
|
|
||||||
|
Ok(Some(user))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Try to authenticate using session cookie
|
||||||
|
#[cfg(feature = "auth-axum-login")]
|
||||||
|
async fn try_session_auth(parts: &mut Parts) -> Result<Option<User>, ApiError> {
|
||||||
|
use infra::auth::backend::AuthSession;
|
||||||
|
|
||||||
|
// Check if AuthSession extension is present (added by auth middleware)
|
||||||
|
if let Some(auth_session) = parts.extensions.get::<AuthSession>() {
|
||||||
|
if let Some(auth_user) = &auth_session.user {
|
||||||
|
return Ok(Some(auth_user.0.clone()));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(None)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Fallback for when auth-axum-login is not enabled
|
||||||
|
#[cfg(not(feature = "auth-axum-login"))]
|
||||||
|
async fn try_session_auth(_parts: &mut Parts) -> Result<Option<User>, ApiError> {
|
||||||
|
Ok(None)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Fallback for when auth-jwt is not enabled but auth mode requires it
|
||||||
|
#[cfg(not(feature = "auth-jwt"))]
|
||||||
|
async fn try_jwt_auth(_parts: &mut Parts, state: &AppState) -> Result<Option<User>, ApiError> {
|
||||||
|
if matches!(state.config.auth_mode, AuthMode::Jwt) {
|
||||||
|
return Err(ApiError::Internal(
|
||||||
|
"JWT auth mode configured but auth-jwt feature not enabled".to_string(),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
Ok(None)
|
||||||
|
}
|
||||||
@@ -1,3 +1,7 @@
|
|||||||
|
//! API Server Entry Point
|
||||||
|
//!
|
||||||
|
//! Configures and starts the HTTP server with authentication based on AUTH_MODE.
|
||||||
|
|
||||||
use std::net::SocketAddr;
|
use std::net::SocketAddr;
|
||||||
use std::time::Duration as StdDuration;
|
use std::time::Duration as StdDuration;
|
||||||
|
|
||||||
@@ -12,17 +16,18 @@ use k_core::http::server::apply_standard_middleware;
|
|||||||
use k_core::logging;
|
use k_core::logging;
|
||||||
use time::Duration;
|
use time::Duration;
|
||||||
use tokio::net::TcpListener;
|
use tokio::net::TcpListener;
|
||||||
|
use tower_sessions::cookie::SameSite;
|
||||||
use tracing::info;
|
use tracing::info;
|
||||||
|
|
||||||
mod auth;
|
mod auth;
|
||||||
mod config;
|
mod config;
|
||||||
mod dto;
|
mod dto;
|
||||||
mod error;
|
mod error;
|
||||||
|
mod extractors;
|
||||||
mod routes;
|
mod routes;
|
||||||
mod state;
|
mod state;
|
||||||
|
|
||||||
use crate::auth::setup_auth_layer;
|
use crate::config::{AuthMode, Config};
|
||||||
use crate::config::Config;
|
|
||||||
use crate::state::AppState;
|
use crate::state::AppState;
|
||||||
|
|
||||||
#[tokio::main]
|
#[tokio::main]
|
||||||
@@ -32,6 +37,7 @@ async fn main() -> anyhow::Result<()> {
|
|||||||
let config = Config::from_env();
|
let config = Config::from_env();
|
||||||
|
|
||||||
info!("Starting server on {}:{}", config.host, config.port);
|
info!("Starting server on {}:{}", config.host, config.port);
|
||||||
|
info!("Auth mode: {:?}", config.auth_mode);
|
||||||
|
|
||||||
// Setup database
|
// Setup database
|
||||||
tracing::info!("Connecting to database: {}", config.database_url);
|
tracing::info!("Connecting to database: {}", config.database_url);
|
||||||
@@ -51,6 +57,7 @@ async fn main() -> anyhow::Result<()> {
|
|||||||
|
|
||||||
let state = AppState::new(user_service, config.clone()).await?;
|
let state = AppState::new(user_service, config.clone()).await?;
|
||||||
|
|
||||||
|
// Build session store (needed for OIDC flow even in JWT mode)
|
||||||
let session_store = build_session_store(&db_pool)
|
let session_store = build_session_store(&db_pool)
|
||||||
.await
|
.await
|
||||||
.map_err(|e| anyhow::anyhow!(e))?;
|
.map_err(|e| anyhow::anyhow!(e))?;
|
||||||
@@ -61,30 +68,85 @@ async fn main() -> anyhow::Result<()> {
|
|||||||
|
|
||||||
let session_layer = SessionManagerLayer::new(session_store)
|
let session_layer = SessionManagerLayer::new(session_store)
|
||||||
.with_secure(config.secure_cookie)
|
.with_secure(config.secure_cookie)
|
||||||
|
.with_same_site(SameSite::Lax)
|
||||||
.with_expiry(Expiry::OnInactivity(Duration::days(7)));
|
.with_expiry(Expiry::OnInactivity(Duration::days(7)));
|
||||||
|
|
||||||
let auth_layer = setup_auth_layer(session_layer, user_repo).await?;
|
|
||||||
|
|
||||||
let server_config = ServerConfig {
|
let server_config = ServerConfig {
|
||||||
cors_origins: config.cors_allowed_origins.clone(),
|
cors_origins: config.cors_allowed_origins.clone(),
|
||||||
session_secret: Some(config.session_secret.clone()),
|
session_secret: Some(config.session_secret.clone()),
|
||||||
};
|
};
|
||||||
|
|
||||||
let app = Router::new()
|
// Build the app with appropriate auth layers based on config
|
||||||
.nest("/api/v1", routes::api_v1_router())
|
let app = build_app(state, session_layer, user_repo, &config).await?;
|
||||||
.layer(auth_layer)
|
|
||||||
.with_state(state);
|
|
||||||
|
|
||||||
let app = apply_standard_middleware(app, &server_config);
|
let app = apply_standard_middleware(app, &server_config);
|
||||||
|
|
||||||
let addr: SocketAddr = format!("{}:{}", config.host, config.port).parse()?;
|
let addr: SocketAddr = format!("{}:{}", config.host, config.port).parse()?;
|
||||||
let listener = TcpListener::bind(addr).await?;
|
let listener = TcpListener::bind(addr).await?;
|
||||||
|
|
||||||
tracing::info!("🚀 API server running at http://{}", addr);
|
tracing::info!("🚀 API server running at http://{}", addr);
|
||||||
tracing::info!("🔒 Authentication enabled (axum-login)");
|
log_auth_info(&config);
|
||||||
tracing::info!("📝 API endpoints available at /api/v1/...");
|
tracing::info!("📝 API endpoints available at /api/v1/...");
|
||||||
|
|
||||||
axum::serve(listener, app).await?;
|
axum::serve(listener, app).await?;
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Build the application router with appropriate auth layers
|
||||||
|
#[allow(unused_variables)] // config/user_repo used conditionally based on features
|
||||||
|
async fn build_app(
|
||||||
|
state: AppState,
|
||||||
|
session_layer: SessionManagerLayer<infra::session_store::InfraSessionStore>,
|
||||||
|
user_repo: std::sync::Arc<dyn domain::UserRepository>,
|
||||||
|
config: &Config,
|
||||||
|
) -> anyhow::Result<Router> {
|
||||||
|
let app = Router::new()
|
||||||
|
.nest("/api/v1", routes::api_v1_router())
|
||||||
|
.with_state(state);
|
||||||
|
|
||||||
|
// When auth-axum-login feature is enabled, always apply the auth layer.
|
||||||
|
// This is needed because:
|
||||||
|
// 1. OIDC callback uses AuthSession for state management
|
||||||
|
// 2. Session-based login/register routes use it
|
||||||
|
// 3. The "JWT mode" just changes what the login endpoint returns, not the underlying session support
|
||||||
|
#[cfg(feature = "auth-axum-login")]
|
||||||
|
{
|
||||||
|
let auth_layer = auth::setup_auth_layer(session_layer, user_repo).await?;
|
||||||
|
return Ok(app.layer(auth_layer));
|
||||||
|
}
|
||||||
|
|
||||||
|
// When auth-axum-login is not compiled in, just use session layer for OIDC flow
|
||||||
|
#[cfg(not(feature = "auth-axum-login"))]
|
||||||
|
{
|
||||||
|
let _ = user_repo; // Suppress unused warning
|
||||||
|
Ok(app.layer(session_layer))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Log authentication info based on enabled features and config
|
||||||
|
fn log_auth_info(config: &Config) {
|
||||||
|
match config.auth_mode {
|
||||||
|
AuthMode::Session => {
|
||||||
|
tracing::info!("🔒 Authentication mode: Session (cookie-based)");
|
||||||
|
}
|
||||||
|
AuthMode::Jwt => {
|
||||||
|
tracing::info!("🔒 Authentication mode: JWT (Bearer token)");
|
||||||
|
}
|
||||||
|
AuthMode::Both => {
|
||||||
|
tracing::info!("🔒 Authentication mode: Both (JWT + Session)");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "auth-axum-login")]
|
||||||
|
tracing::info!(" ✓ Session auth enabled (axum-login)");
|
||||||
|
|
||||||
|
#[cfg(feature = "auth-jwt")]
|
||||||
|
if config.jwt_secret.is_some() {
|
||||||
|
tracing::info!(" ✓ JWT auth enabled");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "auth-oidc")]
|
||||||
|
if config.oidc_issuer.is_some() {
|
||||||
|
tracing::info!(" ✓ OIDC integration enabled");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,35 +1,75 @@
|
|||||||
use axum::http::StatusCode;
|
//! Authentication routes
|
||||||
|
//!
|
||||||
|
//! Provides login, register, logout, and token endpoints.
|
||||||
|
//! Supports both session-based and JWT-based authentication.
|
||||||
|
|
||||||
|
#[cfg(feature = "auth-oidc")]
|
||||||
|
use axum::response::Response;
|
||||||
use axum::{
|
use axum::{
|
||||||
Router,
|
Router,
|
||||||
extract::{Json, State},
|
extract::{Json, State},
|
||||||
|
http::StatusCode,
|
||||||
response::IntoResponse,
|
response::IntoResponse,
|
||||||
routing::post,
|
routing::{get, post},
|
||||||
};
|
};
|
||||||
|
use serde::Serialize;
|
||||||
|
#[cfg(feature = "auth-oidc")]
|
||||||
|
use tower_sessions::Session;
|
||||||
|
|
||||||
|
#[cfg(feature = "auth-axum-login")]
|
||||||
|
use crate::config::AuthMode;
|
||||||
use crate::{
|
use crate::{
|
||||||
dto::{LoginRequest, RegisterRequest, UserResponse},
|
dto::{LoginRequest, RegisterRequest, UserResponse},
|
||||||
error::ApiError,
|
error::ApiError,
|
||||||
|
extractors::CurrentUser,
|
||||||
state::AppState,
|
state::AppState,
|
||||||
};
|
};
|
||||||
|
#[cfg(feature = "auth-axum-login")]
|
||||||
use domain::{DomainError, Email};
|
use domain::{DomainError, Email};
|
||||||
use tower_sessions::Session;
|
|
||||||
|
/// Token response for JWT authentication
|
||||||
|
#[derive(Debug, Serialize)]
|
||||||
|
pub struct TokenResponse {
|
||||||
|
pub access_token: String,
|
||||||
|
pub token_type: String,
|
||||||
|
pub expires_in: u64,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Login response that can be either a user (session mode) or a token (JWT mode)
|
||||||
|
#[derive(Debug, Serialize)]
|
||||||
|
#[serde(untagged)]
|
||||||
|
pub enum LoginResponse {
|
||||||
|
User(UserResponse),
|
||||||
|
Token(TokenResponse),
|
||||||
|
}
|
||||||
|
|
||||||
pub fn router() -> Router<AppState> {
|
pub fn router() -> Router<AppState> {
|
||||||
let r = Router::new()
|
let r = Router::new()
|
||||||
.route("/login", post(login))
|
.route("/login", post(login))
|
||||||
.route("/register", post(register))
|
.route("/register", post(register))
|
||||||
.route("/logout", post(logout))
|
.route("/logout", post(logout))
|
||||||
.route("/me", post(me));
|
.route("/me", get(me));
|
||||||
|
|
||||||
|
// Add token endpoint for getting JWT from session
|
||||||
|
#[cfg(feature = "auth-jwt")]
|
||||||
|
let r = r.route("/token", post(get_token));
|
||||||
|
|
||||||
#[cfg(feature = "auth-oidc")]
|
#[cfg(feature = "auth-oidc")]
|
||||||
let r = r
|
let r = r
|
||||||
.route("/login/oidc", axum::routing::get(oidc_login))
|
.route("/login/oidc", get(oidc_login))
|
||||||
.route("/auth/callback", axum::routing::get(oidc_callback));
|
.route("/callback", get(oidc_callback));
|
||||||
|
|
||||||
r
|
r
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Login endpoint
|
||||||
|
///
|
||||||
|
/// In session mode: Creates a session and returns user info
|
||||||
|
/// In JWT mode: Validates credentials and returns a JWT token
|
||||||
|
/// In both mode: Creates session AND returns JWT token
|
||||||
|
#[cfg(feature = "auth-axum-login")]
|
||||||
async fn login(
|
async fn login(
|
||||||
|
State(state): State<AppState>,
|
||||||
mut auth_session: crate::auth::AuthSession,
|
mut auth_session: crate::auth::AuthSession,
|
||||||
Json(payload): Json<LoginRequest>,
|
Json(payload): Json<LoginRequest>,
|
||||||
) -> Result<impl IntoResponse, ApiError> {
|
) -> Result<impl IntoResponse, ApiError> {
|
||||||
@@ -45,21 +85,56 @@ async fn login(
|
|||||||
None => return Err(ApiError::Validation("Invalid credentials".to_string())),
|
None => return Err(ApiError::Validation("Invalid credentials".to_string())),
|
||||||
};
|
};
|
||||||
|
|
||||||
|
let auth_mode = state.config.auth_mode;
|
||||||
|
|
||||||
|
// In session or both mode, create session
|
||||||
|
if matches!(auth_mode, AuthMode::Session | AuthMode::Both) {
|
||||||
auth_session
|
auth_session
|
||||||
.login(&user)
|
.login(&user)
|
||||||
.await
|
.await
|
||||||
.map_err(|_| ApiError::Internal("Login failed".to_string()))?;
|
.map_err(|_| ApiError::Internal("Login failed".to_string()))?;
|
||||||
|
}
|
||||||
|
|
||||||
|
// In JWT or both mode, return token
|
||||||
|
#[cfg(feature = "auth-jwt")]
|
||||||
|
if matches!(auth_mode, AuthMode::Jwt | AuthMode::Both) {
|
||||||
|
let token = create_jwt_for_user(&user.0, &state)?;
|
||||||
|
return Ok((
|
||||||
|
StatusCode::OK,
|
||||||
|
Json(LoginResponse::Token(TokenResponse {
|
||||||
|
access_token: token,
|
||||||
|
token_type: "Bearer".to_string(),
|
||||||
|
expires_in: state.config.jwt_expiry_hours * 3600,
|
||||||
|
})),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Session mode: return user info
|
||||||
Ok((
|
Ok((
|
||||||
StatusCode::OK,
|
StatusCode::OK,
|
||||||
Json(UserResponse {
|
Json(LoginResponse::User(UserResponse {
|
||||||
id: user.0.id,
|
id: user.0.id,
|
||||||
email: user.0.email.into_inner(),
|
email: user.0.email.into_inner(),
|
||||||
created_at: user.0.created_at,
|
created_at: user.0.created_at,
|
||||||
}),
|
})),
|
||||||
))
|
))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Fallback login when auth-axum-login is not enabled
|
||||||
|
/// Without auth-axum-login, password-based authentication is not available.
|
||||||
|
/// Use OIDC login instead: GET /api/v1/auth/login/oidc
|
||||||
|
#[cfg(not(feature = "auth-axum-login"))]
|
||||||
|
async fn login(
|
||||||
|
State(_state): State<AppState>,
|
||||||
|
Json(_payload): Json<LoginRequest>,
|
||||||
|
) -> Result<(StatusCode, Json<LoginResponse>), ApiError> {
|
||||||
|
Err(ApiError::Internal(
|
||||||
|
"Password-based login not available. auth-axum-login feature is required. Use OIDC login at /api/v1/auth/login/oidc instead.".to_string(),
|
||||||
|
))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Register endpoint
|
||||||
|
#[cfg(feature = "auth-axum-login")]
|
||||||
async fn register(
|
async fn register(
|
||||||
State(state): State<AppState>,
|
State(state): State<AppState>,
|
||||||
mut auth_session: crate::auth::AuthSession,
|
mut auth_session: crate::auth::AuthSession,
|
||||||
@@ -76,9 +151,6 @@ async fn register(
|
|||||||
)));
|
)));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Note: In a real app, you would hash the password here.
|
|
||||||
// This template uses a simplified User::new which doesn't take password.
|
|
||||||
// You should extend User to handle passwords or use an OIDC flow.
|
|
||||||
let email = Email::try_from(payload.email).map_err(|e| ApiError::Validation(e.to_string()))?;
|
let email = Email::try_from(payload.email).map_err(|e| ApiError::Validation(e.to_string()))?;
|
||||||
|
|
||||||
// Using email as subject for local auth for now
|
// Using email as subject for local auth for now
|
||||||
@@ -87,24 +159,54 @@ async fn register(
|
|||||||
.find_or_create(&email.as_ref().to_string(), email.as_ref())
|
.find_or_create(&email.as_ref().to_string(), email.as_ref())
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
// Log the user in
|
let auth_mode = state.config.auth_mode;
|
||||||
let auth_user = crate::auth::AuthUser(user.clone());
|
|
||||||
|
|
||||||
|
// In session or both mode, create session
|
||||||
|
if matches!(auth_mode, AuthMode::Session | AuthMode::Both) {
|
||||||
|
let auth_user = crate::auth::AuthUser(user.clone());
|
||||||
auth_session
|
auth_session
|
||||||
.login(&auth_user)
|
.login(&auth_user)
|
||||||
.await
|
.await
|
||||||
.map_err(|_| ApiError::Internal("Login failed".to_string()))?;
|
.map_err(|_| ApiError::Internal("Login failed".to_string()))?;
|
||||||
|
}
|
||||||
|
|
||||||
|
// In JWT or both mode, return token
|
||||||
|
#[cfg(feature = "auth-jwt")]
|
||||||
|
if matches!(auth_mode, AuthMode::Jwt | AuthMode::Both) {
|
||||||
|
let token = create_jwt_for_user(&user, &state)?;
|
||||||
|
return Ok((
|
||||||
|
StatusCode::CREATED,
|
||||||
|
Json(LoginResponse::Token(TokenResponse {
|
||||||
|
access_token: token,
|
||||||
|
token_type: "Bearer".to_string(),
|
||||||
|
expires_in: state.config.jwt_expiry_hours * 3600,
|
||||||
|
})),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
Ok((
|
Ok((
|
||||||
StatusCode::CREATED,
|
StatusCode::CREATED,
|
||||||
Json(UserResponse {
|
Json(LoginResponse::User(UserResponse {
|
||||||
id: user.id,
|
id: user.id,
|
||||||
email: user.email.into_inner(),
|
email: user.email.into_inner(),
|
||||||
created_at: user.created_at,
|
created_at: user.created_at,
|
||||||
}),
|
})),
|
||||||
))
|
))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Fallback register when auth-axum-login is not enabled
|
||||||
|
#[cfg(not(feature = "auth-axum-login"))]
|
||||||
|
async fn register(
|
||||||
|
State(_state): State<AppState>,
|
||||||
|
Json(_payload): Json<RegisterRequest>,
|
||||||
|
) -> Result<(StatusCode, Json<LoginResponse>), ApiError> {
|
||||||
|
Err(ApiError::Internal(
|
||||||
|
"Session-based registration not available. Use JWT token endpoint.".to_string(),
|
||||||
|
))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Logout endpoint
|
||||||
|
#[cfg(feature = "auth-axum-login")]
|
||||||
async fn logout(mut auth_session: crate::auth::AuthSession) -> impl IntoResponse {
|
async fn logout(mut auth_session: crate::auth::AuthSession) -> impl IntoResponse {
|
||||||
match auth_session.logout().await {
|
match auth_session.logout().await {
|
||||||
Ok(_) => StatusCode::OK,
|
Ok(_) => StatusCode::OK,
|
||||||
@@ -112,23 +214,61 @@ async fn logout(mut auth_session: crate::auth::AuthSession) -> impl IntoResponse
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn me(auth_session: crate::auth::AuthSession) -> Result<impl IntoResponse, ApiError> {
|
/// Fallback logout when auth-axum-login is not enabled
|
||||||
let user = auth_session
|
#[cfg(not(feature = "auth-axum-login"))]
|
||||||
.user
|
async fn logout() -> impl IntoResponse {
|
||||||
.ok_or(ApiError::Unauthorized("Not logged in".to_string()))?;
|
// JWT tokens can't be "logged out" server-side without a blocklist
|
||||||
|
// Just return OK
|
||||||
|
StatusCode::OK
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get current user info
|
||||||
|
async fn me(CurrentUser(user): CurrentUser) -> Result<impl IntoResponse, ApiError> {
|
||||||
Ok(Json(UserResponse {
|
Ok(Json(UserResponse {
|
||||||
id: user.0.id,
|
id: user.id,
|
||||||
email: user.0.email.into_inner(),
|
email: user.email.into_inner(),
|
||||||
created_at: user.0.created_at,
|
created_at: user.created_at,
|
||||||
}))
|
}))
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(feature = "auth-oidc")]
|
/// Get a JWT token for the current session user
|
||||||
async fn oidc_login(
|
///
|
||||||
|
/// This allows session-authenticated users to obtain a JWT for API access.
|
||||||
|
#[cfg(feature = "auth-jwt")]
|
||||||
|
async fn get_token(
|
||||||
State(state): State<AppState>,
|
State(state): State<AppState>,
|
||||||
session: Session,
|
CurrentUser(user): CurrentUser,
|
||||||
) -> Result<impl IntoResponse, ApiError> {
|
) -> Result<impl IntoResponse, ApiError> {
|
||||||
|
let token = create_jwt_for_user(&user, &state)?;
|
||||||
|
|
||||||
|
Ok(Json(TokenResponse {
|
||||||
|
access_token: token,
|
||||||
|
token_type: "Bearer".to_string(),
|
||||||
|
expires_in: state.config.jwt_expiry_hours * 3600,
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Helper to create JWT for a user
|
||||||
|
#[cfg(feature = "auth-jwt")]
|
||||||
|
fn create_jwt_for_user(user: &domain::User, state: &AppState) -> Result<String, ApiError> {
|
||||||
|
let validator = state
|
||||||
|
.jwt_validator
|
||||||
|
.as_ref()
|
||||||
|
.ok_or_else(|| ApiError::Internal("JWT not configured".to_string()))?;
|
||||||
|
|
||||||
|
validator
|
||||||
|
.create_token(user)
|
||||||
|
.map_err(|e| ApiError::Internal(format!("Failed to create token: {}", e)))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============================================================================
|
||||||
|
// OIDC Routes
|
||||||
|
// ============================================================================
|
||||||
|
|
||||||
|
#[cfg(feature = "auth-oidc")]
|
||||||
|
async fn oidc_login(State(state): State<AppState>, session: Session) -> Result<Response, ApiError> {
|
||||||
|
use axum::http::header;
|
||||||
|
|
||||||
let service = state
|
let service = state
|
||||||
.oidc_service
|
.oidc_service
|
||||||
.as_ref()
|
.as_ref()
|
||||||
@@ -149,7 +289,19 @@ async fn oidc_login(
|
|||||||
.await
|
.await
|
||||||
.map_err(|_| ApiError::Internal("Session error".into()))?;
|
.map_err(|_| ApiError::Internal("Session error".into()))?;
|
||||||
|
|
||||||
Ok(axum::response::Redirect::to(&url))
|
let response = axum::response::Redirect::to(&url).into_response();
|
||||||
|
let (mut parts, body) = response.into_parts();
|
||||||
|
|
||||||
|
parts.headers.insert(
|
||||||
|
header::CACHE_CONTROL,
|
||||||
|
"no-cache, no-store, must-revalidate".parse().unwrap(),
|
||||||
|
);
|
||||||
|
parts
|
||||||
|
.headers
|
||||||
|
.insert(header::PRAGMA, "no-cache".parse().unwrap());
|
||||||
|
parts.headers.insert(header::EXPIRES, "0".parse().unwrap());
|
||||||
|
|
||||||
|
Ok(Response::from_parts(parts, body))
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(feature = "auth-oidc")]
|
#[cfg(feature = "auth-oidc")]
|
||||||
@@ -159,7 +311,7 @@ struct CallbackParams {
|
|||||||
state: String,
|
state: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(feature = "auth-oidc")]
|
#[cfg(all(feature = "auth-oidc", feature = "auth-axum-login"))]
|
||||||
async fn oidc_callback(
|
async fn oidc_callback(
|
||||||
State(state): State<AppState>,
|
State(state): State<AppState>,
|
||||||
session: Session,
|
session: Session,
|
||||||
@@ -181,7 +333,6 @@ async fn oidc_callback(
|
|||||||
return Err(ApiError::Validation("Invalid CSRF token".into()));
|
return Err(ApiError::Validation("Invalid CSRF token".into()));
|
||||||
}
|
}
|
||||||
|
|
||||||
// 2. Retrieve secrets
|
|
||||||
let stored_pkce: String = session
|
let stored_pkce: String = session
|
||||||
.get("oidc_pkce")
|
.get("oidc_pkce")
|
||||||
.await
|
.await
|
||||||
@@ -204,11 +355,17 @@ async fn oidc_callback(
|
|||||||
.await
|
.await
|
||||||
.map_err(|e| ApiError::Internal(e.to_string()))?;
|
.map_err(|e| ApiError::Internal(e.to_string()))?;
|
||||||
|
|
||||||
|
let auth_mode = state.config.auth_mode;
|
||||||
|
|
||||||
|
// In session or both mode, create session
|
||||||
|
if matches!(auth_mode, AuthMode::Session | AuthMode::Both) {
|
||||||
auth_session
|
auth_session
|
||||||
.login(&crate::auth::AuthUser(user))
|
.login(&crate::auth::AuthUser(user.clone()))
|
||||||
.await
|
.await
|
||||||
.map_err(|_| ApiError::Internal("Login failed".into()))?;
|
.map_err(|_| ApiError::Internal("Login failed".into()))?;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clean up OIDC state
|
||||||
let _: Option<String> = session
|
let _: Option<String> = session
|
||||||
.remove("oidc_csrf")
|
.remove("oidc_csrf")
|
||||||
.await
|
.await
|
||||||
@@ -222,5 +379,101 @@ async fn oidc_callback(
|
|||||||
.await
|
.await
|
||||||
.map_err(|_| ApiError::Internal("Session error".into()))?;
|
.map_err(|_| ApiError::Internal("Session error".into()))?;
|
||||||
|
|
||||||
Ok(axum::response::Redirect::to("/"))
|
// In JWT mode, return token as JSON
|
||||||
|
#[cfg(feature = "auth-jwt")]
|
||||||
|
if matches!(auth_mode, AuthMode::Jwt | AuthMode::Both) {
|
||||||
|
let token = create_jwt_for_user(&user, &state)?;
|
||||||
|
return Ok(Json(TokenResponse {
|
||||||
|
access_token: token,
|
||||||
|
token_type: "Bearer".to_string(),
|
||||||
|
expires_in: state.config.jwt_expiry_hours * 3600,
|
||||||
|
})
|
||||||
|
.into_response());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Session mode: return user info
|
||||||
|
Ok(Json(UserResponse {
|
||||||
|
id: user.id,
|
||||||
|
email: user.email.into_inner(),
|
||||||
|
created_at: user.created_at,
|
||||||
|
})
|
||||||
|
.into_response())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Fallback OIDC callback when auth-axum-login is not enabled
|
||||||
|
#[cfg(all(feature = "auth-oidc", not(feature = "auth-axum-login")))]
|
||||||
|
async fn oidc_callback(
|
||||||
|
State(state): State<AppState>,
|
||||||
|
session: Session,
|
||||||
|
axum::extract::Query(params): axum::extract::Query<CallbackParams>,
|
||||||
|
) -> Result<impl IntoResponse, ApiError> {
|
||||||
|
let service = state
|
||||||
|
.oidc_service
|
||||||
|
.as_ref()
|
||||||
|
.ok_or(ApiError::Internal("OIDC not configured".into()))?;
|
||||||
|
|
||||||
|
let stored_csrf: String = session
|
||||||
|
.get("oidc_csrf")
|
||||||
|
.await
|
||||||
|
.map_err(|_| ApiError::Internal("Session error".into()))?
|
||||||
|
.ok_or(ApiError::Validation("Missing CSRF token".into()))?;
|
||||||
|
|
||||||
|
if params.state != stored_csrf {
|
||||||
|
return Err(ApiError::Validation("Invalid CSRF token".into()));
|
||||||
|
}
|
||||||
|
|
||||||
|
let stored_pkce: String = session
|
||||||
|
.get("oidc_pkce")
|
||||||
|
.await
|
||||||
|
.map_err(|_| ApiError::Internal("Session error".into()))?
|
||||||
|
.ok_or(ApiError::Validation("Missing PKCE".into()))?;
|
||||||
|
let stored_nonce: String = session
|
||||||
|
.get("oidc_nonce")
|
||||||
|
.await
|
||||||
|
.map_err(|_| ApiError::Internal("Session error".into()))?
|
||||||
|
.ok_or(ApiError::Validation("Missing Nonce".into()))?;
|
||||||
|
|
||||||
|
let oidc_user = service
|
||||||
|
.resolve_callback(params.code, stored_nonce, stored_pkce)
|
||||||
|
.await
|
||||||
|
.map_err(|e| ApiError::Internal(e.to_string()))?;
|
||||||
|
|
||||||
|
let user = state
|
||||||
|
.user_service
|
||||||
|
.find_or_create(&oidc_user.subject, &oidc_user.email)
|
||||||
|
.await
|
||||||
|
.map_err(|e| ApiError::Internal(e.to_string()))?;
|
||||||
|
|
||||||
|
// Clean up OIDC state
|
||||||
|
let _: Option<String> = session
|
||||||
|
.remove("oidc_csrf")
|
||||||
|
.await
|
||||||
|
.map_err(|_| ApiError::Internal("Session error".into()))?;
|
||||||
|
let _: Option<String> = session
|
||||||
|
.remove("oidc_pkce")
|
||||||
|
.await
|
||||||
|
.map_err(|_| ApiError::Internal("Session error".into()))?;
|
||||||
|
let _: Option<String> = session
|
||||||
|
.remove("oidc_nonce")
|
||||||
|
.await
|
||||||
|
.map_err(|_| ApiError::Internal("Session error".into()))?;
|
||||||
|
|
||||||
|
// Return token as JSON
|
||||||
|
#[cfg(feature = "auth-jwt")]
|
||||||
|
{
|
||||||
|
let token = create_jwt_for_user(&user, &state)?;
|
||||||
|
return Ok(Json(TokenResponse {
|
||||||
|
access_token: token,
|
||||||
|
token_type: "Bearer".to_string(),
|
||||||
|
expires_in: state.config.jwt_expiry_hours * 3600,
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(not(feature = "auth-jwt"))]
|
||||||
|
{
|
||||||
|
let _ = user; // Suppress unused warning
|
||||||
|
Err(ApiError::Internal(
|
||||||
|
"No auth backend available for OIDC callback".to_string(),
|
||||||
|
))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,11 +3,13 @@
|
|||||||
//! Holds shared state for the application.
|
//! Holds shared state for the application.
|
||||||
|
|
||||||
use axum::extract::FromRef;
|
use axum::extract::FromRef;
|
||||||
|
#[cfg(feature = "auth-jwt")]
|
||||||
|
use infra::auth::jwt::{JwtConfig, JwtValidator};
|
||||||
#[cfg(feature = "auth-oidc")]
|
#[cfg(feature = "auth-oidc")]
|
||||||
use infra::auth::oidc::OidcService;
|
use infra::auth::oidc::OidcService;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
use crate::config::Config;
|
use crate::config::{AuthMode, Config};
|
||||||
use domain::UserService;
|
use domain::UserService;
|
||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
@@ -15,32 +17,80 @@ pub struct AppState {
|
|||||||
pub user_service: Arc<UserService>,
|
pub user_service: Arc<UserService>,
|
||||||
#[cfg(feature = "auth-oidc")]
|
#[cfg(feature = "auth-oidc")]
|
||||||
pub oidc_service: Option<Arc<OidcService>>,
|
pub oidc_service: Option<Arc<OidcService>>,
|
||||||
|
#[cfg(feature = "auth-jwt")]
|
||||||
|
pub jwt_validator: Option<Arc<JwtValidator>>,
|
||||||
pub config: Arc<Config>,
|
pub config: Arc<Config>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl AppState {
|
impl AppState {
|
||||||
pub async fn new(user_service: UserService, config: Config) -> anyhow::Result<Self> {
|
pub async fn new(user_service: UserService, config: Config) -> anyhow::Result<Self> {
|
||||||
#[cfg(feature = "auth-oidc")]
|
#[cfg(feature = "auth-oidc")]
|
||||||
let oidc_service = if let (Some(issuer), Some(id), Some(secret), Some(redirect)) = (
|
let oidc_service = if let (Some(issuer), Some(id), secret, Some(redirect), resource_id) = (
|
||||||
&config.oidc_issuer,
|
&config.oidc_issuer,
|
||||||
&config.oidc_client_id,
|
&config.oidc_client_id,
|
||||||
&config.oidc_client_secret,
|
&config.oidc_client_secret,
|
||||||
&config.oidc_redirect_url,
|
&config.oidc_redirect_url,
|
||||||
|
&config.oidc_resource_id,
|
||||||
) {
|
) {
|
||||||
tracing::info!("Initializing OIDC service with issuer: {}", issuer);
|
tracing::info!("Initializing OIDC service with issuer: {}", issuer);
|
||||||
Some(Arc::new(
|
Some(Arc::new(
|
||||||
OidcService::new(issuer.clone(), id.clone(), secret.clone(), redirect.clone())
|
OidcService::new(
|
||||||
|
issuer.clone(),
|
||||||
|
id.clone(),
|
||||||
|
secret.clone().unwrap_or_default(),
|
||||||
|
redirect.clone(),
|
||||||
|
resource_id.clone(),
|
||||||
|
)
|
||||||
.await?,
|
.await?,
|
||||||
))
|
))
|
||||||
} else {
|
} else {
|
||||||
None
|
None
|
||||||
};
|
};
|
||||||
|
|
||||||
|
#[cfg(feature = "auth-jwt")]
|
||||||
|
let jwt_validator = if matches!(config.auth_mode, AuthMode::Jwt | AuthMode::Both) {
|
||||||
|
// Use provided secret or fall back to a development secret
|
||||||
|
let secret = if let Some(ref s) = config.jwt_secret {
|
||||||
|
if s.is_empty() { None } else { Some(s.clone()) }
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
|
||||||
|
let secret = match secret {
|
||||||
|
Some(s) => s,
|
||||||
|
None => {
|
||||||
|
if config.is_production {
|
||||||
|
anyhow::bail!(
|
||||||
|
"JWT_SECRET is required when AUTH_MODE is 'jwt' or 'both' in production"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
// Use a development-only default secret
|
||||||
|
tracing::warn!(
|
||||||
|
"⚠️ JWT_SECRET not set - using insecure development secret. DO NOT USE IN PRODUCTION!"
|
||||||
|
);
|
||||||
|
"k-template-dev-secret-not-for-production-use-only".to_string()
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
tracing::info!("Initializing JWT validator");
|
||||||
|
let jwt_config = JwtConfig::new(
|
||||||
|
secret,
|
||||||
|
config.jwt_issuer.clone(),
|
||||||
|
config.jwt_audience.clone(),
|
||||||
|
Some(config.jwt_expiry_hours),
|
||||||
|
config.is_production,
|
||||||
|
)?;
|
||||||
|
Some(Arc::new(JwtValidator::new(jwt_config)))
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
user_service: Arc::new(user_service),
|
user_service: Arc::new(user_service),
|
||||||
#[cfg(feature = "auth-oidc")]
|
#[cfg(feature = "auth-oidc")]
|
||||||
oidc_service,
|
oidc_service,
|
||||||
|
#[cfg(feature = "auth-jwt")]
|
||||||
|
jwt_validator,
|
||||||
config: Arc::new(config),
|
config: Arc::new(config),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ postgres = [
|
|||||||
broker-nats = ["dep:futures-util", "k-core/broker-nats"]
|
broker-nats = ["dep:futures-util", "k-core/broker-nats"]
|
||||||
auth-axum-login = ["dep:axum-login", "dep:password-auth"]
|
auth-axum-login = ["dep:axum-login", "dep:password-auth"]
|
||||||
auth-oidc = ["dep:openidconnect", "dep:url"]
|
auth-oidc = ["dep:openidconnect", "dep:url"]
|
||||||
|
auth-jwt = ["dep:jsonwebtoken"]
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
k-core = { git = "https://git.gabrielkaszewski.dev/GKaszewski/k-core", features = [
|
k-core = { git = "https://git.gabrielkaszewski.dev/GKaszewski/k-core", features = [
|
||||||
@@ -50,4 +51,5 @@ axum-login = { version = "0.18", optional = true }
|
|||||||
password-auth = { version = "1.0", optional = true }
|
password-auth = { version = "1.0", optional = true }
|
||||||
openidconnect = { version = "4.0.1", optional = true }
|
openidconnect = { version = "4.0.1", optional = true }
|
||||||
url = { version = "2.5.8", optional = true }
|
url = { version = "2.5.8", optional = true }
|
||||||
|
jsonwebtoken = { version = "9.3", optional = true }
|
||||||
# reqwest = { version = "0.13.1", features = ["blocking", "json"], optional = true }
|
# reqwest = { version = "0.13.1", features = ["blocking", "json"], optional = true }
|
||||||
|
|||||||
278
infra/src/auth/jwt.rs
Normal file
278
infra/src/auth/jwt.rs
Normal file
@@ -0,0 +1,278 @@
|
|||||||
|
//! JWT Authentication Infrastructure
|
||||||
|
//!
|
||||||
|
//! Provides JWT token creation and validation using HS256 (secret-based).
|
||||||
|
//! For OIDC/JWKS validation, see the `oidc` module.
|
||||||
|
|
||||||
|
use domain::User;
|
||||||
|
use jsonwebtoken::{Algorithm, DecodingKey, EncodingKey, Header, Validation, decode, encode};
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use std::time::{SystemTime, UNIX_EPOCH};
|
||||||
|
|
||||||
|
/// Minimum secret length for production (256 bits = 32 bytes)
|
||||||
|
const MIN_SECRET_LENGTH: usize = 32;
|
||||||
|
|
||||||
|
/// JWT configuration
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct JwtConfig {
|
||||||
|
/// Secret key for HS256 signing/verification
|
||||||
|
pub secret: String,
|
||||||
|
/// Expected issuer (for validation)
|
||||||
|
pub issuer: Option<String>,
|
||||||
|
/// Expected audience (for validation)
|
||||||
|
pub audience: Option<String>,
|
||||||
|
/// Token expiry in hours (default: 24)
|
||||||
|
pub expiry_hours: u64,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl JwtConfig {
|
||||||
|
/// Create a new JWT config with validation
|
||||||
|
///
|
||||||
|
/// In production mode, this will reject weak secrets.
|
||||||
|
pub fn new(
|
||||||
|
secret: String,
|
||||||
|
issuer: Option<String>,
|
||||||
|
audience: Option<String>,
|
||||||
|
expiry_hours: Option<u64>,
|
||||||
|
is_production: bool,
|
||||||
|
) -> Result<Self, JwtError> {
|
||||||
|
// Validate secret strength in production
|
||||||
|
if is_production && secret.len() < MIN_SECRET_LENGTH {
|
||||||
|
return Err(JwtError::WeakSecret {
|
||||||
|
min_length: MIN_SECRET_LENGTH,
|
||||||
|
actual_length: secret.len(),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(Self {
|
||||||
|
secret,
|
||||||
|
issuer,
|
||||||
|
audience,
|
||||||
|
expiry_hours: expiry_hours.unwrap_or(24),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create config without validation (for testing)
|
||||||
|
pub fn new_unchecked(secret: String) -> Self {
|
||||||
|
Self {
|
||||||
|
secret,
|
||||||
|
issuer: None,
|
||||||
|
audience: None,
|
||||||
|
expiry_hours: 24,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// JWT claims structure
|
||||||
|
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||||
|
pub struct JwtClaims {
|
||||||
|
/// Subject - the user's unique identifier (user ID as string)
|
||||||
|
pub sub: String,
|
||||||
|
/// User's email address
|
||||||
|
pub email: String,
|
||||||
|
/// Expiry timestamp (seconds since UNIX epoch)
|
||||||
|
pub exp: usize,
|
||||||
|
/// Issued at timestamp (seconds since UNIX epoch)
|
||||||
|
pub iat: usize,
|
||||||
|
/// Issuer
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub iss: Option<String>,
|
||||||
|
/// Audience
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub aud: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// JWT-related errors
|
||||||
|
#[derive(Debug, thiserror::Error)]
|
||||||
|
pub enum JwtError {
|
||||||
|
#[error("JWT secret is too weak: minimum {min_length} bytes required, got {actual_length}")]
|
||||||
|
WeakSecret {
|
||||||
|
min_length: usize,
|
||||||
|
actual_length: usize,
|
||||||
|
},
|
||||||
|
|
||||||
|
#[error("Token creation failed: {0}")]
|
||||||
|
CreationFailed(#[from] jsonwebtoken::errors::Error),
|
||||||
|
|
||||||
|
#[error("Token validation failed: {0}")]
|
||||||
|
ValidationFailed(String),
|
||||||
|
|
||||||
|
#[error("Token expired")]
|
||||||
|
Expired,
|
||||||
|
|
||||||
|
#[error("Invalid token format")]
|
||||||
|
InvalidFormat,
|
||||||
|
|
||||||
|
#[error("Missing configuration")]
|
||||||
|
MissingConfig,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// JWT token validator and generator
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct JwtValidator {
|
||||||
|
config: JwtConfig,
|
||||||
|
encoding_key: EncodingKey,
|
||||||
|
decoding_key: DecodingKey,
|
||||||
|
validation: Validation,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl JwtValidator {
|
||||||
|
/// Create a new JWT validator with the given configuration
|
||||||
|
pub fn new(config: JwtConfig) -> Self {
|
||||||
|
let encoding_key = EncodingKey::from_secret(config.secret.as_bytes());
|
||||||
|
let decoding_key = DecodingKey::from_secret(config.secret.as_bytes());
|
||||||
|
|
||||||
|
let mut validation = Validation::new(Algorithm::HS256);
|
||||||
|
|
||||||
|
// Configure issuer validation if set
|
||||||
|
if let Some(ref issuer) = config.issuer {
|
||||||
|
validation.set_issuer(&[issuer]);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Configure audience validation if set
|
||||||
|
if let Some(ref audience) = config.audience {
|
||||||
|
validation.set_audience(&[audience]);
|
||||||
|
}
|
||||||
|
|
||||||
|
Self {
|
||||||
|
config,
|
||||||
|
encoding_key,
|
||||||
|
decoding_key,
|
||||||
|
validation,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create a JWT token for the given user
|
||||||
|
pub fn create_token(&self, user: &User) -> Result<String, JwtError> {
|
||||||
|
let now = SystemTime::now()
|
||||||
|
.duration_since(UNIX_EPOCH)
|
||||||
|
.expect("Time went backwards")
|
||||||
|
.as_secs() as usize;
|
||||||
|
|
||||||
|
let expiry = now + (self.config.expiry_hours as usize * 3600);
|
||||||
|
|
||||||
|
let claims = JwtClaims {
|
||||||
|
sub: user.id.to_string(),
|
||||||
|
email: user.email.as_ref().to_string(),
|
||||||
|
exp: expiry,
|
||||||
|
iat: now,
|
||||||
|
iss: self.config.issuer.clone(),
|
||||||
|
aud: self.config.audience.clone(),
|
||||||
|
};
|
||||||
|
|
||||||
|
let header = Header::new(Algorithm::HS256);
|
||||||
|
encode(&header, &claims, &self.encoding_key).map_err(JwtError::CreationFailed)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Validate a JWT token and return the claims
|
||||||
|
pub fn validate_token(&self, token: &str) -> Result<JwtClaims, JwtError> {
|
||||||
|
let token_data = decode::<JwtClaims>(token, &self.decoding_key, &self.validation).map_err(
|
||||||
|
|e| match e.kind() {
|
||||||
|
jsonwebtoken::errors::ErrorKind::ExpiredSignature => JwtError::Expired,
|
||||||
|
jsonwebtoken::errors::ErrorKind::InvalidToken => JwtError::InvalidFormat,
|
||||||
|
_ => JwtError::ValidationFailed(e.to_string()),
|
||||||
|
},
|
||||||
|
)?;
|
||||||
|
|
||||||
|
Ok(token_data.claims)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get the user ID (subject) from a token without full validation
|
||||||
|
/// Useful for logging/debugging, but should not be trusted for auth
|
||||||
|
pub fn decode_unverified(&self, token: &str) -> Result<JwtClaims, JwtError> {
|
||||||
|
let mut validation = Validation::new(Algorithm::HS256);
|
||||||
|
validation.insecure_disable_signature_validation();
|
||||||
|
validation.validate_exp = false;
|
||||||
|
|
||||||
|
let token_data = decode::<JwtClaims>(token, &self.decoding_key, &validation)
|
||||||
|
.map_err(|_| JwtError::InvalidFormat)?;
|
||||||
|
|
||||||
|
Ok(token_data.claims)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::fmt::Debug for JwtValidator {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
f.debug_struct("JwtValidator")
|
||||||
|
.field("issuer", &self.config.issuer)
|
||||||
|
.field("audience", &self.config.audience)
|
||||||
|
.field("expiry_hours", &self.config.expiry_hours)
|
||||||
|
.finish_non_exhaustive()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use domain::Email;
|
||||||
|
|
||||||
|
fn create_test_user() -> User {
|
||||||
|
let email = Email::try_from("test@example.com").unwrap();
|
||||||
|
User::new("test-subject", email)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_create_and_validate_token() {
|
||||||
|
let config = JwtConfig::new_unchecked("test-secret-key-that-is-long-enough".to_string());
|
||||||
|
let validator = JwtValidator::new(config);
|
||||||
|
let user = create_test_user();
|
||||||
|
|
||||||
|
let token = validator.create_token(&user).expect("Should create token");
|
||||||
|
let claims = validator
|
||||||
|
.validate_token(&token)
|
||||||
|
.expect("Should validate token");
|
||||||
|
|
||||||
|
assert_eq!(claims.sub, user.id.to_string());
|
||||||
|
assert_eq!(claims.email, "test@example.com");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_weak_secret_rejected_in_production() {
|
||||||
|
let result = JwtConfig::new(
|
||||||
|
"short".to_string(), // Too short
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
true, // Production mode
|
||||||
|
);
|
||||||
|
|
||||||
|
assert!(matches!(result, Err(JwtError::WeakSecret { .. })));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_weak_secret_allowed_in_development() {
|
||||||
|
let result = JwtConfig::new(
|
||||||
|
"short".to_string(), // Too short but OK in dev
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
false, // Development mode
|
||||||
|
);
|
||||||
|
|
||||||
|
assert!(result.is_ok());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_invalid_token_rejected() {
|
||||||
|
let config = JwtConfig::new_unchecked("test-secret-key-that-is-long-enough".to_string());
|
||||||
|
let validator = JwtValidator::new(config);
|
||||||
|
|
||||||
|
let result = validator.validate_token("invalid.token.here");
|
||||||
|
assert!(result.is_err());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_wrong_secret_rejected() {
|
||||||
|
let config1 = JwtConfig::new_unchecked("secret-one-that-is-long-enough".to_string());
|
||||||
|
let config2 = JwtConfig::new_unchecked("secret-two-that-is-long-enough".to_string());
|
||||||
|
|
||||||
|
let validator1 = JwtValidator::new(config1);
|
||||||
|
let validator2 = JwtValidator::new(config2);
|
||||||
|
|
||||||
|
let user = create_test_user();
|
||||||
|
let token = validator1.create_token(&user).unwrap();
|
||||||
|
|
||||||
|
// Token from validator1 should fail on validator2
|
||||||
|
let result = validator2.validate_token(&token);
|
||||||
|
assert!(result.is_err());
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -118,3 +118,6 @@ pub mod backend {
|
|||||||
|
|
||||||
#[cfg(feature = "auth-oidc")]
|
#[cfg(feature = "auth-oidc")]
|
||||||
pub mod oidc;
|
pub mod oidc;
|
||||||
|
|
||||||
|
#[cfg(feature = "auth-jwt")]
|
||||||
|
pub mod jwt;
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ use openidconnect::{
|
|||||||
AccessTokenHash, AuthorizationCode, Client, ClientId, ClientSecret, CsrfToken,
|
AccessTokenHash, AuthorizationCode, Client, ClientId, ClientSecret, CsrfToken,
|
||||||
EmptyAdditionalClaims, EndpointMaybeSet, EndpointNotSet, EndpointSet, IssuerUrl, Nonce,
|
EmptyAdditionalClaims, EndpointMaybeSet, EndpointNotSet, EndpointSet, IssuerUrl, Nonce,
|
||||||
OAuth2TokenResponse, PkceCodeChallenge, PkceCodeVerifier, RedirectUrl, Scope,
|
OAuth2TokenResponse, PkceCodeChallenge, PkceCodeVerifier, RedirectUrl, Scope,
|
||||||
StandardErrorResponse, TokenResponse,
|
StandardErrorResponse, TokenResponse, UserInfoClaims,
|
||||||
core::{
|
core::{
|
||||||
CoreAuthDisplay, CoreAuthPrompt, CoreAuthenticationFlow, CoreClient, CoreErrorResponseType,
|
CoreAuthDisplay, CoreAuthPrompt, CoreAuthenticationFlow, CoreClient, CoreErrorResponseType,
|
||||||
CoreGenderClaim, CoreJsonWebKey, CoreJweContentEncryptionAlgorithm, CoreProviderMetadata,
|
CoreGenderClaim, CoreJsonWebKey, CoreJweContentEncryptionAlgorithm, CoreProviderMetadata,
|
||||||
@@ -36,6 +36,7 @@ pub type OidcClient = Client<
|
|||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub struct OidcService {
|
pub struct OidcService {
|
||||||
client: OidcClient,
|
client: OidcClient,
|
||||||
|
resource_id: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
@@ -51,7 +52,31 @@ impl OidcService {
|
|||||||
client_id: String,
|
client_id: String,
|
||||||
client_secret: String,
|
client_secret: String,
|
||||||
redirect_url: String,
|
redirect_url: String,
|
||||||
|
resource_id: Option<String>,
|
||||||
) -> anyhow::Result<Self> {
|
) -> anyhow::Result<Self> {
|
||||||
|
let client_id = client_id.trim().to_string();
|
||||||
|
let redirect_url = redirect_url.trim().to_string();
|
||||||
|
let issuer = issuer.trim().to_string();
|
||||||
|
|
||||||
|
// 2. Handle Empty Secret (For PKCE/Public Clients)
|
||||||
|
let client_secret_clean = client_secret.trim();
|
||||||
|
let client_secret_opt = if client_secret_clean.is_empty() {
|
||||||
|
None
|
||||||
|
} else {
|
||||||
|
Some(ClientSecret::new(client_secret_clean.to_string()))
|
||||||
|
};
|
||||||
|
|
||||||
|
tracing::debug!("🔵 OIDC Setup: Client ID = '{}'", client_id);
|
||||||
|
tracing::debug!("🔵 OIDC Setup: Redirect = '{}'", redirect_url);
|
||||||
|
tracing::debug!(
|
||||||
|
"🔵 OIDC Setup: Secret = {:?}",
|
||||||
|
if client_secret_opt.is_some() {
|
||||||
|
"SET"
|
||||||
|
} else {
|
||||||
|
"NONE"
|
||||||
|
}
|
||||||
|
);
|
||||||
|
|
||||||
let http_client = reqwest::ClientBuilder::new()
|
let http_client = reqwest::ClientBuilder::new()
|
||||||
.redirect(reqwest::redirect::Policy::none())
|
.redirect(reqwest::redirect::Policy::none())
|
||||||
.build()?;
|
.build()?;
|
||||||
@@ -62,11 +87,14 @@ impl OidcService {
|
|||||||
let client = CoreClient::from_provider_metadata(
|
let client = CoreClient::from_provider_metadata(
|
||||||
provider_metadata,
|
provider_metadata,
|
||||||
ClientId::new(client_id),
|
ClientId::new(client_id),
|
||||||
Some(ClientSecret::new(client_secret)),
|
client_secret_opt,
|
||||||
)
|
)
|
||||||
.set_redirect_uri(RedirectUrl::new(redirect_url)?);
|
.set_redirect_uri(RedirectUrl::new(redirect_url)?);
|
||||||
|
|
||||||
Ok(Self { client })
|
Ok(Self {
|
||||||
|
client,
|
||||||
|
resource_id,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// todo: replace this tuple with newtype
|
// todo: replace this tuple with newtype
|
||||||
@@ -118,7 +146,15 @@ impl OidcService {
|
|||||||
.id_token()
|
.id_token()
|
||||||
.ok_or_else(|| anyhow!("Server did not return an ID token"))?;
|
.ok_or_else(|| anyhow!("Server did not return an ID token"))?;
|
||||||
|
|
||||||
let id_token_verifier = self.client.id_token_verifier();
|
let mut id_token_verifier = self.client.id_token_verifier().clone();
|
||||||
|
|
||||||
|
let trusted_resource_id = self.resource_id.clone();
|
||||||
|
|
||||||
|
if let Some(resource_id) = trusted_resource_id {
|
||||||
|
id_token_verifier = id_token_verifier
|
||||||
|
.set_other_audience_verifier_fn(move |aud| aud.as_str() == resource_id);
|
||||||
|
}
|
||||||
|
|
||||||
let claims = id_token.claims(&id_token_verifier, &nonce)?;
|
let claims = id_token.claims(&id_token_verifier, &nonce)?;
|
||||||
|
|
||||||
if let Some(expected_access_token_hash) = claims.access_token_hash() {
|
if let Some(expected_access_token_hash) = claims.access_token_hash() {
|
||||||
@@ -133,13 +169,28 @@ impl OidcService {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let email = if let Some(email) = claims.email() {
|
||||||
|
Some(email.as_str().to_string())
|
||||||
|
} else {
|
||||||
|
// Fallback: Call UserInfo Endpoint using the Access Token
|
||||||
|
tracing::debug!("🔵 Email missing in ID Token, fetching UserInfo...");
|
||||||
|
|
||||||
|
let user_info: UserInfoClaims<EmptyAdditionalClaims, CoreGenderClaim> = self
|
||||||
|
.client
|
||||||
|
.user_info(token_response.access_token().clone(), None)?
|
||||||
|
.request_async(&http_client)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
user_info.email().map(|e| e.as_str().to_string())
|
||||||
|
};
|
||||||
|
|
||||||
|
// If email is still missing, we must error out because your app requires valid emails
|
||||||
|
let email =
|
||||||
|
email.ok_or_else(|| anyhow!("User has no verified email address in ZITADEL"))?;
|
||||||
|
|
||||||
Ok(OidcUser {
|
Ok(OidcUser {
|
||||||
subject: claims.subject().to_string(),
|
subject: claims.subject().to_string(),
|
||||||
email: claims
|
email,
|
||||||
.email()
|
|
||||||
.map(|email| email.as_str())
|
|
||||||
.unwrap_or("<not provided>")
|
|
||||||
.to_string(),
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user