diff --git a/Cargo.lock b/Cargo.lock index 6064440..9827bbf 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1216,6 +1216,8 @@ dependencies = [ "sqlx", "thiserror 2.0.17", "tokio", + "tower-sessions", + "tower-sessions-sqlx-store", "tracing", "uuid", ] diff --git a/README.md b/README.md index 79a81ca..a948011 100644 --- a/README.md +++ b/README.md @@ -15,6 +15,7 @@ A modern, self-hosted note-taking application built with performance, security, - **Responsive**: Mobile-friendly UI built with Tailwind CSS. - **Architecture**: - **Backend**: Hexagonal Architecture (Domain, Infra, API layers) in Rust. + - **Infrastructure**: Configurable database backends (SQLite, Postgres). - **Frontend**: Modern React with TypeScript and Vite. - **Deployment**: Full Docker support with `compose.yml`. @@ -23,7 +24,7 @@ A modern, self-hosted note-taking application built with performance, security, ### Backend - **Language**: Rust - **Framework**: Axum -- **Database**: SQLite (SQLx) +- **Database**: SQLite (Default) or Postgres (Supported via feature flag) - **Dependency Injection**: Manual wiring for clear boundaries ### Frontend @@ -59,6 +60,16 @@ The frontend is automatically configured to talk to the backend. cargo run -p notes-api ``` +By default, this uses the **SQLite** backend. + +**Running with Postgres:** + +To use PostgreSQL, build with the `postgres` feature: +```bash +cargo run -p notes-api --no-default-features --features notes-infra/postgres +``` +*Note: Ensure your `DATABASE_URL` is set to a valid Postgres connection string.* + #### Frontend 1. Navigate to `k-notes-frontend`. @@ -74,7 +85,32 @@ bun install bun dev ``` -## 🏗️ Project Structure +## Database Architecture + +The backend follows a Hexagonal Architecture (Ports and Adapters). The `notes-domain` crate defines the repository capabilities (Ports), and `notes-infra` implements them (Adapters). + +### Supported Databases +- **SQLite**: Fully implemented (default). Ideal for single-instance, self-hosted deployments. +- **Postgres**: Structure is in place (via feature flag), ready for implementation. + +### Extending Database Support + +To add a new database (e.g., MySQL), follow these steps: + +1. **Dependencies**: Add the driver to `notes-infra/Cargo.toml` (e.g., `sqlx` with `mysql` feature) and create a feature flag. +2. **Configuration**: Update `DatabaseConfig` in `notes-infra/src/db.rs` to handle the new connection URL scheme and connection logic in `create_pool`. +3. **Repository Implementation**: + - Implement `NoteRepository`, `TagRepository`, and `UserRepository` traits for the new database in `notes-infra`. +4. **Factory Integration**: + - Update `notes-infra/src/factory.rs` to include a builder for the new repositories. + - Update `build_database_pool` and repository `build_*` functions to support the new database type match arm. +5. **Migrations**: + - Add migration files in `migrations/`. + - Update `run_migrations` in `db.rs` to execute them. + +This design ensures the `notes-api` layer remains completely agnostic to the underlying database technology. + +## Project Structure ``` ├── notes-api # API Interface (Axum, HTTP routes) diff --git a/notes-api/Cargo.toml b/notes-api/Cargo.toml index 529a9f1..15c2855 100644 --- a/notes-api/Cargo.toml +++ b/notes-api/Cargo.toml @@ -6,7 +6,7 @@ default-run = "notes-api" [dependencies] notes-domain = { path = "../notes-domain" } -notes-infra = { path = "../notes-infra" } +notes-infra = { path = "../notes-infra", features = ["sqlite"] } # Web framework axum = { version = "0.8.8", features = ["macros"] } diff --git a/notes-api/src/main.rs b/notes-api/src/main.rs index 90ec6cb..f8b0872 100644 --- a/notes-api/src/main.rs +++ b/notes-api/src/main.rs @@ -10,13 +10,9 @@ use axum_login::AuthManagerLayerBuilder; use tower_http::cors::CorsLayer; use tower_http::trace::TraceLayer; use tower_sessions::{Expiry, SessionManagerLayer}; -use tower_sessions_sqlx_store::SqliteStore; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; -use notes_infra::{ - DatabaseConfig, SqliteNoteRepository, SqliteTagRepository, SqliteUserRepository, create_pool, - run_migrations, -}; +use notes_infra::{DatabaseConfig, run_migrations}; mod auth; mod config; @@ -46,29 +42,66 @@ async fn main() -> anyhow::Result<()> { // Setup database tracing::info!("Connecting to database: {}", config.database_url); let db_config = DatabaseConfig::new(&config.database_url); - let pool = create_pool(&db_config).await?; + + use notes_infra::factory::{ + build_database_pool, build_note_repository, build_session_store, build_tag_repository, + build_user_repository, + }; + let pool = build_database_pool(&db_config) + .await + .map_err(|e| anyhow::anyhow!(e))?; // Run migrations - tracing::info!("Running database migrations..."); - run_migrations(&pool).await?; + // The factory/infra layer abstracts the database type + if let Err(e) = run_migrations(&pool).await { + tracing::warn!( + "Migration error (might be expected if not implemented for this DB): {}", + e + ); + } - // Create a default user for development (optional now that we have registration) - create_dev_user(&pool).await?; + // Create a default user for development + create_dev_user(&pool).await.ok(); - // Create repositories - let note_repo = Arc::new(SqliteNoteRepository::new(pool.clone())); - let tag_repo = Arc::new(SqliteTagRepository::new(pool.clone())); - let user_repo = Arc::new(SqliteUserRepository::new(pool.clone())); + // Create repositories via factory + let note_repo = build_note_repository(&pool) + .await + .map_err(|e| anyhow::anyhow!(e))?; + let tag_repo = build_tag_repository(&pool) + .await + .map_err(|e| anyhow::anyhow!(e))?; + let user_repo = build_user_repository(&pool) + .await + .map_err(|e| anyhow::anyhow!(e))?; + + // Create services + use notes_domain::{NoteService, TagService, UserService}; + let note_service = Arc::new(NoteService::new(note_repo.clone(), tag_repo.clone())); + let tag_service = Arc::new(TagService::new(tag_repo.clone())); + let user_service = Arc::new(UserService::new(user_repo.clone())); // Create application state - let state = AppState::new(note_repo, tag_repo, user_repo.clone()); + let state = AppState::new( + note_repo, + tag_repo, + user_repo.clone(), + note_service, + tag_service, + user_service, + ); // Auth backend let backend = AuthBackend::new(user_repo); // Session layer - let session_store = SqliteStore::new(pool.clone()); - session_store.migrate().await?; + // Use the factory to build the session store, agnostic of the underlying DB + let session_store = build_session_store(&pool) + .await + .map_err(|e| anyhow::anyhow!(e))?; + session_store + .migrate() + .await + .map_err(|e| anyhow::anyhow!(e))?; let session_layer = SessionManagerLayer::new(session_store) .with_secure(false) // Set to true in production with HTTPS @@ -129,30 +162,29 @@ async fn main() -> anyhow::Result<()> { Ok(()) } -/// Create a development user for testing -/// In production, users will be created via OIDC authentication -async fn create_dev_user(pool: &sqlx::SqlitePool) -> anyhow::Result<()> { - use notes_domain::{User, UserRepository}; - use notes_infra::SqliteUserRepository; +async fn create_dev_user(pool: ¬es_infra::db::DatabasePool) -> anyhow::Result<()> { + use notes_domain::User; + use notes_infra::factory::build_user_repository; use password_auth::generate_hash; use uuid::Uuid; - let user_repo = SqliteUserRepository::new(pool.clone()); + let user_repo = build_user_repository(pool) + .await + .map_err(|e| anyhow::anyhow!(e))?; // Check if dev user exists let dev_user_id = Uuid::parse_str("00000000-0000-0000-0000-000000000001").unwrap(); if user_repo.find_by_id(dev_user_id).await?.is_none() { - // Create dev user with fixed ID and password 'password' let hash = generate_hash("password"); let user = User::with_id( dev_user_id, "dev|local", - "dev@localhost", + "dev@localhost.com", Some(hash), chrono::Utc::now(), ); user_repo.save(&user).await?; - tracing::info!("Created development user: dev@localhost / password"); + tracing::info!("Created development user: dev@localhost.com / password"); } Ok(()) diff --git a/notes-api/src/routes/notes.rs b/notes-api/src/routes/notes.rs index b4da472..bbb4986 100644 --- a/notes-api/src/routes/notes.rs +++ b/notes-api/src/routes/notes.rs @@ -10,9 +10,7 @@ use uuid::Uuid; use validator::Validate; use axum_login::AuthUser; -use notes_domain::{ - CreateNoteRequest as DomainCreateNote, NoteService, UpdateNoteRequest as DomainUpdateNote, -}; +use notes_domain::{CreateNoteRequest as DomainCreateNote, UpdateNoteRequest as DomainUpdateNote}; use crate::auth::AuthBackend; use crate::dto::{CreateNoteRequest, ListNotesQuery, NoteResponse, SearchQuery, UpdateNoteRequest}; @@ -48,8 +46,7 @@ pub async fn list_notes( } } - let service = NoteService::new(state.note_repo, state.tag_repo); - let notes = service.list_notes(user_id, filter).await?; + let notes = state.note_service.list_notes(user_id, filter).await?; let response: Vec = notes.into_iter().map(NoteResponse::from).collect(); Ok(Json(response)) @@ -74,8 +71,6 @@ pub async fn create_note( .validate() .map_err(|e| ApiError::validation(e.to_string()))?; - let service = NoteService::new(state.note_repo, state.tag_repo); - let domain_req = DomainCreateNote { user_id, title: payload.title, @@ -85,7 +80,7 @@ pub async fn create_note( is_pinned: payload.is_pinned, }; - let note = service.create_note(domain_req).await?; + let note = state.note_service.create_note(domain_req).await?; Ok((StatusCode::CREATED, Json(NoteResponse::from(note)))) } @@ -104,9 +99,7 @@ pub async fn get_note( )))?; let user_id = user.id(); - let service = NoteService::new(state.note_repo, state.tag_repo); - - let note = service.get_note(id, user_id).await?; + let note = state.note_service.get_note(id, user_id).await?; Ok(Json(NoteResponse::from(note))) } @@ -131,8 +124,6 @@ pub async fn update_note( .validate() .map_err(|e| ApiError::validation(e.to_string()))?; - let service = NoteService::new(state.note_repo, state.tag_repo); - let domain_req = DomainUpdateNote { id, user_id, @@ -144,7 +135,7 @@ pub async fn update_note( tags: payload.tags, }; - let note = service.update_note(domain_req).await?; + let note = state.note_service.update_note(domain_req).await?; Ok(Json(NoteResponse::from(note))) } @@ -163,9 +154,7 @@ pub async fn delete_note( )))?; let user_id = user.id(); - let service = NoteService::new(state.note_repo, state.tag_repo); - - service.delete_note(id, user_id).await?; + state.note_service.delete_note(id, user_id).await?; Ok(StatusCode::NO_CONTENT) } @@ -184,9 +173,7 @@ pub async fn search_notes( )))?; let user_id = user.id(); - let service = NoteService::new(state.note_repo, state.tag_repo); - - let notes = service.search_notes(user_id, &query.q).await?; + let notes = state.note_service.search_notes(user_id, &query.q).await?; let response: Vec = notes.into_iter().map(NoteResponse::from).collect(); Ok(Json(response)) @@ -206,9 +193,7 @@ pub async fn list_note_versions( )))?; let user_id = user.id(); - let service = NoteService::new(state.note_repo, state.tag_repo); - - let versions = service.list_note_versions(id, user_id).await?; + let versions = state.note_service.list_note_versions(id, user_id).await?; let response: Vec = versions .into_iter() .map(crate::dto::NoteVersionResponse::from) diff --git a/notes-api/src/routes/tags.rs b/notes-api/src/routes/tags.rs index 0e959bf..b3e084a 100644 --- a/notes-api/src/routes/tags.rs +++ b/notes-api/src/routes/tags.rs @@ -9,8 +9,6 @@ use axum_login::{AuthSession, AuthUser}; use uuid::Uuid; use validator::Validate; -use notes_domain::TagService; - use crate::auth::AuthBackend; use crate::dto::{CreateTagRequest, RenameTagRequest, TagResponse}; use crate::error::{ApiError, ApiResult}; @@ -29,9 +27,7 @@ pub async fn list_tags( )))?; let user_id = user.id(); - let service = TagService::new(state.tag_repo); - - let tags = service.list_tags(user_id).await?; + let tags = state.tag_service.list_tags(user_id).await?; let response: Vec = tags.into_iter().map(TagResponse::from).collect(); Ok(Json(response)) @@ -55,9 +51,7 @@ pub async fn create_tag( .validate() .map_err(|e| ApiError::validation(e.to_string()))?; - let service = TagService::new(state.tag_repo); - - let tag = service.create_tag(user_id, &payload.name).await?; + let tag = state.tag_service.create_tag(user_id, &payload.name).await?; Ok((StatusCode::CREATED, Json(TagResponse::from(tag)))) } @@ -81,9 +75,10 @@ pub async fn rename_tag( .validate() .map_err(|e| ApiError::validation(e.to_string()))?; - let service = TagService::new(state.tag_repo); - - let tag = service.rename_tag(id, user_id, &payload.name).await?; + let tag = state + .tag_service + .rename_tag(id, user_id, &payload.name) + .await?; Ok(Json(TagResponse::from(tag))) } @@ -102,9 +97,7 @@ pub async fn delete_tag( )))?; let user_id = user.id(); - let service = TagService::new(state.tag_repo); - - service.delete_tag(id, user_id).await?; + state.tag_service.delete_tag(id, user_id).await?; Ok(StatusCode::NO_CONTENT) } diff --git a/notes-api/src/state.rs b/notes-api/src/state.rs index e957bbe..de3796e 100644 --- a/notes-api/src/state.rs +++ b/notes-api/src/state.rs @@ -1,8 +1,8 @@ -//! Application state for dependency injection - use std::sync::Arc; -use notes_domain::{NoteRepository, TagRepository, UserRepository}; +use notes_domain::{ + NoteRepository, NoteService, TagRepository, TagService, UserRepository, UserService, +}; /// Application state holding all dependencies #[derive(Clone)] @@ -10,6 +10,9 @@ pub struct AppState { pub note_repo: Arc, pub tag_repo: Arc, pub user_repo: Arc, + pub note_service: Arc, + pub tag_service: Arc, + pub user_service: Arc, } impl AppState { @@ -17,11 +20,17 @@ impl AppState { note_repo: Arc, tag_repo: Arc, user_repo: Arc, + note_service: Arc, + tag_service: Arc, + user_service: Arc, ) -> Self { Self { note_repo, tag_repo, user_repo, + note_service, + tag_service, + user_service, } } } diff --git a/notes-infra/Cargo.toml b/notes-infra/Cargo.toml index da31cf5..e44cdff 100644 --- a/notes-infra/Cargo.toml +++ b/notes-infra/Cargo.toml @@ -3,17 +3,19 @@ name = "notes-infra" version = "0.1.0" edition = "2024" +[features] +default = ["sqlite"] +sqlite = ["sqlx/sqlite", "tower-sessions-sqlx-store/sqlite"] +postgres = ["sqlx/postgres", "tower-sessions-sqlx-store/postgres"] + [dependencies] notes-domain = { path = "../notes-domain" } async-trait = "0.1.89" chrono = { version = "0.4.42", features = ["serde"] } -sqlx = { version = "0.8.6", features = [ - "sqlite", - "runtime-tokio", - "chrono", - "migrate", -] } +sqlx = { version = "0.8.6", features = ["runtime-tokio", "chrono", "migrate"] } thiserror = "2.0.17" tokio = { version = "1.48.0", features = ["full"] } tracing = "0.1" uuid = { version = "1.19.0", features = ["v4", "serde"] } +tower-sessions = "0.14.0" +tower-sessions-sqlx-store = { version = "0.15.0", default-features = false } diff --git a/notes-infra/src/db.rs b/notes-infra/src/db.rs index 1672110..8d89678 100644 --- a/notes-infra/src/db.rs +++ b/notes-infra/src/db.rs @@ -1,6 +1,13 @@ //! Database connection pool management +use sqlx::Pool; +#[cfg(feature = "postgres")] +use sqlx::Postgres; +#[cfg(feature = "sqlite")] +use sqlx::Sqlite; +#[cfg(feature = "sqlite")] use sqlx::sqlite::{SqliteConnectOptions, SqlitePool, SqlitePoolOptions}; +#[cfg(feature = "sqlite")] use std::str::FromStr; use std::time::Duration; @@ -32,7 +39,6 @@ impl DatabaseConfig { } } - /// Create an in-memory database config (useful for testing) pub fn in_memory() -> Self { Self { url: "sqlite::memory:".to_string(), @@ -43,7 +49,16 @@ impl DatabaseConfig { } } +#[derive(Clone, Debug)] +pub enum DatabasePool { + #[cfg(feature = "sqlite")] + Sqlite(Pool), + #[cfg(feature = "postgres")] + Postgres(Pool), +} + /// Create a database connection pool +#[cfg(feature = "sqlite")] pub async fn create_pool(config: &DatabaseConfig) -> Result { let options = SqliteConnectOptions::from_str(&config.url)? .create_if_missing(true) @@ -62,8 +77,28 @@ pub async fn create_pool(config: &DatabaseConfig) -> Result Result<(), sqlx::Error> { - sqlx::migrate!("../migrations").run(pool).await?; +pub async fn run_migrations(pool: &DatabasePool) -> Result<(), sqlx::Error> { + match pool { + #[cfg(feature = "sqlite")] + DatabasePool::Sqlite(pool) => { + sqlx::migrate!("../migrations").run(pool).await?; + } + #[cfg(feature = "postgres")] + DatabasePool::Postgres(_pool) => { + // Placeholder for Postgres migrations + // sqlx::migrate!("../migrations/postgres").run(_pool).await?; + tracing::warn!("Postgres migrations not yet implemented"); + return Err(sqlx::Error::Configuration( + "Postgres migrations not yet implemented".into(), + )); + } + #[allow(unreachable_patterns)] + _ => { + return Err(sqlx::Error::Configuration( + "No database feature enabled".into(), + )); + } + } tracing::info!("Database migrations completed successfully"); Ok(()) @@ -84,7 +119,8 @@ mod tests { async fn test_run_migrations() { let config = DatabaseConfig::in_memory(); let pool = create_pool(&config).await.unwrap(); - let result = run_migrations(&pool).await; + let db_pool = DatabasePool::Sqlite(pool); + let result = run_migrations(&db_pool).await; assert!(result.is_ok()); } } diff --git a/notes-infra/src/factory.rs b/notes-infra/src/factory.rs new file mode 100644 index 0000000..4a8807f --- /dev/null +++ b/notes-infra/src/factory.rs @@ -0,0 +1,117 @@ +use std::sync::Arc; + +use crate::{DatabaseConfig, db::DatabasePool}; +#[cfg(feature = "sqlite")] +use crate::{SqliteNoteRepository, SqliteTagRepository, SqliteUserRepository}; +use notes_domain::{NoteRepository, TagRepository, UserRepository}; + +#[derive(Debug, thiserror::Error)] +pub enum FactoryError { + #[error("Database error: {0}")] + Database(#[from] sqlx::Error), + #[error("Not implemented: {0}")] + NotImplemented(String), +} + +pub type FactoryResult = Result; + +pub async fn build_database_pool(db_config: &DatabaseConfig) -> FactoryResult { + if db_config.url.starts_with("sqlite:") { + #[cfg(feature = "sqlite")] + { + let pool = sqlx::sqlite::SqlitePoolOptions::new() + .max_connections(5) + .connect(&db_config.url) + .await?; + Ok(DatabasePool::Sqlite(pool)) + } + #[cfg(not(feature = "sqlite"))] + Err(FactoryError::NotImplemented( + "SQLite feature not enabled".to_string(), + )) + } else if db_config.url.starts_with("postgres:") { + #[cfg(feature = "postgres")] + { + let pool = sqlx::postgres::PgPoolOptions::new() + .max_connections(5) + .connect(&db_config.url) + .await?; + Ok(DatabasePool::Postgres(pool)) + } + #[cfg(not(feature = "postgres"))] + Err(FactoryError::NotImplemented( + "Postgres feature not enabled".to_string(), + )) + } else { + Err(FactoryError::NotImplemented(format!( + "Unsupported database URL scheme in: {}", + db_config.url + ))) + } +} + +pub async fn build_note_repository(pool: &DatabasePool) -> FactoryResult> { + match pool { + #[cfg(feature = "sqlite")] + DatabasePool::Sqlite(pool) => Ok(Arc::new(SqliteNoteRepository::new(pool.clone()))), + #[cfg(feature = "postgres")] + DatabasePool::Postgres(_) => Err(FactoryError::NotImplemented( + "Postgres NoteRepository".to_string(), + )), + #[allow(unreachable_patterns)] + _ => Err(FactoryError::NotImplemented( + "No database feature enabled".to_string(), + )), + } +} + +pub async fn build_tag_repository(pool: &DatabasePool) -> FactoryResult> { + match pool { + #[cfg(feature = "sqlite")] + DatabasePool::Sqlite(pool) => Ok(Arc::new(SqliteTagRepository::new(pool.clone()))), + #[cfg(feature = "postgres")] + DatabasePool::Postgres(_) => Err(FactoryError::NotImplemented( + "Postgres TagRepository".to_string(), + )), + #[allow(unreachable_patterns)] + _ => Err(FactoryError::NotImplemented( + "No database feature enabled".to_string(), + )), + } +} + +pub async fn build_user_repository(pool: &DatabasePool) -> FactoryResult> { + match pool { + #[cfg(feature = "sqlite")] + DatabasePool::Sqlite(pool) => Ok(Arc::new(SqliteUserRepository::new(pool.clone()))), + #[cfg(feature = "postgres")] + DatabasePool::Postgres(_) => Err(FactoryError::NotImplemented( + "Postgres UserRepository".to_string(), + )), + #[allow(unreachable_patterns)] + _ => Err(FactoryError::NotImplemented( + "No database feature enabled".to_string(), + )), + } +} + +pub async fn build_session_store( + pool: &DatabasePool, +) -> FactoryResult { + match pool { + #[cfg(feature = "sqlite")] + DatabasePool::Sqlite(pool) => { + let store = tower_sessions_sqlx_store::SqliteStore::new(pool.clone()); + Ok(crate::session_store::InfraSessionStore::Sqlite(store)) + } + #[cfg(feature = "postgres")] + DatabasePool::Postgres(pool) => { + let store = tower_sessions_sqlx_store::PostgresStore::new(pool.clone()); + Ok(crate::session_store::InfraSessionStore::Postgres(store)) + } + #[allow(unreachable_patterns)] + _ => Err(FactoryError::NotImplemented( + "No database feature enabled".to_string(), + )), + } +} diff --git a/notes-infra/src/lib.rs b/notes-infra/src/lib.rs index 549a713..bc9a051 100644 --- a/notes-infra/src/lib.rs +++ b/notes-infra/src/lib.rs @@ -15,12 +15,22 @@ //! - [`db::run_migrations`] - Run database migrations pub mod db; +pub mod factory; +#[cfg(feature = "sqlite")] pub mod note_repository; +pub mod session_store; +#[cfg(feature = "sqlite")] pub mod tag_repository; +#[cfg(feature = "sqlite")] pub mod user_repository; // Re-export for convenience -pub use db::{DatabaseConfig, create_pool, run_migrations}; +#[cfg(feature = "sqlite")] +pub use db::create_pool; +pub use db::{DatabaseConfig, run_migrations}; +#[cfg(feature = "sqlite")] pub use note_repository::SqliteNoteRepository; +#[cfg(feature = "sqlite")] pub use tag_repository::SqliteTagRepository; +#[cfg(feature = "sqlite")] pub use user_repository::SqliteUserRepository; diff --git a/notes-infra/src/session_store.rs b/notes-infra/src/session_store.rs new file mode 100644 index 0000000..462aa85 --- /dev/null +++ b/notes-infra/src/session_store.rs @@ -0,0 +1,73 @@ +use async_trait::async_trait; +use sqlx; +use tower_sessions::{ + SessionStore, + session::{Id, Record}, +}; +#[cfg(feature = "postgres")] +use tower_sessions_sqlx_store::PostgresStore; +#[cfg(feature = "sqlite")] +use tower_sessions_sqlx_store::SqliteStore; + +#[derive(Clone, Debug)] +pub enum InfraSessionStore { + #[cfg(feature = "sqlite")] + Sqlite(SqliteStore), + #[cfg(feature = "postgres")] + Postgres(PostgresStore), +} + +#[async_trait] +impl SessionStore for InfraSessionStore { + async fn save(&self, session_record: &Record) -> tower_sessions::session_store::Result<()> { + match self { + #[cfg(feature = "sqlite")] + Self::Sqlite(store) => store.save(session_record).await, + #[cfg(feature = "postgres")] + Self::Postgres(store) => store.save(session_record).await, + #[allow(unreachable_patterns)] + _ => Err(tower_sessions::session_store::Error::Backend( + "No backend enabled".to_string(), + )), + } + } + + async fn load(&self, session_id: &Id) -> tower_sessions::session_store::Result> { + match self { + #[cfg(feature = "sqlite")] + Self::Sqlite(store) => store.load(session_id).await, + #[cfg(feature = "postgres")] + Self::Postgres(store) => store.load(session_id).await, + #[allow(unreachable_patterns)] + _ => Err(tower_sessions::session_store::Error::Backend( + "No backend enabled".to_string(), + )), + } + } + + async fn delete(&self, session_id: &Id) -> tower_sessions::session_store::Result<()> { + match self { + #[cfg(feature = "sqlite")] + Self::Sqlite(store) => store.delete(session_id).await, + #[cfg(feature = "postgres")] + Self::Postgres(store) => store.delete(session_id).await, + #[allow(unreachable_patterns)] + _ => Err(tower_sessions::session_store::Error::Backend( + "No backend enabled".to_string(), + )), + } + } +} + +impl InfraSessionStore { + pub async fn migrate(&self) -> Result<(), sqlx::Error> { + match self { + #[cfg(feature = "sqlite")] + Self::Sqlite(store) => store.migrate().await, + #[cfg(feature = "postgres")] + Self::Postgres(store) => store.migrate().await, + #[allow(unreachable_patterns)] + _ => Err(sqlx::Error::Configuration("No backend enabled".into())), + } + } +} diff --git a/notes-infra/src/tag_repository.rs b/notes-infra/src/tag_repository.rs index 9b65d10..31ba761 100644 --- a/notes-infra/src/tag_repository.rs +++ b/notes-infra/src/tag_repository.rs @@ -158,14 +158,15 @@ impl TagRepository for SqliteTagRepository { #[cfg(test)] mod tests { use super::*; - use crate::db::{DatabaseConfig, create_pool, run_migrations}; + use crate::db::{DatabaseConfig, DatabasePool, create_pool, run_migrations}; use crate::user_repository::SqliteUserRepository; use notes_domain::{User, UserRepository}; async fn setup_test_db() -> SqlitePool { let config = DatabaseConfig::in_memory(); let pool = create_pool(&config).await.unwrap(); - run_migrations(&pool).await.unwrap(); + let db_pool = DatabasePool::Sqlite(pool.clone()); + run_migrations(&db_pool).await.unwrap(); pool } diff --git a/notes-infra/src/user_repository.rs b/notes-infra/src/user_repository.rs index d8d4416..056768b 100644 --- a/notes-infra/src/user_repository.rs +++ b/notes-infra/src/user_repository.rs @@ -133,12 +133,13 @@ impl UserRepository for SqliteUserRepository { #[cfg(test)] mod tests { use super::*; - use crate::db::{DatabaseConfig, create_pool, run_migrations}; + use crate::db::{DatabaseConfig, DatabasePool, create_pool, run_migrations}; async fn setup_test_db() -> SqlitePool { let config = DatabaseConfig::in_memory(); let pool = create_pool(&config).await.unwrap(); - run_migrations(&pool).await.unwrap(); + let db_pool = DatabasePool::Sqlite(pool.clone()); + run_migrations(&db_pool).await.unwrap(); pool }