From 32a0faf302429d39e1529690aa547fa5e32d6103 Mon Sep 17 00:00:00 2001 From: Gabriel Kaszewski Date: Tue, 6 Jan 2026 05:16:16 +0100 Subject: [PATCH] refactor: Replace raw strings with domain value objects for improved type safety in authentication and OIDC. --- Cargo.lock | 107 +------ api/Cargo.toml | 3 +- api/src/dto.rs | 38 +-- api/src/routes/auth.rs | 49 ++-- api/src/state.rs | 25 +- domain/Cargo.toml | 2 + domain/src/value_objects.rs | 557 ++++++++++++++++++++++++++++++++++-- infra/src/auth/mod.rs | 8 +- infra/src/auth/oidc.rs | 110 +++---- 9 files changed, 667 insertions(+), 232 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 4a564a1..a7e9b54 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -57,7 +57,6 @@ dependencies = [ "tracing", "tracing-subscriber", "uuid", - "validator", ] [[package]] @@ -523,38 +522,14 @@ dependencies = [ "syn", ] -[[package]] -name = "darling" -version = "0.20.11" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fc7f46116c46ff9ab3eb1597a45688b6715c6e628b5c133e288e709a29bcb4ee" -dependencies = [ - "darling_core 0.20.11", - "darling_macro 0.20.11", -] - [[package]] name = "darling" version = "0.21.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9cdf337090841a411e2a7f3deb9187445851f91b309c0c0a29e05f74a00a48c0" dependencies = [ - "darling_core 0.21.3", - "darling_macro 0.21.3", -] - -[[package]] -name = "darling_core" -version = "0.20.11" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0d00b9596d185e565c2207a0b01f8bd1a135483d02d9b7b0a54b11da8d53412e" -dependencies = [ - "fnv", - "ident_case", - "proc-macro2", - "quote", - "strsim", - "syn", + "darling_core", + "darling_macro", ] [[package]] @@ -571,24 +546,13 @@ dependencies = [ "syn", ] -[[package]] -name = "darling_macro" -version = "0.20.11" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fc34b93ccb385b40dc71c6fceac4b2ad23662c7eeb248cf10d529b7e055b6ead" -dependencies = [ - "darling_core 0.20.11", - "quote", - "syn", -] - [[package]] name = "darling_macro" version = "0.21.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d38308df82d1080de0afee5d069fa14b0326a88c14f15c5ccda35b4a6c414c81" dependencies = [ - "darling_core 0.21.3", + "darling_core", "quote", "syn", ] @@ -659,12 +623,14 @@ dependencies = [ "anyhow", "async-trait", "chrono", + "email_address", "futures-core", "serde", "serde_json", "thiserror 2.0.17", "tokio", "tracing", + "url", "uuid", ] @@ -749,6 +715,15 @@ dependencies = [ "zeroize", ] +[[package]] +name = "email_address" +version = "0.2.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e079f19b08ca6239f47f8ba8509c11cf3ea30095831f7fed61441475edd8c449" +dependencies = [ + "serde", +] + [[package]] name = "encoding_rs" version = "0.8.35" @@ -2013,28 +1988,6 @@ dependencies = [ "elliptic-curve", ] -[[package]] -name = "proc-macro-error-attr2" -version = "2.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "96de42df36bb9bba5542fe9f1a054b8cc87e172759a1868aa05c1f3acc89dfc5" -dependencies = [ - "proc-macro2", - "quote", -] - -[[package]] -name = "proc-macro-error2" -version = "2.0.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "11ec05c52be0a07b08061f7dd003e7d7092e0472bc731b4af7bb1ef876109802" -dependencies = [ - "proc-macro-error-attr2", - "proc-macro2", - "quote", - "syn", -] - [[package]] name = "proc-macro2" version = "1.0.104" @@ -2692,7 +2645,7 @@ version = "3.16.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "52a8e3ca0ca629121f70ab50f95249e5a6f925cc0f6ffe8256c45b728875706c" dependencies = [ - "darling 0.21.3", + "darling", "proc-macro2", "quote", "syn", @@ -3611,36 +3564,6 @@ dependencies = [ "wasm-bindgen", ] -[[package]] -name = "validator" -version = "0.20.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "43fb22e1a008ece370ce08a3e9e4447a910e92621bb49b85d6e48a45397e7cfa" -dependencies = [ - "idna", - "once_cell", - "regex", - "serde", - "serde_derive", - "serde_json", - "url", - "validator_derive", -] - -[[package]] -name = "validator_derive" -version = "0.20.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b7df16e474ef958526d1205f6dda359fdfab79d9aa6d54bafcb92dcd07673dca" -dependencies = [ - "darling 0.20.11", - "once_cell", - "proc-macro-error2", - "proc-macro2", - "quote", - "syn", -] - [[package]] name = "valuable" version = "0.1.1" diff --git a/api/Cargo.toml b/api/Cargo.toml index 7482743..1aec6a4 100644 --- a/api/Cargo.toml +++ b/api/Cargo.toml @@ -44,8 +44,7 @@ tokio = { version = "1.48.0", features = ["full"] } serde = { version = "1.0.228", features = ["derive"] } serde_json = "1.0" -# Validation -validator = { version = "0.20", features = ["derive"] } +# Validation via domain newtypes (Email, Password) # Error handling thiserror = "2.0.17" diff --git a/api/src/dto.rs b/api/src/dto.rs index 17cbd1a..99a2d88 100644 --- a/api/src/dto.rs +++ b/api/src/dto.rs @@ -1,30 +1,29 @@ //! Request and Response DTOs //! //! Data Transfer Objects for the API. +//! Uses domain newtypes for validation instead of the validator crate. use chrono::{DateTime, Utc}; +use domain::{Email, Password}; use serde::{Deserialize, Serialize}; use uuid::Uuid; -use validator::Validate; -/// Login request -#[derive(Debug, Deserialize, Validate)] +/// Login request with validated email and password newtypes +#[derive(Debug, Deserialize)] pub struct LoginRequest { - #[validate(email(message = "Invalid email format"))] - pub email: String, - - #[validate(length(min = 6, message = "Password must be at least 6 characters"))] - pub password: String, + /// Email is validated on deserialization + pub email: Email, + /// Password is validated on deserialization (min 6 chars) + pub password: Password, } -/// Register request -#[derive(Debug, Deserialize, Validate)] +/// Register request with validated email and password newtypes +#[derive(Debug, Deserialize)] pub struct RegisterRequest { - #[validate(email(message = "Invalid email format"))] - pub email: String, - - #[validate(length(min = 6, message = "Password must be at least 6 characters"))] - pub password: String, + /// Email is validated on deserialization + pub email: Email, + /// Password is validated on deserialization (min 6 chars) + pub password: Password, } /// User response DTO @@ -40,12 +39,3 @@ pub struct UserResponse { pub struct ConfigResponse { pub allow_registration: bool, } - -#[cfg(feature = "auth-jwt")] -#[derive(Debug, Serialize, Deserialize)] -// also newtypes -pub struct Claims { - pub sub: String, - pub email: String, - pub exp: usize, -} diff --git a/api/src/routes/auth.rs b/api/src/routes/auth.rs index af190f2..382a00c 100644 --- a/api/src/routes/auth.rs +++ b/api/src/routes/auth.rs @@ -25,7 +25,7 @@ use crate::{ state::AppState, }; #[cfg(feature = "auth-axum-login")] -use domain::{DomainError, Email}; +use domain::DomainError; /// Token response for JWT authentication #[derive(Debug, Serialize)] @@ -140,19 +140,20 @@ async fn register( mut auth_session: crate::auth::AuthSession, Json(payload): Json, ) -> Result { + // Email is already validated by the newtype deserialization + let email = payload.email; + if state .user_service - .find_by_email(&payload.email) + .find_by_email(email.as_ref()) .await? .is_some() { return Err(ApiError::Domain(DomainError::UserAlreadyExists( - payload.email, + email.as_ref().to_string(), ))); } - let email = Email::try_from(payload.email).map_err(|e| ApiError::Validation(e.to_string()))?; - // Using email as subject for local auth for now let user = state .user_service @@ -274,22 +275,22 @@ async fn oidc_login(State(state): State, session: Session) -> Result) -> Result { - let value = value.into(); - let trimmed = value.trim().to_lowercase(); - - // Basic email validation - let parts: Vec<&str> = trimmed.split('@').collect(); - if parts.len() != 2 || parts[0].is_empty() || parts[1].is_empty() { - return Err(ValidationError::InvalidEmail(value)); - } - - // Domain must contain at least one dot - if !parts[1].contains('.') { - return Err(ValidationError::InvalidEmail(value)); - } - - Ok(Self(trimmed)) + /// Create a new validated email address + pub fn new(value: impl AsRef) -> Result { + let value = value.as_ref().trim().to_lowercase(); + let addr: email_address::EmailAddress = value + .parse() + .map_err(|_| ValidationError::InvalidEmail(value.clone()))?; + Ok(Self(addr)) } /// Get the inner value pub fn into_inner(self) -> String { - self.0 + self.0.to_string() } } impl AsRef for Email { fn as_ref(&self) -> &str { - &self.0 + self.0.as_ref() } } @@ -90,7 +88,7 @@ impl TryFrom<&str> for Email { impl Serialize for Email { fn serialize(&self, serializer: S) -> Result { - serializer.serialize_str(&self.0) + serializer.serialize_str(self.0.as_ref()) } } @@ -171,6 +169,446 @@ impl<'de> Deserialize<'de> for Password { // Note: Password should NOT implement Serialize to prevent accidental exposure +// ============================================================================ +// OIDC Configuration Newtypes +// ============================================================================ + +/// OIDC Issuer URL - validated URL for the identity provider +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[serde(try_from = "String", into = "String")] +pub struct IssuerUrl(Url); + +impl IssuerUrl { + pub fn new(value: impl AsRef) -> Result { + let value = value.as_ref().trim(); + let url = Url::parse(value).map_err(|e| ValidationError::InvalidUrl(e.to_string()))?; + Ok(Self(url)) + } + + pub fn as_url(&self) -> &Url { + &self.0 + } +} + +impl AsRef for IssuerUrl { + fn as_ref(&self) -> &str { + self.0.as_str() + } +} + +impl fmt::Display for IssuerUrl { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.0) + } +} + +impl TryFrom for IssuerUrl { + type Error = ValidationError; + fn try_from(value: String) -> Result { + Self::new(value) + } +} + +impl From for String { + fn from(val: IssuerUrl) -> Self { + val.0.to_string() + } +} + +/// OIDC Client Identifier +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[serde(try_from = "String", into = "String")] +pub struct ClientId(String); + +impl ClientId { + pub fn new(value: impl Into) -> Result { + let value = value.into().trim().to_string(); + if value.is_empty() { + return Err(ValidationError::Empty("client_id".to_string())); + } + Ok(Self(value)) + } +} + +impl AsRef for ClientId { + fn as_ref(&self) -> &str { + &self.0 + } +} + +impl fmt::Display for ClientId { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.0) + } +} + +impl TryFrom for ClientId { + type Error = ValidationError; + fn try_from(value: String) -> Result { + Self::new(value) + } +} + +impl From for String { + fn from(val: ClientId) -> Self { + val.0 + } +} + +/// OIDC Client Secret - hidden in Debug output +#[derive(Clone, PartialEq, Eq)] +pub struct ClientSecret(String); + +impl ClientSecret { + pub fn new(value: impl Into) -> Self { + Self(value.into()) + } + + /// Check if the secret is empty (for public clients) + pub fn is_empty(&self) -> bool { + self.0.trim().is_empty() + } +} + +impl AsRef for ClientSecret { + fn as_ref(&self) -> &str { + &self.0 + } +} + +impl fmt::Debug for ClientSecret { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "ClientSecret(***)") + } +} + +impl fmt::Display for ClientSecret { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "***") + } +} + +impl<'de> Deserialize<'de> for ClientSecret { + fn deserialize>(deserializer: D) -> Result { + let s = String::deserialize(deserializer)?; + Ok(Self::new(s)) + } +} + +// Note: ClientSecret should NOT implement Serialize + +/// OAuth Redirect URL - validated URL +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[serde(try_from = "String", into = "String")] +pub struct RedirectUrl(Url); + +impl RedirectUrl { + pub fn new(value: impl AsRef) -> Result { + let value = value.as_ref().trim(); + let url = Url::parse(value).map_err(|e| ValidationError::InvalidUrl(e.to_string()))?; + Ok(Self(url)) + } + + pub fn as_url(&self) -> &Url { + &self.0 + } +} + +impl AsRef for RedirectUrl { + fn as_ref(&self) -> &str { + self.0.as_str() + } +} + +impl fmt::Display for RedirectUrl { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.0) + } +} + +impl TryFrom for RedirectUrl { + type Error = ValidationError; + fn try_from(value: String) -> Result { + Self::new(value) + } +} + +impl From for String { + fn from(val: RedirectUrl) -> Self { + val.0.to_string() + } +} + +/// OIDC Resource Identifier (optional audience) +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[serde(try_from = "String", into = "String")] +pub struct ResourceId(String); + +impl ResourceId { + pub fn new(value: impl Into) -> Result { + let value = value.into().trim().to_string(); + if value.is_empty() { + return Err(ValidationError::Empty("resource_id".to_string())); + } + Ok(Self(value)) + } +} + +impl AsRef for ResourceId { + fn as_ref(&self) -> &str { + &self.0 + } +} + +impl fmt::Display for ResourceId { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.0) + } +} + +impl TryFrom for ResourceId { + type Error = ValidationError; + fn try_from(value: String) -> Result { + Self::new(value) + } +} + +impl From for String { + fn from(val: ResourceId) -> Self { + val.0 + } +} + +// ============================================================================ +// OIDC Flow Newtypes (for type-safe session storage) +// ============================================================================ + +/// CSRF Token for OIDC state parameter +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct CsrfToken(String); + +impl CsrfToken { + pub fn new(value: impl Into) -> Self { + Self(value.into()) + } +} + +impl AsRef for CsrfToken { + fn as_ref(&self) -> &str { + &self.0 + } +} + +impl fmt::Display for CsrfToken { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.0) + } +} + +/// Nonce for OIDC ID token verification +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct OidcNonce(String); + +impl OidcNonce { + pub fn new(value: impl Into) -> Self { + Self(value.into()) + } +} + +impl AsRef for OidcNonce { + fn as_ref(&self) -> &str { + &self.0 + } +} + +impl fmt::Display for OidcNonce { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.0) + } +} + +/// PKCE Code Verifier +#[derive(Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct PkceVerifier(String); + +impl PkceVerifier { + pub fn new(value: impl Into) -> Self { + Self(value.into()) + } +} + +impl AsRef for PkceVerifier { + fn as_ref(&self) -> &str { + &self.0 + } +} + +// Hide PKCE verifier in Debug (security) +impl fmt::Debug for PkceVerifier { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "PkceVerifier(***)") + } +} + +/// OAuth2 Authorization Code +#[derive(Clone, PartialEq, Eq)] +pub struct AuthorizationCode(String); + +impl AuthorizationCode { + pub fn new(value: impl Into) -> Self { + Self(value.into()) + } +} + +impl AsRef for AuthorizationCode { + fn as_ref(&self) -> &str { + &self.0 + } +} + +// Hide authorization code in Debug (security) +impl fmt::Debug for AuthorizationCode { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "AuthorizationCode(***)") + } +} + +impl<'de> Deserialize<'de> for AuthorizationCode { + fn deserialize>(deserializer: D) -> Result { + let s = String::deserialize(deserializer)?; + Ok(Self::new(s)) + } +} + +/// Complete authorization URL data returned when starting OIDC flow +#[derive(Debug, Clone)] +pub struct AuthorizationUrlData { + /// The URL to redirect the user to + pub url: Url, + /// CSRF token to store in session + pub csrf_token: CsrfToken, + /// Nonce to store in session + pub nonce: OidcNonce, + /// PKCE verifier to store in session + pub pkce_verifier: PkceVerifier, +} + +// ============================================================================ +// Configuration Newtypes +// ============================================================================ + +/// Database connection URL +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[serde(try_from = "String", into = "String")] +pub struct DatabaseUrl(String); + +impl DatabaseUrl { + pub fn new(value: impl Into) -> Result { + let value = value.into(); + if value.trim().is_empty() { + return Err(ValidationError::Empty("database_url".to_string())); + } + Ok(Self(value)) + } +} + +impl AsRef for DatabaseUrl { + fn as_ref(&self) -> &str { + &self.0 + } +} + +impl fmt::Display for DatabaseUrl { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.0) + } +} + +impl TryFrom for DatabaseUrl { + type Error = ValidationError; + fn try_from(value: String) -> Result { + Self::new(value) + } +} + +impl From for String { + fn from(val: DatabaseUrl) -> Self { + val.0 + } +} + +/// Session secret with minimum length requirement +pub const MIN_SESSION_SECRET_LENGTH: usize = 64; + +#[derive(Clone, PartialEq, Eq)] +pub struct SessionSecret(String); + +impl SessionSecret { + pub fn new(value: impl Into) -> Result { + let value = value.into(); + if value.len() < MIN_SESSION_SECRET_LENGTH { + return Err(ValidationError::SecretTooShort { + min: MIN_SESSION_SECRET_LENGTH, + actual: value.len(), + }); + } + Ok(Self(value)) + } + + /// Create without validation (for development/testing) + pub fn new_unchecked(value: impl Into) -> Self { + Self(value.into()) + } +} + +impl AsRef for SessionSecret { + fn as_ref(&self) -> &str { + &self.0 + } +} + +impl fmt::Debug for SessionSecret { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "SessionSecret(***)") + } +} + +/// JWT signing secret with minimum length requirement +pub const MIN_JWT_SECRET_LENGTH: usize = 32; + +#[derive(Clone, PartialEq, Eq)] +pub struct JwtSecret(String); + +impl JwtSecret { + pub fn new(value: impl Into, is_production: bool) -> Result { + let value = value.into(); + if is_production && value.len() < MIN_JWT_SECRET_LENGTH { + return Err(ValidationError::SecretTooShort { + min: MIN_JWT_SECRET_LENGTH, + actual: value.len(), + }); + } + Ok(Self(value)) + } + + /// Create without validation (for development/testing) + pub fn new_unchecked(value: impl Into) -> Self { + Self(value.into()) + } +} + +impl AsRef for JwtSecret { + fn as_ref(&self) -> &str { + &self.0 + } +} + +impl fmt::Debug for JwtSecret { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "JwtSecret(***)") + } +} + // ============================================================================ // Tests // ============================================================================ @@ -209,11 +647,6 @@ mod tests { fn test_invalid_email_no_local() { assert!(Email::new("@example.com").is_err()); } - - #[test] - fn test_invalid_email_no_dot_in_domain() { - assert!(Email::new("user@localhost").is_err()); - } } mod password_tests { @@ -239,4 +672,68 @@ mod tests { assert!(debug.contains("***")); } } + + mod oidc_tests { + use super::*; + + #[test] + fn test_issuer_url_valid() { + assert!(IssuerUrl::new("https://auth.example.com").is_ok()); + } + + #[test] + fn test_issuer_url_invalid() { + assert!(IssuerUrl::new("not-a-url").is_err()); + } + + #[test] + fn test_client_id_non_empty() { + assert!(ClientId::new("my-client").is_ok()); + assert!(ClientId::new("").is_err()); + assert!(ClientId::new(" ").is_err()); + } + + #[test] + fn test_client_secret_hides_in_debug() { + let secret = ClientSecret::new("super-secret"); + let debug = format!("{:?}", secret); + assert!(!debug.contains("super-secret")); + assert!(debug.contains("***")); + } + } + + mod secret_tests { + use super::*; + + #[test] + fn test_session_secret_min_length() { + let short = "short"; + let long = "a".repeat(64); + + assert!(SessionSecret::new(short).is_err()); + assert!(SessionSecret::new(long).is_ok()); + } + + #[test] + fn test_jwt_secret_production_check() { + let short = "short"; + let long = "a".repeat(32); + + // Production mode enforces length + assert!(JwtSecret::new(short, true).is_err()); + assert!(JwtSecret::new(&long, true).is_ok()); + + // Development mode allows short secrets + assert!(JwtSecret::new(short, false).is_ok()); + } + + #[test] + fn test_secrets_hide_in_debug() { + let session = SessionSecret::new_unchecked("secret"); + let jwt = JwtSecret::new_unchecked("secret"); + + assert!(!format!("{:?}", session).contains("secret")); + assert!(!format!("{:?}", jwt).contains("secret")); + } + } } diff --git a/infra/src/auth/mod.rs b/infra/src/auth/mod.rs index 9dd30f5..9f91882 100644 --- a/infra/src/auth/mod.rs +++ b/infra/src/auth/mod.rs @@ -51,8 +51,8 @@ pub mod backend { #[derive(Clone, Debug, Deserialize)] pub struct Credentials { - pub email: String, - pub password: String, + pub email: domain::Email, + pub password: domain::Password, } #[derive(Debug, thiserror::Error)] @@ -72,14 +72,14 @@ pub mod backend { ) -> Result, Self::Error> { let user = self .user_repo - .find_by_email(&creds.email) + .find_by_email(creds.email.as_ref()) .await .map_err(|e| AuthError::Anyhow(anyhow::anyhow!(e)))?; if let Some(user) = user { if let Some(hash) = &user.password_hash { // Verify password - if verify_password(&creds.password, hash).is_ok() { + if verify_password(creds.password.as_ref(), hash).is_ok() { return Ok(Some(AuthUser(user))); } } diff --git a/infra/src/auth/oidc.rs b/infra/src/auth/oidc.rs index 697f4fd..feff062 100644 --- a/infra/src/auth/oidc.rs +++ b/infra/src/auth/oidc.rs @@ -1,9 +1,12 @@ use anyhow::anyhow; +use domain::{ + AuthorizationCode, AuthorizationUrlData, ClientId, ClientSecret, CsrfToken, IssuerUrl, + OidcNonce, PkceVerifier, RedirectUrl, ResourceId, +}; use openidconnect::{ - AccessTokenHash, AuthorizationCode, Client, ClientId, ClientSecret, CsrfToken, - EmptyAdditionalClaims, EndpointMaybeSet, EndpointNotSet, EndpointSet, IssuerUrl, Nonce, - OAuth2TokenResponse, PkceCodeChallenge, PkceCodeVerifier, RedirectUrl, Scope, - StandardErrorResponse, TokenResponse, UserInfoClaims, + AccessTokenHash, Client, EmptyAdditionalClaims, EndpointMaybeSet, EndpointNotSet, EndpointSet, + OAuth2TokenResponse, PkceCodeChallenge, Scope, StandardErrorResponse, TokenResponse, + UserInfoClaims, core::{ CoreAuthDisplay, CoreAuthPrompt, CoreAuthenticationFlow, CoreClient, CoreErrorResponseType, CoreGenderClaim, CoreJsonWebKey, CoreJweContentEncryptionAlgorithm, CoreProviderMetadata, @@ -36,7 +39,7 @@ pub type OidcClient = Client< #[derive(Clone)] pub struct OidcService { client: OidcClient, - resource_id: Option, + resource_id: Option, } #[derive(Debug)] @@ -46,31 +49,19 @@ pub struct OidcUser { } impl OidcService { - //todo: replace Strings with newtypes + /// Create a new OIDC service with validated configuration newtypes pub async fn new( - issuer: String, - client_id: String, - client_secret: String, - redirect_url: String, - resource_id: Option, + issuer: IssuerUrl, + client_id: ClientId, + client_secret: Option, + redirect_url: RedirectUrl, + resource_id: Option, ) -> anyhow::Result { - let client_id = client_id.trim().to_string(); - let redirect_url = redirect_url.trim().to_string(); - let issuer = issuer.trim().to_string(); - - // 2. Handle Empty Secret (For PKCE/Public Clients) - let client_secret_clean = client_secret.trim(); - let client_secret_opt = if client_secret_clean.is_empty() { - None - } else { - Some(ClientSecret::new(client_secret_clean.to_string())) - }; - tracing::debug!("🔵 OIDC Setup: Client ID = '{}'", client_id); tracing::debug!("🔵 OIDC Setup: Redirect = '{}'", redirect_url); tracing::debug!( "🔵 OIDC Setup: Secret = {:?}", - if client_secret_opt.is_some() { + if client_secret.is_some() { "SET" } else { "NONE" @@ -81,15 +72,26 @@ impl OidcService { .redirect(reqwest::redirect::Policy::none()) .build()?; - let provider_metadata = - CoreProviderMetadata::discover_async(IssuerUrl::new(issuer)?, &http_client).await?; + let provider_metadata = CoreProviderMetadata::discover_async( + openidconnect::IssuerUrl::new(issuer.as_ref().to_string())?, + &http_client, + ) + .await?; + + // Convert to openidconnect types + let oidc_client_id = openidconnect::ClientId::new(client_id.as_ref().to_string()); + let oidc_client_secret = client_secret + .as_ref() + .filter(|s| !s.is_empty()) + .map(|s| openidconnect::ClientSecret::new(s.as_ref().to_string())); + let oidc_redirect_url = openidconnect::RedirectUrl::new(redirect_url.as_ref().to_string())?; let client = CoreClient::from_provider_metadata( provider_metadata, - ClientId::new(client_id), - client_secret_opt, + oidc_client_id, + oidc_client_secret, ) - .set_redirect_uri(RedirectUrl::new(redirect_url)?); + .set_redirect_uri(oidc_redirect_url); Ok(Self { client, @@ -97,48 +99,53 @@ impl OidcService { }) } - // todo: replace this tuple with newtype - pub fn get_authorization_url(&self) -> (String, String, String, String) { + /// Get the authorization URL and associated state for OIDC login + /// + /// Returns structured data instead of a raw tuple for better type safety + pub fn get_authorization_url(&self) -> AuthorizationUrlData { let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256(); let (auth_url, csrf_token, nonce) = self .client .authorize_url( CoreAuthenticationFlow::AuthorizationCode, - CsrfToken::new_random, - Nonce::new_random, + openidconnect::CsrfToken::new_random, + openidconnect::Nonce::new_random, ) .add_scope(Scope::new("profile".to_string())) .add_scope(Scope::new("email".to_string())) .set_pkce_challenge(pkce_challenge) .url(); - ( - auth_url.to_string(), - csrf_token.secret().to_string(), - nonce.secret().to_string(), - pkce_verifier.secret().to_string(), - ) + AuthorizationUrlData { + url: auth_url.into(), + csrf_token: CsrfToken::new(csrf_token.secret().to_string()), + nonce: OidcNonce::new(nonce.secret().to_string()), + pkce_verifier: PkceVerifier::new(pkce_verifier.secret().to_string()), + } } - //todo: replace strings with newtype + /// Resolve the OIDC callback with type-safe parameters pub async fn resolve_callback( &self, - code: String, - nonce: String, - pkce_verifier: String, + code: AuthorizationCode, + nonce: OidcNonce, + pkce_verifier: PkceVerifier, ) -> anyhow::Result { let http_client = reqwest::ClientBuilder::new() .redirect(reqwest::redirect::Policy::none()) .build()?; - let pkce_verifier = PkceCodeVerifier::new(pkce_verifier); - let nonce = Nonce::new(nonce); + let oidc_pkce_verifier = + openidconnect::PkceCodeVerifier::new(pkce_verifier.as_ref().to_string()); + let oidc_nonce = openidconnect::Nonce::new(nonce.as_ref().to_string()); let token_response = self .client - .exchange_code(AuthorizationCode::new(code))? - .set_pkce_verifier(pkce_verifier) + .exchange_code(openidconnect::AuthorizationCode::new( + code.as_ref().to_string(), + ))? + .set_pkce_verifier(oidc_pkce_verifier) .request_async(&http_client) .await?; @@ -148,14 +155,13 @@ impl OidcService { let mut id_token_verifier = self.client.id_token_verifier().clone(); - let trusted_resource_id = self.resource_id.clone(); - - if let Some(resource_id) = trusted_resource_id { + if let Some(resource_id) = &self.resource_id { + let trusted_resource_id = resource_id.as_ref().to_string(); id_token_verifier = id_token_verifier - .set_other_audience_verifier_fn(move |aud| aud.as_str() == resource_id); + .set_other_audience_verifier_fn(move |aud| aud.as_str() == trusted_resource_id); } - let claims = id_token.claims(&id_token_verifier, &nonce)?; + let claims = id_token.claims(&id_token_verifier, &oidc_nonce)?; if let Some(expected_access_token_hash) = claims.access_token_hash() { let actual_access_token_hash = AccessTokenHash::from_token(