feat(auth): refresh tokens + remember me

Backend: add refresh JWT (30d, token_type claim), POST /auth/refresh
endpoint (rotates token pair), remember_me on login, JWT_REFRESH_EXPIRY_DAYS
env var. Extractors now reject refresh tokens on protected routes.

Frontend: sessionStorage for non-remembered sessions, localStorage +
refresh token for remembered sessions. Transparent 401 recovery in
api.ts (retry once after refresh). Remember me checkbox on login page
with security note when checked.
This commit is contained in:
2026-03-19 22:24:26 +01:00
parent 8bdd5e2277
commit d2412da057
13 changed files with 307 additions and 35 deletions

View File

@@ -36,6 +36,7 @@ pub struct Config {
pub jwt_issuer: Option<String>, pub jwt_issuer: Option<String>,
pub jwt_audience: Option<String>, pub jwt_audience: Option<String>,
pub jwt_expiry_hours: u64, pub jwt_expiry_hours: u64,
pub jwt_refresh_expiry_days: u64,
/// Whether the application is running in production mode /// Whether the application is running in production mode
pub is_production: bool, pub is_production: bool,
@@ -117,6 +118,11 @@ impl Config {
.and_then(|s| s.parse().ok()) .and_then(|s| s.parse().ok())
.unwrap_or(24); .unwrap_or(24);
let jwt_refresh_expiry_days = env::var("JWT_REFRESH_EXPIRY_DAYS")
.ok()
.and_then(|s| s.parse().ok())
.unwrap_or(30);
let is_production = env::var("PRODUCTION") let is_production = env::var("PRODUCTION")
.or_else(|_| env::var("RUST_ENV")) .or_else(|_| env::var("RUST_ENV"))
.map(|v| v.to_lowercase() == "production" || v == "1" || v == "true") .map(|v| v.to_lowercase() == "production" || v == "1" || v == "true")
@@ -165,6 +171,7 @@ impl Config {
jwt_issuer, jwt_issuer,
jwt_audience, jwt_audience,
jwt_expiry_hours, jwt_expiry_hours,
jwt_refresh_expiry_days,
is_production, is_production,
allow_registration, allow_registration,
jellyfin_base_url, jellyfin_base_url,

View File

@@ -15,6 +15,15 @@ pub struct LoginRequest {
pub email: Email, pub email: Email,
/// Password is validated on deserialization (min 8 chars) /// Password is validated on deserialization (min 8 chars)
pub password: Password, pub password: Password,
/// When true, a refresh token is also issued for persistent sessions
#[serde(default)]
pub remember_me: bool,
}
/// Refresh token request
#[derive(Debug, Deserialize)]
pub struct RefreshRequest {
pub refresh_token: String,
} }
/// Register request with validated email and password newtypes /// Register request with validated email and password newtypes
@@ -41,6 +50,9 @@ pub struct TokenResponse {
pub access_token: String, pub access_token: String,
pub token_type: String, pub token_type: String,
pub expires_in: u64, pub expires_in: u64,
/// Only present when remember_me was true at login, or on token refresh
#[serde(skip_serializing_if = "Option::is_none")]
pub refresh_token: Option<String>,
} }
/// Per-provider info returned by `GET /config`. /// Per-provider info returned by `GET /config`.

View File

@@ -122,7 +122,7 @@ pub(crate) async fn validate_jwt_token(token: &str, state: &AppState) -> Result<
.as_ref() .as_ref()
.ok_or_else(|| ApiError::Internal("JWT validator not configured".to_string()))?; .ok_or_else(|| ApiError::Internal("JWT validator not configured".to_string()))?;
let claims = validator.validate_token(token).map_err(|e| { let claims = validator.validate_access_token(token).map_err(|e| {
tracing::debug!("JWT validation failed: {:?}", e); tracing::debug!("JWT validation failed: {:?}", e);
match e { match e {
infra::auth::jwt::JwtError::Expired => { infra::auth::jwt::JwtError::Expired => {

View File

@@ -6,13 +6,13 @@ use axum::{
}; };
use crate::{ use crate::{
dto::{LoginRequest, RegisterRequest, TokenResponse, UserResponse}, dto::{LoginRequest, RefreshRequest, RegisterRequest, TokenResponse, UserResponse},
error::ApiError, error::ApiError,
extractors::CurrentUser, extractors::CurrentUser,
state::AppState, state::AppState,
}; };
use super::create_jwt; use super::{create_jwt, create_refresh_jwt};
/// Login with email + password → JWT token /// Login with email + password → JWT token
pub(super) async fn login( pub(super) async fn login(
@@ -35,6 +35,11 @@ pub(super) async fn login(
} }
let token = create_jwt(&user, &state)?; let token = create_jwt(&user, &state)?;
let refresh_token = if payload.remember_me {
Some(create_refresh_jwt(&user, &state)?)
} else {
None
};
let _ = state.activity_log_repo.log("user_login", user.email.as_ref(), None).await; let _ = state.activity_log_repo.log("user_login", user.email.as_ref(), None).await;
Ok(( Ok((
@@ -43,6 +48,7 @@ pub(super) async fn login(
access_token: token, access_token: token,
token_type: "Bearer".to_string(), token_type: "Bearer".to_string(),
expires_in: state.config.jwt_expiry_hours * 3600, expires_in: state.config.jwt_expiry_hours * 3600,
refresh_token,
}), }),
)) ))
} }
@@ -71,6 +77,7 @@ pub(super) async fn register(
access_token: token, access_token: token,
token_type: "Bearer".to_string(), token_type: "Bearer".to_string(),
expires_in: state.config.jwt_expiry_hours * 3600, expires_in: state.config.jwt_expiry_hours * 3600,
refresh_token: None,
}), }),
)) ))
} }
@@ -90,6 +97,46 @@ pub(super) async fn me(CurrentUser(user): CurrentUser) -> Result<impl IntoRespon
})) }))
} }
/// Exchange a valid refresh token for a new access + refresh token pair
#[cfg(feature = "auth-jwt")]
pub(super) async fn refresh_token(
State(state): State<AppState>,
Json(payload): Json<RefreshRequest>,
) -> Result<impl IntoResponse, ApiError> {
let validator = state
.jwt_validator
.as_ref()
.ok_or_else(|| ApiError::Internal("JWT not configured".to_string()))?;
let claims = validator
.validate_refresh_token(&payload.refresh_token)
.map_err(|e| {
tracing::debug!("Refresh token validation failed: {:?}", e);
ApiError::Unauthorized("Invalid or expired refresh token".to_string())
})?;
let user_id: uuid::Uuid = claims
.sub
.parse()
.map_err(|_| ApiError::Unauthorized("Invalid user ID in token".to_string()))?;
let user = state
.user_service
.find_by_id(user_id)
.await
.map_err(|e| ApiError::Internal(format!("Failed to fetch user: {}", e)))?;
let access_token = create_jwt(&user, &state)?;
let new_refresh_token = create_refresh_jwt(&user, &state)?;
Ok(Json(TokenResponse {
access_token,
token_type: "Bearer".to_string(),
expires_in: state.config.jwt_expiry_hours * 3600,
refresh_token: Some(new_refresh_token),
}))
}
/// Issue a new JWT for the currently authenticated user (OIDC→JWT exchange or token refresh) /// Issue a new JWT for the currently authenticated user (OIDC→JWT exchange or token refresh)
#[cfg(feature = "auth-jwt")] #[cfg(feature = "auth-jwt")]
pub(super) async fn get_token( pub(super) async fn get_token(
@@ -102,5 +149,6 @@ pub(super) async fn get_token(
access_token: token, access_token: token,
token_type: "Bearer".to_string(), token_type: "Bearer".to_string(),
expires_in: state.config.jwt_expiry_hours * 3600, expires_in: state.config.jwt_expiry_hours * 3600,
refresh_token: None,
})) }))
} }

