diff --git a/crates/adapters/template-askama/templates/activity_feed.html b/crates/adapters/template-askama/templates/activity_feed.html
index 676df2d..bb0160e 100644
--- a/crates/adapters/template-askama/templates/activity_feed.html
+++ b/crates/adapters/template-askama/templates/activity_feed.html
@@ -38,6 +38,7 @@
{% if ctx.is_current_user(entry.review().user_id().value()) %}
{% endif %}
diff --git a/crates/adapters/template-askama/templates/diary.html b/crates/adapters/template-askama/templates/diary.html
index 16d184e..ec01a65 100644
--- a/crates/adapters/template-askama/templates/diary.html
+++ b/crates/adapters/template-askama/templates/diary.html
@@ -30,6 +30,7 @@
{% if let Some(uid) = ctx.user_id %}
{% if *uid == entry.review().user_id().value() %}
{% endif %}
diff --git a/crates/adapters/template-askama/templates/followers.html b/crates/adapters/template-askama/templates/followers.html
index e94cc10..e98d6e6 100644
--- a/crates/adapters/template-askama/templates/followers.html
+++ b/crates/adapters/template-askama/templates/followers.html
@@ -17,6 +17,7 @@
{{ actor.url }}
diff --git a/crates/adapters/template-askama/templates/following.html b/crates/adapters/template-askama/templates/following.html
index 689afab..4070a18 100644
--- a/crates/adapters/template-askama/templates/following.html
+++ b/crates/adapters/template-askama/templates/following.html
@@ -17,6 +17,7 @@
{{ actor.url }}
diff --git a/crates/adapters/template-askama/templates/login.html b/crates/adapters/template-askama/templates/login.html
index 967c22c..819662e 100644
--- a/crates/adapters/template-askama/templates/login.html
+++ b/crates/adapters/template-askama/templates/login.html
@@ -13,6 +13,7 @@
Password
+
{% endblock %}
diff --git a/crates/adapters/template-askama/templates/new_review.html b/crates/adapters/template-askama/templates/new_review.html
index f5ddbd0..c7cdd29 100644
--- a/crates/adapters/template-askama/templates/new_review.html
+++ b/crates/adapters/template-askama/templates/new_review.html
@@ -35,6 +35,7 @@
Comment
+
{% endblock %}
diff --git a/crates/adapters/template-askama/templates/profile.html b/crates/adapters/template-askama/templates/profile.html
index 7100a07..0da26ae 100644
--- a/crates/adapters/template-askama/templates/profile.html
+++ b/crates/adapters/template-askama/templates/profile.html
@@ -29,6 +29,7 @@
Follow remote user
{% if let Some(err) = error %}
@@ -47,10 +48,12 @@
{{ actor.url }}
@@ -183,6 +186,7 @@
{% if ctx.is_current_user(entry.review().user_id().value()) %}
{% endif %}
diff --git a/crates/adapters/template-askama/templates/register.html b/crates/adapters/template-askama/templates/register.html
index bb924fb..ed3b550 100644
--- a/crates/adapters/template-askama/templates/register.html
+++ b/crates/adapters/template-askama/templates/register.html
@@ -19,6 +19,7 @@
Password
+
{% endblock %}
diff --git a/crates/application/src/ports.rs b/crates/application/src/ports.rs
index 3ec7329..29d19dc 100644
--- a/crates/application/src/ports.rs
+++ b/crates/application/src/ports.rs
@@ -18,6 +18,7 @@ pub struct HtmlPageContext {
pub rss_url: String,
pub page_title: String,
pub canonical_url: String,
+ pub csrf_token: String,
}
impl HtmlPageContext {
diff --git a/crates/presentation/src/csrf.rs b/crates/presentation/src/csrf.rs
new file mode 100644
index 0000000..c233a06
--- /dev/null
+++ b/crates/presentation/src/csrf.rs
@@ -0,0 +1,58 @@
+use axum::{
+ extract::Request,
+ http::{HeaderValue, header},
+ middleware::Next,
+ response::Response,
+};
+
+#[derive(Clone)]
+pub struct CsrfToken(pub String);
+
+pub fn extract_from_cookie(headers: &axum::http::HeaderMap) -> Option {
+ headers
+ .get(header::COOKIE)
+ .and_then(|v| v.to_str().ok())
+ .and_then(|cookies| {
+ cookies
+ .split(';')
+ .find_map(|c| c.trim().strip_prefix("csrf=").map(str::to_string))
+ })
+}
+
+fn secure_flag() -> &'static str {
+ if std::env::var("SECURE_COOKIES").as_deref() == Ok("true") {
+ "; Secure"
+ } else {
+ ""
+ }
+}
+
+pub async fn csrf_middleware(mut req: Request, next: Next) -> Response {
+ let existing = extract_from_cookie(req.headers());
+ let (token, needs_set) = match existing {
+ Some(t) => (t, false),
+ None => (uuid::Uuid::new_v4().to_string(), true),
+ };
+
+ req.extensions_mut().insert(CsrfToken(token.clone()));
+
+ let mut response = next.run(req).await;
+
+ if needs_set {
+ let cookie = format!(
+ "csrf={}; HttpOnly; Path=/; SameSite=Strict{}",
+ token,
+ secure_flag()
+ );
+ if let Ok(val) = HeaderValue::from_str(&cookie) {
+ response.headers_mut().append(header::SET_COOKIE, val);
+ }
+ }
+
+ response
+}
+
+/// Returns true if the form token does not match the cookie token.
+pub fn mismatch(token: &CsrfToken, form_value: &str) -> bool {
+ token.0 != form_value || form_value.is_empty()
+}
diff --git a/crates/presentation/src/dtos.rs b/crates/presentation/src/dtos.rs
index 1839cfe..acf2cf8 100644
--- a/crates/presentation/src/dtos.rs
+++ b/crates/presentation/src/dtos.rs
@@ -41,12 +41,16 @@ pub struct LogReviewForm {
#[serde(default, deserialize_with = "empty_string_as_none")]
pub comment: Option,
pub watched_at: String,
+ #[serde(rename = "_csrf", default)]
+ pub csrf_token: String,
}
#[derive(Deserialize)]
pub struct LoginForm {
pub email: String,
pub password: String,
+ #[serde(rename = "_csrf", default)]
+ pub csrf_token: String,
}
#[derive(Deserialize)]
@@ -54,6 +58,8 @@ pub struct RegisterForm {
pub email: String,
pub username: String,
pub password: String,
+ #[serde(rename = "_csrf", default)]
+ pub csrf_token: String,
}
#[derive(Deserialize)]
@@ -65,6 +71,8 @@ pub struct ErrorQuery {
pub struct DeleteRedirectForm {
#[serde(default)]
pub redirect_after: Option,
+ #[serde(rename = "_csrf", default)]
+ pub csrf_token: String,
}
#[derive(Deserialize, utoipa::ToSchema)]
@@ -240,16 +248,22 @@ impl From for GetDiaryQuery {
#[derive(Deserialize)]
pub struct FollowForm {
pub handle: String,
+ #[serde(rename = "_csrf", default)]
+ pub csrf_token: String,
}
#[derive(Deserialize)]
pub struct UnfollowForm {
pub actor_url: String,
+ #[serde(rename = "_csrf", default)]
+ pub csrf_token: String,
}
#[derive(Deserialize)]
pub struct FollowerActionForm {
pub actor_url: String,
+ #[serde(rename = "_csrf", default)]
+ pub csrf_token: String,
}
#[derive(serde::Deserialize, Default)]
@@ -410,6 +424,7 @@ mod tests {
rating: 4,
comment: None,
watched_at: watched_at.to_string(),
+ csrf_token: String::new(),
}
}
diff --git a/crates/presentation/src/handlers.rs b/crates/presentation/src/handlers.rs
index 472ddf1..c9e61fe 100644
--- a/crates/presentation/src/handlers.rs
+++ b/crates/presentation/src/handlers.rs
@@ -6,7 +6,7 @@ pub mod html {
use axum::{
Form,
- extract::{Path, Query, State},
+ extract::{Extension, Path, Query, State},
http::{HeaderValue, StatusCode, header::SET_COOKIE},
response::{Html, IntoResponse, Redirect},
};
@@ -28,6 +28,7 @@ pub mod html {
use domain::{errors::DomainError, value_objects::UserId};
use crate::{
+ csrf::CsrfToken,
dtos::{
DiaryQueryParams, ErrorQuery, FollowForm, FollowerActionForm, LogReviewData,
LogReviewForm, LoginForm, RegisterForm, UnfollowForm,
@@ -36,7 +37,11 @@ pub mod html {
state::AppState,
};
- async fn build_page_context(state: &AppState, user_id: Option) -> HtmlPageContext {
+ async fn build_page_context(
+ state: &AppState,
+ user_id: Option,
+ csrf_token: String,
+ ) -> HtmlPageContext {
let uuid = user_id.as_ref().map(|u| u.value());
let user_email = if let Some(ref id) = user_id {
state
@@ -57,6 +62,7 @@ pub mod html {
rss_url: "/feed.rss".to_string(),
page_title: "Movies Diary".to_string(),
canonical_url: state.app_ctx.config.base_url.clone(),
+ csrf_token,
}
}
@@ -89,6 +95,7 @@ pub mod html {
pub async fn get_login_page(
State(state): State,
Query(params): Query,
+ Extension(csrf): Extension,
) -> impl IntoResponse {
let ctx = HtmlPageContext {
user_email: None,
@@ -97,6 +104,7 @@ pub mod html {
rss_url: "/feed.rss".to_string(),
page_title: "Login — Movies Diary".to_string(),
canonical_url: format!("{}/login", state.app_ctx.config.base_url),
+ csrf_token: csrf.0,
};
let html = state
.html_renderer
@@ -110,8 +118,12 @@ pub mod html {
pub async fn post_login(
State(state): State,
+ Extension(csrf): Extension,
Form(form): Form,
) -> impl IntoResponse {
+ if crate::csrf::mismatch(&csrf, &form.csrf_token) {
+ return StatusCode::FORBIDDEN.into_response();
+ }
match login_uc::execute(
&state.app_ctx,
LoginCommand {
@@ -145,6 +157,7 @@ pub mod html {
pub async fn get_register_page(
State(state): State,
Query(params): Query,
+ Extension(csrf): Extension,
) -> impl IntoResponse {
if !state.app_ctx.config.allow_registration {
return Redirect::to("/").into_response();
@@ -156,6 +169,7 @@ pub mod html {
rss_url: "/feed.rss".to_string(),
page_title: "Register — Movies Diary".to_string(),
canonical_url: format!("{}/register", state.app_ctx.config.base_url),
+ csrf_token: csrf.0,
};
let html = state
.html_renderer
@@ -169,11 +183,15 @@ pub mod html {
pub async fn post_register(
State(state): State,
+ Extension(csrf): Extension,
Form(form): Form,
) -> impl IntoResponse {
if !state.app_ctx.config.allow_registration {
return Redirect::to("/").into_response();
}
+ if crate::csrf::mismatch(&csrf, &form.csrf_token) {
+ return StatusCode::FORBIDDEN.into_response();
+ }
let email = form.email.clone();
let password = form.password.clone();
match register_uc::execute(
@@ -205,8 +223,9 @@ pub mod html {
RequiredCookieUser(user_id): RequiredCookieUser,
State(state): State,
Query(params): Query,
+ Extension(csrf): Extension,
) -> impl IntoResponse {
- let mut ctx = build_page_context(&state, Some(user_id)).await;
+ let mut ctx = build_page_context(&state, Some(user_id), csrf.0).await;
ctx.page_title = "Log a Review — Movies Diary".to_string();
ctx.canonical_url = format!("{}/reviews/new", state.app_ctx.config.base_url);
let html = state
@@ -222,8 +241,12 @@ pub mod html {
pub async fn post_review(
State(state): State,
RequiredCookieUser(user_id): RequiredCookieUser,
+ Extension(csrf): Extension,
Form(form): Form,
) -> impl IntoResponse {
+ if crate::csrf::mismatch(&csrf, &form.csrf_token) {
+ return StatusCode::FORBIDDEN.into_response();
+ }
let data = match LogReviewData::try_from(form) {
Ok(d) => d,
Err(_) => {
@@ -243,9 +266,13 @@ pub mod html {
pub async fn post_delete_review(
State(state): State,
RequiredCookieUser(user_id): RequiredCookieUser,
+ Extension(csrf): Extension,
Path(review_id): Path,
Form(form): Form,
) -> impl IntoResponse {
+ if crate::csrf::mismatch(&csrf, &form.csrf_token) {
+ return StatusCode::FORBIDDEN.into_response();
+ }
let cmd = DeleteReviewCommand {
review_id,
requesting_user_id: user_id.value(),
@@ -312,8 +339,9 @@ pub mod html {
OptionalCookieUser(user_id): OptionalCookieUser,
State(state): State,
Query(params): Query,
+ Extension(csrf): Extension,
) -> impl IntoResponse {
- let ctx = build_page_context(&state, user_id).await;
+ let ctx = build_page_context(&state, user_id, csrf.0).await;
let query = application::queries::GetActivityFeedQuery {
limit: params.limit,
offset: params.offset,
@@ -342,8 +370,9 @@ pub mod html {
pub async fn get_users_list(
OptionalCookieUser(user_id): OptionalCookieUser,
State(state): State,
+ Extension(csrf): Extension,
) -> impl IntoResponse {
- let mut ctx = build_page_context(&state, user_id).await;
+ let mut ctx = build_page_context(&state, user_id, csrf.0).await;
ctx.page_title = "Members — Movies Diary".to_string();
ctx.canonical_url = format!("{}/users", state.app_ctx.config.base_url);
match application::use_cases::get_users::execute(
@@ -369,6 +398,7 @@ pub mod html {
Path(profile_user_uuid): Path,
headers: axum::http::HeaderMap,
Query(params): Query,
+ Extension(csrf): Extension,
) -> impl IntoResponse {
// Content negotiation: AP clients request application/activity+json
let accept = headers
@@ -393,7 +423,7 @@ pub mod html {
};
}
- let mut ctx = build_page_context(&state, user_id.clone()).await;
+ let mut ctx = build_page_context(&state, user_id.clone(), csrf.0).await;
let view_str = params.view.as_deref().unwrap_or("recent");
let profile_view = match application::queries::ProfileView::from_str(view_str) {
Ok(v) => v,
@@ -520,11 +550,15 @@ pub mod html {
RequiredCookieUser(user_id): RequiredCookieUser,
State(state): State,
Path(profile_user_uuid): Path,
+ Extension(csrf): Extension,
Form(form): Form,
) -> impl IntoResponse {
if user_id.value() != profile_user_uuid {
return StatusCode::FORBIDDEN.into_response();
}
+ if crate::csrf::mismatch(&csrf, &form.csrf_token) {
+ return StatusCode::FORBIDDEN.into_response();
+ }
match state.ap_service.follow(user_id.value(), &form.handle).await {
Ok(()) => Redirect::to(&format!("/users/{}", profile_user_uuid)).into_response(),
Err(e) => {
@@ -539,11 +573,15 @@ pub mod html {
RequiredCookieUser(user_id): RequiredCookieUser,
State(state): State,
Path(profile_user_uuid): Path,
+ Extension(csrf): Extension,
Form(form): Form,
) -> impl IntoResponse {
if user_id.value() != profile_user_uuid {
return StatusCode::FORBIDDEN.into_response();
}
+ if crate::csrf::mismatch(&csrf, &form.csrf_token) {
+ return StatusCode::FORBIDDEN.into_response();
+ }
match state
.ap_service
.unfollow(user_id.value(), &form.actor_url)
@@ -566,11 +604,15 @@ pub mod html {
RequiredCookieUser(user_id): RequiredCookieUser,
State(state): State,
Path(profile_user_uuid): Path,
+ Extension(csrf): Extension,
Form(form): Form,
) -> impl IntoResponse {
if user_id.value() != profile_user_uuid {
return StatusCode::FORBIDDEN.into_response();
}
+ if crate::csrf::mismatch(&csrf, &form.csrf_token) {
+ return StatusCode::FORBIDDEN.into_response();
+ }
match state
.ap_service
.accept_follower(user_id.value(), &form.actor_url)
@@ -588,11 +630,15 @@ pub mod html {
RequiredCookieUser(user_id): RequiredCookieUser,
State(state): State,
Path(profile_user_uuid): Path,
+ Extension(csrf): Extension,
Form(form): Form,
) -> impl IntoResponse {
if user_id.value() != profile_user_uuid {
return StatusCode::FORBIDDEN.into_response();
}
+ if crate::csrf::mismatch(&csrf, &form.csrf_token) {
+ return StatusCode::FORBIDDEN.into_response();
+ }
match state
.ap_service
.reject_follower(user_id.value(), &form.actor_url)
@@ -611,11 +657,12 @@ pub mod html {
State(state): State,
Path(profile_user_uuid): Path,
Query(params): Query,
+ Extension(csrf): Extension,
) -> impl IntoResponse {
if user_id.value() != profile_user_uuid {
return StatusCode::FORBIDDEN.into_response();
}
- let mut ctx = build_page_context(&state, Some(user_id.clone())).await;
+ let mut ctx = build_page_context(&state, Some(user_id.clone()), csrf.0).await;
ctx.page_title = "Following — Movies Diary".to_string();
ctx.canonical_url = format!(
"{}/users/{}/following-list",
@@ -658,11 +705,12 @@ pub mod html {
State(state): State,
Path(profile_user_uuid): Path,
Query(params): Query,
+ Extension(csrf): Extension,
) -> impl IntoResponse {
if user_id.value() != profile_user_uuid {
return StatusCode::FORBIDDEN.into_response();
}
- let mut ctx = build_page_context(&state, Some(user_id.clone())).await;
+ let mut ctx = build_page_context(&state, Some(user_id.clone()), csrf.0).await;
ctx.page_title = "Followers — Movies Diary".to_string();
ctx.canonical_url = format!(
"{}/users/{}/followers-list",
@@ -708,11 +756,15 @@ pub mod html {
RequiredCookieUser(user_id): RequiredCookieUser,
State(state): State,
Path(profile_user_uuid): Path,
+ Extension(csrf): Extension,
Form(form): Form,
) -> impl IntoResponse {
if user_id.value() != profile_user_uuid {
return StatusCode::FORBIDDEN.into_response();
}
+ if crate::csrf::mismatch(&csrf, &form.csrf_token) {
+ return StatusCode::FORBIDDEN.into_response();
+ }
match state
.ap_service
.remove_follower(user_id.value(), &form.actor_url)
diff --git a/crates/presentation/src/lib.rs b/crates/presentation/src/lib.rs
index c8014cd..1e00fbe 100644
--- a/crates/presentation/src/lib.rs
+++ b/crates/presentation/src/lib.rs
@@ -1,3 +1,4 @@
+pub mod csrf;
pub mod dtos;
pub mod errors;
pub mod event_handlers;
diff --git a/crates/presentation/src/routes.rs b/crates/presentation/src/routes.rs
index a3de5f4..72f973e 100644
--- a/crates/presentation/src/routes.rs
+++ b/crates/presentation/src/routes.rs
@@ -140,6 +140,7 @@ fn html_routes(rate_limit: u64) -> Router {
"/users/{id}/feed.rss",
routing::get(handlers::rss::get_user_feed),
)
+ .layer(middleware::from_fn(crate::csrf::csrf_middleware))
}
fn api_routes(rate_limit: u64) -> Router {