refactor: Replace raw strings with domain value objects for improved type safety in authentication and OIDC.

This commit is contained in:
2026-01-06 05:16:16 +01:00
parent 16dcc4b95e
commit 32a0faf302
9 changed files with 667 additions and 232 deletions

107
Cargo.lock generated
View File

@@ -57,7 +57,6 @@ dependencies = [
"tracing",
"tracing-subscriber",
"uuid",
"validator",
]
[[package]]
@@ -523,38 +522,14 @@ dependencies = [
"syn",
]
[[package]]
name = "darling"
version = "0.20.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fc7f46116c46ff9ab3eb1597a45688b6715c6e628b5c133e288e709a29bcb4ee"
dependencies = [
"darling_core 0.20.11",
"darling_macro 0.20.11",
]
[[package]]
name = "darling"
version = "0.21.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9cdf337090841a411e2a7f3deb9187445851f91b309c0c0a29e05f74a00a48c0"
dependencies = [
"darling_core 0.21.3",
"darling_macro 0.21.3",
]
[[package]]
name = "darling_core"
version = "0.20.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0d00b9596d185e565c2207a0b01f8bd1a135483d02d9b7b0a54b11da8d53412e"
dependencies = [
"fnv",
"ident_case",
"proc-macro2",
"quote",
"strsim",
"syn",
"darling_core",
"darling_macro",
]
[[package]]
@@ -571,24 +546,13 @@ dependencies = [
"syn",
]
[[package]]
name = "darling_macro"
version = "0.20.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fc34b93ccb385b40dc71c6fceac4b2ad23662c7eeb248cf10d529b7e055b6ead"
dependencies = [
"darling_core 0.20.11",
"quote",
"syn",
]
[[package]]
name = "darling_macro"
version = "0.21.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d38308df82d1080de0afee5d069fa14b0326a88c14f15c5ccda35b4a6c414c81"
dependencies = [
"darling_core 0.21.3",
"darling_core",
"quote",
"syn",
]
@@ -659,12 +623,14 @@ dependencies = [
"anyhow",
"async-trait",
"chrono",
"email_address",
"futures-core",
"serde",
"serde_json",
"thiserror 2.0.17",
"tokio",
"tracing",
"url",
"uuid",
]
@@ -749,6 +715,15 @@ dependencies = [
"zeroize",
]
[[package]]
name = "email_address"
version = "0.2.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e079f19b08ca6239f47f8ba8509c11cf3ea30095831f7fed61441475edd8c449"
dependencies = [
"serde",
]
[[package]]
name = "encoding_rs"
version = "0.8.35"
@@ -2013,28 +1988,6 @@ dependencies = [
"elliptic-curve",
]
[[package]]
name = "proc-macro-error-attr2"
version = "2.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "96de42df36bb9bba5542fe9f1a054b8cc87e172759a1868aa05c1f3acc89dfc5"
dependencies = [
"proc-macro2",
"quote",
]
[[package]]
name = "proc-macro-error2"
version = "2.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "11ec05c52be0a07b08061f7dd003e7d7092e0472bc731b4af7bb1ef876109802"
dependencies = [
"proc-macro-error-attr2",
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "proc-macro2"
version = "1.0.104"
@@ -2692,7 +2645,7 @@ version = "3.16.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "52a8e3ca0ca629121f70ab50f95249e5a6f925cc0f6ffe8256c45b728875706c"
dependencies = [
"darling 0.21.3",
"darling",
"proc-macro2",
"quote",
"syn",
@@ -3611,36 +3564,6 @@ dependencies = [
"wasm-bindgen",
]
[[package]]
name = "validator"
version = "0.20.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "43fb22e1a008ece370ce08a3e9e4447a910e92621bb49b85d6e48a45397e7cfa"
dependencies = [
"idna",
"once_cell",
"regex",
"serde",
"serde_derive",
"serde_json",
"url",
"validator_derive",
]
[[package]]
name = "validator_derive"
version = "0.20.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b7df16e474ef958526d1205f6dda359fdfab79d9aa6d54bafcb92dcd07673dca"
dependencies = [
"darling 0.20.11",
"once_cell",
"proc-macro-error2",
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "valuable"
version = "0.1.1"

View File

@@ -44,8 +44,7 @@ tokio = { version = "1.48.0", features = ["full"] }
serde = { version = "1.0.228", features = ["derive"] }
serde_json = "1.0"
# Validation
validator = { version = "0.20", features = ["derive"] }
# Validation via domain newtypes (Email, Password)
# Error handling
thiserror = "2.0.17"

View File

@@ -1,30 +1,29 @@
//! Request and Response DTOs
//!
//! Data Transfer Objects for the API.
//! Uses domain newtypes for validation instead of the validator crate.
use chrono::{DateTime, Utc};
use domain::{Email, Password};
use serde::{Deserialize, Serialize};
use uuid::Uuid;
use validator::Validate;
/// Login request
#[derive(Debug, Deserialize, Validate)]
/// Login request with validated email and password newtypes
#[derive(Debug, Deserialize)]
pub struct LoginRequest {
#[validate(email(message = "Invalid email format"))]
pub email: String,
#[validate(length(min = 6, message = "Password must be at least 6 characters"))]
pub password: String,
/// Email is validated on deserialization
pub email: Email,
/// Password is validated on deserialization (min 6 chars)
pub password: Password,
}
/// Register request
#[derive(Debug, Deserialize, Validate)]
/// Register request with validated email and password newtypes
#[derive(Debug, Deserialize)]
pub struct RegisterRequest {
#[validate(email(message = "Invalid email format"))]
pub email: String,
#[validate(length(min = 6, message = "Password must be at least 6 characters"))]
pub password: String,
/// Email is validated on deserialization
pub email: Email,
/// Password is validated on deserialization (min 6 chars)
pub password: Password,
}
/// User response DTO
@@ -40,12 +39,3 @@ pub struct UserResponse {
pub struct ConfigResponse {
pub allow_registration: bool,
}
#[cfg(feature = "auth-jwt")]
#[derive(Debug, Serialize, Deserialize)]
// also newtypes
pub struct Claims {
pub sub: String,
pub email: String,
pub exp: usize,
}

View File

@@ -25,7 +25,7 @@ use crate::{
state::AppState,
};
#[cfg(feature = "auth-axum-login")]
use domain::{DomainError, Email};
use domain::DomainError;
/// Token response for JWT authentication
#[derive(Debug, Serialize)]
@@ -140,19 +140,20 @@ async fn register(
mut auth_session: crate::auth::AuthSession,
Json(payload): Json<RegisterRequest>,
) -> Result<impl IntoResponse, ApiError> {
// Email is already validated by the newtype deserialization
let email = payload.email;
if state
.user_service
.find_by_email(&payload.email)
.find_by_email(email.as_ref())
.await?
.is_some()
{
return Err(ApiError::Domain(DomainError::UserAlreadyExists(
payload.email,
email.as_ref().to_string(),
)));
}
let email = Email::try_from(payload.email).map_err(|e| ApiError::Validation(e.to_string()))?;
// Using email as subject for local auth for now
let user = state
.user_service
@@ -274,22 +275,22 @@ async fn oidc_login(State(state): State<AppState>, session: Session) -> Result<R
.as_ref()
.ok_or(ApiError::Internal("OIDC not configured".into()))?;
let (url, csrf, nonce, pkce) = service.get_authorization_url();
let auth_data = service.get_authorization_url();
session
.insert("oidc_csrf", csrf)
.insert("oidc_csrf", &auth_data.csrf_token)
.await
.map_err(|_| ApiError::Internal("Session error".into()))?;
session
.insert("oidc_nonce", nonce)
.insert("oidc_nonce", &auth_data.nonce)
.await
.map_err(|_| ApiError::Internal("Session error".into()))?;
session
.insert("oidc_pkce", pkce)
.insert("oidc_pkce", &auth_data.pkce_verifier)
.await
.map_err(|_| ApiError::Internal("Session error".into()))?;
let response = axum::response::Redirect::to(&url).into_response();
let response = axum::response::Redirect::to(auth_data.url.as_str()).into_response();
let (mut parts, body) = response.into_parts();
parts.headers.insert(
@@ -323,29 +324,33 @@ async fn oidc_callback(
.as_ref()
.ok_or(ApiError::Internal("OIDC not configured".into()))?;
let stored_csrf: String = session
let stored_csrf: domain::CsrfToken = session
.get("oidc_csrf")
.await
.map_err(|_| ApiError::Internal("Session error".into()))?
.ok_or(ApiError::Validation("Missing CSRF token".into()))?;
if params.state != stored_csrf {
if params.state != stored_csrf.as_ref() {
return Err(ApiError::Validation("Invalid CSRF token".into()));
}
let stored_pkce: String = session
let stored_pkce: domain::PkceVerifier = session
.get("oidc_pkce")
.await
.map_err(|_| ApiError::Internal("Session error".into()))?
.ok_or(ApiError::Validation("Missing PKCE".into()))?;
let stored_nonce: String = session
let stored_nonce: domain::OidcNonce = session
.get("oidc_nonce")
.await
.map_err(|_| ApiError::Internal("Session error".into()))?
.ok_or(ApiError::Validation("Missing Nonce".into()))?;
let oidc_user = service
.resolve_callback(params.code, stored_nonce, stored_pkce)
.resolve_callback(
domain::AuthorizationCode::new(params.code),
stored_nonce,
stored_pkce,
)
.await
.map_err(|e| ApiError::Internal(e.to_string()))?;
@@ -412,29 +417,33 @@ async fn oidc_callback(
.as_ref()
.ok_or(ApiError::Internal("OIDC not configured".into()))?;
let stored_csrf: String = session
let stored_csrf: domain::CsrfToken = session
.get("oidc_csrf")
.await
.map_err(|_| ApiError::Internal("Session error".into()))?
.ok_or(ApiError::Validation("Missing CSRF token".into()))?;
if params.state != stored_csrf {
if params.state != stored_csrf.as_ref() {
return Err(ApiError::Validation("Invalid CSRF token".into()));
}
let stored_pkce: String = session
let stored_pkce: domain::PkceVerifier = session
.get("oidc_pkce")
.await
.map_err(|_| ApiError::Internal("Session error".into()))?
.ok_or(ApiError::Validation("Missing PKCE".into()))?;
let stored_nonce: String = session
let stored_nonce: domain::OidcNonce = session
.get("oidc_nonce")
.await
.map_err(|_| ApiError::Internal("Session error".into()))?
.ok_or(ApiError::Validation("Missing Nonce".into()))?;
let oidc_user = service
.resolve_callback(params.code, stored_nonce, stored_pkce)
.resolve_callback(
domain::AuthorizationCode::new(params.code),
stored_nonce,
stored_pkce,
)
.await
.map_err(|e| ApiError::Internal(e.to_string()))?;

View File

@@ -33,14 +33,23 @@ impl AppState {
&config.oidc_resource_id,
) {
tracing::info!("Initializing OIDC service with issuer: {}", issuer);
// Construct newtypes from config strings
let issuer_url = domain::IssuerUrl::new(issuer)
.map_err(|e| anyhow::anyhow!("Invalid OIDC issuer URL: {}", e))?;
let client_id = domain::ClientId::new(id)
.map_err(|e| anyhow::anyhow!("Invalid OIDC client ID: {}", e))?;
let client_secret = secret.as_ref().map(|s| domain::ClientSecret::new(s));
let redirect_url = domain::RedirectUrl::new(redirect)
.map_err(|e| anyhow::anyhow!("Invalid OIDC redirect URL: {}", e))?;
let resource = resource_id
.as_ref()
.map(|r| domain::ResourceId::new(r))
.transpose()
.map_err(|e| anyhow::anyhow!("Invalid OIDC resource ID: {}", e))?;
Some(Arc::new(
OidcService::new(
issuer.clone(),
id.clone(),
secret.clone().unwrap_or_default(),
redirect.clone(),
resource_id.clone(),
)
OidcService::new(issuer_url, client_id, client_secret, redirect_url, resource)
.await?,
))
} else {

View File

@@ -7,10 +7,12 @@ edition = "2024"
anyhow = "1.0.100"
async-trait = "0.1.89"
chrono = { version = "0.4.42", features = ["serde"] }
email_address = "0.2"
serde = { version = "1.0.228", features = ["derive"] }
serde_json = "1.0.146"
thiserror = "2.0.17"
tracing = "0.1"
url = { version = "2.5", features = ["serde"] }
uuid = { version = "1.19.0", features = ["v4", "serde"] }
futures-core = "0.3"

View File

@@ -6,6 +6,7 @@
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use std::fmt;
use thiserror::Error;
use url::Url;
use uuid::Uuid;
pub type UserId = Uuid;
@@ -22,47 +23,44 @@ pub enum ValidationError {
#[error("Password must be at least {min} characters, got {actual}")]
PasswordTooShort { min: usize, actual: usize },
#[error("Invalid URL: {0}")]
InvalidUrl(String),
#[error("Value cannot be empty: {0}")]
Empty(String),
#[error("Secret too short: minimum {min} bytes required, got {actual}")]
SecretTooShort { min: usize, actual: usize },
}
// ============================================================================
// Email
// Email (using email_address crate for RFC-compliant validation)
// ============================================================================
/// A validated email address.
///
/// Simple validation: must contain exactly one `@` with non-empty parts on both sides.
/// A validated email address using RFC-compliant validation.
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct Email(String);
pub struct Email(email_address::EmailAddress);
impl Email {
/// Minimum validation: contains @ with non-empty local and domain parts
pub fn new(value: impl Into<String>) -> Result<Self, ValidationError> {
let value = value.into();
let trimmed = value.trim().to_lowercase();
// Basic email validation
let parts: Vec<&str> = trimmed.split('@').collect();
if parts.len() != 2 || parts[0].is_empty() || parts[1].is_empty() {
return Err(ValidationError::InvalidEmail(value));
}
// Domain must contain at least one dot
if !parts[1].contains('.') {
return Err(ValidationError::InvalidEmail(value));
}
Ok(Self(trimmed))
/// Create a new validated email address
pub fn new(value: impl AsRef<str>) -> Result<Self, ValidationError> {
let value = value.as_ref().trim().to_lowercase();
let addr: email_address::EmailAddress = value
.parse()
.map_err(|_| ValidationError::InvalidEmail(value.clone()))?;
Ok(Self(addr))
}
/// Get the inner value
pub fn into_inner(self) -> String {
self.0
self.0.to_string()
}
}
impl AsRef<str> for Email {
fn as_ref(&self) -> &str {
&self.0
self.0.as_ref()
}
}
@@ -90,7 +88,7 @@ impl TryFrom<&str> for Email {
impl Serialize for Email {
fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
serializer.serialize_str(&self.0)
serializer.serialize_str(self.0.as_ref())
}
}
@@ -171,6 +169,446 @@ impl<'de> Deserialize<'de> for Password {
// Note: Password should NOT implement Serialize to prevent accidental exposure
// ============================================================================
// OIDC Configuration Newtypes
// ============================================================================
/// OIDC Issuer URL - validated URL for the identity provider
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(try_from = "String", into = "String")]
pub struct IssuerUrl(Url);
impl IssuerUrl {
pub fn new(value: impl AsRef<str>) -> Result<Self, ValidationError> {
let value = value.as_ref().trim();
let url = Url::parse(value).map_err(|e| ValidationError::InvalidUrl(e.to_string()))?;
Ok(Self(url))
}
pub fn as_url(&self) -> &Url {
&self.0
}
}
impl AsRef<str> for IssuerUrl {
fn as_ref(&self) -> &str {
self.0.as_str()
}
}
impl fmt::Display for IssuerUrl {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.0)
}
}
impl TryFrom<String> for IssuerUrl {
type Error = ValidationError;
fn try_from(value: String) -> Result<Self, Self::Error> {
Self::new(value)
}
}
impl From<IssuerUrl> for String {
fn from(val: IssuerUrl) -> Self {
val.0.to_string()
}
}
/// OIDC Client Identifier
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(try_from = "String", into = "String")]
pub struct ClientId(String);
impl ClientId {
pub fn new(value: impl Into<String>) -> Result<Self, ValidationError> {
let value = value.into().trim().to_string();
if value.is_empty() {
return Err(ValidationError::Empty("client_id".to_string()));
}
Ok(Self(value))
}
}
impl AsRef<str> for ClientId {
fn as_ref(&self) -> &str {
&self.0
}
}
impl fmt::Display for ClientId {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.0)
}
}
impl TryFrom<String> for ClientId {
type Error = ValidationError;
fn try_from(value: String) -> Result<Self, Self::Error> {
Self::new(value)
}
}
impl From<ClientId> for String {
fn from(val: ClientId) -> Self {
val.0
}
}
/// OIDC Client Secret - hidden in Debug output
#[derive(Clone, PartialEq, Eq)]
pub struct ClientSecret(String);
impl ClientSecret {
pub fn new(value: impl Into<String>) -> Self {
Self(value.into())
}
/// Check if the secret is empty (for public clients)
pub fn is_empty(&self) -> bool {
self.0.trim().is_empty()
}
}
impl AsRef<str> for ClientSecret {
fn as_ref(&self) -> &str {
&self.0
}
}
impl fmt::Debug for ClientSecret {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "ClientSecret(***)")
}
}
impl fmt::Display for ClientSecret {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "***")
}
}
impl<'de> Deserialize<'de> for ClientSecret {
fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
let s = String::deserialize(deserializer)?;
Ok(Self::new(s))
}
}
// Note: ClientSecret should NOT implement Serialize
/// OAuth Redirect URL - validated URL
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(try_from = "String", into = "String")]
pub struct RedirectUrl(Url);
impl RedirectUrl {
pub fn new(value: impl AsRef<str>) -> Result<Self, ValidationError> {
let value = value.as_ref().trim();
let url = Url::parse(value).map_err(|e| ValidationError::InvalidUrl(e.to_string()))?;
Ok(Self(url))
}
pub fn as_url(&self) -> &Url {
&self.0
}
}
impl AsRef<str> for RedirectUrl {
fn as_ref(&self) -> &str {
self.0.as_str()
}
}
impl fmt::Display for RedirectUrl {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.0)
}
}
impl TryFrom<String> for RedirectUrl {
type Error = ValidationError;
fn try_from(value: String) -> Result<Self, Self::Error> {
Self::new(value)
}
}
impl From<RedirectUrl> for String {
fn from(val: RedirectUrl) -> Self {
val.0.to_string()
}
}
/// OIDC Resource Identifier (optional audience)
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(try_from = "String", into = "String")]
pub struct ResourceId(String);
impl ResourceId {
pub fn new(value: impl Into<String>) -> Result<Self, ValidationError> {
let value = value.into().trim().to_string();
if value.is_empty() {
return Err(ValidationError::Empty("resource_id".to_string()));
}
Ok(Self(value))
}
}
impl AsRef<str> for ResourceId {
fn as_ref(&self) -> &str {
&self.0
}
}
impl fmt::Display for ResourceId {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.0)
}
}
impl TryFrom<String> for ResourceId {
type Error = ValidationError;
fn try_from(value: String) -> Result<Self, Self::Error> {
Self::new(value)
}
}
impl From<ResourceId> for String {
fn from(val: ResourceId) -> Self {
val.0
}
}
// ============================================================================
// OIDC Flow Newtypes (for type-safe session storage)
// ============================================================================
/// CSRF Token for OIDC state parameter
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct CsrfToken(String);
impl CsrfToken {
pub fn new(value: impl Into<String>) -> Self {
Self(value.into())
}
}
impl AsRef<str> for CsrfToken {
fn as_ref(&self) -> &str {
&self.0
}
}
impl fmt::Display for CsrfToken {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.0)
}
}
/// Nonce for OIDC ID token verification
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct OidcNonce(String);
impl OidcNonce {
pub fn new(value: impl Into<String>) -> Self {
Self(value.into())
}
}
impl AsRef<str> for OidcNonce {
fn as_ref(&self) -> &str {
&self.0
}
}
impl fmt::Display for OidcNonce {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.0)
}
}
/// PKCE Code Verifier
#[derive(Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct PkceVerifier(String);
impl PkceVerifier {
pub fn new(value: impl Into<String>) -> Self {
Self(value.into())
}
}
impl AsRef<str> for PkceVerifier {
fn as_ref(&self) -> &str {
&self.0
}
}
// Hide PKCE verifier in Debug (security)
impl fmt::Debug for PkceVerifier {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "PkceVerifier(***)")
}
}
/// OAuth2 Authorization Code
#[derive(Clone, PartialEq, Eq)]
pub struct AuthorizationCode(String);
impl AuthorizationCode {
pub fn new(value: impl Into<String>) -> Self {
Self(value.into())
}
}
impl AsRef<str> for AuthorizationCode {
fn as_ref(&self) -> &str {
&self.0
}
}
// Hide authorization code in Debug (security)
impl fmt::Debug for AuthorizationCode {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "AuthorizationCode(***)")
}
}
impl<'de> Deserialize<'de> for AuthorizationCode {
fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
let s = String::deserialize(deserializer)?;
Ok(Self::new(s))
}
}
/// Complete authorization URL data returned when starting OIDC flow
#[derive(Debug, Clone)]
pub struct AuthorizationUrlData {
/// The URL to redirect the user to
pub url: Url,
/// CSRF token to store in session
pub csrf_token: CsrfToken,
/// Nonce to store in session
pub nonce: OidcNonce,
/// PKCE verifier to store in session
pub pkce_verifier: PkceVerifier,
}
// ============================================================================
// Configuration Newtypes
// ============================================================================
/// Database connection URL
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(try_from = "String", into = "String")]
pub struct DatabaseUrl(String);
impl DatabaseUrl {
pub fn new(value: impl Into<String>) -> Result<Self, ValidationError> {
let value = value.into();
if value.trim().is_empty() {
return Err(ValidationError::Empty("database_url".to_string()));
}
Ok(Self(value))
}
}
impl AsRef<str> for DatabaseUrl {
fn as_ref(&self) -> &str {
&self.0
}
}
impl fmt::Display for DatabaseUrl {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.0)
}
}
impl TryFrom<String> for DatabaseUrl {
type Error = ValidationError;
fn try_from(value: String) -> Result<Self, Self::Error> {
Self::new(value)
}
}
impl From<DatabaseUrl> for String {
fn from(val: DatabaseUrl) -> Self {
val.0
}
}
/// Session secret with minimum length requirement
pub const MIN_SESSION_SECRET_LENGTH: usize = 64;
#[derive(Clone, PartialEq, Eq)]
pub struct SessionSecret(String);
impl SessionSecret {
pub fn new(value: impl Into<String>) -> Result<Self, ValidationError> {
let value = value.into();
if value.len() < MIN_SESSION_SECRET_LENGTH {
return Err(ValidationError::SecretTooShort {
min: MIN_SESSION_SECRET_LENGTH,
actual: value.len(),
});
}
Ok(Self(value))
}
/// Create without validation (for development/testing)
pub fn new_unchecked(value: impl Into<String>) -> Self {
Self(value.into())
}
}
impl AsRef<str> for SessionSecret {
fn as_ref(&self) -> &str {
&self.0
}
}
impl fmt::Debug for SessionSecret {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "SessionSecret(***)")
}
}
/// JWT signing secret with minimum length requirement
pub const MIN_JWT_SECRET_LENGTH: usize = 32;
#[derive(Clone, PartialEq, Eq)]
pub struct JwtSecret(String);
impl JwtSecret {
pub fn new(value: impl Into<String>, is_production: bool) -> Result<Self, ValidationError> {
let value = value.into();
if is_production && value.len() < MIN_JWT_SECRET_LENGTH {
return Err(ValidationError::SecretTooShort {
min: MIN_JWT_SECRET_LENGTH,
actual: value.len(),
});
}
Ok(Self(value))
}
/// Create without validation (for development/testing)
pub fn new_unchecked(value: impl Into<String>) -> Self {
Self(value.into())
}
}
impl AsRef<str> for JwtSecret {
fn as_ref(&self) -> &str {
&self.0
}
}
impl fmt::Debug for JwtSecret {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "JwtSecret(***)")
}
}
// ============================================================================
// Tests
// ============================================================================
@@ -209,11 +647,6 @@ mod tests {
fn test_invalid_email_no_local() {
assert!(Email::new("@example.com").is_err());
}
#[test]
fn test_invalid_email_no_dot_in_domain() {
assert!(Email::new("user@localhost").is_err());
}
}
mod password_tests {
@@ -239,4 +672,68 @@ mod tests {
assert!(debug.contains("***"));
}
}
mod oidc_tests {
use super::*;
#[test]
fn test_issuer_url_valid() {
assert!(IssuerUrl::new("https://auth.example.com").is_ok());
}
#[test]
fn test_issuer_url_invalid() {
assert!(IssuerUrl::new("not-a-url").is_err());
}
#[test]
fn test_client_id_non_empty() {
assert!(ClientId::new("my-client").is_ok());
assert!(ClientId::new("").is_err());
assert!(ClientId::new(" ").is_err());
}
#[test]
fn test_client_secret_hides_in_debug() {
let secret = ClientSecret::new("super-secret");
let debug = format!("{:?}", secret);
assert!(!debug.contains("super-secret"));
assert!(debug.contains("***"));
}
}
mod secret_tests {
use super::*;
#[test]
fn test_session_secret_min_length() {
let short = "short";
let long = "a".repeat(64);
assert!(SessionSecret::new(short).is_err());
assert!(SessionSecret::new(long).is_ok());
}
#[test]
fn test_jwt_secret_production_check() {
let short = "short";
let long = "a".repeat(32);
// Production mode enforces length
assert!(JwtSecret::new(short, true).is_err());
assert!(JwtSecret::new(&long, true).is_ok());
// Development mode allows short secrets
assert!(JwtSecret::new(short, false).is_ok());
}
#[test]
fn test_secrets_hide_in_debug() {
let session = SessionSecret::new_unchecked("secret");
let jwt = JwtSecret::new_unchecked("secret");
assert!(!format!("{:?}", session).contains("secret"));
assert!(!format!("{:?}", jwt).contains("secret"));
}
}
}

