diff --git a/Cargo.lock b/Cargo.lock index 6064440..9827bbf 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1216,6 +1216,8 @@ dependencies = [ "sqlx", "thiserror 2.0.17", "tokio", + "tower-sessions", + "tower-sessions-sqlx-store", "tracing", "uuid", ] diff --git a/notes-api/src/main.rs b/notes-api/src/main.rs index 90ec6cb..f8b0872 100644 --- a/notes-api/src/main.rs +++ b/notes-api/src/main.rs @@ -10,13 +10,9 @@ use axum_login::AuthManagerLayerBuilder; use tower_http::cors::CorsLayer; use tower_http::trace::TraceLayer; use tower_sessions::{Expiry, SessionManagerLayer}; -use tower_sessions_sqlx_store::SqliteStore; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; -use notes_infra::{ - DatabaseConfig, SqliteNoteRepository, SqliteTagRepository, SqliteUserRepository, create_pool, - run_migrations, -}; +use notes_infra::{DatabaseConfig, run_migrations}; mod auth; mod config; @@ -46,29 +42,66 @@ async fn main() -> anyhow::Result<()> { // Setup database tracing::info!("Connecting to database: {}", config.database_url); let db_config = DatabaseConfig::new(&config.database_url); - let pool = create_pool(&db_config).await?; + + use notes_infra::factory::{ + build_database_pool, build_note_repository, build_session_store, build_tag_repository, + build_user_repository, + }; + let pool = build_database_pool(&db_config) + .await + .map_err(|e| anyhow::anyhow!(e))?; // Run migrations - tracing::info!("Running database migrations..."); - run_migrations(&pool).await?; + // The factory/infra layer abstracts the database type + if let Err(e) = run_migrations(&pool).await { + tracing::warn!( + "Migration error (might be expected if not implemented for this DB): {}", + e + ); + } - // Create a default user for development (optional now that we have registration) - create_dev_user(&pool).await?; + // Create a default user for development + create_dev_user(&pool).await.ok(); - // Create repositories - let note_repo = Arc::new(SqliteNoteRepository::new(pool.clone())); - let tag_repo = Arc::new(SqliteTagRepository::new(pool.clone())); - let user_repo = Arc::new(SqliteUserRepository::new(pool.clone())); + // Create repositories via factory + let note_repo = build_note_repository(&pool) + .await + .map_err(|e| anyhow::anyhow!(e))?; + let tag_repo = build_tag_repository(&pool) + .await + .map_err(|e| anyhow::anyhow!(e))?; + let user_repo = build_user_repository(&pool) + .await + .map_err(|e| anyhow::anyhow!(e))?; + + // Create services + use notes_domain::{NoteService, TagService, UserService}; + let note_service = Arc::new(NoteService::new(note_repo.clone(), tag_repo.clone())); + let tag_service = Arc::new(TagService::new(tag_repo.clone())); + let user_service = Arc::new(UserService::new(user_repo.clone())); // Create application state - let state = AppState::new(note_repo, tag_repo, user_repo.clone()); + let state = AppState::new( + note_repo, + tag_repo, + user_repo.clone(), + note_service, + tag_service, + user_service, + ); // Auth backend let backend = AuthBackend::new(user_repo); // Session layer - let session_store = SqliteStore::new(pool.clone()); - session_store.migrate().await?; + // Use the factory to build the session store, agnostic of the underlying DB + let session_store = build_session_store(&pool) + .await + .map_err(|e| anyhow::anyhow!(e))?; + session_store + .migrate() + .await + .map_err(|e| anyhow::anyhow!(e))?; let session_layer = SessionManagerLayer::new(session_store) .with_secure(false) // Set to true in production with HTTPS @@ -129,30 +162,29 @@ async fn main() -> anyhow::Result<()> { Ok(()) } -/// Create a development user for testing -/// In production, users will be created via OIDC authentication -async fn create_dev_user(pool: &sqlx::SqlitePool) -> anyhow::Result<()> { - use notes_domain::{User, UserRepository}; - use notes_infra::SqliteUserRepository; +async fn create_dev_user(pool: ¬es_infra::db::DatabasePool) -> anyhow::Result<()> { + use notes_domain::User; + use notes_infra::factory::build_user_repository; use password_auth::generate_hash; use uuid::Uuid; - let user_repo = SqliteUserRepository::new(pool.clone()); + let user_repo = build_user_repository(pool) + .await + .map_err(|e| anyhow::anyhow!(e))?; // Check if dev user exists let dev_user_id = Uuid::parse_str("00000000-0000-0000-0000-000000000001").unwrap(); if user_repo.find_by_id(dev_user_id).await?.is_none() { - // Create dev user with fixed ID and password 'password' let hash = generate_hash("password"); let user = User::with_id( dev_user_id, "dev|local", - "dev@localhost", + "dev@localhost.com", Some(hash), chrono::Utc::now(), ); user_repo.save(&user).await?; - tracing::info!("Created development user: dev@localhost / password"); + tracing::info!("Created development user: dev@localhost.com / password"); } Ok(()) diff --git a/notes-api/src/routes/notes.rs b/notes-api/src/routes/notes.rs index b4da472..bbb4986 100644 --- a/notes-api/src/routes/notes.rs +++ b/notes-api/src/routes/notes.rs @@ -10,9 +10,7 @@ use uuid::Uuid; use validator::Validate; use axum_login::AuthUser; -use notes_domain::{ - CreateNoteRequest as DomainCreateNote, NoteService, UpdateNoteRequest as DomainUpdateNote, -}; +use notes_domain::{CreateNoteRequest as DomainCreateNote, UpdateNoteRequest as DomainUpdateNote}; use crate::auth::AuthBackend; use crate::dto::{CreateNoteRequest, ListNotesQuery, NoteResponse, SearchQuery, UpdateNoteRequest}; @@ -48,8 +46,7 @@ pub async fn list_notes( } } - let service = NoteService::new(state.note_repo, state.tag_repo); - let notes = service.list_notes(user_id, filter).await?; + let notes = state.note_service.list_notes(user_id, filter).await?; let response: Vec = notes.into_iter().map(NoteResponse::from).collect(); Ok(Json(response)) @@ -74,8 +71,6 @@ pub async fn create_note( .validate() .map_err(|e| ApiError::validation(e.to_string()))?; - let service = NoteService::new(state.note_repo, state.tag_repo); - let domain_req = DomainCreateNote { user_id, title: payload.title, @@ -85,7 +80,7 @@ pub async fn create_note( is_pinned: payload.is_pinned, }; - let note = service.create_note(domain_req).await?; + let note = state.note_service.create_note(domain_req).await?; Ok((StatusCode::CREATED, Json(NoteResponse::from(note)))) } @@ -104,9 +99,7 @@ pub async fn get_note( )))?; let user_id = user.id(); - let service = NoteService::new(state.note_repo, state.tag_repo); - - let note = service.get_note(id, user_id).await?; + let note = state.note_service.get_note(id, user_id).await?; Ok(Json(NoteResponse::from(note))) } @@ -131,8 +124,6 @@ pub async fn update_note( .validate() .map_err(|e| ApiError::validation(e.to_string()))?; - let service = NoteService::new(state.note_repo, state.tag_repo); - let domain_req = DomainUpdateNote { id, user_id, @@ -144,7 +135,7 @@ pub async fn update_note( tags: payload.tags, }; - let note = service.update_note(domain_req).await?; + let note = state.note_service.update_note(domain_req).await?; Ok(Json(NoteResponse::from(note))) } @@ -163,9 +154,7 @@ pub async fn delete_note( )))?; let user_id = user.id(); - let service = NoteService::new(state.note_repo, state.tag_repo); - - service.delete_note(id, user_id).await?; + state.note_service.delete_note(id, user_id).await?; Ok(StatusCode::NO_CONTENT) } @@ -184,9 +173,7 @@ pub async fn search_notes( )))?; let user_id = user.id(); - let service = NoteService::new(state.note_repo, state.tag_repo); - - let notes = service.search_notes(user_id, &query.q).await?; + let notes = state.note_service.search_notes(user_id, &query.q).await?; let response: Vec = notes.into_iter().map(NoteResponse::from).collect(); Ok(Json(response)) @@ -206,9 +193,7 @@ pub async fn list_note_versions( )))?; let user_id = user.id(); - let service = NoteService::new(state.note_repo, state.tag_repo); - - let versions = service.list_note_versions(id, user_id).await?; + let versions = state.note_service.list_note_versions(id, user_id).await?; let response: Vec = versions .into_iter() .map(crate::dto::NoteVersionResponse::from) diff --git a/notes-api/src/routes/tags.rs b/notes-api/src/routes/tags.rs index 0e959bf..b3e084a 100644 --- a/notes-api/src/routes/tags.rs +++ b/notes-api/src/routes/tags.rs @@ -9,8 +9,6 @@ use axum_login::{AuthSession, AuthUser}; use uuid::Uuid; use validator::Validate; -use notes_domain::TagService; - use crate::auth::AuthBackend; use crate::dto::{CreateTagRequest, RenameTagRequest, TagResponse}; use crate::error::{ApiError, ApiResult}; @@ -29,9 +27,7 @@ pub async fn list_tags( )))?; let user_id = user.id(); - let service = TagService::new(state.tag_repo); - - let tags = service.list_tags(user_id).await?; + let tags = state.tag_service.list_tags(user_id).await?; let response: Vec = tags.into_iter().map(TagResponse::from).collect(); Ok(Json(response)) @@ -55,9 +51,7 @@ pub async fn create_tag( .validate() .map_err(|e| ApiError::validation(e.to_string()))?; - let service = TagService::new(state.tag_repo); - - let tag = service.create_tag(user_id, &payload.name).await?; + let tag = state.tag_service.create_tag(user_id, &payload.name).await?; Ok((StatusCode::CREATED, Json(TagResponse::from(tag)))) } @@ -81,9 +75,10 @@ pub async fn rename_tag( .validate() .map_err(|e| ApiError::validation(e.to_string()))?; - let service = TagService::new(state.tag_repo); - - let tag = service.rename_tag(id, user_id, &payload.name).await?; + let tag = state + .tag_service + .rename_tag(id, user_id, &payload.name) + .await?; Ok(Json(TagResponse::from(tag))) } @@ -102,9 +97,7 @@ pub async fn delete_tag( )))?; let user_id = user.id(); - let service = TagService::new(state.tag_repo); - - service.delete_tag(id, user_id).await?; + state.tag_service.delete_tag(id, user_id).await?; Ok(StatusCode::NO_CONTENT) } diff --git a/notes-api/src/state.rs b/notes-api/src/state.rs index e957bbe..de3796e 100644 --- a/notes-api/src/state.rs +++ b/notes-api/src/state.rs @@ -1,8 +1,8 @@ -//! Application state for dependency injection - use std::sync::Arc; -use notes_domain::{NoteRepository, TagRepository, UserRepository}; +use notes_domain::{ + NoteRepository, NoteService, TagRepository, TagService, UserRepository, UserService, +}; /// Application state holding all dependencies #[derive(Clone)] @@ -10,6 +10,9 @@ pub struct AppState { pub note_repo: Arc, pub tag_repo: Arc, pub user_repo: Arc, + pub note_service: Arc, + pub tag_service: Arc, + pub user_service: Arc, } impl AppState { @@ -17,11 +20,17 @@ impl AppState { note_repo: Arc, tag_repo: Arc, user_repo: Arc, + note_service: Arc, + tag_service: Arc, + user_service: Arc, ) -> Self { Self { note_repo, tag_repo, user_repo, + note_service, + tag_service, + user_service, } } } diff --git a/notes-infra/Cargo.toml b/notes-infra/Cargo.toml index da31cf5..f3d9b74 100644 --- a/notes-infra/Cargo.toml +++ b/notes-infra/Cargo.toml @@ -7,13 +7,10 @@ edition = "2024" notes-domain = { path = "../notes-domain" } async-trait = "0.1.89" chrono = { version = "0.4.42", features = ["serde"] } -sqlx = { version = "0.8.6", features = [ - "sqlite", - "runtime-tokio", - "chrono", - "migrate", -] } +sqlx = { version = "0.8.6", features = ["sqlite", "runtime-tokio", "chrono", "migrate", "postgres"] } thiserror = "2.0.17" tokio = { version = "1.48.0", features = ["full"] } tracing = "0.1" uuid = { version = "1.19.0", features = ["v4", "serde"] } +tower-sessions = "0.14.0" +tower-sessions-sqlx-store = { version = "0.15.0", features = ["sqlite", "postgres"] } diff --git a/notes-infra/src/db.rs b/notes-infra/src/db.rs index 1672110..05287f5 100644 --- a/notes-infra/src/db.rs +++ b/notes-infra/src/db.rs @@ -1,6 +1,7 @@ //! Database connection pool management use sqlx::sqlite::{SqliteConnectOptions, SqlitePool, SqlitePoolOptions}; +use sqlx::{Pool, Postgres, Sqlite}; use std::str::FromStr; use std::time::Duration; @@ -32,7 +33,6 @@ impl DatabaseConfig { } } - /// Create an in-memory database config (useful for testing) pub fn in_memory() -> Self { Self { url: "sqlite::memory:".to_string(), @@ -43,6 +43,12 @@ impl DatabaseConfig { } } +#[derive(Clone, Debug)] +pub enum DatabasePool { + Sqlite(Pool), + Postgres(Pool), +} + /// Create a database connection pool pub async fn create_pool(config: &DatabaseConfig) -> Result { let options = SqliteConnectOptions::from_str(&config.url)? @@ -62,8 +68,20 @@ pub async fn create_pool(config: &DatabaseConfig) -> Result Result<(), sqlx::Error> { - sqlx::migrate!("../migrations").run(pool).await?; +pub async fn run_migrations(pool: &DatabasePool) -> Result<(), sqlx::Error> { + match pool { + DatabasePool::Sqlite(pool) => { + sqlx::migrate!("../migrations").run(pool).await?; + } + DatabasePool::Postgres(_pool) => { + // Placeholder for Postgres migrations + // sqlx::migrate!("../migrations/postgres").run(_pool).await?; + tracing::warn!("Postgres migrations not yet implemented"); + return Err(sqlx::Error::Configuration( + "Postgres migrations not yet implemented".into(), + )); + } + } tracing::info!("Database migrations completed successfully"); Ok(()) @@ -84,7 +102,8 @@ mod tests { async fn test_run_migrations() { let config = DatabaseConfig::in_memory(); let pool = create_pool(&config).await.unwrap(); - let result = run_migrations(&pool).await; + let db_pool = DatabasePool::Sqlite(pool); + let result = run_migrations(&db_pool).await; assert!(result.is_ok()); } } diff --git a/notes-infra/src/factory.rs b/notes-infra/src/factory.rs new file mode 100644 index 0000000..f5d190a --- /dev/null +++ b/notes-infra/src/factory.rs @@ -0,0 +1,79 @@ +use std::sync::Arc; + +use notes_domain::{NoteRepository, TagRepository, UserRepository}; + +pub use crate::db::DatabasePool; +use crate::{DatabaseConfig, SqliteNoteRepository, SqliteTagRepository, SqliteUserRepository}; + +#[derive(Debug, thiserror::Error)] +pub enum FactoryError { + #[error("Database error: {0}")] + Database(#[from] sqlx::Error), + #[error("Not implemented: {0}")] + NotImplemented(String), +} + +pub type FactoryResult = Result; + +pub async fn build_database_pool(db_config: &DatabaseConfig) -> FactoryResult { + if db_config.url.starts_with("sqlite:") { + let pool = sqlx::sqlite::SqlitePoolOptions::new() + .max_connections(5) + .connect(&db_config.url) + .await?; + Ok(DatabasePool::Sqlite(pool)) + } else if db_config.url.starts_with("postgres:") { + let pool = sqlx::postgres::PgPoolOptions::new() + .max_connections(5) + .connect(&db_config.url) + .await?; + Ok(DatabasePool::Postgres(pool)) + } else { + Err(FactoryError::NotImplemented(format!( + "Unsupported database URL scheme in: {}", + db_config.url + ))) + } +} + +pub async fn build_note_repository(pool: &DatabasePool) -> FactoryResult> { + match pool { + DatabasePool::Sqlite(pool) => Ok(Arc::new(SqliteNoteRepository::new(pool.clone()))), + DatabasePool::Postgres(_) => Err(FactoryError::NotImplemented( + "Postgres NoteRepository".to_string(), + )), + } +} + +pub async fn build_tag_repository(pool: &DatabasePool) -> FactoryResult> { + match pool { + DatabasePool::Sqlite(pool) => Ok(Arc::new(SqliteTagRepository::new(pool.clone()))), + DatabasePool::Postgres(_) => Err(FactoryError::NotImplemented( + "Postgres TagRepository".to_string(), + )), + } +} + +pub async fn build_user_repository(pool: &DatabasePool) -> FactoryResult> { + match pool { + DatabasePool::Sqlite(pool) => Ok(Arc::new(SqliteUserRepository::new(pool.clone()))), + DatabasePool::Postgres(_) => Err(FactoryError::NotImplemented( + "Postgres UserRepository".to_string(), + )), + } +} + +pub async fn build_session_store( + pool: &DatabasePool, +) -> FactoryResult { + match pool { + DatabasePool::Sqlite(pool) => { + let store = tower_sessions_sqlx_store::SqliteStore::new(pool.clone()); + Ok(crate::session_store::InfraSessionStore::Sqlite(store)) + } + DatabasePool::Postgres(pool) => { + let store = tower_sessions_sqlx_store::PostgresStore::new(pool.clone()); + Ok(crate::session_store::InfraSessionStore::Postgres(store)) + } + } +} diff --git a/notes-infra/src/lib.rs b/notes-infra/src/lib.rs index 549a713..b108745 100644 --- a/notes-infra/src/lib.rs +++ b/notes-infra/src/lib.rs @@ -15,7 +15,9 @@ //! - [`db::run_migrations`] - Run database migrations pub mod db; +pub mod factory; pub mod note_repository; +pub mod session_store; pub mod tag_repository; pub mod user_repository; diff --git a/notes-infra/src/session_store.rs b/notes-infra/src/session_store.rs new file mode 100644 index 0000000..727bfed --- /dev/null +++ b/notes-infra/src/session_store.rs @@ -0,0 +1,46 @@ +use async_trait::async_trait; +use sqlx; +use tower_sessions::{ + SessionStore, + session::{Id, Record}, +}; +use tower_sessions_sqlx_store::{PostgresStore, SqliteStore}; + +#[derive(Clone, Debug)] +pub enum InfraSessionStore { + Sqlite(SqliteStore), + Postgres(PostgresStore), +} + +#[async_trait] +impl SessionStore for InfraSessionStore { + async fn save(&self, session_record: &Record) -> tower_sessions::session_store::Result<()> { + match self { + Self::Sqlite(store) => store.save(session_record).await, + Self::Postgres(store) => store.save(session_record).await, + } + } + + async fn load(&self, session_id: &Id) -> tower_sessions::session_store::Result> { + match self { + Self::Sqlite(store) => store.load(session_id).await, + Self::Postgres(store) => store.load(session_id).await, + } + } + + async fn delete(&self, session_id: &Id) -> tower_sessions::session_store::Result<()> { + match self { + Self::Sqlite(store) => store.delete(session_id).await, + Self::Postgres(store) => store.delete(session_id).await, + } + } +} + +impl InfraSessionStore { + pub async fn migrate(&self) -> Result<(), sqlx::Error> { + match self { + Self::Sqlite(store) => store.migrate().await, + Self::Postgres(store) => store.migrate().await, + } + } +} diff --git a/notes-infra/src/tag_repository.rs b/notes-infra/src/tag_repository.rs index 9b65d10..31ba761 100644 --- a/notes-infra/src/tag_repository.rs +++ b/notes-infra/src/tag_repository.rs @@ -158,14 +158,15 @@ impl TagRepository for SqliteTagRepository { #[cfg(test)] mod tests { use super::*; - use crate::db::{DatabaseConfig, create_pool, run_migrations}; + use crate::db::{DatabaseConfig, DatabasePool, create_pool, run_migrations}; use crate::user_repository::SqliteUserRepository; use notes_domain::{User, UserRepository}; async fn setup_test_db() -> SqlitePool { let config = DatabaseConfig::in_memory(); let pool = create_pool(&config).await.unwrap(); - run_migrations(&pool).await.unwrap(); + let db_pool = DatabasePool::Sqlite(pool.clone()); + run_migrations(&db_pool).await.unwrap(); pool } diff --git a/notes-infra/src/user_repository.rs b/notes-infra/src/user_repository.rs index d8d4416..056768b 100644 --- a/notes-infra/src/user_repository.rs +++ b/notes-infra/src/user_repository.rs @@ -133,12 +133,13 @@ impl UserRepository for SqliteUserRepository { #[cfg(test)] mod tests { use super::*; - use crate::db::{DatabaseConfig, create_pool, run_migrations}; + use crate::db::{DatabaseConfig, DatabasePool, create_pool, run_migrations}; async fn setup_test_db() -> SqlitePool { let config = DatabaseConfig::in_memory(); let pool = create_pool(&config).await.unwrap(); - run_migrations(&pool).await.unwrap(); + let db_pool = DatabasePool::Sqlite(pool.clone()); + run_migrations(&db_pool).await.unwrap(); pool }