From d89d373a91030257df96051598a7bde693a08cf5 Mon Sep 17 00:00:00 2001 From: Gabriel Kaszewski Date: Sat, 9 May 2026 22:09:19 +0200 Subject: [PATCH] feat: implement CSRF protection across forms and routes --- .../templates/activity_feed.html | 1 + .../template-askama/templates/diary.html | 1 + .../template-askama/templates/followers.html | 1 + .../template-askama/templates/following.html | 1 + .../template-askama/templates/login.html | 1 + .../template-askama/templates/new_review.html | 1 + .../template-askama/templates/profile.html | 4 ++ .../template-askama/templates/register.html | 1 + crates/application/src/ports.rs | 1 + crates/presentation/src/csrf.rs | 58 ++++++++++++++++ crates/presentation/src/dtos.rs | 15 ++++ crates/presentation/src/handlers.rs | 68 ++++++++++++++++--- crates/presentation/src/lib.rs | 1 + crates/presentation/src/routes.rs | 1 + 14 files changed, 147 insertions(+), 8 deletions(-) create mode 100644 crates/presentation/src/csrf.rs diff --git a/crates/adapters/template-askama/templates/activity_feed.html b/crates/adapters/template-askama/templates/activity_feed.html index 676df2d..bb0160e 100644 --- a/crates/adapters/template-askama/templates/activity_feed.html +++ b/crates/adapters/template-askama/templates/activity_feed.html @@ -38,6 +38,7 @@ {% if ctx.is_current_user(entry.review().user_id().value()) %}
+
{% endif %} diff --git a/crates/adapters/template-askama/templates/diary.html b/crates/adapters/template-askama/templates/diary.html index 16d184e..ec01a65 100644 --- a/crates/adapters/template-askama/templates/diary.html +++ b/crates/adapters/template-askama/templates/diary.html @@ -30,6 +30,7 @@ {% if let Some(uid) = ctx.user_id %} {% if *uid == entry.review().user_id().value() %}
+
{% endif %} diff --git a/crates/adapters/template-askama/templates/followers.html b/crates/adapters/template-askama/templates/followers.html index e94cc10..e98d6e6 100644 --- a/crates/adapters/template-askama/templates/followers.html +++ b/crates/adapters/template-askama/templates/followers.html @@ -17,6 +17,7 @@ {{ actor.url }}
+
diff --git a/crates/adapters/template-askama/templates/following.html b/crates/adapters/template-askama/templates/following.html index 689afab..4070a18 100644 --- a/crates/adapters/template-askama/templates/following.html +++ b/crates/adapters/template-askama/templates/following.html @@ -17,6 +17,7 @@ {{ actor.url }}
+
diff --git a/crates/adapters/template-askama/templates/login.html b/crates/adapters/template-askama/templates/login.html index 967c22c..819662e 100644 --- a/crates/adapters/template-askama/templates/login.html +++ b/crates/adapters/template-askama/templates/login.html @@ -13,6 +13,7 @@ Password
+ {% endblock %} diff --git a/crates/adapters/template-askama/templates/new_review.html b/crates/adapters/template-askama/templates/new_review.html index f5ddbd0..c7cdd29 100644 --- a/crates/adapters/template-askama/templates/new_review.html +++ b/crates/adapters/template-askama/templates/new_review.html @@ -35,6 +35,7 @@ Comment
+ {% endblock %} diff --git a/crates/adapters/template-askama/templates/profile.html b/crates/adapters/template-askama/templates/profile.html index 7100a07..0da26ae 100644 --- a/crates/adapters/template-askama/templates/profile.html +++ b/crates/adapters/template-askama/templates/profile.html @@ -29,6 +29,7 @@

Follow remote user

