diff --git a/Cargo.lock b/Cargo.lock index ac04d3c..f494c9f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -648,6 +648,26 @@ dependencies = [ "tracing", ] +[[package]] +name = "axum-governor" +version = "2.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7d5aacd16c0cb4532f3333899bdac866ab0928e92a4aa00f6d397cacab9ed692" +dependencies = [ + "axum", + "dashmap", + "forwarded-header-value", + "governor", + "http 1.4.0", + "ipnet", + "nonzero_ext", + "pin-project-lite", + "serde_json", + "tokio", + "tower", + "tracing", +] + [[package]] name = "axum-macros" version = "0.5.1" @@ -1146,6 +1166,20 @@ dependencies = [ "syn 2.0.117", ] +[[package]] +name = "dashmap" +version = "6.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5041cc499144891f3790297212f32a74fb938e5136a14943f338ef9e0ae276cf" +dependencies = [ + "cfg-if", + "crossbeam-utils", + "hashbrown 0.14.5", + "lock_api", + "once_cell", + "parking_lot_core", +] + [[package]] name = "deltae" version = "0.3.2" @@ -1590,6 +1624,16 @@ dependencies = [ "percent-encoding", ] +[[package]] +name = "forwarded-header-value" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8835f84f38484cc86f110a805655697908257fb9a7af005234060891557198e9" +dependencies = [ + "nonempty", + "thiserror 1.0.69", +] + [[package]] name = "fs_extra" version = "1.3.0" @@ -1691,6 +1735,12 @@ version = "0.3.32" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "037711b3d59c33004d3856fbdc83b99d4ff37a24768fa1be9ce3538a1cde4393" +[[package]] +name = "futures-timer" +version = "3.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f288b0a4f20f9a56b5d1da57e2227c661b7b16168e2f72365f57b63326e29b24" + [[package]] name = "futures-util" version = "0.3.32" @@ -1764,6 +1814,26 @@ version = "0.3.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0cc23270f6e1808e30a928bdc84dea0b9b4136a8bc82338574f23baf47bbd280" +[[package]] +name = "governor" +version = "0.10.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9efcab3c1958580ff1f25a2a41be1668f7603d849bb63af523b208a3cc1223b8" +dependencies = [ + "cfg-if", + "dashmap", + "futures-sink", + "futures-timer", + "futures-util", + "hashbrown 0.16.1", + "nonzero_ext", + "parking_lot", + "portable-atomic", + "smallvec", + "spinning_top", + "web-time", +] + [[package]] name = "h2" version = "0.4.13" @@ -1783,6 +1853,12 @@ dependencies = [ "tracing", ] +[[package]] +name = "hashbrown" +version = "0.14.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1" + [[package]] name = "hashbrown" version = "0.15.5" @@ -2649,6 +2725,18 @@ dependencies = [ "minimal-lexical", ] +[[package]] +name = "nonempty" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e9e591e719385e6ebaeb5ce5d3887f7d5676fceca6411d1925ccc95745f3d6f7" + +[[package]] +name = "nonzero_ext" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "38bf9645c8b145698bb0b18a4637dcacbc421ea49bef2317e4fd8065a387cf21" + [[package]] name = "nu-ansi-term" version = "0.50.3" @@ -3129,6 +3217,7 @@ dependencies = [ "async-trait", "auth", "axum", + "axum-governor", "chrono", "doc", "domain", @@ -4127,6 +4216,15 @@ dependencies = [ "lock_api", ] +[[package]] +name = "spinning_top" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d96d2d1d716fb500937168cc09353ffdc7a012be8475ac7308e1bdf0e3923300" +dependencies = [ + "lock_api", +] + [[package]] name = "spki" version = "0.7.3" diff --git a/crates/presentation/Cargo.toml b/crates/presentation/Cargo.toml index 9cd81ef..4122a8f 100644 --- a/crates/presentation/Cargo.toml +++ b/crates/presentation/Cargo.toml @@ -5,6 +5,9 @@ edition = "2024" [dependencies] tower-http = { version = "0.6.8", features = ["fs", "trace", "tracing"] } +infer = "0.19.0" +percent-encoding = "2" +axum-governor = "2" axum = { workspace = true } serde = { workspace = true } @@ -35,8 +38,6 @@ rss = { workspace = true } export = { workspace = true } doc = { workspace = true } utoipa = { version = "5.5.0", features = ["axum_extras", "uuid"] } -infer = "0.19.0" -percent-encoding = "2" [dev-dependencies] tower = { version = "0.5", features = ["util"] } diff --git a/crates/presentation/src/main.rs b/crates/presentation/src/main.rs index 3de2094..9be2881 100644 --- a/crates/presentation/src/main.rs +++ b/crates/presentation/src/main.rs @@ -45,7 +45,7 @@ async fn main() -> anyhow::Result<()> { let addr = format!("{}:{}", host, port); let listener = TcpListener::bind(&addr).await?; tracing::info!("Listening on {}", addr); - axum::serve(listener, app).await?; + axum::serve(listener, app.into_make_service_with_connect_info::()).await?; Ok(()) } diff --git a/crates/presentation/src/routes.rs b/crates/presentation/src/routes.rs index 72f973e..fc28a70 100644 --- a/crates/presentation/src/routes.rs +++ b/crates/presentation/src/routes.rs @@ -1,54 +1,12 @@ -use std::sync::{ - Arc, - atomic::{AtomicU64, Ordering}, -}; -use std::time::{SystemTime, UNIX_EPOCH}; +use std::net::SocketAddr; +use std::num::NonZeroU32; -use axum::{Router, http::StatusCode, middleware, response::IntoResponse, routing}; +use axum::{Router, routing}; +use axum_governor::{GovernorConfigBuilder, GovernorLayer, Quota, extractor::PeerIp}; use tower_http::{services::ServeDir, trace::TraceLayer}; use crate::{handlers, state::AppState}; -/// Simple global rate limiter: tracks request count per 60-second window. -/// Not per-IP — suitable for a low-traffic personal app. -#[derive(Clone)] -struct RateLimiter { - window: Arc, - count: Arc, - limit: u64, -} - -impl RateLimiter { - fn new(limit: u64) -> Self { - Self { - window: Arc::new(AtomicU64::new(0)), - count: Arc::new(AtomicU64::new(0)), - limit, - } - } - - fn check(&self) -> bool { - let now = SystemTime::now() - .duration_since(UNIX_EPOCH) - .unwrap_or_default() - .as_secs() - / 60; - let prev = self.window.load(Ordering::Acquire); - if now != prev { - // compare_exchange ensures only one thread wins the window reset - if self - .window - .compare_exchange(prev, now, Ordering::AcqRel, Ordering::Relaxed) - .is_ok() - { - self.count.store(1, Ordering::Release); - return true; - } - } - self.count.fetch_add(1, Ordering::Relaxed) + 1 <= self.limit - } -} - pub fn build_router(state: AppState, ap_router: Router) -> Router { let rate_limit = state.app_ctx.config.rate_limit; Router::new() @@ -60,8 +18,12 @@ pub fn build_router(state: AppState, ap_router: Router) -> Router { .merge(ap_router) } +fn per_minute(n: u64) -> Quota { + let n = NonZeroU32::new(n.clamp(1, u32::MAX as u64) as u32).unwrap(); + Quota::requests_per_minute(n) +} + fn html_routes(rate_limit: u64) -> Router { - let limiter = RateLimiter::new(rate_limit); let auth = Router::new() .route( "/login", @@ -72,18 +34,15 @@ fn html_routes(rate_limit: u64) -> Router { "/register", routing::get(handlers::html::get_register_page).post(handlers::html::post_register), ) - .route_layer(middleware::from_fn( - move |req: axum::extract::Request, next: middleware::Next| { - let limiter = limiter.clone(); - async move { - if limiter.check() { - next.run(req).await - } else { - StatusCode::TOO_MANY_REQUESTS.into_response() - } - } - }, - )); + .layer({ + let cfg = GovernorConfigBuilder::default() + .with_extractor(PeerIp::default()) + .expect_connect_info() + .quota_default(per_minute(rate_limit)) + .finish() + .unwrap(); + GovernorLayer::new(cfg) + }); Router::new() .route("/", routing::get(handlers::html::get_activity_feed)) @@ -140,22 +99,16 @@ fn html_routes(rate_limit: u64) -> Router { "/users/{id}/feed.rss", routing::get(handlers::rss::get_user_feed), ) - .layer(middleware::from_fn(crate::csrf::csrf_middleware)) + .layer(axum::middleware::from_fn(crate::csrf::csrf_middleware)) } fn api_routes(rate_limit: u64) -> Router { - let limiter = RateLimiter::new(rate_limit); - let auth_rate_limit = - middleware::from_fn(move |req: axum::extract::Request, next: middleware::Next| { - let limiter = limiter.clone(); - async move { - if limiter.check() { - next.run(req).await - } else { - StatusCode::TOO_MANY_REQUESTS.into_response() - } - } - }); + let cfg = GovernorConfigBuilder::default() + .with_extractor(PeerIp::default()) + .expect_connect_info() + .quota_default(per_minute(rate_limit)) + .finish() + .unwrap(); Router::new().nest( "/api/v1", @@ -188,6 +141,6 @@ fn api_routes(rate_limit: u64) -> Router { .route("/social/followers/accept", routing::post(handlers::api::accept_follower)) .route("/social/followers/reject", routing::post(handlers::api::reject_follower)) .route("/social/followers/remove", routing::post(handlers::api::remove_follower)) - .route_layer(auth_rate_limit), + .layer(GovernorLayer::new(cfg)), ) } diff --git a/crates/presentation/tests/api_test.rs b/crates/presentation/tests/api_test.rs index 8615942..5511845 100644 --- a/crates/presentation/tests/api_test.rs +++ b/crates/presentation/tests/api_test.rs @@ -161,16 +161,25 @@ async fn test_app() -> Router { routes::build_router(state, axum::Router::new()) } +/// Inject a fake peer IP so the GovernorLayer can extract ConnectInfo. +fn with_ip(req: Request) -> Request { + let addr: std::net::SocketAddr = "127.0.0.1:12345".parse().unwrap(); + let mut req = req; + req.extensions_mut() + .insert(axum::extract::ConnectInfo::(addr)); + req +} + #[tokio::test] async fn get_api_diary_returns_empty_list() { let app = test_app().await; let response = app - .oneshot( + .oneshot(with_ip( Request::builder() .uri("/api/v1/diary") .body(Body::empty()) .unwrap(), - ) + )) .await .unwrap(); @@ -189,7 +198,7 @@ async fn get_api_diary_returns_empty_list() { async fn post_api_reviews_without_auth_returns_401() { let app = test_app().await; let response = app - .oneshot( + .oneshot(with_ip( Request::builder() .method("POST") .uri("/api/v1/reviews") @@ -198,7 +207,7 @@ async fn post_api_reviews_without_auth_returns_401() { r#"{"rating":4,"watched_at":"2026-01-01T20:00:00","manual_title":"Dune","manual_release_year":2021}"#, )) .unwrap(), - ) + )) .await .unwrap(); @@ -209,14 +218,14 @@ async fn post_api_reviews_without_auth_returns_401() { async fn post_api_auth_login_unknown_user_returns_401() { let app = test_app().await; let response = app - .oneshot( + .oneshot(with_ip( Request::builder() .method("POST") .uri("/api/v1/auth/login") .header("content-type", "application/json") .body(Body::from(r#"{"email":"a@b.com","password":"x"}"#)) .unwrap(), - ) + )) .await .unwrap();