feat: integrate axum-governor for rate limiting and update dependencies
This commit is contained in:
@@ -45,7 +45,7 @@ async fn main() -> anyhow::Result<()> {
|
||||
let addr = format!("{}:{}", host, port);
|
||||
let listener = TcpListener::bind(&addr).await?;
|
||||
tracing::info!("Listening on {}", addr);
|
||||
axum::serve(listener, app).await?;
|
||||
axum::serve(listener, app.into_make_service_with_connect_info::<std::net::SocketAddr>()).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -1,54 +1,12 @@
|
||||
use std::sync::{
|
||||
Arc,
|
||||
atomic::{AtomicU64, Ordering},
|
||||
};
|
||||
use std::time::{SystemTime, UNIX_EPOCH};
|
||||
use std::net::SocketAddr;
|
||||
use std::num::NonZeroU32;
|
||||
|
||||
use axum::{Router, http::StatusCode, middleware, response::IntoResponse, routing};
|
||||
use axum::{Router, routing};
|
||||
use axum_governor::{GovernorConfigBuilder, GovernorLayer, Quota, extractor::PeerIp};
|
||||
use tower_http::{services::ServeDir, trace::TraceLayer};
|
||||
|
||||
use crate::{handlers, state::AppState};
|
||||
|
||||
/// Simple global rate limiter: tracks request count per 60-second window.
|
||||
/// Not per-IP — suitable for a low-traffic personal app.
|
||||
#[derive(Clone)]
|
||||
struct RateLimiter {
|
||||
window: Arc<AtomicU64>,
|
||||
count: Arc<AtomicU64>,
|
||||
limit: u64,
|
||||
}
|
||||
|
||||
impl RateLimiter {
|
||||
fn new(limit: u64) -> Self {
|
||||
Self {
|
||||
window: Arc::new(AtomicU64::new(0)),
|
||||
count: Arc::new(AtomicU64::new(0)),
|
||||
limit,
|
||||
}
|
||||
}
|
||||
|
||||
fn check(&self) -> bool {
|
||||
let now = SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.unwrap_or_default()
|
||||
.as_secs()
|
||||
/ 60;
|
||||
let prev = self.window.load(Ordering::Acquire);
|
||||
if now != prev {
|
||||
// compare_exchange ensures only one thread wins the window reset
|
||||
if self
|
||||
.window
|
||||
.compare_exchange(prev, now, Ordering::AcqRel, Ordering::Relaxed)
|
||||
.is_ok()
|
||||
{
|
||||
self.count.store(1, Ordering::Release);
|
||||
return true;
|
||||
}
|
||||
}
|
||||
self.count.fetch_add(1, Ordering::Relaxed) + 1 <= self.limit
|
||||
}
|
||||
}
|
||||
|
||||
pub fn build_router(state: AppState, ap_router: Router) -> Router {
|
||||
let rate_limit = state.app_ctx.config.rate_limit;
|
||||
Router::new()
|
||||
@@ -60,8 +18,12 @@ pub fn build_router(state: AppState, ap_router: Router) -> Router {
|
||||
.merge(ap_router)
|
||||
}
|
||||
|
||||
fn per_minute(n: u64) -> Quota {
|
||||
let n = NonZeroU32::new(n.clamp(1, u32::MAX as u64) as u32).unwrap();
|
||||
Quota::requests_per_minute(n)
|
||||
}
|
||||
|
||||
fn html_routes(rate_limit: u64) -> Router<AppState> {
|
||||
let limiter = RateLimiter::new(rate_limit);
|
||||
let auth = Router::new()
|
||||
.route(
|
||||
"/login",
|
||||
@@ -72,18 +34,15 @@ fn html_routes(rate_limit: u64) -> Router<AppState> {
|
||||
"/register",
|
||||
routing::get(handlers::html::get_register_page).post(handlers::html::post_register),
|
||||
)
|
||||
.route_layer(middleware::from_fn(
|
||||
move |req: axum::extract::Request, next: middleware::Next| {
|
||||
let limiter = limiter.clone();
|
||||
async move {
|
||||
if limiter.check() {
|
||||
next.run(req).await
|
||||
} else {
|
||||
StatusCode::TOO_MANY_REQUESTS.into_response()
|
||||
}
|
||||
}
|
||||
},
|
||||
));
|
||||
.layer({
|
||||
let cfg = GovernorConfigBuilder::default()
|
||||
.with_extractor(PeerIp::default())
|
||||
.expect_connect_info()
|
||||
.quota_default(per_minute(rate_limit))
|
||||
.finish()
|
||||
.unwrap();
|
||||
GovernorLayer::new(cfg)
|
||||
});
|
||||
|
||||
Router::new()
|
||||
.route("/", routing::get(handlers::html::get_activity_feed))
|
||||
@@ -140,22 +99,16 @@ fn html_routes(rate_limit: u64) -> Router<AppState> {
|
||||
"/users/{id}/feed.rss",
|
||||
routing::get(handlers::rss::get_user_feed),
|
||||
)
|
||||
.layer(middleware::from_fn(crate::csrf::csrf_middleware))
|
||||
.layer(axum::middleware::from_fn(crate::csrf::csrf_middleware))
|
||||
}
|
||||
|
||||
fn api_routes(rate_limit: u64) -> Router<AppState> {
|
||||
let limiter = RateLimiter::new(rate_limit);
|
||||
let auth_rate_limit =
|
||||
middleware::from_fn(move |req: axum::extract::Request, next: middleware::Next| {
|
||||
let limiter = limiter.clone();
|
||||
async move {
|
||||
if limiter.check() {
|
||||
next.run(req).await
|
||||
} else {
|
||||
StatusCode::TOO_MANY_REQUESTS.into_response()
|
||||
}
|
||||
}
|
||||
});
|
||||
let cfg = GovernorConfigBuilder::default()
|
||||
.with_extractor(PeerIp::default())
|
||||
.expect_connect_info()
|
||||
.quota_default(per_minute(rate_limit))
|
||||
.finish()
|
||||
.unwrap();
|
||||
|
||||
Router::new().nest(
|
||||
"/api/v1",
|
||||
@@ -188,6 +141,6 @@ fn api_routes(rate_limit: u64) -> Router<AppState> {
|
||||
.route("/social/followers/accept", routing::post(handlers::api::accept_follower))
|
||||
.route("/social/followers/reject", routing::post(handlers::api::reject_follower))
|
||||
.route("/social/followers/remove", routing::post(handlers::api::remove_follower))
|
||||
.route_layer(auth_rate_limit),
|
||||
.layer(GovernorLayer::new(cfg)),
|
||||
)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user