refactor (v2): better arch

Co-authored-by: Copilot <copilot@github.com>
This commit is contained in:
2026-06-07 21:19:54 +02:00
parent 0753f3d256
commit 839308ec19
166 changed files with 8553 additions and 884 deletions

View File

@@ -0,0 +1,27 @@
[package]
name = "auth"
version = "0.1.0"
edition = "2024"
[features]
default = []
jwt = ["dep:jsonwebtoken"]
oidc = ["dep:openidconnect", "dep:reqwest"]
[dependencies]
domain = { workspace = true }
async-trait = { workspace = true }
tokio = { workspace = true }
tracing = { workspace = true }
serde = { workspace = true }
anyhow = { workspace = true }
thiserror = { workspace = true }
argon2 = "0.5"
url = "2"
jsonwebtoken = { version = "10", features = ["rust_crypto"], optional = true }
openidconnect = { version = "4", optional = true }
reqwest = { version = "0.12", features = ["json"], optional = true }
[dev-dependencies]
tokio = { workspace = true }

View File

@@ -0,0 +1,30 @@
/// Config for OIDC. Validated when constructing OidcService.
#[derive(Debug, Clone)]
pub struct OidcConfig {
pub issuer_url: String,
pub client_id: String,
pub client_secret: Option<String>,
pub redirect_url: String,
/// Optional audience / resource ID for token validation.
pub resource_id: Option<String>,
}
/// Config for JWT. Validated when constructing JwtValidator.
#[derive(Debug, Clone)]
pub struct JwtConfig {
pub secret: String,
pub issuer: Option<String>,
pub audience: Option<String>,
pub expiry_hours: u64,
}
impl JwtConfig {
pub fn new(secret: impl Into<String>) -> Self {
Self {
secret: secret.into(),
issuer: None,
audience: None,
expiry_hours: 24,
}
}
}

View File

@@ -0,0 +1,100 @@
use jsonwebtoken::{Algorithm, DecodingKey, EncodingKey, Header, Validation, decode, encode};
use serde::{Deserialize, Serialize};
use std::time::{SystemTime, UNIX_EPOCH};
use thiserror::Error;
use domain::user::entity::User;
use crate::config::JwtConfig;
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct JwtClaims {
pub sub: String,
pub email: String,
pub exp: usize,
pub iat: usize,
#[serde(skip_serializing_if = "Option::is_none")]
pub iss: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub aud: Option<String>,
}
#[derive(Debug, Error)]
pub enum JwtError {
#[error("token creation failed: {0}")]
Creation(#[from] jsonwebtoken::errors::Error),
#[error("token expired")]
Expired,
#[error("invalid token: {0}")]
Invalid(String),
}
pub struct JwtValidator {
config: JwtConfig,
encoding_key: EncodingKey,
decoding_key: DecodingKey,
validation: Validation,
}
impl JwtValidator {
pub fn new(config: JwtConfig) -> Self {
let encoding_key = EncodingKey::from_secret(config.secret.as_bytes());
let decoding_key = DecodingKey::from_secret(config.secret.as_bytes());
let mut validation = Validation::new(Algorithm::HS256);
if let Some(ref iss) = config.issuer {
validation.set_issuer(&[iss]);
}
if let Some(ref aud) = config.audience {
validation.set_audience(&[aud]);
}
Self {
config,
encoding_key,
decoding_key,
validation,
}
}
pub fn create_token(&self, user: &User) -> Result<String, JwtError> {
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("system clock before epoch")
.as_secs() as usize;
let claims = JwtClaims {
sub: user.id.as_uuid().to_string(),
email: user.email.as_ref().to_string(),
exp: now + self.config.expiry_hours as usize * 3600,
iat: now,
iss: self.config.issuer.clone(),
aud: self.config.audience.clone(),
};
encode(&Header::new(Algorithm::HS256), &claims, &self.encoding_key)
.map_err(JwtError::Creation)
}
pub fn validate_token(&self, token: &str) -> Result<JwtClaims, JwtError> {
decode::<JwtClaims>(token, &self.decoding_key, &self.validation)
.map(|td| td.claims)
.map_err(|e| match e.kind() {
jsonwebtoken::errors::ErrorKind::ExpiredSignature => JwtError::Expired,
_ => JwtError::Invalid(e.to_string()),
})
}
}
impl std::fmt::Debug for JwtValidator {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("JwtValidator")
.field("issuer", &self.config.issuer)
.field("expiry_hours", &self.config.expiry_hours)
.finish_non_exhaustive()
}
}
#[cfg(test)]
#[path = "tests/jwt.rs"]
mod tests;

View File

@@ -0,0 +1,8 @@
pub mod config;
pub mod password;
#[cfg(feature = "jwt")]
pub mod jwt;
#[cfg(feature = "oidc")]
pub mod oidc;

View File

@@ -0,0 +1,178 @@
use anyhow::{Result, anyhow};
use openidconnect::{
AccessTokenHash, Client, EmptyAdditionalClaims, EndpointMaybeSet, EndpointNotSet, EndpointSet,
OAuth2TokenResponse, PkceCodeChallenge, Scope, StandardErrorResponse, TokenResponse,
UserInfoClaims,
core::{
CoreAuthDisplay, CoreAuthPrompt, CoreAuthenticationFlow, CoreClient, CoreErrorResponseType,
CoreGenderClaim, CoreJsonWebKey, CoreJweContentEncryptionAlgorithm, CoreProviderMetadata,
CoreRevocableToken, CoreRevocationErrorResponse, CoreTokenIntrospectionResponse,
CoreTokenResponse,
},
reqwest,
};
use serde::{Deserialize, Serialize};
use crate::config::OidcConfig;
pub type OidcClient = Client<
EmptyAdditionalClaims,
CoreAuthDisplay,
CoreGenderClaim,
CoreJweContentEncryptionAlgorithm,
CoreJsonWebKey,
CoreAuthPrompt,
StandardErrorResponse<CoreErrorResponseType>,
CoreTokenResponse,
CoreTokenIntrospectionResponse,
CoreRevocableToken,
CoreRevocationErrorResponse,
EndpointSet,
EndpointNotSet,
EndpointNotSet,
EndpointNotSet,
EndpointMaybeSet,
EndpointMaybeSet,
>;
/// Data returned when starting the OIDC authorization flow.
#[derive(Debug, Clone)]
pub struct AuthorizationUrlData {
pub url: url::Url,
pub csrf_token: String,
pub nonce: String,
pub pkce_verifier: String,
}
/// Verified identity returned after a successful callback.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OidcUser {
pub subject: String,
pub email: String,
}
#[derive(Clone)]
pub struct OidcService {
client: OidcClient,
resource_id: Option<String>,
}
impl OidcService {
pub async fn new(config: OidcConfig) -> Result<Self> {
let http_client = reqwest::ClientBuilder::new()
.redirect(reqwest::redirect::Policy::none())
.build()?;
let provider_metadata = CoreProviderMetadata::discover_async(
openidconnect::IssuerUrl::new(config.issuer_url)?,
&http_client,
)
.await?;
let client_secret = config
.client_secret
.filter(|s| !s.trim().is_empty())
.map(openidconnect::ClientSecret::new);
let client = CoreClient::from_provider_metadata(
provider_metadata,
openidconnect::ClientId::new(config.client_id),
client_secret,
)
.set_redirect_uri(openidconnect::RedirectUrl::new(config.redirect_url)?);
Ok(Self {
client,
resource_id: config.resource_id,
})
}
pub fn authorization_url(&self) -> AuthorizationUrlData {
let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256();
let (url, csrf_token, nonce) = self
.client
.authorize_url(
CoreAuthenticationFlow::AuthorizationCode,
openidconnect::CsrfToken::new_random,
openidconnect::Nonce::new_random,
)
.add_scope(Scope::new("profile".into()))
.add_scope(Scope::new("email".into()))
.set_pkce_challenge(pkce_challenge)
.url();
AuthorizationUrlData {
url: url.into(),
csrf_token: csrf_token.secret().clone(),
nonce: nonce.secret().clone(),
pkce_verifier: pkce_verifier.secret().clone(),
}
}
pub async fn exchange_code(
&self,
code: &str,
nonce: &str,
pkce_verifier: &str,
) -> Result<OidcUser> {
let http_client = reqwest::ClientBuilder::new()
.redirect(reqwest::redirect::Policy::none())
.build()?;
let token_response = self
.client
.exchange_code(openidconnect::AuthorizationCode::new(code.to_owned()))?
.set_pkce_verifier(openidconnect::PkceCodeVerifier::new(
pkce_verifier.to_owned(),
))
.request_async(&http_client)
.await?;
let id_token = token_response
.id_token()
.ok_or_else(|| anyhow!("server did not return an ID token"))?;
let mut verifier = self.client.id_token_verifier().clone();
if let Some(ref rid) = self.resource_id {
let rid = rid.clone();
verifier =
verifier.set_other_audience_verifier_fn(move |aud| aud.as_str() == rid.as_str());
}
let oidc_nonce = openidconnect::Nonce::new(nonce.to_owned());
let claims = id_token.claims(&verifier, &oidc_nonce)?;
if let Some(expected_hash) = claims.access_token_hash() {
let actual_hash = AccessTokenHash::from_token(
token_response.access_token(),
id_token.signing_alg()?,
id_token.signing_key(&verifier)?,
)?;
if actual_hash != *expected_hash {
return Err(anyhow!("access token hash mismatch"));
}
}
let email = match claims.email() {
Some(e) => e.as_str().to_owned(),
None => {
tracing::debug!("email absent in ID token, fetching userinfo");
let userinfo: UserInfoClaims<EmptyAdditionalClaims, CoreGenderClaim> = self
.client
.user_info(token_response.access_token().clone(), None)?
.request_async(&http_client)
.await?;
userinfo
.email()
.map(|e| e.as_str().to_owned())
.ok_or_else(|| anyhow!("no verified email in identity provider response"))?
}
};
Ok(OidcUser {
subject: claims.subject().to_string(),
email,
})
}
}

View File

