feat: Implement database factory and abstract session store for multi-database support, centralizing service creation in main.rs.

This commit is contained in:
2025-12-25 22:43:05 +01:00
parent 78d9314602
commit b53dbf2ea8
12 changed files with 246 additions and 80 deletions

2
Cargo.lock generated
View File

@@ -1216,6 +1216,8 @@ dependencies = [
"sqlx",
"thiserror 2.0.17",
"tokio",
"tower-sessions",
"tower-sessions-sqlx-store",
"tracing",
"uuid",
]

View File

@@ -10,13 +10,9 @@ use axum_login::AuthManagerLayerBuilder;
use tower_http::cors::CorsLayer;
use tower_http::trace::TraceLayer;
use tower_sessions::{Expiry, SessionManagerLayer};
use tower_sessions_sqlx_store::SqliteStore;
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
use notes_infra::{
DatabaseConfig, SqliteNoteRepository, SqliteTagRepository, SqliteUserRepository, create_pool,
run_migrations,
};
use notes_infra::{DatabaseConfig, run_migrations};
mod auth;
mod config;
@@ -46,29 +42,66 @@ async fn main() -> anyhow::Result<()> {
// Setup database
tracing::info!("Connecting to database: {}", config.database_url);
let db_config = DatabaseConfig::new(&config.database_url);
let pool = create_pool(&db_config).await?;
use notes_infra::factory::{
build_database_pool, build_note_repository, build_session_store, build_tag_repository,
build_user_repository,
};
let pool = build_database_pool(&db_config)
.await
.map_err(|e| anyhow::anyhow!(e))?;
// Run migrations
tracing::info!("Running database migrations...");
run_migrations(&pool).await?;
// The factory/infra layer abstracts the database type
if let Err(e) = run_migrations(&pool).await {
tracing::warn!(
"Migration error (might be expected if not implemented for this DB): {}",
e
);
}
// Create a default user for development (optional now that we have registration)
create_dev_user(&pool).await?;
// Create a default user for development
create_dev_user(&pool).await.ok();
// Create repositories
let note_repo = Arc::new(SqliteNoteRepository::new(pool.clone()));
let tag_repo = Arc::new(SqliteTagRepository::new(pool.clone()));
let user_repo = Arc::new(SqliteUserRepository::new(pool.clone()));
// Create repositories via factory
let note_repo = build_note_repository(&pool)
.await
.map_err(|e| anyhow::anyhow!(e))?;
let tag_repo = build_tag_repository(&pool)
.await
.map_err(|e| anyhow::anyhow!(e))?;
let user_repo = build_user_repository(&pool)
.await
.map_err(|e| anyhow::anyhow!(e))?;
// Create services
use notes_domain::{NoteService, TagService, UserService};
let note_service = Arc::new(NoteService::new(note_repo.clone(), tag_repo.clone()));
let tag_service = Arc::new(TagService::new(tag_repo.clone()));
let user_service = Arc::new(UserService::new(user_repo.clone()));
// Create application state
let state = AppState::new(note_repo, tag_repo, user_repo.clone());
let state = AppState::new(
note_repo,
tag_repo,
user_repo.clone(),
note_service,
tag_service,
user_service,
);
// Auth backend
let backend = AuthBackend::new(user_repo);
// Session layer
let session_store = SqliteStore::new(pool.clone());
session_store.migrate().await?;
// Use the factory to build the session store, agnostic of the underlying DB
let session_store = build_session_store(&pool)
.await
.map_err(|e| anyhow::anyhow!(e))?;
session_store
.migrate()
.await
.map_err(|e| anyhow::anyhow!(e))?;
let session_layer = SessionManagerLayer::new(session_store)
.with_secure(false) // Set to true in production with HTTPS
@@ -129,30 +162,29 @@ async fn main() -> anyhow::Result<()> {
Ok(())
}
/// Create a development user for testing
/// In production, users will be created via OIDC authentication
async fn create_dev_user(pool: &sqlx::SqlitePool) -> anyhow::Result<()> {
use notes_domain::{User, UserRepository};
use notes_infra::SqliteUserRepository;
async fn create_dev_user(pool: &notes_infra::db::DatabasePool) -> anyhow::Result<()> {
use notes_domain::User;
use notes_infra::factory::build_user_repository;
use password_auth::generate_hash;
use uuid::Uuid;
let user_repo = SqliteUserRepository::new(pool.clone());
let user_repo = build_user_repository(pool)
.await
.map_err(|e| anyhow::anyhow!(e))?;
// Check if dev user exists
let dev_user_id = Uuid::parse_str("00000000-0000-0000-0000-000000000001").unwrap();
if user_repo.find_by_id(dev_user_id).await?.is_none() {
// Create dev user with fixed ID and password 'password'
let hash = generate_hash("password");
let user = User::with_id(
dev_user_id,
"dev|local",
"dev@localhost",
"dev@localhost.com",
Some(hash),
chrono::Utc::now(),
);
user_repo.save(&user).await?;
tracing::info!("Created development user: dev@localhost / password");
tracing::info!("Created development user: dev@localhost.com / password");
}
Ok(())

View File

@@ -10,9 +10,7 @@ use uuid::Uuid;
use validator::Validate;
use axum_login::AuthUser;
use notes_domain::{
CreateNoteRequest as DomainCreateNote, NoteService, UpdateNoteRequest as DomainUpdateNote,
};
use notes_domain::{CreateNoteRequest as DomainCreateNote, UpdateNoteRequest as DomainUpdateNote};
use crate::auth::AuthBackend;
use crate::dto::{CreateNoteRequest, ListNotesQuery, NoteResponse, SearchQuery, UpdateNoteRequest};
@@ -48,8 +46,7 @@ pub async fn list_notes(
}
}
let service = NoteService::new(state.note_repo, state.tag_repo);
let notes = service.list_notes(user_id, filter).await?;
let notes = state.note_service.list_notes(user_id, filter).await?;
let response: Vec<NoteResponse> = notes.into_iter().map(NoteResponse::from).collect();
Ok(Json(response))
@@ -74,8 +71,6 @@ pub async fn create_note(
.validate()
.map_err(|e| ApiError::validation(e.to_string()))?;
let service = NoteService::new(state.note_repo, state.tag_repo);
let domain_req = DomainCreateNote {
user_id,
title: payload.title,
@@ -85,7 +80,7 @@ pub async fn create_note(
is_pinned: payload.is_pinned,
};
let note = service.create_note(domain_req).await?;
let note = state.note_service.create_note(domain_req).await?;
Ok((StatusCode::CREATED, Json(NoteResponse::from(note))))
}
@@ -104,9 +99,7 @@ pub async fn get_note(
)))?;
let user_id = user.id();
let service = NoteService::new(state.note_repo, state.tag_repo);
let note = service.get_note(id, user_id).await?;
let note = state.note_service.get_note(id, user_id).await?;
Ok(Json(NoteResponse::from(note)))
}
@@ -131,8 +124,6 @@ pub async fn update_note(
.validate()
.map_err(|e| ApiError::validation(e.to_string()))?;
let service = NoteService::new(state.note_repo, state.tag_repo);
let domain_req = DomainUpdateNote {
id,
user_id,
@@ -144,7 +135,7 @@ pub async fn update_note(
tags: payload.tags,
};
let note = service.update_note(domain_req).await?;
let note = state.note_service.update_note(domain_req).await?;
Ok(Json(NoteResponse::from(note)))
}
@@ -163,9 +154,7 @@ pub async fn delete_note(
)))?;
let user_id = user.id();
let service = NoteService::new(state.note_repo, state.tag_repo);
service.delete_note(id, user_id).await?;
state.note_service.delete_note(id, user_id).await?;
Ok(StatusCode::NO_CONTENT)
}
@@ -184,9 +173,7 @@ pub async fn search_notes(
)))?;
let user_id = user.id();
let service = NoteService::new(state.note_repo, state.tag_repo);
let notes = service.search_notes(user_id, &query.q).await?;
let notes = state.note_service.search_notes(user_id, &query.q).await?;
let response: Vec<NoteResponse> = notes.into_iter().map(NoteResponse::from).collect();
Ok(Json(response))
@@ -206,9 +193,7 @@ pub async fn list_note_versions(
)))?;
let user_id = user.id();
let service = NoteService::new(state.note_repo, state.tag_repo);
let versions = service.list_note_versions(id, user_id).await?;
let versions = state.note_service.list_note_versions(id, user_id).await?;
let response: Vec<crate::dto::NoteVersionResponse> = versions
.into_iter()
.map(crate::dto::NoteVersionResponse::from)