+
{% if let Some(err) = error %} @@ -47,10 +48,12 @@ {{ actor.url }}
+
+
@@ -183,6 +186,7 @@ {% if ctx.is_current_user(entry.review().user_id().value()) %}
+
{% endif %} diff --git a/crates/adapters/template-askama/templates/register.html b/crates/adapters/template-askama/templates/register.html index bb924fb..ed3b550 100644 --- a/crates/adapters/template-askama/templates/register.html +++ b/crates/adapters/template-askama/templates/register.html @@ -19,6 +19,7 @@ Password
+ {% endblock %} diff --git a/crates/application/src/ports.rs b/crates/application/src/ports.rs index 3ec7329..29d19dc 100644 --- a/crates/application/src/ports.rs +++ b/crates/application/src/ports.rs @@ -18,6 +18,7 @@ pub struct HtmlPageContext { pub rss_url: String, pub page_title: String, pub canonical_url: String, + pub csrf_token: String, } impl HtmlPageContext { diff --git a/crates/presentation/src/csrf.rs b/crates/presentation/src/csrf.rs new file mode 100644 index 0000000..c233a06 --- /dev/null +++ b/crates/presentation/src/csrf.rs @@ -0,0 +1,58 @@ +use axum::{ + extract::Request, + http::{HeaderValue, header}, + middleware::Next, + response::Response, +}; + +#[derive(Clone)] +pub struct CsrfToken(pub String); + +pub fn extract_from_cookie(headers: &axum::http::HeaderMap) -> Option { + headers + .get(header::COOKIE) + .and_then(|v| v.to_str().ok()) + .and_then(|cookies| { + cookies + .split(';') + .find_map(|c| c.trim().strip_prefix("csrf=").map(str::to_string)) + }) +} + +fn secure_flag() -> &'static str { + if std::env::var("SECURE_COOKIES").as_deref() == Ok("true") { + "; Secure" + } else { + "" + } +} + +pub async fn csrf_middleware(mut req: Request, next: Next) -> Response { + let existing = extract_from_cookie(req.headers()); + let (token, needs_set) = match existing { + Some(t) => (t, false), + None => (uuid::Uuid::new_v4().to_string(), true), + }; + + req.extensions_mut().insert(CsrfToken(token.clone())); + + let mut response = next.run(req).await; + + if needs_set { + let cookie = format!( + "csrf={}; HttpOnly; Path=/; SameSite=Strict{}", + token, + secure_flag() + ); + if let Ok(val) = HeaderValue::from_str(&cookie) { + response.headers_mut().append(header::SET_COOKIE, val); + } + } + + response +} + +/// Returns true if the form token does not match the cookie token. +pub fn mismatch(token: &CsrfToken, form_value: &str) -> bool { + token.0 != form_value || form_value.is_empty() +} diff --git a/crates/presentation/src/dtos.rs b/crates/presentation/src/dtos.rs index 1839cfe..acf2cf8 100644 --- a/crates/presentation/src/dtos.rs +++ b/crates/presentation/src/dtos.rs @@ -41,12 +41,16 @@ pub struct LogReviewForm { #[serde(default, deserialize_with = "empty_string_as_none")] pub comment: Option, pub watched_at: String, + #[serde(rename = "_csrf", default)] + pub csrf_token: String, } #[derive(Deserialize)] pub struct LoginForm { pub email: String, pub password: String, + #[serde(rename = "_csrf", default)] + pub csrf_token: String, } #[derive(Deserialize)] @@ -54,6 +58,8 @@ pub struct RegisterForm { pub email: String, pub username: String, pub password: String, + #[serde(rename = "_csrf", default)] + pub csrf_token: String, } #[derive(Deserialize)] @@ -65,6 +71,8 @@ pub struct ErrorQuery { pub struct DeleteRedirectForm { #[serde(default)] pub redirect_after: Option, + #[serde(rename = "_csrf", default)] + pub csrf_token: String, } #[derive(Deserialize, utoipa::ToSchema)] @@ -240,16 +248,22 @@ impl From for GetDiaryQuery { #[derive(Deserialize)] pub struct FollowForm { pub handle: String, + #[serde(rename = "_csrf", default)] + pub csrf_token: String, } #[derive(Deserialize)] pub struct UnfollowForm { pub actor_url: String, + #[serde(rename = "_csrf", default)] + pub csrf_token: String, } #[derive(Deserialize)] pub struct FollowerActionForm { pub actor_url: String, + #[serde(rename = "_csrf", default)] + pub csrf_token: String, } #[derive(serde::Deserialize, Default)] @@ -410,6 +424,7 @@ mod tests { rating: 4, comment: None, watched_at: watched_at.to_string(), + csrf_token: String::new(), } } diff --git a/crates/presentation/src/handlers.rs b/crates/presentation/src/handlers.rs index 472ddf1..c9e61fe 100644 --- a/crates/presentation/src/handlers.rs +++ b/crates/presentation/src/handlers.rs @@ -6,7 +6,7 @@ pub mod html { use axum::{ Form, - extract::{Path, Query, State}, + extract::{Extension, Path, Query, State}, http::{HeaderValue, StatusCode, header::SET_COOKIE}, response::{Html, IntoResponse, Redirect}, }; @@ -28,6 +28,7 @@ pub mod html { use domain::{errors::DomainError, value_objects::UserId}; use crate::{ + csrf::CsrfToken, dtos::{ DiaryQueryParams, ErrorQuery, FollowForm, FollowerActionForm, LogReviewData, LogReviewForm, LoginForm, RegisterForm, UnfollowForm, @@ -36,7 +37,11 @@ pub mod html { state::AppState, }; - async fn build_page_context(state: &AppState, user_id: Option) -> HtmlPageContext { + async fn build_page_context( + state: &AppState, + user_id: Option, + csrf_token: String, + ) -> HtmlPageContext { let uuid = user_id.as_ref().map(|u| u.value()); let user_email = if let Some(ref id) = user_id { state @@ -57,6 +62,7 @@ pub mod html { rss_url: "/feed.rss".to_string(), page_title: "Movies Diary".to_string(), canonical_url: state.app_ctx.config.base_url.clone(), + csrf_token, } } @@ -89,6 +95,7 @@ pub mod html { pub async fn get_login_page( State(state): State, Query(params): Query, + Extension(csrf): Extension, ) -> impl IntoResponse { let ctx = HtmlPageContext { user_email: None, @@ -97,6 +104,7 @@ pub mod html { rss_url: "/feed.rss".to_string(), page_title: "Login — Movies Diary".to_string(), canonical_url: format!("{}/login", state.app_ctx.config.base_url), + csrf_token: csrf.0, }; let html = state .html_renderer @@ -110,8 +118,12 @@ pub mod html { pub async fn post_login( State(state): State, + Extension(csrf): Extension, Form(form): Form, ) -> impl IntoResponse { + if crate::csrf::mismatch(&csrf, &form.csrf_token) { + return StatusCode::FORBIDDEN.into_response(); + } match login_uc::execute( &state.app_ctx, LoginCommand { @@ -145,6 +157,7 @@ pub mod html { pub async fn get_register_page( State(state): State, Query(params): Query, + Extension(csrf): Extension, ) -> impl IntoResponse { if !state.app_ctx.config.allow_registration { return Redirect::to("/").into_response(); @@ -156,6 +169,7 @@ pub mod html { rss_url: "/feed.rss".to_string(), page_title: "Register — Movies Diary".to_string(), canonical_url: format!("{}/register", state.app_ctx.config.base_url), + csrf_token: csrf.0, }; let html = state .html_renderer @@ -169,11 +183,15 @@ pub mod html { pub async fn post_register( State(state): State, + Extension(csrf): Extension, Form(form): Form, ) -> impl IntoResponse { if !state.app_ctx.config.allow_registration { return Redirect::to("/").into_response(); } + if crate::csrf::mismatch(&csrf, &form.csrf_token) { + return StatusCode::FORBIDDEN.into_response(); + } let email = form.email.clone(); let password = form.password.clone(); match register_uc::execute( @@ -205,8 +223,9 @@ pub mod html { RequiredCookieUser(user_id): RequiredCookieUser, State(state): State, Query(params): Query, + Extension(csrf): Extension, ) -> impl IntoResponse { - let mut ctx = build_page_context(&state, Some(user_id)).await; + let mut ctx = build_page_context(&state, Some(user_id), csrf.0).await; ctx.page_title = "Log a Review — Movies Diary".to_string(); ctx.canonical_url = format!("{}/reviews/new", state.app_ctx.config.base_url); let html = state @@ -222,8 +241,12 @@ pub mod html { pub async fn post_review( State(state): State, RequiredCookieUser(user_id): RequiredCookieUser, + Extension(csrf): Extension, Form(form): Form, ) -> impl IntoResponse { + if crate::csrf::mismatch(&csrf, &form.csrf_token) { + return StatusCode::FORBIDDEN.into_response(); + } let data = match LogReviewData::try_from(form) { Ok(d) => d, Err(_) => { @@ -243,9 +266,13 @@ pub mod html { pub async fn post_delete_review( State(state): State, RequiredCookieUser(user_id): RequiredCookieUser, + Extension(csrf): Extension, Path(review_id): Path, Form(form): Form, ) -> impl IntoResponse { + if crate::csrf::mismatch(&csrf, &form.csrf_token) { + return StatusCode::FORBIDDEN.into_response(); + } let cmd = DeleteReviewCommand { review_id, requesting_user_id: user_id.value(), @@ -312,8 +339,9 @@ pub mod html { OptionalCookieUser(user_id): OptionalCookieUser, State(state): State, Query(params): Query, + Extension(csrf): Extension, ) -> impl IntoResponse { - let ctx = build_page_context(&state, user_id).await; + let ctx = build_page_context(&state, user_id, csrf.0).await; let query = application::queries::GetActivityFeedQuery { limit: params.limit, offset: params.offset, @@ -342,8 +370,9 @@ pub mod html { pub async fn get_users_list( OptionalCookieUser(user_id): OptionalCookieUser, State(state): State, + Extension(csrf): Extension, ) -> impl IntoResponse { - let mut ctx = build_page_context(&state, user_id).await; + let mut ctx = build_page_context(&state, user_id, csrf.0).await; ctx.page_title = "Members — Movies Diary".to_string(); ctx.canonical_url = format!("{}/users", state.app_ctx.config.base_url); match application::use_cases::get_users::execute( @@ -369,6 +398,7 @@ pub mod html { Path(profile_user_uuid): Path, headers: axum::http::HeaderMap, Query(params): Query, + Extension(csrf): Extension, ) -> impl IntoResponse { // Content negotiation: AP clients request application/activity+json let accept = headers @@ -393,7 +423,7 @@ pub mod html { }; } - let mut ctx = build_page_context(&state, user_id.clone()).await; + let mut ctx = build_page_context(&state, user_id.clone(), csrf.0).await; let view_str = params.view.as_deref().unwrap_or("recent"); let profile_view = match application::queries::ProfileView::from_str(view_str) { Ok(v) => v, @@ -520,11 +550,15 @@ pub mod html { RequiredCookieUser(user_id): RequiredCookieUser, State(state): State, Path(profile_user_uuid): Path, + Extension(csrf): Extension, Form(form): Form, ) -> impl IntoResponse { if user_id.value() != profile_user_uuid { return StatusCode::FORBIDDEN.into_response(); } + if crate::csrf::mismatch(&csrf, &form.csrf_token) { + return StatusCode::FORBIDDEN.into_response(); + } match state.ap_service.follow(user_id.value(), &form.handle).await { Ok(()) => Redirect::to(&format!("/users/{}", profile_user_uuid)).into_response(), Err(e) => { @@ -539,11 +573,15 @@ pub mod html { RequiredCookieUser(user_id): RequiredCookieUser, State(state): State, Path(profile_user_uuid): Path, + Extension(csrf): Extension, Form(form): Form, ) -> impl IntoResponse { if user_id.value() != profile_user_uuid { return StatusCode::FORBIDDEN.into_response(); } + if crate::csrf::mismatch(&csrf, &form.csrf_token) { + return StatusCode::FORBIDDEN.into_response(); + } match state .ap_service .unfollow(user_id.value(), &form.actor_url) @@ -566,11 +604,15 @@ pub mod html { RequiredCookieUser(user_id): RequiredCookieUser, State(state): State, Path(profile_user_uuid): Path, + Extension(csrf): Extension, Form(form): Form, ) -> impl IntoResponse { if user_id.value() != profile_user_uuid { return StatusCode::FORBIDDEN.into_response(); } + if crate::csrf::mismatch(&csrf, &form.csrf_token) { + return StatusCode::FORBIDDEN.into_response(); + } match state .ap_service .accept_follower(user_id.value(), &form.actor_url) @@ -588,11 +630,15 @@ pub mod html { RequiredCookieUser(user_id): RequiredCookieUser, State(state): State, Path(profile_user_uuid): Path, + Extension(csrf): Extension, Form(form): Form, ) -> impl IntoResponse { if user_id.value() != profile_user_uuid { return StatusCode::FORBIDDEN.into_response(); } + if crate::csrf::mismatch(&csrf, &form.csrf_token) { + return StatusCode::FORBIDDEN.into_response(); + } match state .ap_service .reject_follower(user_id.value(), &form.actor_url) @@ -611,11 +657,12 @@ pub mod html { State(state): State, Path(profile_user_uuid): Path, Query(params): Query, + Extension(csrf): Extension, ) -> impl IntoResponse { if user_id.value() != profile_user_uuid { return StatusCode::FORBIDDEN.into_response(); } - let mut ctx = build_page_context(&state, Some(user_id.clone())).await; + let mut ctx = build_page_context(&state, Some(user_id.clone()), csrf.0).await; ctx.page_title = "Following — Movies Diary".to_string(); ctx.canonical_url = format!( "{}/users/{}/following-list", @@ -658,11 +705,12 @@ pub mod html { State(state): State, Path(profile_user_uuid): Path, Query(params): Query, + Extension(csrf): Extension, ) -> impl IntoResponse { if user_id.value() != profile_user_uuid { return StatusCode::FORBIDDEN.into_response(); } - let mut ctx = build_page_context(&state, Some(user_id.clone())).await; + let mut ctx = build_page_context(&state, Some(user_id.clone()), csrf.0).await; ctx.page_title = "Followers — Movies Diary".to_string(); ctx.canonical_url = format!( "{}/users/{}/followers-list", @@ -708,11 +756,15 @@ pub mod html { RequiredCookieUser(user_id): RequiredCookieUser, State(state): State, Path(profile_user_uuid): Path, + Extension(csrf): Extension, Form(form): Form, ) -> impl IntoResponse { if user_id.value() != profile_user_uuid { return StatusCode::FORBIDDEN.into_response(); } + if crate::csrf::mismatch(&csrf, &form.csrf_token) { + return StatusCode::FORBIDDEN.into_response(); + } match state .ap_service .remove_follower(user_id.value(), &form.actor_url) diff --git a/crates/presentation/src/lib.rs b/crates/presentation/src/lib.rs index c8014cd..1e00fbe 100644 --- a/crates/presentation/src/lib.rs +++ b/crates/presentation/src/lib.rs @@ -1,3 +1,4 @@ +pub mod csrf; pub mod dtos; pub mod errors; pub mod event_handlers; diff --git a/crates/presentation/src/routes.rs b/crates/presentation/src/routes.rs index a3de5f4..72f973e 100644 --- a/crates/presentation/src/routes.rs +++ b/crates/presentation/src/routes.rs @@ -140,6 +140,7 @@ fn html_routes(rate_limit: u64) -> Router { "/users/{id}/feed.rss", routing::get(handlers::rss::get_user_feed), ) + .layer(middleware::from_fn(crate::csrf::csrf_middleware)) } fn api_routes(rate_limit: u64) -> Router {