Compare commits
2 Commits
8bdd5e2277
...
373e1c7c0a
| Author | SHA1 | Date | |
|---|---|---|---|
| 373e1c7c0a | |||
| d2412da057 |
@@ -1,284 +0,0 @@
|
|||||||
# Scheduling V2 — Design Spec
|
|
||||||
|
|
||||||
## Context
|
|
||||||
|
|
||||||
The current scheduler is a 48h rolling window with a flat block list per channel. This works as MVP but has two major gaps for everyday use:
|
|
||||||
|
|
||||||
1. **No weekly patterns** — users can't say "Monday runs X, weekends run Y"; all blocks repeat identically every day.
|
|
||||||
2. **No history or recovery** — overwriting a channel config loses the previous setup forever; a bug that resets a sequential series (e.g. Sopranos resets from S3E4 to S1E1) has no recovery path.
|
|
||||||
|
|
||||||
This spec covers two features: **weekly scheduling** and **schedule history**.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Feature 1: Weekly Scheduling (7-day grid)
|
|
||||||
|
|
||||||
### Data model
|
|
||||||
|
|
||||||
`ScheduleConfig` changes from a flat block list to a day-keyed map:
|
|
||||||
|
|
||||||
```rust
|
|
||||||
// BEFORE
|
|
||||||
pub struct ScheduleConfig {
|
|
||||||
pub blocks: Vec<ProgrammingBlock>,
|
|
||||||
}
|
|
||||||
|
|
||||||
// AFTER
|
|
||||||
pub struct ScheduleConfig {
|
|
||||||
pub day_blocks: HashMap<Weekday, Vec<ProgrammingBlock>>,
|
|
||||||
}
|
|
||||||
|
|
||||||
pub enum Weekday {
|
|
||||||
Monday, Tuesday, Wednesday, Thursday, Friday, Saturday, Sunday,
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
`ProgrammingBlock` is otherwise unchanged. Block IDs remain UUIDs; each day has its own independent Vec, so the same "show" on Mon and Wed has two separate block entries (different IDs, independent continuity tracking).
|
|
||||||
|
|
||||||
### Migration (transparent, zero-downtime)
|
|
||||||
|
|
||||||
Existing `channels.schedule_config` stores `{"blocks":[...]}`. Use `#[serde(untagged)]` deserialization:
|
|
||||||
|
|
||||||
```rust
|
|
||||||
#[serde(untagged)]
|
|
||||||
enum ScheduleConfigCompat {
|
|
||||||
V2(ScheduleConfig), // {"day_blocks": {"monday": [...], ...}}
|
|
||||||
V1(OldScheduleConfig), // {"blocks": [...]}
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
V1→V2 conversion: clone the blocks Vec into all 7 days. The first `PUT /channels/:id` after deploy saves V2 format. Channels never touched continue to deserialize via V1 path indefinitely.
|
|
||||||
|
|
||||||
**Edge case**: if a payload has both `blocks` and `day_blocks` keys (e.g. partially migrated export), `#[serde(untagged)]` tries V2 first and succeeds — `day_blocks` is used and `blocks` is silently ignored. This is acceptable; the alternative (error on ambiguity) would break more use cases.
|
|
||||||
|
|
||||||
### ScheduleConfig helper methods
|
|
||||||
|
|
||||||
Three methods on `ScheduleConfig` must be updated:
|
|
||||||
|
|
||||||
- **`find_block_at(weekday: Weekday, time: NaiveTime) -> Option<&ProgrammingBlock>`** — searches `day_blocks[weekday]` for the block whose window contains `time`.
|
|
||||||
- **`next_block_start_after(weekday: Weekday, time: NaiveTime) -> Option<NaiveTime>`** — searches that day's vec; returns `None` if no block starts after `time` on that day (day-rollover is the caller's responsibility).
|
|
||||||
- **`earliest_block_start() -> Option<NaiveTime>`** — **iterates all days, returns the global earliest start time across the entire week**. This is the form needed by the background scheduler (which needs to know when any content starts). Empty day = no contribution; all days empty = `None`.
|
|
||||||
|
|
||||||
**Call-site update pattern for `broadcast.rs` (lines 64, 171):**
|
|
||||||
```rust
|
|
||||||
// derive weekday from slot start_at in channel timezone
|
|
||||||
let tz: chrono_tz::Tz = channel.timezone.parse().unwrap_or(chrono_tz::UTC);
|
|
||||||
let local_dt = slot.start_at.with_timezone(&tz);
|
|
||||||
let weekday = Weekday::from_chrono(local_dt.weekday()); // new From impl
|
|
||||||
let block = channel.schedule_config.find_block_at(weekday, local_dt.time());
|
|
||||||
```
|
|
||||||
|
|
||||||
The same derivation applies to `dto.rs` (`ScheduledSlotResponse::with_block_access`).
|
|
||||||
|
|
||||||
### MCP crate
|
|
||||||
|
|
||||||
`mcp/src/tools/channels.rs` manipulates `schedule_config.blocks` directly. After V2:
|
|
||||||
|
|
||||||
- The MCP `add_block` tool must accept a `day: Weekday` parameter (required). It pushes the new block to `day_blocks[day]`.
|
|
||||||
- The MCP `remove_block` tool must iterate all days' vecs (remove by block ID across all days, since block IDs are unique per entry).
|
|
||||||
- `mcp/src/server.rs` `set_schedule_config` must accept a `day_blocks` map. The old `blocks_json` string parameter is replaced with `day_blocks_json: String` (JSON object keyed by weekday name).
|
|
||||||
|
|
||||||
These are breaking changes to the MCP API — acceptable since MCP tools are internal/developer-facing.
|
|
||||||
|
|
||||||
### Generation engine
|
|
||||||
|
|
||||||
- Window: `valid_from + 7 days` (was 48h). Update `GeneratedSchedule` doc comment accordingly.
|
|
||||||
- Day iteration: already walks calendar days; now walks 7 days, looks up `day_blocks[weekday]` for each day.
|
|
||||||
- **Empty day**: if `day_blocks[weekday]` is empty or the key is absent, that day produces no slots — valid, not an error.
|
|
||||||
- Continuity (`find_last_slot_per_block`): unchanged.
|
|
||||||
|
|
||||||
### Files changed (backend)
|
|
||||||
- `domain/src/value_objects.rs` — add `Weekday` enum with `From<chrono::Weekday>` impl
|
|
||||||
- `domain/src/entities.rs` — `ScheduleConfig`, `OldScheduleConfig` compat struct, update helper method signatures, update `GeneratedSchedule` doc comment
|
|
||||||
- `domain/src/services.rs` — 7-day window, `day_blocks[weekday]` lookup per day
|
|
||||||
- `api/src/routes/channels/broadcast.rs` — update block lookups at lines 64 and 171 using weekday-derivation pattern above
|
|
||||||
- `api/src/dto.rs` — update `ScheduledSlotResponse::with_block_access` block lookup
|
|
||||||
- `mcp/src/tools/channels.rs` — `add_block` accepts `day` param; `remove_block` iterates all days
|
|
||||||
- `mcp/src/server.rs` — replace `blocks_json` with `day_blocks_json`
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Feature 2: Schedule History
|
|
||||||
|
|
||||||
### 2a. Config version history
|
|
||||||
|
|
||||||
Every `PUT /channels/:id` auto-snapshots the previous config before overwriting. Users can pin named checkpoints and restore any version.
|
|
||||||
|
|
||||||
**New DB migration:**
|
|
||||||
```sql
|
|
||||||
CREATE TABLE channel_config_snapshots (
|
|
||||||
id TEXT PRIMARY KEY,
|
|
||||||
channel_id TEXT NOT NULL REFERENCES channels(id) ON DELETE CASCADE,
|
|
||||||
config_json TEXT NOT NULL,
|
|
||||||
version_num INTEGER NOT NULL,
|
|
||||||
label TEXT, -- NULL = auto-saved, non-NULL = pinned
|
|
||||||
created_at TEXT NOT NULL,
|
|
||||||
UNIQUE (channel_id, version_num)
|
|
||||||
);
|
|
||||||
CREATE INDEX idx_config_snapshots_channel ON channel_config_snapshots(channel_id, version_num DESC);
|
|
||||||
```
|
|
||||||
|
|
||||||
**`version_num` assignment**: computed inside the write transaction as `SELECT COALESCE(MAX(version_num), 0) + 1 FROM channel_config_snapshots WHERE channel_id = ?`. The transaction serializes concurrent writes naturally in SQLite (single writer). The `UNIQUE` constraint is a safety net only — no 409 is exposed to the client; the server retries within the transaction if needed (in practice impossible with SQLite's serialized writes).
|
|
||||||
|
|
||||||
**New API endpoints (all require auth + channel ownership — same auth middleware as existing channel routes):**
|
|
||||||
```
|
|
||||||
GET /channels/:id/config/history
|
|
||||||
→ [{id, version_num, label, created_at}] -- channel_id omitted (implicit from URL)
|
|
||||||
|
|
||||||
PATCH /channels/:id/config/history/:snap_id
|
|
||||||
body: {"label": "Before S3 switchover"}
|
|
||||||
→ 404 if snap_id not found or not owned by this channel
|
|
||||||
→ 200 {id, version_num, label, created_at}
|
|
||||||
|
|
||||||
POST /channels/:id/config/history/:snap_id/restore
|
|
||||||
→ snapshots current config first, then replaces channel config with target snapshot
|
|
||||||
→ 404 if snap_id not found or not owned by this channel
|
|
||||||
→ 200 {channel}
|
|
||||||
```
|
|
||||||
|
|
||||||
**Domain + infra changes:**
|
|
||||||
- `ChannelConfigSnapshot` entity (fields: id, channel_id, config, version_num, label, created_at)
|
|
||||||
- Extend `ChannelRepository` port: `save_config_snapshot`, `list_config_snapshots`, `get_config_snapshot`, `patch_config_snapshot_label`
|
|
||||||
- `ChannelService::update_channel` calls `save_config_snapshot` before writing new config
|
|
||||||
|
|
||||||
**Files changed (backend):**
|
|
||||||
- `domain/src/entities.rs` — add `ChannelConfigSnapshot`
|
|
||||||
- `domain/src/repositories.rs` — extend `ChannelRepository` port
|
|
||||||
- `infra/src/channel_repo.rs` — implement snapshot methods
|
|
||||||
- `migrations_sqlite/YYYYMMDD_add_config_snapshots.sql`
|
|
||||||
- `api/src/routes/channels.rs` — new history endpoints + DTOs for snapshot responses
|
|
||||||
|
|
||||||
### 2b. Generated schedule audit log
|
|
||||||
|
|
||||||
**Ownership check**: `get_schedule_by_id(channel_id, gen_id)` queries `generated_schedules WHERE id = :gen_id AND channel_id = :channel_id` — the `channel_id` column is the join, so no separate channel lookup is needed.
|
|
||||||
|
|
||||||
**New API endpoints (all require auth + channel ownership):**
|
|
||||||
```
|
|
||||||
GET /channels/:id/schedule/history
|
|
||||||
→ [{id, generation, valid_from, valid_until}] ordered by generation DESC
|
|
||||||
|
|
||||||
GET /channels/:id/schedule/history/:gen_id
|
|
||||||
→ full GeneratedSchedule with slots
|
|
||||||
→ 404 if gen_id not found or channel_id mismatch
|
|
||||||
|
|
||||||
POST /channels/:id/schedule/history/:gen_id/rollback
|
|
||||||
→ 404 if gen_id not found or channel_id mismatch
|
|
||||||
→ explicit two-step delete (no DB-level cascade from playback_records to generated_schedules):
|
|
||||||
1. DELETE FROM playback_records WHERE channel_id = ? AND generation > :target_generation
|
|
||||||
2. DELETE FROM generated_schedules WHERE channel_id = ? AND generation > :target_generation
|
|
||||||
(scheduled_slots cascade via FK from generated_schedules)
|
|
||||||
→ calls generate_schedule from now
|
|
||||||
→ 200 {new_schedule}
|
|
||||||
```
|
|
||||||
|
|
||||||
**Repository changes:**
|
|
||||||
- `list_schedule_history(channel_id)` — headers only
|
|
||||||
- `get_schedule_by_id(channel_id, gen_id)` — full with slots
|
|
||||||
- `delete_schedules_after(channel_id, generation_num)` — two-step explicit delete as above
|
|
||||||
|
|
||||||
**Files changed (backend):**
|
|
||||||
- `domain/src/repositories.rs` — extend `ScheduleRepository`
|
|
||||||
- `infra/src/schedule_repo.rs` — implement list, get-by-id, delete-after
|
|
||||||
- `api/src/routes/channels.rs` — new history and rollback endpoints
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Frontend
|
|
||||||
|
|
||||||
### Weekly grid editor (`edit-channel-sheet.tsx`)
|
|
||||||
|
|
||||||
Replace the flat block list with a tabbed weekly grid:
|
|
||||||
|
|
||||||
- 7 day tabs (Mon–Sun), each showing that day's block list
|
|
||||||
- Blocks within a day: same card UI as current (drag to reorder, edit, delete)
|
|
||||||
- "Copy to →" dropdown per tab: duplicates block entries with new UUIDs into target day(s)
|
|
||||||
- "+ Add block for [Day]" button per tab
|
|
||||||
- "🕐 Config history" button in sheet footer → opens config history panel
|
|
||||||
|
|
||||||
### Config history panel (`config-history-sheet.tsx` — new)
|
|
||||||
|
|
||||||
- List of snapshots: version_num, timestamp, label (if pinned)
|
|
||||||
- Current version highlighted
|
|
||||||
- Pin button on current version (opens label input)
|
|
||||||
- Restore button on any past version (confirm dialog)
|
|
||||||
|
|
||||||
### Schedule audit log (`schedule-history-dialog.tsx` — new)
|
|
||||||
|
|
||||||
- Lists past generations: gen#, date range
|
|
||||||
- "Rollback to here" button with confirm dialog
|
|
||||||
|
|
||||||
### Types (`lib/types.ts`)
|
|
||||||
```ts
|
|
||||||
type Weekday = 'monday' | 'tuesday' | 'wednesday' | 'thursday' | 'friday' | 'saturday' | 'sunday'
|
|
||||||
const WEEKDAYS: Weekday[] = ['monday','tuesday','wednesday','thursday','friday','saturday','sunday']
|
|
||||||
|
|
||||||
interface ScheduleConfig {
|
|
||||||
day_blocks: Record<Weekday, ProgrammingBlock[]>
|
|
||||||
}
|
|
||||||
|
|
||||||
interface ConfigSnapshot {
|
|
||||||
id: string
|
|
||||||
version_num: number
|
|
||||||
label: string | null
|
|
||||||
created_at: string
|
|
||||||
// channel_id intentionally omitted — always accessed via /channels/:id/config/history
|
|
||||||
}
|
|
||||||
|
|
||||||
interface ScheduleHistoryEntry {
|
|
||||||
id: string
|
|
||||||
generation: number
|
|
||||||
valid_from: string
|
|
||||||
valid_until: string
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
### Zod schema (`lib/schemas.ts`)
|
|
||||||
```ts
|
|
||||||
const weekdaySchema = z.enum(['monday','tuesday','wednesday','thursday','friday','saturday','sunday'])
|
|
||||||
|
|
||||||
// replace blocks: z.array(blockSchema) with:
|
|
||||||
day_blocks: z.record(weekdaySchema, z.array(blockSchema)).default(
|
|
||||||
() => Object.fromEntries(WEEKDAYS.map(d => [d, []])) as Record<Weekday, ProgrammingBlock[]>
|
|
||||||
)
|
|
||||||
// A missing day key is valid (treated as empty). The default initializes all days to [].
|
|
||||||
```
|
|
||||||
|
|
||||||
### Channel export (`lib/channel-export.ts`)
|
|
||||||
Export format after V2: `day_blocks` map as-is (no flattening). The export JSON shape mirrors `ScheduleConfig` directly. Re-import reads via the same `ScheduleConfigCompat` deserialization path, so V1 exports remain importable indefinitely.
|
|
||||||
|
|
||||||
### New hooks (`hooks/use-channels.ts`)
|
|
||||||
- `useConfigHistory(channelId)`
|
|
||||||
- `useRestoreConfig()`
|
|
||||||
- `usePinSnapshot()`
|
|
||||||
- `useScheduleHistory(channelId)`
|
|
||||||
- `useScheduleGeneration(channelId, genId)` (lazy, for detail view)
|
|
||||||
- `useRollbackSchedule()`
|
|
||||||
|
|
||||||
### Files changed (frontend)
|
|
||||||
- `lib/types.ts`
|
|
||||||
- `lib/schemas.ts`
|
|
||||||
- `lib/channel-export.ts`
|
|
||||||
- `hooks/use-channels.ts`
|
|
||||||
- `dashboard/components/edit-channel-sheet.tsx`
|
|
||||||
- `dashboard/components/config-history-sheet.tsx` (new)
|
|
||||||
- `dashboard/components/schedule-history-dialog.tsx` (new)
|
|
||||||
- `app/(main)/dashboard/page.tsx` — wire new dialog triggers
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Verification
|
|
||||||
|
|
||||||
| Scenario | Expected |
|
|
||||||
|---|---|
|
|
||||||
| Load channel with old `{blocks:[...]}` config | Blocks appear on all 7 day tabs |
|
|
||||||
| `PUT /channels/:id` on old-format channel | Config saved as V2 `{day_blocks:{...}}`; snapshot v1 created |
|
|
||||||
| Channel with Mon+Sat blocks only → generate | Slots only on Mondays and Saturdays in 7-day window |
|
|
||||||
| Day with empty block list | No slots that day, no error |
|
|
||||||
| `PUT /channels/:id` twice | `GET /config/history` returns 2 entries with incrementing version_num |
|
|
||||||
| Pin snapshot | Label persists in history list |
|
|
||||||
| Restore snapshot | Config reverts; new snapshot created at top of history |
|
|
||||||
| `GET /schedule/history/:bad_id` | 404 |
|
|
||||||
| Generate 3 schedules → rollback to gen#1 | gen#2+3 deleted (schedules + playback_records); new generation resumes from gen#1 continuity |
|
|
||||||
| Sequential block at S4E2 → rollback → regenerate | New schedule starts at correct episode |
|
|
||||||
| Payload with both `blocks` and `day_blocks` keys | `day_blocks` used, `blocks` silently ignored |
|
|
||||||
| V1 export file re-imported after V2 deploy | Deserializes correctly via compat path |
|
|
||||||
@@ -36,6 +36,7 @@ pub struct Config {
|
|||||||
pub jwt_issuer: Option<String>,
|
pub jwt_issuer: Option<String>,
|
||||||
pub jwt_audience: Option<String>,
|
pub jwt_audience: Option<String>,
|
||||||
pub jwt_expiry_hours: u64,
|
pub jwt_expiry_hours: u64,
|
||||||
|
pub jwt_refresh_expiry_days: u64,
|
||||||
|
|
||||||
/// Whether the application is running in production mode
|
/// Whether the application is running in production mode
|
||||||
pub is_production: bool,
|
pub is_production: bool,
|
||||||
@@ -117,6 +118,11 @@ impl Config {
|
|||||||
.and_then(|s| s.parse().ok())
|
.and_then(|s| s.parse().ok())
|
||||||
.unwrap_or(24);
|
.unwrap_or(24);
|
||||||
|
|
||||||
|
let jwt_refresh_expiry_days = env::var("JWT_REFRESH_EXPIRY_DAYS")
|
||||||
|
.ok()
|
||||||
|
.and_then(|s| s.parse().ok())
|
||||||
|
.unwrap_or(30);
|
||||||
|
|
||||||
let is_production = env::var("PRODUCTION")
|
let is_production = env::var("PRODUCTION")
|
||||||
.or_else(|_| env::var("RUST_ENV"))
|
.or_else(|_| env::var("RUST_ENV"))
|
||||||
.map(|v| v.to_lowercase() == "production" || v == "1" || v == "true")
|
.map(|v| v.to_lowercase() == "production" || v == "1" || v == "true")
|
||||||
@@ -165,6 +171,7 @@ impl Config {
|
|||||||
jwt_issuer,
|
jwt_issuer,
|
||||||
jwt_audience,
|
jwt_audience,
|
||||||
jwt_expiry_hours,
|
jwt_expiry_hours,
|
||||||
|
jwt_refresh_expiry_days,
|
||||||
is_production,
|
is_production,
|
||||||
allow_registration,
|
allow_registration,
|
||||||
jellyfin_base_url,
|
jellyfin_base_url,
|
||||||
|
|||||||
@@ -15,6 +15,15 @@ pub struct LoginRequest {
|
|||||||
pub email: Email,
|
pub email: Email,
|
||||||
/// Password is validated on deserialization (min 8 chars)
|
/// Password is validated on deserialization (min 8 chars)
|
||||||
pub password: Password,
|
pub password: Password,
|
||||||
|
/// When true, a refresh token is also issued for persistent sessions
|
||||||
|
#[serde(default)]
|
||||||
|
pub remember_me: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Refresh token request
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
pub struct RefreshRequest {
|
||||||
|
pub refresh_token: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Register request with validated email and password newtypes
|
/// Register request with validated email and password newtypes
|
||||||
@@ -41,6 +50,9 @@ pub struct TokenResponse {
|
|||||||
pub access_token: String,
|
pub access_token: String,
|
||||||
pub token_type: String,
|
pub token_type: String,
|
||||||
pub expires_in: u64,
|
pub expires_in: u64,
|
||||||
|
/// Only present when remember_me was true at login, or on token refresh
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub refresh_token: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Per-provider info returned by `GET /config`.
|
/// Per-provider info returned by `GET /config`.
|
||||||
|
|||||||
@@ -122,7 +122,7 @@ pub(crate) async fn validate_jwt_token(token: &str, state: &AppState) -> Result<
|
|||||||
.as_ref()
|
.as_ref()
|
||||||
.ok_or_else(|| ApiError::Internal("JWT validator not configured".to_string()))?;
|
.ok_or_else(|| ApiError::Internal("JWT validator not configured".to_string()))?;
|
||||||
|
|
||||||
let claims = validator.validate_token(token).map_err(|e| {
|
let claims = validator.validate_access_token(token).map_err(|e| {
|
||||||
tracing::debug!("JWT validation failed: {:?}", e);
|
tracing::debug!("JWT validation failed: {:?}", e);
|
||||||
match e {
|
match e {
|
||||||
infra::auth::jwt::JwtError::Expired => {
|
infra::auth::jwt::JwtError::Expired => {
|
||||||
|
|||||||
@@ -6,13 +6,13 @@ use axum::{
|
|||||||
};
|
};
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
dto::{LoginRequest, RegisterRequest, TokenResponse, UserResponse},
|
dto::{LoginRequest, RefreshRequest, RegisterRequest, TokenResponse, UserResponse},
|
||||||
error::ApiError,
|
error::ApiError,
|
||||||
extractors::CurrentUser,
|
extractors::CurrentUser,
|
||||||
state::AppState,
|
state::AppState,
|
||||||
};
|
};
|
||||||
|
|
||||||
use super::create_jwt;
|
use super::{create_jwt, create_refresh_jwt};
|
||||||
|
|
||||||
/// Login with email + password → JWT token
|
/// Login with email + password → JWT token
|
||||||
pub(super) async fn login(
|
pub(super) async fn login(
|
||||||
@@ -35,6 +35,11 @@ pub(super) async fn login(
|
|||||||
}
|
}
|
||||||
|
|
||||||
let token = create_jwt(&user, &state)?;
|
let token = create_jwt(&user, &state)?;
|
||||||
|
let refresh_token = if payload.remember_me {
|
||||||
|
Some(create_refresh_jwt(&user, &state)?)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
let _ = state.activity_log_repo.log("user_login", user.email.as_ref(), None).await;
|
let _ = state.activity_log_repo.log("user_login", user.email.as_ref(), None).await;
|
||||||
|
|
||||||
Ok((
|
Ok((
|
||||||
@@ -43,6 +48,7 @@ pub(super) async fn login(
|
|||||||
access_token: token,
|
access_token: token,
|
||||||
token_type: "Bearer".to_string(),
|
token_type: "Bearer".to_string(),
|
||||||
expires_in: state.config.jwt_expiry_hours * 3600,
|
expires_in: state.config.jwt_expiry_hours * 3600,
|
||||||
|
refresh_token,
|
||||||
}),
|
}),
|
||||||
))
|
))
|
||||||
}
|
}
|
||||||
@@ -71,6 +77,7 @@ pub(super) async fn register(
|
|||||||
access_token: token,
|
access_token: token,
|
||||||
token_type: "Bearer".to_string(),
|
token_type: "Bearer".to_string(),
|
||||||
expires_in: state.config.jwt_expiry_hours * 3600,
|
expires_in: state.config.jwt_expiry_hours * 3600,
|
||||||
|
refresh_token: None,
|
||||||
}),
|
}),
|
||||||
))
|
))
|
||||||
}
|
}
|
||||||
@@ -90,6 +97,46 @@ pub(super) async fn me(CurrentUser(user): CurrentUser) -> Result<impl IntoRespon
|
|||||||
}))
|
}))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Exchange a valid refresh token for a new access + refresh token pair
|
||||||
|
#[cfg(feature = "auth-jwt")]
|
||||||
|
pub(super) async fn refresh_token(
|
||||||
|
State(state): State<AppState>,
|
||||||
|
Json(payload): Json<RefreshRequest>,
|
||||||
|
) -> Result<impl IntoResponse, ApiError> {
|
||||||
|
let validator = state
|
||||||
|
.jwt_validator
|
||||||
|
.as_ref()
|
||||||
|
.ok_or_else(|| ApiError::Internal("JWT not configured".to_string()))?;
|
||||||
|
|
||||||
|
let claims = validator
|
||||||
|
.validate_refresh_token(&payload.refresh_token)
|
||||||
|
.map_err(|e| {
|
||||||
|
tracing::debug!("Refresh token validation failed: {:?}", e);
|
||||||
|
ApiError::Unauthorized("Invalid or expired refresh token".to_string())
|
||||||
|
})?;
|
||||||
|
|
||||||
|
let user_id: uuid::Uuid = claims
|
||||||
|
.sub
|
||||||
|
.parse()
|
||||||
|
.map_err(|_| ApiError::Unauthorized("Invalid user ID in token".to_string()))?;
|
||||||
|
|
||||||
|
let user = state
|
||||||
|
.user_service
|
||||||
|
.find_by_id(user_id)
|
||||||
|
.await
|
||||||
|
.map_err(|e| ApiError::Internal(format!("Failed to fetch user: {}", e)))?;
|
||||||
|
|
||||||
|
let access_token = create_jwt(&user, &state)?;
|
||||||
|
let new_refresh_token = create_refresh_jwt(&user, &state)?;
|
||||||
|
|
||||||
|
Ok(Json(TokenResponse {
|
||||||
|
access_token,
|
||||||
|
token_type: "Bearer".to_string(),
|
||||||
|
expires_in: state.config.jwt_expiry_hours * 3600,
|
||||||
|
refresh_token: Some(new_refresh_token),
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
|
||||||
/// Issue a new JWT for the currently authenticated user (OIDC→JWT exchange or token refresh)
|
/// Issue a new JWT for the currently authenticated user (OIDC→JWT exchange or token refresh)
|
||||||
#[cfg(feature = "auth-jwt")]
|
#[cfg(feature = "auth-jwt")]
|
||||||
pub(super) async fn get_token(
|
pub(super) async fn get_token(
|
||||||
@@ -102,5 +149,6 @@ pub(super) async fn get_token(
|
|||||||
access_token: token,
|
access_token: token,
|
||||||
token_type: "Bearer".to_string(),
|
token_type: "Bearer".to_string(),
|
||||||
expires_in: state.config.jwt_expiry_hours * 3600,
|
expires_in: state.config.jwt_expiry_hours * 3600,
|
||||||
|
refresh_token: None,
|
||||||
}))
|
}))
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -18,7 +18,9 @@ pub fn router() -> Router<AppState> {
|
|||||||
.route("/me", get(local::me));
|
.route("/me", get(local::me));
|
||||||
|
|
||||||
#[cfg(feature = "auth-jwt")]
|
#[cfg(feature = "auth-jwt")]
|
||||||
let r = r.route("/token", post(local::get_token));
|
let r = r
|
||||||
|
.route("/token", post(local::get_token))
|
||||||
|
.route("/refresh", post(local::refresh_token));
|
||||||
|
|
||||||
#[cfg(feature = "auth-oidc")]
|
#[cfg(feature = "auth-oidc")]
|
||||||
let r = r
|
let r = r
|
||||||
@@ -28,7 +30,7 @@ pub fn router() -> Router<AppState> {
|
|||||||
r
|
r
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Helper: create JWT for a user
|
/// Helper: create access JWT for a user
|
||||||
#[cfg(feature = "auth-jwt")]
|
#[cfg(feature = "auth-jwt")]
|
||||||
pub(super) fn create_jwt(user: &domain::User, state: &AppState) -> Result<String, ApiError> {
|
pub(super) fn create_jwt(user: &domain::User, state: &AppState) -> Result<String, ApiError> {
|
||||||
let validator = state
|
let validator = state
|
||||||
@@ -45,3 +47,21 @@ pub(super) fn create_jwt(user: &domain::User, state: &AppState) -> Result<String
|
|||||||
pub(super) fn create_jwt(_user: &domain::User, _state: &AppState) -> Result<String, ApiError> {
|
pub(super) fn create_jwt(_user: &domain::User, _state: &AppState) -> Result<String, ApiError> {
|
||||||
Err(ApiError::Internal("JWT feature not enabled".to_string()))
|
Err(ApiError::Internal("JWT feature not enabled".to_string()))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Helper: create refresh JWT for a user
|
||||||
|
#[cfg(feature = "auth-jwt")]
|
||||||
|
pub(super) fn create_refresh_jwt(user: &domain::User, state: &AppState) -> Result<String, ApiError> {
|
||||||
|
let validator = state
|
||||||
|
.jwt_validator
|
||||||
|
.as_ref()
|
||||||
|
.ok_or_else(|| ApiError::Internal("JWT not configured".to_string()))?;
|
||||||
|
|
||||||
|
validator
|
||||||
|
.create_refresh_token(user)
|
||||||
|
.map_err(|e| ApiError::Internal(format!("Failed to create refresh token: {}", e)))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(not(feature = "auth-jwt"))]
|
||||||
|
pub(super) fn create_refresh_jwt(_user: &domain::User, _state: &AppState) -> Result<String, ApiError> {
|
||||||
|
Err(ApiError::Internal("JWT feature not enabled".to_string()))
|
||||||
|
}
|
||||||
|
|||||||
@@ -124,6 +124,7 @@ impl AppState {
|
|||||||
config.jwt_issuer.clone(),
|
config.jwt_issuer.clone(),
|
||||||
config.jwt_audience.clone(),
|
config.jwt_audience.clone(),
|
||||||
Some(config.jwt_expiry_hours),
|
Some(config.jwt_expiry_hours),
|
||||||
|
Some(config.jwt_refresh_expiry_days),
|
||||||
config.is_production,
|
config.is_production,
|
||||||
)?;
|
)?;
|
||||||
Some(Arc::new(JwtValidator::new(jwt_config)))
|
Some(Arc::new(JwtValidator::new(jwt_config)))
|
||||||
|
|||||||
@@ -20,8 +20,10 @@ pub struct JwtConfig {
|
|||||||
pub issuer: Option<String>,
|
pub issuer: Option<String>,
|
||||||
/// Expected audience (for validation)
|
/// Expected audience (for validation)
|
||||||
pub audience: Option<String>,
|
pub audience: Option<String>,
|
||||||
/// Token expiry in hours (default: 24)
|
/// Access token expiry in hours (default: 24)
|
||||||
pub expiry_hours: u64,
|
pub expiry_hours: u64,
|
||||||
|
/// Refresh token expiry in days (default: 30)
|
||||||
|
pub refresh_expiry_days: u64,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl JwtConfig {
|
impl JwtConfig {
|
||||||
@@ -33,6 +35,7 @@ impl JwtConfig {
|
|||||||
issuer: Option<String>,
|
issuer: Option<String>,
|
||||||
audience: Option<String>,
|
audience: Option<String>,
|
||||||
expiry_hours: Option<u64>,
|
expiry_hours: Option<u64>,
|
||||||
|
refresh_expiry_days: Option<u64>,
|
||||||
is_production: bool,
|
is_production: bool,
|
||||||
) -> Result<Self, JwtError> {
|
) -> Result<Self, JwtError> {
|
||||||
// Validate secret strength in production
|
// Validate secret strength in production
|
||||||
@@ -48,6 +51,7 @@ impl JwtConfig {
|
|||||||
issuer,
|
issuer,
|
||||||
audience,
|
audience,
|
||||||
expiry_hours: expiry_hours.unwrap_or(24),
|
expiry_hours: expiry_hours.unwrap_or(24),
|
||||||
|
refresh_expiry_days: refresh_expiry_days.unwrap_or(30),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -58,10 +62,15 @@ impl JwtConfig {
|
|||||||
issuer: None,
|
issuer: None,
|
||||||
audience: None,
|
audience: None,
|
||||||
expiry_hours: 24,
|
expiry_hours: 24,
|
||||||
|
refresh_expiry_days: 30,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn default_token_type() -> String {
|
||||||
|
"access".to_string()
|
||||||
|
}
|
||||||
|
|
||||||
/// JWT claims structure
|
/// JWT claims structure
|
||||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||||
pub struct JwtClaims {
|
pub struct JwtClaims {
|
||||||
@@ -79,6 +88,9 @@ pub struct JwtClaims {
|
|||||||
/// Audience
|
/// Audience
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
pub aud: Option<String>,
|
pub aud: Option<String>,
|
||||||
|
/// Token type: "access" or "refresh". Defaults to "access" for backward compat.
|
||||||
|
#[serde(default = "default_token_type")]
|
||||||
|
pub token_type: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// JWT-related errors
|
/// JWT-related errors
|
||||||
@@ -141,7 +153,7 @@ impl JwtValidator {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Create a JWT token for the given user
|
/// Create an access JWT token for the given user
|
||||||
pub fn create_token(&self, user: &User) -> Result<String, JwtError> {
|
pub fn create_token(&self, user: &User) -> Result<String, JwtError> {
|
||||||
let now = SystemTime::now()
|
let now = SystemTime::now()
|
||||||
.duration_since(UNIX_EPOCH)
|
.duration_since(UNIX_EPOCH)
|
||||||
@@ -157,6 +169,30 @@ impl JwtValidator {
|
|||||||
iat: now,
|
iat: now,
|
||||||
iss: self.config.issuer.clone(),
|
iss: self.config.issuer.clone(),
|
||||||
aud: self.config.audience.clone(),
|
aud: self.config.audience.clone(),
|
||||||
|
token_type: "access".to_string(),
|
||||||
|
};
|
||||||
|
|
||||||
|
let header = Header::new(Algorithm::HS256);
|
||||||
|
encode(&header, &claims, &self.encoding_key).map_err(JwtError::CreationFailed)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create a refresh JWT token for the given user (longer-lived)
|
||||||
|
pub fn create_refresh_token(&self, user: &User) -> Result<String, JwtError> {
|
||||||
|
let now = SystemTime::now()
|
||||||
|
.duration_since(UNIX_EPOCH)
|
||||||
|
.expect("Time went backwards")
|
||||||
|
.as_secs() as usize;
|
||||||
|
|
||||||
|
let expiry = now + (self.config.refresh_expiry_days as usize * 86400);
|
||||||
|
|
||||||
|
let claims = JwtClaims {
|
||||||
|
sub: user.id.to_string(),
|
||||||
|
email: user.email.as_ref().to_string(),
|
||||||
|
exp: expiry,
|
||||||
|
iat: now,
|
||||||
|
iss: self.config.issuer.clone(),
|
||||||
|
aud: self.config.audience.clone(),
|
||||||
|
token_type: "refresh".to_string(),
|
||||||
};
|
};
|
||||||
|
|
||||||
let header = Header::new(Algorithm::HS256);
|
let header = Header::new(Algorithm::HS256);
|
||||||
@@ -176,6 +212,24 @@ impl JwtValidator {
|
|||||||
Ok(token_data.claims)
|
Ok(token_data.claims)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Validate an access token — rejects refresh tokens
|
||||||
|
pub fn validate_access_token(&self, token: &str) -> Result<JwtClaims, JwtError> {
|
||||||
|
let claims = self.validate_token(token)?;
|
||||||
|
if claims.token_type != "access" {
|
||||||
|
return Err(JwtError::ValidationFailed("Not an access token".to_string()));
|
||||||
|
}
|
||||||
|
Ok(claims)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Validate a refresh token — rejects access tokens
|
||||||
|
pub fn validate_refresh_token(&self, token: &str) -> Result<JwtClaims, JwtError> {
|
||||||
|
let claims = self.validate_token(token)?;
|
||||||
|
if claims.token_type != "refresh" {
|
||||||
|
return Err(JwtError::ValidationFailed("Not a refresh token".to_string()));
|
||||||
|
}
|
||||||
|
Ok(claims)
|
||||||
|
}
|
||||||
|
|
||||||
/// Get the user ID (subject) from a token without full validation
|
/// Get the user ID (subject) from a token without full validation
|
||||||
/// Useful for logging/debugging, but should not be trusted for auth
|
/// Useful for logging/debugging, but should not be trusted for auth
|
||||||
pub fn decode_unverified(&self, token: &str) -> Result<JwtClaims, JwtError> {
|
pub fn decode_unverified(&self, token: &str) -> Result<JwtClaims, JwtError> {
|
||||||
|
|||||||
@@ -2,7 +2,6 @@
|
|||||||
name = "mcp"
|
name = "mcp"
|
||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
edition = "2024"
|
edition = "2024"
|
||||||
default-run = "mcp"
|
|
||||||
|
|
||||||
[features]
|
[features]
|
||||||
default = ["sqlite", "jellyfin"]
|
default = ["sqlite", "jellyfin"]
|
||||||
|
|||||||
@@ -8,12 +8,13 @@ import { useConfig } from "@/hooks/use-channels";
|
|||||||
export default function LoginPage() {
|
export default function LoginPage() {
|
||||||
const [email, setEmail] = useState("");
|
const [email, setEmail] = useState("");
|
||||||
const [password, setPassword] = useState("");
|
const [password, setPassword] = useState("");
|
||||||
|
const [rememberMe, setRememberMe] = useState(false);
|
||||||
const { mutate: login, isPending, error } = useLogin();
|
const { mutate: login, isPending, error } = useLogin();
|
||||||
const { data: config } = useConfig();
|
const { data: config } = useConfig();
|
||||||
|
|
||||||
const handleSubmit = (e: React.FormEvent) => {
|
const handleSubmit = (e: React.FormEvent) => {
|
||||||
e.preventDefault();
|
e.preventDefault();
|
||||||
login({ email, password });
|
login({ email, password, rememberMe });
|
||||||
};
|
};
|
||||||
|
|
||||||
return (
|
return (
|
||||||
@@ -54,6 +55,23 @@ export default function LoginPage() {
|
|||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
<div className="space-y-1">
|
||||||
|
<label className="flex cursor-pointer items-center gap-2">
|
||||||
|
<input
|
||||||
|
type="checkbox"
|
||||||
|
checked={rememberMe}
|
||||||
|
onChange={(e) => setRememberMe(e.target.checked)}
|
||||||
|
className="h-3.5 w-3.5 rounded border-zinc-600 bg-zinc-900 accent-white"
|
||||||
|
/>
|
||||||
|
<span className="text-xs text-zinc-400">Remember me</span>
|
||||||
|
</label>
|
||||||
|
{rememberMe && (
|
||||||
|
<p className="pl-5 text-xs text-amber-500/80">
|
||||||
|
A refresh token will be stored locally — don't share it.
|
||||||
|
</p>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
|
||||||
{error && <p className="text-xs text-red-400">{error.message}</p>}
|
{error && <p className="text-xs text-red-400">{error.message}</p>}
|
||||||
|
|
||||||
<button
|
<button
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ import { Toaster } from "@/components/ui/sonner";
|
|||||||
import { ApiRequestError } from "@/lib/api";
|
import { ApiRequestError } from "@/lib/api";
|
||||||
|
|
||||||
function QueryProvider({ children }: { children: React.ReactNode }) {
|
function QueryProvider({ children }: { children: React.ReactNode }) {
|
||||||
const { token, setToken } = useAuthContext();
|
const { token, setTokens } = useAuthContext();
|
||||||
const router = useRouter();
|
const router = useRouter();
|
||||||
const tokenRef = useRef(token);
|
const tokenRef = useRef(token);
|
||||||
useEffect(() => { tokenRef.current = token; }, [token]);
|
useEffect(() => { tokenRef.current = token; }, [token]);
|
||||||
@@ -29,7 +29,7 @@ function QueryProvider({ children }: { children: React.ReactNode }) {
|
|||||||
// Guests hitting 401 on restricted content should not be redirected.
|
// Guests hitting 401 on restricted content should not be redirected.
|
||||||
if (error instanceof ApiRequestError && error.status === 401 && tokenRef.current) {
|
if (error instanceof ApiRequestError && error.status === 401 && tokenRef.current) {
|
||||||
toast.warning("Session expired, please log in again.");
|
toast.warning("Session expired, please log in again.");
|
||||||
setToken(null);
|
setTokens(null, null, false);
|
||||||
router.push("/login");
|
router.push("/login");
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
@@ -39,7 +39,7 @@ function QueryProvider({ children }: { children: React.ReactNode }) {
|
|||||||
// Mutations always require auth — redirect on 401 regardless.
|
// Mutations always require auth — redirect on 401 regardless.
|
||||||
if (error instanceof ApiRequestError && error.status === 401) {
|
if (error instanceof ApiRequestError && error.status === 401) {
|
||||||
toast.warning("Session expired, please log in again.");
|
toast.warning("Session expired, please log in again.");
|
||||||
setToken(null);
|
setTokens(null, null, false);
|
||||||
router.push("/login");
|
router.push("/login");
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -4,42 +4,94 @@ import {
|
|||||||
createContext,
|
createContext,
|
||||||
useContext,
|
useContext,
|
||||||
useState,
|
useState,
|
||||||
|
useEffect,
|
||||||
type ReactNode,
|
type ReactNode,
|
||||||
} from "react";
|
} from "react";
|
||||||
|
import { useRouter } from "next/navigation";
|
||||||
|
import { api, setRefreshCallback } from "@/lib/api";
|
||||||
|
|
||||||
const TOKEN_KEY = "k-tv-token";
|
const ACCESS_KEY_LOCAL = "k-tv-token";
|
||||||
|
const ACCESS_KEY_SESSION = "k-tv-token-session";
|
||||||
|
const REFRESH_KEY = "k-tv-refresh-token";
|
||||||
|
|
||||||
interface AuthContextValue {
|
interface AuthContextValue {
|
||||||
token: string | null;
|
token: string | null;
|
||||||
/** True once the initial localStorage read has completed */
|
refreshToken: string | null;
|
||||||
|
/** Always true (lazy init reads storage synchronously) */
|
||||||
isLoaded: boolean;
|
isLoaded: boolean;
|
||||||
setToken: (token: string | null) => void;
|
setTokens: (access: string | null, refresh: string | null, remember: boolean) => void;
|
||||||
}
|
}
|
||||||
|
|
||||||
const AuthContext = createContext<AuthContextValue | null>(null);
|
const AuthContext = createContext<AuthContextValue | null>(null);
|
||||||
|
|
||||||
export function AuthProvider({ children }: { children: ReactNode }) {
|
export function AuthProvider({ children }: { children: ReactNode }) {
|
||||||
|
const router = useRouter();
|
||||||
|
|
||||||
const [token, setTokenState] = useState<string | null>(() => {
|
const [token, setTokenState] = useState<string | null>(() => {
|
||||||
try {
|
try {
|
||||||
return localStorage.getItem(TOKEN_KEY);
|
return sessionStorage.getItem(ACCESS_KEY_SESSION) ?? localStorage.getItem(ACCESS_KEY_LOCAL);
|
||||||
} catch {
|
} catch {
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
// isLoaded is always true: lazy init above reads localStorage synchronously
|
|
||||||
const [isLoaded] = useState(true);
|
|
||||||
|
|
||||||
const setToken = (t: string | null) => {
|
const [refreshToken, setRefreshTokenState] = useState<string | null>(() => {
|
||||||
setTokenState(t);
|
try {
|
||||||
if (t) {
|
return localStorage.getItem(REFRESH_KEY);
|
||||||
localStorage.setItem(TOKEN_KEY, t);
|
} catch {
|
||||||
} else {
|
return null;
|
||||||
localStorage.removeItem(TOKEN_KEY);
|
|
||||||
}
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
const setTokens = (access: string | null, refresh: string | null, remember: boolean) => {
|
||||||
|
try {
|
||||||
|
if (access === null) {
|
||||||
|
sessionStorage.removeItem(ACCESS_KEY_SESSION);
|
||||||
|
localStorage.removeItem(ACCESS_KEY_LOCAL);
|
||||||
|
localStorage.removeItem(REFRESH_KEY);
|
||||||
|
} else if (remember) {
|
||||||
|
localStorage.setItem(ACCESS_KEY_LOCAL, access);
|
||||||
|
sessionStorage.removeItem(ACCESS_KEY_SESSION);
|
||||||
|
if (refresh) {
|
||||||
|
localStorage.setItem(REFRESH_KEY, refresh);
|
||||||
|
} else {
|
||||||
|
localStorage.removeItem(REFRESH_KEY);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
sessionStorage.setItem(ACCESS_KEY_SESSION, access);
|
||||||
|
localStorage.removeItem(ACCESS_KEY_LOCAL);
|
||||||
|
localStorage.removeItem(REFRESH_KEY);
|
||||||
|
}
|
||||||
|
} catch {
|
||||||
|
// storage unavailable — state-only fallback
|
||||||
|
}
|
||||||
|
setTokenState(access);
|
||||||
|
setRefreshTokenState(refresh);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Wire up the transparent refresh callback in api.ts
|
||||||
|
useEffect(() => {
|
||||||
|
if (refreshToken) {
|
||||||
|
setRefreshCallback(async () => {
|
||||||
|
try {
|
||||||
|
const data = await api.auth.refresh(refreshToken);
|
||||||
|
const newRefresh = data.refresh_token ?? null;
|
||||||
|
setTokens(data.access_token, newRefresh, true);
|
||||||
|
return data.access_token;
|
||||||
|
} catch {
|
||||||
|
setTokens(null, null, false);
|
||||||
|
router.push("/login");
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
setRefreshCallback(null);
|
||||||
|
}
|
||||||
|
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||||
|
}, [refreshToken]);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<AuthContext.Provider value={{ token, isLoaded, setToken }}>
|
<AuthContext.Provider value={{ token, refreshToken, isLoaded: true, setTokens }}>
|
||||||
{children}
|
{children}
|
||||||
</AuthContext.Provider>
|
</AuthContext.Provider>
|
||||||
);
|
);
|
||||||
|
|||||||
@@ -16,14 +16,21 @@ export function useCurrentUser() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
export function useLogin() {
|
export function useLogin() {
|
||||||
const { setToken } = useAuthContext();
|
const { setTokens } = useAuthContext();
|
||||||
const router = useRouter();
|
const router = useRouter();
|
||||||
const queryClient = useQueryClient();
|
const queryClient = useQueryClient();
|
||||||
return useMutation({
|
return useMutation({
|
||||||
mutationFn: ({ email, password }: { email: string; password: string }) =>
|
mutationFn: ({
|
||||||
api.auth.login(email, password),
|
email,
|
||||||
onSuccess: (data) => {
|
password,
|
||||||
setToken(data.access_token);
|
rememberMe,
|
||||||
|
}: {
|
||||||
|
email: string;
|
||||||
|
password: string;
|
||||||
|
rememberMe: boolean;
|
||||||
|
}) => api.auth.login(email, password, rememberMe),
|
||||||
|
onSuccess: (data, { rememberMe }) => {
|
||||||
|
setTokens(data.access_token, data.refresh_token ?? null, rememberMe);
|
||||||
queryClient.invalidateQueries({ queryKey: ["me"] });
|
queryClient.invalidateQueries({ queryKey: ["me"] });
|
||||||
router.push("/dashboard");
|
router.push("/dashboard");
|
||||||
},
|
},
|
||||||
@@ -31,14 +38,14 @@ export function useLogin() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
export function useRegister() {
|
export function useRegister() {
|
||||||
const { setToken } = useAuthContext();
|
const { setTokens } = useAuthContext();
|
||||||
const router = useRouter();
|
const router = useRouter();
|
||||||
const queryClient = useQueryClient();
|
const queryClient = useQueryClient();
|
||||||
return useMutation({
|
return useMutation({
|
||||||
mutationFn: ({ email, password }: { email: string; password: string }) =>
|
mutationFn: ({ email, password }: { email: string; password: string }) =>
|
||||||
api.auth.register(email, password),
|
api.auth.register(email, password),
|
||||||
onSuccess: (data) => {
|
onSuccess: (data) => {
|
||||||
setToken(data.access_token);
|
setTokens(data.access_token, null, false);
|
||||||
queryClient.invalidateQueries({ queryKey: ["me"] });
|
queryClient.invalidateQueries({ queryKey: ["me"] });
|
||||||
router.push("/dashboard");
|
router.push("/dashboard");
|
||||||
},
|
},
|
||||||
@@ -46,13 +53,13 @@ export function useRegister() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
export function useLogout() {
|
export function useLogout() {
|
||||||
const { token, setToken } = useAuthContext();
|
const { token, setTokens } = useAuthContext();
|
||||||
const router = useRouter();
|
const router = useRouter();
|
||||||
const queryClient = useQueryClient();
|
const queryClient = useQueryClient();
|
||||||
return useMutation({
|
return useMutation({
|
||||||
mutationFn: () => (token ? api.auth.logout(token) : Promise.resolve()),
|
mutationFn: () => (token ? api.auth.logout(token) : Promise.resolve()),
|
||||||
onSettled: () => {
|
onSettled: () => {
|
||||||
setToken(null);
|
setTokens(null, null, false);
|
||||||
queryClient.clear();
|
queryClient.clear();
|
||||||
router.push("/login");
|
router.push("/login");
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -34,6 +34,23 @@ export class ApiRequestError extends Error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Called by AuthProvider when refreshToken changes — enables transparent 401 recovery
|
||||||
|
let refreshCallback: (() => Promise<string | null>) | null = null;
|
||||||
|
let refreshInFlight: Promise<string | null> | null = null;
|
||||||
|
|
||||||
|
export function setRefreshCallback(cb: (() => Promise<string | null>) | null) {
|
||||||
|
refreshCallback = cb;
|
||||||
|
}
|
||||||
|
|
||||||
|
async function attemptRefresh(): Promise<string | null> {
|
||||||
|
if (!refreshCallback) return null;
|
||||||
|
if (refreshInFlight) return refreshInFlight;
|
||||||
|
refreshInFlight = refreshCallback().finally(() => {
|
||||||
|
refreshInFlight = null;
|
||||||
|
});
|
||||||
|
return refreshInFlight;
|
||||||
|
}
|
||||||
|
|
||||||
async function request<T>(
|
async function request<T>(
|
||||||
path: string,
|
path: string,
|
||||||
options: RequestInit & { token?: string } = {},
|
options: RequestInit & { token?: string } = {},
|
||||||
@@ -50,6 +67,35 @@ async function request<T>(
|
|||||||
|
|
||||||
const res = await fetch(`${API_BASE}${path}`, { ...init, headers });
|
const res = await fetch(`${API_BASE}${path}`, { ...init, headers });
|
||||||
|
|
||||||
|
// Transparent refresh: on 401, try to get a new access token and retry once.
|
||||||
|
// Skip for the refresh endpoint itself to avoid infinite loops.
|
||||||
|
if (res.status === 401 && path !== "/auth/refresh") {
|
||||||
|
const newToken = await attemptRefresh();
|
||||||
|
if (newToken) {
|
||||||
|
const retryHeaders = new Headers(init.headers);
|
||||||
|
retryHeaders.set("Authorization", `Bearer ${newToken}`);
|
||||||
|
if (init.body && !retryHeaders.has("Content-Type")) {
|
||||||
|
retryHeaders.set("Content-Type", "application/json");
|
||||||
|
}
|
||||||
|
const retryRes = await fetch(`${API_BASE}${path}`, {
|
||||||
|
...init,
|
||||||
|
headers: retryHeaders,
|
||||||
|
});
|
||||||
|
if (!retryRes.ok) {
|
||||||
|
let message = retryRes.statusText;
|
||||||
|
try {
|
||||||
|
const body = await retryRes.json();
|
||||||
|
message = body.message ?? body.error ?? message;
|
||||||
|
} catch {
|
||||||
|
// ignore parse error
|
||||||
|
}
|
||||||
|
throw new ApiRequestError(retryRes.status, message);
|
||||||
|
}
|
||||||
|
if (retryRes.status === 204) return null as T;
|
||||||
|
return retryRes.json() as Promise<T>;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if (!res.ok) {
|
if (!res.ok) {
|
||||||
let message = res.statusText;
|
let message = res.statusText;
|
||||||
try {
|
try {
|
||||||
@@ -77,10 +123,16 @@ export const api = {
|
|||||||
body: JSON.stringify({ email, password }),
|
body: JSON.stringify({ email, password }),
|
||||||
}),
|
}),
|
||||||
|
|
||||||
login: (email: string, password: string) =>
|
login: (email: string, password: string, rememberMe = false) =>
|
||||||
request<TokenResponse>("/auth/login", {
|
request<TokenResponse>("/auth/login", {
|
||||||
method: "POST",
|
method: "POST",
|
||||||
body: JSON.stringify({ email, password }),
|
body: JSON.stringify({ email, password, remember_me: rememberMe }),
|
||||||
|
}),
|
||||||
|
|
||||||
|
refresh: (refreshToken: string) =>
|
||||||
|
request<TokenResponse>("/auth/refresh", {
|
||||||
|
method: "POST",
|
||||||
|
body: JSON.stringify({ refresh_token: refreshToken }),
|
||||||
}),
|
}),
|
||||||
|
|
||||||
logout: (token: string) =>
|
logout: (token: string) =>
|
||||||
|
|||||||
@@ -178,6 +178,7 @@ export interface TokenResponse {
|
|||||||
access_token: string;
|
access_token: string;
|
||||||
token_type: string;
|
token_type: string;
|
||||||
expires_in: number;
|
expires_in: number;
|
||||||
|
refresh_token?: string;
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface UserResponse {
|
export interface UserResponse {
|
||||||
|
|||||||
Reference in New Issue
Block a user