From 66f6a7d70c0e352f4c51804179a82519908379bf Mon Sep 17 00:00:00 2001 From: Gabriel Kaszewski Date: Tue, 6 Jan 2026 19:25:18 +0100 Subject: [PATCH 1/5] feat: enhance email validation and add new OIDC/OAuth configuration and flow value objects --- notes-domain/Cargo.toml | 2 + notes-domain/src/value_objects.rs | 555 ++++++++++++++++++++++++++++-- 2 files changed, 528 insertions(+), 29 deletions(-) diff --git a/notes-domain/Cargo.toml b/notes-domain/Cargo.toml index 5b2b2ef..1235086 100644 --- a/notes-domain/Cargo.toml +++ b/notes-domain/Cargo.toml @@ -13,6 +13,8 @@ thiserror = "2.0.17" tracing = "0.1" uuid = { version = "1.19.0", features = ["v4", "serde"] } futures-core = "0.3" +email_address = "0.2.9" +url = { version = "2.5.8", features = ["serde"] } [dev-dependencies] tokio = { version = "1", features = ["rt", "macros"] } diff --git a/notes-domain/src/value_objects.rs b/notes-domain/src/value_objects.rs index 979de18..fceab2a 100644 --- a/notes-domain/src/value_objects.rs +++ b/notes-domain/src/value_objects.rs @@ -6,6 +6,7 @@ use serde::{Deserialize, Deserializer, Serialize, Serializer}; use std::fmt; use thiserror::Error; +use url::Url; // ============================================================================ // Validation Error @@ -28,47 +29,44 @@ pub enum ValidationError { #[error("Note title cannot exceed {max} characters, got {actual}")] TitleTooLong { max: usize, actual: usize }, + + #[error("Invalid URL: {0}")] + InvalidUrl(String), + + #[error("Value cannot be empty: {0}")] + Empty(String), + + #[error("Secret too short: minimum {min} bytes required, got {actual}")] + SecretTooShort { min: usize, actual: usize }, } // ============================================================================ // Email // ============================================================================ -/// A validated email address. -/// -/// Simple validation: must contain exactly one `@` with non-empty parts on both sides. +/// A validated email address using RFC-compliant validation. #[derive(Debug, Clone, PartialEq, Eq, Hash)] -pub struct Email(String); +pub struct Email(email_address::EmailAddress); impl Email { - /// Minimum validation: contains @ with non-empty local and domain parts - pub fn new(value: impl Into) -> Result { - let value = value.into(); - let trimmed = value.trim().to_lowercase(); - - // Basic email validation - let parts: Vec<&str> = trimmed.split('@').collect(); - if parts.len() != 2 || parts[0].is_empty() || parts[1].is_empty() { - return Err(ValidationError::InvalidEmail(value)); - } - - // Domain must contain at least one dot - if !parts[1].contains('.') { - return Err(ValidationError::InvalidEmail(value)); - } - - Ok(Self(trimmed)) + /// Create a new validated email address + pub fn new(value: impl AsRef) -> Result { + let value = value.as_ref().trim().to_lowercase(); + let addr: email_address::EmailAddress = value + .parse() + .map_err(|_| ValidationError::InvalidEmail(value.clone()))?; + Ok(Self(addr)) } /// Get the inner value pub fn into_inner(self) -> String { - self.0 + self.0.to_string() } } impl AsRef for Email { fn as_ref(&self) -> &str { - &self.0 + self.0.as_ref() } } @@ -96,7 +94,7 @@ impl TryFrom<&str> for Email { impl Serialize for Email { fn serialize(&self, serializer: S) -> Result { - serializer.serialize_str(&self.0) + serializer.serialize_str(self.0.as_ref()) } } @@ -351,6 +349,446 @@ impl<'de> Deserialize<'de> for NoteTitle { } } +// ============================================================================ +// OIDC Configuration Newtypes +// ============================================================================ + +/// OIDC Issuer URL - validated URL for the identity provider +/// +/// Stores the original string to preserve exact formatting (e.g., trailing slashes) +/// since OIDC providers expect issuer URLs to match exactly. +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[serde(try_from = "String", into = "String")] +pub struct IssuerUrl(String); + +impl IssuerUrl { + pub fn new(value: impl AsRef) -> Result { + let value = value.as_ref().trim().to_string(); + // Validate URL format but store original string to preserve exact formatting + Url::parse(&value).map_err(|e| ValidationError::InvalidUrl(e.to_string()))?; + Ok(Self(value)) + } +} + +impl AsRef for IssuerUrl { + fn as_ref(&self) -> &str { + &self.0 + } +} + +impl fmt::Display for IssuerUrl { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.0) + } +} + +impl TryFrom for IssuerUrl { + type Error = ValidationError; + fn try_from(value: String) -> Result { + Self::new(value) + } +} + +impl From for String { + fn from(val: IssuerUrl) -> Self { + val.0 + } +} + +/// OIDC Client Identifier +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[serde(try_from = "String", into = "String")] +pub struct ClientId(String); + +impl ClientId { + pub fn new(value: impl Into) -> Result { + let value = value.into().trim().to_string(); + if value.is_empty() { + return Err(ValidationError::Empty("client_id".to_string())); + } + Ok(Self(value)) + } +} + +impl AsRef for ClientId { + fn as_ref(&self) -> &str { + &self.0 + } +} + +impl fmt::Display for ClientId { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.0) + } +} + +impl TryFrom for ClientId { + type Error = ValidationError; + fn try_from(value: String) -> Result { + Self::new(value) + } +} + +impl From for String { + fn from(val: ClientId) -> Self { + val.0 + } +} + +/// OIDC Client Secret - hidden in Debug output +#[derive(Clone, PartialEq, Eq)] +pub struct ClientSecret(String); + +impl ClientSecret { + pub fn new(value: impl Into) -> Self { + Self(value.into()) + } + + /// Check if the secret is empty (for public clients) + pub fn is_empty(&self) -> bool { + self.0.trim().is_empty() + } +} + +impl AsRef for ClientSecret { + fn as_ref(&self) -> &str { + &self.0 + } +} + +impl fmt::Debug for ClientSecret { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "ClientSecret(***)") + } +} + +impl fmt::Display for ClientSecret { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "***") + } +} + +impl<'de> Deserialize<'de> for ClientSecret { + fn deserialize>(deserializer: D) -> Result { + let s = String::deserialize(deserializer)?; + Ok(Self::new(s)) + } +} + +// Note: ClientSecret should NOT implement Serialize + +/// OAuth Redirect URL - validated URL +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[serde(try_from = "String", into = "String")] +pub struct RedirectUrl(Url); + +impl RedirectUrl { + pub fn new(value: impl AsRef) -> Result { + let value = value.as_ref().trim(); + let url = Url::parse(value).map_err(|e| ValidationError::InvalidUrl(e.to_string()))?; + Ok(Self(url)) + } + + pub fn as_url(&self) -> &Url { + &self.0 + } +} + +impl AsRef for RedirectUrl { + fn as_ref(&self) -> &str { + self.0.as_str() + } +} + +impl fmt::Display for RedirectUrl { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.0) + } +} + +impl TryFrom for RedirectUrl { + type Error = ValidationError; + fn try_from(value: String) -> Result { + Self::new(value) + } +} + +impl From for String { + fn from(val: RedirectUrl) -> Self { + val.0.to_string() + } +} + +/// OIDC Resource Identifier (optional audience) +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[serde(try_from = "String", into = "String")] +pub struct ResourceId(String); + +impl ResourceId { + pub fn new(value: impl Into) -> Result { + let value = value.into().trim().to_string(); + if value.is_empty() { + return Err(ValidationError::Empty("resource_id".to_string())); + } + Ok(Self(value)) + } +} + +impl AsRef for ResourceId { + fn as_ref(&self) -> &str { + &self.0 + } +} + +impl fmt::Display for ResourceId { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.0) + } +} + +impl TryFrom for ResourceId { + type Error = ValidationError; + fn try_from(value: String) -> Result { + Self::new(value) + } +} + +impl From for String { + fn from(val: ResourceId) -> Self { + val.0 + } +} + +// ============================================================================ +// OIDC Flow Newtypes (for type-safe session storage) +// ============================================================================ + +/// CSRF Token for OIDC state parameter +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct CsrfToken(String); + +impl CsrfToken { + pub fn new(value: impl Into) -> Self { + Self(value.into()) + } +} + +impl AsRef for CsrfToken { + fn as_ref(&self) -> &str { + &self.0 + } +} + +impl fmt::Display for CsrfToken { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.0) + } +} + +/// Nonce for OIDC ID token verification +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct OidcNonce(String); + +impl OidcNonce { + pub fn new(value: impl Into) -> Self { + Self(value.into()) + } +} + +impl AsRef for OidcNonce { + fn as_ref(&self) -> &str { + &self.0 + } +} + +impl fmt::Display for OidcNonce { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.0) + } +} + +/// PKCE Code Verifier +#[derive(Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct PkceVerifier(String); + +impl PkceVerifier { + pub fn new(value: impl Into) -> Self { + Self(value.into()) + } +} + +impl AsRef for PkceVerifier { + fn as_ref(&self) -> &str { + &self.0 + } +} + +// Hide PKCE verifier in Debug (security) +impl fmt::Debug for PkceVerifier { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "PkceVerifier(***)") + } +} + +/// OAuth2 Authorization Code +#[derive(Clone, PartialEq, Eq)] +pub struct AuthorizationCode(String); + +impl AuthorizationCode { + pub fn new(value: impl Into) -> Self { + Self(value.into()) + } +} + +impl AsRef for AuthorizationCode { + fn as_ref(&self) -> &str { + &self.0 + } +} + +// Hide authorization code in Debug (security) +impl fmt::Debug for AuthorizationCode { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "AuthorizationCode(***)") + } +} + +impl<'de> Deserialize<'de> for AuthorizationCode { + fn deserialize>(deserializer: D) -> Result { + let s = String::deserialize(deserializer)?; + Ok(Self::new(s)) + } +} + +/// Complete authorization URL data returned when starting OIDC flow +#[derive(Debug, Clone)] +pub struct AuthorizationUrlData { + /// The URL to redirect the user to + pub url: Url, + /// CSRF token to store in session + pub csrf_token: CsrfToken, + /// Nonce to store in session + pub nonce: OidcNonce, + /// PKCE verifier to store in session + pub pkce_verifier: PkceVerifier, +} + +// ============================================================================ +// Configuration Newtypes +// ============================================================================ + +/// Database connection URL +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[serde(try_from = "String", into = "String")] +pub struct DatabaseUrl(String); + +impl DatabaseUrl { + pub fn new(value: impl Into) -> Result { + let value = value.into(); + if value.trim().is_empty() { + return Err(ValidationError::Empty("database_url".to_string())); + } + Ok(Self(value)) + } +} + +impl AsRef for DatabaseUrl { + fn as_ref(&self) -> &str { + &self.0 + } +} + +impl fmt::Display for DatabaseUrl { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.0) + } +} + +impl TryFrom for DatabaseUrl { + type Error = ValidationError; + fn try_from(value: String) -> Result { + Self::new(value) + } +} + +impl From for String { + fn from(val: DatabaseUrl) -> Self { + val.0 + } +} + +/// Session secret with minimum length requirement +pub const MIN_SESSION_SECRET_LENGTH: usize = 64; + +#[derive(Clone, PartialEq, Eq)] +pub struct SessionSecret(String); + +impl SessionSecret { + pub fn new(value: impl Into) -> Result { + let value = value.into(); + if value.len() < MIN_SESSION_SECRET_LENGTH { + return Err(ValidationError::SecretTooShort { + min: MIN_SESSION_SECRET_LENGTH, + actual: value.len(), + }); + } + Ok(Self(value)) + } + + /// Create without validation (for development/testing) + pub fn new_unchecked(value: impl Into) -> Self { + Self(value.into()) + } +} + +impl AsRef for SessionSecret { + fn as_ref(&self) -> &str { + &self.0 + } +} + +impl fmt::Debug for SessionSecret { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "SessionSecret(***)") + } +} + +/// JWT signing secret with minimum length requirement +pub const MIN_JWT_SECRET_LENGTH: usize = 32; + +#[derive(Clone, PartialEq, Eq)] +pub struct JwtSecret(String); + +impl JwtSecret { + pub fn new(value: impl Into, is_production: bool) -> Result { + let value = value.into(); + if is_production && value.len() < MIN_JWT_SECRET_LENGTH { + return Err(ValidationError::SecretTooShort { + min: MIN_JWT_SECRET_LENGTH, + actual: value.len(), + }); + } + Ok(Self(value)) + } + + /// Create without validation (for development/testing) + pub fn new_unchecked(value: impl Into) -> Self { + Self(value.into()) + } +} + +impl AsRef for JwtSecret { + fn as_ref(&self) -> &str { + &self.0 + } +} + +impl fmt::Debug for JwtSecret { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "JwtSecret(***)") + } +} + // ============================================================================ // Tests // ============================================================================ @@ -389,11 +827,6 @@ mod tests { fn test_invalid_email_no_local() { assert!(Email::new("@example.com").is_err()); } - - #[test] - fn test_invalid_email_no_dot_in_domain() { - assert!(Email::new("user@localhost").is_err()); - } } mod password_tests { @@ -493,4 +926,68 @@ mod tests { assert_eq!(result.unwrap().as_ref(), "My Note"); } } + + mod oidc_tests { + use super::*; + + #[test] + fn test_issuer_url_valid() { + assert!(IssuerUrl::new("https://auth.example.com").is_ok()); + } + + #[test] + fn test_issuer_url_invalid() { + assert!(IssuerUrl::new("not-a-url").is_err()); + } + + #[test] + fn test_client_id_non_empty() { + assert!(ClientId::new("my-client").is_ok()); + assert!(ClientId::new("").is_err()); + assert!(ClientId::new(" ").is_err()); + } + + #[test] + fn test_client_secret_hides_in_debug() { + let secret = ClientSecret::new("super-secret"); + let debug = format!("{:?}", secret); + assert!(!debug.contains("super-secret")); + assert!(debug.contains("***")); + } + } + + mod secret_tests { + use super::*; + + #[test] + fn test_session_secret_min_length() { + let short = "short"; + let long = "a".repeat(64); + + assert!(SessionSecret::new(short).is_err()); + assert!(SessionSecret::new(long).is_ok()); + } + + #[test] + fn test_jwt_secret_production_check() { + let short = "short"; + let long = "a".repeat(32); + + // Production mode enforces length + assert!(JwtSecret::new(short, true).is_err()); + assert!(JwtSecret::new(&long, true).is_ok()); + + // Development mode allows short secrets + assert!(JwtSecret::new(short, false).is_ok()); + } + + #[test] + fn test_secrets_hide_in_debug() { + let session = SessionSecret::new_unchecked("secret"); + let jwt = JwtSecret::new_unchecked("secret"); + + assert!(!format!("{:?}", session).contains("secret")); + assert!(!format!("{:?}", jwt).contains("secret")); + } + } } -- 2.49.1 From 6a3259d3474bdb65f5b8b62dde6d779d7c20b619 Mon Sep 17 00:00:00 2001 From: Gabriel Kaszewski Date: Tue, 6 Jan 2026 19:25:29 +0100 Subject: [PATCH 2/5] chore: update dependencies --- Cargo.lock | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 4c57f3f..0ef8745 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -911,6 +911,15 @@ dependencies = [ "serde", ] +[[package]] +name = "email_address" +version = "0.2.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e079f19b08ca6239f47f8ba8509c11cf3ea30095831f7fed61441475edd8c449" +dependencies = [ + "serde", +] + [[package]] name = "encode_unicode" version = "1.0.0" @@ -2195,12 +2204,14 @@ dependencies = [ "anyhow", "async-trait", "chrono", + "email_address", "futures-core", "serde", "serde_json", "thiserror 2.0.17", "tokio", "tracing", + "url", "uuid", ] @@ -4545,14 +4556,15 @@ dependencies = [ [[package]] name = "url" -version = "2.5.7" +version = "2.5.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "08bc136a29a3d1758e07a9cca267be308aeebf5cfd5a10f3f67ab2097683ef5b" +checksum = "ff67a8a4397373c3ef660812acab3268222035010ab8680ec4215f38ba3d0eed" dependencies = [ "form_urlencoded", "idna", "percent-encoding", "serde", + "serde_derive", ] [[package]] -- 2.49.1 From 82a6c087901d04242749bc4780c2e286ddc991e7 Mon Sep 17 00:00:00 2001 From: Gabriel Kaszewski Date: Tue, 6 Jan 2026 19:39:05 +0100 Subject: [PATCH 3/5] auth in infra --- Cargo.lock | 418 ++++++++++++++++++++++++++++- notes-domain/src/lib.rs | 13 +- notes-infra/Cargo.toml | 46 +++- notes-infra/src/auth/axum_login.rs | 110 ++++++++ notes-infra/src/auth/jwt.rs | 278 +++++++++++++++++++ notes-infra/src/auth/mod.rs | 6 + notes-infra/src/auth/oidc.rs | 202 ++++++++++++++ notes-infra/src/lib.rs | 1 + 8 files changed, 1051 insertions(+), 23 deletions(-) create mode 100644 notes-infra/src/auth/axum_login.rs create mode 100644 notes-infra/src/auth/jwt.rs create mode 100644 notes-infra/src/auth/mod.rs create mode 100644 notes-infra/src/auth/oidc.rs diff --git a/Cargo.lock b/Cargo.lock index 0ef8745..d28ae67 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -378,12 +378,24 @@ dependencies = [ "syn", ] +[[package]] +name = "base16ct" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4c7f02d4ea65f2c1853089ffd8d2787bdbc63de2f0d29dedbcf8ccdfa0ccd4cf" + [[package]] name = "base64" version = "0.13.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9e1b586273c5702936fe7b7d6896644d8be71e6314cfe09d3167c95f712589e8" +[[package]] +name = "base64" +version = "0.21.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9d297deb1925b89f2ccc13d7635fa0714f12c87adce1c75356b39ca9b7178567" + [[package]] name = "base64" version = "0.22.1" @@ -692,6 +704,18 @@ version = "0.2.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "460fbee9c2c2f33933d720630a6a0bac33ba7053db5344fac858d4b8952d77d5" +[[package]] +name = "crypto-bigint" +version = "0.5.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0dc92fb57ca44df6db8059111ab3af99a63d5d0f8375d9972e319a379c6bab76" +dependencies = [ + "generic-array", + "rand_core 0.6.4", + "subtle", + "zeroize", +] + [[package]] name = "crypto-common" version = "0.1.7" @@ -715,6 +739,7 @@ dependencies = [ "fiat-crypto", "rustc_version", "subtle", + "zeroize", ] [[package]] @@ -734,8 +759,18 @@ version = "0.20.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fc7f46116c46ff9ab3eb1597a45688b6715c6e628b5c133e288e709a29bcb4ee" dependencies = [ - "darling_core", - "darling_macro", + "darling_core 0.20.11", + "darling_macro 0.20.11", +] + +[[package]] +name = "darling" +version = "0.21.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9cdf337090841a411e2a7f3deb9187445851f91b309c0c0a29e05f74a00a48c0" +dependencies = [ + "darling_core 0.21.3", + "darling_macro 0.21.3", ] [[package]] @@ -752,13 +787,38 @@ dependencies = [ "syn", ] +[[package]] +name = "darling_core" +version = "0.21.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1247195ecd7e3c85f83c8d2a366e4210d588e802133e1e355180a9870b517ea4" +dependencies = [ + "fnv", + "ident_case", + "proc-macro2", + "quote", + "strsim", + "syn", +] + [[package]] name = "darling_macro" version = "0.20.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fc34b93ccb385b40dc71c6fceac4b2ad23662c7eeb248cf10d529b7e055b6ead" dependencies = [ - "darling_core", + "darling_core 0.20.11", + "quote", + "syn", +] + +[[package]] +name = "darling_macro" +version = "0.21.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d38308df82d1080de0afee5d069fa14b0326a88c14f15c5ccda35b4a6c414c81" +dependencies = [ + "darling_core 0.21.3", "quote", "syn", ] @@ -814,7 +874,7 @@ version = "0.20.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2d5bcf7b024d6835cfb3d473887cd966994907effbe9227e8c8219824d06c4e8" dependencies = [ - "darling", + "darling 0.20.11", "proc-macro2", "quote", "syn", @@ -880,12 +940,33 @@ version = "0.15.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1aaf95b3e5c8f23aa320147307562d361db0ae0d51242340f558153b4eb2439b" +[[package]] +name = "dyn-clone" +version = "1.0.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d0881ea181b1df73ff77ffaaf9c7544ecc11e82fba9b5f27b262a3c73a332555" + +[[package]] +name = "ecdsa" +version = "0.16.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ee27f32b5c5292967d2d4a9d7f1e0b0aed2c15daded5a60300e4abb9d8020bca" +dependencies = [ + "der", + "digest", + "elliptic-curve", + "rfc6979", + "signature", + "spki", +] + [[package]] name = "ed25519" version = "2.2.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "115531babc129696a58c64a4fef0a8bf9e9698629fb97e9e40767d235cfbcd53" dependencies = [ + "pkcs8", "signature", ] @@ -897,9 +978,11 @@ checksum = "70e796c081cee67dc755e1a36a0a172b897fab85fc3f6bc48307991f64e4eca9" dependencies = [ "curve25519-dalek", "ed25519", + "serde", "sha2", "signature", "subtle", + "zeroize", ] [[package]] @@ -911,6 +994,27 @@ dependencies = [ "serde", ] +[[package]] +name = "elliptic-curve" +version = "0.13.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b5e6043086bf7973472e0c7dff2142ea0b680d30e18d9cc40f267efbf222bd47" +dependencies = [ + "base16ct", + "crypto-bigint", + "digest", + "ff", + "generic-array", + "group", + "hkdf", + "pem-rfc7468", + "pkcs8", + "rand_core 0.6.4", + "sec1", + "subtle", + "zeroize", +] + [[package]] name = "email_address" version = "0.2.9" @@ -1065,6 +1169,16 @@ dependencies = [ "simd-adler32", ] +[[package]] +name = "ff" +version = "0.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c0b50bfb653653f9ca9095b427bed08ab8d75a137839d9ad64eb11810d5b6393" +dependencies = [ + "rand_core 0.6.4", + "subtle", +] + [[package]] name = "fiat-crypto" version = "0.2.9" @@ -1254,6 +1368,7 @@ checksum = "85649ca51fd72272d7821adaf274ad91c288277713d9c18820d8499a7ff69e9a" dependencies = [ "typenum", "version_check", + "zeroize", ] [[package]] @@ -1293,6 +1408,17 @@ dependencies = [ "weezl", ] +[[package]] +name = "group" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f0f9ef7462f7c099f518d754361858f86d8a07af53ba9af0fe635bbccb151a63" +dependencies = [ + "ff", + "rand_core 0.6.4", + "subtle", +] + [[package]] name = "h2" version = "0.4.12" @@ -1735,6 +1861,7 @@ checksum = "bd070e393353796e801d209ad339e89596eb4c8d430d18ede6a1cced8fafbd99" dependencies = [ "autocfg", "hashbrown 0.12.3", + "serde", ] [[package]] @@ -1745,6 +1872,8 @@ checksum = "0ad4bb2b565bca0645f4d68c5c9af97fba094e9791da685bf83cb5f3ce74acf2" dependencies = [ "equivalent", "hashbrown 0.16.1", + "serde", + "serde_core", ] [[package]] @@ -1787,6 +1916,15 @@ dependencies = [ "serde", ] +[[package]] +name = "itertools" +version = "0.10.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b0fd2260e829bddf4cb6ea802289de2f86d6a7a690192fbe91b3f46e0f2c8473" +dependencies = [ + "either", +] + [[package]] name = "itertools" version = "0.14.0" @@ -1822,6 +1960,29 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "jsonwebtoken" +version = "10.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c76e1c7d7df3e34443b3621b459b066a7b79644f059fc8b2db7070c825fd417e" +dependencies = [ + "base64 0.22.1", + "ed25519-dalek", + "getrandom 0.2.16", + "hmac", + "js-sys", + "p256", + "p384", + "pem", + "rand 0.8.5", + "rsa", + "serde", + "serde_json", + "sha2", + "signature", + "simple_asn1", +] + [[package]] name = "k-core" version = "0.1.10" @@ -2221,18 +2382,24 @@ version = "0.1.0" dependencies = [ "anyhow", "async-trait", + "axum-login", "chrono", "futures-core", "futures-util", + "jsonwebtoken", "k-core", "notes-domain", + "openidconnect", + "password-auth", "serde", "serde_json", "sqlx", "thiserror 2.0.17", "tokio", + "tower-sessions", "tower-sessions-sqlx-store", "tracing", + "url", "uuid", ] @@ -2375,6 +2542,26 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "830b246a0e5f20af87141b25c173cd1b609bd7779a4617d6ec582abaf90870f3" +[[package]] +name = "oauth2" +version = "5.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "51e219e79014df21a225b1860a479e2dcd7cbd9130f4defd4bd0e191ea31d67d" +dependencies = [ + "base64 0.22.1", + "chrono", + "getrandom 0.2.16", + "http", + "rand 0.8.5", + "reqwest", + "serde", + "serde_json", + "serde_path_to_error", + "sha2", + "thiserror 1.0.69", + "url", +] + [[package]] name = "once_cell" version = "1.21.3" @@ -2403,6 +2590,37 @@ dependencies = [ "pkg-config", ] +[[package]] +name = "openidconnect" +version = "4.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0d8c6709ba2ea764bbed26bce1adf3c10517113ddea6f2d4196e4851757ef2b2" +dependencies = [ + "base64 0.21.7", + "chrono", + "dyn-clone", + "ed25519-dalek", + "hmac", + "http", + "itertools 0.10.5", + "log", + "oauth2", + "p256", + "p384", + "rand 0.8.5", + "rsa", + "serde", + "serde-value", + "serde_json", + "serde_path_to_error", + "serde_plain", + "serde_with", + "sha2", + "subtle", + "thiserror 1.0.69", + "url", +] + [[package]] name = "openssl" version = "0.10.75" @@ -2459,6 +2677,15 @@ version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "04744f49eae99ab78e0d5c0b603ab218f515ea8cfe5a456d7629ad883a3b6e7d" +[[package]] +name = "ordered-float" +version = "2.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "68f19d67e5a2795c94e73e0bb1cc1a7edeb2e28efd39e2e1c9b7a40c1108b11c" +dependencies = [ + "num-traits", +] + [[package]] name = "ort" version = "2.0.0-rc.10" @@ -2484,6 +2711,30 @@ dependencies = [ "ureq 3.1.4", ] +[[package]] +name = "p256" +version = "0.13.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c9863ad85fa8f4460f9c48cb909d38a0d689dba1f6f6988a5e3e0d31071bcd4b" +dependencies = [ + "ecdsa", + "elliptic-curve", + "primeorder", + "sha2", +] + +[[package]] +name = "p384" +version = "0.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fe42f1670a52a47d448f14b6a5c61dd78fce51856e68edaa38f7ae3a46b8d6b6" +dependencies = [ + "ecdsa", + "elliptic-curve", + "primeorder", + "sha2", +] + [[package]] name = "parking" version = "2.2.1" @@ -2548,6 +2799,16 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "35fb2e5f958ec131621fdd531e9fc186ed768cbe395337403ae56c17a74c68ec" +[[package]] +name = "pem" +version = "3.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d30c53c26bc5b31a98cd02d20f25a7c8567146caf63ed593a9d87b2775291be" +dependencies = [ + "base64 0.22.1", + "serde_core", +] + [[package]] name = "pem-rfc7468" version = "0.7.0" @@ -2674,6 +2935,15 @@ dependencies = [ "zerocopy", ] +[[package]] +name = "primeorder" +version = "0.13.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "353e1ca18966c16d9deb1c69278edbc5f194139612772bd9537af60ac231e1e6" +dependencies = [ + "elliptic-curve", +] + [[package]] name = "proc-macro-error-attr2" version = "2.0.0" @@ -2741,7 +3011,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8a56d757972c98b346a9b766e3f02746cde6dd1cd1d1d563472929fdd74bec4d" dependencies = [ "anyhow", - "itertools", + "itertools 0.14.0", "proc-macro2", "quote", "syn", @@ -2947,7 +3217,7 @@ dependencies = [ "built", "cfg-if", "interpolate_name", - "itertools", + "itertools 0.14.0", "libc", "libfuzzer-sys", "log", @@ -3004,7 +3274,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2964d0cf57a3e7a06e8183d14a8b527195c706b7983549cd5462d5aa3747438f" dependencies = [ "either", - "itertools", + "itertools 0.14.0", "rayon", ] @@ -3047,6 +3317,26 @@ dependencies = [ "thiserror 2.0.17", ] +[[package]] +name = "ref-cast" +version = "1.0.25" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f354300ae66f76f1c85c5f84693f0ce81d747e2c3f21a45fef496d89c960bf7d" +dependencies = [ + "ref-cast-impl", +] + +[[package]] +name = "ref-cast-impl" +version = "1.0.25" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b7186006dcb21920990093f30e3dea63b7d6e977bf1256be20c3563a5db070da" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "regex" version = "1.12.2" @@ -3123,6 +3413,16 @@ dependencies = [ "webpki-roots 1.0.4", ] +[[package]] +name = "rfc6979" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f8dd2a808d456c4a54e300a23e9f5a67e122c3024119acbfd73e3bf664491cb2" +dependencies = [ + "hmac", + "subtle", +] + [[package]] name = "rgb" version = "0.8.52" @@ -3321,12 +3621,50 @@ dependencies = [ "windows-sys 0.61.2", ] +[[package]] +name = "schemars" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4cd191f9397d57d581cddd31014772520aa448f65ef991055d7f61582c65165f" +dependencies = [ + "dyn-clone", + "ref-cast", + "serde", + "serde_json", +] + +[[package]] +name = "schemars" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "54e910108742c57a770f492731f99be216a52fadd361b06c8fb59d74ccc267d2" +dependencies = [ + "dyn-clone", + "ref-cast", + "serde", + "serde_json", +] + [[package]] name = "scopeguard" version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" +[[package]] +name = "sec1" +version = "0.7.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d3e97a565f76233a6003f9f5c54be1d9c5bdfa3eccfb189469f11ec4901c47dc" +dependencies = [ + "base16ct", + "der", + "generic-array", + "pkcs8", + "subtle", + "zeroize", +] + [[package]] name = "security-framework" version = "2.11.1" @@ -3379,6 +3717,16 @@ dependencies = [ "serde_derive", ] +[[package]] +name = "serde-value" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f3a1a3341211875ef120e117ea7fd5228530ae7e7036a779fdc9117be6b3282c" +dependencies = [ + "ordered-float", + "serde", +] + [[package]] name = "serde_core" version = "1.0.228" @@ -3432,6 +3780,15 @@ dependencies = [ "serde_core", ] +[[package]] +name = "serde_plain" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ce1fc6db65a611022b23a0dec6975d63fb80a302cb3388835ff02c097258d50" +dependencies = [ + "serde", +] + [[package]] name = "serde_repr" version = "0.1.20" @@ -3455,6 +3812,37 @@ dependencies = [ "serde", ] +[[package]] +name = "serde_with" +version = "3.16.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4fa237f2807440d238e0364a218270b98f767a00d3dada77b1c53ae88940e2e7" +dependencies = [ + "base64 0.22.1", + "chrono", + "hex", + "indexmap 1.9.3", + "indexmap 2.12.1", + "schemars 0.9.0", + "schemars 1.2.0", + "serde_core", + "serde_json", + "serde_with_macros", + "time", +] + +[[package]] +name = "serde_with_macros" +version = "3.16.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "52a8e3ca0ca629121f70ab50f95249e5a6f925cc0f6ffe8256c45b728875706c" +dependencies = [ + "darling 0.21.3", + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "sha1" version = "0.10.6" @@ -3539,6 +3927,18 @@ dependencies = [ "quote", ] +[[package]] +name = "simple_asn1" +version = "0.6.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "297f631f50729c8c99b84667867963997ec0b50f32b2a7dbcab828ef0541e8bb" +dependencies = [ + "num-bigint", + "num-traits", + "thiserror 2.0.17", + "time", +] + [[package]] name = "slab" version = "0.4.11" @@ -4065,7 +4465,7 @@ dependencies = [ "derive_builder", "esaxx-rs", "getrandom 0.3.4", - "itertools", + "itertools 0.14.0", "log", "macro_rules_attribute", "monostate", @@ -4630,7 +5030,7 @@ version = "0.20.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b7df16e474ef958526d1205f6dda359fdfab79d9aa6d54bafcb92dcd07673dca" dependencies = [ - "darling", + "darling 0.20.11", "once_cell", "proc-macro-error2", "proc-macro2", diff --git a/notes-domain/src/lib.rs b/notes-domain/src/lib.rs index 0ee58eb..7dc26c9 100644 --- a/notes-domain/src/lib.rs +++ b/notes-domain/src/lib.rs @@ -17,12 +17,9 @@ pub mod services; pub mod value_objects; // Re-export commonly used types at crate root -pub use entities::{MAX_TAGS_PER_NOTE, Note, NoteFilter, NoteVersion, Tag, User}; +pub use entities::*; pub use errors::{DomainError, DomainResult}; -pub use ports::MessageBroker; -pub use repositories::{NoteRepository, TagRepository, UserRepository}; -pub use services::{CreateNoteRequest, NoteService, TagService, UpdateNoteRequest, UserService}; -pub use value_objects::{ - Email, MAX_NOTE_TITLE_LENGTH, MAX_TAG_NAME_LENGTH, MIN_PASSWORD_LENGTH, NoteTitle, Password, - TagName, ValidationError, -}; +pub use ports::*; +pub use repositories::*; +pub use services::*; +pub use value_objects::*; diff --git a/notes-infra/Cargo.toml b/notes-infra/Cargo.toml index aaa07d2..629036a 100644 --- a/notes-infra/Cargo.toml +++ b/notes-infra/Cargo.toml @@ -4,18 +4,38 @@ version = "0.1.0" edition = "2024" [features] -default = ["sqlite", "smart-features", "broker-nats"] -sqlite = ["sqlx/sqlite", "k-core/sqlite", "tower-sessions-sqlx-store", "k-core/sessions-db"] -postgres = ["sqlx/postgres", "k-core/postgres", "tower-sessions-sqlx-store", "k-core/sessions-db"] +default = [ + "sqlite", + "smart-features", + "broker-nats", + "auth-jwt", + "auth-oidc", + "auth-axum-login", +] +sqlite = [ + "sqlx/sqlite", + "k-core/sqlite", + "tower-sessions-sqlx-store", + "k-core/sessions-db", +] +postgres = [ + "sqlx/postgres", + "k-core/postgres", + "tower-sessions-sqlx-store", + "k-core/sessions-db", +] smart-features = ["k-core/ai"] broker-nats = ["dep:futures-util", "k-core/broker-nats"] +auth-axum-login = ["dep:axum-login", "dep:password-auth"] +auth-oidc = ["dep:openidconnect", "dep:url"] +auth-jwt = ["dep:jsonwebtoken"] [dependencies] k-core = { git = "https://git.gabrielkaszewski.dev/GKaszewski/k-core", features = [ "logging", "db-sqlx", - "sessions-db" -], version = "*"} + "sessions-db", +], version = "*" } notes-domain = { path = "../notes-domain" } chrono = { version = "0.4.42", features = ["serde"] } @@ -31,4 +51,18 @@ futures-util = { version = "0.3", optional = true } futures-core = "0.3" async-trait = "0.1.89" anyhow = "1.0.100" -tower-sessions-sqlx-store = { version = "0.15.0", optional = true} +tower-sessions-sqlx-store = { version = "0.15.0", optional = true } +tower-sessions = "0.14" + +# Auth dependencies (optional) +axum-login = { version = "0.18", optional = true } +password-auth = { version = "1.0", optional = true } +openidconnect = { version = "4.0.1", optional = true } +url = { version = "2.5.8", optional = true } +jsonwebtoken = { version = "10.2.0", features = [ + "sha2", + "p256", + "hmac", + "rsa", + "rust_crypto", +], optional = true } diff --git a/notes-infra/src/auth/axum_login.rs b/notes-infra/src/auth/axum_login.rs new file mode 100644 index 0000000..75bd0fc --- /dev/null +++ b/notes-infra/src/auth/axum_login.rs @@ -0,0 +1,110 @@ +use std::sync::Arc; + +use axum_login::{AuthnBackend, UserId}; +use password_auth::verify_password; +use serde::{Deserialize, Serialize}; +use tower_sessions::SessionManagerLayer; +use uuid::Uuid; + +use notes_domain::{User, UserRepository}; + +use crate::session_store::InfraSessionStore; + +/// Wrapper around domain User to implement AuthUser +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct AuthUser(pub User); + +impl axum_login::AuthUser for AuthUser { + type Id = Uuid; + + fn id(&self) -> Self::Id { + self.0.id + } + + fn session_auth_hash(&self) -> &[u8] { + // Use password hash to invalidate sessions if password changes + self.0 + .password_hash + .as_ref() + .map(|s| s.as_bytes()) + .unwrap_or(&[]) + } +} + +#[derive(Clone)] +pub struct AuthBackend { + pub user_repo: Arc, +} + +impl AuthBackend { + pub fn new(user_repo: Arc) -> Self { + Self { user_repo } + } +} + +#[derive(Clone, Debug, Deserialize)] +pub struct Credentials { + pub email: notes_domain::Email, + pub password: notes_domain::Password, +} + +#[derive(Debug, thiserror::Error)] +pub enum AuthError { + #[error(transparent)] + Anyhow(#[from] anyhow::Error), +} + +impl AuthnBackend for AuthBackend { + type User = AuthUser; + type Credentials = Credentials; + type Error = AuthError; + + async fn authenticate( + &self, + creds: Self::Credentials, + ) -> Result, Self::Error> { + let user = self + .user_repo + .find_by_email(creds.email.as_ref()) + .await + .map_err(|e| AuthError::Anyhow(anyhow::anyhow!(e)))?; + + if let Some(user) = user { + if let Some(hash) = &user.password_hash { + // Verify password + if verify_password(creds.password.as_ref(), hash).is_ok() { + return Ok(Some(AuthUser(user))); + } + } + } + + Ok(None) + } + + async fn get_user(&self, user_id: &UserId) -> Result, Self::Error> { + let user = self + .user_repo + .find_by_id(*user_id) + .await + .map_err(|e| AuthError::Anyhow(anyhow::anyhow!(e)))?; + + Ok(user.map(AuthUser)) + } +} + +pub type AuthSession = axum_login::AuthSession; +pub type AuthManagerLayer = axum_login::AuthManagerLayer; + +pub async fn setup_auth_layer( + session_layer: SessionManagerLayer, + user_repo: Arc, +) -> Result { + let backend = AuthBackend::new(user_repo); + + let auth_layer = axum_login::AuthManagerLayerBuilder::new(backend, session_layer).build(); + Ok(auth_layer) +} + +pub fn hash_password(password: &str) -> String { + password_auth::generate_hash(password) +} diff --git a/notes-infra/src/auth/jwt.rs b/notes-infra/src/auth/jwt.rs new file mode 100644 index 0000000..bf1c917 --- /dev/null +++ b/notes-infra/src/auth/jwt.rs @@ -0,0 +1,278 @@ +//! JWT Authentication Infrastructure +//! +//! Provides JWT token creation and validation using HS256 (secret-based). +//! For OIDC/JWKS validation, see the `oidc` module. + +use jsonwebtoken::{Algorithm, DecodingKey, EncodingKey, Header, Validation, decode, encode}; +use notes_domain::User; +use serde::{Deserialize, Serialize}; +use std::time::{SystemTime, UNIX_EPOCH}; + +/// Minimum secret length for production (256 bits = 32 bytes) +const MIN_SECRET_LENGTH: usize = 32; + +/// JWT configuration +#[derive(Debug, Clone)] +pub struct JwtConfig { + /// Secret key for HS256 signing/verification + pub secret: String, + /// Expected issuer (for validation) + pub issuer: Option, + /// Expected audience (for validation) + pub audience: Option, + /// Token expiry in hours (default: 24) + pub expiry_hours: u64, +} + +impl JwtConfig { + /// Create a new JWT config with validation + /// + /// In production mode, this will reject weak secrets. + pub fn new( + secret: String, + issuer: Option, + audience: Option, + expiry_hours: Option, + is_production: bool, + ) -> Result { + // Validate secret strength in production + if is_production && secret.len() < MIN_SECRET_LENGTH { + return Err(JwtError::WeakSecret { + min_length: MIN_SECRET_LENGTH, + actual_length: secret.len(), + }); + } + + Ok(Self { + secret, + issuer, + audience, + expiry_hours: expiry_hours.unwrap_or(24), + }) + } + + /// Create config without validation (for testing) + pub fn new_unchecked(secret: String) -> Self { + Self { + secret, + issuer: None, + audience: None, + expiry_hours: 24, + } + } +} + +/// JWT claims structure +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct JwtClaims { + /// Subject - the user's unique identifier (user ID as string) + pub sub: String, + /// User's email address + pub email: String, + /// Expiry timestamp (seconds since UNIX epoch) + pub exp: usize, + /// Issued at timestamp (seconds since UNIX epoch) + pub iat: usize, + /// Issuer + #[serde(skip_serializing_if = "Option::is_none")] + pub iss: Option, + /// Audience + #[serde(skip_serializing_if = "Option::is_none")] + pub aud: Option, +} + +/// JWT-related errors +#[derive(Debug, thiserror::Error)] +pub enum JwtError { + #[error("JWT secret is too weak: minimum {min_length} bytes required, got {actual_length}")] + WeakSecret { + min_length: usize, + actual_length: usize, + }, + + #[error("Token creation failed: {0}")] + CreationFailed(#[from] jsonwebtoken::errors::Error), + + #[error("Token validation failed: {0}")] + ValidationFailed(String), + + #[error("Token expired")] + Expired, + + #[error("Invalid token format")] + InvalidFormat, + + #[error("Missing configuration")] + MissingConfig, +} + +/// JWT token validator and generator +#[derive(Clone)] +pub struct JwtValidator { + config: JwtConfig, + encoding_key: EncodingKey, + decoding_key: DecodingKey, + validation: Validation, +} + +impl JwtValidator { + /// Create a new JWT validator with the given configuration + pub fn new(config: JwtConfig) -> Self { + let encoding_key = EncodingKey::from_secret(config.secret.as_bytes()); + let decoding_key = DecodingKey::from_secret(config.secret.as_bytes()); + + let mut validation = Validation::new(Algorithm::HS256); + + // Configure issuer validation if set + if let Some(ref issuer) = config.issuer { + validation.set_issuer(&[issuer]); + } + + // Configure audience validation if set + if let Some(ref audience) = config.audience { + validation.set_audience(&[audience]); + } + + Self { + config, + encoding_key, + decoding_key, + validation, + } + } + + /// Create a JWT token for the given user + pub fn create_token(&self, user: &User) -> Result { + let now = SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("Time went backwards") + .as_secs() as usize; + + let expiry = now + (self.config.expiry_hours as usize * 3600); + + let claims = JwtClaims { + sub: user.id.to_string(), + email: user.email.as_ref().to_string(), + exp: expiry, + iat: now, + iss: self.config.issuer.clone(), + aud: self.config.audience.clone(), + }; + + let header = Header::new(Algorithm::HS256); + encode(&header, &claims, &self.encoding_key).map_err(JwtError::CreationFailed) + } + + /// Validate a JWT token and return the claims + pub fn validate_token(&self, token: &str) -> Result { + let token_data = decode::(token, &self.decoding_key, &self.validation).map_err( + |e| match e.kind() { + jsonwebtoken::errors::ErrorKind::ExpiredSignature => JwtError::Expired, + jsonwebtoken::errors::ErrorKind::InvalidToken => JwtError::InvalidFormat, + _ => JwtError::ValidationFailed(e.to_string()), + }, + )?; + + Ok(token_data.claims) + } + + /// Get the user ID (subject) from a token without full validation + /// Useful for logging/debugging, but should not be trusted for auth + pub fn decode_unverified(&self, token: &str) -> Result { + let mut validation = Validation::new(Algorithm::HS256); + validation.insecure_disable_signature_validation(); + validation.validate_exp = false; + + let token_data = decode::(token, &self.decoding_key, &validation) + .map_err(|_| JwtError::InvalidFormat)?; + + Ok(token_data.claims) + } +} + +impl std::fmt::Debug for JwtValidator { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("JwtValidator") + .field("issuer", &self.config.issuer) + .field("audience", &self.config.audience) + .field("expiry_hours", &self.config.expiry_hours) + .finish_non_exhaustive() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use notes_domain::Email; + + fn create_test_user() -> User { + let email = Email::try_from("test@example.com").unwrap(); + User::new("test-subject", email) + } + + #[test] + fn test_create_and_validate_token() { + let config = JwtConfig::new_unchecked("test-secret-key-that-is-long-enough".to_string()); + let validator = JwtValidator::new(config); + let user = create_test_user(); + + let token = validator.create_token(&user).expect("Should create token"); + let claims = validator + .validate_token(&token) + .expect("Should validate token"); + + assert_eq!(claims.sub, user.id.to_string()); + assert_eq!(claims.email, "test@example.com"); + } + + #[test] + fn test_weak_secret_rejected_in_production() { + let result = JwtConfig::new( + "short".to_string(), // Too short + None, + None, + None, + true, // Production mode + ); + + assert!(matches!(result, Err(JwtError::WeakSecret { .. }))); + } + + #[test] + fn test_weak_secret_allowed_in_development() { + let result = JwtConfig::new( + "short".to_string(), // Too short but OK in dev + None, + None, + None, + false, // Development mode + ); + + assert!(result.is_ok()); + } + + #[test] + fn test_invalid_token_rejected() { + let config = JwtConfig::new_unchecked("test-secret-key-that-is-long-enough".to_string()); + let validator = JwtValidator::new(config); + + let result = validator.validate_token("invalid.token.here"); + assert!(result.is_err()); + } + + #[test] + fn test_wrong_secret_rejected() { + let config1 = JwtConfig::new_unchecked("secret-one-that-is-long-enough".to_string()); + let config2 = JwtConfig::new_unchecked("secret-two-that-is-long-enough".to_string()); + + let validator1 = JwtValidator::new(config1); + let validator2 = JwtValidator::new(config2); + + let user = create_test_user(); + let token = validator1.create_token(&user).unwrap(); + + // Token from validator1 should fail on validator2 + let result = validator2.validate_token(&token); + assert!(result.is_err()); + } +} diff --git a/notes-infra/src/auth/mod.rs b/notes-infra/src/auth/mod.rs new file mode 100644 index 0000000..ac01f0b --- /dev/null +++ b/notes-infra/src/auth/mod.rs @@ -0,0 +1,6 @@ +#[cfg(feature = "auth-axum-login")] +mod axum_login; +#[cfg(feature = "auth-jwt")] +mod jwt; +#[cfg(feature = "auth-oidc")] +mod oidc; diff --git a/notes-infra/src/auth/oidc.rs b/notes-infra/src/auth/oidc.rs new file mode 100644 index 0000000..76f0929 --- /dev/null +++ b/notes-infra/src/auth/oidc.rs @@ -0,0 +1,202 @@ +use anyhow::anyhow; +use notes_domain::{ + AuthorizationCode, AuthorizationUrlData, ClientId, ClientSecret, CsrfToken, IssuerUrl, + OidcNonce, PkceVerifier, RedirectUrl, ResourceId, +}; +use openidconnect::{ + AccessTokenHash, Client, EmptyAdditionalClaims, EndpointMaybeSet, EndpointNotSet, EndpointSet, + OAuth2TokenResponse, PkceCodeChallenge, Scope, StandardErrorResponse, TokenResponse, + UserInfoClaims, + core::{ + CoreAuthDisplay, CoreAuthPrompt, CoreAuthenticationFlow, CoreClient, CoreErrorResponseType, + CoreGenderClaim, CoreJsonWebKey, CoreJweContentEncryptionAlgorithm, CoreProviderMetadata, + CoreRevocableToken, CoreRevocationErrorResponse, CoreTokenIntrospectionResponse, + CoreTokenResponse, + }, + reqwest, +}; + +pub type OidcClient = Client< + EmptyAdditionalClaims, + CoreAuthDisplay, + CoreGenderClaim, + CoreJweContentEncryptionAlgorithm, + CoreJsonWebKey, + CoreAuthPrompt, + StandardErrorResponse, + CoreTokenResponse, + CoreTokenIntrospectionResponse, + CoreRevocableToken, + CoreRevocationErrorResponse, + EndpointSet, // HasAuthUrl (Required and guaranteed by discovery) + EndpointNotSet, // HasDeviceAuthUrl + EndpointNotSet, // HasIntrospectionUrl + EndpointNotSet, // HasRevocationUrl + EndpointMaybeSet, // HasTokenUrl (Discovered, might be missing) + EndpointMaybeSet, // HasUserInfoUrl (Discovered, might be missing) +>; + +#[derive(Clone)] +pub struct OidcService { + client: OidcClient, + resource_id: Option, +} + +#[derive(Debug)] +pub struct OidcUser { + pub subject: String, + pub email: String, +} + +impl OidcService { + /// Create a new OIDC service with validated configuration newtypes + pub async fn new( + issuer: IssuerUrl, + client_id: ClientId, + client_secret: Option, + redirect_url: RedirectUrl, + resource_id: Option, + ) -> anyhow::Result { + tracing::debug!("🔵 OIDC Setup: Client ID = '{}'", client_id); + tracing::debug!("🔵 OIDC Setup: Redirect = '{}'", redirect_url); + tracing::debug!( + "🔵 OIDC Setup: Secret = {:?}", + if client_secret.is_some() { + "SET" + } else { + "NONE" + } + ); + + let http_client = reqwest::ClientBuilder::new() + .redirect(reqwest::redirect::Policy::none()) + .build()?; + + let provider_metadata = CoreProviderMetadata::discover_async( + openidconnect::IssuerUrl::new(issuer.as_ref().to_string())?, + &http_client, + ) + .await?; + + // Convert to openidconnect types + let oidc_client_id = openidconnect::ClientId::new(client_id.as_ref().to_string()); + let oidc_client_secret = client_secret + .as_ref() + .filter(|s| !s.is_empty()) + .map(|s| openidconnect::ClientSecret::new(s.as_ref().to_string())); + let oidc_redirect_url = openidconnect::RedirectUrl::new(redirect_url.as_ref().to_string())?; + + let client = CoreClient::from_provider_metadata( + provider_metadata, + oidc_client_id, + oidc_client_secret, + ) + .set_redirect_uri(oidc_redirect_url); + + Ok(Self { + client, + resource_id, + }) + } + + /// Get the authorization URL and associated state for OIDC login + /// + /// Returns structured data instead of a raw tuple for better type safety + pub fn get_authorization_url(&self) -> AuthorizationUrlData { + let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256(); + + let (auth_url, csrf_token, nonce) = self + .client + .authorize_url( + CoreAuthenticationFlow::AuthorizationCode, + openidconnect::CsrfToken::new_random, + openidconnect::Nonce::new_random, + ) + .add_scope(Scope::new("profile".to_string())) + .add_scope(Scope::new("email".to_string())) + .set_pkce_challenge(pkce_challenge) + .url(); + + AuthorizationUrlData { + url: auth_url.into(), + csrf_token: CsrfToken::new(csrf_token.secret().to_string()), + nonce: OidcNonce::new(nonce.secret().to_string()), + pkce_verifier: PkceVerifier::new(pkce_verifier.secret().to_string()), + } + } + + /// Resolve the OIDC callback with type-safe parameters + pub async fn resolve_callback( + &self, + code: AuthorizationCode, + nonce: OidcNonce, + pkce_verifier: PkceVerifier, + ) -> anyhow::Result { + let http_client = reqwest::ClientBuilder::new() + .redirect(reqwest::redirect::Policy::none()) + .build()?; + + let oidc_pkce_verifier = + openidconnect::PkceCodeVerifier::new(pkce_verifier.as_ref().to_string()); + let oidc_nonce = openidconnect::Nonce::new(nonce.as_ref().to_string()); + + let token_response = self + .client + .exchange_code(openidconnect::AuthorizationCode::new( + code.as_ref().to_string(), + ))? + .set_pkce_verifier(oidc_pkce_verifier) + .request_async(&http_client) + .await?; + + let id_token = token_response + .id_token() + .ok_or_else(|| anyhow!("Server did not return an ID token"))?; + + let mut id_token_verifier = self.client.id_token_verifier().clone(); + + if let Some(resource_id) = &self.resource_id { + let trusted_resource_id = resource_id.as_ref().to_string(); + id_token_verifier = id_token_verifier + .set_other_audience_verifier_fn(move |aud| aud.as_str() == trusted_resource_id); + } + + let claims = id_token.claims(&id_token_verifier, &oidc_nonce)?; + + if let Some(expected_access_token_hash) = claims.access_token_hash() { + let actual_access_token_hash = AccessTokenHash::from_token( + token_response.access_token(), + id_token.signing_alg()?, + id_token.signing_key(&id_token_verifier)?, + )?; + + if actual_access_token_hash != *expected_access_token_hash { + return Err(anyhow!("Invalid access token")); + } + } + + let email = if let Some(email) = claims.email() { + Some(email.as_str().to_string()) + } else { + // Fallback: Call UserInfo Endpoint using the Access Token + tracing::debug!("🔵 Email missing in ID Token, fetching UserInfo..."); + + let user_info: UserInfoClaims = self + .client + .user_info(token_response.access_token().clone(), None)? + .request_async(&http_client) + .await?; + + user_info.email().map(|e| e.as_str().to_string()) + }; + + // If email is still missing, we must error out because your app requires valid emails + let email = + email.ok_or_else(|| anyhow!("User has no verified email address in ZITADEL"))?; + + Ok(OidcUser { + subject: claims.subject().to_string(), + email, + }) + } +} diff --git a/notes-infra/src/lib.rs b/notes-infra/src/lib.rs index 2a0ad21..9b4659e 100644 --- a/notes-infra/src/lib.rs +++ b/notes-infra/src/lib.rs @@ -13,6 +13,7 @@ //! //! - [`db::run_migrations`] - Run database migrations +pub mod auth; #[cfg(feature = "broker-nats")] pub mod broker; pub mod db; -- 2.49.1 From a5f9e8ae9ea9abbf75b300ed3c3d9f05f2dc917e Mon Sep 17 00:00:00 2001 From: Gabriel Kaszewski Date: Tue, 6 Jan 2026 20:31:57 +0100 Subject: [PATCH 4/5] feat: Implement flexible authentication supporting JWT, OIDC, and session modes, alongside new configuration options and refactored auth layer setup. --- Cargo.lock | 2 - notes-api/Cargo.toml | 28 +- notes-api/src/auth.rs | 98 +---- notes-api/src/config.rs | 118 ++++++ notes-api/src/dto.rs | 24 +- notes-api/src/error.rs | 11 + notes-api/src/extractors.rs | 133 ++++++ notes-api/src/main.rs | 128 +++--- notes-api/src/routes/auth.rs | 574 +++++++++++++++++++++----- notes-api/src/routes/import_export.rs | 23 +- notes-api/src/routes/mod.rs | 5 +- notes-api/src/routes/notes.rs | 80 +--- notes-api/src/routes/tags.rs | 43 +- notes-api/src/state.rs | 100 ++++- notes-domain/src/errors.rs | 6 + notes-domain/src/services.rs | 56 +-- notes-infra/src/auth/mod.rs | 6 +- notes-infra/src/session_store.rs | 1 + 18 files changed, 1022 insertions(+), 414 deletions(-) create mode 100644 notes-api/src/extractors.rs diff --git a/Cargo.lock b/Cargo.lock index d28ae67..09ce3ce 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2336,7 +2336,6 @@ dependencies = [ "anyhow", "async-trait", "axum 0.8.8", - "axum-login", "chrono", "dotenvy", "k-core", @@ -2351,7 +2350,6 @@ dependencies = [ "tower 0.5.2", "tower-http", "tower-sessions", - "tower-sessions-sqlx-store", "tracing", "tracing-subscriber", "uuid", diff --git a/notes-api/Cargo.toml b/notes-api/Cargo.toml index cd58238..69cfe7a 100644 --- a/notes-api/Cargo.toml +++ b/notes-api/Cargo.toml @@ -5,16 +5,14 @@ edition = "2024" default-run = "notes-api" [features] -default = ["sqlite", "smart-features"] -sqlite = [ - "notes-infra/sqlite", - "tower-sessions-sqlx-store/sqlite", -] -postgres = [ - "notes-infra/postgres", - "tower-sessions-sqlx-store/postgres", -] +default = ["sqlite", "smart-features", "auth-oidc", "auth-jwt"] +sqlite = ["notes-infra/sqlite"] +postgres = ["notes-infra/postgres"] smart-features = ["notes-infra/smart-features", "notes-infra/broker-nats"] +auth-axum-login = ["notes-infra/auth-axum-login"] +auth-oidc = ["notes-infra/auth-oidc"] +auth-jwt = ["notes-infra/auth-jwt"] +auth-full = ["auth-axum-login", "auth-oidc", "auth-jwt"] [dependencies] notes-domain = { path = "../notes-domain" } @@ -28,9 +26,7 @@ tower = "0.5.2" tower-http = { version = "0.6.2", features = ["cors", "trace"] } # Authentication -axum-login = "0.18" -tower-sessions = "0.14" -tower-sessions-sqlx-store = { version = "0.15", features = ["sqlite"] } + password-auth = "1.0" time = "0.3" async-trait = "0.1.89" @@ -64,5 +60,9 @@ k-core = { git = "https://git.gabrielkaszewski.dev/GKaszewski/k-core", features "db-sqlx", "sqlite", "http", - "auth","sessions-db" -] } \ No newline at end of file + "auth", + "sessions-db", +] } + + +tower-sessions = "0.14.0" diff --git a/notes-api/src/auth.rs b/notes-api/src/auth.rs index 2ad1be0..a7a930b 100644 --- a/notes-api/src/auth.rs +++ b/notes-api/src/auth.rs @@ -1,87 +1,27 @@ -//! Authentication logic using axum-login +//! Authentication logic +//! +//! Proxies to infra implementation if enabled. +#[cfg(feature = "auth-axum-login")] use std::sync::Arc; -use axum_login::{AuthnBackend, UserId}; -use password_auth::verify_password; -use serde::{Deserialize, Serialize}; -use uuid::Uuid; +#[cfg(feature = "auth-axum-login")] +use notes_domain::UserRepository; +#[cfg(feature = "auth-axum-login")] +use notes_infra::session_store::{InfraSessionStore, SessionManagerLayer}; +#[cfg(feature = "auth-axum-login")] use crate::error::ApiError; -use notes_domain::{User, UserRepository}; -/// Wrapper around domain User to implement AuthUser -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct AuthUser(pub User); +#[cfg(feature = "auth-axum-login")] +pub use notes_infra::auth::axum_login::{AuthManagerLayer, AuthSession, AuthUser, Credentials}; -impl axum_login::AuthUser for AuthUser { - type Id = Uuid; - - fn id(&self) -> Self::Id { - self.0.id - } - - fn session_auth_hash(&self) -> &[u8] { - // Use password hash to invalidate sessions if password changes - self.0 - .password_hash - .as_ref() - .map(|s| s.as_bytes()) - .unwrap_or(&[]) - } -} - -#[derive(Clone)] -pub struct AuthBackend { - pub user_repo: Arc, -} - -impl AuthBackend { - pub fn new(user_repo: Arc) -> Self { - Self { user_repo } - } -} - -#[derive(Clone, Debug, Deserialize)] -pub struct Credentials { - pub email: String, - pub password: String, -} - -impl AuthnBackend for AuthBackend { - type User = AuthUser; - type Credentials = Credentials; - type Error = ApiError; - - async fn authenticate( - &self, - creds: Self::Credentials, - ) -> Result, Self::Error> { - let user = self - .user_repo - .find_by_email(&creds.email) - .await - .map_err(|e| ApiError::internal(e.to_string()))?; - - if let Some(user) = user { - if let Some(hash) = &user.password_hash { - // Verify password - if verify_password(&creds.password, hash).is_ok() { - return Ok(Some(AuthUser(user))); - } - } - } - - Ok(None) - } - - async fn get_user(&self, user_id: &UserId) -> Result, Self::Error> { - let user = self - .user_repo - .find_by_id(*user_id) - .await - .map_err(|e| ApiError::internal(e.to_string()))?; - - Ok(user.map(AuthUser)) - } +#[cfg(feature = "auth-axum-login")] +pub async fn setup_auth_layer( + session_layer: SessionManagerLayer, + user_repo: Arc, +) -> Result { + notes_infra::auth::axum_login::setup_auth_layer(session_layer, user_repo) + .await + .map_err(|e| ApiError::Internal(e.to_string())) } diff --git a/notes-api/src/config.rs b/notes-api/src/config.rs index c20f5ee..7d35876 100644 --- a/notes-api/src/config.rs +++ b/notes-api/src/config.rs @@ -1,7 +1,32 @@ #[cfg(feature = "smart-features")] use notes_infra::factory::{EmbeddingProvider, VectorProvider}; +use serde::Deserialize; use std::env; +/// Authentication mode - determines how the API authenticates requests +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Deserialize)] +#[serde(rename_all = "lowercase")] +pub enum AuthMode { + /// Session-based authentication using cookies (default for backward compatibility) + #[default] + Session, + /// JWT-based authentication using Bearer tokens + Jwt, + /// Support both session and JWT authentication (try JWT first, then session) + Both, +} + +impl AuthMode { + /// Parse auth mode from string + pub fn from_str(s: &str) -> Self { + match s.to_lowercase().as_str() { + "jwt" => AuthMode::Jwt, + "both" => AuthMode::Both, + _ => AuthMode::Session, + } + } +} + /// Server configuration #[derive(Debug, Clone)] pub struct Config { @@ -16,6 +41,31 @@ pub struct Config { #[cfg(feature = "smart-features")] pub vector_provider: VectorProvider, pub broker_url: String, + + pub secure_cookie: bool, + + pub db_max_connections: u32, + + pub db_min_connections: u32, + + // OIDC configuration + pub oidc_issuer: Option, + pub oidc_client_id: Option, + pub oidc_client_secret: Option, + pub oidc_redirect_url: Option, + pub oidc_resource_id: Option, + + // Auth mode configuration + pub auth_mode: AuthMode, + + // JWT configuration + pub jwt_secret: Option, + pub jwt_issuer: Option, + pub jwt_audience: Option, + pub jwt_expiry_hours: u64, + + /// Whether the application is running in production mode + pub is_production: bool, } impl Default for Config { @@ -36,6 +86,20 @@ impl Default for Config { collection: "notes".to_string(), }, broker_url: "nats://localhost:4222".to_string(), + secure_cookie: false, + db_max_connections: 5, + db_min_connections: 1, + oidc_issuer: None, + oidc_client_id: None, + oidc_client_secret: None, + oidc_redirect_url: None, + oidc_resource_id: None, + auth_mode: AuthMode::Session, + jwt_secret: None, + jwt_issuer: None, + jwt_audience: None, + jwt_expiry_hours: 24, + is_production: false, } } } @@ -89,6 +153,46 @@ impl Config { let broker_url = env::var("BROKER_URL").unwrap_or_else(|_| "nats://localhost:4222".to_string()); + let secure_cookie = env::var("SECURE_COOKIE") + .ok() + .and_then(|s| s.parse().ok()) + .unwrap_or(false); + + let db_max_connections = env::var("DB_MAX_CONNECTIONS") + .ok() + .and_then(|s| s.parse().ok()) + .unwrap_or(5); + + let db_min_connections = env::var("DB_MIN_CONNECTIONS") + .ok() + .and_then(|s| s.parse().ok()) + .unwrap_or(1); + + let oidc_issuer = env::var("OIDC_ISSUER").ok(); + let oidc_client_id = env::var("OIDC_CLIENT_ID").ok(); + let oidc_client_secret = env::var("OIDC_CLIENT_SECRET").ok(); + let oidc_redirect_url = env::var("OIDC_REDIRECT_URL").ok(); + let oidc_resource_id = env::var("OIDC_RESOURCE_ID").ok(); + + // Auth mode configuration + let auth_mode = env::var("AUTH_MODE") + .map(|s| AuthMode::from_str(&s)) + .unwrap_or_default(); + + // JWT configuration + let jwt_secret = env::var("JWT_SECRET").ok(); + let jwt_issuer = env::var("JWT_ISSUER").ok(); + let jwt_audience = env::var("JWT_AUDIENCE").ok(); + let jwt_expiry_hours = env::var("JWT_EXPIRY_HOURS") + .ok() + .and_then(|s| s.parse().ok()) + .unwrap_or(24); + + let is_production = env::var("PRODUCTION") + .or_else(|_| env::var("RUST_ENV")) + .map(|v| v.to_lowercase() == "production" || v == "1" || v == "true") + .unwrap_or(false); + Self { host, port, @@ -101,6 +205,20 @@ impl Config { #[cfg(feature = "smart-features")] vector_provider, broker_url, + secure_cookie, + db_max_connections, + db_min_connections, + oidc_issuer, + oidc_client_id, + oidc_client_secret, + oidc_redirect_url, + oidc_resource_id, + auth_mode, + jwt_secret, + jwt_issuer, + jwt_audience, + jwt_expiry_hours, + is_production, } } } diff --git a/notes-api/src/dto.rs b/notes-api/src/dto.rs index 46d9fbd..c424759 100644 --- a/notes-api/src/dto.rs +++ b/notes-api/src/dto.rs @@ -5,7 +5,7 @@ use serde::{Deserialize, Serialize}; use uuid::Uuid; use validator::Validate; -use notes_domain::{Note, Tag}; +use notes_domain::{Email, Note, Password, Tag}; /// Request to create a new note #[derive(Debug, Deserialize, Validate)] @@ -118,30 +118,24 @@ pub struct RenameTagRequest { } /// Login request -#[derive(Debug, Deserialize, Validate)] +#[derive(Debug, Deserialize)] pub struct LoginRequest { - #[validate(email(message = "Invalid email format"))] - pub email: String, - - #[validate(length(min = 6, message = "Password must be at least 6 characters"))] - pub password: String, + pub email: Email, + pub password: Password, } /// Register request -#[derive(Debug, Deserialize, Validate)] +#[derive(Debug, Deserialize)] pub struct RegisterRequest { - #[validate(email(message = "Invalid email format"))] - pub email: String, - - #[validate(length(min = 6, message = "Password must be at least 6 characters"))] - pub password: String, + pub email: Email, + pub password: Password, } /// User response DTO #[derive(Debug, Serialize)] pub struct UserResponse { pub id: Uuid, - pub email: String, + pub email: Email, pub created_at: DateTime, } @@ -160,7 +154,7 @@ impl From for NoteVersionResponse { Self { id: version.id, note_id: version.note_id, - title: version.title.unwrap_or_default(), // Convert Option to String + title: version.title.unwrap_or_default(), content: version.content, created_at: version.created_at, } diff --git a/notes-api/src/error.rs b/notes-api/src/error.rs index 6b2fb67..5ea5be7 100644 --- a/notes-api/src/error.rs +++ b/notes-api/src/error.rs @@ -26,6 +26,9 @@ pub enum ApiError { #[error("Forbidden: {0}")] Forbidden(String), + + #[error("Unauthorized: {0}")] + Unauthorized(String), } /// Error response body @@ -96,6 +99,14 @@ impl IntoResponse for ApiError { details: Some(msg.clone()), }, ), + + ApiError::Unauthorized(msg) => ( + StatusCode::UNAUTHORIZED, + ErrorResponse { + error: "Unauthorized".to_string(), + details: Some(msg.clone()), + }, + ), }; (status, Json(error_response)).into_response() diff --git a/notes-api/src/extractors.rs b/notes-api/src/extractors.rs new file mode 100644 index 0000000..3ec5a54 --- /dev/null +++ b/notes-api/src/extractors.rs @@ -0,0 +1,133 @@ +//! Auth extractors for API handlers +//! +//! Provides the `CurrentUser` extractor that works with both session and JWT auth. + +use axum::{extract::FromRequestParts, http::request::Parts}; +use notes_domain::User; + +use crate::config::AuthMode; +use crate::error::ApiError; +use crate::state::AppState; + +/// Extracted current user from the request. +/// +/// This extractor supports multiple authentication methods based on the configured `AuthMode`: +/// - `Session`: Uses axum-login session cookies +/// - `Jwt`: Uses Bearer token in Authorization header +/// - `Both`: Tries JWT first, then falls back to session +pub struct CurrentUser(pub User); + +impl FromRequestParts for CurrentUser { + type Rejection = ApiError; + + async fn from_request_parts( + parts: &mut Parts, + state: &AppState, + ) -> Result { + let auth_mode = state.config.auth_mode; + + // Try JWT first if enabled + #[cfg(feature = "auth-jwt")] + if matches!(auth_mode, AuthMode::Jwt | AuthMode::Both) { + match try_jwt_auth(parts, state).await { + Ok(Some(user)) => return Ok(CurrentUser(user)), + Ok(None) => { + // No JWT token present, continue to session auth if Both mode + if auth_mode == AuthMode::Jwt { + return Err(ApiError::Unauthorized( + "Missing or invalid Authorization header".to_string(), + )); + } + } + Err(e) => { + // JWT was present but invalid + tracing::debug!("JWT auth failed: {}", e); + if auth_mode == AuthMode::Jwt { + return Err(e); + } + // In Both mode, continue to try session + } + } + } + + // Try session auth if enabled + #[cfg(feature = "auth-axum-login")] + if matches!(auth_mode, AuthMode::Session | AuthMode::Both) { + if let Some(user) = try_session_auth(parts).await? { + return Ok(CurrentUser(user)); + } + } + + Err(ApiError::Unauthorized("Not authenticated".to_string())) + } +} + +/// Try to authenticate using JWT Bearer token +#[cfg(feature = "auth-jwt")] +async fn try_jwt_auth(parts: &mut Parts, state: &AppState) -> Result, ApiError> { + use axum::http::header::AUTHORIZATION; + + // Get Authorization header + let auth_header = match parts.headers.get(AUTHORIZATION) { + Some(header) => header, + None => return Ok(None), // No header = no JWT auth attempted + }; + + let auth_str = auth_header + .to_str() + .map_err(|_| ApiError::Unauthorized("Invalid Authorization header encoding".to_string()))?; + + // Extract Bearer token + let token = auth_str.strip_prefix("Bearer ").ok_or_else(|| { + ApiError::Unauthorized("Authorization header must use Bearer scheme".to_string()) + })?; + + // Get JWT validator + let validator = state + .jwt_validator + .as_ref() + .ok_or_else(|| ApiError::Internal("JWT validator not configured".to_string()))?; + + // Validate token + let claims = validator.validate_token(token).map_err(|e| { + tracing::debug!("JWT validation failed: {:?}", e); + match e { + notes_infra::auth::jwt::JwtError::Expired => { + ApiError::Unauthorized("Token expired".to_string()) + } + notes_infra::auth::jwt::JwtError::InvalidFormat => { + ApiError::Unauthorized("Invalid token format".to_string()) + } + _ => ApiError::Unauthorized("Token validation failed".to_string()), + } + })?; + + // Fetch user from database by ID (subject contains user ID) + let user_id: uuid::Uuid = claims + .sub + .parse() + .map_err(|_| ApiError::Unauthorized("Invalid user ID in token".to_string()))?; + + let user = state + .user_service + .find_by_id(user_id) + .await + .map_err(|e| ApiError::Internal(format!("Failed to fetch user: {}", e)))?; + + Ok(Some(user)) +} + +/// Try to authenticate using session cookie +#[cfg(feature = "auth-axum-login")] +async fn try_session_auth(parts: &mut Parts) -> Result, ApiError> { + use notes_infra::auth::axum_login::AuthSession; + + // Check if AuthSession extension is present (added by auth middleware) + if let Some(auth_session) = parts.extensions.get::() { + if let Some(auth_user) = &auth_session.user { + return Ok(Some(auth_user.0.clone())); + } + } + + Ok(None) +} diff --git a/notes-api/src/main.rs b/notes-api/src/main.rs index 99abea2..d9fc0e3 100644 --- a/notes-api/src/main.rs +++ b/notes-api/src/main.rs @@ -2,17 +2,15 @@ //! //! A high-performance, self-hosted note-taking API following hexagonal architecture. -use k_core::{ - db::DatabasePool, - http::server::{ServerConfig, apply_standard_middleware}, -}; +use k_core::http::server::{ServerConfig, apply_standard_middleware}; +use std::net::SocketAddr; use std::{sync::Arc, time::Duration as StdDuration}; use time::Duration; +use tokio::net::TcpListener; +use tower_sessions::cookie::SameSite; +use tower_sessions::{Expiry, SessionManagerLayer}; use axum::Router; -use axum_login::AuthManagerLayerBuilder; - -use tower_sessions::{Expiry, SessionManagerLayer}; use notes_infra::run_migrations; @@ -20,13 +18,15 @@ mod auth; mod config; mod dto; mod error; +mod extractors; mod routes; mod state; -use auth::AuthBackend; use config::Config; use state::AppState; +use crate::config::AuthMode; + #[tokio::main] async fn main() -> anyhow::Result<()> { k_core::logging::init("notes_api"); @@ -53,9 +53,6 @@ async fn main() -> anyhow::Result<()> { build_note_repository, build_session_store, build_tag_repository, build_user_repository, }; - // Create a default user for development - create_dev_user(&db_pool).await.ok(); - // Create repositories via factory let note_repo = build_note_repository(&db_pool) .await @@ -105,20 +102,16 @@ async fn main() -> anyhow::Result<()> { let state = AppState::new( note_repo, tag_repo, - user_repo.clone(), #[cfg(feature = "smart-features")] link_repo, note_service, tag_service, user_service, config.clone(), - ); + ) + .await?; - // Auth backend - let backend = AuthBackend::new(user_repo); // no idea what now with this - - // Session layer - // Use the factory to build the session store, agnostic of the underlying DB + // Build session store (needed for OIDC flow even in JWT mode) let session_store = build_session_store(&db_pool) .await .map_err(|e| anyhow::anyhow!(e))?; @@ -128,28 +121,24 @@ async fn main() -> anyhow::Result<()> { .map_err(|e| anyhow::anyhow!(e))?; let session_layer = SessionManagerLayer::new(session_store) - .with_secure(false) // Set to true in prod + .with_secure(config.secure_cookie) + .with_same_site(SameSite::Lax) .with_expiry(Expiry::OnInactivity(Duration::days(7))); - let auth_layer = AuthManagerLayerBuilder::new(backend, session_layer).build(); - let server_config = ServerConfig { cors_origins: config.cors_allowed_origins.clone(), session_secret: Some(config.session_secret.clone()), }; - let app = Router::new() - .nest("/api/v1", routes::api_v1_router()) - .layer(auth_layer) - .with_state(state); - + // Build the app with appropriate auth layers based on config + let app = build_app(state, session_layer, user_repo, &config).await?; let app = apply_standard_middleware(app, &server_config); - let addr = format!("{}:{}", config.host, config.port); - let listener = tokio::net::TcpListener::bind(&addr).await?; + let addr: SocketAddr = format!("{}:{}", config.host, config.port).parse()?; + let listener = TcpListener::bind(addr).await?; - tracing::info!("🚀 K-Notes API server running at http://{}", addr); - tracing::info!("🔒 Authentication enabled (axum-login)"); + tracing::info!("🚀 API server running at http://{}", addr); + log_auth_info(&config); tracing::info!("📝 API endpoints available at /api/v1/..."); axum::serve(listener, app).await?; @@ -157,32 +146,61 @@ async fn main() -> anyhow::Result<()> { Ok(()) } -async fn create_dev_user(pool: &DatabasePool) -> anyhow::Result<()> { - use notes_domain::{Email, User}; - use notes_infra::factory::build_user_repository; - use password_auth::generate_hash; - use uuid::Uuid; +/// Build the application router with appropriate auth layers +#[allow(unused_variables)] // config/user_repo used conditionally based on features +async fn build_app( + state: AppState, + session_layer: SessionManagerLayer, + user_repo: std::sync::Arc, + config: &Config, +) -> anyhow::Result { + let app = Router::new() + .nest("/api/v1", routes::api_v1_router()) + .with_state(state); - let user_repo = build_user_repository(pool) - .await - .map_err(|e| anyhow::anyhow!(e))?; - - // Check if dev user exists - let dev_user_id = Uuid::parse_str("00000000-0000-0000-0000-000000000001").unwrap(); - if user_repo.find_by_id(dev_user_id).await?.is_none() { - let hash = generate_hash("password"); - let dev_email = Email::try_from("dev@localhost.com") - .map_err(|e| anyhow::anyhow!("Invalid dev email: {}", e))?; - let user = User::with_id( - dev_user_id, - "dev|local", - dev_email, - Some(hash), - chrono::Utc::now(), - ); - user_repo.save(&user).await?; - tracing::info!("Created development user: dev@localhost.com / password"); + // When auth-axum-login feature is enabled, always apply the auth layer. + // This is needed because: + // 1. OIDC callback uses AuthSession for state management + // 2. Session-based login/register routes use it + // 3. The "JWT mode" just changes what the login endpoint returns, not the underlying session support + #[cfg(feature = "auth-axum-login")] + { + let auth_layer = auth::setup_auth_layer(session_layer, user_repo).await?; + return Ok(app.layer(auth_layer)); } - Ok(()) + // When auth-axum-login is not compiled in, just use session layer for OIDC flow + #[cfg(not(feature = "auth-axum-login"))] + { + let _ = user_repo; // Suppress unused warning + Ok(app.layer(session_layer)) + } +} + +/// Log authentication info based on enabled features and config +fn log_auth_info(config: &Config) { + match config.auth_mode { + AuthMode::Session => { + tracing::info!("🔒 Authentication mode: Session (cookie-based)"); + } + AuthMode::Jwt => { + tracing::info!("🔒 Authentication mode: JWT (Bearer token)"); + } + AuthMode::Both => { + tracing::info!("🔒 Authentication mode: Both (JWT + Session)"); + } + } + + #[cfg(feature = "auth-axum-login")] + tracing::info!(" ✓ Session auth enabled (axum-login)"); + + #[cfg(feature = "auth-jwt")] + if config.jwt_secret.is_some() { + tracing::info!(" ✓ JWT auth enabled"); + } + + #[cfg(feature = "auth-oidc")] + if config.oidc_issuer.is_some() { + tracing::info!(" ✓ OIDC integration enabled"); + } } diff --git a/notes-api/src/routes/auth.rs b/notes-api/src/routes/auth.rs index 9418805..4cc5344 100644 --- a/notes-api/src/routes/auth.rs +++ b/notes-api/src/routes/auth.rs @@ -1,117 +1,491 @@ //! Authentication routes +//! +//! Provides login, register, logout, and token endpoints. +//! Supports both session-based and JWT-based authentication. -use axum::{Json, extract::State, http::StatusCode}; -use axum_login::AuthSession; -use validator::Validate; +#[cfg(feature = "auth-oidc")] +use axum::response::Response; +use axum::{ + Router, + extract::{Json, State}, + http::StatusCode, + response::IntoResponse, + routing::{get, post}, +}; +use serde::Serialize; +#[cfg(feature = "auth-oidc")] +use tower_sessions::Session; -use notes_domain::{Email, User}; -use password_auth::generate_hash; +#[cfg(feature = "auth-axum-login")] +use crate::config::AuthMode; +use crate::{ + dto::{LoginRequest, RegisterRequest, UserResponse}, + error::ApiError, + extractors::CurrentUser, + state::AppState, +}; +#[cfg(feature = "auth-axum-login")] +use notes_domain::DomainError; -use crate::auth::{AuthBackend, AuthUser, Credentials}; -use crate::dto::{LoginRequest, RegisterRequest}; -use crate::error::{ApiError, ApiResult}; -use crate::state::AppState; - -/// Register a new user -pub async fn register( - State(state): State, - mut auth_session: AuthSession, - Json(payload): Json, -) -> ApiResult { - payload - .validate() - .map_err(|e| ApiError::validation(e.to_string()))?; - - // Check if registration is allowed - if !state.config.allow_registration { - return Err(ApiError::Forbidden("Registration is disabled".to_string())); - } - - // Check if user exists - if state - .user_repo - .find_by_email(&payload.email) - .await - .map_err(ApiError::from)? - .is_some() - { - return Err(ApiError::Domain( - notes_domain::DomainError::UserAlreadyExists(payload.email.clone()), - )); - } - - // Hash password - let password_hash = generate_hash(&payload.password); - - // Parse email string to Email newtype - let email = Email::try_from(payload.email) - .map_err(|e| ApiError::validation(format!("Invalid email: {}", e)))?; - - // Create user - for local registration, we use email as subject - let user = User::new_local(email, &password_hash); - - state.user_repo.save(&user).await.map_err(ApiError::from)?; - - // Auto login after registration - let user = AuthUser(user); - auth_session - .login(&user) - .await - .map_err(|e| ApiError::internal(e.to_string()))?; - - Ok(StatusCode::CREATED) +/// Token response for JWT authentication +#[derive(Debug, Serialize)] +pub struct TokenResponse { + pub access_token: String, + pub token_type: String, + pub expires_in: u64, } -/// Login user -pub async fn login( - mut auth_session: AuthSession, - Json(payload): Json, -) -> ApiResult { - payload - .validate() - .map_err(|e| ApiError::validation(e.to_string()))?; +/// Login response that can be either a user (session mode) or a token (JWT mode) +#[derive(Debug, Serialize)] +#[serde(untagged)] +pub enum LoginResponse { + User(UserResponse), + Token(TokenResponse), +} - let user = auth_session - .authenticate(Credentials { +pub fn router() -> Router { + let r = Router::new() + .route("/login", post(login)) + .route("/register", post(register)) + .route("/logout", post(logout)) + .route("/me", get(me)); + + // Add token endpoint for getting JWT from session + #[cfg(feature = "auth-jwt")] + let r = r.route("/token", post(get_token)); + + #[cfg(feature = "auth-oidc")] + let r = r + .route("/login/oidc", get(oidc_login)) + .route("/callback", get(oidc_callback)); + + r +} + +/// Login endpoint +/// +/// In session mode: Creates a session and returns user info +/// In JWT mode: Validates credentials and returns a JWT token +/// In both mode: Creates session AND returns JWT token +#[cfg(feature = "auth-axum-login")] +async fn login( + State(state): State, + mut auth_session: crate::auth::AuthSession, + Json(payload): Json, +) -> Result { + let user = match auth_session + .authenticate(crate::auth::Credentials { email: payload.email, password: payload.password, }) .await - .map_err(|e| ApiError::internal(e.to_string()))? - .ok_or_else(|| ApiError::validation("Invalid email or password"))?; // Generic error for security + .map_err(|e| ApiError::Internal(e.to_string()))? + { + Some(user) => user, + None => return Err(ApiError::Validation("Invalid credentials".to_string())), + }; - auth_session - .login(&user) - .await - .map_err(|e| ApiError::internal(e.to_string()))?; + let auth_mode = state.config.auth_mode; - Ok(StatusCode::OK) -} - -/// Logout user -pub async fn logout(mut auth_session: AuthSession) -> ApiResult { - auth_session - .logout() - .await - .map_err(|e| ApiError::internal(e.to_string()))?; - - Ok(StatusCode::OK) -} - -/// Get current user -pub async fn me( - auth_session: AuthSession, -) -> ApiResult> { - let user = + // In session or both mode, create session + if matches!(auth_mode, AuthMode::Session | AuthMode::Both) { auth_session - .user - .ok_or(ApiError::Domain(notes_domain::DomainError::Unauthorized( - "Not logged in".to_string(), - )))?; + .login(&user) + .await + .map_err(|_| ApiError::Internal("Login failed".to_string()))?; + } - Ok(Json(crate::dto::UserResponse { - id: user.0.id, - email: user.0.email_str().to_string(), // Convert Email to String - created_at: user.0.created_at, + // In JWT or both mode, return token + #[cfg(feature = "auth-jwt")] + if matches!(auth_mode, AuthMode::Jwt | AuthMode::Both) { + let token = create_jwt_for_user(&user.0, &state)?; + return Ok(( + StatusCode::OK, + Json(LoginResponse::Token(TokenResponse { + access_token: token, + token_type: "Bearer".to_string(), + expires_in: state.config.jwt_expiry_hours * 3600, + })), + )); + } + + // Session mode: return user info + Ok(( + StatusCode::OK, + Json(LoginResponse::User(UserResponse { + id: user.0.id, + email: user.0.email, + created_at: user.0.created_at, + })), + )) +} + +/// Fallback login when auth-axum-login is not enabled +/// Without auth-axum-login, password-based authentication is not available. +/// Use OIDC login instead: GET /api/v1/auth/login/oidc +#[cfg(not(feature = "auth-axum-login"))] +async fn login( + State(_state): State, + Json(_payload): Json, +) -> Result<(StatusCode, Json), ApiError> { + Err(ApiError::Internal( + "Password-based login not available. auth-axum-login feature is required. Use OIDC login at /api/v1/auth/login/oidc instead.".to_string(), + )) +} + +/// Register endpoint +#[cfg(feature = "auth-axum-login")] +async fn register( + State(state): State, + mut auth_session: crate::auth::AuthSession, + Json(payload): Json, +) -> Result { + // Email is already validated by the newtype deserialization + let email = payload.email; + + if state + .user_service + .find_by_email(email.as_ref()) + .await? + .is_some() + { + return Err(ApiError::Domain(DomainError::UserAlreadyExists( + email.as_ref().to_string(), + ))); + } + + // Hash password + let password_hash = notes_infra::auth::axum_login::hash_password(payload.password.as_ref()); + + // Create user with password + let user = state + .user_service + .create_local(email.as_ref(), &password_hash) + .await?; + + let auth_mode = state.config.auth_mode; + + // In session or both mode, create session + if matches!(auth_mode, AuthMode::Session | AuthMode::Both) { + let auth_user = crate::auth::AuthUser(user.clone()); + auth_session + .login(&auth_user) + .await + .map_err(|_| ApiError::Internal("Login failed".to_string()))?; + } + + // In JWT or both mode, return token + #[cfg(feature = "auth-jwt")] + if matches!(auth_mode, AuthMode::Jwt | AuthMode::Both) { + let token = create_jwt_for_user(&user, &state)?; + return Ok(( + StatusCode::CREATED, + Json(LoginResponse::Token(TokenResponse { + access_token: token, + token_type: "Bearer".to_string(), + expires_in: state.config.jwt_expiry_hours * 3600, + })), + )); + } + + Ok(( + StatusCode::CREATED, + Json(LoginResponse::User(UserResponse { + id: user.id, + email: user.email, + created_at: user.created_at, + })), + )) +} + +/// Fallback register when auth-axum-login is not enabled +#[cfg(not(feature = "auth-axum-login"))] +async fn register( + State(_state): State, + Json(_payload): Json, +) -> Result<(StatusCode, Json), ApiError> { + Err(ApiError::Internal( + "Session-based registration not available. Use JWT token endpoint.".to_string(), + )) +} + +/// Logout endpoint +#[cfg(feature = "auth-axum-login")] +async fn logout(mut auth_session: crate::auth::AuthSession) -> impl IntoResponse { + match auth_session.logout().await { + Ok(_) => StatusCode::OK, + Err(_) => StatusCode::INTERNAL_SERVER_ERROR, + } +} + +/// Fallback logout when auth-axum-login is not enabled +#[cfg(not(feature = "auth-axum-login"))] +async fn logout() -> impl IntoResponse { + // JWT tokens can't be "logged out" server-side without a blocklist + // Just return OK + StatusCode::OK +} + +/// Get current user info +async fn me(CurrentUser(user): CurrentUser) -> Result { + Ok(Json(UserResponse { + id: user.id, + email: user.email, + created_at: user.created_at, })) } + +/// Get a JWT token for the current session user +/// +/// This allows session-authenticated users to obtain a JWT for API access. +#[cfg(feature = "auth-jwt")] +async fn get_token( + State(state): State, + CurrentUser(user): CurrentUser, +) -> Result { + let token = create_jwt_for_user(&user, &state)?; + + Ok(Json(TokenResponse { + access_token: token, + token_type: "Bearer".to_string(), + expires_in: state.config.jwt_expiry_hours * 3600, + })) +} + +/// Helper to create JWT for a user +#[cfg(feature = "auth-jwt")] +fn create_jwt_for_user(user: ¬es_domain::User, state: &AppState) -> Result { + let validator = state + .jwt_validator + .as_ref() + .ok_or_else(|| ApiError::Internal("JWT not configured".to_string()))?; + + validator + .create_token(user) + .map_err(|e| ApiError::Internal(format!("Failed to create token: {}", e))) +} + +// ============================================================================ +// OIDC Routes +// ============================================================================ + +#[cfg(feature = "auth-oidc")] +async fn oidc_login(State(state): State, session: Session) -> Result { + use axum::http::header; + + let service = state + .oidc_service + .as_ref() + .ok_or(ApiError::Internal("OIDC not configured".into()))?; + + let auth_data = service.get_authorization_url(); + + session + .insert("oidc_csrf", &auth_data.csrf_token) + .await + .map_err(|_| ApiError::Internal("Session error".into()))?; + session + .insert("oidc_nonce", &auth_data.nonce) + .await + .map_err(|_| ApiError::Internal("Session error".into()))?; + session + .insert("oidc_pkce", &auth_data.pkce_verifier) + .await + .map_err(|_| ApiError::Internal("Session error".into()))?; + + let response = axum::response::Redirect::to(auth_data.url.as_str()).into_response(); + let (mut parts, body) = response.into_parts(); + + parts.headers.insert( + header::CACHE_CONTROL, + "no-cache, no-store, must-revalidate".parse().unwrap(), + ); + parts + .headers + .insert(header::PRAGMA, "no-cache".parse().unwrap()); + parts.headers.insert(header::EXPIRES, "0".parse().unwrap()); + + Ok(Response::from_parts(parts, body)) +} + +#[cfg(feature = "auth-oidc")] +#[derive(serde::Deserialize)] +struct CallbackParams { + code: String, + state: String, +} + +#[cfg(all(feature = "auth-oidc", feature = "auth-axum-login"))] +async fn oidc_callback( + State(state): State, + session: Session, + mut auth_session: crate::auth::AuthSession, + axum::extract::Query(params): axum::extract::Query, +) -> Result { + let service = state + .oidc_service + .as_ref() + .ok_or(ApiError::Internal("OIDC not configured".into()))?; + + let stored_csrf: notes_domain::CsrfToken = session + .get("oidc_csrf") + .await + .map_err(|_| ApiError::Internal("Session error".into()))? + .ok_or(ApiError::Validation("Missing CSRF token".into()))?; + + if params.state != stored_csrf.as_ref() { + return Err(ApiError::Validation("Invalid CSRF token".into())); + } + + let stored_pkce: notes_domain::PkceVerifier = session + .get("oidc_pkce") + .await + .map_err(|_| ApiError::Internal("Session error".into()))? + .ok_or(ApiError::Validation("Missing PKCE".into()))?; + let stored_nonce: notes_domain::OidcNonce = session + .get("oidc_nonce") + .await + .map_err(|_| ApiError::Internal("Session error".into()))? + .ok_or(ApiError::Validation("Missing Nonce".into()))?; + + let oidc_user = service + .resolve_callback( + notes_domain::AuthorizationCode::new(params.code), + stored_nonce, + stored_pkce, + ) + .await + .map_err(|e| ApiError::Internal(e.to_string()))?; + + let user = state + .user_service + .find_or_create(&oidc_user.subject, &oidc_user.email) + .await + .map_err(|e| ApiError::Internal(e.to_string()))?; + + let auth_mode = state.config.auth_mode; + + // In session or both mode, create session + if matches!(auth_mode, AuthMode::Session | AuthMode::Both) { + auth_session + .login(&crate::auth::AuthUser(user.clone())) + .await + .map_err(|_| ApiError::Internal("Login failed".into()))?; + } + + // Clean up OIDC state + let _: Option = session + .remove("oidc_csrf") + .await + .map_err(|_| ApiError::Internal("Session error".into()))?; + let _: Option = session + .remove("oidc_pkce") + .await + .map_err(|_| ApiError::Internal("Session error".into()))?; + let _: Option = session + .remove("oidc_nonce") + .await + .map_err(|_| ApiError::Internal("Session error".into()))?; + + // In JWT mode, return token as JSON + #[cfg(feature = "auth-jwt")] + if matches!(auth_mode, AuthMode::Jwt | AuthMode::Both) { + let token = create_jwt_for_user(&user, &state)?; + return Ok(Json(TokenResponse { + access_token: token, + token_type: "Bearer".to_string(), + expires_in: state.config.jwt_expiry_hours * 3600, + }) + .into_response()); + } + + // Session mode: return user info + Ok(Json(UserResponse { + id: user.id, + email: user.email, + created_at: user.created_at, + }) + .into_response()) +} + +/// Fallback OIDC callback when auth-axum-login is not enabled +#[cfg(all(feature = "auth-oidc", not(feature = "auth-axum-login")))] +async fn oidc_callback( + State(state): State, + session: Session, + axum::extract::Query(params): axum::extract::Query, +) -> Result { + let service = state + .oidc_service + .as_ref() + .ok_or(ApiError::Internal("OIDC not configured".into()))?; + + let stored_csrf: notes_domain::CsrfToken = session + .get("oidc_csrf") + .await + .map_err(|_| ApiError::Internal("Session error".into()))? + .ok_or(ApiError::Validation("Missing CSRF token".into()))?; + + if params.state != stored_csrf.as_ref() { + return Err(ApiError::Validation("Invalid CSRF token".into())); + } + + let stored_pkce: notes_domain::PkceVerifier = session + .get("oidc_pkce") + .await + .map_err(|_| ApiError::Internal("Session error".into()))? + .ok_or(ApiError::Validation("Missing PKCE".into()))?; + let stored_nonce: notes_domain::OidcNonce = session + .get("oidc_nonce") + .await + .map_err(|_| ApiError::Internal("Session error".into()))? + .ok_or(ApiError::Validation("Missing Nonce".into()))?; + + let oidc_user = service + .resolve_callback( + notes_domain::AuthorizationCode::new(params.code), + stored_nonce, + stored_pkce, + ) + .await + .map_err(|e| ApiError::Internal(e.to_string()))?; + + let user = state + .user_service + .find_or_create(&oidc_user.subject, &oidc_user.email) + .await + .map_err(|e| ApiError::Internal(e.to_string()))?; + + // Clean up OIDC state + let _: Option = session + .remove("oidc_csrf") + .await + .map_err(|_| ApiError::Internal("Session error".into()))?; + let _: Option = session + .remove("oidc_pkce") + .await + .map_err(|_| ApiError::Internal("Session error".into()))?; + let _: Option = session + .remove("oidc_nonce") + .await + .map_err(|_| ApiError::Internal("Session error".into()))?; + + // Return token as JSON + #[cfg(feature = "auth-jwt")] + { + let token = create_jwt_for_user(&user, &state)?; + return Ok(Json(TokenResponse { + access_token: token, + token_type: "Bearer".to_string(), + expires_in: state.config.jwt_expiry_hours * 3600, + })); + } + + #[cfg(not(feature = "auth-jwt"))] + { + let _ = user; // Suppress unused warning + Err(ApiError::Internal( + "No auth backend available for OIDC callback".to_string(), + )) + } +} diff --git a/notes-api/src/routes/import_export.rs b/notes-api/src/routes/import_export.rs index b5274ab..a8199e7 100644 --- a/notes-api/src/routes/import_export.rs +++ b/notes-api/src/routes/import_export.rs @@ -1,9 +1,8 @@ use axum::{Json, extract::State, http::StatusCode}; -use axum_login::{AuthSession, AuthUser}; use serde::{Deserialize, Serialize}; -use crate::auth::AuthBackend; -use crate::error::{ApiError, ApiResult}; +use crate::error::ApiResult; +use crate::extractors::CurrentUser; use crate::state::AppState; use notes_domain::{Note, NoteFilter, Tag}; @@ -17,14 +16,9 @@ pub struct BackupData { /// GET /api/v1/export pub async fn export_data( State(state): State, - auth: AuthSession, + CurrentUser(user): CurrentUser, ) -> ApiResult> { - let user = auth - .user - .ok_or(ApiError::Domain(notes_domain::DomainError::Unauthorized( - "Login required".to_string(), - )))?; - let user_id = user.id(); + let user_id = user.id; let notes = state .note_repo @@ -39,15 +33,10 @@ pub async fn export_data( /// POST /api/v1/import pub async fn import_data( State(state): State, - auth: AuthSession, + CurrentUser(user): CurrentUser, Json(payload): Json, ) -> ApiResult { - let user = auth - .user - .ok_or(ApiError::Domain(notes_domain::DomainError::Unauthorized( - "Login required".to_string(), - )))?; - let user_id = user.id(); + let user_id = user.id; // 1. Import standalone tags (to ensure even unused tags are restored) for tag in payload.tags { diff --git a/notes-api/src/routes/mod.rs b/notes-api/src/routes/mod.rs index 01f11b3..77b7c2b 100644 --- a/notes-api/src/routes/mod.rs +++ b/notes-api/src/routes/mod.rs @@ -17,10 +17,7 @@ use crate::state::AppState; pub fn api_v1_router() -> Router { let router = Router::new() // Auth routes - .route("/auth/register", post(auth::register)) - .route("/auth/login", post(auth::login)) - .route("/auth/logout", post(auth::logout)) - .route("/auth/me", get(auth::me)) + .nest("/auth", auth::router()) // Note routes .route("/notes", get(notes::list_notes).post(notes::create_note)) .route( diff --git a/notes-api/src/routes/notes.rs b/notes-api/src/routes/notes.rs index f6dc189..2889529 100644 --- a/notes-api/src/routes/notes.rs +++ b/notes-api/src/routes/notes.rs @@ -5,34 +5,29 @@ use axum::{ extract::{Path, Query, State}, http::StatusCode, }; -use axum_login::AuthSession; use uuid::Uuid; use validator::Validate; -use axum_login::AuthUser; use notes_domain::{ CreateNoteRequest as DomainCreateNote, NoteTitle, TagName, UpdateNoteRequest as DomainUpdateNote, }; -use crate::auth::AuthBackend; -use crate::dto::{CreateNoteRequest, ListNotesQuery, NoteResponse, SearchQuery, UpdateNoteRequest}; use crate::error::{ApiError, ApiResult}; use crate::state::AppState; +use crate::{ + dto::{CreateNoteRequest, ListNotesQuery, NoteResponse, SearchQuery, UpdateNoteRequest}, + extractors::CurrentUser, +}; /// List notes with optional filtering /// GET /api/v1/notes pub async fn list_notes( State(state): State, - auth: AuthSession, + CurrentUser(user): CurrentUser, Query(query): Query, ) -> ApiResult>> { - let user = auth - .user - .ok_or(ApiError::Domain(notes_domain::DomainError::Unauthorized( - "Login required".to_string(), - )))?; - let user_id = user.id(); + let user_id = user.id; // Build the filter, looking up tag_id by name if needed let mut filter = notes_domain::NoteFilter::new(); @@ -59,15 +54,10 @@ pub async fn list_notes( /// POST /api/v1/notes pub async fn create_note( State(state): State, - auth: AuthSession, + CurrentUser(user): CurrentUser, Json(payload): Json, ) -> ApiResult<(StatusCode, Json)> { - let user = auth - .user - .ok_or(ApiError::Domain(notes_domain::DomainError::Unauthorized( - "Login required".to_string(), - )))?; - let user_id = user.id(); + let user_id = user.id; // Validate input payload @@ -113,15 +103,10 @@ pub async fn create_note( /// GET /api/v1/notes/:id pub async fn get_note( State(state): State, - auth: AuthSession, + CurrentUser(user): CurrentUser, Path(id): Path, ) -> ApiResult> { - let user = auth - .user - .ok_or(ApiError::Domain(notes_domain::DomainError::Unauthorized( - "Login required".to_string(), - )))?; - let user_id = user.id(); + let user_id = user.id; let note = state.note_service.get_note(id, user_id).await?; @@ -132,16 +117,11 @@ pub async fn get_note( /// PATCH /api/v1/notes/:id pub async fn update_note( State(state): State, - auth: AuthSession, + CurrentUser(user): CurrentUser, Path(id): Path, Json(payload): Json, ) -> ApiResult> { - let user = auth - .user - .ok_or(ApiError::Domain(notes_domain::DomainError::Unauthorized( - "Login required".to_string(), - )))?; - let user_id = user.id(); + let user_id = user.id; // Validate input payload @@ -195,15 +175,10 @@ pub async fn update_note( /// DELETE /api/v1/notes/:id pub async fn delete_note( State(state): State, - auth: AuthSession, + CurrentUser(user): CurrentUser, Path(id): Path, ) -> ApiResult { - let user = auth - .user - .ok_or(ApiError::Domain(notes_domain::DomainError::Unauthorized( - "Login required".to_string(), - )))?; - let user_id = user.id(); + let user_id = user.id; state.note_service.delete_note(id, user_id).await?; @@ -214,15 +189,10 @@ pub async fn delete_note( /// GET /api/v1/notes/search pub async fn search_notes( State(state): State, - auth: AuthSession, + CurrentUser(user): CurrentUser, Query(query): Query, ) -> ApiResult>> { - let user = auth - .user - .ok_or(ApiError::Domain(notes_domain::DomainError::Unauthorized( - "Login required".to_string(), - )))?; - let user_id = user.id(); + let user_id = user.id; let notes = state.note_service.search_notes(user_id, &query.q).await?; let response: Vec = notes.into_iter().map(NoteResponse::from).collect(); @@ -234,15 +204,10 @@ pub async fn search_notes( /// GET /api/v1/notes/:id/versions pub async fn list_note_versions( State(state): State, - auth: AuthSession, + CurrentUser(user): CurrentUser, Path(id): Path, ) -> ApiResult>> { - let user = auth - .user - .ok_or(ApiError::Domain(notes_domain::DomainError::Unauthorized( - "Login required".to_string(), - )))?; - let user_id = user.id(); + let user_id = user.id; let versions = state.note_service.list_note_versions(id, user_id).await?; let response: Vec = versions @@ -260,15 +225,10 @@ pub async fn list_note_versions( #[cfg(feature = "smart-features")] pub async fn get_related_notes( State(state): State, - auth: AuthSession, + CurrentUser(user): CurrentUser, Path(id): Path, ) -> ApiResult>> { - let user = auth - .user - .ok_or(ApiError::Domain(notes_domain::DomainError::Unauthorized( - "Login required".to_string(), - )))?; - let user_id = user.id(); + let user_id = user.id; // Verify access to the source note state.note_service.get_note(id, user_id).await?; diff --git a/notes-api/src/routes/tags.rs b/notes-api/src/routes/tags.rs index 911d30a..09580a8 100644 --- a/notes-api/src/routes/tags.rs +++ b/notes-api/src/routes/tags.rs @@ -5,29 +5,25 @@ use axum::{ extract::{Path, State}, http::StatusCode, }; -use axum_login::{AuthSession, AuthUser}; use uuid::Uuid; use validator::Validate; use notes_domain::TagName; -use crate::auth::AuthBackend; -use crate::dto::{CreateTagRequest, RenameTagRequest, TagResponse}; use crate::error::{ApiError, ApiResult}; use crate::state::AppState; +use crate::{ + dto::{CreateTagRequest, RenameTagRequest, TagResponse}, + extractors::CurrentUser, +}; /// List all tags for the user /// GET /api/v1/tags pub async fn list_tags( State(state): State, - auth: AuthSession, + CurrentUser(user): CurrentUser, ) -> ApiResult>> { - let user = auth - .user - .ok_or(ApiError::Domain(notes_domain::DomainError::Unauthorized( - "Login required".to_string(), - )))?; - let user_id = user.id(); + let user_id = user.id; let tags = state.tag_service.list_tags(user_id).await?; let response: Vec = tags.into_iter().map(TagResponse::from).collect(); @@ -39,15 +35,10 @@ pub async fn list_tags( /// POST /api/v1/tags pub async fn create_tag( State(state): State, - auth: AuthSession, + CurrentUser(user): CurrentUser, Json(payload): Json, ) -> ApiResult<(StatusCode, Json)> { - let user = auth - .user - .ok_or(ApiError::Domain(notes_domain::DomainError::Unauthorized( - "Login required".to_string(), - )))?; - let user_id = user.id(); + let user_id = user.id; payload .validate() @@ -66,16 +57,11 @@ pub async fn create_tag( /// PATCH /api/v1/tags/:id pub async fn rename_tag( State(state): State, - auth: AuthSession, + CurrentUser(user): CurrentUser, Path(id): Path, Json(payload): Json, ) -> ApiResult> { - let user = auth - .user - .ok_or(ApiError::Domain(notes_domain::DomainError::Unauthorized( - "Login required".to_string(), - )))?; - let user_id = user.id(); + let user_id = user.id; payload .validate() @@ -94,15 +80,10 @@ pub async fn rename_tag( /// DELETE /api/v1/tags/:id pub async fn delete_tag( State(state): State, - auth: AuthSession, + CurrentUser(user): CurrentUser, Path(id): Path, ) -> ApiResult { - let user = auth - .user - .ok_or(ApiError::Domain(notes_domain::DomainError::Unauthorized( - "Login required".to_string(), - )))?; - let user_id = user.id(); + let user_id = user.id; state.tag_service.delete_tag(id, user_id).await?; diff --git a/notes-api/src/state.rs b/notes-api/src/state.rs index 1f6d97e..61f7fb9 100644 --- a/notes-api/src/state.rs +++ b/notes-api/src/state.rs @@ -1,45 +1,123 @@ use std::sync::Arc; -use crate::config::Config; -use notes_domain::{ - NoteRepository, NoteService, TagRepository, TagService, UserRepository, UserService, -}; +use crate::config::{AuthMode, Config}; +use notes_domain::{NoteRepository, NoteService, TagRepository, TagService, UserService}; + +#[cfg(feature = "auth-jwt")] +use notes_infra::auth::jwt::{JwtConfig, JwtValidator}; +#[cfg(feature = "auth-oidc")] +use notes_infra::auth::oidc::OidcService; /// Application state holding all dependencies #[derive(Clone)] pub struct AppState { pub note_repo: Arc, pub tag_repo: Arc, - pub user_repo: Arc, #[cfg(feature = "smart-features")] pub link_repo: Arc, pub note_service: Arc, pub tag_service: Arc, pub user_service: Arc, pub config: Config, + #[cfg(feature = "auth-oidc")] + pub oidc_service: Option>, + #[cfg(feature = "auth-jwt")] + pub jwt_validator: Option>, } impl AppState { - pub fn new( + pub async fn new( note_repo: Arc, tag_repo: Arc, - user_repo: Arc, #[cfg(feature = "smart-features")] link_repo: Arc, note_service: Arc, tag_service: Arc, user_service: Arc, config: Config, - ) -> Self { - Self { + ) -> anyhow::Result { + #[cfg(feature = "auth-oidc")] + let oidc_service = if let (Some(issuer), Some(id), secret, Some(redirect), resource_id) = ( + &config.oidc_issuer, + &config.oidc_client_id, + &config.oidc_client_secret, + &config.oidc_redirect_url, + &config.oidc_resource_id, + ) { + tracing::info!("Initializing OIDC service with issuer: {}", issuer); + + // Construct newtypes from config strings + let issuer_url = notes_domain::IssuerUrl::new(issuer) + .map_err(|e| anyhow::anyhow!("Invalid OIDC issuer URL: {}", e))?; + let client_id = notes_domain::ClientId::new(id) + .map_err(|e| anyhow::anyhow!("Invalid OIDC client ID: {}", e))?; + let client_secret = secret.as_ref().map(|s| notes_domain::ClientSecret::new(s)); + let redirect_url = notes_domain::RedirectUrl::new(redirect) + .map_err(|e| anyhow::anyhow!("Invalid OIDC redirect URL: {}", e))?; + let resource = resource_id + .as_ref() + .map(|r| notes_domain::ResourceId::new(r)) + .transpose() + .map_err(|e| anyhow::anyhow!("Invalid OIDC resource ID: {}", e))?; + + Some(Arc::new( + OidcService::new(issuer_url, client_id, client_secret, redirect_url, resource) + .await?, + )) + } else { + None + }; + + #[cfg(feature = "auth-jwt")] + let jwt_validator = if matches!(config.auth_mode, AuthMode::Jwt | AuthMode::Both) { + // Use provided secret or fall back to a development secret + let secret = if let Some(ref s) = config.jwt_secret { + if s.is_empty() { None } else { Some(s.clone()) } + } else { + None + }; + + let secret = match secret { + Some(s) => s, + None => { + if config.is_production { + anyhow::bail!( + "JWT_SECRET is required when AUTH_MODE is 'jwt' or 'both' in production" + ); + } + // Use a development-only default secret + tracing::warn!( + "⚠️ JWT_SECRET not set - using insecure development secret. DO NOT USE IN PRODUCTION!" + ); + "k-template-dev-secret-not-for-production-use-only".to_string() + } + }; + + tracing::info!("Initializing JWT validator"); + let jwt_config = JwtConfig::new( + secret, + config.jwt_issuer.clone(), + config.jwt_audience.clone(), + Some(config.jwt_expiry_hours), + config.is_production, + )?; + Some(Arc::new(JwtValidator::new(jwt_config))) + } else { + None + }; + + Ok(Self { note_repo, tag_repo, - user_repo, #[cfg(feature = "smart-features")] link_repo, note_service, tag_service, user_service, config, - } + #[cfg(feature = "auth-oidc")] + oidc_service, + #[cfg(feature = "auth-jwt")] + jwt_validator, + }) } } diff --git a/notes-domain/src/errors.rs b/notes-domain/src/errors.rs index 2f4cec1..8c5de6e 100644 --- a/notes-domain/src/errors.rs +++ b/notes-domain/src/errors.rs @@ -91,6 +91,12 @@ impl DomainError { } } +impl From for DomainError { + fn from(error: crate::value_objects::ValidationError) -> Self { + DomainError::ValidationError(error.to_string()) + } +} + /// Result type alias for domain operations pub type DomainResult = Result; diff --git a/notes-domain/src/services.rs b/notes-domain/src/services.rs index 348ff20..de139aa 100644 --- a/notes-domain/src/services.rs +++ b/notes-domain/src/services.rs @@ -375,36 +375,46 @@ impl UserService { Self { user_repo } } - /// Find or create a user by OIDC subject - /// This is the main entry point for OIDC authentication - pub async fn find_or_create_by_subject( - &self, - subject: &str, - email: Email, - ) -> DomainResult { + pub async fn find_or_create(&self, subject: &str, email: &str) -> DomainResult { + // 1. Try to find by subject (OIDC id) if let Some(user) = self.user_repo.find_by_subject(subject).await? { - Ok(user) - } else { - let user = User::new(subject, email); - self.user_repo.save(&user).await?; - Ok(user) + return Ok(user); } + + // 2. Try to find by email + if let Some(mut user) = self.user_repo.find_by_email(email).await? { + // Link subject if missing (account linking logic) + if user.subject != subject { + user.subject = subject.to_string(); + self.user_repo.save(&user).await?; + } + return Ok(user); + } + + // 3. Create new user + let email = Email::try_from(email)?; + let user = User::new(subject, email); + self.user_repo.save(&user).await?; + + Ok(user) } - /// Get a user by ID - pub async fn get_user(&self, id: Uuid) -> DomainResult { + pub async fn find_by_id(&self, id: Uuid) -> DomainResult { self.user_repo .find_by_id(id) .await? .ok_or(DomainError::UserNotFound(id)) } - /// Delete a user and all associated data - pub async fn delete_user(&self, id: Uuid) -> DomainResult<()> { - // Note: In practice, we'd also need to delete notes and tags - // This would be handled by cascade delete in the database - // or by coordinating with other services - self.user_repo.delete(id).await + pub async fn find_by_email(&self, email: &str) -> DomainResult> { + self.user_repo.find_by_email(email).await + } + + pub async fn create_local(&self, email: &str, password_hash: &str) -> DomainResult { + let email = Email::try_from(email)?; + let user = User::new_local(email, password_hash); + self.user_repo.save(&user).await?; + Ok(user) } } @@ -889,7 +899,7 @@ mod tests { let email = Email::try_from("test@example.com").unwrap(); let user = service - .find_or_create_by_subject("oidc|123", email) + .find_or_create("oidc|123", email.as_ref()) .await .unwrap(); @@ -903,13 +913,13 @@ mod tests { let email1 = Email::try_from("test@example.com").unwrap(); let user1 = service - .find_or_create_by_subject("oidc|123", email1) + .find_or_create("oidc|123", email1.as_ref()) .await .unwrap(); let email2 = Email::try_from("test@example.com").unwrap(); let user2 = service - .find_or_create_by_subject("oidc|123", email2) + .find_or_create("oidc|123", email2.as_ref()) .await .unwrap(); diff --git a/notes-infra/src/auth/mod.rs b/notes-infra/src/auth/mod.rs index ac01f0b..3760d03 100644 --- a/notes-infra/src/auth/mod.rs +++ b/notes-infra/src/auth/mod.rs @@ -1,6 +1,6 @@ #[cfg(feature = "auth-axum-login")] -mod axum_login; +pub mod axum_login; #[cfg(feature = "auth-jwt")] -mod jwt; +pub mod jwt; #[cfg(feature = "auth-oidc")] -mod oidc; +pub mod oidc; diff --git a/notes-infra/src/session_store.rs b/notes-infra/src/session_store.rs index e9f5bee..edb657f 100644 --- a/notes-infra/src/session_store.rs +++ b/notes-infra/src/session_store.rs @@ -1 +1,2 @@ pub use k_core::session::store::InfraSessionStore; +pub use tower_sessions::{Expiry, SessionManagerLayer}; -- 2.49.1 From 3d9c72a7ef5cd3d96c07124c11156431ed4fb3f4 Mon Sep 17 00:00:00 2001 From: Gabriel Kaszewski Date: Tue, 6 Jan 2026 21:10:57 +0100 Subject: [PATCH 5/5] feat: Implement OIDC authentication with JWT token handling and dynamic auth configuration --- .../public/locales/de/translation.json | 5 +- .../public/locales/en/translation.json | 5 +- .../public/locales/es/translation.json | 5 +- .../public/locales/fr/translation.json | 5 +- .../public/locales/pl/translation.json | 5 +- k-notes-frontend/src/App.tsx | 3 + k-notes-frontend/src/hooks/use-auth.ts | 51 ++++++++- k-notes-frontend/src/hooks/useConfig.ts | 6 + k-notes-frontend/src/lib/api.ts | 37 ++++++- k-notes-frontend/src/pages/login.tsx | 103 ++++++++++++------ k-notes-frontend/src/pages/oidc-callback.tsx | 52 +++++++++ k-notes-frontend/src/pages/register.tsx | 6 +- notes-api/Cargo.toml | 2 +- notes-api/src/config.rs | 10 +- notes-api/src/dto.rs | 5 + notes-api/src/routes/auth.rs | 34 +++--- notes-api/src/routes/config.rs | 6 + 17 files changed, 265 insertions(+), 75 deletions(-) create mode 100644 k-notes-frontend/src/pages/oidc-callback.tsx diff --git a/k-notes-frontend/public/locales/de/translation.json b/k-notes-frontend/public/locales/de/translation.json index 4e3f97c..b4a0659 100644 --- a/k-notes-frontend/public/locales/de/translation.json +++ b/k-notes-frontend/public/locales/de/translation.json @@ -92,5 +92,8 @@ "Update": "Aktualisieren", "Welcome back": "Willkommen zurück", "work, todo, ideas": "Arbeit, Aufgaben, Ideen", - "Your notes will appear here. Click + to create one.": "Deine Notizen werden hier erscheinen. Klicke +, um eine zu erstellen." + "Your notes will appear here. Click + to create one.": "Deine Notizen werden hier erscheinen. Klicke +, um eine zu erstellen.", + "Sign in with SSO": "Mit SSO anmelden", + "Or continue with": "Oder fortfahren mit", + "Completing sign in...": "Anmeldung wird abgeschlossen..." } \ No newline at end of file diff --git a/k-notes-frontend/public/locales/en/translation.json b/k-notes-frontend/public/locales/en/translation.json index 203c5f4..acd4f40 100644 --- a/k-notes-frontend/public/locales/en/translation.json +++ b/k-notes-frontend/public/locales/en/translation.json @@ -92,5 +92,8 @@ "Update": "Update", "Welcome back": "Welcome back", "work, todo, ideas": "work, todo, ideas", - "Your notes will appear here. Click + to create one.": "Your notes will appear here. Click + to create one." + "Your notes will appear here. Click + to create one.": "Your notes will appear here. Click + to create one.", + "Sign in with SSO": "Sign in with SSO", + "Or continue with": "Or continue with", + "Completing sign in...": "Completing sign in..." } \ No newline at end of file diff --git a/k-notes-frontend/public/locales/es/translation.json b/k-notes-frontend/public/locales/es/translation.json index 804f3af..84891c0 100644 --- a/k-notes-frontend/public/locales/es/translation.json +++ b/k-notes-frontend/public/locales/es/translation.json @@ -96,5 +96,8 @@ "Update": "Actualizar", "Welcome back": "Bienvenido de nuevo", "work, todo, ideas": "trabajo, tareas, ideas", - "Your notes will appear here. Click + to create one.": "Tus notas aparecerán aquí. Haz clic en + para crear una." + "Your notes will appear here. Click + to create one.": "Tus notas aparecerán aquí. Haz clic en + para crear una.", + "Sign in with SSO": "Iniciar sesión con SSO", + "Or continue with": "O continuar con", + "Completing sign in...": "Completando inicio de sesión..." } \ No newline at end of file diff --git a/k-notes-frontend/public/locales/fr/translation.json b/k-notes-frontend/public/locales/fr/translation.json index d0d6ff7..812bcde 100644 --- a/k-notes-frontend/public/locales/fr/translation.json +++ b/k-notes-frontend/public/locales/fr/translation.json @@ -96,5 +96,8 @@ "Update": "Mettre à jour", "Welcome back": "Bon retour", "work, todo, ideas": "travail, tâches, idées", - "Your notes will appear here. Click + to create one.": "Tes notes apparaîtront ici. Clique sur + pour en créer une." + "Your notes will appear here. Click + to create one.": "Tes notes apparaîtront ici. Clique sur + pour en créer une.", + "Sign in with SSO": "Se connecter avec SSO", + "Or continue with": "Ou continuer avec", + "Completing sign in...": "Connexion en cours..." } \ No newline at end of file diff --git a/k-notes-frontend/public/locales/pl/translation.json b/k-notes-frontend/public/locales/pl/translation.json index 68e0955..7040651 100644 --- a/k-notes-frontend/public/locales/pl/translation.json +++ b/k-notes-frontend/public/locales/pl/translation.json @@ -100,5 +100,8 @@ "Update": "Aktualizuj", "Welcome back": "Witaj ponownie", "work, todo, ideas": "praca, zadania, pomysły", - "Your notes will appear here. Click + to create one.": "Twoje notatki pojawią się tutaj. Kliknij +, aby utworzyć notatkę." + "Your notes will appear here. Click + to create one.": "Twoje notatki pojawią się tutaj. Kliknij +, aby utworzyć notatkę.", + "Sign in with SSO": "Zaloguj się przez SSO", + "Or continue with": "Lub kontynuuj przez", + "Completing sign in...": "Kończenie logowania..." } \ No newline at end of file diff --git a/k-notes-frontend/src/App.tsx b/k-notes-frontend/src/App.tsx index e95bc14..d2010a7 100644 --- a/k-notes-frontend/src/App.tsx +++ b/k-notes-frontend/src/App.tsx @@ -5,6 +5,7 @@ import LoginPage from "@/pages/login"; import RegisterPage from "@/pages/register"; import DashboardPage from "@/pages/dashboard"; import PrivacyPolicyPage from "@/pages/privacy-policy"; +import OidcCallbackPage from "@/pages/oidc-callback"; import Layout from "@/components/layout"; import { useSync } from "@/lib/sync"; import { useMobileStatusBar } from "@/hooks/use-mobile-status-bar"; @@ -17,6 +18,7 @@ function App() { {/* Public Routes (accessible to everyone) */} } /> + } /> {/* Public Routes (only accessible if NOT logged in) */} }> @@ -40,3 +42,4 @@ function App() { } export default App; + diff --git a/k-notes-frontend/src/hooks/use-auth.ts b/k-notes-frontend/src/hooks/use-auth.ts index 82d0145..5a867c5 100644 --- a/k-notes-frontend/src/hooks/use-auth.ts +++ b/k-notes-frontend/src/hooks/use-auth.ts @@ -1,5 +1,5 @@ import { useQuery, useMutation, useQueryClient } from "@tanstack/react-query"; -import { api } from "@/lib/api"; +import { api, setAuthToken, clearAuthToken, getBaseUrl } from "@/lib/api"; import { useNavigate } from "react-router-dom"; export interface User { @@ -8,6 +8,20 @@ export interface User { created_at: string; } +// Token response from JWT/OIDC login +export interface TokenResponse { + access_token: string; + token_type: string; + expires_in: number; +} + +// Login can return either User (session mode) or Token (JWT mode) +export type LoginResult = User | TokenResponse; + +function isTokenResponse(result: LoginResult): result is TokenResponse { + return 'access_token' in result; +} + // Fetch current user async function fetchUser(): Promise { try { @@ -35,8 +49,13 @@ export function useLogin() { const navigate = useNavigate(); return useMutation({ - mutationFn: (credentials: any) => api.post("/auth/login", credentials), - onSuccess: () => { + mutationFn: (credentials: { email: string; password: string }): Promise => + api.post("/auth/login", credentials), + onSuccess: (result: LoginResult) => { + // If we got a token response, store the token + if (isTokenResponse(result)) { + setAuthToken(result.access_token); + } queryClient.invalidateQueries({ queryKey: ["user"] }); navigate("/"); }, @@ -48,8 +67,13 @@ export function useRegister() { const navigate = useNavigate(); return useMutation({ - mutationFn: (credentials: any) => api.post("/auth/register", credentials), - onSuccess: () => { + mutationFn: (credentials: { email: string; password: string }): Promise => + api.post("/auth/register", credentials), + onSuccess: (result: LoginResult) => { + // If we got a token response, store the token + if (isTokenResponse(result)) { + setAuthToken(result.access_token); + } queryClient.invalidateQueries({ queryKey: ["user"] }); navigate("/"); }, @@ -63,8 +87,25 @@ export function useLogout() { return useMutation({ mutationFn: () => api.post("/auth/logout", {}), onSuccess: () => { + // Clear both session data and JWT token + clearAuthToken(); + queryClient.setQueryData(["user"], null); + navigate("/login"); + }, + onError: () => { + // Even on error, clear local state + clearAuthToken(); queryClient.setQueryData(["user"], null); navigate("/login"); }, }); } + +// Hook to initiate OIDC login flow +export function useOidcLogin() { + return () => { + // Redirect to OIDC login endpoint + window.location.href = `${getBaseUrl()}/api/v1/auth/login/oidc`; + }; +} + diff --git a/k-notes-frontend/src/hooks/useConfig.ts b/k-notes-frontend/src/hooks/useConfig.ts index 91f9305..7265611 100644 --- a/k-notes-frontend/src/hooks/useConfig.ts +++ b/k-notes-frontend/src/hooks/useConfig.ts @@ -2,8 +2,13 @@ import { useQuery } from "@tanstack/react-query"; import { api } from "@/lib/api"; +export type AuthMode = 'session' | 'jwt' | 'both'; + export interface ConfigResponse { allow_registration: boolean; + auth_mode: AuthMode; + oidc_enabled: boolean; + password_login_enabled: boolean; } export function useConfig() { @@ -13,3 +18,4 @@ export function useConfig() { staleTime: Infinity, // Config rarely changes }); } + diff --git a/k-notes-frontend/src/lib/api.ts b/k-notes-frontend/src/lib/api.ts index 7191d3c..858db52 100644 --- a/k-notes-frontend/src/lib/api.ts +++ b/k-notes-frontend/src/lib/api.ts @@ -6,6 +6,21 @@ declare global { } } +const TOKEN_STORAGE_KEY = 'k_notes_auth_token'; + +// JWT Token management +export function setAuthToken(token: string): void { + localStorage.setItem(TOKEN_STORAGE_KEY, token); +} + +export function getAuthToken(): string | null { + return localStorage.getItem(TOKEN_STORAGE_KEY); +} + +export function clearAuthToken(): void { + localStorage.removeItem(TOKEN_STORAGE_KEY); +} + const getApiUrl = () => { // 1. Runtime config (Docker) if (window.env?.API_URL) { @@ -40,17 +55,22 @@ export class ApiError extends Error { async function fetchWithAuth(endpoint: string, options: RequestInit = {}) { const url = `${getApiUrl()}${endpoint}`; + const token = getAuthToken(); - const headers = { + const headers: Record = { "Content-Type": "application/json", - ...options.headers, + ...(options.headers as Record || {}), }; + // Add Authorization header if we have a JWT token + if (token) { + headers["Authorization"] = `Bearer ${token}`; + } + const config: RequestInit = { ...options, headers, - credentials: "include", // Important for cookies! - // signal: controller.signal, // Removing signal, using race instead + credentials: "include", // Still include for session-based auth }; try { @@ -60,8 +80,6 @@ async function fetchWithAuth(endpoint: string, options: RequestInit = {}) { ); const response = (await Promise.race([fetchPromise, timeoutPromise])) as Response; - // clearTimeout(timeoutId); // Not needed with race logic here (though leaking timer? No, race settles.) - if (!response.ok) { // Try to parse error message @@ -109,11 +127,18 @@ export const api = { }), delete: (endpoint: string) => fetchWithAuth(endpoint, { method: "DELETE" }), exportData: async () => { + const token = getAuthToken(); + const headers: Record = {}; + if (token) { + headers["Authorization"] = `Bearer ${token}`; + } const response = await fetch(`${getApiUrl()}/export`, { credentials: "include", + headers, }); if (!response.ok) throw new ApiError(response.status, "Failed to export data"); return response.blob(); }, importData: (data: any) => api.post("/import", data), }; + diff --git a/k-notes-frontend/src/pages/login.tsx b/k-notes-frontend/src/pages/login.tsx index acfa012..926313f 100644 --- a/k-notes-frontend/src/pages/login.tsx +++ b/k-notes-frontend/src/pages/login.tsx @@ -1,11 +1,11 @@ import { useState } from "react"; import { useForm } from "react-hook-form"; -import { Settings } from "lucide-react"; +import { Settings, ExternalLink } from "lucide-react"; import { SettingsDialog } from "@/components/settings-dialog"; import { zodResolver } from "@hookform/resolvers/zod"; import { z } from "zod"; import { Link } from "react-router-dom"; -import { useLogin } from "@/hooks/use-auth"; +import { useLogin, useOidcLogin } from "@/hooks/use-auth"; import { useConfig } from "@/hooks/useConfig"; import { Button } from "@/components/ui/button"; import { Input } from "@/components/ui/input"; @@ -26,6 +26,7 @@ export default function LoginPage() { const { mutate: login, isPending } = useLogin(); const { data: config } = useConfig(); const { t } = useTranslation(); + const startOidcLogin = useOidcLogin(); const form = useForm({ resolver: zodResolver(loginSchema), @@ -63,40 +64,71 @@ export default function LoginPage() { {t("Enter your email to sign in to your account")} - -
- - ( - - {t("Email")} - - - - - - )} - /> - ( - - {t("Password")} - - - - - - )} - /> - - - + {/* Divider only if both OIDC and password login are enabled */} + {config?.password_login_enabled && ( +
+
+ +
+
+ + {t("Or continue with")} + +
+
+ )} + + )} + + {/* Email/Password Form - only show if password login is enabled */} + {config?.password_login_enabled !== false && ( +
+ + ( + + {t("Email")} + + + + + + )} + /> + ( + + {t("Password")} + + + + + + )} + /> + + + + )}
{config?.allow_registration !== false && ( @@ -113,3 +145,4 @@ export default function LoginPage() { ); } + diff --git a/k-notes-frontend/src/pages/oidc-callback.tsx b/k-notes-frontend/src/pages/oidc-callback.tsx new file mode 100644 index 0000000..aac735a --- /dev/null +++ b/k-notes-frontend/src/pages/oidc-callback.tsx @@ -0,0 +1,52 @@ +import { useEffect } from "react"; +import { useNavigate, useSearchParams } from "react-router-dom"; +import { useQueryClient } from "@tanstack/react-query"; +import { setAuthToken } from "@/lib/api"; +import { useTranslation } from "react-i18next"; + +/** + * OIDC Callback Handler + * + * This page handles redirects from the OIDC provider after authentication. + * + * In Session mode: The backend sets a session cookie during the callback, + * so we just need to redirect to the dashboard. + * + * In JWT mode: The backend redirects here with a token in the URL fragment + * or query params, which we need to extract and store. + */ +export default function OidcCallbackPage() { + const navigate = useNavigate(); + const [searchParams] = useSearchParams(); + const queryClient = useQueryClient(); + const { t } = useTranslation(); + + useEffect(() => { + // Check for token in URL hash (implicit flow) or query params + const hashParams = new URLSearchParams(window.location.hash.slice(1)); + const accessToken = + hashParams.get("access_token") || searchParams.get("access_token"); + + if (accessToken) { + // JWT mode: store the token + setAuthToken(accessToken); + } + + // Invalidate user query to refetch with new auth state + queryClient.invalidateQueries({ queryKey: ["user"] }); + + // Redirect to dashboard + navigate("/", { replace: true }); + }, [navigate, searchParams, queryClient]); + + return ( +
+
+
+

+ {t("Completing sign in...")} +

+
+
+ ); +} diff --git a/k-notes-frontend/src/pages/register.tsx b/k-notes-frontend/src/pages/register.tsx index 1e95b8d..0adb1a4 100644 --- a/k-notes-frontend/src/pages/register.tsx +++ b/k-notes-frontend/src/pages/register.tsx @@ -36,10 +36,14 @@ export default function RegisterPage() { if (!isConfigLoading && config?.allow_registration === false) { toast.error(t("Registration is currently disabled")); navigate("/login"); + } else if (!isConfigLoading && config?.password_login_enabled === false) { + // Registration requires password login to be enabled + toast.error(t("Registration is not available")); + navigate("/login"); } }, [config, isConfigLoading, navigate, t]); - if (isConfigLoading || config?.allow_registration === false) { + if (isConfigLoading || config?.allow_registration === false || config?.password_login_enabled === false) { return null; // Or a loading spinner } diff --git a/notes-api/Cargo.toml b/notes-api/Cargo.toml index 69cfe7a..9fa77d3 100644 --- a/notes-api/Cargo.toml +++ b/notes-api/Cargo.toml @@ -5,7 +5,7 @@ edition = "2024" default-run = "notes-api" [features] -default = ["sqlite", "smart-features", "auth-oidc", "auth-jwt"] +default = ["sqlite", "smart-features"] sqlite = ["notes-infra/sqlite"] postgres = ["notes-infra/postgres"] smart-features = ["notes-infra/smart-features", "notes-infra/broker-nats"] diff --git a/notes-api/src/config.rs b/notes-api/src/config.rs index 7d35876..229954e 100644 --- a/notes-api/src/config.rs +++ b/notes-api/src/config.rs @@ -1,10 +1,10 @@ #[cfg(feature = "smart-features")] use notes_infra::factory::{EmbeddingProvider, VectorProvider}; -use serde::Deserialize; +use serde::{Deserialize, Serialize}; use std::env; /// Authentication mode - determines how the API authenticates requests -#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Deserialize)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Deserialize, Serialize)] #[serde(rename_all = "lowercase")] pub enum AuthMode { /// Session-based authentication using cookies (default for backward compatibility) @@ -66,6 +66,9 @@ pub struct Config { /// Whether the application is running in production mode pub is_production: bool, + + /// Frontend URL for OIDC redirect (defaults to first CORS origin) + pub frontend_url: String, } impl Default for Config { @@ -100,6 +103,7 @@ impl Default for Config { jwt_audience: None, jwt_expiry_hours: 24, is_production: false, + frontend_url: "http://localhost:5173".to_string(), } } } @@ -219,6 +223,8 @@ impl Config { jwt_audience, jwt_expiry_hours, is_production, + frontend_url: env::var("FRONTEND_URL") + .unwrap_or_else(|_| "http://localhost:5173".to_string()), } } } diff --git a/notes-api/src/dto.rs b/notes-api/src/dto.rs index c424759..88dc5ad 100644 --- a/notes-api/src/dto.rs +++ b/notes-api/src/dto.rs @@ -7,6 +7,8 @@ use validator::Validate; use notes_domain::{Email, Note, Password, Tag}; +use crate::config::AuthMode; + /// Request to create a new note #[derive(Debug, Deserialize, Validate)] pub struct CreateNoteRequest { @@ -165,6 +167,9 @@ impl From for NoteVersionResponse { #[derive(Debug, Serialize)] pub struct ConfigResponse { pub allow_registration: bool, + pub auth_mode: AuthMode, + pub oidc_enabled: bool, + pub password_login_enabled: bool, } /// Note Link response DTO diff --git a/notes-api/src/routes/auth.rs b/notes-api/src/routes/auth.rs index 4cc5344..e355435 100644 --- a/notes-api/src/routes/auth.rs +++ b/notes-api/src/routes/auth.rs @@ -387,25 +387,19 @@ async fn oidc_callback( .await .map_err(|_| ApiError::Internal("Session error".into()))?; - // In JWT mode, return token as JSON + // In JWT mode, redirect to frontend with token in URL fragment #[cfg(feature = "auth-jwt")] if matches!(auth_mode, AuthMode::Jwt | AuthMode::Both) { let token = create_jwt_for_user(&user, &state)?; - return Ok(Json(TokenResponse { - access_token: token, - token_type: "Bearer".to_string(), - expires_in: state.config.jwt_expiry_hours * 3600, - }) - .into_response()); + let redirect_url = format!( + "{}/auth/callback#access_token={}", + state.config.frontend_url, token + ); + return Ok(axum::response::Redirect::to(&redirect_url).into_response()); } - // Session mode: return user info - Ok(Json(UserResponse { - id: user.id, - email: user.email, - created_at: user.created_at, - }) - .into_response()) + // Session mode: redirect to frontend (session cookie already set) + Ok(axum::response::Redirect::to(&state.config.frontend_url).into_response()) } /// Fallback OIDC callback when auth-axum-login is not enabled @@ -470,15 +464,15 @@ async fn oidc_callback( .await .map_err(|_| ApiError::Internal("Session error".into()))?; - // Return token as JSON + // Redirect to frontend with token in URL fragment #[cfg(feature = "auth-jwt")] { let token = create_jwt_for_user(&user, &state)?; - return Ok(Json(TokenResponse { - access_token: token, - token_type: "Bearer".to_string(), - expires_in: state.config.jwt_expiry_hours * 3600, - })); + let redirect_url = format!( + "{}/auth/callback#access_token={}", + state.config.frontend_url, token + ); + return Ok(axum::response::Redirect::to(&redirect_url)); } #[cfg(not(feature = "auth-jwt"))] diff --git a/notes-api/src/routes/config.rs b/notes-api/src/routes/config.rs index 94e14c3..09d1c4b 100644 --- a/notes-api/src/routes/config.rs +++ b/notes-api/src/routes/config.rs @@ -10,5 +10,11 @@ use crate::state::AppState; pub async fn get_config(State(state): State) -> ApiResult> { Ok(Json(ConfigResponse { allow_registration: state.config.allow_registration, + auth_mode: state.config.auth_mode, + #[cfg(feature = "auth-oidc")] + oidc_enabled: state.oidc_service.is_some(), + #[cfg(not(feature = "auth-oidc"))] + oidc_enabled: false, + password_login_enabled: cfg!(feature = "auth-axum-login"), })) } -- 2.49.1