@@ -0,0 +1,51 @@
use argon2::{
Argon2,
password_hash::{
PasswordHash, PasswordHasher as _, PasswordVerifier, SaltString, rand_core::OsRng,
},
};
use async_trait::async_trait;
use domain::{
errors::{DomainError, DomainResult},
user::{
ports::PasswordHasher,
value_objects::{Password, PasswordHash as DomainPasswordHash},
},
};
pub struct Argon2PasswordHasher;
#[async_trait]
impl PasswordHasher for Argon2PasswordHasher {
async fn hash(&self, password: &Password) -> DomainResult<DomainPasswordHash> {
let password_str = password.as_ref().to_owned();
tokio::task::spawn_blocking(move || {
let salt = SaltString::generate(&mut OsRng);
let hash = Argon2::default()
.hash_password(password_str.as_bytes(), &salt)
.map_err(|e| DomainError::Infrastructure(format!("hash failed: {e}")))?;
Ok(DomainPasswordHash::new(hash.to_string()))
})
.await
.map_err(|e| DomainError::Infrastructure(format!("task panicked: {e}")))?
}
async fn verify(&self, password: &Password, hash: &DomainPasswordHash) -> DomainResult<bool> {
let password_str = password.as_ref().to_owned();
let hash_str = hash.as_str().to_owned();
tokio::task::spawn_blocking(move || {
let parsed = PasswordHash::new(&hash_str)
.map_err(|e| DomainError::Infrastructure(format!("invalid hash: {e}")))?;
Ok(Argon2::default()
.verify_password(password_str.as_bytes(), &parsed)
.is_ok())
})
.await
.map_err(|e| DomainError::Infrastructure(format!("task panicked: {e}")))?
}
}
#[cfg(test)]
#[path = "tests/password.rs"]
mod tests;

View File

@@ -0,0 +1,68 @@
use domain::user::{entity::User, value_objects::Email};
use crate::{config::JwtConfig, jwt::JwtValidator};
fn validator() -> JwtValidator {
JwtValidator::new(JwtConfig::new(
"a-test-secret-that-is-long-enough-for-hs256",
))
}
fn user() -> User {
User::new_oidc("sub|123", Email::new("test@example.com").unwrap())
}
#[test]
fn create_and_validate_round_trip() {
let v = validator();
let u = user();
let token = v.create_token(&u).unwrap();
let claims = v.validate_token(&token).unwrap();
assert_eq!(claims.email, "test@example.com");
assert_eq!(claims.sub, u.id.as_uuid().to_string());
}
#[test]
fn wrong_secret_rejects_token() {
let v1 = JwtValidator::new(JwtConfig::new(
"secret-one-long-enough-for-hs256-validation",
));
let v2 = JwtValidator::new(JwtConfig::new(
"secret-two-long-enough-for-hs256-validation",
));
let token = v1.create_token(&user()).unwrap();
assert!(v2.validate_token(&token).is_err());
}
#[test]
fn invalid_token_is_rejected() {
let v = validator();
assert!(v.validate_token("not.a.valid.jwt").is_err());
}
#[test]
fn expired_token_returns_expired_error() {
use crate::jwt::JwtError;
use jsonwebtoken::{Algorithm, EncodingKey, Header, encode};
let secret = "a-test-secret-that-is-long-enough-for-hs256";
let claims = crate::jwt::JwtClaims {
sub: "user-id".into(),
email: "x@example.com".into(),
exp: 1, // epoch + 1 second — already expired
iat: 0,
iss: None,
aud: None,
};
let token = encode(
&Header::new(Algorithm::HS256),
&claims,
&EncodingKey::from_secret(secret.as_bytes()),
)
.unwrap();
let v = JwtValidator::new(JwtConfig::new(secret));
assert!(matches!(v.validate_token(&token), Err(JwtError::Expired)));
}

View File

@@ -0,0 +1,36 @@
use domain::user::{
ports::PasswordHasher,
value_objects::{Password, PasswordHash},
};
use crate::password::Argon2PasswordHasher;
#[tokio::test]
async fn hash_produces_verifiable_hash() {
let hasher = Argon2PasswordHasher;
let password = Password::new("correcthorsebattery").unwrap();
let hash = hasher.hash(&password).await.unwrap();
assert!(hasher.verify(&password, &hash).await.unwrap());
}
#[tokio::test]
async fn wrong_password_does_not_verify() {
let hasher = Argon2PasswordHasher;
let password = Password::new("correcthorsebattery").unwrap();
let wrong = Password::new("wrongpassword12345").unwrap();
let hash = hasher.hash(&password).await.unwrap();
assert!(!hasher.verify(&wrong, &hash).await.unwrap());
}
#[tokio::test]
async fn same_password_produces_different_hashes() {
let hasher = Argon2PasswordHasher;
let password = Password::new("samepassword123").unwrap();
let hash1 = hasher.hash(&password).await.unwrap();
let hash2 = hasher.hash(&password).await.unwrap();
assert_ne!(hash1.as_str(), hash2.as_str());
}

View File

@@ -0,0 +1,11 @@
[package]
name = "event-payload"
version = "0.1.0"
edition = "2024"
[dependencies]
domain = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }
uuid = { workspace = true }
thiserror = { workspace = true }

View File

@@ -0,0 +1,86 @@
use serde::{Deserialize, Serialize};
use uuid::Uuid;
use domain::{
errors::DomainError, events::DomainEvent, note::entity::NoteId, user::entity::UserId,
};
/// Wire-format representation of a DomainEvent.
/// Uses primitive types only — no domain newtypes — so it is stable across
/// schema versions and safe to serialize to any transport (NATS, HTTP, file).
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(tag = "type", content = "data")]
pub enum EventPayload {
NoteCreated { note_id: String, user_id: String },
NoteUpdated { note_id: String, user_id: String },
NoteDeleted { note_id: String, user_id: String },
}
impl EventPayload {
pub fn event_type(&self) -> &'static str {
match self {
Self::NoteCreated { .. } => "NoteCreated",
Self::NoteUpdated { .. } => "NoteUpdated",
Self::NoteDeleted { .. } => "NoteDeleted",
}
}
pub fn to_json(&self) -> Result<Vec<u8>, DomainError> {
serde_json::to_vec(self)
.map_err(|e| DomainError::Infrastructure(format!("serialize failed: {e}")))
}
pub fn from_json(bytes: &[u8]) -> Result<Self, DomainError> {
serde_json::from_slice(bytes)
.map_err(|e| DomainError::Infrastructure(format!("deserialize failed: {e}")))
}
}
impl From<&DomainEvent> for EventPayload {
fn from(event: &DomainEvent) -> Self {
match event {
DomainEvent::NoteCreated { note_id, user_id } => Self::NoteCreated {
note_id: note_id.as_uuid().to_string(),
user_id: user_id.as_uuid().to_string(),
},
DomainEvent::NoteUpdated { note_id, user_id } => Self::NoteUpdated {
note_id: note_id.as_uuid().to_string(),
user_id: user_id.as_uuid().to_string(),
},
DomainEvent::NoteDeleted { note_id, user_id } => Self::NoteDeleted {
note_id: note_id.as_uuid().to_string(),
user_id: user_id.as_uuid().to_string(),
},
}
}
}
impl TryFrom<EventPayload> for DomainEvent {
type Error = DomainError;
fn try_from(payload: EventPayload) -> Result<Self, Self::Error> {
fn parse(s: &str) -> Result<Uuid, DomainError> {
Uuid::parse_str(s)
.map_err(|e| DomainError::Infrastructure(format!("invalid uuid '{s}': {e}")))
}
match payload {
EventPayload::NoteCreated { note_id, user_id } => Ok(DomainEvent::NoteCreated {
note_id: NoteId::from_uuid(parse(&note_id)?),
user_id: UserId::from_uuid(parse(&user_id)?),
}),
EventPayload::NoteUpdated { note_id, user_id } => Ok(DomainEvent::NoteUpdated {
note_id: NoteId::from_uuid(parse(&note_id)?),
user_id: UserId::from_uuid(parse(&user_id)?),
}),
EventPayload::NoteDeleted { note_id, user_id } => Ok(DomainEvent::NoteDeleted {
note_id: NoteId::from_uuid(parse(&note_id)?),
user_id: UserId::from_uuid(parse(&user_id)?),
}),
}
}
}
#[cfg(test)]
#[path = "tests/lib.rs"]
mod tests;

View File

@@ -0,0 +1,88 @@
use domain::{events::DomainEvent, note::entity::NoteId, user::entity::UserId};
use crate::EventPayload;
fn note_created() -> DomainEvent {
DomainEvent::NoteCreated {
note_id: NoteId::new(),
user_id: UserId::new(),
}
}
#[test]
fn domain_event_round_trips_through_payload() {
let event = note_created();
let payload = EventPayload::from(&event);
let recovered = DomainEvent::try_from(payload).unwrap();
// Compare by serialising both — DomainEvent doesn't implement PartialEq.
let EventPayload::NoteCreated {
note_id: orig_nid,
user_id: orig_uid,
} = EventPayload::from(&event)
else {
panic!("wrong variant");
};
let EventPayload::NoteCreated {
note_id: rec_nid,
user_id: rec_uid,
} = EventPayload::from(&recovered)
else {
panic!("wrong variant");
};
assert_eq!(orig_nid, rec_nid);
assert_eq!(orig_uid, rec_uid);
}
#[test]
fn payload_serialises_to_json_and_back() {
let event = note_created();
let payload = EventPayload::from(&event);
let bytes = payload.to_json().unwrap();
let recovered = EventPayload::from_json(&bytes).unwrap();
assert_eq!(payload, recovered);
}
#[test]
fn event_type_label_is_correct() {
let uid = UserId::new();
let nid = NoteId::new();
assert_eq!(
EventPayload::NoteCreated {
note_id: nid.to_string(),
user_id: uid.to_string()
}
.event_type(),
"NoteCreated"
);
assert_eq!(
EventPayload::NoteUpdated {
note_id: nid.to_string(),
user_id: uid.to_string()
}
.event_type(),
"NoteUpdated"
);
assert_eq!(
EventPayload::NoteDeleted {
note_id: nid.to_string(),
user_id: uid.to_string()
}
.event_type(),
"NoteDeleted"
);
}
#[test]
fn invalid_json_returns_error() {
assert!(EventPayload::from_json(b"not json at all").is_err());
}
#[test]
fn invalid_uuid_in_payload_returns_error() {
let payload = EventPayload::NoteCreated {
note_id: "not-a-uuid".into(),
user_id: "also-not-a-uuid".into(),
};
assert!(DomainEvent::try_from(payload).is_err());
}