View File

@@ -9,8 +9,6 @@ use axum_login::{AuthSession, AuthUser};
use uuid::Uuid;
use validator::Validate;
use notes_domain::TagService;
use crate::auth::AuthBackend;
use crate::dto::{CreateTagRequest, RenameTagRequest, TagResponse};
use crate::error::{ApiError, ApiResult};
@@ -29,9 +27,7 @@ pub async fn list_tags(
)))?;
let user_id = user.id();
let service = TagService::new(state.tag_repo);
let tags = service.list_tags(user_id).await?;
let tags = state.tag_service.list_tags(user_id).await?;
let response: Vec<TagResponse> = tags.into_iter().map(TagResponse::from).collect();
Ok(Json(response))
@@ -55,9 +51,7 @@ pub async fn create_tag(
.validate()
.map_err(|e| ApiError::validation(e.to_string()))?;
let service = TagService::new(state.tag_repo);
let tag = service.create_tag(user_id, &payload.name).await?;
let tag = state.tag_service.create_tag(user_id, &payload.name).await?;
Ok((StatusCode::CREATED, Json(TagResponse::from(tag))))
}
@@ -81,9 +75,10 @@ pub async fn rename_tag(
.validate()
.map_err(|e| ApiError::validation(e.to_string()))?;
let service = TagService::new(state.tag_repo);
let tag = service.rename_tag(id, user_id, &payload.name).await?;
let tag = state
.tag_service
.rename_tag(id, user_id, &payload.name)
.await?;
Ok(Json(TagResponse::from(tag)))
}
@@ -102,9 +97,7 @@ pub async fn delete_tag(
)))?;
let user_id = user.id();
let service = TagService::new(state.tag_repo);
service.delete_tag(id, user_id).await?;
state.tag_service.delete_tag(id, user_id).await?;
Ok(StatusCode::NO_CONTENT)
}

