feat: integrate axum-governor for rate limiting and update dependencies

This commit is contained in:
2026-05-09 22:35:08 +02:00
parent d89d373a91
commit a078d5315e
5 changed files with 143 additions and 82 deletions

98
Cargo.lock generated
View File

@@ -648,6 +648,26 @@ dependencies = [
"tracing", "tracing",
] ]
[[package]]
name = "axum-governor"
version = "2.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7d5aacd16c0cb4532f3333899bdac866ab0928e92a4aa00f6d397cacab9ed692"
dependencies = [
"axum",
"dashmap",
"forwarded-header-value",
"governor",
"http 1.4.0",
"ipnet",
"nonzero_ext",
"pin-project-lite",
"serde_json",
"tokio",
"tower",
"tracing",
]
[[package]] [[package]]
name = "axum-macros" name = "axum-macros"
version = "0.5.1" version = "0.5.1"
@@ -1146,6 +1166,20 @@ dependencies = [
"syn 2.0.117", "syn 2.0.117",
] ]
[[package]]
name = "dashmap"
version = "6.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5041cc499144891f3790297212f32a74fb938e5136a14943f338ef9e0ae276cf"
dependencies = [
"cfg-if",
"crossbeam-utils",
"hashbrown 0.14.5",
"lock_api",
"once_cell",
"parking_lot_core",
]
[[package]] [[package]]
name = "deltae" name = "deltae"
version = "0.3.2" version = "0.3.2"
@@ -1590,6 +1624,16 @@ dependencies = [
"percent-encoding", "percent-encoding",
] ]
[[package]]
name = "forwarded-header-value"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8835f84f38484cc86f110a805655697908257fb9a7af005234060891557198e9"
dependencies = [
"nonempty",
"thiserror 1.0.69",
]
[[package]] [[package]]
name = "fs_extra" name = "fs_extra"
version = "1.3.0" version = "1.3.0"
@@ -1691,6 +1735,12 @@ version = "0.3.32"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "037711b3d59c33004d3856fbdc83b99d4ff37a24768fa1be9ce3538a1cde4393" checksum = "037711b3d59c33004d3856fbdc83b99d4ff37a24768fa1be9ce3538a1cde4393"
[[package]]
name = "futures-timer"
version = "3.0.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f288b0a4f20f9a56b5d1da57e2227c661b7b16168e2f72365f57b63326e29b24"
[[package]] [[package]]
name = "futures-util" name = "futures-util"
version = "0.3.32" version = "0.3.32"
@@ -1764,6 +1814,26 @@ version = "0.3.3"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0cc23270f6e1808e30a928bdc84dea0b9b4136a8bc82338574f23baf47bbd280" checksum = "0cc23270f6e1808e30a928bdc84dea0b9b4136a8bc82338574f23baf47bbd280"
[[package]]
name = "governor"
version = "0.10.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9efcab3c1958580ff1f25a2a41be1668f7603d849bb63af523b208a3cc1223b8"
dependencies = [
"cfg-if",
"dashmap",
"futures-sink",
"futures-timer",
"futures-util",
"hashbrown 0.16.1",
"nonzero_ext",
"parking_lot",
"portable-atomic",
"smallvec",
"spinning_top",
"web-time",
]
[[package]] [[package]]
name = "h2" name = "h2"
version = "0.4.13" version = "0.4.13"
@@ -1783,6 +1853,12 @@ dependencies = [
"tracing", "tracing",
] ]
[[package]]
name = "hashbrown"
version = "0.14.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1"
[[package]] [[package]]
name = "hashbrown" name = "hashbrown"
version = "0.15.5" version = "0.15.5"
@@ -2649,6 +2725,18 @@ dependencies = [
"minimal-lexical", "minimal-lexical",
] ]
[[package]]
name = "nonempty"
version = "0.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e9e591e719385e6ebaeb5ce5d3887f7d5676fceca6411d1925ccc95745f3d6f7"
[[package]]
name = "nonzero_ext"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "38bf9645c8b145698bb0b18a4637dcacbc421ea49bef2317e4fd8065a387cf21"
[[package]] [[package]]
name = "nu-ansi-term" name = "nu-ansi-term"
version = "0.50.3" version = "0.50.3"
@@ -3129,6 +3217,7 @@ dependencies = [
"async-trait", "async-trait",
"auth", "auth",
"axum", "axum",
"axum-governor",
"chrono", "chrono",
"doc", "doc",
"domain", "domain",
@@ -4127,6 +4216,15 @@ dependencies = [
"lock_api", "lock_api",
] ]
[[package]]
name = "spinning_top"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d96d2d1d716fb500937168cc09353ffdc7a012be8475ac7308e1bdf0e3923300"
dependencies = [
"lock_api",
]
[[package]] [[package]]
name = "spki" name = "spki"
version = "0.7.3" version = "0.7.3"