View File

@@ -0,0 +1,11 @@
[package]
name = "event-publisher-memory"
version = "0.1.0"
edition = "2024"
[dependencies]
domain = { workspace = true }
async-trait = { workspace = true }
futures = { workspace = true }
tokio = { workspace = true }
tracing = { workspace = true }

View File

@@ -0,0 +1,85 @@
use std::sync::Arc;
use async_trait::async_trait;
use futures::stream::BoxStream;
use tokio::sync::broadcast;
use domain::{
errors::DomainError,
events::{DomainEvent, EventConsumer, EventEnvelope, EventPublisher},
};
const CHANNEL_CAPACITY: usize = 256;
/// Shared in-memory event bus backed by a tokio broadcast channel.
/// Create one bus, then hand out publisher and consumer handles from it.
pub struct MemoryEventBus {
sender: broadcast::Sender<DomainEvent>,
}
impl MemoryEventBus {
pub fn new() -> Self {
let (sender, _) = broadcast::channel(CHANNEL_CAPACITY);
Self { sender }
}
pub fn publisher(&self) -> Arc<MemoryEventPublisher> {
Arc::new(MemoryEventPublisher {
sender: self.sender.clone(),
})
}
pub fn consumer(&self) -> Arc<MemoryEventConsumer> {
Arc::new(MemoryEventConsumer {
sender: self.sender.clone(),
})
}
}
impl Default for MemoryEventBus {
fn default() -> Self {
Self::new()
}
}
pub struct MemoryEventPublisher {
sender: broadcast::Sender<DomainEvent>,
}
#[async_trait]
impl EventPublisher for MemoryEventPublisher {
async fn publish(&self, event: &DomainEvent) -> Result<(), DomainError> {
// send() only fails when there are no receivers; that is fine in dev/test.
let _ = self.sender.send(event.clone());
Ok(())
}
}
pub struct MemoryEventConsumer {
sender: broadcast::Sender<DomainEvent>,
}
impl EventConsumer for MemoryEventConsumer {
fn consume(&self) -> BoxStream<'_, Result<EventEnvelope, DomainError>> {
let rx = self.sender.subscribe();
Box::pin(futures::stream::unfold(rx, |mut rx| async move {
loop {
match rx.recv().await {
Ok(event) => {
let envelope = EventEnvelope::noop(event);
return Some((Ok(envelope), rx));
}
Err(broadcast::error::RecvError::Lagged(n)) => {
tracing::warn!("memory event bus: consumer lagged, skipped {n} messages");
}
Err(broadcast::error::RecvError::Closed) => return None,
}
}
}))
}
}
#[cfg(test)]
#[path = "tests/lib.rs"]
mod tests;

View File

@@ -0,0 +1,75 @@
use futures::StreamExt;
use domain::{
events::{DomainEvent, EventConsumer, EventPublisher},
note::entity::NoteId,
user::entity::UserId,
};
use crate::MemoryEventBus;
fn note_updated() -> DomainEvent {
DomainEvent::NoteUpdated {
note_id: NoteId::new(),
user_id: UserId::new(),
}
}
#[tokio::test]
async fn published_event_is_received_by_consumer() {
let bus = MemoryEventBus::new();
let publisher = bus.publisher();
let consumer = bus.consumer();
let event = note_updated();
let mut stream = consumer.consume();
publisher.publish(&event).await.unwrap();
let envelope = stream.next().await.unwrap().unwrap();
assert!(matches!(envelope.event, DomainEvent::NoteUpdated { .. }));
}
#[tokio::test]
async fn ack_on_memory_envelope_is_noop() {
let bus = MemoryEventBus::new();
let publisher = bus.publisher();
let consumer = bus.consumer();
// Subscribe before publishing — broadcast drops messages sent before subscribe.
let mut stream = consumer.consume();
publisher.publish(&note_updated()).await.unwrap();
let envelope = stream.next().await.unwrap().unwrap();
envelope.ack().await.unwrap();
}
#[tokio::test]
async fn multiple_consumers_each_receive_the_event() {
let bus = MemoryEventBus::new();
let publisher = bus.publisher();
let c1 = bus.consumer();
let c2 = bus.consumer();
let mut s1 = c1.consume();
let mut s2 = c2.consume();
publisher.publish(&note_updated()).await.unwrap();
assert!(matches!(
s1.next().await.unwrap().unwrap().event,
DomainEvent::NoteUpdated { .. }
));
assert!(matches!(
s2.next().await.unwrap().unwrap().event,
DomainEvent::NoteUpdated { .. }
));
}
#[tokio::test]
async fn publish_with_no_consumer_does_not_error() {
let bus = MemoryEventBus::new();
let publisher = bus.publisher();
// No consumer — publish should silently succeed.
publisher.publish(&note_updated()).await.unwrap();
}

View File

@@ -0,0 +1,14 @@
[package]
name = "fastembed-adapter"
version = "0.1.0"
edition = "2024"
[dependencies]
domain = { workspace = true }
async-trait = { workspace = true }
tokio = { workspace = true }
tracing = { workspace = true }
fastembed = "5"
[dev-dependencies]
tokio = { workspace = true }

View File

@@ -0,0 +1,88 @@
use std::{
path::PathBuf,
sync::{Arc, Mutex},
};
use async_trait::async_trait;
use fastembed::{EmbeddingModel, TextEmbedding, TextInitOptions};
use domain::{
errors::{DomainError, DomainResult},
smart::ports::EmbeddingGenerator,
};
pub struct FastEmbedConfig {
pub model: EmbeddingModel,
/// Directory used to cache downloaded model files.
/// Defaults to the system cache directory when `None`.
pub cache_dir: Option<PathBuf>,
pub show_download_progress: bool,
}
impl Default for FastEmbedConfig {
fn default() -> Self {
Self {
model: EmbeddingModel::AllMiniLML6V2,
cache_dir: None,
show_download_progress: false,
}
}
}
impl FastEmbedConfig {
pub fn with_model(model: EmbeddingModel) -> Self {
Self {
model,
..Default::default()
}
}
}
pub struct FastEmbedGenerator {
model: Arc<Mutex<TextEmbedding>>,
}
impl FastEmbedGenerator {
/// Initialise the model. Downloads and caches model files on first call.
pub fn new(config: FastEmbedConfig) -> Result<Self, DomainError> {
let mut opts = TextInitOptions::new(config.model)
.with_show_download_progress(config.show_download_progress);
if let Some(dir) = config.cache_dir {
opts = opts.with_cache_dir(dir);
}
let model = TextEmbedding::try_new(opts)
.map_err(|e| DomainError::Infrastructure(format!("fastembed init failed: {e}")))?;
Ok(Self {
model: Arc::new(Mutex::new(model)),
})
}
}
#[async_trait]
impl EmbeddingGenerator for FastEmbedGenerator {
async fn generate(&self, text: &str) -> DomainResult<Vec<f32>> {
let model = Arc::clone(&self.model);
let text = text.to_owned();
tokio::task::spawn_blocking(move || {
let mut guard = model
.lock()
.map_err(|_| DomainError::Infrastructure("model mutex poisoned".into()))?;
guard
.embed(vec![text.as_str()], None)
.map_err(|e| DomainError::Infrastructure(format!("embedding failed: {e}")))?
.into_iter()
.next()
.ok_or_else(|| DomainError::Infrastructure("no embedding returned".into()))
})
.await
.map_err(|e| DomainError::Infrastructure(format!("spawn_blocking panicked: {e}")))?
}
}
#[cfg(test)]
#[path = "tests/lib.rs"]
mod tests;

View File

@@ -0,0 +1,33 @@
use crate::{FastEmbedConfig, FastEmbedGenerator};
use domain::smart::ports::EmbeddingGenerator;
use fastembed::EmbeddingModel;
/// Downloads the model on first run (~90 MB). Run with:
/// cargo test -p fastembed-adapter -- --ignored
#[tokio::test]
#[ignore]
async fn generates_embedding_with_correct_dimension() {
let generator =
FastEmbedGenerator::new(FastEmbedConfig::with_model(EmbeddingModel::AllMiniLML6V2))
.expect("model init failed");
let embedding = generator.generate("hello world").await.unwrap();
// AllMiniLML6V2 produces 384-dimensional vectors.
assert_eq!(embedding.len(), 384);
// Sanity: values are in a reasonable range.
assert!(embedding.iter().all(|v| v.is_finite()));
}
#[tokio::test]
#[ignore]
async fn different_texts_produce_different_embeddings() {
let generator =
FastEmbedGenerator::new(FastEmbedConfig::with_model(EmbeddingModel::AllMiniLML6V2))
.expect("model init failed");
let a = generator.generate("cats").await.unwrap();
let b = generator.generate("quantum mechanics").await.unwrap();
assert_ne!(a, b);
}

View File

@@ -0,0 +1,17 @@
[package]
name = "nats"
version = "0.1.0"
edition = "2024"
[dependencies]
domain = { workspace = true }
event-payload = { workspace = true }
async-trait = { workspace = true }
futures = { workspace = true }
tracing = { workspace = true }
serde_json = { workspace = true }
async-nats = "0.37"
async-stream = "0.3"
[dev-dependencies]
tokio = { workspace = true }

View File

