Implement multi-user RBAC: database, auth, routing, admin API (Phases 1-6)
Phase 1: Add role, mfa_enforce_pending, must_change_password to users table. Create system_config (singleton) and audit_log tables. Migration 026. Phase 2: Add user_id FK to all 8 data tables (todos, reminders, projects, calendars, people, locations, event_templates, ntfy_sent) with 4-step nullable→backfill→FK→NOT NULL pattern. Migrations 027-034. Phase 3: Harden auth schemas (extra="forbid" on RegisterRequest), add MFA enforcement token serializer with distinct salt, rewrite auth router with require_role() factory and registration endpoint. Phase 4: Scope all 12 routers by user_id, fix dependency type bugs, bound weather cache (SEC-15), multi-user ntfy dispatch. Phase 5: Create admin router (14 endpoints), admin schemas, audit service, rate limiting in nginx. SEC-08 CSRF via X-Requested-With. Phase 6: Update frontend types, useAuth hook (role/isAdmin/register), App.tsx (AdminRoute guard), Sidebar (admin link), api.ts (XHR header). Security findings addressed: SEC-01, SEC-02, SEC-03, SEC-04, SEC-05, SEC-06, SEC-07, SEC-08, SEC-12, SEC-13, SEC-15. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
2ec70d9344
commit
d8bdae8ec3
101
backend/alembic/versions/026_add_user_role_and_system_config.py
Normal file
101
backend/alembic/versions/026_add_user_role_and_system_config.py
Normal file
@ -0,0 +1,101 @@
|
||||
"""Add role, mfa_enforce_pending, must_change_password to users; create system_config table.
|
||||
|
||||
Revision ID: 026
|
||||
Revises: 025
|
||||
Create Date: 2026-02-26
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
revision = "026"
|
||||
down_revision = "025"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# 1. Add role column with server_default for existing rows
|
||||
op.add_column("users", sa.Column(
|
||||
"role", sa.String(30), nullable=False, server_default="standard"
|
||||
))
|
||||
|
||||
# 2. Add MFA enforcement pending flag
|
||||
op.add_column("users", sa.Column(
|
||||
"mfa_enforce_pending", sa.Boolean(), nullable=False, server_default="false"
|
||||
))
|
||||
|
||||
# 3. Add forced password change flag (SEC-12)
|
||||
op.add_column("users", sa.Column(
|
||||
"must_change_password", sa.Boolean(), nullable=False, server_default="false"
|
||||
))
|
||||
|
||||
# 4. Add last_password_change_at audit column
|
||||
op.add_column("users", sa.Column(
|
||||
"last_password_change_at", sa.DateTime(), nullable=True
|
||||
))
|
||||
|
||||
# 5. Add CHECK constraint on role values (SEC-16)
|
||||
op.create_check_constraint(
|
||||
"ck_users_role",
|
||||
"users",
|
||||
"role IN ('admin', 'standard', 'public_event_manager')"
|
||||
)
|
||||
|
||||
# 6. Promote the first (existing) user to admin
|
||||
op.execute(
|
||||
"UPDATE users SET role = 'admin' WHERE id = (SELECT MIN(id) FROM users)"
|
||||
)
|
||||
|
||||
# 7. Create system_config table (singleton pattern -- always id=1)
|
||||
op.create_table(
|
||||
"system_config",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("allow_registration", sa.Boolean(), nullable=False, server_default="false"),
|
||||
sa.Column("enforce_mfa_new_users", sa.Boolean(), nullable=False, server_default="false"),
|
||||
sa.Column("created_at", sa.DateTime(), nullable=False, server_default=sa.text("NOW()")),
|
||||
sa.Column("updated_at", sa.DateTime(), nullable=False, server_default=sa.text("NOW()")),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
# SEC-09: Enforce singleton row
|
||||
sa.CheckConstraint("id = 1", name="ck_system_config_singleton"),
|
||||
)
|
||||
|
||||
# 8. Seed the singleton row
|
||||
op.execute(
|
||||
"INSERT INTO system_config (id, allow_registration, enforce_mfa_new_users) "
|
||||
"VALUES (1, false, false)"
|
||||
)
|
||||
|
||||
# 9. Create audit_log table
|
||||
op.create_table(
|
||||
"audit_log",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("actor_user_id", sa.Integer(), nullable=True),
|
||||
sa.Column("target_user_id", sa.Integer(), nullable=True),
|
||||
sa.Column("action", sa.String(100), nullable=False),
|
||||
sa.Column("detail", sa.Text(), nullable=True),
|
||||
sa.Column("ip_address", sa.String(45), nullable=True),
|
||||
sa.Column("created_at", sa.DateTime(), nullable=False, server_default=sa.text("NOW()")),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.ForeignKeyConstraint(["actor_user_id"], ["users.id"]),
|
||||
sa.ForeignKeyConstraint(
|
||||
["target_user_id"], ["users.id"], ondelete="SET NULL"
|
||||
),
|
||||
)
|
||||
op.create_index("ix_audit_log_actor_user_id", "audit_log", ["actor_user_id"])
|
||||
op.create_index("ix_audit_log_target_user_id", "audit_log", ["target_user_id"])
|
||||
op.create_index("ix_audit_log_action", "audit_log", ["action"])
|
||||
op.create_index("ix_audit_log_created_at", "audit_log", ["created_at"])
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index("ix_audit_log_created_at", table_name="audit_log")
|
||||
op.drop_index("ix_audit_log_action", table_name="audit_log")
|
||||
op.drop_index("ix_audit_log_target_user_id", table_name="audit_log")
|
||||
op.drop_index("ix_audit_log_actor_user_id", table_name="audit_log")
|
||||
op.drop_table("audit_log")
|
||||
op.drop_table("system_config")
|
||||
op.drop_constraint("ck_users_role", "users", type_="check")
|
||||
op.drop_column("users", "last_password_change_at")
|
||||
op.drop_column("users", "must_change_password")
|
||||
op.drop_column("users", "mfa_enforce_pending")
|
||||
op.drop_column("users", "role")
|
||||
38
backend/alembic/versions/027_add_user_id_to_todos.py
Normal file
38
backend/alembic/versions/027_add_user_id_to_todos.py
Normal file
@ -0,0 +1,38 @@
|
||||
"""Add user_id FK to todos table.
|
||||
|
||||
Revision ID: 027
|
||||
Revises: 026
|
||||
Create Date: 2026-02-26
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
revision = "027"
|
||||
down_revision = "026"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column("todos", sa.Column("user_id", sa.Integer(), nullable=True))
|
||||
op.execute(
|
||||
"UPDATE todos SET user_id = ("
|
||||
" SELECT id FROM users WHERE role = 'admin' ORDER BY id LIMIT 1"
|
||||
")"
|
||||
)
|
||||
op.create_foreign_key(
|
||||
"fk_todos_user_id", "todos", "users",
|
||||
["user_id"], ["id"], ondelete="CASCADE"
|
||||
)
|
||||
op.alter_column("todos", "user_id", nullable=False)
|
||||
op.create_index("ix_todos_user_id", "todos", ["user_id"])
|
||||
op.create_index("ix_todos_user_completed", "todos", ["user_id", "completed"])
|
||||
op.create_index("ix_todos_user_due_date", "todos", ["user_id", "due_date"])
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index("ix_todos_user_due_date", table_name="todos")
|
||||
op.drop_index("ix_todos_user_completed", table_name="todos")
|
||||
op.drop_index("ix_todos_user_id", table_name="todos")
|
||||
op.drop_constraint("fk_todos_user_id", "todos", type_="foreignkey")
|
||||
op.drop_column("todos", "user_id")
|
||||
36
backend/alembic/versions/028_add_user_id_to_reminders.py
Normal file
36
backend/alembic/versions/028_add_user_id_to_reminders.py
Normal file
@ -0,0 +1,36 @@
|
||||
"""Add user_id FK to reminders table.
|
||||
|
||||
Revision ID: 028
|
||||
Revises: 027
|
||||
Create Date: 2026-02-26
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
revision = "028"
|
||||
down_revision = "027"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column("reminders", sa.Column("user_id", sa.Integer(), nullable=True))
|
||||
op.execute(
|
||||
"UPDATE reminders SET user_id = ("
|
||||
" SELECT id FROM users WHERE role = 'admin' ORDER BY id LIMIT 1"
|
||||
")"
|
||||
)
|
||||
op.create_foreign_key(
|
||||
"fk_reminders_user_id", "reminders", "users",
|
||||
["user_id"], ["id"], ondelete="CASCADE"
|
||||
)
|
||||
op.alter_column("reminders", "user_id", nullable=False)
|
||||
op.create_index("ix_reminders_user_id", "reminders", ["user_id"])
|
||||
op.create_index("ix_reminders_user_remind_at", "reminders", ["user_id", "remind_at"])
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index("ix_reminders_user_remind_at", table_name="reminders")
|
||||
op.drop_index("ix_reminders_user_id", table_name="reminders")
|
||||
op.drop_constraint("fk_reminders_user_id", "reminders", type_="foreignkey")
|
||||
op.drop_column("reminders", "user_id")
|
||||
36
backend/alembic/versions/029_add_user_id_to_projects.py
Normal file
36
backend/alembic/versions/029_add_user_id_to_projects.py
Normal file
@ -0,0 +1,36 @@
|
||||
"""Add user_id FK to projects table.
|
||||
|
||||
Revision ID: 029
|
||||
Revises: 028
|
||||
Create Date: 2026-02-26
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
revision = "029"
|
||||
down_revision = "028"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column("projects", sa.Column("user_id", sa.Integer(), nullable=True))
|
||||
op.execute(
|
||||
"UPDATE projects SET user_id = ("
|
||||
" SELECT id FROM users WHERE role = 'admin' ORDER BY id LIMIT 1"
|
||||
")"
|
||||
)
|
||||
op.create_foreign_key(
|
||||
"fk_projects_user_id", "projects", "users",
|
||||
["user_id"], ["id"], ondelete="CASCADE"
|
||||
)
|
||||
op.alter_column("projects", "user_id", nullable=False)
|
||||
op.create_index("ix_projects_user_id", "projects", ["user_id"])
|
||||
op.create_index("ix_projects_user_status", "projects", ["user_id", "status"])
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index("ix_projects_user_status", table_name="projects")
|
||||
op.drop_index("ix_projects_user_id", table_name="projects")
|
||||
op.drop_constraint("fk_projects_user_id", "projects", type_="foreignkey")
|
||||
op.drop_column("projects", "user_id")
|
||||
36
backend/alembic/versions/030_add_user_id_to_calendars.py
Normal file
36
backend/alembic/versions/030_add_user_id_to_calendars.py
Normal file
@ -0,0 +1,36 @@
|
||||
"""Add user_id FK to calendars table.
|
||||
|
||||
Revision ID: 030
|
||||
Revises: 029
|
||||
Create Date: 2026-02-26
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
revision = "030"
|
||||
down_revision = "029"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column("calendars", sa.Column("user_id", sa.Integer(), nullable=True))
|
||||
op.execute(
|
||||
"UPDATE calendars SET user_id = ("
|
||||
" SELECT id FROM users WHERE role = 'admin' ORDER BY id LIMIT 1"
|
||||
")"
|
||||
)
|
||||
op.create_foreign_key(
|
||||
"fk_calendars_user_id", "calendars", "users",
|
||||
["user_id"], ["id"], ondelete="CASCADE"
|
||||
)
|
||||
op.alter_column("calendars", "user_id", nullable=False)
|
||||
op.create_index("ix_calendars_user_id", "calendars", ["user_id"])
|
||||
op.create_index("ix_calendars_user_default", "calendars", ["user_id", "is_default"])
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index("ix_calendars_user_default", table_name="calendars")
|
||||
op.drop_index("ix_calendars_user_id", table_name="calendars")
|
||||
op.drop_constraint("fk_calendars_user_id", "calendars", type_="foreignkey")
|
||||
op.drop_column("calendars", "user_id")
|
||||
36
backend/alembic/versions/031_add_user_id_to_people.py
Normal file
36
backend/alembic/versions/031_add_user_id_to_people.py
Normal file
@ -0,0 +1,36 @@
|
||||
"""Add user_id FK to people table.
|
||||
|
||||
Revision ID: 031
|
||||
Revises: 030
|
||||
Create Date: 2026-02-26
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
revision = "031"
|
||||
down_revision = "030"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column("people", sa.Column("user_id", sa.Integer(), nullable=True))
|
||||
op.execute(
|
||||
"UPDATE people SET user_id = ("
|
||||
" SELECT id FROM users WHERE role = 'admin' ORDER BY id LIMIT 1"
|
||||
")"
|
||||
)
|
||||
op.create_foreign_key(
|
||||
"fk_people_user_id", "people", "users",
|
||||
["user_id"], ["id"], ondelete="CASCADE"
|
||||
)
|
||||
op.alter_column("people", "user_id", nullable=False)
|
||||
op.create_index("ix_people_user_id", "people", ["user_id"])
|
||||
op.create_index("ix_people_user_name", "people", ["user_id", "name"])
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index("ix_people_user_name", table_name="people")
|
||||
op.drop_index("ix_people_user_id", table_name="people")
|
||||
op.drop_constraint("fk_people_user_id", "people", type_="foreignkey")
|
||||
op.drop_column("people", "user_id")
|
||||
34
backend/alembic/versions/032_add_user_id_to_locations.py
Normal file
34
backend/alembic/versions/032_add_user_id_to_locations.py
Normal file
@ -0,0 +1,34 @@
|
||||
"""Add user_id FK to locations table.
|
||||
|
||||
Revision ID: 032
|
||||
Revises: 031
|
||||
Create Date: 2026-02-26
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
revision = "032"
|
||||
down_revision = "031"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column("locations", sa.Column("user_id", sa.Integer(), nullable=True))
|
||||
op.execute(
|
||||
"UPDATE locations SET user_id = ("
|
||||
" SELECT id FROM users WHERE role = 'admin' ORDER BY id LIMIT 1"
|
||||
")"
|
||||
)
|
||||
op.create_foreign_key(
|
||||
"fk_locations_user_id", "locations", "users",
|
||||
["user_id"], ["id"], ondelete="CASCADE"
|
||||
)
|
||||
op.alter_column("locations", "user_id", nullable=False)
|
||||
op.create_index("ix_locations_user_id", "locations", ["user_id"])
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index("ix_locations_user_id", table_name="locations")
|
||||
op.drop_constraint("fk_locations_user_id", "locations", type_="foreignkey")
|
||||
op.drop_column("locations", "user_id")
|
||||
@ -0,0 +1,34 @@
|
||||
"""Add user_id FK to event_templates table.
|
||||
|
||||
Revision ID: 033
|
||||
Revises: 032
|
||||
Create Date: 2026-02-26
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
revision = "033"
|
||||
down_revision = "032"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column("event_templates", sa.Column("user_id", sa.Integer(), nullable=True))
|
||||
op.execute(
|
||||
"UPDATE event_templates SET user_id = ("
|
||||
" SELECT id FROM users WHERE role = 'admin' ORDER BY id LIMIT 1"
|
||||
")"
|
||||
)
|
||||
op.create_foreign_key(
|
||||
"fk_event_templates_user_id", "event_templates", "users",
|
||||
["user_id"], ["id"], ondelete="CASCADE"
|
||||
)
|
||||
op.alter_column("event_templates", "user_id", nullable=False)
|
||||
op.create_index("ix_event_templates_user_id", "event_templates", ["user_id"])
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index("ix_event_templates_user_id", table_name="event_templates")
|
||||
op.drop_constraint("fk_event_templates_user_id", "event_templates", type_="foreignkey")
|
||||
op.drop_column("event_templates", "user_id")
|
||||
46
backend/alembic/versions/034_add_user_id_to_ntfy_sent.py
Normal file
46
backend/alembic/versions/034_add_user_id_to_ntfy_sent.py
Normal file
@ -0,0 +1,46 @@
|
||||
"""Add user_id FK to ntfy_sent table, rebuild unique constraint as composite.
|
||||
|
||||
Revision ID: 034
|
||||
Revises: 033
|
||||
Create Date: 2026-02-26
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
revision = "034"
|
||||
down_revision = "033"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column("ntfy_sent", sa.Column("user_id", sa.Integer(), nullable=True))
|
||||
op.execute(
|
||||
"UPDATE ntfy_sent SET user_id = ("
|
||||
" SELECT id FROM users WHERE role = 'admin' ORDER BY id LIMIT 1"
|
||||
")"
|
||||
)
|
||||
op.create_foreign_key(
|
||||
"fk_ntfy_sent_user_id", "ntfy_sent", "users",
|
||||
["user_id"], ["id"], ondelete="CASCADE"
|
||||
)
|
||||
op.alter_column("ntfy_sent", "user_id", nullable=False)
|
||||
|
||||
# Drop old unique constraint on notification_key alone
|
||||
op.drop_constraint("ntfy_sent_notification_key_key", "ntfy_sent", type_="unique")
|
||||
|
||||
# Create composite unique constraint (per-user dedup)
|
||||
op.create_unique_constraint(
|
||||
"uq_ntfy_sent_user_key", "ntfy_sent", ["user_id", "notification_key"]
|
||||
)
|
||||
op.create_index("ix_ntfy_sent_user_id", "ntfy_sent", ["user_id"])
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index("ix_ntfy_sent_user_id", table_name="ntfy_sent")
|
||||
op.drop_constraint("uq_ntfy_sent_user_key", "ntfy_sent", type_="unique")
|
||||
op.create_unique_constraint(
|
||||
"ntfy_sent_notification_key_key", "ntfy_sent", ["notification_key"]
|
||||
)
|
||||
op.drop_constraint("fk_ntfy_sent_user_id", "ntfy_sent", type_="foreignkey")
|
||||
op.drop_column("ntfy_sent", "user_id")
|
||||
@ -19,6 +19,7 @@ from app.database import AsyncSessionLocal
|
||||
from app.models.settings import Settings
|
||||
from app.models.reminder import Reminder
|
||||
from app.models.calendar_event import CalendarEvent
|
||||
from app.models.calendar import Calendar
|
||||
from app.models.todo import Todo
|
||||
from app.models.project import Project
|
||||
from app.models.ntfy_sent import NtfySent
|
||||
@ -55,10 +56,11 @@ async def _mark_sent(db: AsyncSession, key: str) -> None:
|
||||
|
||||
async def _dispatch_reminders(db: AsyncSession, settings: Settings, now: datetime) -> None:
|
||||
"""Send notifications for reminders that are currently due and not dismissed/snoozed."""
|
||||
# Mirror the filter from /api/reminders/due
|
||||
# Mirror the filter from /api/reminders/due, scoped to this user
|
||||
result = await db.execute(
|
||||
select(Reminder).where(
|
||||
and_(
|
||||
Reminder.user_id == settings.user_id,
|
||||
Reminder.remind_at <= now,
|
||||
Reminder.is_dismissed == False, # noqa: E712
|
||||
Reminder.is_active == True, # noqa: E712
|
||||
@ -72,8 +74,8 @@ async def _dispatch_reminders(db: AsyncSession, settings: Settings, now: datetim
|
||||
if reminder.snoozed_until and reminder.snoozed_until > now:
|
||||
continue # respect snooze
|
||||
|
||||
# Key ties notification to the specific day to handle re-fires after midnight
|
||||
key = f"reminder:{reminder.id}:{reminder.remind_at.date()}"
|
||||
# Key includes user_id to prevent cross-user dedup collisions
|
||||
key = f"reminder:{settings.user_id}:{reminder.id}:{reminder.remind_at.date()}"
|
||||
if await _already_sent(db, key):
|
||||
continue
|
||||
|
||||
@ -98,9 +100,13 @@ async def _dispatch_events(db: AsyncSession, settings: Settings, now: datetime)
|
||||
# Window: events starting between now and (now + lead_minutes)
|
||||
window_end = now + timedelta(minutes=lead_minutes)
|
||||
|
||||
# Scope events through calendar ownership
|
||||
user_calendar_ids = select(Calendar.id).where(Calendar.user_id == settings.user_id)
|
||||
|
||||
result = await db.execute(
|
||||
select(CalendarEvent).where(
|
||||
and_(
|
||||
CalendarEvent.calendar_id.in_(user_calendar_ids),
|
||||
CalendarEvent.start_datetime >= now,
|
||||
CalendarEvent.start_datetime <= window_end,
|
||||
# Exclude recurring parent templates — they duplicate the child instance rows.
|
||||
@ -116,8 +122,8 @@ async def _dispatch_events(db: AsyncSession, settings: Settings, now: datetime)
|
||||
today = now.date()
|
||||
|
||||
for event in events:
|
||||
# Key includes the minute-precision start to avoid re-firing during the window
|
||||
key = f"event:{event.id}:{event.start_datetime.strftime('%Y-%m-%dT%H:%M')}"
|
||||
# Key includes user_id to prevent cross-user dedup collisions
|
||||
key = f"event:{settings.user_id}:{event.id}:{event.start_datetime.strftime('%Y-%m-%dT%H:%M')}"
|
||||
if await _already_sent(db, key):
|
||||
continue
|
||||
|
||||
@ -141,13 +147,13 @@ async def _dispatch_events(db: AsyncSession, settings: Settings, now: datetime)
|
||||
|
||||
async def _dispatch_todos(db: AsyncSession, settings: Settings, today) -> None:
|
||||
"""Send notifications for incomplete todos due within the configured lead days."""
|
||||
from datetime import date as date_type
|
||||
lead_days = settings.ntfy_todo_lead_days
|
||||
cutoff = today + timedelta(days=lead_days)
|
||||
|
||||
result = await db.execute(
|
||||
select(Todo).where(
|
||||
and_(
|
||||
Todo.user_id == settings.user_id,
|
||||
Todo.completed == False, # noqa: E712
|
||||
Todo.due_date != None, # noqa: E711
|
||||
Todo.due_date <= cutoff,
|
||||
@ -157,7 +163,8 @@ async def _dispatch_todos(db: AsyncSession, settings: Settings, today) -> None:
|
||||
todos = result.scalars().all()
|
||||
|
||||
for todo in todos:
|
||||
key = f"todo:{todo.id}:{today}"
|
||||
# Key includes user_id to prevent cross-user dedup collisions
|
||||
key = f"todo:{settings.user_id}:{todo.id}:{today}"
|
||||
if await _already_sent(db, key):
|
||||
continue
|
||||
|
||||
@ -185,6 +192,7 @@ async def _dispatch_projects(db: AsyncSession, settings: Settings, today) -> Non
|
||||
result = await db.execute(
|
||||
select(Project).where(
|
||||
and_(
|
||||
Project.user_id == settings.user_id,
|
||||
Project.due_date != None, # noqa: E711
|
||||
Project.due_date <= cutoff,
|
||||
Project.status != "completed",
|
||||
@ -194,7 +202,8 @@ async def _dispatch_projects(db: AsyncSession, settings: Settings, today) -> Non
|
||||
projects = result.scalars().all()
|
||||
|
||||
for project in projects:
|
||||
key = f"project:{project.id}:{today}"
|
||||
# Key includes user_id to prevent cross-user dedup collisions
|
||||
key = f"project:{settings.user_id}:{project.id}:{today}"
|
||||
if await _already_sent(db, key):
|
||||
continue
|
||||
|
||||
@ -213,6 +222,18 @@ async def _dispatch_projects(db: AsyncSession, settings: Settings, today) -> Non
|
||||
await _mark_sent(db, key)
|
||||
|
||||
|
||||
async def _dispatch_for_user(db: AsyncSession, settings: Settings, now: datetime) -> None:
|
||||
"""Run all notification dispatches for a single user's settings."""
|
||||
if settings.ntfy_reminders_enabled:
|
||||
await _dispatch_reminders(db, settings, now)
|
||||
if settings.ntfy_events_enabled:
|
||||
await _dispatch_events(db, settings, now)
|
||||
if settings.ntfy_todos_enabled:
|
||||
await _dispatch_todos(db, settings, now.date())
|
||||
if settings.ntfy_projects_enabled:
|
||||
await _dispatch_projects(db, settings, now.date())
|
||||
|
||||
|
||||
async def _purge_old_sent_records(db: AsyncSession) -> None:
|
||||
"""Remove ntfy_sent entries older than 7 days to keep the table lean."""
|
||||
# See DATETIME NOTE at top of file re: naive datetime usage
|
||||
@ -240,29 +261,35 @@ async def run_notification_dispatch() -> None:
|
||||
"""
|
||||
Main dispatch function called by APScheduler every 60 seconds.
|
||||
Uses AsyncSessionLocal directly — not the get_db() request-scoped dependency.
|
||||
|
||||
Iterates over ALL users with ntfy enabled. Per-user errors are caught and
|
||||
logged individually so one user's failure does not prevent others from
|
||||
receiving notifications.
|
||||
"""
|
||||
try:
|
||||
async with AsyncSessionLocal() as db:
|
||||
result = await db.execute(select(Settings))
|
||||
settings = result.scalar_one_or_none()
|
||||
# Fetch all Settings rows that have ntfy enabled
|
||||
result = await db.execute(
|
||||
select(Settings).where(Settings.ntfy_enabled == True) # noqa: E712
|
||||
)
|
||||
all_settings = result.scalars().all()
|
||||
|
||||
if not settings or not settings.ntfy_enabled:
|
||||
if not all_settings:
|
||||
return
|
||||
|
||||
# See DATETIME NOTE at top of file re: naive datetime usage
|
||||
now = datetime.now()
|
||||
today = now.date()
|
||||
|
||||
if settings.ntfy_reminders_enabled:
|
||||
await _dispatch_reminders(db, settings, now)
|
||||
if settings.ntfy_events_enabled:
|
||||
await _dispatch_events(db, settings, now)
|
||||
if settings.ntfy_todos_enabled:
|
||||
await _dispatch_todos(db, settings, today)
|
||||
if settings.ntfy_projects_enabled:
|
||||
await _dispatch_projects(db, settings, today)
|
||||
for user_settings in all_settings:
|
||||
try:
|
||||
await _dispatch_for_user(db, user_settings, now)
|
||||
except Exception:
|
||||
# Isolate per-user failures — log and continue to next user
|
||||
logger.exception(
|
||||
"ntfy dispatch failed for user_id=%s", user_settings.user_id
|
||||
)
|
||||
|
||||
# Daily housekeeping: purge stale dedup records
|
||||
# Daily housekeeping: purge stale dedup records (shared across all users)
|
||||
await _purge_old_sent_records(db)
|
||||
|
||||
# Security housekeeping runs every cycle regardless of ntfy_enabled
|
||||
|
||||
@ -7,7 +7,7 @@ from apscheduler.schedulers.asyncio import AsyncIOScheduler
|
||||
from app.config import settings
|
||||
from app.database import engine
|
||||
from app.routers import auth, todos, events, calendars, reminders, projects, people, locations, settings as settings_router, dashboard, weather, event_templates
|
||||
from app.routers import totp
|
||||
from app.routers import totp, admin
|
||||
from app.jobs.notifications import run_notification_dispatch
|
||||
|
||||
# Import models so Alembic's autogenerate can discover them
|
||||
@ -15,6 +15,8 @@ from app.models import user as _user_model # noqa: F401
|
||||
from app.models import session as _session_model # noqa: F401
|
||||
from app.models import totp_usage as _totp_usage_model # noqa: F401
|
||||
from app.models import backup_code as _backup_code_model # noqa: F401
|
||||
from app.models import system_config as _system_config_model # noqa: F401
|
||||
from app.models import audit_log as _audit_log_model # noqa: F401
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
@ -68,6 +70,7 @@ app.include_router(dashboard.router, prefix="/api", tags=["Dashboard"])
|
||||
app.include_router(weather.router, prefix="/api/weather", tags=["Weather"])
|
||||
app.include_router(event_templates.router, prefix="/api/event-templates", tags=["Event Templates"])
|
||||
app.include_router(totp.router, prefix="/api/auth", tags=["TOTP MFA"])
|
||||
app.include_router(admin.router, prefix="/api/admin", tags=["Admin"])
|
||||
|
||||
|
||||
@app.get("/")
|
||||
|
||||
@ -13,6 +13,8 @@ from app.models.session import UserSession
|
||||
from app.models.ntfy_sent import NtfySent
|
||||
from app.models.totp_usage import TOTPUsage
|
||||
from app.models.backup_code import BackupCode
|
||||
from app.models.system_config import SystemConfig
|
||||
from app.models.audit_log import AuditLog
|
||||
|
||||
__all__ = [
|
||||
"Settings",
|
||||
@ -30,4 +32,6 @@ __all__ = [
|
||||
"NtfySent",
|
||||
"TOTPUsage",
|
||||
"BackupCode",
|
||||
"SystemConfig",
|
||||
"AuditLog",
|
||||
]
|
||||
|
||||
27
backend/app/models/audit_log.py
Normal file
27
backend/app/models/audit_log.py
Normal file
@ -0,0 +1,27 @@
|
||||
from sqlalchemy import String, Text, Integer, ForeignKey, func
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
from app.database import Base
|
||||
|
||||
|
||||
class AuditLog(Base):
|
||||
"""
|
||||
Append-only audit trail for admin actions and auth events.
|
||||
No DELETE endpoint — this table is immutable once written.
|
||||
"""
|
||||
__tablename__ = "audit_log"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
actor_user_id: Mapped[Optional[int]] = mapped_column(
|
||||
Integer, ForeignKey("users.id"), nullable=True, index=True
|
||||
)
|
||||
target_user_id: Mapped[Optional[int]] = mapped_column(
|
||||
Integer, ForeignKey("users.id", ondelete="SET NULL"), nullable=True, index=True
|
||||
)
|
||||
action: Mapped[str] = mapped_column(String(100), nullable=False, index=True)
|
||||
detail: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
ip_address: Mapped[Optional[str]] = mapped_column(String(45), nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
default=func.now(), server_default=func.now(), index=True
|
||||
)
|
||||
@ -1,4 +1,4 @@
|
||||
from sqlalchemy import String, Boolean, func
|
||||
from sqlalchemy import String, Boolean, Integer, ForeignKey, func
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
from datetime import datetime
|
||||
from typing import List
|
||||
@ -9,6 +9,9 @@ class Calendar(Base):
|
||||
__tablename__ = "calendars"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True, index=True)
|
||||
user_id: Mapped[int] = mapped_column(
|
||||
Integer, ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
|
||||
)
|
||||
name: Mapped[str] = mapped_column(String(100), nullable=False)
|
||||
color: Mapped[str] = mapped_column(String(20), nullable=False, default="#3b82f6")
|
||||
is_default: Mapped[bool] = mapped_column(Boolean, default=False, server_default="false")
|
||||
|
||||
@ -9,6 +9,9 @@ class EventTemplate(Base):
|
||||
__tablename__ = "event_templates"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True, index=True)
|
||||
user_id: Mapped[int] = mapped_column(
|
||||
Integer, ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
|
||||
)
|
||||
name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
title: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
description: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from sqlalchemy import String, Text, Boolean, func, text
|
||||
from sqlalchemy import String, Text, Boolean, Integer, ForeignKey, func, text
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
from datetime import datetime
|
||||
from typing import Optional, List
|
||||
@ -9,6 +9,9 @@ class Location(Base):
|
||||
__tablename__ = "locations"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True, index=True)
|
||||
user_id: Mapped[int] = mapped_column(
|
||||
Integer, ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
|
||||
)
|
||||
name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
address: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
category: Mapped[str] = mapped_column(String(100), default="other")
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from sqlalchemy import String, func
|
||||
from sqlalchemy import String, Integer, ForeignKey, UniqueConstraint, func
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
from datetime import datetime
|
||||
from app.database import Base
|
||||
@ -8,7 +8,7 @@ class NtfySent(Base):
|
||||
"""
|
||||
Deduplication table for ntfy notifications.
|
||||
Prevents the background job from re-sending the same notification
|
||||
within a given time window.
|
||||
within a given time window. Scoped per-user.
|
||||
|
||||
Key format: "{type}:{entity_id}:{date_window}"
|
||||
Examples:
|
||||
@ -18,7 +18,13 @@ class NtfySent(Base):
|
||||
"project:3:2026-02-25"
|
||||
"""
|
||||
__tablename__ = "ntfy_sent"
|
||||
__table_args__ = (
|
||||
UniqueConstraint("user_id", "notification_key", name="uq_ntfy_sent_user_key"),
|
||||
)
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
notification_key: Mapped[str] = mapped_column(String(255), unique=True, index=True)
|
||||
user_id: Mapped[int] = mapped_column(
|
||||
Integer, ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
|
||||
)
|
||||
notification_key: Mapped[str] = mapped_column(String(255), index=True)
|
||||
sent_at: Mapped[datetime] = mapped_column(default=func.now(), server_default=func.now())
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from sqlalchemy import String, Text, Date, Boolean, func, text
|
||||
from sqlalchemy import String, Text, Date, Boolean, Integer, ForeignKey, func, text
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
from datetime import datetime, date
|
||||
from typing import Optional, List
|
||||
@ -9,6 +9,9 @@ class Person(Base):
|
||||
__tablename__ = "people"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True, index=True)
|
||||
user_id: Mapped[int] = mapped_column(
|
||||
Integer, ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
|
||||
)
|
||||
name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
email: Mapped[Optional[str]] = mapped_column(String(255), nullable=True)
|
||||
phone: Mapped[Optional[str]] = mapped_column(String(50), nullable=True)
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import Boolean, String, Text, Date, func
|
||||
from sqlalchemy import Boolean, String, Text, Date, Integer, ForeignKey, func
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
from datetime import datetime, date
|
||||
from typing import Optional, List
|
||||
@ -10,6 +10,9 @@ class Project(Base):
|
||||
__tablename__ = "projects"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True, index=True)
|
||||
user_id: Mapped[int] = mapped_column(
|
||||
Integer, ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
|
||||
)
|
||||
name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
description: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
status: Mapped[str] = mapped_column(String(20), default="not_started")
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from sqlalchemy import String, Text, Boolean, func
|
||||
from sqlalchemy import String, Text, Boolean, Integer, ForeignKey, func
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
@ -9,6 +9,9 @@ class Reminder(Base):
|
||||
__tablename__ = "reminders"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True, index=True)
|
||||
user_id: Mapped[int] = mapped_column(
|
||||
Integer, ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
|
||||
)
|
||||
title: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
description: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
remind_at: Mapped[Optional[datetime]] = mapped_column(nullable=True)
|
||||
|
||||
27
backend/app/models/system_config.py
Normal file
27
backend/app/models/system_config.py
Normal file
@ -0,0 +1,27 @@
|
||||
from sqlalchemy import Boolean, CheckConstraint, func
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
from datetime import datetime
|
||||
from app.database import Base
|
||||
|
||||
|
||||
class SystemConfig(Base):
|
||||
"""
|
||||
Singleton system configuration table (always id=1).
|
||||
Stores global toggles for registration, MFA enforcement, etc.
|
||||
"""
|
||||
__tablename__ = "system_config"
|
||||
__table_args__ = (
|
||||
CheckConstraint("id = 1", name="ck_system_config_singleton"),
|
||||
)
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
allow_registration: Mapped[bool] = mapped_column(
|
||||
Boolean, default=False, server_default="false"
|
||||
)
|
||||
enforce_mfa_new_users: Mapped[bool] = mapped_column(
|
||||
Boolean, default=False, server_default="false"
|
||||
)
|
||||
created_at: Mapped[datetime] = mapped_column(default=func.now(), server_default=func.now())
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
default=func.now(), onupdate=func.now(), server_default=func.now()
|
||||
)
|
||||
@ -9,6 +9,9 @@ class Todo(Base):
|
||||
__tablename__ = "todos"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True, index=True)
|
||||
user_id: Mapped[int] = mapped_column(
|
||||
Integer, ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
|
||||
)
|
||||
title: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
description: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
priority: Mapped[str] = mapped_column(String(20), default="medium")
|
||||
|
||||
@ -23,7 +23,23 @@ class User(Base):
|
||||
# Account state
|
||||
is_active: Mapped[bool] = mapped_column(Boolean, default=True)
|
||||
|
||||
# RBAC
|
||||
role: Mapped[str] = mapped_column(
|
||||
String(30), nullable=False, default="standard", server_default="standard"
|
||||
)
|
||||
|
||||
# MFA enforcement (admin can toggle; checked at login)
|
||||
mfa_enforce_pending: Mapped[bool] = mapped_column(
|
||||
Boolean, default=False, server_default="false"
|
||||
)
|
||||
|
||||
# Forced password change (set after admin reset)
|
||||
must_change_password: Mapped[bool] = mapped_column(
|
||||
Boolean, default=False, server_default="false"
|
||||
)
|
||||
|
||||
# Audit
|
||||
created_at: Mapped[datetime] = mapped_column(default=func.now())
|
||||
updated_at: Mapped[datetime] = mapped_column(default=func.now(), onupdate=func.now())
|
||||
last_login_at: Mapped[datetime | None] = mapped_column(nullable=True, default=None)
|
||||
last_password_change_at: Mapped[datetime | None] = mapped_column(nullable=True, default=None)
|
||||
|
||||
687
backend/app/routers/admin.py
Normal file
687
backend/app/routers/admin.py
Normal file
@ -0,0 +1,687 @@
|
||||
"""
|
||||
Admin router — full user management, system config, and audit log.
|
||||
|
||||
Security measures implemented:
|
||||
SEC-02: Session revocation on role change
|
||||
SEC-05: Block admin self-actions (own role/password/MFA/active status)
|
||||
SEC-08: X-Requested-With header check (verify_xhr) on all state-mutating requests
|
||||
SEC-13: Session revocation + ntfy alert on MFA disable
|
||||
|
||||
All routes require the `require_admin` dependency (which chains through
|
||||
get_current_user, so the session cookie is always validated).
|
||||
"""
|
||||
import secrets
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
import sqlalchemy as sa
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Request
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.database import get_db
|
||||
from app.models.audit_log import AuditLog
|
||||
from app.models.backup_code import BackupCode
|
||||
from app.models.session import UserSession
|
||||
from app.models.system_config import SystemConfig
|
||||
from app.models.user import User
|
||||
from app.routers.auth import (
|
||||
_create_user_defaults,
|
||||
get_current_user,
|
||||
require_admin,
|
||||
)
|
||||
from app.schemas.admin import (
|
||||
AdminDashboardResponse,
|
||||
AuditLogEntry,
|
||||
AuditLogResponse,
|
||||
CreateUserRequest,
|
||||
ResetPasswordResponse,
|
||||
SystemConfigResponse,
|
||||
SystemConfigUpdate,
|
||||
ToggleActiveRequest,
|
||||
ToggleMfaEnforceRequest,
|
||||
UpdateUserRoleRequest,
|
||||
UserDetailResponse,
|
||||
UserListItem,
|
||||
UserListResponse,
|
||||
)
|
||||
from app.services.audit import log_audit_event
|
||||
from app.services.auth import hash_password
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CSRF guard — SEC-08
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def verify_xhr(request: Request) -> None:
|
||||
"""
|
||||
Lightweight CSRF mitigation: require X-Requested-With on state-mutating
|
||||
requests. Browsers never send this header cross-origin without CORS
|
||||
pre-flight, which our CORS policy blocks.
|
||||
"""
|
||||
if request.method not in ("GET", "HEAD", "OPTIONS"):
|
||||
if request.headers.get("X-Requested-With") != "XMLHttpRequest":
|
||||
raise HTTPException(status_code=403, detail="Invalid request origin")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Router — all endpoints inherit require_admin + verify_xhr
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
router = APIRouter(
|
||||
dependencies=[Depends(require_admin), Depends(verify_xhr)],
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Session revocation helper (used in multiple endpoints)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def _revoke_all_sessions(db: AsyncSession, user_id: int) -> int:
|
||||
"""Mark every active session for user_id as revoked. Returns count revoked."""
|
||||
result = await db.execute(
|
||||
sa.update(UserSession)
|
||||
.where(UserSession.user_id == user_id, UserSession.revoked == False)
|
||||
.values(revoked=True)
|
||||
.returning(UserSession.id)
|
||||
)
|
||||
return len(result.fetchall())
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Self-action guard — SEC-05
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _guard_self_action(actor: User, target_id: int, action: str) -> None:
|
||||
"""Raise 403 if an admin attempts a privileged action against their own account."""
|
||||
if actor.id == target_id:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail=f"Admins cannot {action} their own account",
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GET /users
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.get("/users", response_model=UserListResponse)
|
||||
async def list_users(
|
||||
db: AsyncSession = Depends(get_db),
|
||||
_actor: User = Depends(get_current_user),
|
||||
):
|
||||
"""Return all users with basic stats."""
|
||||
result = await db.execute(sa.select(User).order_by(User.created_at))
|
||||
users = result.scalars().all()
|
||||
return UserListResponse(
|
||||
users=[UserListItem.model_validate(u) for u in users],
|
||||
total=len(users),
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GET /users/{user_id}
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.get("/users/{user_id}", response_model=UserDetailResponse)
|
||||
async def get_user(
|
||||
user_id: int,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
_actor: User = Depends(get_current_user),
|
||||
):
|
||||
"""Return a single user with their active session count."""
|
||||
result = await db.execute(sa.select(User).where(User.id == user_id))
|
||||
user = result.scalar_one_or_none()
|
||||
if not user:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
|
||||
session_result = await db.execute(
|
||||
sa.select(sa.func.count()).select_from(UserSession).where(
|
||||
UserSession.user_id == user_id,
|
||||
UserSession.revoked == False,
|
||||
UserSession.expires_at > datetime.now(),
|
||||
)
|
||||
)
|
||||
active_sessions = session_result.scalar_one()
|
||||
|
||||
return UserDetailResponse(
|
||||
**UserListItem.model_validate(user).model_dump(),
|
||||
active_sessions=active_sessions,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# POST /users
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.post("/users", response_model=UserDetailResponse, status_code=201)
|
||||
async def create_user(
|
||||
data: CreateUserRequest,
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
actor: User = Depends(get_current_user),
|
||||
):
|
||||
"""Admin-create a user with Settings and default calendars."""
|
||||
existing = await db.execute(sa.select(User).where(User.username == data.username))
|
||||
if existing.scalar_one_or_none():
|
||||
raise HTTPException(status_code=409, detail="Username already taken")
|
||||
|
||||
new_user = User(
|
||||
username=data.username,
|
||||
password_hash=hash_password(data.password),
|
||||
role=data.role,
|
||||
last_password_change_at=datetime.now(),
|
||||
# Force password change so the user sets their own credential
|
||||
must_change_password=True,
|
||||
)
|
||||
db.add(new_user)
|
||||
await db.flush() # populate new_user.id
|
||||
|
||||
await _create_user_defaults(db, new_user.id)
|
||||
await db.commit()
|
||||
|
||||
await log_audit_event(
|
||||
db,
|
||||
action="admin.user_created",
|
||||
actor_id=actor.id,
|
||||
target_id=new_user.id,
|
||||
detail={"username": new_user.username, "role": new_user.role},
|
||||
ip=request.client.host if request.client else None,
|
||||
)
|
||||
await db.commit()
|
||||
|
||||
return UserDetailResponse(
|
||||
**UserListItem.model_validate(new_user).model_dump(),
|
||||
active_sessions=0,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# PUT /users/{user_id}/role — SEC-02, SEC-05
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.put("/users/{user_id}/role")
|
||||
async def update_user_role(
|
||||
user_id: int,
|
||||
data: UpdateUserRoleRequest,
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
actor: User = Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
Change a user's role.
|
||||
Blocks demotion of the last admin (SEC-05 variant).
|
||||
Revokes all sessions after role change (SEC-02).
|
||||
"""
|
||||
_guard_self_action(actor, user_id, "change role of")
|
||||
|
||||
result = await db.execute(sa.select(User).where(User.id == user_id))
|
||||
user = result.scalar_one_or_none()
|
||||
if not user:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
|
||||
# Prevent demoting the last admin
|
||||
if user.role == "admin" and data.role != "admin":
|
||||
admin_count = await db.scalar(
|
||||
sa.select(sa.func.count()).select_from(User).where(User.role == "admin")
|
||||
)
|
||||
if admin_count <= 1:
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail="Cannot demote the last admin account",
|
||||
)
|
||||
|
||||
old_role = user.role
|
||||
user.role = data.role
|
||||
|
||||
# SEC-02: revoke sessions so the new role takes effect immediately
|
||||
revoked = await _revoke_all_sessions(db, user_id)
|
||||
|
||||
await log_audit_event(
|
||||
db,
|
||||
action="admin.role_changed",
|
||||
actor_id=actor.id,
|
||||
target_id=user_id,
|
||||
detail={"old_role": old_role, "new_role": data.role, "sessions_revoked": revoked},
|
||||
ip=request.client.host if request.client else None,
|
||||
)
|
||||
await db.commit()
|
||||
|
||||
return {"message": f"Role updated to '{data.role}'. {revoked} session(s) revoked."}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# POST /users/{user_id}/reset-password — SEC-05
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.post("/users/{user_id}/reset-password", response_model=ResetPasswordResponse)
|
||||
async def reset_user_password(
|
||||
user_id: int,
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
actor: User = Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
Generate a temporary password, revoke all sessions, and mark must_change_password.
|
||||
The admin is shown the plaintext temp password once — it is not stored.
|
||||
"""
|
||||
_guard_self_action(actor, user_id, "reset the password of")
|
||||
|
||||
result = await db.execute(sa.select(User).where(User.id == user_id))
|
||||
user = result.scalar_one_or_none()
|
||||
if not user:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
|
||||
temp_password = secrets.token_urlsafe(16)
|
||||
user.password_hash = hash_password(temp_password)
|
||||
user.must_change_password = True
|
||||
user.last_password_change_at = datetime.now()
|
||||
|
||||
revoked = await _revoke_all_sessions(db, user_id)
|
||||
|
||||
await log_audit_event(
|
||||
db,
|
||||
action="admin.password_reset",
|
||||
actor_id=actor.id,
|
||||
target_id=user_id,
|
||||
detail={"sessions_revoked": revoked},
|
||||
ip=request.client.host if request.client else None,
|
||||
)
|
||||
await db.commit()
|
||||
|
||||
return ResetPasswordResponse(
|
||||
message=f"Password reset. {revoked} session(s) revoked. User must change password on next login.",
|
||||
temporary_password=temp_password,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# POST /users/{user_id}/disable-mfa — SEC-05, SEC-13
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.post("/users/{user_id}/disable-mfa")
|
||||
async def disable_user_mfa(
|
||||
user_id: int,
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
actor: User = Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
Clear TOTP secret + backup codes and revoke all sessions (SEC-13).
|
||||
"""
|
||||
_guard_self_action(actor, user_id, "disable MFA for")
|
||||
|
||||
result = await db.execute(sa.select(User).where(User.id == user_id))
|
||||
user = result.scalar_one_or_none()
|
||||
if not user:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
|
||||
if not user.totp_enabled:
|
||||
raise HTTPException(status_code=409, detail="MFA is not enabled for this user")
|
||||
|
||||
# Clear TOTP data
|
||||
user.totp_secret = None
|
||||
user.totp_enabled = False
|
||||
user.mfa_enforce_pending = False
|
||||
|
||||
# Remove all backup codes
|
||||
await db.execute(
|
||||
sa.delete(BackupCode).where(BackupCode.user_id == user_id)
|
||||
)
|
||||
|
||||
# SEC-13: revoke sessions so the MFA downgrade takes effect immediately
|
||||
revoked = await _revoke_all_sessions(db, user_id)
|
||||
|
||||
await log_audit_event(
|
||||
db,
|
||||
action="admin.mfa_disabled",
|
||||
actor_id=actor.id,
|
||||
target_id=user_id,
|
||||
detail={"sessions_revoked": revoked},
|
||||
ip=request.client.host if request.client else None,
|
||||
)
|
||||
await db.commit()
|
||||
|
||||
return {"message": f"MFA disabled. {revoked} session(s) revoked."}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# PUT /users/{user_id}/enforce-mfa — SEC-05
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.put("/users/{user_id}/enforce-mfa")
|
||||
async def toggle_mfa_enforce(
|
||||
user_id: int,
|
||||
data: ToggleMfaEnforceRequest,
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
actor: User = Depends(get_current_user),
|
||||
):
|
||||
"""Toggle the mfa_enforce_pending flag. Next login will prompt MFA setup."""
|
||||
_guard_self_action(actor, user_id, "toggle MFA enforcement for")
|
||||
|
||||
result = await db.execute(sa.select(User).where(User.id == user_id))
|
||||
user = result.scalar_one_or_none()
|
||||
if not user:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
|
||||
user.mfa_enforce_pending = data.enforce
|
||||
|
||||
await log_audit_event(
|
||||
db,
|
||||
action="admin.mfa_enforce_toggled",
|
||||
actor_id=actor.id,
|
||||
target_id=user_id,
|
||||
detail={"enforce": data.enforce},
|
||||
ip=request.client.host if request.client else None,
|
||||
)
|
||||
await db.commit()
|
||||
|
||||
return {"message": f"MFA enforcement {'enabled' if data.enforce else 'disabled'} for user."}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# PUT /users/{user_id}/active — SEC-05
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.put("/users/{user_id}/active")
|
||||
async def toggle_user_active(
|
||||
user_id: int,
|
||||
data: ToggleActiveRequest,
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
actor: User = Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
Enable or disable a user account.
|
||||
Revoking an account also revokes all active sessions immediately.
|
||||
"""
|
||||
_guard_self_action(actor, user_id, "change active status of")
|
||||
|
||||
result = await db.execute(sa.select(User).where(User.id == user_id))
|
||||
user = result.scalar_one_or_none()
|
||||
if not user:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
|
||||
user.is_active = data.is_active
|
||||
revoked = 0
|
||||
|
||||
if not data.is_active:
|
||||
revoked = await _revoke_all_sessions(db, user_id)
|
||||
|
||||
await log_audit_event(
|
||||
db,
|
||||
action="admin.user_deactivated" if not data.is_active else "admin.user_activated",
|
||||
actor_id=actor.id,
|
||||
target_id=user_id,
|
||||
detail={"sessions_revoked": revoked},
|
||||
ip=request.client.host if request.client else None,
|
||||
)
|
||||
await db.commit()
|
||||
|
||||
state = "activated" if data.is_active else f"deactivated ({revoked} session(s) revoked)"
|
||||
return {"message": f"User {state}."}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# DELETE /users/{user_id}/sessions
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.delete("/users/{user_id}/sessions")
|
||||
async def revoke_user_sessions(
|
||||
user_id: int,
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
actor: User = Depends(get_current_user),
|
||||
):
|
||||
"""Forcibly revoke all active sessions for a user."""
|
||||
result = await db.execute(sa.select(User).where(User.id == user_id))
|
||||
if not result.scalar_one_or_none():
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
|
||||
revoked = await _revoke_all_sessions(db, user_id)
|
||||
|
||||
await log_audit_event(
|
||||
db,
|
||||
action="admin.sessions_revoked",
|
||||
actor_id=actor.id,
|
||||
target_id=user_id,
|
||||
detail={"sessions_revoked": revoked},
|
||||
ip=request.client.host if request.client else None,
|
||||
)
|
||||
await db.commit()
|
||||
|
||||
return {"message": f"{revoked} session(s) revoked."}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GET /users/{user_id}/sessions
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.get("/users/{user_id}/sessions")
|
||||
async def list_user_sessions(
|
||||
user_id: int,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
_actor: User = Depends(get_current_user),
|
||||
):
|
||||
"""List all active (non-revoked, non-expired) sessions for a user."""
|
||||
result = await db.execute(sa.select(User).where(User.id == user_id))
|
||||
if not result.scalar_one_or_none():
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
|
||||
sessions_result = await db.execute(
|
||||
sa.select(UserSession).where(
|
||||
UserSession.user_id == user_id,
|
||||
UserSession.revoked == False,
|
||||
UserSession.expires_at > datetime.now(),
|
||||
).order_by(UserSession.created_at.desc())
|
||||
)
|
||||
sessions = sessions_result.scalars().all()
|
||||
|
||||
return {
|
||||
"sessions": [
|
||||
{
|
||||
"id": s.id,
|
||||
"created_at": s.created_at,
|
||||
"expires_at": s.expires_at,
|
||||
"ip_address": s.ip_address,
|
||||
"user_agent": s.user_agent,
|
||||
}
|
||||
for s in sessions
|
||||
],
|
||||
"total": len(sessions),
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GET /config
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.get("/config", response_model=SystemConfigResponse)
|
||||
async def get_system_config(
|
||||
db: AsyncSession = Depends(get_db),
|
||||
_actor: User = Depends(get_current_user),
|
||||
):
|
||||
"""Fetch the singleton system configuration row."""
|
||||
result = await db.execute(sa.select(SystemConfig).where(SystemConfig.id == 1))
|
||||
config = result.scalar_one_or_none()
|
||||
if not config:
|
||||
# Bootstrap the singleton if it doesn't exist yet
|
||||
config = SystemConfig(id=1)
|
||||
db.add(config)
|
||||
await db.commit()
|
||||
return config
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# PUT /config
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.put("/config", response_model=SystemConfigResponse)
|
||||
async def update_system_config(
|
||||
data: SystemConfigUpdate,
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
actor: User = Depends(get_current_user),
|
||||
):
|
||||
"""Update one or more system config fields (partial update)."""
|
||||
result = await db.execute(sa.select(SystemConfig).where(SystemConfig.id == 1))
|
||||
config = result.scalar_one_or_none()
|
||||
if not config:
|
||||
config = SystemConfig(id=1)
|
||||
db.add(config)
|
||||
await db.flush()
|
||||
|
||||
changes: dict = {}
|
||||
if data.allow_registration is not None:
|
||||
changes["allow_registration"] = data.allow_registration
|
||||
config.allow_registration = data.allow_registration
|
||||
if data.enforce_mfa_new_users is not None:
|
||||
changes["enforce_mfa_new_users"] = data.enforce_mfa_new_users
|
||||
config.enforce_mfa_new_users = data.enforce_mfa_new_users
|
||||
|
||||
if changes:
|
||||
await log_audit_event(
|
||||
db,
|
||||
action="admin.config_updated",
|
||||
actor_id=actor.id,
|
||||
detail=changes,
|
||||
ip=request.client.host if request.client else None,
|
||||
)
|
||||
|
||||
await db.commit()
|
||||
return config
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GET /dashboard
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.get("/dashboard", response_model=AdminDashboardResponse)
|
||||
async def admin_dashboard(
|
||||
db: AsyncSession = Depends(get_db),
|
||||
_actor: User = Depends(get_current_user),
|
||||
):
|
||||
"""Aggregate stats for the admin portal dashboard."""
|
||||
total_users = await db.scalar(
|
||||
sa.select(sa.func.count()).select_from(User)
|
||||
)
|
||||
active_users = await db.scalar(
|
||||
sa.select(sa.func.count()).select_from(User).where(User.is_active == True)
|
||||
)
|
||||
admin_count = await db.scalar(
|
||||
sa.select(sa.func.count()).select_from(User).where(User.role == "admin")
|
||||
)
|
||||
totp_count = await db.scalar(
|
||||
sa.select(sa.func.count()).select_from(User).where(User.totp_enabled == True)
|
||||
)
|
||||
active_sessions = await db.scalar(
|
||||
sa.select(sa.func.count()).select_from(UserSession).where(
|
||||
UserSession.revoked == False,
|
||||
UserSession.expires_at > datetime.now(),
|
||||
)
|
||||
)
|
||||
|
||||
mfa_adoption = (totp_count / total_users) if total_users else 0.0
|
||||
|
||||
# 10 most recent logins — join to get username
|
||||
actor_alias = sa.alias(User.__table__, name="actor")
|
||||
recent_logins_result = await db.execute(
|
||||
sa.select(User.username, User.last_login_at)
|
||||
.where(User.last_login_at != None)
|
||||
.order_by(User.last_login_at.desc())
|
||||
.limit(10)
|
||||
)
|
||||
recent_logins = [
|
||||
{"username": row.username, "last_login_at": row.last_login_at}
|
||||
for row in recent_logins_result
|
||||
]
|
||||
|
||||
# 10 most recent audit entries
|
||||
recent_audit_result = await db.execute(
|
||||
sa.select(AuditLog).order_by(AuditLog.created_at.desc()).limit(10)
|
||||
)
|
||||
recent_audit_entries = [
|
||||
{
|
||||
"id": e.id,
|
||||
"action": e.action,
|
||||
"actor_user_id": e.actor_user_id,
|
||||
"target_user_id": e.target_user_id,
|
||||
"detail": e.detail,
|
||||
"created_at": e.created_at,
|
||||
}
|
||||
for e in recent_audit_result.scalars()
|
||||
]
|
||||
|
||||
return AdminDashboardResponse(
|
||||
total_users=total_users or 0,
|
||||
active_users=active_users or 0,
|
||||
admin_count=admin_count or 0,
|
||||
active_sessions=active_sessions or 0,
|
||||
mfa_adoption_rate=round(mfa_adoption, 4),
|
||||
recent_logins=recent_logins,
|
||||
recent_audit_entries=recent_audit_entries,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GET /audit-log
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.get("/audit-log", response_model=AuditLogResponse)
|
||||
async def get_audit_log(
|
||||
db: AsyncSession = Depends(get_db),
|
||||
_actor: User = Depends(get_current_user),
|
||||
action: Optional[str] = Query(None, description="Filter by action string (prefix match)"),
|
||||
target_user_id: Optional[int] = Query(None, description="Filter by target user ID"),
|
||||
page: int = Query(1, ge=1, description="Page number (1-indexed)"),
|
||||
per_page: int = Query(50, ge=1, le=200, description="Results per page"),
|
||||
):
|
||||
"""
|
||||
Paginated audit log with optional filters.
|
||||
Resolves actor and target user IDs to usernames via a JOIN.
|
||||
"""
|
||||
# Aliases for the two user joins
|
||||
actor_user = sa.orm.aliased(User, name="actor_user")
|
||||
target_user = sa.orm.aliased(User, name="target_user")
|
||||
|
||||
# Base query — left outer join so entries with NULL actor/target still appear
|
||||
base_q = (
|
||||
sa.select(
|
||||
AuditLog,
|
||||
actor_user.username.label("actor_username"),
|
||||
target_user.username.label("target_username"),
|
||||
)
|
||||
.outerjoin(actor_user, AuditLog.actor_user_id == actor_user.id)
|
||||
.outerjoin(target_user, AuditLog.target_user_id == target_user.id)
|
||||
)
|
||||
|
||||
if action:
|
||||
base_q = base_q.where(AuditLog.action.like(f"{action}%"))
|
||||
if target_user_id is not None:
|
||||
base_q = base_q.where(AuditLog.target_user_id == target_user_id)
|
||||
|
||||
# Count before pagination
|
||||
count_q = sa.select(sa.func.count()).select_from(
|
||||
base_q.subquery()
|
||||
)
|
||||
total = await db.scalar(count_q) or 0
|
||||
|
||||
# Paginate
|
||||
offset = (page - 1) * per_page
|
||||
rows_result = await db.execute(
|
||||
base_q.order_by(AuditLog.created_at.desc()).offset(offset).limit(per_page)
|
||||
)
|
||||
|
||||
entries = [
|
||||
AuditLogEntry(
|
||||
id=row.AuditLog.id,
|
||||
actor_username=row.actor_username,
|
||||
target_username=row.target_username,
|
||||
action=row.AuditLog.action,
|
||||
detail=row.AuditLog.detail,
|
||||
ip_address=row.AuditLog.ip_address,
|
||||
created_at=row.AuditLog.created_at,
|
||||
)
|
||||
for row in rows_result
|
||||
]
|
||||
|
||||
return AuditLogResponse(entries=entries, total=total)
|
||||
@ -1,18 +1,20 @@
|
||||
"""
|
||||
Authentication router — username/password with DB-backed sessions and account lockout.
|
||||
Authentication router — username/password with DB-backed sessions, account lockout,
|
||||
role-based access control, and multi-user registration.
|
||||
|
||||
Session flow:
|
||||
POST /setup → create User + Settings row → issue session cookie
|
||||
POST /login → verify credentials → check lockout → insert UserSession → issue cookie
|
||||
→ if TOTP enabled: return mfa_token instead of full session
|
||||
POST /setup → create admin User + Settings + calendars → issue session cookie
|
||||
POST /login → verify credentials → check lockout → MFA/enforce checks → issue session
|
||||
POST /register → create standard user (when registration enabled)
|
||||
POST /logout → mark session revoked in DB → delete cookie
|
||||
GET /status → verify user exists + session valid
|
||||
GET /status → verify user exists + session valid + role + registration_open
|
||||
|
||||
Security layers:
|
||||
1. Nginx limit_req_zone (real-IP, 10 req/min burst 5) — outer guard on all auth endpoints
|
||||
2. DB-backed account lockout (10 failures → 30-min lock, HTTP 423) — per-user guard
|
||||
1. Nginx limit_req_zone (real-IP, 10 req/min burst 5) — outer guard on auth endpoints
|
||||
2. DB-backed account lockout (10 failures → 30-min lock, HTTP 423)
|
||||
3. Session revocation stored in DB (survives container restarts)
|
||||
4. bcrypt→Argon2id transparent upgrade on first login with migrated hash
|
||||
4. bcrypt→Argon2id transparent upgrade on first login
|
||||
5. Role-based authorization via require_role() dependency factory
|
||||
"""
|
||||
import uuid
|
||||
from datetime import datetime, timedelta
|
||||
@ -26,14 +28,21 @@ from app.database import get_db
|
||||
from app.models.user import User
|
||||
from app.models.session import UserSession
|
||||
from app.models.settings import Settings
|
||||
from app.schemas.auth import SetupRequest, LoginRequest, ChangePasswordRequest, VerifyPasswordRequest
|
||||
from app.models.system_config import SystemConfig
|
||||
from app.models.calendar import Calendar
|
||||
from app.schemas.auth import (
|
||||
SetupRequest, LoginRequest, RegisterRequest,
|
||||
ChangePasswordRequest, VerifyPasswordRequest,
|
||||
)
|
||||
from app.services.auth import (
|
||||
hash_password,
|
||||
verify_password_with_upgrade,
|
||||
create_session_token,
|
||||
verify_session_token,
|
||||
create_mfa_token,
|
||||
create_mfa_enforce_token,
|
||||
)
|
||||
from app.services.audit import log_audit_event
|
||||
from app.config import settings as app_settings
|
||||
|
||||
router = APIRouter()
|
||||
@ -64,8 +73,6 @@ async def get_current_user(
|
||||
) -> User:
|
||||
"""
|
||||
Dependency that verifies the session cookie and returns the authenticated User.
|
||||
Replaces the old get_current_session (which returned Settings).
|
||||
Any router that hasn't been updated will get a compile-time type error.
|
||||
"""
|
||||
if not session_cookie:
|
||||
raise HTTPException(status_code=401, detail="Not authenticated")
|
||||
@ -119,6 +126,24 @@ async def get_current_settings(
|
||||
return settings_obj
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Role-based authorization dependencies
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def require_role(*allowed_roles: str):
|
||||
"""Factory: returns a dependency that enforces role membership."""
|
||||
async def _check(
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> User:
|
||||
if current_user.role not in allowed_roles:
|
||||
raise HTTPException(status_code=403, detail="Insufficient permissions")
|
||||
return current_user
|
||||
return _check
|
||||
|
||||
# Convenience aliases
|
||||
require_admin = require_role("admin")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Account lockout helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
@ -166,7 +191,7 @@ async def _create_db_session(
|
||||
id=session_id,
|
||||
user_id=user.id,
|
||||
expires_at=expires_at,
|
||||
ip_address=ip[:45] if ip else None, # clamp to column width
|
||||
ip_address=ip[:45] if ip else None,
|
||||
user_agent=(user_agent or "")[:255] if user_agent else None,
|
||||
)
|
||||
db.add(db_session)
|
||||
@ -175,6 +200,25 @@ async def _create_db_session(
|
||||
return session_id, token
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# User bootstrapping helper (Settings + default calendars)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def _create_user_defaults(db: AsyncSession, user_id: int) -> None:
|
||||
"""Create Settings row and default calendars for a new user."""
|
||||
db.add(Settings(user_id=user_id))
|
||||
db.add(Calendar(
|
||||
name="Personal", color="#3b82f6",
|
||||
is_default=True, is_system=False, is_visible=True,
|
||||
user_id=user_id,
|
||||
))
|
||||
db.add(Calendar(
|
||||
name="Birthdays", color="#f59e0b",
|
||||
is_default=False, is_system=True, is_visible=True,
|
||||
user_id=user_id,
|
||||
))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Routes
|
||||
# ---------------------------------------------------------------------------
|
||||
@ -187,7 +231,7 @@ async def setup(
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
First-time setup: create the User record and a linked Settings row.
|
||||
First-time setup: create the admin User + Settings + default calendars.
|
||||
Only works when no users exist (i.e., fresh install).
|
||||
"""
|
||||
existing = await db.execute(select(User))
|
||||
@ -195,13 +239,16 @@ async def setup(
|
||||
raise HTTPException(status_code=400, detail="Setup already completed")
|
||||
|
||||
password_hash = hash_password(data.password)
|
||||
new_user = User(username=data.username, password_hash=password_hash)
|
||||
new_user = User(
|
||||
username=data.username,
|
||||
password_hash=password_hash,
|
||||
role="admin",
|
||||
last_password_change_at=datetime.now(),
|
||||
)
|
||||
db.add(new_user)
|
||||
await db.flush() # assign new_user.id before creating Settings
|
||||
await db.flush()
|
||||
|
||||
# Create Settings row linked to this user with all defaults
|
||||
new_settings = Settings(user_id=new_user.id)
|
||||
db.add(new_settings)
|
||||
await _create_user_defaults(db, new_user.id)
|
||||
await db.commit()
|
||||
|
||||
ip = request.client.host if request.client else "unknown"
|
||||
@ -209,6 +256,11 @@ async def setup(
|
||||
_, token = await _create_db_session(db, new_user, ip, user_agent)
|
||||
_set_session_cookie(response, token)
|
||||
|
||||
await log_audit_event(
|
||||
db, action="auth.setup_complete", actor_id=new_user.id, ip=ip,
|
||||
)
|
||||
await db.commit()
|
||||
|
||||
return {"message": "Setup completed successfully", "authenticated": True}
|
||||
|
||||
|
||||
@ -223,15 +275,15 @@ async def login(
|
||||
Authenticate with username + password.
|
||||
|
||||
Returns:
|
||||
{ authenticated: true } — on success (no TOTP)
|
||||
{ authenticated: true } — on success (no TOTP, no enforcement)
|
||||
{ authenticated: false, totp_required: true, mfa_token: "..." } — TOTP pending
|
||||
HTTP 401 — wrong credentials (generic; never reveals which field is wrong)
|
||||
{ authenticated: false, mfa_setup_required: true, mfa_token: "..." } — MFA enforcement
|
||||
{ authenticated: false, must_change_password: true } — forced password change after admin reset
|
||||
HTTP 401 — wrong credentials
|
||||
HTTP 423 — account locked
|
||||
HTTP 429 — IP rate limited
|
||||
"""
|
||||
client_ip = request.client.host if request.client else "unknown"
|
||||
|
||||
# Lookup user — do NOT differentiate "user not found" from "wrong password"
|
||||
result = await db.execute(select(User).where(User.username == data.username))
|
||||
user = result.scalar_one_or_none()
|
||||
|
||||
@ -240,20 +292,36 @@ async def login(
|
||||
|
||||
await _check_account_lockout(user)
|
||||
|
||||
# Transparent bcrypt→Argon2id upgrade
|
||||
valid, new_hash = verify_password_with_upgrade(data.password, user.password_hash)
|
||||
|
||||
if not valid:
|
||||
await _record_failed_login(db, user)
|
||||
await log_audit_event(
|
||||
db, action="auth.login_failed", actor_id=user.id,
|
||||
detail={"reason": "invalid_password"}, ip=client_ip,
|
||||
)
|
||||
await db.commit()
|
||||
raise HTTPException(status_code=401, detail="Invalid username or password")
|
||||
|
||||
# Persist upgraded hash if migration happened
|
||||
if new_hash:
|
||||
user.password_hash = new_hash
|
||||
|
||||
await _record_successful_login(db, user)
|
||||
|
||||
# If TOTP is enabled, issue a short-lived MFA challenge token instead of a full session
|
||||
# SEC-03: MFA enforcement — block login entirely until MFA setup completes
|
||||
if user.mfa_enforce_pending and not user.totp_enabled:
|
||||
enforce_token = create_mfa_enforce_token(user.id)
|
||||
await log_audit_event(
|
||||
db, action="auth.mfa_enforce_prompted", actor_id=user.id, ip=client_ip,
|
||||
)
|
||||
await db.commit()
|
||||
return {
|
||||
"authenticated": False,
|
||||
"mfa_setup_required": True,
|
||||
"mfa_token": enforce_token,
|
||||
}
|
||||
|
||||
# If TOTP is enabled, issue a short-lived MFA challenge token
|
||||
if user.totp_enabled:
|
||||
mfa_token = create_mfa_token(user.id)
|
||||
return {
|
||||
@ -262,13 +330,97 @@ async def login(
|
||||
"mfa_token": mfa_token,
|
||||
}
|
||||
|
||||
# SEC-12: Forced password change after admin reset
|
||||
if user.must_change_password:
|
||||
# Issue a session but flag the frontend to show password change
|
||||
user_agent = request.headers.get("user-agent")
|
||||
_, token = await _create_db_session(db, user, client_ip, user_agent)
|
||||
_set_session_cookie(response, token)
|
||||
return {
|
||||
"authenticated": True,
|
||||
"must_change_password": True,
|
||||
}
|
||||
|
||||
user_agent = request.headers.get("user-agent")
|
||||
_, token = await _create_db_session(db, user, client_ip, user_agent)
|
||||
_set_session_cookie(response, token)
|
||||
|
||||
await log_audit_event(
|
||||
db, action="auth.login_success", actor_id=user.id, ip=client_ip,
|
||||
)
|
||||
await db.commit()
|
||||
|
||||
return {"authenticated": True}
|
||||
|
||||
|
||||
@router.post("/register")
|
||||
async def register(
|
||||
data: RegisterRequest,
|
||||
response: Response,
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
Create a new standard user account.
|
||||
Only available when system_config.allow_registration is True.
|
||||
"""
|
||||
config_result = await db.execute(
|
||||
select(SystemConfig).where(SystemConfig.id == 1)
|
||||
)
|
||||
config = config_result.scalar_one_or_none()
|
||||
if not config or not config.allow_registration:
|
||||
raise HTTPException(status_code=403, detail="Registration is not available")
|
||||
|
||||
# Check username availability (generic error to prevent enumeration)
|
||||
existing = await db.execute(
|
||||
select(User).where(User.username == data.username)
|
||||
)
|
||||
if existing.scalar_one_or_none():
|
||||
raise HTTPException(status_code=400, detail="Registration failed")
|
||||
|
||||
password_hash = hash_password(data.password)
|
||||
# SEC-01: Explicit field assignment — never **data.model_dump()
|
||||
new_user = User(
|
||||
username=data.username,
|
||||
password_hash=password_hash,
|
||||
role="standard",
|
||||
last_password_change_at=datetime.now(),
|
||||
)
|
||||
|
||||
# Check if MFA enforcement is enabled for new users
|
||||
if config.enforce_mfa_new_users:
|
||||
new_user.mfa_enforce_pending = True
|
||||
|
||||
db.add(new_user)
|
||||
await db.flush()
|
||||
|
||||
await _create_user_defaults(db, new_user.id)
|
||||
await db.commit()
|
||||
|
||||
ip = request.client.host if request.client else "unknown"
|
||||
user_agent = request.headers.get("user-agent")
|
||||
|
||||
await log_audit_event(
|
||||
db, action="auth.registration", actor_id=new_user.id, ip=ip,
|
||||
)
|
||||
await db.commit()
|
||||
|
||||
# If MFA enforcement is pending, don't issue a session — require MFA setup first
|
||||
if new_user.mfa_enforce_pending:
|
||||
enforce_token = create_mfa_enforce_token(new_user.id)
|
||||
return {
|
||||
"message": "Registration successful",
|
||||
"authenticated": False,
|
||||
"mfa_setup_required": True,
|
||||
"mfa_token": enforce_token,
|
||||
}
|
||||
|
||||
_, token = await _create_db_session(db, new_user, ip, user_agent)
|
||||
_set_session_cookie(response, token)
|
||||
|
||||
return {"message": "Registration successful", "authenticated": True}
|
||||
|
||||
|
||||
@router.post("/logout")
|
||||
async def logout(
|
||||
response: Response,
|
||||
@ -304,13 +456,13 @@ async def auth_status(
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
Check authentication status and whether initial setup has been performed.
|
||||
Used by the frontend to decide whether to show login vs setup screen.
|
||||
Check authentication status, role, and whether initial setup/registration is available.
|
||||
"""
|
||||
user_result = await db.execute(select(User))
|
||||
existing_user = user_result.scalar_one_or_none()
|
||||
setup_required = existing_user is None
|
||||
authenticated = False
|
||||
role = None
|
||||
|
||||
if not setup_required and session_cookie:
|
||||
payload = verify_session_token(session_cookie)
|
||||
@ -326,9 +478,32 @@ async def auth_status(
|
||||
UserSession.expires_at > datetime.now(),
|
||||
)
|
||||
)
|
||||
authenticated = session_result.scalar_one_or_none() is not None
|
||||
if session_result.scalar_one_or_none() is not None:
|
||||
authenticated = True
|
||||
user_obj_result = await db.execute(
|
||||
select(User).where(User.id == user_id, User.is_active == True)
|
||||
)
|
||||
u = user_obj_result.scalar_one_or_none()
|
||||
if u:
|
||||
role = u.role
|
||||
else:
|
||||
authenticated = False
|
||||
|
||||
return {"authenticated": authenticated, "setup_required": setup_required}
|
||||
# Check registration availability
|
||||
registration_open = False
|
||||
if not setup_required:
|
||||
config_result = await db.execute(
|
||||
select(SystemConfig).where(SystemConfig.id == 1)
|
||||
)
|
||||
config = config_result.scalar_one_or_none()
|
||||
registration_open = config.allow_registration if config else False
|
||||
|
||||
return {
|
||||
"authenticated": authenticated,
|
||||
"setup_required": setup_required,
|
||||
"role": role,
|
||||
"registration_open": registration_open,
|
||||
}
|
||||
|
||||
|
||||
@router.post("/verify-password")
|
||||
@ -340,8 +515,6 @@ async def verify_password(
|
||||
"""
|
||||
Verify the current user's password without changing anything.
|
||||
Used by the frontend lock screen to re-authenticate without a full login.
|
||||
Also handles transparent bcrypt→Argon2id upgrade.
|
||||
Shares the same lockout guards as /login. Nginx limit_req_zone handles IP rate limiting.
|
||||
"""
|
||||
await _check_account_lockout(current_user)
|
||||
|
||||
@ -350,7 +523,6 @@ async def verify_password(
|
||||
await _record_failed_login(db, current_user)
|
||||
raise HTTPException(status_code=401, detail="Invalid password")
|
||||
|
||||
# Persist upgraded hash if migration happened
|
||||
if new_hash:
|
||||
current_user.password_hash = new_hash
|
||||
await db.commit()
|
||||
@ -373,6 +545,12 @@ async def change_password(
|
||||
raise HTTPException(status_code=401, detail="Invalid current password")
|
||||
|
||||
current_user.password_hash = hash_password(data.new_password)
|
||||
current_user.last_password_change_at = datetime.now()
|
||||
|
||||
# Clear forced password change flag if set (SEC-12)
|
||||
if current_user.must_change_password:
|
||||
current_user.must_change_password = False
|
||||
|
||||
await db.commit()
|
||||
|
||||
return {"message": "Password changed successfully"}
|
||||
|
||||
@ -18,7 +18,11 @@ async def get_calendars(
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
result = await db.execute(select(Calendar).order_by(Calendar.is_default.desc(), Calendar.name.asc()))
|
||||
result = await db.execute(
|
||||
select(Calendar)
|
||||
.where(Calendar.user_id == current_user.id)
|
||||
.order_by(Calendar.is_default.desc(), Calendar.name.asc())
|
||||
)
|
||||
return result.scalars().all()
|
||||
|
||||
|
||||
@ -34,6 +38,7 @@ async def create_calendar(
|
||||
is_default=False,
|
||||
is_system=False,
|
||||
is_visible=True,
|
||||
user_id=current_user.id,
|
||||
)
|
||||
db.add(new_calendar)
|
||||
await db.commit()
|
||||
@ -48,7 +53,9 @@ async def update_calendar(
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
result = await db.execute(select(Calendar).where(Calendar.id == calendar_id))
|
||||
result = await db.execute(
|
||||
select(Calendar).where(Calendar.id == calendar_id, Calendar.user_id == current_user.id)
|
||||
)
|
||||
calendar = result.scalar_one_or_none()
|
||||
|
||||
if not calendar:
|
||||
@ -74,7 +81,9 @@ async def delete_calendar(
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
result = await db.execute(select(Calendar).where(Calendar.id == calendar_id))
|
||||
result = await db.execute(
|
||||
select(Calendar).where(Calendar.id == calendar_id, Calendar.user_id == current_user.id)
|
||||
)
|
||||
calendar = result.scalar_one_or_none()
|
||||
|
||||
if not calendar:
|
||||
@ -86,8 +95,13 @@ async def delete_calendar(
|
||||
if calendar.is_default:
|
||||
raise HTTPException(status_code=400, detail="Cannot delete the default calendar")
|
||||
|
||||
# Reassign all events on this calendar to the default calendar
|
||||
default_result = await db.execute(select(Calendar).where(Calendar.is_default == True))
|
||||
# Reassign all events on this calendar to the user's default calendar
|
||||
default_result = await db.execute(
|
||||
select(Calendar).where(
|
||||
Calendar.user_id == current_user.id,
|
||||
Calendar.is_default == True,
|
||||
)
|
||||
)
|
||||
default_calendar = default_result.scalar_one_or_none()
|
||||
|
||||
if default_calendar:
|
||||
|
||||
@ -8,9 +8,11 @@ from app.database import get_db
|
||||
from app.models.settings import Settings
|
||||
from app.models.todo import Todo
|
||||
from app.models.calendar_event import CalendarEvent
|
||||
from app.models.calendar import Calendar
|
||||
from app.models.reminder import Reminder
|
||||
from app.models.project import Project
|
||||
from app.routers.auth import get_current_settings
|
||||
from app.models.user import User
|
||||
from app.routers.auth import get_current_user, get_current_settings
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@ -26,16 +28,21 @@ _not_parent_template = or_(
|
||||
async def get_dashboard(
|
||||
client_date: Optional[date] = Query(None),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_settings: Settings = Depends(get_current_settings)
|
||||
current_user: User = Depends(get_current_user),
|
||||
current_settings: Settings = Depends(get_current_settings),
|
||||
):
|
||||
"""Get aggregated dashboard data."""
|
||||
today = client_date or date.today()
|
||||
upcoming_cutoff = today + timedelta(days=current_settings.upcoming_days)
|
||||
|
||||
# Subquery: calendar IDs belonging to this user (for event scoping)
|
||||
user_calendar_ids = select(Calendar.id).where(Calendar.user_id == current_user.id)
|
||||
|
||||
# Today's events (exclude parent templates — they are hidden, children are shown)
|
||||
today_start = datetime.combine(today, datetime.min.time())
|
||||
today_end = datetime.combine(today, datetime.max.time())
|
||||
events_query = select(CalendarEvent).where(
|
||||
CalendarEvent.calendar_id.in_(user_calendar_ids),
|
||||
CalendarEvent.start_datetime >= today_start,
|
||||
CalendarEvent.start_datetime <= today_end,
|
||||
_not_parent_template,
|
||||
@ -45,6 +52,7 @@ async def get_dashboard(
|
||||
|
||||
# Upcoming todos (not completed, with due date from today through upcoming_days)
|
||||
todos_query = select(Todo).where(
|
||||
Todo.user_id == current_user.id,
|
||||
Todo.completed == False,
|
||||
Todo.due_date.isnot(None),
|
||||
Todo.due_date >= today,
|
||||
@ -55,6 +63,7 @@ async def get_dashboard(
|
||||
|
||||
# Active reminders (not dismissed, is_active = true, from today onward)
|
||||
reminders_query = select(Reminder).where(
|
||||
Reminder.user_id == current_user.id,
|
||||
Reminder.is_active == True,
|
||||
Reminder.is_dismissed == False,
|
||||
Reminder.remind_at >= today_start
|
||||
@ -62,26 +71,32 @@ async def get_dashboard(
|
||||
reminders_result = await db.execute(reminders_query)
|
||||
active_reminders = reminders_result.scalars().all()
|
||||
|
||||
# Project stats
|
||||
total_projects_result = await db.execute(select(func.count(Project.id)))
|
||||
# Project stats (scoped to user)
|
||||
total_projects_result = await db.execute(
|
||||
select(func.count(Project.id)).where(Project.user_id == current_user.id)
|
||||
)
|
||||
total_projects = total_projects_result.scalar()
|
||||
|
||||
projects_by_status_query = select(
|
||||
Project.status,
|
||||
func.count(Project.id).label("count")
|
||||
).group_by(Project.status)
|
||||
).where(Project.user_id == current_user.id).group_by(Project.status)
|
||||
projects_by_status_result = await db.execute(projects_by_status_query)
|
||||
projects_by_status = {row[0]: row[1] for row in projects_by_status_result}
|
||||
|
||||
# Total incomplete todos count
|
||||
# Total incomplete todos count (scoped to user)
|
||||
total_incomplete_result = await db.execute(
|
||||
select(func.count(Todo.id)).where(Todo.completed == False)
|
||||
select(func.count(Todo.id)).where(
|
||||
Todo.user_id == current_user.id,
|
||||
Todo.completed == False,
|
||||
)
|
||||
)
|
||||
total_incomplete_todos = total_incomplete_result.scalar()
|
||||
|
||||
# Starred events (upcoming, ordered by date)
|
||||
# Starred events (upcoming, ordered by date, scoped to user's calendars)
|
||||
now = datetime.now()
|
||||
starred_query = select(CalendarEvent).where(
|
||||
CalendarEvent.calendar_id.in_(user_calendar_ids),
|
||||
CalendarEvent.is_starred == True,
|
||||
CalendarEvent.start_datetime > now,
|
||||
_not_parent_template,
|
||||
@ -143,7 +158,8 @@ async def get_upcoming(
|
||||
days: int = Query(default=7, ge=1, le=90),
|
||||
client_date: Optional[date] = Query(None),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_settings: Settings = Depends(get_current_settings)
|
||||
current_user: User = Depends(get_current_user),
|
||||
current_settings: Settings = Depends(get_current_settings),
|
||||
):
|
||||
"""Get unified list of upcoming items (todos, events, reminders) sorted by date."""
|
||||
today = client_date or date.today()
|
||||
@ -151,8 +167,12 @@ async def get_upcoming(
|
||||
cutoff_datetime = datetime.combine(cutoff_date, datetime.max.time())
|
||||
today_start = datetime.combine(today, datetime.min.time())
|
||||
|
||||
# Get upcoming todos with due dates (today onward only)
|
||||
# Subquery: calendar IDs belonging to this user
|
||||
user_calendar_ids = select(Calendar.id).where(Calendar.user_id == current_user.id)
|
||||
|
||||
# Get upcoming todos with due dates (today onward only, scoped to user)
|
||||
todos_query = select(Todo).where(
|
||||
Todo.user_id == current_user.id,
|
||||
Todo.completed == False,
|
||||
Todo.due_date.isnot(None),
|
||||
Todo.due_date >= today,
|
||||
@ -161,8 +181,9 @@ async def get_upcoming(
|
||||
todos_result = await db.execute(todos_query)
|
||||
todos = todos_result.scalars().all()
|
||||
|
||||
# Get upcoming events (from today onward, exclude parent templates)
|
||||
# Get upcoming events (from today onward, exclude parent templates, scoped to user's calendars)
|
||||
events_query = select(CalendarEvent).where(
|
||||
CalendarEvent.calendar_id.in_(user_calendar_ids),
|
||||
CalendarEvent.start_datetime >= today_start,
|
||||
CalendarEvent.start_datetime <= cutoff_datetime,
|
||||
_not_parent_template,
|
||||
@ -170,8 +191,9 @@ async def get_upcoming(
|
||||
events_result = await db.execute(events_query)
|
||||
events = events_result.scalars().all()
|
||||
|
||||
# Get upcoming reminders (today onward only)
|
||||
# Get upcoming reminders (today onward only, scoped to user)
|
||||
reminders_query = select(Reminder).where(
|
||||
Reminder.user_id == current_user.id,
|
||||
Reminder.is_active == True,
|
||||
Reminder.is_dismissed == False,
|
||||
Reminder.remind_at >= today_start,
|
||||
|
||||
@ -20,7 +20,11 @@ async def list_templates(
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
result = await db.execute(select(EventTemplate).order_by(EventTemplate.name))
|
||||
result = await db.execute(
|
||||
select(EventTemplate)
|
||||
.where(EventTemplate.user_id == current_user.id)
|
||||
.order_by(EventTemplate.name)
|
||||
)
|
||||
return result.scalars().all()
|
||||
|
||||
|
||||
@ -30,7 +34,7 @@ async def create_template(
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
template = EventTemplate(**payload.model_dump())
|
||||
template = EventTemplate(**payload.model_dump(), user_id=current_user.id)
|
||||
db.add(template)
|
||||
await db.commit()
|
||||
await db.refresh(template)
|
||||
@ -45,7 +49,10 @@ async def update_template(
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
result = await db.execute(
|
||||
select(EventTemplate).where(EventTemplate.id == template_id)
|
||||
select(EventTemplate).where(
|
||||
EventTemplate.id == template_id,
|
||||
EventTemplate.user_id == current_user.id,
|
||||
)
|
||||
)
|
||||
template = result.scalar_one_or_none()
|
||||
if template is None:
|
||||
@ -66,7 +73,10 @@ async def delete_template(
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
result = await db.execute(
|
||||
select(EventTemplate).where(EventTemplate.id == template_id)
|
||||
select(EventTemplate).where(
|
||||
EventTemplate.id == template_id,
|
||||
EventTemplate.user_id == current_user.id,
|
||||
)
|
||||
)
|
||||
template = result.scalar_one_or_none()
|
||||
if template is None:
|
||||
|
||||
@ -105,15 +105,29 @@ def _birthday_events_for_range(
|
||||
return virtual_events
|
||||
|
||||
|
||||
async def _get_default_calendar_id(db: AsyncSession) -> int:
|
||||
"""Return the id of the default calendar, raising 500 if not found."""
|
||||
result = await db.execute(select(Calendar).where(Calendar.is_default == True))
|
||||
async def _get_default_calendar_id(db: AsyncSession, user_id: int) -> int:
|
||||
"""Return the id of the user's default calendar, raising 500 if not found."""
|
||||
result = await db.execute(
|
||||
select(Calendar).where(
|
||||
Calendar.user_id == user_id,
|
||||
Calendar.is_default == True,
|
||||
)
|
||||
)
|
||||
default = result.scalar_one_or_none()
|
||||
if not default:
|
||||
raise HTTPException(status_code=500, detail="No default calendar configured")
|
||||
return default.id
|
||||
|
||||
|
||||
async def _verify_calendar_ownership(db: AsyncSession, calendar_id: int, user_id: int) -> None:
|
||||
"""Raise 404 if calendar_id does not belong to user_id (SEC-04)."""
|
||||
result = await db.execute(
|
||||
select(Calendar).where(Calendar.id == calendar_id, Calendar.user_id == user_id)
|
||||
)
|
||||
if not result.scalar_one_or_none():
|
||||
raise HTTPException(status_code=404, detail="Calendar not found")
|
||||
|
||||
|
||||
@router.get("/", response_model=None)
|
||||
async def get_events(
|
||||
start: Optional[date] = Query(None),
|
||||
@ -128,9 +142,13 @@ async def get_events(
|
||||
recurrence_rule IS NOT NULL) are excluded — their materialised children
|
||||
are what get displayed on the calendar.
|
||||
"""
|
||||
# Scope events through calendar ownership
|
||||
user_calendar_ids = select(Calendar.id).where(Calendar.user_id == current_user.id)
|
||||
|
||||
query = (
|
||||
select(CalendarEvent)
|
||||
.options(selectinload(CalendarEvent.calendar))
|
||||
.where(CalendarEvent.calendar_id.in_(user_calendar_ids))
|
||||
)
|
||||
|
||||
# Exclude parent template rows — they are not directly rendered
|
||||
@ -154,14 +172,24 @@ async def get_events(
|
||||
|
||||
response: List[dict] = [_event_to_dict(e) for e in events]
|
||||
|
||||
# Fetch Birthdays calendar; only generate virtual events if visible
|
||||
# Fetch the user's Birthdays system calendar; only generate virtual events if visible
|
||||
bday_result = await db.execute(
|
||||
select(Calendar).where(Calendar.name == "Birthdays", Calendar.is_system == True)
|
||||
select(Calendar).where(
|
||||
Calendar.user_id == current_user.id,
|
||||
Calendar.name == "Birthdays",
|
||||
Calendar.is_system == True,
|
||||
)
|
||||
)
|
||||
bday_calendar = bday_result.scalar_one_or_none()
|
||||
|
||||
if bday_calendar and bday_calendar.is_visible:
|
||||
people_result = await db.execute(select(Person).where(Person.birthday.isnot(None)))
|
||||
# Scope birthday people to this user
|
||||
people_result = await db.execute(
|
||||
select(Person).where(
|
||||
Person.user_id == current_user.id,
|
||||
Person.birthday.isnot(None),
|
||||
)
|
||||
)
|
||||
people = people_result.scalars().all()
|
||||
|
||||
virtual = _birthday_events_for_range(
|
||||
@ -187,9 +215,12 @@ async def create_event(
|
||||
|
||||
data = event.model_dump()
|
||||
|
||||
# Resolve calendar_id to default if not provided
|
||||
# Resolve calendar_id to user's default if not provided
|
||||
if not data.get("calendar_id"):
|
||||
data["calendar_id"] = await _get_default_calendar_id(db)
|
||||
data["calendar_id"] = await _get_default_calendar_id(db, current_user.id)
|
||||
else:
|
||||
# SEC-04: verify the target calendar belongs to the requesting user
|
||||
await _verify_calendar_ownership(db, data["calendar_id"], current_user.id)
|
||||
|
||||
# Serialize RecurrenceRule object to JSON string for DB storage
|
||||
# Exclude None values so defaults in recurrence service work correctly
|
||||
@ -245,10 +276,15 @@ async def get_event(
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
user_calendar_ids = select(Calendar.id).where(Calendar.user_id == current_user.id)
|
||||
|
||||
result = await db.execute(
|
||||
select(CalendarEvent)
|
||||
.options(selectinload(CalendarEvent.calendar))
|
||||
.where(CalendarEvent.id == event_id)
|
||||
.where(
|
||||
CalendarEvent.id == event_id,
|
||||
CalendarEvent.calendar_id.in_(user_calendar_ids),
|
||||
)
|
||||
)
|
||||
event = result.scalar_one_or_none()
|
||||
|
||||
@ -265,10 +301,15 @@ async def update_event(
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
user_calendar_ids = select(Calendar.id).where(Calendar.user_id == current_user.id)
|
||||
|
||||
result = await db.execute(
|
||||
select(CalendarEvent)
|
||||
.options(selectinload(CalendarEvent.calendar))
|
||||
.where(CalendarEvent.id == event_id)
|
||||
.where(
|
||||
CalendarEvent.id == event_id,
|
||||
CalendarEvent.calendar_id.in_(user_calendar_ids),
|
||||
)
|
||||
)
|
||||
event = result.scalar_one_or_none()
|
||||
|
||||
@ -285,6 +326,10 @@ async def update_event(
|
||||
if rule_obj is not None:
|
||||
update_data["recurrence_rule"] = json.dumps({k: v for k, v in rule_obj.items() if v is not None}) if rule_obj else None
|
||||
|
||||
# SEC-04: if calendar_id is being changed, verify the target belongs to the user
|
||||
if "calendar_id" in update_data and update_data["calendar_id"] is not None:
|
||||
await _verify_calendar_ownership(db, update_data["calendar_id"], current_user.id)
|
||||
|
||||
start = update_data.get("start_datetime", event.start_datetime)
|
||||
end_dt = update_data.get("end_datetime", event.end_datetime)
|
||||
if end_dt is not None and end_dt < start:
|
||||
@ -381,7 +426,14 @@ async def delete_event(
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
result = await db.execute(select(CalendarEvent).where(CalendarEvent.id == event_id))
|
||||
user_calendar_ids = select(Calendar.id).where(Calendar.user_id == current_user.id)
|
||||
|
||||
result = await db.execute(
|
||||
select(CalendarEvent).where(
|
||||
CalendarEvent.id == event_id,
|
||||
CalendarEvent.calendar_id.in_(user_calendar_ids),
|
||||
)
|
||||
)
|
||||
event = result.scalar_one_or_none()
|
||||
|
||||
if not event:
|
||||
|
||||
@ -29,14 +29,15 @@ async def search_locations(
|
||||
"""Search locations from local DB and Nominatim OSM."""
|
||||
results: List[LocationSearchResult] = []
|
||||
|
||||
# Local DB search
|
||||
# Local DB search — scoped to user's locations
|
||||
local_query = (
|
||||
select(Location)
|
||||
.where(
|
||||
Location.user_id == current_user.id,
|
||||
or_(
|
||||
Location.name.ilike(f"%{q}%"),
|
||||
Location.address.ilike(f"%{q}%"),
|
||||
)
|
||||
),
|
||||
)
|
||||
.limit(5)
|
||||
)
|
||||
@ -89,7 +90,7 @@ async def get_locations(
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""Get all locations with optional category filter."""
|
||||
query = select(Location)
|
||||
query = select(Location).where(Location.user_id == current_user.id)
|
||||
|
||||
if category:
|
||||
query = query.where(Location.category == category)
|
||||
@ -109,7 +110,7 @@ async def create_location(
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""Create a new location."""
|
||||
new_location = Location(**location.model_dump())
|
||||
new_location = Location(**location.model_dump(), user_id=current_user.id)
|
||||
db.add(new_location)
|
||||
await db.commit()
|
||||
await db.refresh(new_location)
|
||||
@ -124,7 +125,9 @@ async def get_location(
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""Get a specific location by ID."""
|
||||
result = await db.execute(select(Location).where(Location.id == location_id))
|
||||
result = await db.execute(
|
||||
select(Location).where(Location.id == location_id, Location.user_id == current_user.id)
|
||||
)
|
||||
location = result.scalar_one_or_none()
|
||||
|
||||
if not location:
|
||||
@ -141,7 +144,9 @@ async def update_location(
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""Update a location."""
|
||||
result = await db.execute(select(Location).where(Location.id == location_id))
|
||||
result = await db.execute(
|
||||
select(Location).where(Location.id == location_id, Location.user_id == current_user.id)
|
||||
)
|
||||
location = result.scalar_one_or_none()
|
||||
|
||||
if not location:
|
||||
@ -168,7 +173,9 @@ async def delete_location(
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""Delete a location."""
|
||||
result = await db.execute(select(Location).where(Location.id == location_id))
|
||||
result = await db.execute(
|
||||
select(Location).where(Location.id == location_id, Location.user_id == current_user.id)
|
||||
)
|
||||
location = result.scalar_one_or_none()
|
||||
|
||||
if not location:
|
||||
|
||||
@ -37,7 +37,7 @@ async def get_people(
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""Get all people with optional search and category filter."""
|
||||
query = select(Person)
|
||||
query = select(Person).where(Person.user_id == current_user.id)
|
||||
|
||||
if search:
|
||||
term = f"%{search}%"
|
||||
@ -75,7 +75,7 @@ async def create_person(
|
||||
parts = data['name'].split(' ', 1)
|
||||
data['first_name'] = parts[0]
|
||||
data['last_name'] = parts[1] if len(parts) > 1 else None
|
||||
new_person = Person(**data)
|
||||
new_person = Person(**data, user_id=current_user.id)
|
||||
new_person.name = _compute_display_name(
|
||||
new_person.first_name,
|
||||
new_person.last_name,
|
||||
@ -96,7 +96,9 @@ async def get_person(
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""Get a specific person by ID."""
|
||||
result = await db.execute(select(Person).where(Person.id == person_id))
|
||||
result = await db.execute(
|
||||
select(Person).where(Person.id == person_id, Person.user_id == current_user.id)
|
||||
)
|
||||
person = result.scalar_one_or_none()
|
||||
|
||||
if not person:
|
||||
@ -113,7 +115,9 @@ async def update_person(
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""Update a person and refresh the denormalised display name."""
|
||||
result = await db.execute(select(Person).where(Person.id == person_id))
|
||||
result = await db.execute(
|
||||
select(Person).where(Person.id == person_id, Person.user_id == current_user.id)
|
||||
)
|
||||
person = result.scalar_one_or_none()
|
||||
|
||||
if not person:
|
||||
@ -147,7 +151,9 @@ async def delete_person(
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""Delete a person."""
|
||||
result = await db.execute(select(Person).where(Person.id == person_id))
|
||||
result = await db.execute(
|
||||
select(Person).where(Person.id == person_id, Person.user_id == current_user.id)
|
||||
)
|
||||
person = result.scalar_one_or_none()
|
||||
|
||||
if not person:
|
||||
|
||||
@ -49,7 +49,12 @@ async def get_projects(
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""Get all projects with their tasks. Optionally filter by tracked status."""
|
||||
query = select(Project).options(*_project_load_options()).order_by(Project.created_at.desc())
|
||||
query = (
|
||||
select(Project)
|
||||
.options(*_project_load_options())
|
||||
.where(Project.user_id == current_user.id)
|
||||
.order_by(Project.created_at.desc())
|
||||
)
|
||||
if tracked is not None:
|
||||
query = query.where(Project.is_tracked == tracked)
|
||||
result = await db.execute(query)
|
||||
@ -77,6 +82,7 @@ async def get_tracked_tasks(
|
||||
selectinload(ProjectTask.parent_task),
|
||||
)
|
||||
.where(
|
||||
Project.user_id == current_user.id,
|
||||
Project.is_tracked == True,
|
||||
ProjectTask.due_date.isnot(None),
|
||||
ProjectTask.due_date >= today,
|
||||
@ -110,7 +116,7 @@ async def create_project(
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""Create a new project."""
|
||||
new_project = Project(**project.model_dump())
|
||||
new_project = Project(**project.model_dump(), user_id=current_user.id)
|
||||
db.add(new_project)
|
||||
await db.commit()
|
||||
|
||||
@ -127,7 +133,11 @@ async def get_project(
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""Get a specific project by ID with its tasks."""
|
||||
query = select(Project).options(*_project_load_options()).where(Project.id == project_id)
|
||||
query = (
|
||||
select(Project)
|
||||
.options(*_project_load_options())
|
||||
.where(Project.id == project_id, Project.user_id == current_user.id)
|
||||
)
|
||||
result = await db.execute(query)
|
||||
project = result.scalar_one_or_none()
|
||||
|
||||
@ -145,7 +155,9 @@ async def update_project(
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""Update a project."""
|
||||
result = await db.execute(select(Project).where(Project.id == project_id))
|
||||
result = await db.execute(
|
||||
select(Project).where(Project.id == project_id, Project.user_id == current_user.id)
|
||||
)
|
||||
project = result.scalar_one_or_none()
|
||||
|
||||
if not project:
|
||||
@ -171,7 +183,9 @@ async def delete_project(
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""Delete a project and all its tasks."""
|
||||
result = await db.execute(select(Project).where(Project.id == project_id))
|
||||
result = await db.execute(
|
||||
select(Project).where(Project.id == project_id, Project.user_id == current_user.id)
|
||||
)
|
||||
project = result.scalar_one_or_none()
|
||||
|
||||
if not project:
|
||||
@ -190,7 +204,10 @@ async def get_project_tasks(
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""Get top-level tasks for a specific project (subtasks are nested)."""
|
||||
result = await db.execute(select(Project).where(Project.id == project_id))
|
||||
# Verify project ownership first
|
||||
result = await db.execute(
|
||||
select(Project).where(Project.id == project_id, Project.user_id == current_user.id)
|
||||
)
|
||||
project = result.scalar_one_or_none()
|
||||
|
||||
if not project:
|
||||
@ -219,7 +236,10 @@ async def create_project_task(
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""Create a new task or subtask for a project."""
|
||||
result = await db.execute(select(Project).where(Project.id == project_id))
|
||||
# Verify project ownership first
|
||||
result = await db.execute(
|
||||
select(Project).where(Project.id == project_id, Project.user_id == current_user.id)
|
||||
)
|
||||
project = result.scalar_one_or_none()
|
||||
|
||||
if not project:
|
||||
@ -265,7 +285,10 @@ async def reorder_tasks(
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""Bulk update sort_order for tasks."""
|
||||
result = await db.execute(select(Project).where(Project.id == project_id))
|
||||
# Verify project ownership first
|
||||
result = await db.execute(
|
||||
select(Project).where(Project.id == project_id, Project.user_id == current_user.id)
|
||||
)
|
||||
project = result.scalar_one_or_none()
|
||||
|
||||
if not project:
|
||||
@ -296,6 +319,13 @@ async def update_project_task(
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""Update a project task."""
|
||||
# Verify project ownership first, then fetch task scoped to that project
|
||||
project_result = await db.execute(
|
||||
select(Project).where(Project.id == project_id, Project.user_id == current_user.id)
|
||||
)
|
||||
if not project_result.scalar_one_or_none():
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
|
||||
result = await db.execute(
|
||||
select(ProjectTask).where(
|
||||
ProjectTask.id == task_id,
|
||||
@ -332,6 +362,13 @@ async def delete_project_task(
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""Delete a project task (cascades to subtasks)."""
|
||||
# Verify project ownership first, then fetch task scoped to that project
|
||||
project_result = await db.execute(
|
||||
select(Project).where(Project.id == project_id, Project.user_id == current_user.id)
|
||||
)
|
||||
if not project_result.scalar_one_or_none():
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
|
||||
result = await db.execute(
|
||||
select(ProjectTask).where(
|
||||
ProjectTask.id == task_id,
|
||||
@ -358,6 +395,13 @@ async def create_task_comment(
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""Add a comment to a task."""
|
||||
# Verify project ownership first, then fetch task scoped to that project
|
||||
project_result = await db.execute(
|
||||
select(Project).where(Project.id == project_id, Project.user_id == current_user.id)
|
||||
)
|
||||
if not project_result.scalar_one_or_none():
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
|
||||
result = await db.execute(
|
||||
select(ProjectTask).where(
|
||||
ProjectTask.id == task_id,
|
||||
@ -386,6 +430,13 @@ async def delete_task_comment(
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""Delete a task comment."""
|
||||
# Verify project ownership first, then fetch comment scoped through task
|
||||
project_result = await db.execute(
|
||||
select(Project).where(Project.id == project_id, Project.user_id == current_user.id)
|
||||
)
|
||||
if not project_result.scalar_one_or_none():
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
|
||||
result = await db.execute(
|
||||
select(TaskComment).where(
|
||||
TaskComment.id == comment_id,
|
||||
|
||||
@ -22,7 +22,7 @@ async def get_reminders(
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""Get all reminders with optional filters."""
|
||||
query = select(Reminder)
|
||||
query = select(Reminder).where(Reminder.user_id == current_user.id)
|
||||
|
||||
if active is not None:
|
||||
query = query.where(Reminder.is_active == active)
|
||||
@ -48,6 +48,7 @@ async def get_due_reminders(
|
||||
now = client_now or datetime.now()
|
||||
query = select(Reminder).where(
|
||||
and_(
|
||||
Reminder.user_id == current_user.id,
|
||||
Reminder.remind_at <= now,
|
||||
Reminder.is_dismissed == False,
|
||||
Reminder.is_active == True,
|
||||
@ -74,7 +75,12 @@ async def snooze_reminder(
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""Snooze a reminder for N minutes from now."""
|
||||
result = await db.execute(select(Reminder).where(Reminder.id == reminder_id))
|
||||
result = await db.execute(
|
||||
select(Reminder).where(
|
||||
Reminder.id == reminder_id,
|
||||
Reminder.user_id == current_user.id,
|
||||
)
|
||||
)
|
||||
reminder = result.scalar_one_or_none()
|
||||
|
||||
if not reminder:
|
||||
@ -99,7 +105,7 @@ async def create_reminder(
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""Create a new reminder."""
|
||||
new_reminder = Reminder(**reminder.model_dump())
|
||||
new_reminder = Reminder(**reminder.model_dump(), user_id=current_user.id)
|
||||
db.add(new_reminder)
|
||||
await db.commit()
|
||||
await db.refresh(new_reminder)
|
||||
@ -114,7 +120,12 @@ async def get_reminder(
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""Get a specific reminder by ID."""
|
||||
result = await db.execute(select(Reminder).where(Reminder.id == reminder_id))
|
||||
result = await db.execute(
|
||||
select(Reminder).where(
|
||||
Reminder.id == reminder_id,
|
||||
Reminder.user_id == current_user.id,
|
||||
)
|
||||
)
|
||||
reminder = result.scalar_one_or_none()
|
||||
|
||||
if not reminder:
|
||||
@ -131,7 +142,12 @@ async def update_reminder(
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""Update a reminder."""
|
||||
result = await db.execute(select(Reminder).where(Reminder.id == reminder_id))
|
||||
result = await db.execute(
|
||||
select(Reminder).where(
|
||||
Reminder.id == reminder_id,
|
||||
Reminder.user_id == current_user.id,
|
||||
)
|
||||
)
|
||||
reminder = result.scalar_one_or_none()
|
||||
|
||||
if not reminder:
|
||||
@ -164,7 +180,12 @@ async def delete_reminder(
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""Delete a reminder."""
|
||||
result = await db.execute(select(Reminder).where(Reminder.id == reminder_id))
|
||||
result = await db.execute(
|
||||
select(Reminder).where(
|
||||
Reminder.id == reminder_id,
|
||||
Reminder.user_id == current_user.id,
|
||||
)
|
||||
)
|
||||
reminder = result.scalar_one_or_none()
|
||||
|
||||
if not reminder:
|
||||
@ -183,7 +204,12 @@ async def dismiss_reminder(
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""Dismiss a reminder."""
|
||||
result = await db.execute(select(Reminder).where(Reminder.id == reminder_id))
|
||||
result = await db.execute(
|
||||
select(Reminder).where(
|
||||
Reminder.id == reminder_id,
|
||||
Reminder.user_id == current_user.id,
|
||||
)
|
||||
)
|
||||
reminder = result.scalar_one_or_none()
|
||||
|
||||
if not reminder:
|
||||
|
||||
@ -73,15 +73,17 @@ def _calculate_recurrence(
|
||||
return reset_at, next_due
|
||||
|
||||
|
||||
async def _reactivate_recurring_todos(db: AsyncSession) -> None:
|
||||
async def _reactivate_recurring_todos(db: AsyncSession, user_id: int) -> None:
|
||||
"""Auto-reactivate recurring todos whose reset_at has passed.
|
||||
|
||||
Uses flush (not commit) so changes are visible to the subsequent query
|
||||
within the same transaction. The caller's commit handles persistence.
|
||||
Scoped to a single user to avoid cross-user reactivation.
|
||||
"""
|
||||
now = datetime.now()
|
||||
query = select(Todo).where(
|
||||
and_(
|
||||
Todo.user_id == user_id,
|
||||
Todo.completed == True,
|
||||
Todo.recurrence_rule.isnot(None),
|
||||
Todo.reset_at.isnot(None),
|
||||
@ -110,13 +112,14 @@ async def get_todos(
|
||||
category: Optional[str] = Query(None),
|
||||
search: Optional[str] = Query(None),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: Settings = Depends(get_current_settings)
|
||||
current_user: User = Depends(get_current_user),
|
||||
current_settings: Settings = Depends(get_current_settings),
|
||||
):
|
||||
"""Get all todos with optional filters."""
|
||||
# Reactivate any recurring todos whose reset time has passed
|
||||
await _reactivate_recurring_todos(db)
|
||||
await _reactivate_recurring_todos(db, current_user.id)
|
||||
|
||||
query = select(Todo)
|
||||
query = select(Todo).where(Todo.user_id == current_user.id)
|
||||
|
||||
if completed is not None:
|
||||
query = query.where(Todo.completed == completed)
|
||||
@ -144,10 +147,10 @@ async def get_todos(
|
||||
async def create_todo(
|
||||
todo: TodoCreate,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: Settings = Depends(get_current_settings)
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""Create a new todo."""
|
||||
new_todo = Todo(**todo.model_dump())
|
||||
new_todo = Todo(**todo.model_dump(), user_id=current_user.id)
|
||||
db.add(new_todo)
|
||||
await db.commit()
|
||||
await db.refresh(new_todo)
|
||||
@ -159,10 +162,12 @@ async def create_todo(
|
||||
async def get_todo(
|
||||
todo_id: int,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: Settings = Depends(get_current_settings)
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""Get a specific todo by ID."""
|
||||
result = await db.execute(select(Todo).where(Todo.id == todo_id))
|
||||
result = await db.execute(
|
||||
select(Todo).where(Todo.id == todo_id, Todo.user_id == current_user.id)
|
||||
)
|
||||
todo = result.scalar_one_or_none()
|
||||
|
||||
if not todo:
|
||||
@ -176,10 +181,13 @@ async def update_todo(
|
||||
todo_id: int,
|
||||
todo_update: TodoUpdate,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: Settings = Depends(get_current_settings)
|
||||
current_user: User = Depends(get_current_user),
|
||||
current_settings: Settings = Depends(get_current_settings),
|
||||
):
|
||||
"""Update a todo."""
|
||||
result = await db.execute(select(Todo).where(Todo.id == todo_id))
|
||||
result = await db.execute(
|
||||
select(Todo).where(Todo.id == todo_id, Todo.user_id == current_user.id)
|
||||
)
|
||||
todo = result.scalar_one_or_none()
|
||||
|
||||
if not todo:
|
||||
@ -210,7 +218,7 @@ async def update_todo(
|
||||
reset_at, next_due = _calculate_recurrence(
|
||||
todo.recurrence_rule,
|
||||
todo.due_date,
|
||||
current_user.first_day_of_week,
|
||||
current_settings.first_day_of_week,
|
||||
)
|
||||
todo.reset_at = reset_at
|
||||
todo.next_due_date = next_due
|
||||
@ -229,10 +237,12 @@ async def update_todo(
|
||||
async def delete_todo(
|
||||
todo_id: int,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: Settings = Depends(get_current_settings)
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""Delete a todo."""
|
||||
result = await db.execute(select(Todo).where(Todo.id == todo_id))
|
||||
result = await db.execute(
|
||||
select(Todo).where(Todo.id == todo_id, Todo.user_id == current_user.id)
|
||||
)
|
||||
todo = result.scalar_one_or_none()
|
||||
|
||||
if not todo:
|
||||
@ -248,10 +258,13 @@ async def delete_todo(
|
||||
async def toggle_todo(
|
||||
todo_id: int,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: Settings = Depends(get_current_settings)
|
||||
current_user: User = Depends(get_current_user),
|
||||
current_settings: Settings = Depends(get_current_settings),
|
||||
):
|
||||
"""Toggle todo completion status. For recurring todos, calculates reset schedule."""
|
||||
result = await db.execute(select(Todo).where(Todo.id == todo_id))
|
||||
result = await db.execute(
|
||||
select(Todo).where(Todo.id == todo_id, Todo.user_id == current_user.id)
|
||||
)
|
||||
todo = result.scalar_one_or_none()
|
||||
|
||||
if not todo:
|
||||
@ -267,7 +280,7 @@ async def toggle_todo(
|
||||
reset_at, next_due = _calculate_recurrence(
|
||||
todo.recurrence_rule,
|
||||
todo.due_date,
|
||||
current_user.first_day_of_week,
|
||||
current_settings.first_day_of_week,
|
||||
)
|
||||
todo.reset_at = reset_at
|
||||
todo.next_due_date = next_due
|
||||
|
||||
@ -3,6 +3,7 @@ from pydantic import BaseModel
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select
|
||||
from datetime import datetime, timedelta
|
||||
from collections import OrderedDict
|
||||
import asyncio
|
||||
import urllib.request
|
||||
import urllib.parse
|
||||
@ -11,13 +12,37 @@ import json
|
||||
|
||||
from app.database import get_db
|
||||
from app.models.settings import Settings
|
||||
from app.models.user import User
|
||||
from app.config import settings as app_settings
|
||||
from app.routers.auth import get_current_user, get_current_settings
|
||||
from app.models.user import User
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
_cache: dict = {}
|
||||
# SEC-15: Bounded LRU cache keyed by (user_id, location) — max 100 entries.
|
||||
# OrderedDict preserves insertion order; move_to_end on hit, popitem(last=False)
|
||||
# to evict the oldest when capacity is exceeded.
|
||||
_CACHE_MAX = 100
|
||||
_cache: OrderedDict = OrderedDict()
|
||||
|
||||
|
||||
def _cache_get(key: tuple) -> dict | None:
|
||||
"""Return cached entry if it exists and hasn't expired."""
|
||||
entry = _cache.get(key)
|
||||
if entry and datetime.now() < entry["expires_at"]:
|
||||
_cache.move_to_end(key) # LRU: promote to most-recently-used
|
||||
return entry["data"]
|
||||
if entry:
|
||||
del _cache[key] # expired — evict immediately
|
||||
return None
|
||||
|
||||
|
||||
def _cache_set(key: tuple, data: dict) -> None:
|
||||
"""Store an entry; evict the oldest if over capacity."""
|
||||
if key in _cache:
|
||||
_cache.move_to_end(key)
|
||||
_cache[key] = {"data": data, "expires_at": datetime.now() + timedelta(hours=1)}
|
||||
while len(_cache) > _CACHE_MAX:
|
||||
_cache.popitem(last=False) # evict LRU (oldest)
|
||||
|
||||
|
||||
class GeoSearchResult(BaseModel):
|
||||
@ -66,23 +91,24 @@ async def search_locations(
|
||||
@router.get("/")
|
||||
async def get_weather(
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: Settings = Depends(get_current_settings)
|
||||
current_user: User = Depends(get_current_user),
|
||||
current_settings: Settings = Depends(get_current_settings),
|
||||
):
|
||||
city = current_user.weather_city
|
||||
lat = current_user.weather_lat
|
||||
lon = current_user.weather_lon
|
||||
city = current_settings.weather_city
|
||||
lat = current_settings.weather_lat
|
||||
lon = current_settings.weather_lon
|
||||
|
||||
if not city and (lat is None or lon is None):
|
||||
raise HTTPException(status_code=400, detail="No weather location configured")
|
||||
|
||||
# Build cache key from coordinates or city
|
||||
# Cache key includes user_id so each user gets isolated cache entries
|
||||
use_coords = lat is not None and lon is not None
|
||||
cache_key = f"{lat},{lon}" if use_coords else city
|
||||
location_key = f"{lat},{lon}" if use_coords else city
|
||||
cache_key = (current_user.id, location_key)
|
||||
|
||||
# Check cache
|
||||
now = datetime.now()
|
||||
if _cache.get("expires_at") and now < _cache["expires_at"] and _cache.get("cache_key") == cache_key:
|
||||
return _cache["data"]
|
||||
cached = _cache_get(cache_key)
|
||||
if cached is not None:
|
||||
return cached
|
||||
|
||||
api_key = app_settings.OPENWEATHERMAP_API_KEY
|
||||
if not api_key:
|
||||
@ -122,11 +148,7 @@ async def get_weather(
|
||||
"city": current_data["name"],
|
||||
}
|
||||
|
||||
# Cache for 1 hour
|
||||
_cache["data"] = weather_result
|
||||
_cache["expires_at"] = now + timedelta(hours=1)
|
||||
_cache["cache_key"] = cache_key
|
||||
|
||||
_cache_set(cache_key, weather_result)
|
||||
return weather_result
|
||||
|
||||
except urllib.error.URLError:
|
||||
|
||||
133
backend/app/schemas/admin.py
Normal file
133
backend/app/schemas/admin.py
Normal file
@ -0,0 +1,133 @@
|
||||
"""
|
||||
Admin API schemas — Pydantic v2.
|
||||
|
||||
All admin-facing request/response shapes live here to keep the admin router
|
||||
clean and testable in isolation.
|
||||
"""
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import Optional, Literal
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, field_validator
|
||||
|
||||
from app.schemas.auth import _validate_username, _validate_password_strength
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# User list / detail
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class UserListItem(BaseModel):
|
||||
id: int
|
||||
username: str
|
||||
role: str
|
||||
is_active: bool
|
||||
last_login_at: Optional[datetime] = None
|
||||
last_password_change_at: Optional[datetime] = None
|
||||
totp_enabled: bool
|
||||
mfa_enforce_pending: bool
|
||||
created_at: datetime
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class UserListResponse(BaseModel):
|
||||
users: list[UserListItem]
|
||||
total: int
|
||||
|
||||
|
||||
class UserDetailResponse(UserListItem):
|
||||
active_sessions: int
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Mutating user requests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class CreateUserRequest(BaseModel):
|
||||
"""Admin-created user — allows role selection (unlike public RegisterRequest)."""
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
username: str
|
||||
password: str
|
||||
role: Literal["admin", "standard", "public_event_manager"] = "standard"
|
||||
|
||||
@field_validator("username")
|
||||
@classmethod
|
||||
def validate_username(cls, v: str) -> str:
|
||||
return _validate_username(v)
|
||||
|
||||
@field_validator("password")
|
||||
@classmethod
|
||||
def validate_password(cls, v: str) -> str:
|
||||
return _validate_password_strength(v)
|
||||
|
||||
|
||||
class UpdateUserRoleRequest(BaseModel):
|
||||
role: Literal["admin", "standard", "public_event_manager"]
|
||||
|
||||
|
||||
class ToggleActiveRequest(BaseModel):
|
||||
is_active: bool
|
||||
|
||||
|
||||
class ToggleMfaEnforceRequest(BaseModel):
|
||||
enforce: bool
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# System config
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class SystemConfigResponse(BaseModel):
|
||||
allow_registration: bool
|
||||
enforce_mfa_new_users: bool
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class SystemConfigUpdate(BaseModel):
|
||||
allow_registration: Optional[bool] = None
|
||||
enforce_mfa_new_users: Optional[bool] = None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Admin dashboard
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class AdminDashboardResponse(BaseModel):
|
||||
total_users: int
|
||||
active_users: int
|
||||
admin_count: int
|
||||
active_sessions: int
|
||||
mfa_adoption_rate: float
|
||||
recent_logins: list[dict]
|
||||
recent_audit_entries: list[dict]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Password reset
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class ResetPasswordResponse(BaseModel):
|
||||
message: str
|
||||
temporary_password: str
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Audit log
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class AuditLogEntry(BaseModel):
|
||||
id: int
|
||||
actor_username: Optional[str] = None
|
||||
target_username: Optional[str] = None
|
||||
action: str
|
||||
detail: Optional[str] = None
|
||||
ip_address: Optional[str] = None
|
||||
created_at: datetime
|
||||
|
||||
|
||||
class AuditLogResponse(BaseModel):
|
||||
entries: list[AuditLogEntry]
|
||||
total: int
|
||||
@ -1,5 +1,5 @@
|
||||
import re
|
||||
from pydantic import BaseModel, field_validator
|
||||
from pydantic import BaseModel, ConfigDict, field_validator
|
||||
|
||||
|
||||
def _validate_password_strength(v: str) -> str:
|
||||
@ -21,6 +21,16 @@ def _validate_password_strength(v: str) -> str:
|
||||
return v
|
||||
|
||||
|
||||
def _validate_username(v: str) -> str:
|
||||
"""Shared username validation."""
|
||||
v = v.strip().lower()
|
||||
if not 3 <= len(v) <= 50:
|
||||
raise ValueError("Username must be 3–50 characters")
|
||||
if not re.fullmatch(r"[a-z0-9_\-]+", v):
|
||||
raise ValueError("Username may only contain letters, numbers, _ and -")
|
||||
return v
|
||||
|
||||
|
||||
class SetupRequest(BaseModel):
|
||||
username: str
|
||||
password: str
|
||||
@ -28,12 +38,29 @@ class SetupRequest(BaseModel):
|
||||
@field_validator("username")
|
||||
@classmethod
|
||||
def validate_username(cls, v: str) -> str:
|
||||
v = v.strip().lower()
|
||||
if not 3 <= len(v) <= 50:
|
||||
raise ValueError("Username must be 3–50 characters")
|
||||
if not re.fullmatch(r"[a-z0-9_\-]+", v):
|
||||
raise ValueError("Username may only contain letters, numbers, _ and -")
|
||||
return v
|
||||
return _validate_username(v)
|
||||
|
||||
@field_validator("password")
|
||||
@classmethod
|
||||
def validate_password(cls, v: str) -> str:
|
||||
return _validate_password_strength(v)
|
||||
|
||||
|
||||
class RegisterRequest(BaseModel):
|
||||
"""
|
||||
Public registration schema — SEC-01: extra="forbid" prevents role injection.
|
||||
An attacker sending {"username": "...", "password": "...", "role": "admin"}
|
||||
will get a 422 Validation Error instead of silent acceptance.
|
||||
"""
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
username: str
|
||||
password: str
|
||||
|
||||
@field_validator("username")
|
||||
@classmethod
|
||||
def validate_username(cls, v: str) -> str:
|
||||
return _validate_username(v)
|
||||
|
||||
@field_validator("password")
|
||||
@classmethod
|
||||
|
||||
22
backend/app/services/audit.py
Normal file
22
backend/app/services/audit.py
Normal file
@ -0,0 +1,22 @@
|
||||
import json
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from app.models.audit_log import AuditLog
|
||||
|
||||
|
||||
async def log_audit_event(
|
||||
db: AsyncSession,
|
||||
action: str,
|
||||
actor_id: int | None = None,
|
||||
target_id: int | None = None,
|
||||
detail: dict | None = None,
|
||||
ip: str | None = None,
|
||||
) -> None:
|
||||
"""Record an action in the audit log. Does NOT commit — caller handles transaction."""
|
||||
entry = AuditLog(
|
||||
actor_user_id=actor_id,
|
||||
target_user_id=target_id,
|
||||
action=action,
|
||||
detail=json.dumps(detail) if detail else None,
|
||||
ip_address=ip[:45] if ip else None,
|
||||
)
|
||||
db.add(entry)
|
||||
@ -126,3 +126,32 @@ def verify_mfa_token(token: str) -> int | None:
|
||||
return data["uid"]
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# MFA enforcement tokens (SEC-03: distinct salt from challenge tokens)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_mfa_enforce_serializer = URLSafeTimedSerializer(
|
||||
secret_key=app_settings.SECRET_KEY,
|
||||
salt="mfa-enforce-setup-v1",
|
||||
)
|
||||
|
||||
|
||||
def create_mfa_enforce_token(user_id: int) -> str:
|
||||
"""Create a short-lived token for MFA enforcement setup (not a session)."""
|
||||
return _mfa_enforce_serializer.dumps({"uid": user_id})
|
||||
|
||||
|
||||
def verify_mfa_enforce_token(token: str) -> int | None:
|
||||
"""
|
||||
Verify an MFA enforcement setup token.
|
||||
Returns user_id on success, None if invalid or expired (5-minute TTL).
|
||||
"""
|
||||
try:
|
||||
data = _mfa_enforce_serializer.loads(
|
||||
token, max_age=app_settings.MFA_TOKEN_MAX_AGE_SECONDS
|
||||
)
|
||||
return data["uid"]
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
@ -1,5 +1,9 @@
|
||||
# Rate limiting zones (before server block)
|
||||
limit_req_zone $binary_remote_addr zone=auth_limit:10m rate=10r/m;
|
||||
# SEC-14: Registration endpoint — slightly more permissive than strict auth endpoints
|
||||
limit_req_zone $binary_remote_addr zone=register_limit:10m rate=5r/m;
|
||||
# Admin API — generous for legitimate use but still guards against scraping/brute-force
|
||||
limit_req_zone $binary_remote_addr zone=admin_limit:10m rate=30r/m;
|
||||
|
||||
# Use X-Forwarded-Proto from upstream proxy when present, fall back to $scheme for direct access
|
||||
map $http_x_forwarded_proto $forwarded_proto {
|
||||
@ -60,6 +64,20 @@ server {
|
||||
include /etc/nginx/proxy-params.conf;
|
||||
}
|
||||
|
||||
# SEC-14: Rate-limit public registration endpoint
|
||||
location /api/auth/register {
|
||||
limit_req zone=register_limit burst=3 nodelay;
|
||||
limit_req_status 429;
|
||||
include /etc/nginx/proxy-params.conf;
|
||||
}
|
||||
|
||||
# Admin API — rate-limited separately from general /api traffic
|
||||
location /api/admin/ {
|
||||
limit_req zone=admin_limit burst=10 nodelay;
|
||||
limit_req_status 429;
|
||||
include /etc/nginx/proxy-params.conf;
|
||||
}
|
||||
|
||||
# API proxy
|
||||
location /api {
|
||||
proxy_pass http://backend:8000;
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import { lazy, Suspense } from 'react';
|
||||
import { Routes, Route, Navigate } from 'react-router-dom';
|
||||
import { useAuth } from '@/hooks/useAuth';
|
||||
import LockScreen from '@/components/auth/LockScreen';
|
||||
@ -12,6 +13,8 @@ import PeoplePage from '@/components/people/PeoplePage';
|
||||
import LocationsPage from '@/components/locations/LocationsPage';
|
||||
import SettingsPage from '@/components/settings/SettingsPage';
|
||||
|
||||
const AdminPortal = lazy(() => import('@/components/admin/AdminPortal'));
|
||||
|
||||
function ProtectedRoute({ children }: { children: React.ReactNode }) {
|
||||
const { authStatus, isLoading } = useAuth();
|
||||
|
||||
@ -30,6 +33,24 @@ function ProtectedRoute({ children }: { children: React.ReactNode }) {
|
||||
return <>{children}</>;
|
||||
}
|
||||
|
||||
function AdminRoute({ children }: { children: React.ReactNode }) {
|
||||
const { authStatus, isLoading } = useAuth();
|
||||
|
||||
if (isLoading) {
|
||||
return (
|
||||
<div className="flex h-screen items-center justify-center">
|
||||
<div className="text-muted-foreground">Loading...</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
if (!authStatus?.authenticated || authStatus?.role !== 'admin') {
|
||||
return <Navigate to="/dashboard" replace />;
|
||||
}
|
||||
|
||||
return <>{children}</>;
|
||||
}
|
||||
|
||||
function App() {
|
||||
return (
|
||||
<Routes>
|
||||
@ -52,6 +73,16 @@ function App() {
|
||||
<Route path="people" element={<PeoplePage />} />
|
||||
<Route path="locations" element={<LocationsPage />} />
|
||||
<Route path="settings" element={<SettingsPage />} />
|
||||
<Route
|
||||
path="admin/*"
|
||||
element={
|
||||
<AdminRoute>
|
||||
<Suspense fallback={<div className="flex h-full items-center justify-center text-muted-foreground">Loading...</div>}>
|
||||
<AdminPortal />
|
||||
</Suspense>
|
||||
</AdminRoute>
|
||||
}
|
||||
/>
|
||||
</Route>
|
||||
</Routes>
|
||||
);
|
||||
|
||||
@ -16,6 +16,7 @@ import {
|
||||
X,
|
||||
LogOut,
|
||||
Lock,
|
||||
Shield,
|
||||
} from 'lucide-react';
|
||||
import { cn } from '@/lib/utils';
|
||||
import { useAuth } from '@/hooks/useAuth';
|
||||
@ -44,7 +45,7 @@ interface SidebarProps {
|
||||
export default function Sidebar({ collapsed, onToggle, mobileOpen, onMobileClose }: SidebarProps) {
|
||||
const navigate = useNavigate();
|
||||
const location = useLocation();
|
||||
const { logout } = useAuth();
|
||||
const { logout, isAdmin } = useAuth();
|
||||
const { lock } = useLock();
|
||||
const [projectsExpanded, setProjectsExpanded] = useState(false);
|
||||
|
||||
@ -193,6 +194,16 @@ export default function Sidebar({ collapsed, onToggle, mobileOpen, onMobileClose
|
||||
<Lock className="h-5 w-5 shrink-0" />
|
||||
{showExpanded && <span>Lock</span>}
|
||||
</button>
|
||||
{isAdmin && (
|
||||
<NavLink
|
||||
to="/admin"
|
||||
onClick={mobileOpen ? onMobileClose : undefined}
|
||||
className={navLinkClass}
|
||||
>
|
||||
<Shield className="h-5 w-5 shrink-0" />
|
||||
{showExpanded && <span>Admin</span>}
|
||||
</NavLink>
|
||||
)}
|
||||
<NavLink
|
||||
to="/settings"
|
||||
onClick={mobileOpen ? onMobileClose : undefined}
|
||||
|
||||
@ -5,8 +5,8 @@ import type { AuthStatus, LoginResponse } from '@/types';
|
||||
|
||||
export function useAuth() {
|
||||
const queryClient = useQueryClient();
|
||||
// Ephemeral MFA token — not in TanStack cache, lives only during the TOTP challenge step
|
||||
const [mfaToken, setMfaToken] = useState<string | null>(null);
|
||||
const [mfaSetupRequired, setMfaSetupRequired] = useState(false);
|
||||
|
||||
const authQuery = useQuery({
|
||||
queryKey: ['auth'],
|
||||
@ -23,11 +23,34 @@ export function useAuth() {
|
||||
return data;
|
||||
},
|
||||
onSuccess: (data) => {
|
||||
if ('mfa_token' in data && data.totp_required) {
|
||||
// MFA required — store token locally, do NOT mark as authenticated yet
|
||||
if ('mfa_setup_required' in data && data.mfa_setup_required) {
|
||||
// MFA enforcement — user must set up TOTP before accessing app
|
||||
setMfaSetupRequired(true);
|
||||
setMfaToken(data.mfa_token);
|
||||
} else if ('mfa_token' in data && 'totp_required' in data && data.totp_required) {
|
||||
// Regular TOTP challenge
|
||||
setMfaToken(data.mfa_token);
|
||||
setMfaSetupRequired(false);
|
||||
} else {
|
||||
setMfaToken(null);
|
||||
setMfaSetupRequired(false);
|
||||
queryClient.invalidateQueries({ queryKey: ['auth'] });
|
||||
}
|
||||
},
|
||||
});
|
||||
|
||||
const registerMutation = useMutation({
|
||||
mutationFn: async ({ username, password }: { username: string; password: string }) => {
|
||||
const { data } = await api.post<LoginResponse & { message?: string }>('/auth/register', { username, password });
|
||||
return data;
|
||||
},
|
||||
onSuccess: (data) => {
|
||||
if ('mfa_setup_required' in data && data.mfa_setup_required) {
|
||||
setMfaSetupRequired(true);
|
||||
setMfaToken(data.mfa_token);
|
||||
} else {
|
||||
setMfaToken(null);
|
||||
setMfaSetupRequired(false);
|
||||
queryClient.invalidateQueries({ queryKey: ['auth'] });
|
||||
}
|
||||
},
|
||||
@ -43,6 +66,7 @@ export function useAuth() {
|
||||
},
|
||||
onSuccess: () => {
|
||||
setMfaToken(null);
|
||||
setMfaSetupRequired(false);
|
||||
queryClient.invalidateQueries({ queryKey: ['auth'] });
|
||||
},
|
||||
});
|
||||
@ -64,6 +88,7 @@ export function useAuth() {
|
||||
},
|
||||
onSuccess: () => {
|
||||
setMfaToken(null);
|
||||
setMfaSetupRequired(false);
|
||||
queryClient.invalidateQueries({ queryKey: ['auth'] });
|
||||
},
|
||||
});
|
||||
@ -71,12 +96,18 @@ export function useAuth() {
|
||||
return {
|
||||
authStatus: authQuery.data,
|
||||
isLoading: authQuery.isLoading,
|
||||
mfaRequired: mfaToken !== null,
|
||||
role: authQuery.data?.role ?? null,
|
||||
isAdmin: authQuery.data?.role === 'admin',
|
||||
mfaRequired: mfaToken !== null && !mfaSetupRequired,
|
||||
mfaSetupRequired,
|
||||
mfaToken,
|
||||
login: loginMutation.mutateAsync,
|
||||
register: registerMutation.mutateAsync,
|
||||
verifyTotp: totpVerifyMutation.mutateAsync,
|
||||
setup: setupMutation.mutateAsync,
|
||||
logout: logoutMutation.mutateAsync,
|
||||
isLoginPending: loginMutation.isPending,
|
||||
isRegisterPending: registerMutation.isPending,
|
||||
isTotpPending: totpVerifyMutation.isPending,
|
||||
isSetupPending: setupMutation.isPending,
|
||||
};
|
||||
|
||||
@ -4,6 +4,7 @@ const api = axios.create({
|
||||
baseURL: '/api',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
'X-Requested-With': 'XMLHttpRequest',
|
||||
},
|
||||
withCredentials: true,
|
||||
});
|
||||
|
||||
@ -188,14 +188,19 @@ export interface Location {
|
||||
updated_at: string;
|
||||
}
|
||||
|
||||
export type UserRole = 'admin' | 'standard' | 'public_event_manager';
|
||||
|
||||
export interface AuthStatus {
|
||||
authenticated: boolean;
|
||||
setup_required: boolean;
|
||||
role: UserRole | null;
|
||||
registration_open: boolean;
|
||||
}
|
||||
|
||||
// Login response discriminated union
|
||||
export interface LoginSuccessResponse {
|
||||
authenticated: true;
|
||||
must_change_password?: boolean;
|
||||
}
|
||||
|
||||
export interface LoginMfaRequiredResponse {
|
||||
@ -204,7 +209,64 @@ export interface LoginMfaRequiredResponse {
|
||||
mfa_token: string;
|
||||
}
|
||||
|
||||
export type LoginResponse = LoginSuccessResponse | LoginMfaRequiredResponse;
|
||||
export interface LoginMfaSetupRequiredResponse {
|
||||
authenticated: false;
|
||||
mfa_setup_required: true;
|
||||
mfa_token: string;
|
||||
}
|
||||
|
||||
export type LoginResponse = LoginSuccessResponse | LoginMfaRequiredResponse | LoginMfaSetupRequiredResponse;
|
||||
|
||||
// Admin types
|
||||
export interface AdminUser {
|
||||
id: number;
|
||||
username: string;
|
||||
role: UserRole;
|
||||
is_active: boolean;
|
||||
last_login_at: string | null;
|
||||
last_password_change_at: string | null;
|
||||
totp_enabled: boolean;
|
||||
mfa_enforce_pending: boolean;
|
||||
created_at: string;
|
||||
}
|
||||
|
||||
export interface AdminUserDetail extends AdminUser {
|
||||
active_sessions: number;
|
||||
}
|
||||
|
||||
export interface SystemConfig {
|
||||
allow_registration: boolean;
|
||||
enforce_mfa_new_users: boolean;
|
||||
}
|
||||
|
||||
export interface AuditLogEntry {
|
||||
id: number;
|
||||
actor_username: string | null;
|
||||
target_username: string | null;
|
||||
action: string;
|
||||
detail: string | null;
|
||||
ip_address: string | null;
|
||||
created_at: string;
|
||||
}
|
||||
|
||||
export interface AdminDashboardData {
|
||||
total_users: number;
|
||||
active_users: number;
|
||||
admin_count: number;
|
||||
active_sessions: number;
|
||||
mfa_adoption_rate: number;
|
||||
recent_logins: Array<{
|
||||
username: string;
|
||||
last_login_at: string;
|
||||
ip_address: string;
|
||||
}>;
|
||||
recent_audit_entries: Array<{
|
||||
action: string;
|
||||
actor_username: string | null;
|
||||
target_username: string | null;
|
||||
created_at: string;
|
||||
}>;
|
||||
}
|
||||
|
||||
// TOTP setup response (from POST /api/auth/totp/setup)
|
||||
export interface TotpSetupResponse {
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user