From 38b4774a6368056a1ce7f878f4c9eca6a1be66d0 Mon Sep 17 00:00:00 2001 From: Gabriel Kaszewski Date: Thu, 14 May 2026 15:37:38 +0200 Subject: [PATCH] feat(bootstrap): configurable HOST, CORS_ORIGINS, and optional rate limiting --- .env.example | 7 ++++ crates/bootstrap/Cargo.toml | 2 ++ crates/bootstrap/src/config.rs | 6 ++++ crates/bootstrap/src/main.rs | 65 ++++++++++++++++++++++++++++++---- 4 files changed, 74 insertions(+), 6 deletions(-) diff --git a/.env.example b/.env.example index 92675a3..d15d6c4 100644 --- a/.env.example +++ b/.env.example @@ -10,6 +10,13 @@ BASE_URL=http://localhost:3000 # Optional HOST=0.0.0.0 PORT=3000 + +# CORS — comma-separated allowed origins, or * for permissive (default: *) +CORS_ORIGINS=* +# CORS_ORIGINS=https://your-nextjs-app.example.com + +# Rate limiting — max requests per minute per IP (disabled by default) +# RATE_LIMIT=60 ALLOW_REGISTRATION=true # set to false to disable new sign-ups RUST_ENV=development # set to "production" to disable AP debug mode diff --git a/crates/bootstrap/Cargo.toml b/crates/bootstrap/Cargo.toml index c21d6df..0bcf085 100644 --- a/crates/bootstrap/Cargo.toml +++ b/crates/bootstrap/Cargo.toml @@ -27,3 +27,5 @@ tower-http = { workspace = true } tracing = { workspace = true } tracing-subscriber = { workspace = true } dotenvy = { workspace = true } +tower_governor = "0.8" +http = "1" diff --git a/crates/bootstrap/src/config.rs b/crates/bootstrap/src/config.rs index 15e700f..59a8ded 100644 --- a/crates/bootstrap/src/config.rs +++ b/crates/bootstrap/src/config.rs @@ -8,6 +8,9 @@ pub struct Config { pub allow_registration: bool, /// true when RUST_ENV != "production" — enables AP debug mode pub debug: bool, + pub host: String, + pub cors_origins: String, + pub rate_limit: Option, } impl Config { @@ -31,6 +34,9 @@ impl Config { debug: std::env::var("RUST_ENV") .map(|v| v != "production") .unwrap_or(true), + host: std::env::var("HOST").unwrap_or_else(|_| "0.0.0.0".into()), + cors_origins: std::env::var("CORS_ORIGINS").unwrap_or_else(|_| "*".into()), + rate_limit: std::env::var("RATE_LIMIT").ok().and_then(|v| v.parse().ok()), } } } diff --git a/crates/bootstrap/src/main.rs b/crates/bootstrap/src/main.rs index 121541a..c77acde 100644 --- a/crates/bootstrap/src/main.rs +++ b/crates/bootstrap/src/main.rs @@ -1,7 +1,9 @@ mod config; mod factory; -use tower_http::cors::CorsLayer; +use std::net::SocketAddr; +use std::sync::Arc; +use tower_http::cors::{AllowOrigin, CorsLayer}; use tracing_subscriber::EnvFilter; #[tokio::main] @@ -14,12 +16,63 @@ async fn main() { let infra = factory::build(&cfg).await; - let app = presentation::routes::router(&infra.fed_config) - .with_state(infra.state) - .layer(CorsLayer::permissive()); + // CORS + let cors = if cfg.cors_origins.trim() == "*" { + CorsLayer::permissive() + } else { + let origins: Vec = cfg + .cors_origins + .split(',') + .map(|o| o.trim()) + .filter_map(|o| o.parse().ok()) + .collect(); + CorsLayer::new() + .allow_origin(AllowOrigin::list(origins)) + .allow_methods(tower_http::cors::Any) + .allow_headers(tower_http::cors::Any) + }; - let addr = format!("0.0.0.0:{}", cfg.port); + let base = presentation::routes::router(&infra.fed_config) + .with_state(infra.state) + .layer(cors); + + let addr = format!("{}:{}", cfg.host, cfg.port); tracing::info!("Listening on {addr}"); let listener = tokio::net::TcpListener::bind(&addr).await.unwrap(); - axum::serve(listener, app).await.unwrap(); + + if let Some(rate_limit) = cfg.rate_limit { + use tower_governor::{governor::GovernorConfigBuilder, GovernorLayer}; // crate: tower_governor + + // per_millisecond sets the token replenishment interval. + // rate_limit = max requests/minute => replenish every (60000 / rate_limit) ms. + let ms = (60_000u64).saturating_div(rate_limit as u64).max(1); + let governor_conf = Arc::new( + GovernorConfigBuilder::default() + .per_millisecond(ms) + .burst_size(rate_limit) + .use_headers() + .finish() + .expect("valid rate limit config"), + ); + + let limiter = governor_conf.limiter().clone(); + tokio::spawn(async move { + let mut interval = + tokio::time::interval(std::time::Duration::from_secs(60)); + loop { + interval.tick().await; + limiter.retain_recent(); + } + }); + + let app = base.layer(GovernorLayer::new(governor_conf)); + axum::serve( + listener, + app.into_make_service_with_connect_info::(), + ) + .await + .unwrap(); + } else { + axum::serve(listener, base).await.unwrap(); + } }