@@ -0,0 +1,104 @@
use std::{sync::Arc, time::Duration};
use async_nats::jetstream::{
AckKind,
consumer::{self, pull},
};
use futures::{StreamExt, future::BoxFuture, stream::BoxStream};
use domain::{
errors::DomainError,
events::{DomainEvent, EventConsumer, EventEnvelope},
};
use event_payload::EventPayload;
pub struct NatsEventConsumer {
consumer: Arc<consumer::Consumer<pull::Config>>,
}
impl NatsEventConsumer {
pub(crate) fn new(consumer: consumer::Consumer<pull::Config>) -> Self {
Self {
consumer: Arc::new(consumer),
}
}
}
impl EventConsumer for NatsEventConsumer {
fn consume(&self) -> BoxStream<'_, Result<EventEnvelope, DomainError>> {
let consumer = Arc::clone(&self.consumer);
Box::pin(async_stream::stream! {
let mut messages = match consumer.messages().await {
Ok(m) => m,
Err(e) => {
yield Err(DomainError::Infrastructure(
format!("failed to open jetstream message stream: {e}")
));
return;
}
};
while let Some(result) = messages.next().await {
let msg = match result {
Ok(m) => m,
Err(e) => {
yield Err(DomainError::Infrastructure(e.to_string()));
continue;
}
};
// Malformed messages are acked immediately to prevent infinite
// redelivery of poison payloads that can never be processed.
let payload = match EventPayload::from_json(&msg.payload) {
Ok(p) => p,
Err(e) => {
tracing::error!("unprocessable message payload, acking to discard: {e}");
let _ = msg.ack().await;
continue;
}
};
let event = match DomainEvent::try_from(payload) {
Ok(e) => e,
Err(e) => {
tracing::error!("invalid event payload, acking to discard: {e}");
let _ = msg.ack().await;
continue;
}
};
let delivered = msg.info().map(|i| i.delivered).unwrap_or(1);
let nack_delay = backoff(delivered);
let msg = Arc::new(msg);
let ack_msg = Arc::clone(&msg);
let nack_msg = Arc::clone(&msg);
yield Ok(EventEnvelope::new(
event,
move || -> BoxFuture<'static, _> {
Box::pin(async move {
ack_msg.ack().await.map_err(|e| {
DomainError::Infrastructure(format!("nats ack failed: {e}"))
})
})
},
move || -> BoxFuture<'static, _> {
Box::pin(async move {
nack_msg.ack_with(AckKind::Nak(Some(nack_delay))).await.map_err(|e| {
DomainError::Infrastructure(format!("nats nack failed: {e}"))
})
})
},
));
}
})
}
}
/// Exponential backoff capped at 5 minutes: 1s → 5s → 25s → 125s → 300s.
fn backoff(delivered: i64) -> Duration {
let exp = delivered.saturating_sub(1) as u32;
Duration::from_secs(5u64.saturating_pow(exp).min(300))
}

View File

@@ -0,0 +1,92 @@
pub mod consumer;
pub mod publisher;
use std::time::Duration;
use async_nats::jetstream::{self, consumer as nats_consumer, consumer::pull};
use crate::{consumer::NatsEventConsumer, publisher::NatsEventPublisher};
// ── Subject routing ───────────────────────────────────────────────────────────
pub(crate) fn subject_for(event: &domain::events::DomainEvent) -> &'static str {
use domain::events::DomainEvent;
match event {
DomainEvent::NoteCreated { .. } => "knotes.note.created",
DomainEvent::NoteUpdated { .. } => "knotes.note.updated",
DomainEvent::NoteDeleted { .. } => "knotes.note.deleted",
}
}
pub(crate) const SUBSCRIBE_SUBJECT: &str = "knotes.note.>";
// ── Config ────────────────────────────────────────────────────────────────────
/// Configuration for the JetStream stream and durable pull consumer.
///
/// **Dead-letter queue**: after `max_deliver` failed attempts NATS stops
/// redelivering and publishes an advisory to
/// `$JS.EVENT.ADVISORY.CONSUMER.MAX_DELIVERIES.{stream}.{consumer}`.
/// Subscribe to those with a monitoring consumer or NATS dashboard to
/// observe dead messages.
#[derive(Debug, Clone)]
pub struct JetStreamConfig {
/// Name of the JetStream stream (created on first use if absent).
pub stream_name: String,
/// Durable consumer name — survives worker restarts.
pub consumer_name: String,
/// Maximum delivery attempts before the message is considered dead.
pub max_deliver: i64,
/// How long JetStream waits for an ack before redelivering.
pub ack_wait: Duration,
}
impl Default for JetStreamConfig {
fn default() -> Self {
Self {
stream_name: "KNOTES".into(),
consumer_name: "knotes-worker".into(),
max_deliver: 5,
ack_wait: Duration::from_secs(30),
}
}
}
// ── Setup ─────────────────────────────────────────────────────────────────────
/// Connect to NATS and initialise both the publisher and consumer.
/// Creates the JetStream stream and durable pull consumer if they do not exist.
pub async fn setup(
url: &str,
config: JetStreamConfig,
) -> Result<(NatsEventPublisher, NatsEventConsumer), Box<dyn std::error::Error + Send + Sync>> {
let client = async_nats::connect(url).await?;
let js = jetstream::new(client);
let stream = js
.get_or_create_stream(jetstream::stream::Config {
name: config.stream_name.clone(),
subjects: vec![SUBSCRIBE_SUBJECT.into()],
..Default::default()
})
.await?;
let nats_consumer: nats_consumer::Consumer<pull::Config> = stream
.get_or_create_consumer(
&config.consumer_name,
pull::Config {
durable_name: Some(config.consumer_name.clone()),
ack_policy: jetstream::consumer::AckPolicy::Explicit,
max_deliver: config.max_deliver,
ack_wait: config.ack_wait,
filter_subject: SUBSCRIBE_SUBJECT.into(),
..Default::default()
},
)
.await?;
Ok((
NatsEventPublisher::new(js),
NatsEventConsumer::new(nats_consumer),
))
}

View File

@@ -0,0 +1,34 @@
use async_nats::jetstream;
use async_trait::async_trait;
use domain::{
errors::DomainError,
events::{DomainEvent, EventPublisher},
};
use event_payload::EventPayload;
use crate::subject_for;
pub struct NatsEventPublisher {
js: jetstream::Context,
}
impl NatsEventPublisher {
pub(crate) fn new(js: jetstream::Context) -> Self {
Self { js }
}
}
#[async_trait]
impl EventPublisher for NatsEventPublisher {
async fn publish(&self, event: &DomainEvent) -> Result<(), DomainError> {
let bytes = EventPayload::from(event).to_json()?;
self.js
.publish(subject_for(event), bytes.into())
.await
.map_err(|e| DomainError::Infrastructure(format!("nats publish failed: {e}")))?
.await
.map_err(|e| DomainError::Infrastructure(format!("nats publish ack failed: {e}")))?;
Ok(())
}
}

View File

@@ -0,0 +1,14 @@
[package]
name = "qdrant-adapter"
version = "0.1.0"
edition = "2024"
[dependencies]
domain = { workspace = true }
async-trait = { workspace = true }
tracing = { workspace = true }
uuid = { workspace = true }
qdrant-client = "1"
[dev-dependencies]
tokio = { workspace = true }

View File

@@ -0,0 +1,139 @@
use async_trait::async_trait;
use qdrant_client::{
Qdrant, QdrantError,
qdrant::{
CreateCollectionBuilder, DeletePointsBuilder, Distance, PointId, PointStruct,
PointsIdsList, SearchPointsBuilder, UpsertPointsBuilder, VectorParamsBuilder,
point_id::PointIdOptions,
},
};
use uuid::Uuid;
use domain::{
errors::{DomainError, DomainResult},
note::entity::NoteId,
smart::ports::VectorStore,
};
pub struct QdrantConfig {
pub url: String,
pub collection: String,
/// Dimensionality of the vectors stored in this collection.
/// Must match the output size of the embedding model (e.g. 384 for AllMiniLML6V2).
pub vector_size: u64,
}
impl Default for QdrantConfig {
fn default() -> Self {
Self {
url: "http://localhost:6334".into(),
collection: "notes".into(),
vector_size: 384,
}
}
}
pub struct QdrantVectorStore {
client: Qdrant,
collection: String,
}
impl QdrantVectorStore {
pub fn new(config: QdrantConfig) -> Result<Self, Box<QdrantError>> {
let client = Qdrant::from_url(&config.url).build().map_err(Box::new)?;
Ok(Self {
client,
collection: config.collection,
})
}
/// Ensure the collection exists. Call once during startup before accepting requests.
pub async fn init(&self, vector_size: u64) -> DomainResult<()> {
if self
.client
.collection_exists(&self.collection)
.await
.map_err(qdrant_err)?
{
return Ok(());
}
self.client
.create_collection(
CreateCollectionBuilder::new(&self.collection)
.vectors_config(VectorParamsBuilder::new(vector_size, Distance::Cosine)),
)
.await
.map_err(qdrant_err)?;
tracing::info!(collection = %self.collection, "qdrant collection created");
Ok(())
}
}
#[async_trait]
impl VectorStore for QdrantVectorStore {
async fn upsert(&self, id: &NoteId, vector: &[f32]) -> DomainResult<()> {
let point = PointStruct::new(
uuid_to_point_id(id.as_uuid()),
vector.to_vec(),
qdrant_client::Payload::default(),
);
self.client
.upsert_points(UpsertPointsBuilder::new(&self.collection, vec![point]))
.await
.map_err(qdrant_err)
.map(|_| ())
}
async fn find_similar(&self, vector: &[f32], limit: usize) -> DomainResult<Vec<(NoteId, f32)>> {
let response = self
.client
.search_points(
SearchPointsBuilder::new(&self.collection, vector.to_vec(), limit as u64)
.with_payload(false),
)
.await
.map_err(qdrant_err)?;
response
.result
.into_iter()
.filter_map(|scored| {
let uuid_str = match scored.id?.point_id_options? {
PointIdOptions::Uuid(s) => s,
_ => return None,
};
let uuid = Uuid::parse_str(&uuid_str).ok()?;
Some(Ok((NoteId::from_uuid(uuid), scored.score)))
})
.collect()
}
async fn delete(&self, id: &NoteId) -> DomainResult<()> {
self.client
.delete_points(
DeletePointsBuilder::new(&self.collection).points(PointsIdsList {
ids: vec![uuid_to_point_id(id.as_uuid())],
}),
)
.await
.map_err(qdrant_err)
.map(|_| ())
}
}
fn uuid_to_point_id(uuid: Uuid) -> PointId {
PointId {
point_id_options: Some(PointIdOptions::Uuid(uuid.to_string())),
}
}
fn qdrant_err(e: QdrantError) -> DomainError {
DomainError::Infrastructure(format!("qdrant: {e}"))
}
#[cfg(test)]
#[path = "tests/lib.rs"]
mod tests;