View File

@@ -5,6 +5,9 @@ edition = "2024"
[dependencies] [dependencies]
tower-http = { version = "0.6.8", features = ["fs", "trace", "tracing"] } tower-http = { version = "0.6.8", features = ["fs", "trace", "tracing"] }
infer = "0.19.0"
percent-encoding = "2"
axum-governor = "2"
axum = { workspace = true } axum = { workspace = true }
serde = { workspace = true } serde = { workspace = true }
@@ -35,8 +38,6 @@ rss = { workspace = true }
export = { workspace = true } export = { workspace = true }
doc = { workspace = true } doc = { workspace = true }
utoipa = { version = "5.5.0", features = ["axum_extras", "uuid"] } utoipa = { version = "5.5.0", features = ["axum_extras", "uuid"] }
infer = "0.19.0"
percent-encoding = "2"
[dev-dependencies] [dev-dependencies]
tower = { version = "0.5", features = ["util"] } tower = { version = "0.5", features = ["util"] }

View File

@@ -45,7 +45,7 @@ async fn main() -> anyhow::Result<()> {
let addr = format!("{}:{}", host, port); let addr = format!("{}:{}", host, port);
let listener = TcpListener::bind(&addr).await?; let listener = TcpListener::bind(&addr).await?;
tracing::info!("Listening on {}", addr); tracing::info!("Listening on {}", addr);
axum::serve(listener, app).await?; axum::serve(listener, app.into_make_service_with_connect_info::<std::net::SocketAddr>()).await?;
Ok(()) Ok(())
} }

View File