View File

@@ -1,8 +1,8 @@
//! Application state for dependency injection
use std::sync::Arc;
use notes_domain::{NoteRepository, TagRepository, UserRepository};
use notes_domain::{
NoteRepository, NoteService, TagRepository, TagService, UserRepository, UserService,
};
/// Application state holding all dependencies
#[derive(Clone)]
@@ -10,6 +10,9 @@ pub struct AppState {
pub note_repo: Arc<dyn NoteRepository>,
pub tag_repo: Arc<dyn TagRepository>,
pub user_repo: Arc<dyn UserRepository>,
pub note_service: Arc<NoteService>,
pub tag_service: Arc<TagService>,
pub user_service: Arc<UserService>,
}
impl AppState {
@@ -17,11 +20,17 @@ impl AppState {
note_repo: Arc<dyn NoteRepository>,
tag_repo: Arc<dyn TagRepository>,
user_repo: Arc<dyn UserRepository>,
note_service: Arc<NoteService>,
tag_service: Arc<TagService>,
user_service: Arc<UserService>,
) -> Self {
Self {
note_repo,
tag_repo,
user_repo,
note_service,
tag_service,
user_service,
}
}
}

View File

@@ -7,13 +7,10 @@ edition = "2024"
notes-domain = { path = "../notes-domain" }
async-trait = "0.1.89"
chrono = { version = "0.4.42", features = ["serde"] }
sqlx = { version = "0.8.6", features = [
"sqlite",
"runtime-tokio",
"chrono",
"migrate",
] }
sqlx = { version = "0.8.6", features = ["sqlite", "runtime-tokio", "chrono", "migrate", "postgres"] }
thiserror = "2.0.17"
tokio = { version = "1.48.0", features = ["full"] }
tracing = "0.1"
uuid = { version = "1.19.0", features = ["v4", "serde"] }
tower-sessions = "0.14.0"
tower-sessions-sqlx-store = { version = "0.15.0", features = ["sqlite", "postgres"] }

