refactor: Replace raw strings with domain value objects for improved type safety in authentication and OIDC.

This commit is contained in:
2026-01-06 05:16:16 +01:00
parent 16dcc4b95e
commit 32a0faf302
9 changed files with 667 additions and 232 deletions

107
Cargo.lock generated
View File

@@ -57,7 +57,6 @@ dependencies = [
"tracing", "tracing",
"tracing-subscriber", "tracing-subscriber",
"uuid", "uuid",
"validator",
] ]
[[package]] [[package]]
@@ -523,38 +522,14 @@ dependencies = [
"syn", "syn",
] ]
[[package]]
name = "darling"
version = "0.20.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fc7f46116c46ff9ab3eb1597a45688b6715c6e628b5c133e288e709a29bcb4ee"
dependencies = [
"darling_core 0.20.11",
"darling_macro 0.20.11",
]
[[package]] [[package]]
name = "darling" name = "darling"
version = "0.21.3" version = "0.21.3"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9cdf337090841a411e2a7f3deb9187445851f91b309c0c0a29e05f74a00a48c0" checksum = "9cdf337090841a411e2a7f3deb9187445851f91b309c0c0a29e05f74a00a48c0"
dependencies = [ dependencies = [
"darling_core 0.21.3", "darling_core",
"darling_macro 0.21.3", "darling_macro",
]
[[package]]
name = "darling_core"
version = "0.20.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0d00b9596d185e565c2207a0b01f8bd1a135483d02d9b7b0a54b11da8d53412e"
dependencies = [
"fnv",
"ident_case",
"proc-macro2",
"quote",
"strsim",
"syn",
] ]
[[package]] [[package]]
@@ -571,24 +546,13 @@ dependencies = [
"syn", "syn",
] ]
[[package]]
name = "darling_macro"
version = "0.20.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fc34b93ccb385b40dc71c6fceac4b2ad23662c7eeb248cf10d529b7e055b6ead"
dependencies = [
"darling_core 0.20.11",
"quote",
"syn",
]
[[package]] [[package]]
name = "darling_macro" name = "darling_macro"
version = "0.21.3" version = "0.21.3"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d38308df82d1080de0afee5d069fa14b0326a88c14f15c5ccda35b4a6c414c81" checksum = "d38308df82d1080de0afee5d069fa14b0326a88c14f15c5ccda35b4a6c414c81"
dependencies = [ dependencies = [
"darling_core 0.21.3", "darling_core",
"quote", "quote",
"syn", "syn",
] ]
@@ -659,12 +623,14 @@ dependencies = [
"anyhow", "anyhow",
"async-trait", "async-trait",
"chrono", "chrono",
"email_address",
"futures-core", "futures-core",
"serde", "serde",
"serde_json", "serde_json",
"thiserror 2.0.17", "thiserror 2.0.17",
"tokio", "tokio",
"tracing", "tracing",
"url",
"uuid", "uuid",
] ]
@@ -749,6 +715,15 @@ dependencies = [
"zeroize", "zeroize",
] ]
[[package]]
name = "email_address"
version = "0.2.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e079f19b08ca6239f47f8ba8509c11cf3ea30095831f7fed61441475edd8c449"
dependencies = [
"serde",
]
[[package]] [[package]]
name = "encoding_rs" name = "encoding_rs"
version = "0.8.35" version = "0.8.35"
@@ -2013,28 +1988,6 @@ dependencies = [
"elliptic-curve", "elliptic-curve",
] ]
[[package]]
name = "proc-macro-error-attr2"
version = "2.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "96de42df36bb9bba5542fe9f1a054b8cc87e172759a1868aa05c1f3acc89dfc5"
dependencies = [
"proc-macro2",
"quote",
]
[[package]]
name = "proc-macro-error2"
version = "2.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "11ec05c52be0a07b08061f7dd003e7d7092e0472bc731b4af7bb1ef876109802"
dependencies = [
"proc-macro-error-attr2",
"proc-macro2",
"quote",
"syn",
]
[[package]] [[package]]
name = "proc-macro2" name = "proc-macro2"
version = "1.0.104" version = "1.0.104"
@@ -2692,7 +2645,7 @@ version = "3.16.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "52a8e3ca0ca629121f70ab50f95249e5a6f925cc0f6ffe8256c45b728875706c" checksum = "52a8e3ca0ca629121f70ab50f95249e5a6f925cc0f6ffe8256c45b728875706c"
dependencies = [ dependencies = [
"darling 0.21.3", "darling",
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn", "syn",
@@ -3611,36 +3564,6 @@ dependencies = [
"wasm-bindgen", "wasm-bindgen",
] ]
[[package]]
name = "validator"
version = "0.20.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "43fb22e1a008ece370ce08a3e9e4447a910e92621bb49b85d6e48a45397e7cfa"
dependencies = [
"idna",
"once_cell",
"regex",
"serde",
"serde_derive",
"serde_json",
"url",
"validator_derive",
]
[[package]]
name = "validator_derive"
version = "0.20.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b7df16e474ef958526d1205f6dda359fdfab79d9aa6d54bafcb92dcd07673dca"
dependencies = [
"darling 0.20.11",
"once_cell",
"proc-macro-error2",
"proc-macro2",
"quote",
"syn",
]
[[package]] [[package]]
name = "valuable" name = "valuable"
version = "0.1.1" version = "0.1.1"

View File

@@ -44,8 +44,7 @@ tokio = { version = "1.48.0", features = ["full"] }
serde = { version = "1.0.228", features = ["derive"] } serde = { version = "1.0.228", features = ["derive"] }
serde_json = "1.0" serde_json = "1.0"
# Validation # Validation via domain newtypes (Email, Password)
validator = { version = "0.20", features = ["derive"] }
# Error handling # Error handling
thiserror = "2.0.17" thiserror = "2.0.17"

View File

@@ -1,30 +1,29 @@
//! Request and Response DTOs //! Request and Response DTOs
//! //!
//! Data Transfer Objects for the API. //! Data Transfer Objects for the API.
//! Uses domain newtypes for validation instead of the validator crate.
use chrono::{DateTime, Utc}; use chrono::{DateTime, Utc};
use domain::{Email, Password};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use uuid::Uuid; use uuid::Uuid;
use validator::Validate;
/// Login request /// Login request with validated email and password newtypes
#[derive(Debug, Deserialize, Validate)] #[derive(Debug, Deserialize)]
pub struct LoginRequest { pub struct LoginRequest {
#[validate(email(message = "Invalid email format"))] /// Email is validated on deserialization
pub email: String, pub email: Email,
/// Password is validated on deserialization (min 6 chars)
#[validate(length(min = 6, message = "Password must be at least 6 characters"))] pub password: Password,
pub password: String,
} }
/// Register request /// Register request with validated email and password newtypes
#[derive(Debug, Deserialize, Validate)] #[derive(Debug, Deserialize)]
pub struct RegisterRequest { pub struct RegisterRequest {
#[validate(email(message = "Invalid email format"))] /// Email is validated on deserialization
pub email: String, pub email: Email,
/// Password is validated on deserialization (min 6 chars)
#[validate(length(min = 6, message = "Password must be at least 6 characters"))] pub password: Password,
pub password: String,
} }
/// User response DTO /// User response DTO
@@ -40,12 +39,3 @@ pub struct UserResponse {
pub struct ConfigResponse { pub struct ConfigResponse {
pub allow_registration: bool, pub allow_registration: bool,
} }
#[cfg(feature = "auth-jwt")]
#[derive(Debug, Serialize, Deserialize)]
// also newtypes
pub struct Claims {
pub sub: String,
pub email: String,
pub exp: usize,
}

View File

@@ -25,7 +25,7 @@ use crate::{
state::AppState, state::AppState,
}; };
#[cfg(feature = "auth-axum-login")] #[cfg(feature = "auth-axum-login")]
use domain::{DomainError, Email}; use domain::DomainError;
/// Token response for JWT authentication /// Token response for JWT authentication
#[derive(Debug, Serialize)] #[derive(Debug, Serialize)]
@@ -140,19 +140,20 @@ async fn register(
mut auth_session: crate::auth::AuthSession, mut auth_session: crate::auth::AuthSession,
Json(payload): Json<RegisterRequest>, Json(payload): Json<RegisterRequest>,
) -> Result<impl IntoResponse, ApiError> { ) -> Result<impl IntoResponse, ApiError> {
// Email is already validated by the newtype deserialization
let email = payload.email;
if state if state
.user_service .user_service
.find_by_email(&payload.email) .find_by_email(email.as_ref())
.await? .await?
.is_some() .is_some()
{ {
return Err(ApiError::Domain(DomainError::UserAlreadyExists( return Err(ApiError::Domain(DomainError::UserAlreadyExists(
payload.email, email.as_ref().to_string(),
))); )));
} }
let email = Email::try_from(payload.email).map_err(|e| ApiError::Validation(e.to_string()))?;
// Using email as subject for local auth for now // Using email as subject for local auth for now
let user = state let user = state
.user_service .user_service
@@ -274,22 +275,22 @@ async fn oidc_login(State(state): State<AppState>, session: Session) -> Result<R
.as_ref() .as_ref()
.ok_or(ApiError::Internal("OIDC not configured".into()))?; .ok_or(ApiError::Internal("OIDC not configured".into()))?;
let (url, csrf, nonce, pkce) = service.get_authorization_url(); let auth_data = service.get_authorization_url();
session session
.insert("oidc_csrf", csrf) .insert("oidc_csrf", &auth_data.csrf_token)
.await .await
.map_err(|_| ApiError::Internal("Session error".into()))?; .map_err(|_| ApiError::Internal("Session error".into()))?;
session session
.insert("oidc_nonce", nonce) .insert("oidc_nonce", &auth_data.nonce)
.await .await
.map_err(|_| ApiError::Internal("Session error".into()))?; .map_err(|_| ApiError::Internal("Session error".into()))?;
session session
.insert("oidc_pkce", pkce) .insert("oidc_pkce", &auth_data.pkce_verifier)
.await .await
.map_err(|_| ApiError::Internal("Session error".into()))?; .map_err(|_| ApiError::Internal("Session error".into()))?;
let response = axum::response::Redirect::to(&url).into_response(); let response = axum::response::Redirect::to(auth_data.url.as_str()).into_response();
let (mut parts, body) = response.into_parts(); let (mut parts, body) = response.into_parts();
parts.headers.insert( parts.headers.insert(
@@ -323,29 +324,33 @@ async fn oidc_callback(
.as_ref() .as_ref()
.ok_or(ApiError::Internal("OIDC not configured".into()))?; .ok_or(ApiError::Internal("OIDC not configured".into()))?;
let stored_csrf: String = session let stored_csrf: domain::CsrfToken = session
.get("oidc_csrf") .get("oidc_csrf")
.await .await
.map_err(|_| ApiError::Internal("Session error".into()))? .map_err(|_| ApiError::Internal("Session error".into()))?
.ok_or(ApiError::Validation("Missing CSRF token".into()))?; .ok_or(ApiError::Validation("Missing CSRF token".into()))?;
if params.state != stored_csrf { if params.state != stored_csrf.as_ref() {
return Err(ApiError::Validation("Invalid CSRF token".into())); return Err(ApiError::Validation("Invalid CSRF token".into()));
} }
let stored_pkce: String = session let stored_pkce: domain::PkceVerifier = session
.get("oidc_pkce") .get("oidc_pkce")
.await .await
.map_err(|_| ApiError::Internal("Session error".into()))? .map_err(|_| ApiError::Internal("Session error".into()))?
.ok_or(ApiError::Validation("Missing PKCE".into()))?; .ok_or(ApiError::Validation("Missing PKCE".into()))?;
let stored_nonce: String = session let stored_nonce: domain::OidcNonce = session
.get("oidc_nonce") .get("oidc_nonce")
.await .await
.map_err(|_| ApiError::Internal("Session error".into()))? .map_err(|_| ApiError::Internal("Session error".into()))?
.ok_or(ApiError::Validation("Missing Nonce".into()))?; .ok_or(ApiError::Validation("Missing Nonce".into()))?;
let oidc_user = service let oidc_user = service
.resolve_callback(params.code, stored_nonce, stored_pkce) .resolve_callback(
domain::AuthorizationCode::new(params.code),
stored_nonce,
stored_pkce,
)
.await .await
.map_err(|e| ApiError::Internal(e.to_string()))?; .map_err(|e| ApiError::Internal(e.to_string()))?;
@@ -412,29 +417,33 @@ async fn oidc_callback(
.as_ref() .as_ref()
.ok_or(ApiError::Internal("OIDC not configured".into()))?; .ok_or(ApiError::Internal("OIDC not configured".into()))?;
let stored_csrf: String = session let stored_csrf: domain::CsrfToken = session
.get("oidc_csrf") .get("oidc_csrf")
.await .await
.map_err(|_| ApiError::Internal("Session error".into()))? .map_err(|_| ApiError::Internal("Session error".into()))?
.ok_or(ApiError::Validation("Missing CSRF token".into()))?; .ok_or(ApiError::Validation("Missing CSRF token".into()))?;
if params.state != stored_csrf { if params.state != stored_csrf.as_ref() {
return Err(ApiError::Validation("Invalid CSRF token".into())); return Err(ApiError::Validation("Invalid CSRF token".into()));
} }
let stored_pkce: String = session let stored_pkce: domain::PkceVerifier = session
.get("oidc_pkce") .get("oidc_pkce")
.await .await
.map_err(|_| ApiError::Internal("Session error".into()))? .map_err(|_| ApiError::Internal("Session error".into()))?
.ok_or(ApiError::Validation("Missing PKCE".into()))?; .ok_or(ApiError::Validation("Missing PKCE".into()))?;
let stored_nonce: String = session let stored_nonce: domain::OidcNonce = session
.get("oidc_nonce") .get("oidc_nonce")
.await .await
.map_err(|_| ApiError::Internal("Session error".into()))? .map_err(|_| ApiError::Internal("Session error".into()))?
.ok_or(ApiError::Validation("Missing Nonce".into()))?; .ok_or(ApiError::Validation("Missing Nonce".into()))?;
let oidc_user = service let oidc_user = service
.resolve_callback(params.code, stored_nonce, stored_pkce) .resolve_callback(
domain::AuthorizationCode::new(params.code),
stored_nonce,
stored_pkce,
)
.await .await
.map_err(|e| ApiError::Internal(e.to_string()))?; .map_err(|e| ApiError::Internal(e.to_string()))?;

View File

@@ -33,15 +33,24 @@ impl AppState {
&config.oidc_resource_id, &config.oidc_resource_id,
) { ) {
tracing::info!("Initializing OIDC service with issuer: {}", issuer); tracing::info!("Initializing OIDC service with issuer: {}", issuer);
// Construct newtypes from config strings
let issuer_url = domain::IssuerUrl::new(issuer)
.map_err(|e| anyhow::anyhow!("Invalid OIDC issuer URL: {}", e))?;
let client_id = domain::ClientId::new(id)
.map_err(|e| anyhow::anyhow!("Invalid OIDC client ID: {}", e))?;
let client_secret = secret.as_ref().map(|s| domain::ClientSecret::new(s));
let redirect_url = domain::RedirectUrl::new(redirect)
.map_err(|e| anyhow::anyhow!("Invalid OIDC redirect URL: {}", e))?;
let resource = resource_id
.as_ref()
.map(|r| domain::ResourceId::new(r))
.transpose()
.map_err(|e| anyhow::anyhow!("Invalid OIDC resource ID: {}", e))?;
Some(Arc::new( Some(Arc::new(
OidcService::new( OidcService::new(issuer_url, client_id, client_secret, redirect_url, resource)
issuer.clone(), .await?,
id.clone(),
secret.clone().unwrap_or_default(),
redirect.clone(),
resource_id.clone(),
)
.await?,
)) ))
} else { } else {
None None

View File

@@ -7,10 +7,12 @@ edition = "2024"
anyhow = "1.0.100" anyhow = "1.0.100"
async-trait = "0.1.89" async-trait = "0.1.89"
chrono = { version = "0.4.42", features = ["serde"] } chrono = { version = "0.4.42", features = ["serde"] }
email_address = "0.2"
serde = { version = "1.0.228", features = ["derive"] } serde = { version = "1.0.228", features = ["derive"] }
serde_json = "1.0.146" serde_json = "1.0.146"
thiserror = "2.0.17" thiserror = "2.0.17"
tracing = "0.1" tracing = "0.1"
url = { version = "2.5", features = ["serde"] }
uuid = { version = "1.19.0", features = ["v4", "serde"] } uuid = { version = "1.19.0", features = ["v4", "serde"] }
futures-core = "0.3" futures-core = "0.3"

View File

@@ -6,6 +6,7 @@
use serde::{Deserialize, Deserializer, Serialize, Serializer}; use serde::{Deserialize, Deserializer, Serialize, Serializer};
use std::fmt; use std::fmt;
use thiserror::Error; use thiserror::Error;
use url::Url;
use uuid::Uuid; use uuid::Uuid;
pub type UserId = Uuid; pub type UserId = Uuid;
@@ -22,47 +23,44 @@ pub enum ValidationError {
#[error("Password must be at least {min} characters, got {actual}")] #[error("Password must be at least {min} characters, got {actual}")]
PasswordTooShort { min: usize, actual: usize }, PasswordTooShort { min: usize, actual: usize },
#[error("Invalid URL: {0}")]
InvalidUrl(String),
#[error("Value cannot be empty: {0}")]
Empty(String),
#[error("Secret too short: minimum {min} bytes required, got {actual}")]
SecretTooShort { min: usize, actual: usize },
} }
// ============================================================================ // ============================================================================
// Email // Email (using email_address crate for RFC-compliant validation)
// ============================================================================ // ============================================================================
/// A validated email address. /// A validated email address using RFC-compliant validation.
///
/// Simple validation: must contain exactly one `@` with non-empty parts on both sides.
#[derive(Debug, Clone, PartialEq, Eq, Hash)] #[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct Email(String); pub struct Email(email_address::EmailAddress);
impl Email { impl Email {
/// Minimum validation: contains @ with non-empty local and domain parts /// Create a new validated email address
pub fn new(value: impl Into<String>) -> Result<Self, ValidationError> { pub fn new(value: impl AsRef<str>) -> Result<Self, ValidationError> {
let value = value.into(); let value = value.as_ref().trim().to_lowercase();
let trimmed = value.trim().to_lowercase(); let addr: email_address::EmailAddress = value
.parse()
// Basic email validation .map_err(|_| ValidationError::InvalidEmail(value.clone()))?;
let parts: Vec<&str> = trimmed.split('@').collect(); Ok(Self(addr))
if parts.len() != 2 || parts[0].is_empty() || parts[1].is_empty() {
return Err(ValidationError::InvalidEmail(value));
}
// Domain must contain at least one dot
if !parts[1].contains('.') {
return Err(ValidationError::InvalidEmail(value));
}
Ok(Self(trimmed))
} }
/// Get the inner value /// Get the inner value
pub fn into_inner(self) -> String { pub fn into_inner(self) -> String {
self.0 self.0.to_string()
} }
} }
impl AsRef<str> for Email { impl AsRef<str> for Email {
fn as_ref(&self) -> &str { fn as_ref(&self) -> &str {
&self.0 self.0.as_ref()
} }
} }
@@ -90,7 +88,7 @@ impl TryFrom<&str> for Email {
impl Serialize for Email { impl Serialize for Email {
fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> { fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
serializer.serialize_str(&self.0) serializer.serialize_str(self.0.as_ref())
} }
} }
@@ -171,6 +169,446 @@ impl<'de> Deserialize<'de> for Password {
// Note: Password should NOT implement Serialize to prevent accidental exposure // Note: Password should NOT implement Serialize to prevent accidental exposure
// ============================================================================
// OIDC Configuration Newtypes
// ============================================================================
/// OIDC Issuer URL - validated URL for the identity provider
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(try_from = "String", into = "String")]
pub struct IssuerUrl(Url);
impl IssuerUrl {
pub fn new(value: impl AsRef<str>) -> Result<Self, ValidationError> {
let value = value.as_ref().trim();
let url = Url::parse(value).map_err(|e| ValidationError::InvalidUrl(e.to_string()))?;
Ok(Self(url))
}
pub fn as_url(&self) -> &Url {
&self.0
}
}
impl AsRef<str> for IssuerUrl {
fn as_ref(&self) -> &str {
self.0.as_str()
}
}
impl fmt::Display for IssuerUrl {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.0)
}
}
impl TryFrom<String> for IssuerUrl {
type Error = ValidationError;
fn try_from(value: String) -> Result<Self, Self::Error> {
Self::new(value)
}
}
impl From<IssuerUrl> for String {
fn from(val: IssuerUrl) -> Self {
val.0.to_string()
}
}
/// OIDC Client Identifier
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(try_from = "String", into = "String")]
pub struct ClientId(String);
impl ClientId {
pub fn new(value: impl Into<String>) -> Result<Self, ValidationError> {
let value = value.into().trim().to_string();
if value.is_empty() {
return Err(ValidationError::Empty("client_id".to_string()));
}
Ok(Self(value))
}
}
impl AsRef<str> for ClientId {
fn as_ref(&self) -> &str {
&self.0
}
}
impl fmt::Display for ClientId {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.0)
}
}
impl TryFrom<String> for ClientId {
type Error = ValidationError;
fn try_from(value: String) -> Result<Self, Self::Error> {
Self::new(value)
}
}
impl From<ClientId> for String {
fn from(val: ClientId) -> Self {
val.0
}
}
/// OIDC Client Secret - hidden in Debug output
#[derive(Clone, PartialEq, Eq)]
pub struct ClientSecret(String);
impl ClientSecret {
pub fn new(value: impl Into<String>) -> Self {
Self(value.into())
}
/// Check if the secret is empty (for public clients)
pub fn is_empty(&self) -> bool {
self.0.trim().is_empty()
}
}
impl AsRef<str> for ClientSecret {
fn as_ref(&self) -> &str {
&self.0
}
}
impl fmt::Debug for ClientSecret {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "ClientSecret(***)")
}
}
impl fmt::Display for ClientSecret {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "***")
}
}
impl<'de> Deserialize<'de> for ClientSecret {
fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
let s = String::deserialize(deserializer)?;
Ok(Self::new(s))
}
}
// Note: ClientSecret should NOT implement Serialize
/// OAuth Redirect URL - validated URL
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(try_from = "String", into = "String")]
pub struct RedirectUrl(Url);
impl RedirectUrl {
pub fn new(value: impl AsRef<str>) -> Result<Self, ValidationError> {
let value = value.as_ref().trim();
let url = Url::parse(value).map_err(|e| ValidationError::InvalidUrl(e.to_string()))?;
Ok(Self(url))
}
pub fn as_url(&self) -> &Url {
&self.0
}
}
impl AsRef<str> for RedirectUrl {
fn as_ref(&self) -> &str {
self.0.as_str()
}
}
impl fmt::Display for RedirectUrl {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.0)
}
}
impl TryFrom<String> for RedirectUrl {
type Error = ValidationError;
fn try_from(value: String) -> Result<Self, Self::Error> {
Self::new(value)
}
}
impl From<RedirectUrl> for String {
fn from(val: RedirectUrl) -> Self {
val.0.to_string()
}
}
/// OIDC Resource Identifier (optional audience)
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(try_from = "String", into = "String")]
pub struct ResourceId(String);
impl ResourceId {
pub fn new(value: impl Into<String>) -> Result<Self, ValidationError> {
let value = value.into().trim().to_string();
if value.is_empty() {
return Err(ValidationError::Empty("resource_id".to_string()));
}
Ok(Self(value))
}
}
impl AsRef<str> for ResourceId {
fn as_ref(&self) -> &str {
&self.0
}
}
impl fmt::Display for ResourceId {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.0)
}
}
impl TryFrom<String> for ResourceId {
type Error = ValidationError;
fn try_from(value: String) -> Result<Self, Self::Error> {
Self::new(value)
}
}
impl From<ResourceId> for String {
fn from(val: ResourceId) -> Self {
val.0
}
}
// ============================================================================
// OIDC Flow Newtypes (for type-safe session storage)
// ============================================================================
/// CSRF Token for OIDC state parameter
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct CsrfToken(String);
impl CsrfToken {
pub fn new(value: impl Into<String>) -> Self {
Self(value.into())
}
}
impl AsRef<str> for CsrfToken {
fn as_ref(&self) -> &str {
&self.0
}
}
impl fmt::Display for CsrfToken {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.0)
}
}
/// Nonce for OIDC ID token verification
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct OidcNonce(String);
impl OidcNonce {
pub fn new(value: impl Into<String>) -> Self {
Self(value.into())
}
}
impl AsRef<str> for OidcNonce {
fn as_ref(&self) -> &str {
&self.0
}
}
impl fmt::Display for OidcNonce {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.0)
}
}
/// PKCE Code Verifier
#[derive(Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct PkceVerifier(String);
impl PkceVerifier {
pub fn new(value: impl Into<String>) -> Self {
Self(value.into())
}
}
impl AsRef<str> for PkceVerifier {
fn as_ref(&self) -> &str {
&self.0
}
}
// Hide PKCE verifier in Debug (security)
impl fmt::Debug for PkceVerifier {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "PkceVerifier(***)")
}
}
/// OAuth2 Authorization Code
#[derive(Clone, PartialEq, Eq)]
pub struct AuthorizationCode(String);
impl AuthorizationCode {
pub fn new(value: impl Into<String>) -> Self {
Self(value.into())
}
}
impl AsRef<str> for AuthorizationCode {
fn as_ref(&self) -> &str {
&self.0
}
}
// Hide authorization code in Debug (security)
impl fmt::Debug for AuthorizationCode {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "AuthorizationCode(***)")
}
}
impl<'de> Deserialize<'de> for AuthorizationCode {
fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
let s = String::deserialize(deserializer)?;
Ok(Self::new(s))
}
}
/// Complete authorization URL data returned when starting OIDC flow
#[derive(Debug, Clone)]
pub struct AuthorizationUrlData {
/// The URL to redirect the user to
pub url: Url,
/// CSRF token to store in session
pub csrf_token: CsrfToken,
/// Nonce to store in session
pub nonce: OidcNonce,
/// PKCE verifier to store in session
pub pkce_verifier: PkceVerifier,
}
// ============================================================================
// Configuration Newtypes
// ============================================================================
/// Database connection URL
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(try_from = "String", into = "String")]
pub struct DatabaseUrl(String);
impl DatabaseUrl {
pub fn new(value: impl Into<String>) -> Result<Self, ValidationError> {
let value = value.into();
if value.trim().is_empty() {
return Err(ValidationError::Empty("database_url".to_string()));
}
Ok(Self(value))
}
}
impl AsRef<str> for DatabaseUrl {
fn as_ref(&self) -> &str {
&self.0
}
}
impl fmt::Display for DatabaseUrl {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.0)
}
}
impl TryFrom<String> for DatabaseUrl {
type Error = ValidationError;
fn try_from(value: String) -> Result<Self, Self::Error> {
Self::new(value)
}
}
impl From<DatabaseUrl> for String {
fn from(val: DatabaseUrl) -> Self {
val.0
}
}
/// Session secret with minimum length requirement
pub const MIN_SESSION_SECRET_LENGTH: usize = 64;
#[derive(Clone, PartialEq, Eq)]
pub struct SessionSecret(String);
impl SessionSecret {
pub fn new(value: impl Into<String>) -> Result<Self, ValidationError> {
let value = value.into();
if value.len() < MIN_SESSION_SECRET_LENGTH {
return Err(ValidationError::SecretTooShort {
min: MIN_SESSION_SECRET_LENGTH,
actual: value.len(),
});
}
Ok(Self(value))
}
/// Create without validation (for development/testing)
pub fn new_unchecked(value: impl Into<String>) -> Self {
Self(value.into())
}
}
impl AsRef<str> for SessionSecret {
fn as_ref(&self) -> &str {
&self.0
}
}
impl fmt::Debug for SessionSecret {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "SessionSecret(***)")
}
}
/// JWT signing secret with minimum length requirement
pub const MIN_JWT_SECRET_LENGTH: usize = 32;
#[derive(Clone, PartialEq, Eq)]
pub struct JwtSecret(String);
impl JwtSecret {
pub fn new(value: impl Into<String>, is_production: bool) -> Result<Self, ValidationError> {
let value = value.into();
if is_production && value.len() < MIN_JWT_SECRET_LENGTH {
return Err(ValidationError::SecretTooShort {
min: MIN_JWT_SECRET_LENGTH,
actual: value.len(),
});
}
Ok(Self(value))
}
/// Create without validation (for development/testing)
pub fn new_unchecked(value: impl Into<String>) -> Self {
Self(value.into())
}
}
impl AsRef<str> for JwtSecret {
fn as_ref(&self) -> &str {
&self.0
}
}
impl fmt::Debug for JwtSecret {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "JwtSecret(***)")
}
}
// ============================================================================ // ============================================================================
// Tests // Tests
// ============================================================================ // ============================================================================
@@ -209,11 +647,6 @@ mod tests {
fn test_invalid_email_no_local() { fn test_invalid_email_no_local() {
assert!(Email::new("@example.com").is_err()); assert!(Email::new("@example.com").is_err());
} }
#[test]
fn test_invalid_email_no_dot_in_domain() {
assert!(Email::new("user@localhost").is_err());
}
} }
mod password_tests { mod password_tests {
@@ -239,4 +672,68 @@ mod tests {
assert!(debug.contains("***")); assert!(debug.contains("***"));
} }
} }
mod oidc_tests {
use super::*;
#[test]
fn test_issuer_url_valid() {
assert!(IssuerUrl::new("https://auth.example.com").is_ok());
}
#[test]
fn test_issuer_url_invalid() {
assert!(IssuerUrl::new("not-a-url").is_err());
}
#[test]
fn test_client_id_non_empty() {
assert!(ClientId::new("my-client").is_ok());
assert!(ClientId::new("").is_err());
assert!(ClientId::new(" ").is_err());
}
#[test]
fn test_client_secret_hides_in_debug() {
let secret = ClientSecret::new("super-secret");
let debug = format!("{:?}", secret);
assert!(!debug.contains("super-secret"));
assert!(debug.contains("***"));
}
}
mod secret_tests {
use super::*;
#[test]
fn test_session_secret_min_length() {
let short = "short";
let long = "a".repeat(64);
assert!(SessionSecret::new(short).is_err());
assert!(SessionSecret::new(long).is_ok());
}
#[test]
fn test_jwt_secret_production_check() {
let short = "short";
let long = "a".repeat(32);
// Production mode enforces length
assert!(JwtSecret::new(short, true).is_err());
assert!(JwtSecret::new(&long, true).is_ok());
// Development mode allows short secrets
assert!(JwtSecret::new(short, false).is_ok());
}
#[test]
fn test_secrets_hide_in_debug() {
let session = SessionSecret::new_unchecked("secret");
let jwt = JwtSecret::new_unchecked("secret");
assert!(!format!("{:?}", session).contains("secret"));
assert!(!format!("{:?}", jwt).contains("secret"));
}
}
} }