View File

@@ -18,7 +18,9 @@ pub fn router() -> Router<AppState> {
.route("/me", get(local::me)); .route("/me", get(local::me));
#[cfg(feature = "auth-jwt")] #[cfg(feature = "auth-jwt")]
let r = r.route("/token", post(local::get_token)); let r = r
.route("/token", post(local::get_token))
.route("/refresh", post(local::refresh_token));
#[cfg(feature = "auth-oidc")] #[cfg(feature = "auth-oidc")]
let r = r let r = r
@@ -28,7 +30,7 @@ pub fn router() -> Router<AppState> {
r r
} }
/// Helper: create JWT for a user /// Helper: create access JWT for a user
#[cfg(feature = "auth-jwt")] #[cfg(feature = "auth-jwt")]
pub(super) fn create_jwt(user: &domain::User, state: &AppState) -> Result<String, ApiError> { pub(super) fn create_jwt(user: &domain::User, state: &AppState) -> Result<String, ApiError> {
let validator = state let validator = state
@@ -45,3 +47,21 @@ pub(super) fn create_jwt(user: &domain::User, state: &AppState) -> Result<String
pub(super) fn create_jwt(_user: &domain::User, _state: &AppState) -> Result<String, ApiError> { pub(super) fn create_jwt(_user: &domain::User, _state: &AppState) -> Result<String, ApiError> {
Err(ApiError::Internal("JWT feature not enabled".to_string())) Err(ApiError::Internal("JWT feature not enabled".to_string()))
} }
/// Helper: create refresh JWT for a user
#[cfg(feature = "auth-jwt")]
pub(super) fn create_refresh_jwt(user: &domain::User, state: &AppState) -> Result<String, ApiError> {
let validator = state
.jwt_validator
.as_ref()
.ok_or_else(|| ApiError::Internal("JWT not configured".to_string()))?;
validator
.create_refresh_token(user)
.map_err(|e| ApiError::Internal(format!("Failed to create refresh token: {}", e)))
}
#[cfg(not(feature = "auth-jwt"))]
pub(super) fn create_refresh_jwt(_user: &domain::User, _state: &AppState) -> Result<String, ApiError> {
Err(ApiError::Internal("JWT feature not enabled".to_string()))
}

View File

@@ -124,6 +124,7 @@ impl AppState {
config.jwt_issuer.clone(), config.jwt_issuer.clone(),
config.jwt_audience.clone(), config.jwt_audience.clone(),
Some(config.jwt_expiry_hours), Some(config.jwt_expiry_hours),
Some(config.jwt_refresh_expiry_days),
config.is_production, config.is_production,
)?; )?;
Some(Arc::new(JwtValidator::new(jwt_config))) Some(Arc::new(JwtValidator::new(jwt_config)))

View File

@@ -20,8 +20,10 @@ pub struct JwtConfig {
pub issuer: Option<String>, pub issuer: Option<String>,
/// Expected audience (for validation) /// Expected audience (for validation)
pub audience: Option<String>, pub audience: Option<String>,
/// Token expiry in hours (default: 24) /// Access token expiry in hours (default: 24)
pub expiry_hours: u64, pub expiry_hours: u64,
/// Refresh token expiry in days (default: 30)
pub refresh_expiry_days: u64,
} }
impl JwtConfig { impl JwtConfig {
@@ -33,6 +35,7 @@ impl JwtConfig {
issuer: Option<String>, issuer: Option<String>,
audience: Option<String>, audience: Option<String>,
expiry_hours: Option<u64>, expiry_hours: Option<u64>,
refresh_expiry_days: Option<u64>,
is_production: bool, is_production: bool,
) -> Result<Self, JwtError> { ) -> Result<Self, JwtError> {
// Validate secret strength in production // Validate secret strength in production
@@ -48,6 +51,7 @@ impl JwtConfig {
issuer, issuer,
audience, audience,
expiry_hours: expiry_hours.unwrap_or(24), expiry_hours: expiry_hours.unwrap_or(24),
refresh_expiry_days: refresh_expiry_days.unwrap_or(30),
}) })
} }
@@ -58,10 +62,15 @@ impl JwtConfig {
issuer: None, issuer: None,
audience: None, audience: None,
expiry_hours: 24, expiry_hours: 24,
refresh_expiry_days: 30,
} }
} }
} }
fn default_token_type() -> String {
"access".to_string()
}
/// JWT claims structure /// JWT claims structure
#[derive(Debug, Serialize, Deserialize, Clone)] #[derive(Debug, Serialize, Deserialize, Clone)]
pub struct JwtClaims { pub struct JwtClaims {
@@ -79,6 +88,9 @@ pub struct JwtClaims {
/// Audience /// Audience
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub aud: Option<String>, pub aud: Option<String>,
/// Token type: "access" or "refresh". Defaults to "access" for backward compat.
#[serde(default = "default_token_type")]
pub token_type: String,
} }
/// JWT-related errors /// JWT-related errors
@@ -141,7 +153,7 @@ impl JwtValidator {
} }
} }
/// Create a JWT token for the given user /// Create an access JWT token for the given user
pub fn create_token(&self, user: &User) -> Result<String, JwtError> { pub fn create_token(&self, user: &User) -> Result<String, JwtError> {
let now = SystemTime::now() let now = SystemTime::now()
.duration_since(UNIX_EPOCH) .duration_since(UNIX_EPOCH)
@@ -157,6 +169,30 @@ impl JwtValidator {
iat: now, iat: now,
iss: self.config.issuer.clone(), iss: self.config.issuer.clone(),
aud: self.config.audience.clone(), aud: self.config.audience.clone(),
token_type: "access".to_string(),
};
let header = Header::new(Algorithm::HS256);
encode(&header, &claims, &self.encoding_key).map_err(JwtError::CreationFailed)
}
/// Create a refresh JWT token for the given user (longer-lived)
pub fn create_refresh_token(&self, user: &User) -> Result<String, JwtError> {
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("Time went backwards")
.as_secs() as usize;
let expiry = now + (self.config.refresh_expiry_days as usize * 86400);
let claims = JwtClaims {
sub: user.id.to_string(),
email: user.email.as_ref().to_string(),
exp: expiry,
iat: now,
iss: self.config.issuer.clone(),
aud: self.config.audience.clone(),
token_type: "refresh".to_string(),
}; };
let header = Header::new(Algorithm::HS256); let header = Header::new(Algorithm::HS256);
@@ -176,6 +212,24 @@ impl JwtValidator {
Ok(token_data.claims) Ok(token_data.claims)
} }
/// Validate an access token — rejects refresh tokens
pub fn validate_access_token(&self, token: &str) -> Result<JwtClaims, JwtError> {
let claims = self.validate_token(token)?;
if claims.token_type != "access" {
return Err(JwtError::ValidationFailed("Not an access token".to_string()));
}
Ok(claims)
}
/// Validate a refresh token — rejects access tokens
pub fn validate_refresh_token(&self, token: &str) -> Result<JwtClaims, JwtError> {
let claims = self.validate_token(token)?;
if claims.token_type != "refresh" {
return Err(JwtError::ValidationFailed("Not a refresh token".to_string()));
}
Ok(claims)
}
/// Get the user ID (subject) from a token without full validation /// Get the user ID (subject) from a token without full validation
/// Useful for logging/debugging, but should not be trusted for auth /// Useful for logging/debugging, but should not be trusted for auth
pub fn decode_unverified(&self, token: &str) -> Result<JwtClaims, JwtError> { pub fn decode_unverified(&self, token: &str) -> Result<JwtClaims, JwtError> {

View File

@@ -8,12 +8,13 @@ import { useConfig } from "@/hooks/use-channels";
export default function LoginPage() { export default function LoginPage() {
const [email, setEmail] = useState(""); const [email, setEmail] = useState("");
const [password, setPassword] = useState(""); const [password, setPassword] = useState("");
const [rememberMe, setRememberMe] = useState(false);
const { mutate: login, isPending, error } = useLogin(); const { mutate: login, isPending, error } = useLogin();
const { data: config } = useConfig(); const { data: config } = useConfig();
const handleSubmit = (e: React.FormEvent) => { const handleSubmit = (e: React.FormEvent) => {
e.preventDefault(); e.preventDefault();
login({ email, password }); login({ email, password, rememberMe });
}; };
return ( return (
@@ -54,6 +55,23 @@ export default function LoginPage() {
/> />
</div> </div>
<div className="space-y-1">
<label className="flex cursor-pointer items-center gap-2">
<input
type="checkbox"
checked={rememberMe}
onChange={(e) => setRememberMe(e.target.checked)}
className="h-3.5 w-3.5 rounded border-zinc-600 bg-zinc-900 accent-white"
/>
<span className="text-xs text-zinc-400">Remember me</span>
</label>
{rememberMe && (
<p className="pl-5 text-xs text-amber-500/80">
A refresh token will be stored locally don&apos;t share it.
</p>
)}
</div>
{error && <p className="text-xs text-red-400">{error.message}</p>} {error && <p className="text-xs text-red-400">{error.message}</p>}
<button <button

View File

@@ -15,7 +15,7 @@ import { Toaster } from "@/components/ui/sonner";
import { ApiRequestError } from "@/lib/api"; import { ApiRequestError } from "@/lib/api";
function QueryProvider({ children }: { children: React.ReactNode }) { function QueryProvider({ children }: { children: React.ReactNode }) {
const { token, setToken } = useAuthContext(); const { token, setTokens } = useAuthContext();
const router = useRouter(); const router = useRouter();
const tokenRef = useRef(token); const tokenRef = useRef(token);
useEffect(() => { tokenRef.current = token; }, [token]); useEffect(() => { tokenRef.current = token; }, [token]);
@@ -29,7 +29,7 @@ function QueryProvider({ children }: { children: React.ReactNode }) {
// Guests hitting 401 on restricted content should not be redirected. // Guests hitting 401 on restricted content should not be redirected.
if (error instanceof ApiRequestError && error.status === 401 && tokenRef.current) { if (error instanceof ApiRequestError && error.status === 401 && tokenRef.current) {
toast.warning("Session expired, please log in again."); toast.warning("Session expired, please log in again.");
setToken(null); setTokens(null, null, false);
router.push("/login"); router.push("/login");
} }
}, },
@@ -39,7 +39,7 @@ function QueryProvider({ children }: { children: React.ReactNode }) {
// Mutations always require auth — redirect on 401 regardless. // Mutations always require auth — redirect on 401 regardless.
if (error instanceof ApiRequestError && error.status === 401) { if (error instanceof ApiRequestError && error.status === 401) {
toast.warning("Session expired, please log in again."); toast.warning("Session expired, please log in again.");
setToken(null); setTokens(null, null, false);
router.push("/login"); router.push("/login");
} }
}, },

View File

@@ -4,42 +4,94 @@ import {
createContext, createContext,
useContext, useContext,
useState, useState,
useEffect,
type ReactNode, type ReactNode,
} from "react"; } from "react";
import { useRouter } from "next/navigation";
import { api, setRefreshCallback } from "@/lib/api";
const TOKEN_KEY = "k-tv-token"; const ACCESS_KEY_LOCAL = "k-tv-token";
const ACCESS_KEY_SESSION = "k-tv-token-session";
const REFRESH_KEY = "k-tv-refresh-token";
interface AuthContextValue { interface AuthContextValue {
token: string | null; token: string | null;
/** True once the initial localStorage read has completed */ refreshToken: string | null;
/** Always true (lazy init reads storage synchronously) */
isLoaded: boolean; isLoaded: boolean;
setToken: (token: string | null) => void; setTokens: (access: string | null, refresh: string | null, remember: boolean) => void;
} }
const AuthContext = createContext<AuthContextValue | null>(null); const AuthContext = createContext<AuthContextValue | null>(null);
export function AuthProvider({ children }: { children: ReactNode }) { export function AuthProvider({ children }: { children: ReactNode }) {
const router = useRouter();
const [token, setTokenState] = useState<string | null>(() => { const [token, setTokenState] = useState<string | null>(() => {
try { try {
return localStorage.getItem(TOKEN_KEY); return sessionStorage.getItem(ACCESS_KEY_SESSION) ?? localStorage.getItem(ACCESS_KEY_LOCAL);
} catch { } catch {
return null; return null;
} }
}); });
// isLoaded is always true: lazy init above reads localStorage synchronously
const [isLoaded] = useState(true);
const setToken = (t: string | null) => { const [refreshToken, setRefreshTokenState] = useState<string | null>(() => {
setTokenState(t); try {
if (t) { return localStorage.getItem(REFRESH_KEY);
localStorage.setItem(TOKEN_KEY, t); } catch {
} else { return null;
localStorage.removeItem(TOKEN_KEY);
} }
});
const setTokens = (access: string | null, refresh: string | null, remember: boolean) => {
try {
if (access === null) {
sessionStorage.removeItem(ACCESS_KEY_SESSION);
localStorage.removeItem(ACCESS_KEY_LOCAL);
localStorage.removeItem(REFRESH_KEY);
} else if (remember) {
localStorage.setItem(ACCESS_KEY_LOCAL, access);
sessionStorage.removeItem(ACCESS_KEY_SESSION);
if (refresh) {
localStorage.setItem(REFRESH_KEY, refresh);
} else {
localStorage.removeItem(REFRESH_KEY);
}
} else {
sessionStorage.setItem(ACCESS_KEY_SESSION, access);
localStorage.removeItem(ACCESS_KEY_LOCAL);
localStorage.removeItem(REFRESH_KEY);
}
} catch {
// storage unavailable — state-only fallback
}
setTokenState(access);
setRefreshTokenState(refresh);
}; };
// Wire up the transparent refresh callback in api.ts
useEffect(() => {
if (refreshToken) {
setRefreshCallback(async () => {
try {
const data = await api.auth.refresh(refreshToken);
const newRefresh = data.refresh_token ?? null;
setTokens(data.access_token, newRefresh, true);
return data.access_token;
} catch {
setTokens(null, null, false);
router.push("/login");
return null;
}
});
} else {
setRefreshCallback(null);
}
// eslint-disable-next-line react-hooks/exhaustive-deps
}, [refreshToken]);
return ( return (
<AuthContext.Provider value={{ token, isLoaded, setToken }}> <AuthContext.Provider value={{ token, refreshToken, isLoaded: true, setTokens }}>
{children} {children}
</AuthContext.Provider> </AuthContext.Provider>
); );

View File

@@ -16,14 +16,21 @@ export function useCurrentUser() {
} }
export function useLogin() { export function useLogin() {
const { setToken } = useAuthContext(); const { setTokens } = useAuthContext();
const router = useRouter(); const router = useRouter();
const queryClient = useQueryClient(); const queryClient = useQueryClient();
return useMutation({ return useMutation({
mutationFn: ({ email, password }: { email: string; password: string }) => mutationFn: ({
api.auth.login(email, password), email,
onSuccess: (data) => { password,
setToken(data.access_token); rememberMe,
}: {
email: string;
password: string;
rememberMe: boolean;
}) => api.auth.login(email, password, rememberMe),
onSuccess: (data, { rememberMe }) => {
setTokens(data.access_token, data.refresh_token ?? null, rememberMe);
queryClient.invalidateQueries({ queryKey: ["me"] }); queryClient.invalidateQueries({ queryKey: ["me"] });
router.push("/dashboard"); router.push("/dashboard");
}, },
@@ -31,14 +38,14 @@ export function useLogin() {
} }
export function useRegister() { export function useRegister() {
const { setToken } = useAuthContext(); const { setTokens } = useAuthContext();
const router = useRouter(); const router = useRouter();
const queryClient = useQueryClient(); const queryClient = useQueryClient();
return useMutation({ return useMutation({
mutationFn: ({ email, password }: { email: string; password: string }) => mutationFn: ({ email, password }: { email: string; password: string }) =>
api.auth.register(email, password), api.auth.register(email, password),
onSuccess: (data) => { onSuccess: (data) => {
setToken(data.access_token); setTokens(data.access_token, null, false);
queryClient.invalidateQueries({ queryKey: ["me"] }); queryClient.invalidateQueries({ queryKey: ["me"] });
router.push("/dashboard"); router.push("/dashboard");
}, },
@@ -46,13 +53,13 @@ export function useRegister() {
} }
export function useLogout() { export function useLogout() {
const { token, setToken } = useAuthContext(); const { token, setTokens } = useAuthContext();
const router = useRouter(); const router = useRouter();
const queryClient = useQueryClient(); const queryClient = useQueryClient();
return useMutation({ return useMutation({
mutationFn: () => (token ? api.auth.logout(token) : Promise.resolve()), mutationFn: () => (token ? api.auth.logout(token) : Promise.resolve()),
onSettled: () => { onSettled: () => {
setToken(null); setTokens(null, null, false);
queryClient.clear(); queryClient.clear();
router.push("/login"); router.push("/login");
}, },

View File

@@ -34,6 +34,23 @@ export class ApiRequestError extends Error {
} }
} }
// Called by AuthProvider when refreshToken changes — enables transparent 401 recovery
let refreshCallback: (() => Promise<string | null>) | null = null;
let refreshInFlight: Promise<string | null> | null = null;
export function setRefreshCallback(cb: (() => Promise<string | null>) | null) {
refreshCallback = cb;
}
async function attemptRefresh(): Promise<string | null> {
if (!refreshCallback) return null;
if (refreshInFlight) return refreshInFlight;
refreshInFlight = refreshCallback().finally(() => {
refreshInFlight = null;
});
return refreshInFlight;
}
async function request<T>( async function request<T>(
path: string, path: string,
options: RequestInit & { token?: string } = {}, options: RequestInit & { token?: string } = {},
@@ -50,6 +67,35 @@ async function request<T>(
const res = await fetch(`${API_BASE}${path}`, { ...init, headers }); const res = await fetch(`${API_BASE}${path}`, { ...init, headers });
// Transparent refresh: on 401, try to get a new access token and retry once.
// Skip for the refresh endpoint itself to avoid infinite loops.
if (res.status === 401 && path !== "/auth/refresh") {
const newToken = await attemptRefresh();
if (newToken) {
const retryHeaders = new Headers(init.headers);
retryHeaders.set("Authorization", `Bearer ${newToken}`);
if (init.body && !retryHeaders.has("Content-Type")) {
retryHeaders.set("Content-Type", "application/json");
}
const retryRes = await fetch(`${API_BASE}${path}`, {
...init,
headers: retryHeaders,
});
if (!retryRes.ok) {
let message = retryRes.statusText;
try {
const body = await retryRes.json();
message = body.message ?? body.error ?? message;
} catch {
// ignore parse error
}
throw new ApiRequestError(retryRes.status, message);
}
if (retryRes.status === 204) return null as T;
return retryRes.json() as Promise<T>;
}
}
if (!res.ok) { if (!res.ok) {
let message = res.statusText; let message = res.statusText;
try { try {
@@ -77,10 +123,16 @@ export const api = {
body: JSON.stringify({ email, password }), body: JSON.stringify({ email, password }),
}), }),
login: (email: string, password: string) => login: (email: string, password: string, rememberMe = false) =>
request<TokenResponse>("/auth/login", { request<TokenResponse>("/auth/login", {
method: "POST", method: "POST",
body: JSON.stringify({ email, password }), body: JSON.stringify({ email, password, remember_me: rememberMe }),
}),
refresh: (refreshToken: string) =>
request<TokenResponse>("/auth/refresh", {
method: "POST",
body: JSON.stringify({ refresh_token: refreshToken }),
}), }),
logout: (token: string) => logout: (token: string) =>

View File

@@ -178,6 +178,7 @@ export interface TokenResponse {
access_token: string; access_token: string;
token_type: string; token_type: string;
expires_in: number; expires_in: number;
refresh_token?: string;
} }
export interface UserResponse { export interface UserResponse {