View File

@@ -51,8 +51,8 @@ pub mod backend {
#[derive(Clone, Debug, Deserialize)]
pub struct Credentials {
pub email: String,
pub password: String,
pub email: domain::Email,
pub password: domain::Password,
}
#[derive(Debug, thiserror::Error)]
@@ -72,14 +72,14 @@ pub mod backend {
) -> Result<Option<Self::User>, Self::Error> {
let user = self
.user_repo
.find_by_email(&creds.email)
.find_by_email(creds.email.as_ref())
.await
.map_err(|e| AuthError::Anyhow(anyhow::anyhow!(e)))?;
if let Some(user) = user {
if let Some(hash) = &user.password_hash {
// Verify password
if verify_password(&creds.password, hash).is_ok() {
if verify_password(creds.password.as_ref(), hash).is_ok() {
return Ok(Some(AuthUser(user)));
}
}

View File

@@ -1,9 +1,12 @@
use anyhow::anyhow;
use domain::{
AuthorizationCode, AuthorizationUrlData, ClientId, ClientSecret, CsrfToken, IssuerUrl,
OidcNonce, PkceVerifier, RedirectUrl, ResourceId,
};
use openidconnect::{
AccessTokenHash, AuthorizationCode, Client, ClientId, ClientSecret, CsrfToken,
EmptyAdditionalClaims, EndpointMaybeSet, EndpointNotSet, EndpointSet, IssuerUrl, Nonce,
OAuth2TokenResponse, PkceCodeChallenge, PkceCodeVerifier, RedirectUrl, Scope,
StandardErrorResponse, TokenResponse, UserInfoClaims,
AccessTokenHash, Client, EmptyAdditionalClaims, EndpointMaybeSet, EndpointNotSet, EndpointSet,
OAuth2TokenResponse, PkceCodeChallenge, Scope, StandardErrorResponse, TokenResponse,
UserInfoClaims,
core::{
CoreAuthDisplay, CoreAuthPrompt, CoreAuthenticationFlow, CoreClient, CoreErrorResponseType,
CoreGenderClaim, CoreJsonWebKey, CoreJweContentEncryptionAlgorithm, CoreProviderMetadata,
@@ -36,7 +39,7 @@ pub type OidcClient = Client<
#[derive(Clone)]
pub struct OidcService {
client: OidcClient,
resource_id: Option<String>,
resource_id: Option<ResourceId>,
}
#[derive(Debug)]
@@ -46,31 +49,19 @@ pub struct OidcUser {
}
impl OidcService {
//todo: replace Strings with newtypes
/// Create a new OIDC service with validated configuration newtypes
pub async fn new(
issuer: String,
client_id: String,
client_secret: String,
redirect_url: String,
resource_id: Option<String>,
issuer: IssuerUrl,
client_id: ClientId,
client_secret: Option<ClientSecret>,
redirect_url: RedirectUrl,
resource_id: Option<ResourceId>,
) -> anyhow::Result<Self> {
let client_id = client_id.trim().to_string();
let redirect_url = redirect_url.trim().to_string();
let issuer = issuer.trim().to_string();
// 2. Handle Empty Secret (For PKCE/Public Clients)
let client_secret_clean = client_secret.trim();
let client_secret_opt = if client_secret_clean.is_empty() {
None
} else {
Some(ClientSecret::new(client_secret_clean.to_string()))
};
tracing::debug!("🔵 OIDC Setup: Client ID = '{}'", client_id);
tracing::debug!("🔵 OIDC Setup: Redirect = '{}'", redirect_url);
tracing::debug!(
"🔵 OIDC Setup: Secret = {:?}",
if client_secret_opt.is_some() {
if client_secret.is_some() {
"SET"
} else {
"NONE"
@@ -81,15 +72,26 @@ impl OidcService {
.redirect(reqwest::redirect::Policy::none())
.build()?;
let provider_metadata =
CoreProviderMetadata::discover_async(IssuerUrl::new(issuer)?, &http_client).await?;
let provider_metadata = CoreProviderMetadata::discover_async(
openidconnect::IssuerUrl::new(issuer.as_ref().to_string())?,
&http_client,
)
.await?;
// Convert to openidconnect types
let oidc_client_id = openidconnect::ClientId::new(client_id.as_ref().to_string());
let oidc_client_secret = client_secret
.as_ref()
.filter(|s| !s.is_empty())
.map(|s| openidconnect::ClientSecret::new(s.as_ref().to_string()));
let oidc_redirect_url = openidconnect::RedirectUrl::new(redirect_url.as_ref().to_string())?;
let client = CoreClient::from_provider_metadata(
provider_metadata,
ClientId::new(client_id),
client_secret_opt,
oidc_client_id,
oidc_client_secret,
)
.set_redirect_uri(RedirectUrl::new(redirect_url)?);
.set_redirect_uri(oidc_redirect_url);
Ok(Self {
client,
@@ -97,48 +99,53 @@ impl OidcService {
})
}
// todo: replace this tuple with newtype
pub fn get_authorization_url(&self) -> (String, String, String, String) {
/// Get the authorization URL and associated state for OIDC login
///
/// Returns structured data instead of a raw tuple for better type safety
pub fn get_authorization_url(&self) -> AuthorizationUrlData {
let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256();
let (auth_url, csrf_token, nonce) = self
.client
.authorize_url(
CoreAuthenticationFlow::AuthorizationCode,
CsrfToken::new_random,
Nonce::new_random,
openidconnect::CsrfToken::new_random,
openidconnect::Nonce::new_random,
)
.add_scope(Scope::new("profile".to_string()))
.add_scope(Scope::new("email".to_string()))
.set_pkce_challenge(pkce_challenge)
.url();
(
auth_url.to_string(),
csrf_token.secret().to_string(),
nonce.secret().to_string(),
pkce_verifier.secret().to_string(),
)
AuthorizationUrlData {
url: auth_url.into(),
csrf_token: CsrfToken::new(csrf_token.secret().to_string()),
nonce: OidcNonce::new(nonce.secret().to_string()),
pkce_verifier: PkceVerifier::new(pkce_verifier.secret().to_string()),
}
}
//todo: replace strings with newtype
/// Resolve the OIDC callback with type-safe parameters
pub async fn resolve_callback(
&self,
code: String,
nonce: String,
pkce_verifier: String,
code: AuthorizationCode,
nonce: OidcNonce,
pkce_verifier: PkceVerifier,
) -> anyhow::Result<OidcUser> {
let http_client = reqwest::ClientBuilder::new()
.redirect(reqwest::redirect::Policy::none())
.build()?;
let pkce_verifier = PkceCodeVerifier::new(pkce_verifier);
let nonce = Nonce::new(nonce);
let oidc_pkce_verifier =
openidconnect::PkceCodeVerifier::new(pkce_verifier.as_ref().to_string());
let oidc_nonce = openidconnect::Nonce::new(nonce.as_ref().to_string());
let token_response = self
.client
.exchange_code(AuthorizationCode::new(code))?
.set_pkce_verifier(pkce_verifier)
.exchange_code(openidconnect::AuthorizationCode::new(
code.as_ref().to_string(),
))?
.set_pkce_verifier(oidc_pkce_verifier)
.request_async(&http_client)
.await?;
@@ -148,14 +155,13 @@ impl OidcService {
let mut id_token_verifier = self.client.id_token_verifier().clone();
let trusted_resource_id = self.resource_id.clone();
if let Some(resource_id) = trusted_resource_id {
if let Some(resource_id) = &self.resource_id {
let trusted_resource_id = resource_id.as_ref().to_string();
id_token_verifier = id_token_verifier
.set_other_audience_verifier_fn(move |aud| aud.as_str() == resource_id);
.set_other_audience_verifier_fn(move |aud| aud.as_str() == trusted_resource_id);
}
let claims = id_token.claims(&id_token_verifier, &nonce)?;
let claims = id_token.claims(&id_token_verifier, &oidc_nonce)?;
if let Some(expected_access_token_hash) = claims.access_token_hash() {
let actual_access_token_hash = AccessTokenHash::from_token(