View File

@@ -51,8 +51,8 @@ pub mod backend {
#[derive(Clone, Debug, Deserialize)] #[derive(Clone, Debug, Deserialize)]
pub struct Credentials { pub struct Credentials {
pub email: String, pub email: domain::Email,
pub password: String, pub password: domain::Password,
} }
#[derive(Debug, thiserror::Error)] #[derive(Debug, thiserror::Error)]
@@ -72,14 +72,14 @@ pub mod backend {
) -> Result<Option<Self::User>, Self::Error> { ) -> Result<Option<Self::User>, Self::Error> {
let user = self let user = self
.user_repo .user_repo
.find_by_email(&creds.email) .find_by_email(creds.email.as_ref())
.await .await
.map_err(|e| AuthError::Anyhow(anyhow::anyhow!(e)))?; .map_err(|e| AuthError::Anyhow(anyhow::anyhow!(e)))?;
if let Some(user) = user { if let Some(user) = user {
if let Some(hash) = &user.password_hash { if let Some(hash) = &user.password_hash {
// Verify password // Verify password
if verify_password(&creds.password, hash).is_ok() { if verify_password(creds.password.as_ref(), hash).is_ok() {
return Ok(Some(AuthUser(user))); return Ok(Some(AuthUser(user)));
} }
} }

View File

@@ -1,9 +1,12 @@
use anyhow::anyhow; use anyhow::anyhow;
use domain::{
AuthorizationCode, AuthorizationUrlData, ClientId, ClientSecret, CsrfToken, IssuerUrl,
OidcNonce, PkceVerifier, RedirectUrl, ResourceId,
};
use openidconnect::{ use openidconnect::{
AccessTokenHash, AuthorizationCode, Client, ClientId, ClientSecret, CsrfToken, AccessTokenHash, Client, EmptyAdditionalClaims, EndpointMaybeSet, EndpointNotSet, EndpointSet,
EmptyAdditionalClaims, EndpointMaybeSet, EndpointNotSet, EndpointSet, IssuerUrl, Nonce, OAuth2TokenResponse, PkceCodeChallenge, Scope, StandardErrorResponse, TokenResponse,
OAuth2TokenResponse, PkceCodeChallenge, PkceCodeVerifier, RedirectUrl, Scope, UserInfoClaims,
StandardErrorResponse, TokenResponse, UserInfoClaims,
core::{ core::{
CoreAuthDisplay, CoreAuthPrompt, CoreAuthenticationFlow, CoreClient, CoreErrorResponseType, CoreAuthDisplay, CoreAuthPrompt, CoreAuthenticationFlow, CoreClient, CoreErrorResponseType,
CoreGenderClaim, CoreJsonWebKey, CoreJweContentEncryptionAlgorithm, CoreProviderMetadata, CoreGenderClaim, CoreJsonWebKey, CoreJweContentEncryptionAlgorithm, CoreProviderMetadata,
@@ -36,7 +39,7 @@ pub type OidcClient = Client<
#[derive(Clone)] #[derive(Clone)]
pub struct OidcService { pub struct OidcService {
client: OidcClient, client: OidcClient,
resource_id: Option<String>, resource_id: Option<ResourceId>,
} }
#[derive(Debug)] #[derive(Debug)]
@@ -46,31 +49,19 @@ pub struct OidcUser {
} }
impl OidcService { impl OidcService {
//todo: replace Strings with newtypes /// Create a new OIDC service with validated configuration newtypes
pub async fn new( pub async fn new(
issuer: String, issuer: IssuerUrl,
client_id: String, client_id: ClientId,
client_secret: String, client_secret: Option<ClientSecret>,
redirect_url: String, redirect_url: RedirectUrl,
resource_id: Option<String>, resource_id: Option<ResourceId>,
) -> anyhow::Result<Self> { ) -> anyhow::Result<Self> {
let client_id = client_id.trim().to_string();
let redirect_url = redirect_url.trim().to_string();
let issuer = issuer.trim().to_string();
// 2. Handle Empty Secret (For PKCE/Public Clients)
let client_secret_clean = client_secret.trim();
let client_secret_opt = if client_secret_clean.is_empty() {
None
} else {
Some(ClientSecret::new(client_secret_clean.to_string()))
};
tracing::debug!("🔵 OIDC Setup: Client ID = '{}'", client_id); tracing::debug!("🔵 OIDC Setup: Client ID = '{}'", client_id);
tracing::debug!("🔵 OIDC Setup: Redirect = '{}'", redirect_url); tracing::debug!("🔵 OIDC Setup: Redirect = '{}'", redirect_url);
tracing::debug!( tracing::debug!(
"🔵 OIDC Setup: Secret = {:?}", "🔵 OIDC Setup: Secret = {:?}",
if client_secret_opt.is_some() { if client_secret.is_some() {
"SET" "SET"
} else { } else {
"NONE" "NONE"
@@ -81,15 +72,26 @@ impl OidcService {
.redirect(reqwest::redirect::Policy::none()) .redirect(reqwest::redirect::Policy::none())
.build()?; .build()?;
let provider_metadata = let provider_metadata = CoreProviderMetadata::discover_async(
CoreProviderMetadata::discover_async(IssuerUrl::new(issuer)?, &http_client).await?; openidconnect::IssuerUrl::new(issuer.as_ref().to_string())?,
&http_client,
)
.await?;
// Convert to openidconnect types
let oidc_client_id = openidconnect::ClientId::new(client_id.as_ref().to_string());
let oidc_client_secret = client_secret
.as_ref()
.filter(|s| !s.is_empty())
.map(|s| openidconnect::ClientSecret::new(s.as_ref().to_string()));
let oidc_redirect_url = openidconnect::RedirectUrl::new(redirect_url.as_ref().to_string())?;
let client = CoreClient::from_provider_metadata( let client = CoreClient::from_provider_metadata(
provider_metadata, provider_metadata,
ClientId::new(client_id), oidc_client_id,
client_secret_opt, oidc_client_secret,
) )
.set_redirect_uri(RedirectUrl::new(redirect_url)?); .set_redirect_uri(oidc_redirect_url);
Ok(Self { Ok(Self {
client, client,
@@ -97,48 +99,53 @@ impl OidcService {
}) })
} }
// todo: replace this tuple with newtype /// Get the authorization URL and associated state for OIDC login
pub fn get_authorization_url(&self) -> (String, String, String, String) { ///
/// Returns structured data instead of a raw tuple for better type safety
pub fn get_authorization_url(&self) -> AuthorizationUrlData {
let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256(); let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256();
let (auth_url, csrf_token, nonce) = self let (auth_url, csrf_token, nonce) = self
.client .client
.authorize_url( .authorize_url(
CoreAuthenticationFlow::AuthorizationCode, CoreAuthenticationFlow::AuthorizationCode,
CsrfToken::new_random, openidconnect::CsrfToken::new_random,
Nonce::new_random, openidconnect::Nonce::new_random,
) )
.add_scope(Scope::new("profile".to_string())) .add_scope(Scope::new("profile".to_string()))
.add_scope(Scope::new("email".to_string())) .add_scope(Scope::new("email".to_string()))
.set_pkce_challenge(pkce_challenge) .set_pkce_challenge(pkce_challenge)
.url(); .url();
( AuthorizationUrlData {
auth_url.to_string(), url: auth_url.into(),
csrf_token.secret().to_string(), csrf_token: CsrfToken::new(csrf_token.secret().to_string()),
nonce.secret().to_string(), nonce: OidcNonce::new(nonce.secret().to_string()),
pkce_verifier.secret().to_string(), pkce_verifier: PkceVerifier::new(pkce_verifier.secret().to_string()),
) }
} }
//todo: replace strings with newtype /// Resolve the OIDC callback with type-safe parameters
pub async fn resolve_callback( pub async fn resolve_callback(
&self, &self,
code: String, code: AuthorizationCode,
nonce: String, nonce: OidcNonce,
pkce_verifier: String, pkce_verifier: PkceVerifier,
) -> anyhow::Result<OidcUser> { ) -> anyhow::Result<OidcUser> {
let http_client = reqwest::ClientBuilder::new() let http_client = reqwest::ClientBuilder::new()
.redirect(reqwest::redirect::Policy::none()) .redirect(reqwest::redirect::Policy::none())
.build()?; .build()?;
let pkce_verifier = PkceCodeVerifier::new(pkce_verifier); let oidc_pkce_verifier =
let nonce = Nonce::new(nonce); openidconnect::PkceCodeVerifier::new(pkce_verifier.as_ref().to_string());
let oidc_nonce = openidconnect::Nonce::new(nonce.as_ref().to_string());
let token_response = self let token_response = self
.client .client
.exchange_code(AuthorizationCode::new(code))? .exchange_code(openidconnect::AuthorizationCode::new(
.set_pkce_verifier(pkce_verifier) code.as_ref().to_string(),
))?
.set_pkce_verifier(oidc_pkce_verifier)
.request_async(&http_client) .request_async(&http_client)
.await?; .await?;
@@ -148,14 +155,13 @@ impl OidcService {
let mut id_token_verifier = self.client.id_token_verifier().clone(); let mut id_token_verifier = self.client.id_token_verifier().clone();
let trusted_resource_id = self.resource_id.clone(); if let Some(resource_id) = &self.resource_id {
let trusted_resource_id = resource_id.as_ref().to_string();
if let Some(resource_id) = trusted_resource_id {
id_token_verifier = id_token_verifier id_token_verifier = id_token_verifier
.set_other_audience_verifier_fn(move |aud| aud.as_str() == resource_id); .set_other_audience_verifier_fn(move |aud| aud.as_str() == trusted_resource_id);
} }
let claims = id_token.claims(&id_token_verifier, &nonce)?; let claims = id_token.claims(&id_token_verifier, &oidc_nonce)?;
if let Some(expected_access_token_hash) = claims.access_token_hash() { if let Some(expected_access_token_hash) = claims.access_token_hash() {
let actual_access_token_hash = AccessTokenHash::from_token( let actual_access_token_hash = AccessTokenHash::from_token(