View File

@@ -0,0 +1,45 @@
use domain::{note::entity::NoteId, smart::ports::VectorStore};
use crate::{QdrantConfig, QdrantVectorStore};
const VECTOR_SIZE: u64 = 4; // small for tests
fn test_config() -> QdrantConfig {
QdrantConfig {
url: "http://localhost:6334".into(),
collection: "test-notes".into(),
vector_size: VECTOR_SIZE,
}
}
/// Requires a running Qdrant instance. Run with:
/// cargo test -p qdrant-adapter -- --ignored
#[tokio::test]
#[ignore]
async fn upsert_and_find_similar() {
let store = QdrantVectorStore::new(test_config()).unwrap();
store.init(VECTOR_SIZE).await.unwrap();
let id = NoteId::new();
let vector = vec![1.0f32, 0.0, 0.0, 0.0];
store.upsert(&id, &vector).await.unwrap();
let results = store.find_similar(&vector, 1).await.unwrap();
assert_eq!(results.len(), 1);
assert_eq!(results[0].0, id);
assert!(results[0].1 > 0.99);
}
#[tokio::test]
#[ignore]
async fn delete_removes_vector() {
let store = QdrantVectorStore::new(test_config()).unwrap();
store.init(VECTOR_SIZE).await.unwrap();
let id = NoteId::new();
store.upsert(&id, &[1.0, 0.0, 0.0, 0.0]).await.unwrap();
store.delete(&id).await.unwrap();
let results = store.find_similar(&[1.0, 0.0, 0.0, 0.0], 10).await.unwrap();
assert!(!results.iter().any(|(rid, _)| rid == &id));
}

View File

@@ -0,0 +1,15 @@
[package]
name = "sqlite"
version = "0.1.0"
edition = "2024"
[dependencies]
domain = { workspace = true }
async-trait = { workspace = true }
chrono = { workspace = true }
uuid = { workspace = true }
sqlx = { version = "0.8", features = ["runtime-tokio-rustls", "sqlite", "chrono", "uuid", "migrate", "macros"] }
serde_json = { workspace = true }
[dev-dependencies]
tokio = { workspace = true }

View File

@@ -0,0 +1,71 @@
-- Initial schema for K-Notes
-- SQLite with FTS5 for full-text search
-- Users table (OIDC-ready)
CREATE TABLE IF NOT EXISTS users (
id TEXT PRIMARY KEY NOT NULL,
subject TEXT UNIQUE NOT NULL, -- OIDC subject identifier
email TEXT NOT NULL,
created_at TEXT NOT NULL DEFAULT (datetime('now'))
);
CREATE INDEX idx_users_subject ON users(subject);
CREATE INDEX idx_users_email ON users(email);
-- Notes table
CREATE TABLE IF NOT EXISTS notes (
id TEXT PRIMARY KEY NOT NULL,
user_id TEXT NOT NULL REFERENCES users(id) ON DELETE CASCADE,
title TEXT NOT NULL,
content TEXT NOT NULL DEFAULT '',
is_pinned INTEGER NOT NULL DEFAULT 0,
is_archived INTEGER NOT NULL DEFAULT 0,
created_at TEXT NOT NULL DEFAULT (datetime('now')),
updated_at TEXT NOT NULL DEFAULT (datetime('now'))
);
CREATE INDEX idx_notes_user_id ON notes(user_id);
CREATE INDEX idx_notes_is_pinned ON notes(is_pinned);
CREATE INDEX idx_notes_is_archived ON notes(is_archived);
CREATE INDEX idx_notes_updated_at ON notes(updated_at);
-- Tags table (user-scoped)
CREATE TABLE IF NOT EXISTS tags (
id TEXT PRIMARY KEY NOT NULL,
name TEXT NOT NULL,
user_id TEXT NOT NULL REFERENCES users(id) ON DELETE CASCADE,
UNIQUE(name, user_id)
);
CREATE INDEX idx_tags_user_id ON tags(user_id);
-- Junction table for note-tag relationship
CREATE TABLE IF NOT EXISTS note_tags (
note_id TEXT NOT NULL REFERENCES notes(id) ON DELETE CASCADE,
tag_id TEXT NOT NULL REFERENCES tags(id) ON DELETE CASCADE,
PRIMARY KEY (note_id, tag_id)
);
CREATE INDEX idx_note_tags_tag_id ON note_tags(tag_id);
-- Full-text search virtual table
CREATE VIRTUAL TABLE IF NOT EXISTS notes_fts USING fts5(
title,
content,
content='notes',
content_rowid='rowid'
);
-- Triggers to keep FTS index in sync
CREATE TRIGGER notes_ai AFTER INSERT ON notes BEGIN
INSERT INTO notes_fts(rowid, title, content) VALUES (NEW.rowid, NEW.title, NEW.content);
END;
CREATE TRIGGER notes_ad AFTER DELETE ON notes BEGIN
INSERT INTO notes_fts(notes_fts, rowid, title, content) VALUES('delete', OLD.rowid, OLD.title, OLD.content);
END;
CREATE TRIGGER notes_au AFTER UPDATE ON notes BEGIN
INSERT INTO notes_fts(notes_fts, rowid, title, content) VALUES('delete', OLD.rowid, OLD.title, OLD.content);
INSERT INTO notes_fts(rowid, title, content) VALUES (NEW.rowid, NEW.title, NEW.content);
END;

View File

@@ -0,0 +1,2 @@
-- Add password_hash column to users table
ALTER TABLE users ADD COLUMN password_hash TEXT;

View File

@@ -0,0 +1 @@
ALTER TABLE notes ADD COLUMN color TEXT NOT NULL DEFAULT 'DEFAULT';

View File

@@ -0,0 +1,11 @@
-- Add note_versions table
CREATE TABLE note_versions (
id TEXT PRIMARY KEY,
note_id TEXT NOT NULL,
title TEXT NOT NULL,
content TEXT NOT NULL,
created_at TEXT NOT NULL,
FOREIGN KEY(note_id) REFERENCES notes(id) ON DELETE CASCADE
);
CREATE INDEX idx_note_versions_note_id ON note_versions(note_id);

View File

@@ -0,0 +1,12 @@
CREATE TABLE IF NOT EXISTS note_links (
source_note_id TEXT NOT NULL,
target_note_id TEXT NOT NULL,
score REAL NOT NULL,
created_at DATETIME NOT NULL,
PRIMARY KEY (source_note_id, target_note_id),
FOREIGN KEY (source_note_id) REFERENCES notes(id) ON DELETE CASCADE,
FOREIGN KEY (target_note_id) REFERENCES notes(id) ON DELETE CASCADE
);
CREATE INDEX idx_note_links_source ON note_links(source_note_id);
CREATE INDEX idx_note_links_target ON note_links(target_note_id);

View File

@@ -0,0 +1,45 @@
-- Allow NULL titles in notes table
-- SQLite doesn't support ALTER COLUMN, so we need to recreate the table
-- Step 1: Create new table with nullable title
CREATE TABLE notes_new (
id TEXT PRIMARY KEY NOT NULL,
user_id TEXT NOT NULL REFERENCES users(id) ON DELETE CASCADE,
title TEXT, -- Now nullable
content TEXT NOT NULL DEFAULT '',
color TEXT NOT NULL DEFAULT 'DEFAULT',
is_pinned INTEGER NOT NULL DEFAULT 0,
is_archived INTEGER NOT NULL DEFAULT 0,
created_at TEXT NOT NULL DEFAULT (datetime('now')),
updated_at TEXT NOT NULL DEFAULT (datetime('now'))
);
-- Step 2: Copy data from old table
INSERT INTO notes_new (id, user_id, title, content, color, is_pinned, is_archived, created_at, updated_at)
SELECT id, user_id, title, content, color, is_pinned, is_archived, created_at, updated_at FROM notes;
-- Step 3: Drop old table
DROP TABLE notes;
-- Step 4: Rename new table
ALTER TABLE notes_new RENAME TO notes;
-- Step 5: Recreate indexes
CREATE INDEX idx_notes_user_id ON notes(user_id);
CREATE INDEX idx_notes_is_pinned ON notes(is_pinned);
CREATE INDEX idx_notes_is_archived ON notes(is_archived);
CREATE INDEX idx_notes_updated_at ON notes(updated_at);
-- Step 6: Recreate FTS triggers
CREATE TRIGGER notes_ai AFTER INSERT ON notes BEGIN
INSERT INTO notes_fts(rowid, title, content) VALUES (NEW.rowid, COALESCE(NEW.title, ''), NEW.content);
END;
CREATE TRIGGER notes_ad AFTER DELETE ON notes BEGIN
INSERT INTO notes_fts(notes_fts, rowid, title, content) VALUES('delete', OLD.rowid, COALESCE(OLD.title, ''), OLD.content);
END;
CREATE TRIGGER notes_au AFTER UPDATE ON notes BEGIN
INSERT INTO notes_fts(notes_fts, rowid, title, content) VALUES('delete', OLD.rowid, COALESCE(OLD.title, ''), OLD.content);
INSERT INTO notes_fts(rowid, title, content) VALUES (NEW.rowid, COALESCE(NEW.title, ''), NEW.content);
END;

View File

@@ -0,0 +1,16 @@
-- note_versions.title should be nullable to match notes where title is optional
CREATE TABLE note_versions_new (
id TEXT PRIMARY KEY,
note_id TEXT NOT NULL,
title TEXT,
content TEXT NOT NULL,
created_at TEXT NOT NULL,
FOREIGN KEY(note_id) REFERENCES notes(id) ON DELETE CASCADE
);
INSERT INTO note_versions_new SELECT id, note_id, NULLIF(title, ''), content, created_at FROM note_versions;
DROP TABLE note_versions;
ALTER TABLE note_versions_new RENAME TO note_versions;
CREATE INDEX idx_note_versions_note_id ON note_versions(note_id);

View File

