From c251a5c41fc3b8165d5f07a23438df0cd4fb14eb Mon Sep 17 00:00:00 2001 From: Gabriel Kaszewski Date: Mon, 1 Jun 2026 02:14:44 +0200 Subject: [PATCH] perf: concurrent worker with claim/execute split + graceful shutdown MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - JobRepository::claim_next() — atomic SELECT FOR UPDATE SKIP LOCKED + UPDATE status=processing in one query, no duplicate claims - ExecutePipelineHandler skips start() for already-claimed jobs - Sweep spawns N concurrent tasks via JoinSet, claims are fast+sequential, execution is slow+concurrent - Graceful shutdown: stop claiming, await all in-flight JoinSet tasks - WORKER_CONCURRENCY env (default: CPU cores) - DB_MAX_CONNECTIONS env (default: 20, was hardcoded 10) - VolumeFileResolver impl for InMemoryFileStorage (test fix) --- .env.example | 5 ++ crates/adapters/postgres/src/db.rs | 6 +- .../adapters/postgres/src/processing/mod.rs | 24 +++++- .../processing/commands/execute_pipeline.rs | 6 +- crates/application/src/testing/fakes.rs | 19 +++++ .../application/src/testing/repositories.rs | 16 ++++ crates/domain/src/processing/ports.rs | 1 + crates/worker/src/bootstrap.rs | 14 ++- crates/worker/src/config.rs | 5 ++ crates/worker/src/event_loop.rs | 30 +++---- crates/worker/src/factories/mod.rs | 2 +- crates/worker/src/factories/processing.rs | 12 ++- crates/worker/src/main.rs | 9 +- crates/worker/src/sweep.rs | 85 +++++++++++++++---- 14 files changed, 178 insertions(+), 56 deletions(-) diff --git a/.env.example b/.env.example index a43472c..787278b 100644 --- a/.env.example +++ b/.env.example @@ -39,6 +39,11 @@ STORAGE_PATH=./data/media # ============================================================================ # MAX_UPLOAD_BYTES=268435456 +# ============================================================================ +# Worker concurrency (default: number of CPU cores) +# ============================================================================ +# WORKER_CONCURRENCY=8 + # ============================================================================ # Trash (default 30 days before permanent purge) # ============================================================================ diff --git a/crates/adapters/postgres/src/db.rs b/crates/adapters/postgres/src/db.rs index 75e9f37..2f32673 100644 --- a/crates/adapters/postgres/src/db.rs +++ b/crates/adapters/postgres/src/db.rs @@ -1,8 +1,12 @@ pub type PgPool = sqlx::PgPool; pub async fn connect(url: &str) -> anyhow::Result { + let max_conn: u32 = std::env::var("DB_MAX_CONNECTIONS") + .ok() + .and_then(|v| v.parse().ok()) + .unwrap_or(20); let pool = sqlx::postgres::PgPoolOptions::new() - .max_connections(10) + .max_connections(max_conn) .connect(url) .await?; Ok(pool) diff --git a/crates/adapters/postgres/src/processing/mod.rs b/crates/adapters/postgres/src/processing/mod.rs index c54b374..d727996 100644 --- a/crates/adapters/postgres/src/processing/mod.rs +++ b/crates/adapters/postgres/src/processing/mod.rs @@ -145,7 +145,29 @@ impl JobRepository for PostgresJobRepository { started_at, completed_at, error_message FROM jobs WHERE status = 'queued' ORDER BY priority DESC, created_at ASC - LIMIT 1", + LIMIT 1 + FOR UPDATE SKIP LOCKED", + ) + .fetch_optional(&self.pool) + .await + .map_pg()?; + + Ok(row.map(Into::into)) + } + + async fn claim_next(&self) -> Result, DomainError> { + let row = sqlx::query_as::<_, JobRow>( + "UPDATE jobs SET status = 'processing', started_at = NOW() + WHERE job_id = ( + SELECT job_id FROM jobs + WHERE status = 'queued' + ORDER BY priority DESC, created_at ASC + LIMIT 1 + FOR UPDATE SKIP LOCKED + ) + RETURNING job_id, job_type, target_asset_id, batch_id, status, priority, + payload, result_data, retry_count, max_retries, created_at, + started_at, completed_at, error_message", ) .fetch_optional(&self.pool) .await diff --git a/crates/application/src/processing/commands/execute_pipeline.rs b/crates/application/src/processing/commands/execute_pipeline.rs index d277561..78aaec5 100644 --- a/crates/application/src/processing/commands/execute_pipeline.rs +++ b/crates/application/src/processing/commands/execute_pipeline.rs @@ -60,8 +60,10 @@ impl ExecutePipelineHandler { .await? .ok_or_else(|| DomainError::NotFound(format!("Job {} not found", cmd.job_id)))?; - job.start()?; - self.job_repo.save(&job).await?; + if job.status == domain::entities::JobStatus::Queued { + job.start()?; + self.job_repo.save(&job).await?; + } let trigger = job_type_to_trigger(&job.job_type); let pipelines = self.pipeline_repo.find_by_trigger(trigger).await?; diff --git a/crates/application/src/testing/fakes.rs b/crates/application/src/testing/fakes.rs index 3dce0fe..e5b46fa 100644 --- a/crates/application/src/testing/fakes.rs +++ b/crates/application/src/testing/fakes.rs @@ -118,6 +118,25 @@ impl FileStoragePort for InMemoryFileStorage { } } +#[async_trait] +impl domain::ports::VolumeFileResolver for InMemoryFileStorage { + async fn open_by_volume( + &self, + _volume_id: &domain::value_objects::SystemId, + relative_path: &str, + ) -> Result<(domain::ports::DataStream, u64), DomainError> { + self.open_file(relative_path).await + } + + async fn read_by_volume( + &self, + _volume_id: &domain::value_objects::SystemId, + relative_path: &str, + ) -> Result { + self.read_file(relative_path).await + } +} + // --- StubSidecarWriter --- pub struct StubSidecarWriter; diff --git a/crates/application/src/testing/repositories.rs b/crates/application/src/testing/repositories.rs index ff393a1..d2c1ef0 100644 --- a/crates/application/src/testing/repositories.rs +++ b/crates/application/src/testing/repositories.rs @@ -289,6 +289,22 @@ impl JobRepository for InMemoryJobRepository { .cloned()) } + async fn claim_next(&self) -> Result, DomainError> { + let mut data = self.data.lock().await; + let id = data + .values() + .filter(|j| j.status == JobStatus::Queued) + .max_by_key(|j| j.priority) + .map(|j| j.job_id.to_string()); + if let Some(id) = id { + if let Some(job) = data.get_mut(&id) { + let _ = job.start(); + return Ok(Some(job.clone())); + } + } + Ok(None) + } + async fn find_by_batch(&self, batch_id: &SystemId) -> Result, DomainError> { Ok(self .data diff --git a/crates/domain/src/processing/ports.rs b/crates/domain/src/processing/ports.rs index 1b9240b..665f41d 100644 --- a/crates/domain/src/processing/ports.rs +++ b/crates/domain/src/processing/ports.rs @@ -10,6 +10,7 @@ use std::sync::Arc; pub trait JobRepository: Send + Sync { async fn find_by_id(&self, id: &SystemId) -> Result, DomainError>; async fn find_next_queued(&self) -> Result, DomainError>; + async fn claim_next(&self) -> Result, DomainError>; async fn find_by_batch(&self, batch_id: &SystemId) -> Result, DomainError>; async fn find_all( &self, diff --git a/crates/worker/src/bootstrap.rs b/crates/worker/src/bootstrap.rs index 34bd419..9b4fb2c 100644 --- a/crates/worker/src/bootstrap.rs +++ b/crates/worker/src/bootstrap.rs @@ -1,16 +1,16 @@ use std::sync::Arc; use application::catalog::DeleteAssetHandler; -use application::processing::{EnqueueJobHandler, ProcessNextJobHandler}; +use application::processing::{EnqueueJobHandler, ExecutePipelineHandler}; use domain::ports::{AssetRepository, JobRepository}; use crate::config::WorkerConfig; use crate::factories::{ - Repos, build_enqueue_handler, build_plugin_registry, build_process_next_handler, + Repos, build_enqueue_handler, build_executor, build_plugin_registry, }; pub struct WorkerServices { - pub process_next: Arc, + pub executor: Arc, pub enqueue: Arc, pub job_repo: Arc, pub asset_repo: Arc, @@ -57,11 +57,7 @@ pub async fn build(config: &WorkerConfig) -> anyhow::Result { event_pub.clone(), )); - let process_next = Arc::new(build_process_next_handler( - &repos, - registry, - event_pub.clone(), - )); + let executor = Arc::new(build_executor(&repos, registry, event_pub.clone())); let job_repo: Arc = repos.job.clone(); let asset_repo: Arc = repos.asset.clone(); let enqueue = Arc::new(build_enqueue_handler(&repos, event_pub.clone())); @@ -80,7 +76,7 @@ pub async fn build(config: &WorkerConfig) -> anyhow::Result { let event_consumer = adapters_event_transport::EventConsumerAdapter::new(consumer_source); Ok(WorkerServices { - process_next, + executor, enqueue, job_repo, asset_repo, diff --git a/crates/worker/src/config.rs b/crates/worker/src/config.rs index f3d1175..01af94b 100644 --- a/crates/worker/src/config.rs +++ b/crates/worker/src/config.rs @@ -5,6 +5,7 @@ pub struct WorkerConfig { pub fallback_sweep_secs: u64, pub storage_path: String, pub trash_retention_days: u64, + pub concurrency: usize, } impl WorkerConfig { @@ -22,6 +23,10 @@ impl WorkerConfig { .ok() .and_then(|v| v.parse().ok()) .unwrap_or(30), + concurrency: std::env::var("WORKER_CONCURRENCY") + .ok() + .and_then(|v| v.parse().ok()) + .unwrap_or_else(|| std::thread::available_parallelism().map(|n| n.get()).unwrap_or(4)), } } } diff --git a/crates/worker/src/event_loop.rs b/crates/worker/src/event_loop.rs index 9eaebf7..7cf43fa 100644 --- a/crates/worker/src/event_loop.rs +++ b/crates/worker/src/event_loop.rs @@ -2,9 +2,10 @@ use std::sync::Arc; use futures::StreamExt; use tokio::sync::watch; -use tracing::{error, info, warn}; +use tokio::task::JoinSet; +use tracing::{error, info}; -use application::processing::{EnqueueJobCommand, ProcessNextJobCommand, ProcessNextJobHandler}; +use application::processing::EnqueueJobCommand; use domain::entities::JobType; use domain::events::DomainEvent; use domain::ports::{EventConsumer, JobRepository}; @@ -25,11 +26,12 @@ fn enqueue_cmd(job_type: JobType, priority: u32, asset_id: SystemId) -> EnqueueJ pub async fn run(services: WorkerServices, mut shutdown: watch::Receiver) { info!("event loop: listening for NATS events"); let mut stream = services.event_consumer.consume(); + let mut in_flight = JoinSet::new(); loop { tokio::select! { _ = shutdown.changed() => { - info!("event loop: shutting down"); + info!("event loop: shutdown, waiting for {} in-flight jobs", in_flight.len()); break; } msg = stream.next() => { @@ -69,7 +71,11 @@ pub async fn run(services: WorkerServices, mut shutdown: watch::Receiver) } DomainEvent::JobEnqueued { job_id, job_type, .. } => { info!(job_id = %job_id, job_type = %job_type, "JobEnqueued → process"); - drain_one(&services.process_next).await; + crate::sweep::spawn_one( + &services.job_repo, + &services.executor, + &mut in_flight, + ); } other => { tracing::debug!(event = ?other, "unhandled event, acked"); @@ -78,6 +84,9 @@ pub async fn run(services: WorkerServices, mut shutdown: watch::Receiver) } } } + + while in_flight.join_next().await.is_some() {} + info!("event loop: all in-flight jobs finished"); } async fn handle_job_completed( @@ -97,16 +106,3 @@ async fn handle_job_completed( } } -async fn drain_one(handler: &Arc) { - match handler.execute(ProcessNextJobCommand).await { - Ok(Some(job)) => { - info!(job_id = %job.job_id, status = ?job.status, "processed job"); - } - Ok(None) => { - warn!("JobEnqueued but no queued job found"); - } - Err(e) => { - error!(error = %e, "error processing job"); - } - } -} diff --git a/crates/worker/src/factories/mod.rs b/crates/worker/src/factories/mod.rs index 0b82f2a..d23141c 100644 --- a/crates/worker/src/factories/mod.rs +++ b/crates/worker/src/factories/mod.rs @@ -4,4 +4,4 @@ mod processing; pub use infra::Repos; pub use plugins::build_plugin_registry; -pub use processing::{build_enqueue_handler, build_process_next_handler}; +pub use processing::{build_enqueue_handler, build_executor}; diff --git a/crates/worker/src/factories/processing.rs b/crates/worker/src/factories/processing.rs index 7b451bc..d7b4231 100644 --- a/crates/worker/src/factories/processing.rs +++ b/crates/worker/src/factories/processing.rs @@ -1,24 +1,22 @@ -use application::processing::{EnqueueJobHandler, ExecutePipelineHandler, ProcessNextJobHandler}; +use application::processing::{EnqueueJobHandler, ExecutePipelineHandler}; use domain::ports::{EventPublisher, PluginRegistry}; use std::sync::Arc; use super::Repos; -pub fn build_process_next_handler( +pub fn build_executor( repos: &Repos, registry: Arc, event_pub: Arc, -) -> ProcessNextJobHandler { - let execute_pipeline = Arc::new(ExecutePipelineHandler::new( +) -> ExecutePipelineHandler { + ExecutePipelineHandler::new( repos.job.clone(), repos.batch.clone(), repos.pipeline.clone(), repos.plugin.clone(), registry, event_pub, - )); - - ProcessNextJobHandler::new(repos.job.clone(), execute_pipeline) + ) } pub fn build_enqueue_handler( diff --git a/crates/worker/src/main.rs b/crates/worker/src/main.rs index 1b7e504..32842db 100644 --- a/crates/worker/src/main.rs +++ b/crates/worker/src/main.rs @@ -22,7 +22,7 @@ async fn main() -> anyhow::Result<()> { .init(); let config = config::WorkerConfig::from_env(); - info!("Worker starting"); + info!(concurrency = config.concurrency, "Worker starting"); let services = bootstrap::build(&config).await?; @@ -48,8 +48,10 @@ async fn main() -> anyhow::Result<()> { }); let sweep_interval = Duration::from_secs(config.fallback_sweep_secs); - tokio::spawn(sweep::run( - services.process_next.clone(), + let sweep_handle = tokio::spawn(sweep::run( + services.job_repo.clone(), + services.executor.clone(), + config.concurrency, sweep_interval, shutdown_rx.clone(), )); @@ -62,6 +64,7 @@ async fn main() -> anyhow::Result<()> { )); event_loop::run(services, shutdown_rx).await; + let _ = sweep_handle.await; info!("worker shutdown complete"); Ok(()) diff --git a/crates/worker/src/sweep.rs b/crates/worker/src/sweep.rs index 5e6433a..afa0c9b 100644 --- a/crates/worker/src/sweep.rs +++ b/crates/worker/src/sweep.rs @@ -2,42 +2,97 @@ use std::sync::Arc; use std::time::Duration; use tokio::sync::watch; +use tokio::task::JoinSet; use tracing::{error, info}; use application::catalog::DeleteAssetHandler; -use application::processing::{ProcessNextJobCommand, ProcessNextJobHandler}; -use domain::ports::AssetRepository; +use application::processing::{ExecutePipelineCommand, ExecutePipelineHandler}; +use domain::ports::{AssetRepository, JobRepository}; pub async fn run( - handler: Arc, + job_repo: Arc, + executor: Arc, + concurrency: usize, interval: Duration, mut shutdown: watch::Receiver, ) { - info!(every_secs = interval.as_secs(), "sweep task started"); + let mut in_flight = JoinSet::new(); + info!(every_secs = interval.as_secs(), concurrency, "sweep task started"); + loop { tokio::select! { _ = shutdown.changed() => { - info!("sweep task: shutting down"); + info!("sweep: shutdown, waiting for {} in-flight jobs", in_flight.len()); break; } _ = tokio::time::sleep(interval) => {} } info!("sweep: draining queued jobs"); - loop { - match handler.execute(ProcessNextJobCommand).await { - Ok(Some(job)) => { - info!(job_id = %job.job_id, status = ?job.status, "sweep: processed job"); - } - Ok(None) => break, - Err(e) => { - error!(error = %e, "sweep: error processing job"); - break; - } + drain(&job_repo, &executor, concurrency, &mut in_flight).await; + } + + while in_flight.join_next().await.is_some() {} + info!("sweep: all in-flight jobs finished"); +} + +async fn drain( + job_repo: &Arc, + executor: &Arc, + concurrency: usize, + in_flight: &mut JoinSet<()>, +) { + loop { + while in_flight.len() >= concurrency { + if in_flight.join_next().await.is_none() { + break; } } + + let job = match job_repo.claim_next().await { + Ok(Some(j)) => j, + Ok(None) => break, + Err(e) => { + error!(error = %e, "sweep: error claiming job"); + break; + } + }; + + info!(job_id = %job.job_id, job_type = ?job.job_type, "sweep: claimed"); + let exec = executor.clone(); + in_flight.spawn(async move { + let job_id = job.job_id; + match exec.execute(ExecutePipelineCommand { job_id }).await { + Ok(j) => info!(job_id = %j.job_id, status = ?j.status, "sweep: done"), + Err(e) => error!(job_id = %job_id, error = %e, "sweep: failed"), + } + }); } } +pub fn spawn_one( + job_repo: &Arc, + executor: &Arc, + in_flight: &mut JoinSet<()>, +) { + let repo = job_repo.clone(); + let exec = executor.clone(); + in_flight.spawn(async move { + let job = match repo.claim_next().await { + Ok(Some(j)) => j, + Ok(None) => return, + Err(e) => { + error!(error = %e, "error claiming job"); + return; + } + }; + let job_id = job.job_id; + match exec.execute(ExecutePipelineCommand { job_id }).await { + Ok(j) => info!(job_id = %j.job_id, status = ?j.status, "done"), + Err(e) => error!(job_id = %job_id, error = %e, "failed"), + } + }); +} + pub async fn purge_trash( asset_repo: Arc, delete_handler: Arc,