Kyle Pope 9f7bbbfcbb Add per-user active session counts to IAM user list
Move active_sessions field from UserDetailResponse into UserListItem
so GET /admin/users returns session counts. Uses a correlated subquery
to count non-revoked, non-expired sessions per user.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-27 13:26:32 +08:00

716 lines
24 KiB
Python

"""
Admin router — full user management, system config, and audit log.
Security measures implemented:
SEC-02: Session revocation on role change
SEC-05: Block admin self-actions (own role/password/MFA/active status)
SEC-08: X-Requested-With header check (verify_xhr) on all state-mutating requests
SEC-13: Session revocation + ntfy alert on MFA disable
All routes require the `require_admin` dependency (which chains through
get_current_user, so the session cookie is always validated).
"""
import secrets
from datetime import datetime
from typing import Optional
import sqlalchemy as sa
from fastapi import APIRouter, Depends, HTTPException, Query, Request
from sqlalchemy.ext.asyncio import AsyncSession
from app.database import get_db
from app.models.audit_log import AuditLog
from app.models.backup_code import BackupCode
from app.models.session import UserSession
from app.models.system_config import SystemConfig
from app.models.user import User
from app.routers.auth import (
_create_user_defaults,
get_current_user,
require_admin,
)
from app.schemas.admin import (
AdminDashboardResponse,
AuditLogEntry,
AuditLogResponse,
CreateUserRequest,
ResetPasswordResponse,
SystemConfigResponse,
SystemConfigUpdate,
ToggleActiveRequest,
ToggleMfaEnforceRequest,
UpdateUserRoleRequest,
UserDetailResponse,
UserListItem,
UserListResponse,
)
from app.services.audit import log_audit_event
from app.services.auth import hash_password
# ---------------------------------------------------------------------------
# CSRF guard — SEC-08
# ---------------------------------------------------------------------------
async def verify_xhr(request: Request) -> None:
"""
Lightweight CSRF mitigation: require X-Requested-With on state-mutating
requests. Browsers never send this header cross-origin without CORS
pre-flight, which our CORS policy blocks.
"""
if request.method not in ("GET", "HEAD", "OPTIONS"):
if request.headers.get("X-Requested-With") != "XMLHttpRequest":
raise HTTPException(status_code=403, detail="Invalid request origin")
# ---------------------------------------------------------------------------
# Router — all endpoints inherit require_admin + verify_xhr
# ---------------------------------------------------------------------------
router = APIRouter(
dependencies=[Depends(require_admin), Depends(verify_xhr)],
)
# ---------------------------------------------------------------------------
# Session revocation helper (used in multiple endpoints)
# ---------------------------------------------------------------------------
async def _revoke_all_sessions(db: AsyncSession, user_id: int) -> int:
"""Mark every active session for user_id as revoked. Returns count revoked."""
result = await db.execute(
sa.update(UserSession)
.where(UserSession.user_id == user_id, UserSession.revoked == False)
.values(revoked=True)
.returning(UserSession.id)
)
return len(result.fetchall())
# ---------------------------------------------------------------------------
# Self-action guard — SEC-05
# ---------------------------------------------------------------------------
def _guard_self_action(actor: User, target_id: int, action: str) -> None:
"""Raise 403 if an admin attempts a privileged action against their own account."""
if actor.id == target_id:
raise HTTPException(
status_code=403,
detail=f"Admins cannot {action} their own account",
)
# ---------------------------------------------------------------------------
# GET /users
# ---------------------------------------------------------------------------
@router.get("/users", response_model=UserListResponse)
async def list_users(
db: AsyncSession = Depends(get_db),
_actor: User = Depends(get_current_user),
):
"""Return all users with basic stats including active session counts."""
active_sub = (
sa.select(sa.func.count())
.select_from(UserSession)
.where(
UserSession.user_id == User.id,
UserSession.revoked == False,
UserSession.expires_at > datetime.now(),
)
.correlate(User)
.scalar_subquery()
)
result = await db.execute(
sa.select(User, active_sub.label("active_sessions"))
.order_by(User.created_at)
)
rows = result.all()
return UserListResponse(
users=[
UserListItem(
**UserListItem.model_validate(user).model_dump(exclude={"active_sessions"}),
active_sessions=count,
)
for user, count in rows
],
total=len(rows),
)
# ---------------------------------------------------------------------------
# GET /users/{user_id}
# ---------------------------------------------------------------------------
@router.get("/users/{user_id}", response_model=UserDetailResponse)
async def get_user(
user_id: int,
db: AsyncSession = Depends(get_db),
_actor: User = Depends(get_current_user),
):
"""Return a single user with their active session count."""
result = await db.execute(sa.select(User).where(User.id == user_id))
user = result.scalar_one_or_none()
if not user:
raise HTTPException(status_code=404, detail="User not found")
session_result = await db.execute(
sa.select(sa.func.count()).select_from(UserSession).where(
UserSession.user_id == user_id,
UserSession.revoked == False,
UserSession.expires_at > datetime.now(),
)
)
active_sessions = session_result.scalar_one()
return UserDetailResponse(
**UserListItem.model_validate(user).model_dump(),
active_sessions=active_sessions,
)
# ---------------------------------------------------------------------------
# POST /users
# ---------------------------------------------------------------------------
@router.post("/users", response_model=UserDetailResponse, status_code=201)
async def create_user(
data: CreateUserRequest,
request: Request,
db: AsyncSession = Depends(get_db),
actor: User = Depends(get_current_user),
):
"""Admin-create a user with Settings and default calendars."""
existing = await db.execute(sa.select(User).where(User.username == data.username))
if existing.scalar_one_or_none():
raise HTTPException(status_code=409, detail="Username already taken")
new_user = User(
username=data.username,
password_hash=hash_password(data.password),
role=data.role,
last_password_change_at=datetime.now(),
# Force password change so the user sets their own credential
must_change_password=True,
)
db.add(new_user)
await db.flush() # populate new_user.id
await _create_user_defaults(db, new_user.id)
await log_audit_event(
db,
action="admin.user_created",
actor_id=actor.id,
target_id=new_user.id,
detail={"username": new_user.username, "role": new_user.role},
ip=request.client.host if request.client else None,
)
await db.commit()
return UserDetailResponse(
**UserListItem.model_validate(new_user).model_dump(),
active_sessions=0,
)
# ---------------------------------------------------------------------------
# PUT /users/{user_id}/role — SEC-02, SEC-05
# ---------------------------------------------------------------------------
@router.put("/users/{user_id}/role")
async def update_user_role(
user_id: int,
data: UpdateUserRoleRequest,
request: Request,
db: AsyncSession = Depends(get_db),
actor: User = Depends(get_current_user),
):
"""
Change a user's role.
Blocks demotion of the last admin (SEC-05 variant).
Revokes all sessions after role change (SEC-02).
"""
_guard_self_action(actor, user_id, "change role of")
result = await db.execute(sa.select(User).where(User.id == user_id))
user = result.scalar_one_or_none()
if not user:
raise HTTPException(status_code=404, detail="User not found")
# Prevent demoting the last admin
if user.role == "admin" and data.role != "admin":
admin_count = await db.scalar(
sa.select(sa.func.count()).select_from(User).where(User.role == "admin")
)
if admin_count <= 1:
raise HTTPException(
status_code=409,
detail="Cannot demote the last admin account",
)
old_role = user.role
user.role = data.role
# SEC-02: revoke sessions so the new role takes effect immediately
revoked = await _revoke_all_sessions(db, user_id)
await log_audit_event(
db,
action="admin.role_changed",
actor_id=actor.id,
target_id=user_id,
detail={"old_role": old_role, "new_role": data.role, "sessions_revoked": revoked},
ip=request.client.host if request.client else None,
)
await db.commit()
return {"message": f"Role updated to '{data.role}'. {revoked} session(s) revoked."}
# ---------------------------------------------------------------------------
# POST /users/{user_id}/reset-password — SEC-05
# ---------------------------------------------------------------------------
@router.post("/users/{user_id}/reset-password", response_model=ResetPasswordResponse)
async def reset_user_password(
user_id: int,
request: Request,
db: AsyncSession = Depends(get_db),
actor: User = Depends(get_current_user),
):
"""
Generate a temporary password, revoke all sessions, and mark must_change_password.
The admin is shown the plaintext temp password once — it is not stored.
"""
_guard_self_action(actor, user_id, "reset the password of")
result = await db.execute(sa.select(User).where(User.id == user_id))
user = result.scalar_one_or_none()
if not user:
raise HTTPException(status_code=404, detail="User not found")
temp_password = secrets.token_urlsafe(16)
user.password_hash = hash_password(temp_password)
user.must_change_password = True
user.last_password_change_at = datetime.now()
revoked = await _revoke_all_sessions(db, user_id)
await log_audit_event(
db,
action="admin.password_reset",
actor_id=actor.id,
target_id=user_id,
detail={"sessions_revoked": revoked},
ip=request.client.host if request.client else None,
)
await db.commit()
return ResetPasswordResponse(
message=f"Password reset. {revoked} session(s) revoked. User must change password on next login.",
temporary_password=temp_password,
)
# ---------------------------------------------------------------------------
# POST /users/{user_id}/disable-mfa — SEC-05, SEC-13
# ---------------------------------------------------------------------------
@router.post("/users/{user_id}/disable-mfa")
async def disable_user_mfa(
user_id: int,
request: Request,
db: AsyncSession = Depends(get_db),
actor: User = Depends(get_current_user),
):
"""
Clear TOTP secret + backup codes and revoke all sessions (SEC-13).
"""
_guard_self_action(actor, user_id, "disable MFA for")
result = await db.execute(sa.select(User).where(User.id == user_id))
user = result.scalar_one_or_none()
if not user:
raise HTTPException(status_code=404, detail="User not found")
if not user.totp_enabled:
raise HTTPException(status_code=409, detail="MFA is not enabled for this user")
# Clear TOTP data
user.totp_secret = None
user.totp_enabled = False
user.mfa_enforce_pending = False
# Remove all backup codes
await db.execute(
sa.delete(BackupCode).where(BackupCode.user_id == user_id)
)
# SEC-13: revoke sessions so the MFA downgrade takes effect immediately
revoked = await _revoke_all_sessions(db, user_id)
await log_audit_event(
db,
action="admin.mfa_disabled",
actor_id=actor.id,
target_id=user_id,
detail={"sessions_revoked": revoked},
ip=request.client.host if request.client else None,
)
await db.commit()
return {"message": f"MFA disabled. {revoked} session(s) revoked."}
# ---------------------------------------------------------------------------
# PUT /users/{user_id}/enforce-mfa — SEC-05
# ---------------------------------------------------------------------------
@router.put("/users/{user_id}/enforce-mfa")
async def toggle_mfa_enforce(
user_id: int,
data: ToggleMfaEnforceRequest,
request: Request,
db: AsyncSession = Depends(get_db),
actor: User = Depends(get_current_user),
):
"""Toggle the mfa_enforce_pending flag. Next login will prompt MFA setup."""
_guard_self_action(actor, user_id, "toggle MFA enforcement for")
result = await db.execute(sa.select(User).where(User.id == user_id))
user = result.scalar_one_or_none()
if not user:
raise HTTPException(status_code=404, detail="User not found")
user.mfa_enforce_pending = data.enforce
await log_audit_event(
db,
action="admin.mfa_enforce_toggled",
actor_id=actor.id,
target_id=user_id,
detail={"enforce": data.enforce},
ip=request.client.host if request.client else None,
)
await db.commit()
return {"message": f"MFA enforcement {'enabled' if data.enforce else 'disabled'} for user."}
# ---------------------------------------------------------------------------
# PUT /users/{user_id}/active — SEC-05
# ---------------------------------------------------------------------------
@router.put("/users/{user_id}/active")
async def toggle_user_active(
user_id: int,
data: ToggleActiveRequest,
request: Request,
db: AsyncSession = Depends(get_db),
actor: User = Depends(get_current_user),
):
"""
Enable or disable a user account.
Revoking an account also revokes all active sessions immediately.
"""
_guard_self_action(actor, user_id, "change active status of")
result = await db.execute(sa.select(User).where(User.id == user_id))
user = result.scalar_one_or_none()
if not user:
raise HTTPException(status_code=404, detail="User not found")
user.is_active = data.is_active
revoked = 0
if not data.is_active:
revoked = await _revoke_all_sessions(db, user_id)
await log_audit_event(
db,
action="admin.user_deactivated" if not data.is_active else "admin.user_activated",
actor_id=actor.id,
target_id=user_id,
detail={"sessions_revoked": revoked},
ip=request.client.host if request.client else None,
)
await db.commit()
state = "activated" if data.is_active else f"deactivated ({revoked} session(s) revoked)"
return {"message": f"User {state}."}
# ---------------------------------------------------------------------------
# DELETE /users/{user_id}/sessions
# ---------------------------------------------------------------------------
@router.delete("/users/{user_id}/sessions")
async def revoke_user_sessions(
user_id: int,
request: Request,
db: AsyncSession = Depends(get_db),
actor: User = Depends(get_current_user),
):
"""Forcibly revoke all active sessions for a user."""
result = await db.execute(sa.select(User).where(User.id == user_id))
if not result.scalar_one_or_none():
raise HTTPException(status_code=404, detail="User not found")
revoked = await _revoke_all_sessions(db, user_id)
await log_audit_event(
db,
action="admin.sessions_revoked",
actor_id=actor.id,
target_id=user_id,
detail={"sessions_revoked": revoked},
ip=request.client.host if request.client else None,
)
await db.commit()
return {"message": f"{revoked} session(s) revoked."}
# ---------------------------------------------------------------------------
# GET /users/{user_id}/sessions
# ---------------------------------------------------------------------------
@router.get("/users/{user_id}/sessions")
async def list_user_sessions(
user_id: int,
db: AsyncSession = Depends(get_db),
_actor: User = Depends(get_current_user),
):
"""List all active (non-revoked, non-expired) sessions for a user."""
result = await db.execute(sa.select(User).where(User.id == user_id))
if not result.scalar_one_or_none():
raise HTTPException(status_code=404, detail="User not found")
sessions_result = await db.execute(
sa.select(UserSession).where(
UserSession.user_id == user_id,
UserSession.revoked == False,
UserSession.expires_at > datetime.now(),
).order_by(UserSession.created_at.desc())
)
sessions = sessions_result.scalars().all()
return {
"sessions": [
{
"id": s.id,
"created_at": s.created_at,
"expires_at": s.expires_at,
"ip_address": s.ip_address,
"user_agent": s.user_agent,
}
for s in sessions
],
"total": len(sessions),
}
# ---------------------------------------------------------------------------
# GET /config
# ---------------------------------------------------------------------------
@router.get("/config", response_model=SystemConfigResponse)
async def get_system_config(
db: AsyncSession = Depends(get_db),
_actor: User = Depends(get_current_user),
):
"""Fetch the singleton system configuration row."""
result = await db.execute(sa.select(SystemConfig).where(SystemConfig.id == 1))
config = result.scalar_one_or_none()
if not config:
# Bootstrap the singleton if it doesn't exist yet
config = SystemConfig(id=1)
db.add(config)
await db.commit()
return config
# ---------------------------------------------------------------------------
# PUT /config
# ---------------------------------------------------------------------------
@router.put("/config", response_model=SystemConfigResponse)
async def update_system_config(
data: SystemConfigUpdate,
request: Request,
db: AsyncSession = Depends(get_db),
actor: User = Depends(get_current_user),
):
"""Update one or more system config fields (partial update)."""
result = await db.execute(sa.select(SystemConfig).where(SystemConfig.id == 1))
config = result.scalar_one_or_none()
if not config:
config = SystemConfig(id=1)
db.add(config)
await db.flush()
changes: dict = {}
if data.allow_registration is not None:
changes["allow_registration"] = data.allow_registration
config.allow_registration = data.allow_registration
if data.enforce_mfa_new_users is not None:
changes["enforce_mfa_new_users"] = data.enforce_mfa_new_users
config.enforce_mfa_new_users = data.enforce_mfa_new_users
if changes:
await log_audit_event(
db,
action="admin.config_updated",
actor_id=actor.id,
detail=changes,
ip=request.client.host if request.client else None,
)
await db.commit()
return config
# ---------------------------------------------------------------------------
# GET /dashboard
# ---------------------------------------------------------------------------
@router.get("/dashboard", response_model=AdminDashboardResponse)
async def admin_dashboard(
db: AsyncSession = Depends(get_db),
_actor: User = Depends(get_current_user),
):
"""Aggregate stats for the admin portal dashboard."""
total_users = await db.scalar(
sa.select(sa.func.count()).select_from(User)
)
active_users = await db.scalar(
sa.select(sa.func.count()).select_from(User).where(User.is_active == True)
)
admin_count = await db.scalar(
sa.select(sa.func.count()).select_from(User).where(User.role == "admin")
)
totp_count = await db.scalar(
sa.select(sa.func.count()).select_from(User).where(User.totp_enabled == True)
)
active_sessions = await db.scalar(
sa.select(sa.func.count()).select_from(UserSession).where(
UserSession.revoked == False,
UserSession.expires_at > datetime.now(),
)
)
mfa_adoption = (totp_count / total_users) if total_users else 0.0
# 10 most recent logins
recent_logins_result = await db.execute(
sa.select(User.username, User.last_login_at)
.where(User.last_login_at != None)
.order_by(User.last_login_at.desc())
.limit(10)
)
recent_logins = [
{"username": row.username, "last_login_at": row.last_login_at}
for row in recent_logins_result
]
# 10 most recent audit entries — resolve usernames via JOINs
actor_user = sa.orm.aliased(User, name="actor_user")
target_user = sa.orm.aliased(User, name="target_user")
recent_audit_result = await db.execute(
sa.select(
AuditLog,
actor_user.username.label("actor_username"),
target_user.username.label("target_username"),
)
.outerjoin(actor_user, AuditLog.actor_user_id == actor_user.id)
.outerjoin(target_user, AuditLog.target_user_id == target_user.id)
.order_by(AuditLog.created_at.desc())
.limit(10)
)
recent_audit_entries = [
{
"action": row.AuditLog.action,
"actor_username": row.actor_username,
"target_username": row.target_username,
"created_at": row.AuditLog.created_at,
}
for row in recent_audit_result
]
return AdminDashboardResponse(
total_users=total_users or 0,
active_users=active_users or 0,
admin_count=admin_count or 0,
active_sessions=active_sessions or 0,
mfa_adoption_rate=round(mfa_adoption, 4),
recent_logins=recent_logins,
recent_audit_entries=recent_audit_entries,
)
# ---------------------------------------------------------------------------
# GET /audit-log
# ---------------------------------------------------------------------------
@router.get("/audit-log", response_model=AuditLogResponse)
async def get_audit_log(
db: AsyncSession = Depends(get_db),
_actor: User = Depends(get_current_user),
action: Optional[str] = Query(None, description="Filter by action string (prefix match)"),
target_user_id: Optional[int] = Query(None, description="Filter by target user ID"),
page: int = Query(1, ge=1, description="Page number (1-indexed)"),
per_page: int = Query(50, ge=1, le=200, description="Results per page"),
):
"""
Paginated audit log with optional filters.
Resolves actor and target user IDs to usernames via a JOIN.
"""
# Aliases for the two user joins
actor_user = sa.orm.aliased(User, name="actor_user")
target_user = sa.orm.aliased(User, name="target_user")
# Base query — left outer join so entries with NULL actor/target still appear
base_q = (
sa.select(
AuditLog,
actor_user.username.label("actor_username"),
target_user.username.label("target_username"),
)
.outerjoin(actor_user, AuditLog.actor_user_id == actor_user.id)
.outerjoin(target_user, AuditLog.target_user_id == target_user.id)
)
if action:
base_q = base_q.where(AuditLog.action.like(f"{action}%"))
if target_user_id is not None:
base_q = base_q.where(AuditLog.target_user_id == target_user_id)
# Count before pagination
count_q = sa.select(sa.func.count()).select_from(
base_q.subquery()
)
total = await db.scalar(count_q) or 0
# Paginate
offset = (page - 1) * per_page
rows_result = await db.execute(
base_q.order_by(AuditLog.created_at.desc()).offset(offset).limit(per_page)
)
entries = [
AuditLogEntry(
id=row.AuditLog.id,
actor_username=row.actor_username,
target_username=row.target_username,
action=row.AuditLog.action,
detail=row.AuditLog.detail,
ip_address=row.AuditLog.ip_address,
created_at=row.AuditLog.created_at,
)
for row in rows_result
]
return AuditLogResponse(entries=entries, total=total)