feat: integrate axum-governor for rate limiting and update dependencies

This commit is contained in:
2026-05-09 22:35:08 +02:00
parent d89d373a91
commit a078d5315e
5 changed files with 143 additions and 82 deletions

View File

@@ -45,7 +45,7 @@ async fn main() -> anyhow::Result<()> {
let addr = format!("{}:{}", host, port);
let listener = TcpListener::bind(&addr).await?;
tracing::info!("Listening on {}", addr);
axum::serve(listener, app).await?;
axum::serve(listener, app.into_make_service_with_connect_info::<std::net::SocketAddr>()).await?;
Ok(())
}

View File

@@ -1,54 +1,12 @@
use std::sync::{
Arc,
atomic::{AtomicU64, Ordering},
};
use std::time::{SystemTime, UNIX_EPOCH};
use std::net::SocketAddr;
use std::num::NonZeroU32;
use axum::{Router, http::StatusCode, middleware, response::IntoResponse, routing};
use axum::{Router, routing};
use axum_governor::{GovernorConfigBuilder, GovernorLayer, Quota, extractor::PeerIp};
use tower_http::{services::ServeDir, trace::TraceLayer};
use crate::{handlers, state::AppState};
/// Simple global rate limiter: tracks request count per 60-second window.
/// Not per-IP — suitable for a low-traffic personal app.
#[derive(Clone)]
struct RateLimiter {
window: Arc<AtomicU64>,
count: Arc<AtomicU64>,
limit: u64,
}
impl RateLimiter {
fn new(limit: u64) -> Self {
Self {
window: Arc::new(AtomicU64::new(0)),
count: Arc::new(AtomicU64::new(0)),
limit,
}
}
fn check(&self) -> bool {
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs()
/ 60;
let prev = self.window.load(Ordering::Acquire);
if now != prev {
// compare_exchange ensures only one thread wins the window reset
if self
.window
.compare_exchange(prev, now, Ordering::AcqRel, Ordering::Relaxed)
.is_ok()
{
self.count.store(1, Ordering::Release);
return true;
}
}
self.count.fetch_add(1, Ordering::Relaxed) + 1 <= self.limit
}
}
pub fn build_router(state: AppState, ap_router: Router) -> Router {
let rate_limit = state.app_ctx.config.rate_limit;
Router::new()
@@ -60,8 +18,12 @@ pub fn build_router(state: AppState, ap_router: Router) -> Router {
.merge(ap_router)
}
fn per_minute(n: u64) -> Quota {
let n = NonZeroU32::new(n.clamp(1, u32::MAX as u64) as u32).unwrap();
Quota::requests_per_minute(n)
}
fn html_routes(rate_limit: u64) -> Router<AppState> {
let limiter = RateLimiter::new(rate_limit);
let auth = Router::new()
.route(
"/login",
@@ -72,18 +34,15 @@ fn html_routes(rate_limit: u64) -> Router<AppState> {
"/register",
routing::get(handlers::html::get_register_page).post(handlers::html::post_register),
)
.route_layer(middleware::from_fn(
move |req: axum::extract::Request, next: middleware::Next| {
let limiter = limiter.clone();
async move {
if limiter.check() {
next.run(req).await
} else {
StatusCode::TOO_MANY_REQUESTS.into_response()
}
}
},
));
.layer({
let cfg = GovernorConfigBuilder::default()
.with_extractor(PeerIp::default())
.expect_connect_info()
.quota_default(per_minute(rate_limit))
.finish()
.unwrap();
GovernorLayer::new(cfg)
});
Router::new()
.route("/", routing::get(handlers::html::get_activity_feed))
@@ -140,22 +99,16 @@ fn html_routes(rate_limit: u64) -> Router<AppState> {
"/users/{id}/feed.rss",
routing::get(handlers::rss::get_user_feed),
)
.layer(middleware::from_fn(crate::csrf::csrf_middleware))
.layer(axum::middleware::from_fn(crate::csrf::csrf_middleware))
}
fn api_routes(rate_limit: u64) -> Router<AppState> {
let limiter = RateLimiter::new(rate_limit);
let auth_rate_limit =
middleware::from_fn(move |req: axum::extract::Request, next: middleware::Next| {
let limiter = limiter.clone();
async move {
if limiter.check() {
next.run(req).await
} else {
StatusCode::TOO_MANY_REQUESTS.into_response()
}
}
});
let cfg = GovernorConfigBuilder::default()
.with_extractor(PeerIp::default())
.expect_connect_info()
.quota_default(per_minute(rate_limit))
.finish()
.unwrap();
Router::new().nest(
"/api/v1",
@@ -188,6 +141,6 @@ fn api_routes(rate_limit: u64) -> Router<AppState> {
.route("/social/followers/accept", routing::post(handlers::api::accept_follower))
.route("/social/followers/reject", routing::post(handlers::api::reject_follower))
.route("/social/followers/remove", routing::post(handlers::api::remove_follower))
.route_layer(auth_rate_limit),
.layer(GovernorLayer::new(cfg)),
)
}