@@ -0,0 +1,43 @@
use chrono::{DateTime, Utc};
pub use sqlx::SqlitePool;
use sqlx::sqlite::{SqliteConnectOptions, SqliteJournalMode, SqlitePoolOptions};
use std::str::FromStr;
use domain::errors::DomainError;
pub async fn connect(database_url: &str) -> Result<SqlitePool, sqlx::Error> {
let options = SqliteConnectOptions::from_str(database_url)?
.create_if_missing(true)
.journal_mode(SqliteJournalMode::Wal)
.foreign_keys(true);
SqlitePoolOptions::new()
.max_connections(5)
.connect_with(options)
.await
}
pub async fn run_migrations(pool: &SqlitePool) -> Result<(), sqlx::migrate::MigrateError> {
sqlx::migrate!("./migrations").run(pool).await
}
/// Parse a datetime string from SQLite (RFC3339 or naive format).
pub(crate) fn parse_dt(s: &str) -> Result<DateTime<Utc>, DomainError> {
DateTime::parse_from_rfc3339(s)
.map(|dt| dt.with_timezone(&Utc))
.or_else(|_| {
chrono::NaiveDateTime::parse_from_str(s, "%Y-%m-%d %H:%M:%S").map(|dt| dt.and_utc())
})
.map_err(|e| DomainError::Repository(format!("invalid datetime '{s}': {e}")))
}
/// Map a sqlx error to DomainError::Repository.
pub(crate) trait RepoExt<T> {
fn repo(self) -> Result<T, DomainError>;
}
impl<T> RepoExt<T> for Result<T, sqlx::Error> {
fn repo(self) -> Result<T, DomainError> {
self.map_err(|e| DomainError::Repository(e.to_string()))
}
}

View File

@@ -0,0 +1,5 @@
pub mod db;
pub mod link;
pub mod note;
pub mod tag;
pub mod user;

View File

@@ -0,0 +1,103 @@
use async_trait::async_trait;
use sqlx::{FromRow, SqlitePool};
use domain::{
errors::{DomainError, DomainResult},
note::{
entity::{NoteId, NoteLink},
ports::LinkRepository,
},
};
use crate::db::RepoExt;
pub struct SqliteLinkRepository {
pool: SqlitePool,
}
impl SqliteLinkRepository {
pub fn new(pool: SqlitePool) -> Self {
Self { pool }
}
}
#[derive(FromRow)]
struct LinkRow {
source_note_id: String,
target_note_id: String,
score: f32,
created_at: String,
}
impl TryFrom<LinkRow> for NoteLink {
type Error = DomainError;
fn try_from(row: LinkRow) -> Result<Self, Self::Error> {
let source_id = NoteId::from_uuid(
uuid::Uuid::parse_str(&row.source_note_id)
.map_err(|e| DomainError::Repository(format!("invalid source uuid: {e}")))?,
);
let target_id = NoteId::from_uuid(
uuid::Uuid::parse_str(&row.target_note_id)
.map_err(|e| DomainError::Repository(format!("invalid target uuid: {e}")))?,
);
let created_at = crate::db::parse_dt(&row.created_at)?;
Ok(NoteLink {
source_id,
target_id,
score: row.score,
created_at,
})
}
}
#[async_trait]
impl LinkRepository for SqliteLinkRepository {
async fn save_links(&self, links: &[NoteLink]) -> DomainResult<()> {
let mut tx = self.pool.begin().await.repo()?;
for link in links {
sqlx::query(
r#"
INSERT INTO note_links (source_note_id, target_note_id, score, created_at)
VALUES (?, ?, ?, ?)
ON CONFLICT(source_note_id, target_note_id) DO UPDATE SET
score = excluded.score,
created_at = excluded.created_at
"#,
)
.bind(link.source_id.as_uuid().to_string())
.bind(link.target_id.as_uuid().to_string())
.bind(link.score)
.bind(link.created_at.to_rfc3339())
.execute(&mut *tx)
.await
.repo()?;
}
tx.commit().await.repo()
}
async fn delete_for_source(&self, source_id: &NoteId) -> DomainResult<()> {
sqlx::query("DELETE FROM note_links WHERE source_note_id = ?")
.bind(source_id.as_uuid().to_string())
.execute(&self.pool)
.await
.repo()
.map(|_| ())
}
async fn find_for_note(&self, note_id: &NoteId) -> DomainResult<Vec<NoteLink>> {
sqlx::query_as::<_, LinkRow>(
"SELECT source_note_id, target_note_id, score, created_at \
FROM note_links WHERE source_note_id = ? ORDER BY score DESC",
)
.bind(note_id.as_uuid().to_string())
.fetch_all(&self.pool)
.await
.repo()?
.into_iter()
.map(NoteLink::try_from)
.collect()
}
}

View File