View File

@@ -1,6 +1,7 @@
//! Database connection pool management
use sqlx::sqlite::{SqliteConnectOptions, SqlitePool, SqlitePoolOptions};
use sqlx::{Pool, Postgres, Sqlite};
use std::str::FromStr;
use std::time::Duration;
@@ -32,7 +33,6 @@ impl DatabaseConfig {
}
}
/// Create an in-memory database config (useful for testing)
pub fn in_memory() -> Self {
Self {
url: "sqlite::memory:".to_string(),
@@ -43,6 +43,12 @@ impl DatabaseConfig {
}
}
#[derive(Clone, Debug)]
pub enum DatabasePool {
Sqlite(Pool<Sqlite>),
Postgres(Pool<Postgres>),
}
/// Create a database connection pool
pub async fn create_pool(config: &DatabaseConfig) -> Result<SqlitePool, sqlx::Error> {
let options = SqliteConnectOptions::from_str(&config.url)?
@@ -62,8 +68,20 @@ pub async fn create_pool(config: &DatabaseConfig) -> Result<SqlitePool, sqlx::Er
}
/// Run database migrations
pub async fn run_migrations(pool: &SqlitePool) -> Result<(), sqlx::Error> {
sqlx::migrate!("../migrations").run(pool).await?;
pub async fn run_migrations(pool: &DatabasePool) -> Result<(), sqlx::Error> {
match pool {
DatabasePool::Sqlite(pool) => {
sqlx::migrate!("../migrations").run(pool).await?;
}
DatabasePool::Postgres(_pool) => {
// Placeholder for Postgres migrations
// sqlx::migrate!("../migrations/postgres").run(_pool).await?;
tracing::warn!("Postgres migrations not yet implemented");
return Err(sqlx::Error::Configuration(
"Postgres migrations not yet implemented".into(),
));
}
}
tracing::info!("Database migrations completed successfully");
Ok(())
@@ -84,7 +102,8 @@ mod tests {
async fn test_run_migrations() {
let config = DatabaseConfig::in_memory();
let pool = create_pool(&config).await.unwrap();
let result = run_migrations(&pool).await;
let db_pool = DatabasePool::Sqlite(pool);
let result = run_migrations(&db_pool).await;
assert!(result.is_ok());
}
}

View File

@@ -0,0 +1,79 @@
use std::sync::Arc;
use notes_domain::{NoteRepository, TagRepository, UserRepository};
pub use crate::db::DatabasePool;
use crate::{DatabaseConfig, SqliteNoteRepository, SqliteTagRepository, SqliteUserRepository};
#[derive(Debug, thiserror::Error)]
pub enum FactoryError {
#[error("Database error: {0}")]
Database(#[from] sqlx::Error),
#[error("Not implemented: {0}")]
NotImplemented(String),
}
pub type FactoryResult<T> = Result<T, FactoryError>;
pub async fn build_database_pool(db_config: &DatabaseConfig) -> FactoryResult<DatabasePool> {
if db_config.url.starts_with("sqlite:") {
let pool = sqlx::sqlite::SqlitePoolOptions::new()
.max_connections(5)
.connect(&db_config.url)
.await?;
Ok(DatabasePool::Sqlite(pool))
} else if db_config.url.starts_with("postgres:") {
let pool = sqlx::postgres::PgPoolOptions::new()
.max_connections(5)
.connect(&db_config.url)
.await?;
Ok(DatabasePool::Postgres(pool))
} else {
Err(FactoryError::NotImplemented(format!(
"Unsupported database URL scheme in: {}",
db_config.url
)))
}
}
pub async fn build_note_repository(pool: &DatabasePool) -> FactoryResult<Arc<dyn NoteRepository>> {
match pool {
DatabasePool::Sqlite(pool) => Ok(Arc::new(SqliteNoteRepository::new(pool.clone()))),
DatabasePool::Postgres(_) => Err(FactoryError::NotImplemented(
"Postgres NoteRepository".to_string(),
)),
}
}
pub async fn build_tag_repository(pool: &DatabasePool) -> FactoryResult<Arc<dyn TagRepository>> {
match pool {
DatabasePool::Sqlite(pool) => Ok(Arc::new(SqliteTagRepository::new(pool.clone()))),
DatabasePool::Postgres(_) => Err(FactoryError::NotImplemented(
"Postgres TagRepository".to_string(),
)),
}
}
pub async fn build_user_repository(pool: &DatabasePool) -> FactoryResult<Arc<dyn UserRepository>> {
match pool {
DatabasePool::Sqlite(pool) => Ok(Arc::new(SqliteUserRepository::new(pool.clone()))),
DatabasePool::Postgres(_) => Err(FactoryError::NotImplemented(
"Postgres UserRepository".to_string(),
)),
}
}
pub async fn build_session_store(
pool: &DatabasePool,
) -> FactoryResult<crate::session_store::InfraSessionStore> {
match pool {
DatabasePool::Sqlite(pool) => {
let store = tower_sessions_sqlx_store::SqliteStore::new(pool.clone());
Ok(crate::session_store::InfraSessionStore::Sqlite(store))
}
DatabasePool::Postgres(pool) => {
let store = tower_sessions_sqlx_store::PostgresStore::new(pool.clone());
Ok(crate::session_store::InfraSessionStore::Postgres(store))
}
}
}

View File

@@ -15,7 +15,9 @@
//! - [`db::run_migrations`] - Run database migrations
pub mod db;
pub mod factory;
pub mod note_repository;
pub mod session_store;
pub mod tag_repository;
pub mod user_repository;

View File

@@ -0,0 +1,46 @@
use async_trait::async_trait;
use sqlx;
use tower_sessions::{
SessionStore,
session::{Id, Record},
};
use tower_sessions_sqlx_store::{PostgresStore, SqliteStore};
#[derive(Clone, Debug)]
pub enum InfraSessionStore {
Sqlite(SqliteStore),
Postgres(PostgresStore),
}
#[async_trait]
impl SessionStore for InfraSessionStore {
async fn save(&self, session_record: &Record) -> tower_sessions::session_store::Result<()> {
match self {
Self::Sqlite(store) => store.save(session_record).await,
Self::Postgres(store) => store.save(session_record).await,
}
}
async fn load(&self, session_id: &Id) -> tower_sessions::session_store::Result<Option<Record>> {
match self {
Self::Sqlite(store) => store.load(session_id).await,
Self::Postgres(store) => store.load(session_id).await,
}
}
async fn delete(&self, session_id: &Id) -> tower_sessions::session_store::Result<()> {
match self {
Self::Sqlite(store) => store.delete(session_id).await,
Self::Postgres(store) => store.delete(session_id).await,
}
}
}
impl InfraSessionStore {
pub async fn migrate(&self) -> Result<(), sqlx::Error> {
match self {
Self::Sqlite(store) => store.migrate().await,
Self::Postgres(store) => store.migrate().await,
}
}
}

View File

@@ -158,14 +158,15 @@ impl TagRepository for SqliteTagRepository {
#[cfg(test)]
mod tests {
use super::*;
use crate::db::{DatabaseConfig, create_pool, run_migrations};
use crate::db::{DatabaseConfig, DatabasePool, create_pool, run_migrations};
use crate::user_repository::SqliteUserRepository;
use notes_domain::{User, UserRepository};
async fn setup_test_db() -> SqlitePool {
let config = DatabaseConfig::in_memory();
let pool = create_pool(&config).await.unwrap();
run_migrations(&pool).await.unwrap();
let db_pool = DatabasePool::Sqlite(pool.clone());
run_migrations(&db_pool).await.unwrap();
pool
}

View File

@@ -133,12 +133,13 @@ impl UserRepository for SqliteUserRepository {
#[cfg(test)]
mod tests {
use super::*;
use crate::db::{DatabaseConfig, create_pool, run_migrations};
use crate::db::{DatabaseConfig, DatabasePool, create_pool, run_migrations};
async fn setup_test_db() -> SqlitePool {
let config = DatabaseConfig::in_memory();
let pool = create_pool(&config).await.unwrap();
run_migrations(&pool).await.unwrap();
let db_pool = DatabasePool::Sqlite(pool.clone());
run_migrations(&db_pool).await.unwrap();
pool
}