@@ -1,54 +1,12 @@
use std::sync::{ use std::net::SocketAddr;
Arc, use std::num::NonZeroU32;
atomic::{AtomicU64, Ordering},
};
use std::time::{SystemTime, UNIX_EPOCH};
use axum::{Router, http::StatusCode, middleware, response::IntoResponse, routing}; use axum::{Router, routing};
use axum_governor::{GovernorConfigBuilder, GovernorLayer, Quota, extractor::PeerIp};
use tower_http::{services::ServeDir, trace::TraceLayer}; use tower_http::{services::ServeDir, trace::TraceLayer};
use crate::{handlers, state::AppState}; use crate::{handlers, state::AppState};
/// Simple global rate limiter: tracks request count per 60-second window.
/// Not per-IP — suitable for a low-traffic personal app.
#[derive(Clone)]
struct RateLimiter {
window: Arc<AtomicU64>,
count: Arc<AtomicU64>,
limit: u64,
}
impl RateLimiter {
fn new(limit: u64) -> Self {
Self {
window: Arc::new(AtomicU64::new(0)),
count: Arc::new(AtomicU64::new(0)),
limit,
}
}
fn check(&self) -> bool {
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs()
/ 60;
let prev = self.window.load(Ordering::Acquire);
if now != prev {
// compare_exchange ensures only one thread wins the window reset
if self
.window
.compare_exchange(prev, now, Ordering::AcqRel, Ordering::Relaxed)
.is_ok()
{
self.count.store(1, Ordering::Release);
return true;
}
}
self.count.fetch_add(1, Ordering::Relaxed) + 1 <= self.limit
}
}
pub fn build_router(state: AppState, ap_router: Router) -> Router { pub fn build_router(state: AppState, ap_router: Router) -> Router {
let rate_limit = state.app_ctx.config.rate_limit; let rate_limit = state.app_ctx.config.rate_limit;
Router::new() Router::new()
@@ -60,8 +18,12 @@ pub fn build_router(state: AppState, ap_router: Router) -> Router {
.merge(ap_router) .merge(ap_router)
} }
fn per_minute(n: u64) -> Quota {
let n = NonZeroU32::new(n.clamp(1, u32::MAX as u64) as u32).unwrap();
Quota::requests_per_minute(n)
}
fn html_routes(rate_limit: u64) -> Router<AppState> { fn html_routes(rate_limit: u64) -> Router<AppState> {
let limiter = RateLimiter::new(rate_limit);
let auth = Router::new() let auth = Router::new()
.route( .route(
"/login", "/login",
@@ -72,18 +34,15 @@ fn html_routes(rate_limit: u64) -> Router<AppState> {
"/register", "/register",
routing::get(handlers::html::get_register_page).post(handlers::html::post_register), routing::get(handlers::html::get_register_page).post(handlers::html::post_register),
) )
.route_layer(middleware::from_fn( .layer({
move |req: axum::extract::Request, next: middleware::Next| { let cfg = GovernorConfigBuilder::default()
let limiter = limiter.clone(); .with_extractor(PeerIp::default())
async move { .expect_connect_info()
if limiter.check() { .quota_default(per_minute(rate_limit))
next.run(req).await .finish()
} else { .unwrap();
StatusCode::TOO_MANY_REQUESTS.into_response() GovernorLayer::new(cfg)
} });
}
},
));
Router::new() Router::new()
.route("/", routing::get(handlers::html::get_activity_feed)) .route("/", routing::get(handlers::html::get_activity_feed))
@@ -140,22 +99,16 @@ fn html_routes(rate_limit: u64) -> Router<AppState> {
"/users/{id}/feed.rss", "/users/{id}/feed.rss",
routing::get(handlers::rss::get_user_feed), routing::get(handlers::rss::get_user_feed),
) )
.layer(middleware::from_fn(crate::csrf::csrf_middleware)) .layer(axum::middleware::from_fn(crate::csrf::csrf_middleware))
} }
fn api_routes(rate_limit: u64) -> Router<AppState> { fn api_routes(rate_limit: u64) -> Router<AppState> {
let limiter = RateLimiter::new(rate_limit); let cfg = GovernorConfigBuilder::default()
let auth_rate_limit = .with_extractor(PeerIp::default())
middleware::from_fn(move |req: axum::extract::Request, next: middleware::Next| { .expect_connect_info()
let limiter = limiter.clone(); .quota_default(per_minute(rate_limit))
async move { .finish()
if limiter.check() { .unwrap();
next.run(req).await
} else {
StatusCode::TOO_MANY_REQUESTS.into_response()
}
}
});
Router::new().nest( Router::new().nest(
"/api/v1", "/api/v1",
@@ -188,6 +141,6 @@ fn api_routes(rate_limit: u64) -> Router<AppState> {
.route("/social/followers/accept", routing::post(handlers::api::accept_follower)) .route("/social/followers/accept", routing::post(handlers::api::accept_follower))
.route("/social/followers/reject", routing::post(handlers::api::reject_follower)) .route("/social/followers/reject", routing::post(handlers::api::reject_follower))
.route("/social/followers/remove", routing::post(handlers::api::remove_follower)) .route("/social/followers/remove", routing::post(handlers::api::remove_follower))
.route_layer(auth_rate_limit), .layer(GovernorLayer::new(cfg)),
) )
} }

View File

@@ -161,16 +161,25 @@ async fn test_app() -> Router {
routes::build_router(state, axum::Router::new()) routes::build_router(state, axum::Router::new())
} }
/// Inject a fake peer IP so the GovernorLayer can extract ConnectInfo.
fn with_ip(req: Request<Body>) -> Request<Body> {
let addr: std::net::SocketAddr = "127.0.0.1:12345".parse().unwrap();
let mut req = req;
req.extensions_mut()
.insert(axum::extract::ConnectInfo::<std::net::SocketAddr>(addr));
req
}
#[tokio::test] #[tokio::test]
async fn get_api_diary_returns_empty_list() { async fn get_api_diary_returns_empty_list() {
let app = test_app().await; let app = test_app().await;
let response = app let response = app
.oneshot( .oneshot(with_ip(
Request::builder() Request::builder()
.uri("/api/v1/diary") .uri("/api/v1/diary")
.body(Body::empty()) .body(Body::empty())
.unwrap(), .unwrap(),
) ))
.await .await
.unwrap(); .unwrap();
@@ -189,7 +198,7 @@ async fn get_api_diary_returns_empty_list() {
async fn post_api_reviews_without_auth_returns_401() { async fn post_api_reviews_without_auth_returns_401() {
let app = test_app().await; let app = test_app().await;
let response = app let response = app
.oneshot( .oneshot(with_ip(
Request::builder() Request::builder()
.method("POST") .method("POST")
.uri("/api/v1/reviews") .uri("/api/v1/reviews")
@@ -198,7 +207,7 @@ async fn post_api_reviews_without_auth_returns_401() {
r#"{"rating":4,"watched_at":"2026-01-01T20:00:00","manual_title":"Dune","manual_release_year":2021}"#, r#"{"rating":4,"watched_at":"2026-01-01T20:00:00","manual_title":"Dune","manual_release_year":2021}"#,
)) ))
.unwrap(), .unwrap(),
) ))
.await .await
.unwrap(); .unwrap();
@@ -209,14 +218,14 @@ async fn post_api_reviews_without_auth_returns_401() {
async fn post_api_auth_login_unknown_user_returns_401() { async fn post_api_auth_login_unknown_user_returns_401() {
let app = test_app().await; let app = test_app().await;
let response = app let response = app
.oneshot( .oneshot(with_ip(
Request::builder() Request::builder()
.method("POST") .method("POST")
.uri("/api/v1/auth/login") .uri("/api/v1/auth/login")
.header("content-type", "application/json") .header("content-type", "application/json")
.body(Body::from(r#"{"email":"a@b.com","password":"x"}"#)) .body(Body::from(r#"{"email":"a@b.com","password":"x"}"#))
.unwrap(), .unwrap(),
) ))
.await .await
.unwrap(); .unwrap();