@@ -0,0 +1,281 @@
use async_trait::async_trait;
use sqlx::{FromRow, QueryBuilder, Sqlite, SqlitePool};
use domain::{
errors::{DomainError, DomainResult},
note::{
entity::{Note, NoteFilter, NoteId, NoteVersion},
ports::NoteRepository,
value_objects::{NoteColor, NoteTitle},
},
tag::entity::{Tag, TagId},
user::entity::UserId,
};
use crate::db::{RepoExt, parse_dt};
pub struct SqliteNoteRepository {
pool: SqlitePool,
}
impl SqliteNoteRepository {
pub fn new(pool: SqlitePool) -> Self {
Self { pool }
}
}
// ── Row types ────────────────────────────────────────────────────────────────
#[derive(FromRow)]
struct NoteRow {
id: String,
user_id: String,
title: Option<String>,
content: String,
color: String,
is_pinned: i32,
is_archived: i32,
created_at: String,
updated_at: String,
tags_json: String,
}
impl TryFrom<NoteRow> for Note {
type Error = DomainError;
fn try_from(row: NoteRow) -> Result<Self, Self::Error> {
let id = NoteId::from_uuid(
uuid::Uuid::parse_str(&row.id)
.map_err(|e| DomainError::Repository(format!("invalid note uuid: {e}")))?,
);
let user_id = UserId::from_uuid(
uuid::Uuid::parse_str(&row.user_id)
.map_err(|e| DomainError::Repository(format!("invalid user uuid: {e}")))?,
);
let title = NoteTitle::from_optional(row.title)?;
let tags = parse_tags_json(&row.tags_json)?;
Ok(Note {
id,
user_id,
title,
content: row.content,
color: NoteColor::new(row.color),
is_pinned: row.is_pinned != 0,
is_archived: row.is_archived != 0,
created_at: parse_dt(&row.created_at)?,
updated_at: parse_dt(&row.updated_at)?,
tags,
})
}
}
fn parse_tags_json(json: &str) -> Result<Vec<Tag>, DomainError> {
let values: Vec<serde_json::Value> = serde_json::from_str(json)
.map_err(|e| DomainError::Repository(format!("invalid tags json: {e}")))?;
values
.into_iter()
.filter(|v| !v.is_null())
.map(|v| {
let parse_str = |key: &str| {
v[key]
.as_str()
.ok_or_else(|| DomainError::Repository(format!("missing tag field '{key}'")))
};
let id = TagId::from_uuid(
uuid::Uuid::parse_str(parse_str("id")?)
.map_err(|e| DomainError::Repository(format!("invalid tag uuid: {e}")))?,
);
let user_id = UserId::from_uuid(
uuid::Uuid::parse_str(parse_str("user_id")?)
.map_err(|e| DomainError::Repository(format!("invalid tag user_id: {e}")))?,
);
let name = domain::tag::value_objects::TagName::new(parse_str("name")?)?;
Ok(Tag::from_row(id, name, user_id))
})
.collect()
}
#[derive(FromRow)]
struct VersionRow {
id: String,
note_id: String,
title: Option<String>,
content: String,
created_at: String,
}
impl TryFrom<VersionRow> for NoteVersion {
type Error = DomainError;
fn try_from(row: VersionRow) -> Result<Self, Self::Error> {
Ok(NoteVersion {
id: uuid::Uuid::parse_str(&row.id)
.map_err(|e| DomainError::Repository(format!("invalid version uuid: {e}")))?,
note_id: NoteId::from_uuid(
uuid::Uuid::parse_str(&row.note_id)
.map_err(|e| DomainError::Repository(format!("invalid note uuid: {e}")))?,
),
title: row.title,
content: row.content,
created_at: parse_dt(&row.created_at)?,
})
}
}
// ── Shared SELECT fragment ────────────────────────────────────────────────────
const NOTE_SELECT: &str = r#"
SELECT n.id, n.user_id, n.title, n.content, n.color, n.is_pinned, n.is_archived,
n.created_at, n.updated_at,
json_group_array(
CASE WHEN t.id IS NOT NULL
THEN json_object('id', t.id, 'name', t.name, 'user_id', t.user_id)
ELSE NULL END
) AS tags_json
FROM notes n
LEFT JOIN note_tags nt ON n.id = nt.note_id
LEFT JOIN tags t ON nt.tag_id = t.id
"#;
// ── NoteRepository ───────────────────────────────────────────────────────────
#[async_trait]
impl NoteRepository for SqliteNoteRepository {
async fn find_by_id(&self, id: &NoteId) -> DomainResult<Option<Note>> {
let sql = format!("{NOTE_SELECT} WHERE n.id = ? GROUP BY n.id");
sqlx::query_as::<_, NoteRow>(&sql)
.bind(id.as_uuid().to_string())
.fetch_optional(&self.pool)
.await
.repo()?
.map(Note::try_from)
.transpose()
}
async fn find_by_user(&self, user_id: &UserId, filter: NoteFilter) -> DomainResult<Vec<Note>> {
let base = format!("{NOTE_SELECT} WHERE n.user_id = ");
let mut qb: QueryBuilder<Sqlite> = QueryBuilder::new(base);
qb.push_bind(user_id.as_uuid().to_string());
if let Some(pinned) = filter.is_pinned {
qb.push(" AND n.is_pinned = ").push_bind(pinned as i32);
}
if let Some(archived) = filter.is_archived {
qb.push(" AND n.is_archived = ").push_bind(archived as i32);
}
if let Some(tag_id) = filter.tag_id {
qb.push(" AND n.id IN (SELECT note_id FROM note_tags WHERE tag_id = ")
.push_bind(tag_id.as_uuid().to_string())
.push(")");
}
qb.push(" GROUP BY n.id ORDER BY n.is_pinned DESC, n.updated_at DESC");
qb.build_query_as::<NoteRow>()
.fetch_all(&self.pool)
.await
.repo()?
.into_iter()
.map(Note::try_from)
.collect()
}
async fn search(&self, user_id: &UserId, query: &str) -> DomainResult<Vec<Note>> {
let sql = format!(
r#"{NOTE_SELECT}
WHERE n.user_id = ?
AND (
n.rowid IN (SELECT rowid FROM notes_fts WHERE notes_fts MATCH ?)
OR EXISTS (
SELECT 1 FROM note_tags nt2
JOIN tags t2 ON nt2.tag_id = t2.id
WHERE nt2.note_id = n.id AND t2.name LIKE ?
)
)
GROUP BY n.id ORDER BY n.updated_at DESC"#
);
sqlx::query_as::<_, NoteRow>(&sql)
.bind(user_id.as_uuid().to_string())
.bind(query)
.bind(format!("%{query}%"))
.fetch_all(&self.pool)
.await
.repo()?
.into_iter()
.map(Note::try_from)
.collect()
}
async fn save(&self, note: &Note) -> DomainResult<()> {
sqlx::query(
r#"
INSERT INTO notes (id, user_id, title, content, color, is_pinned, is_archived, created_at, updated_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
ON CONFLICT(id) DO UPDATE SET
title = excluded.title,
content = excluded.content,
color = excluded.color,
is_pinned = excluded.is_pinned,
is_archived = excluded.is_archived,
updated_at = excluded.updated_at
"#,
)
.bind(note.id.as_uuid().to_string())
.bind(note.user_id.as_uuid().to_string())
.bind(note.title.as_ref().map(|t| t.as_ref()))
.bind(&note.content)
.bind(note.color.as_str())
.bind(note.is_pinned as i32)
.bind(note.is_archived as i32)
.bind(note.created_at.to_rfc3339())
.bind(note.updated_at.to_rfc3339())
.execute(&self.pool)
.await
.repo()
.map(|_| ())
}
async fn delete(&self, id: &NoteId) -> DomainResult<()> {
sqlx::query("DELETE FROM notes WHERE id = ?")
.bind(id.as_uuid().to_string())
.execute(&self.pool)
.await
.repo()
.map(|_| ())
}
async fn save_version(&self, version: &NoteVersion) -> DomainResult<()> {
sqlx::query(
"INSERT INTO note_versions (id, note_id, title, content, created_at) VALUES (?, ?, ?, ?, ?)",
)
.bind(version.id.to_string())
.bind(version.note_id.as_uuid().to_string())
.bind(version.title.as_deref())
.bind(&version.content)
.bind(version.created_at.to_rfc3339())
.execute(&self.pool)
.await
.repo()
.map(|_| ())
}
async fn find_versions(&self, note_id: &NoteId) -> DomainResult<Vec<NoteVersion>> {
sqlx::query_as::<_, VersionRow>(
"SELECT id, note_id, title, content, created_at FROM note_versions WHERE note_id = ? ORDER BY created_at DESC",
)
.bind(note_id.as_uuid().to_string())
.fetch_all(&self.pool)
.await
.repo()?
.into_iter()
.map(NoteVersion::try_from)
.collect()
}
}
#[cfg(test)]
#[path = "tests/note.rs"]
mod tests;

View File

@@ -0,0 +1,157 @@
use async_trait::async_trait;
use sqlx::{FromRow, SqlitePool};
use domain::{
errors::{DomainError, DomainResult},
note::entity::NoteId,
tag::{
entity::{Tag, TagId},
ports::TagRepository,
value_objects::TagName,
},
user::entity::UserId,
};
use crate::db::RepoExt;
pub struct SqliteTagRepository {
pool: SqlitePool,
}
impl SqliteTagRepository {
pub fn new(pool: SqlitePool) -> Self {
Self { pool }
}
}
#[derive(FromRow)]
struct TagRow {
id: String,
name: String,
user_id: String,
}
impl TryFrom<TagRow> for Tag {
type Error = DomainError;
fn try_from(row: TagRow) -> Result<Self, Self::Error> {
let id = TagId::from_uuid(
uuid::Uuid::parse_str(&row.id)
.map_err(|e| DomainError::Repository(format!("invalid tag uuid: {e}")))?,
);
let user_id = UserId::from_uuid(
uuid::Uuid::parse_str(&row.user_id)
.map_err(|e| DomainError::Repository(format!("invalid user uuid: {e}")))?,
);
let name = TagName::new(row.name)?;
Ok(Tag::from_row(id, name, user_id))
}
}
#[async_trait]
impl TagRepository for SqliteTagRepository {
async fn find_by_id(&self, id: &TagId) -> DomainResult<Option<Tag>> {
sqlx::query_as::<_, TagRow>("SELECT id, name, user_id FROM tags WHERE id = ?")
.bind(id.as_uuid().to_string())
.fetch_optional(&self.pool)
.await
.repo()?
.map(Tag::try_from)
.transpose()
}
async fn find_by_user(&self, user_id: &UserId) -> DomainResult<Vec<Tag>> {
sqlx::query_as::<_, TagRow>(
"SELECT id, name, user_id FROM tags WHERE user_id = ? ORDER BY name",
)
.bind(user_id.as_uuid().to_string())
.fetch_all(&self.pool)
.await
.repo()?
.into_iter()
.map(Tag::try_from)
.collect()
}
async fn find_by_name(&self, user_id: &UserId, name: &TagName) -> DomainResult<Option<Tag>> {
sqlx::query_as::<_, TagRow>(
"SELECT id, name, user_id FROM tags WHERE user_id = ? AND name = ?",
)
.bind(user_id.as_uuid().to_string())
.bind(name.as_ref())
.fetch_optional(&self.pool)
.await
.repo()?
.map(Tag::try_from)
.transpose()
}
async fn find_by_note(&self, note_id: &NoteId) -> DomainResult<Vec<Tag>> {
sqlx::query_as::<_, TagRow>(
r#"
SELECT t.id, t.name, t.user_id
FROM tags t
INNER JOIN note_tags nt ON t.id = nt.tag_id
WHERE nt.note_id = ?
ORDER BY t.name
"#,
)
.bind(note_id.as_uuid().to_string())
.fetch_all(&self.pool)
.await
.repo()?
.into_iter()
.map(Tag::try_from)
.collect()
}
async fn save(&self, tag: &Tag) -> DomainResult<()> {
sqlx::query(
r#"
INSERT INTO tags (id, name, user_id)
VALUES (?, ?, ?)
ON CONFLICT(id) DO UPDATE SET name = excluded.name
"#,
)
.bind(tag.id.as_uuid().to_string())
.bind(tag.name.as_ref())
.bind(tag.user_id.as_uuid().to_string())
.execute(&self.pool)
.await
.repo()
.map(|_| ())
}
async fn delete(&self, id: &TagId) -> DomainResult<()> {
sqlx::query("DELETE FROM tags WHERE id = ?")
.bind(id.as_uuid().to_string())
.execute(&self.pool)
.await
.repo()
.map(|_| ())
}
async fn add_to_note(&self, tag_id: &TagId, note_id: &NoteId) -> DomainResult<()> {
sqlx::query("INSERT OR IGNORE INTO note_tags (note_id, tag_id) VALUES (?, ?)")
.bind(note_id.as_uuid().to_string())
.bind(tag_id.as_uuid().to_string())
.execute(&self.pool)
.await
.repo()
.map(|_| ())
}
async fn remove_from_note(&self, tag_id: &TagId, note_id: &NoteId) -> DomainResult<()> {
sqlx::query("DELETE FROM note_tags WHERE note_id = ? AND tag_id = ?")
.bind(note_id.as_uuid().to_string())
.bind(tag_id.as_uuid().to_string())
.execute(&self.pool)
.await
.repo()
.map(|_| ())
}
}
#[cfg(test)]
#[path = "tests/tag.rs"]
mod tests;

View File

@@ -0,0 +1,119 @@
use sqlx::SqlitePool;
use domain::{
note::{
entity::{Note, NoteFilter},
ports::NoteRepository,
value_objects::NoteTitle,
},
user::{entity::User, ports::UserRepository, value_objects::Email},
};
use crate::{db::run_migrations, note::SqliteNoteRepository, user::SqliteUserRepository};
async fn pool() -> SqlitePool {
let p = SqlitePool::connect("sqlite::memory:").await.unwrap();
run_migrations(&p).await.unwrap();
p
}
async fn seed_user(pool: &SqlitePool) -> User {
let repo = SqliteUserRepository::new(pool.clone());
let user = User::new_oidc("sub", Email::new("u@example.com").unwrap());
repo.save(&user).await.unwrap();
user
}
#[tokio::test]
async fn save_and_find_by_id() {
let p = pool().await;
let user = seed_user(&p).await;
let repo = SqliteNoteRepository::new(p);
let note = Note::new(user.id, NoteTitle::new("Hello").ok(), "world".to_string());
repo.save(&note).await.unwrap();
let found = repo.find_by_id(&note.id).await.unwrap().unwrap();
assert_eq!(found.content, "world");
assert_eq!(found.title.as_ref().unwrap().as_ref(), "Hello");
}
#[tokio::test]
async fn save_note_without_title() {
let p = pool().await;
let user = seed_user(&p).await;
let repo = SqliteNoteRepository::new(p);
let note = Note::new(user.id, None, "no title".to_string());
repo.save(&note).await.unwrap();
let found = repo.find_by_id(&note.id).await.unwrap().unwrap();
assert!(found.title.is_none());
}
#[tokio::test]
async fn find_by_user_with_pinned_filter() {
let p = pool().await;
let user = seed_user(&p).await;
let repo = SqliteNoteRepository::new(p);
let mut pinned = Note::new(user.id, None, "pinned".to_string());
pinned.set_pinned(true);
repo.save(&pinned).await.unwrap();
repo.save(&Note::new(user.id, None, "normal".to_string()))
.await
.unwrap();
let results = repo
.find_by_user(&user.id, NoteFilter::default().pinned())
.await
.unwrap();
assert_eq!(results.len(), 1);
assert_eq!(results[0].content, "pinned");
}
#[tokio::test]
async fn delete_removes_note() {
let p = pool().await;
let user = seed_user(&p).await;
let repo = SqliteNoteRepository::new(p);
let note = Note::new(user.id, None, "bye".to_string());
repo.save(&note).await.unwrap();
repo.delete(&note.id).await.unwrap();
assert!(repo.find_by_id(&note.id).await.unwrap().is_none());
}
#[tokio::test]
async fn save_and_find_versions() {
let p = pool().await;
let user = seed_user(&p).await;
let repo = SqliteNoteRepository::new(p);
let note = Note::new(user.id, None, "v1".to_string());
repo.save(&note).await.unwrap();
let version = domain::note::entity::NoteVersion::snapshot(&note);
repo.save_version(&version).await.unwrap();
let versions = repo.find_versions(&note.id).await.unwrap();
assert_eq!(versions.len(), 1);
assert_eq!(versions[0].content, "v1");
}
#[tokio::test]
async fn upsert_updates_note() {
let p = pool().await;
let user = seed_user(&p).await;
let repo = SqliteNoteRepository::new(p);
let mut note = Note::new(user.id, None, "original".to_string());
repo.save(&note).await.unwrap();
note.set_content("updated");
repo.save(&note).await.unwrap();
let found = repo.find_by_id(&note.id).await.unwrap().unwrap();
assert_eq!(found.content, "updated");
}

View File

@@ -0,0 +1,82 @@
use sqlx::SqlitePool;
use domain::{
tag::{entity::Tag, ports::TagRepository, value_objects::TagName},
user::entity::{User, UserId},
};
use crate::{db::run_migrations, tag::SqliteTagRepository, user::SqliteUserRepository};
use domain::user::{ports::UserRepository, value_objects::Email};
async fn pool() -> SqlitePool {
let p = SqlitePool::connect("sqlite::memory:").await.unwrap();
run_migrations(&p).await.unwrap();
p
}
async fn seed_user(pool: &SqlitePool) -> User {
let repo = SqliteUserRepository::new(pool.clone());
let user = User::new_oidc("sub", Email::new("u@example.com").unwrap());
repo.save(&user).await.unwrap();
user
}
#[tokio::test]
async fn save_and_find_by_id() {
let p = pool().await;
let user = seed_user(&p).await;
let repo = SqliteTagRepository::new(p);
let tag = Tag::new(TagName::new("work").unwrap(), user.id);
repo.save(&tag).await.unwrap();
let found = repo.find_by_id(&tag.id).await.unwrap().unwrap();
assert_eq!(found.name.as_ref(), "work");
}
#[tokio::test]
async fn find_by_name() {
let p = pool().await;
let user = seed_user(&p).await;
let repo = SqliteTagRepository::new(p);
let tag = Tag::new(TagName::new("rust").unwrap(), user.id);
repo.save(&tag).await.unwrap();
let found = repo
.find_by_name(&user.id, &TagName::new("rust").unwrap())
.await
.unwrap();
assert_eq!(found.unwrap().id, tag.id);
}
#[tokio::test]
async fn find_by_user_returns_sorted() {
let p = pool().await;
let user = seed_user(&p).await;
let repo = SqliteTagRepository::new(p);
repo.save(&Tag::new(TagName::new("zebra").unwrap(), user.id))
.await
.unwrap();
repo.save(&Tag::new(TagName::new("alpha").unwrap(), user.id))
.await
.unwrap();
let tags = repo.find_by_user(&user.id).await.unwrap();
assert_eq!(tags[0].name.as_ref(), "alpha");
assert_eq!(tags[1].name.as_ref(), "zebra");
}
#[tokio::test]
async fn delete_removes_tag() {
let p = pool().await;
let user = seed_user(&p).await;
let repo = SqliteTagRepository::new(p);
let tag = Tag::new(TagName::new("gone").unwrap(), user.id);
repo.save(&tag).await.unwrap();
repo.delete(&tag.id).await.unwrap();
assert!(repo.find_by_id(&tag.id).await.unwrap().is_none());
}

View File

@@ -0,0 +1,84 @@
use sqlx::SqlitePool;
use domain::user::{
entity::{User, UserId},
ports::UserRepository,
value_objects::{Email, PasswordHash},
};
use crate::{db::run_migrations, user::SqliteUserRepository};
async fn pool() -> SqlitePool {
let p = SqlitePool::connect("sqlite::memory:").await.unwrap();
run_migrations(&p).await.unwrap();
p
}
#[tokio::test]
async fn save_and_find_by_id() {
let repo = SqliteUserRepository::new(pool().await);
let user = User::new_oidc("oidc|123", Email::new("a@example.com").unwrap());
repo.save(&user).await.unwrap();
let found = repo.find_by_id(&user.id).await.unwrap().unwrap();
assert_eq!(found.subject, "oidc|123");
assert_eq!(found.email.as_ref(), "a@example.com");
assert!(found.password_hash.is_none());
}
#[tokio::test]
async fn save_local_user_with_password_hash() {
let repo = SqliteUserRepository::new(pool().await);
let user = User::new_local(
Email::new("local@example.com").unwrap(),
PasswordHash::new("argon2hash"),
);
repo.save(&user).await.unwrap();
let found = repo.find_by_id(&user.id).await.unwrap().unwrap();
assert_eq!(found.password_hash.unwrap().as_str(), "argon2hash");
}
#[tokio::test]
async fn find_by_subject() {
let repo = SqliteUserRepository::new(pool().await);
let user = User::new_oidc("google|456", Email::new("g@example.com").unwrap());
repo.save(&user).await.unwrap();
let found = repo.find_by_subject("google|456").await.unwrap().unwrap();
assert_eq!(found.id, user.id);
}
#[tokio::test]
async fn find_by_email() {
let repo = SqliteUserRepository::new(pool().await);
let email = Email::new("find@example.com").unwrap();
let user = User::new_oidc("sub", email.clone());
repo.save(&user).await.unwrap();
let found = repo.find_by_email(&email).await.unwrap().unwrap();
assert_eq!(found.id, user.id);
}
#[tokio::test]
async fn delete_removes_user() {
let repo = SqliteUserRepository::new(pool().await);
let user = User::new_oidc("del|1", Email::new("del@example.com").unwrap());
repo.save(&user).await.unwrap();
repo.delete(&user.id).await.unwrap();
assert!(repo.find_by_id(&user.id).await.unwrap().is_none());
}
#[tokio::test]
async fn upsert_updates_existing_user() {
let repo = SqliteUserRepository::new(pool().await);
let mut user = User::new_oidc("sub", Email::new("u@example.com").unwrap());
repo.save(&user).await.unwrap();
user.subject = "sub-updated".into();
repo.save(&user).await.unwrap();
let found = repo.find_by_id(&user.id).await.unwrap().unwrap();
assert_eq!(found.subject, "sub-updated");
}

View File

@@ -0,0 +1,129 @@
use async_trait::async_trait;
use sqlx::{FromRow, SqlitePool};
use domain::{
errors::DomainResult,
user::{
entity::{User, UserId},
ports::UserRepository,
value_objects::{Email, PasswordHash},
},
};
use crate::db::{RepoExt, parse_dt};
pub struct SqliteUserRepository {
pool: SqlitePool,
}
impl SqliteUserRepository {
pub fn new(pool: SqlitePool) -> Self {
Self { pool }
}
}
#[derive(FromRow)]
struct UserRow {
id: String,
subject: String,
email: String,
password_hash: Option<String>,
created_at: String,
}
impl TryFrom<UserRow> for User {
type Error = domain::errors::DomainError;
fn try_from(row: UserRow) -> Result<Self, Self::Error> {
use domain::errors::DomainError;
let id = UserId::from_uuid(
uuid::Uuid::parse_str(&row.id)
.map_err(|e| DomainError::Repository(format!("invalid user uuid: {e}")))?,
);
let email = Email::new(&row.email)?;
let password_hash = row.password_hash.map(PasswordHash::new);
let created_at = parse_dt(&row.created_at)?;
Ok(User::from_row(
id,
row.subject,
email,
password_hash,
created_at,
))
}
}
#[async_trait]
impl UserRepository for SqliteUserRepository {
async fn find_by_id(&self, id: &UserId) -> DomainResult<Option<User>> {
let id_str = id.as_uuid().to_string();
sqlx::query_as::<_, UserRow>(
"SELECT id, subject, email, password_hash, created_at FROM users WHERE id = ?",
)
.bind(&id_str)
.fetch_optional(&self.pool)
.await
.repo()?
.map(User::try_from)
.transpose()
}
async fn find_by_subject(&self, subject: &str) -> DomainResult<Option<User>> {
sqlx::query_as::<_, UserRow>(
"SELECT id, subject, email, password_hash, created_at FROM users WHERE subject = ?",
)
.bind(subject)
.fetch_optional(&self.pool)
.await
.repo()?
.map(User::try_from)
.transpose()
}
async fn find_by_email(&self, email: &Email) -> DomainResult<Option<User>> {
sqlx::query_as::<_, UserRow>(
"SELECT id, subject, email, password_hash, created_at FROM users WHERE email = ?",
)
.bind(email.as_ref())
.fetch_optional(&self.pool)
.await
.repo()?
.map(User::try_from)
.transpose()
}
async fn save(&self, user: &User) -> DomainResult<()> {
sqlx::query(
r#"
INSERT INTO users (id, subject, email, password_hash, created_at)
VALUES (?, ?, ?, ?, ?)
ON CONFLICT(id) DO UPDATE SET
subject = excluded.subject,
email = excluded.email,
password_hash = excluded.password_hash
"#,
)
.bind(user.id.as_uuid().to_string())
.bind(&user.subject)
.bind(user.email.as_ref())
.bind(user.password_hash.as_ref().map(PasswordHash::as_str))
.bind(user.created_at.to_rfc3339())
.execute(&self.pool)
.await
.repo()
.map(|_| ())
}
async fn delete(&self, id: &UserId) -> DomainResult<()> {
sqlx::query("DELETE FROM users WHERE id = ?")
.bind(id.as_uuid().to_string())
.execute(&self.pool)
.await
.repo()
.map(|_| ())
}
}
#[cfg(test)]
#[path = "tests/user.rs"]
mod tests;