Phase 0: Consolidate session creation into shared service

Extract _create_db_session, _set_session_cookie, _check_account_lockout,
_record_failed_login, and _record_successful_login from auth.py into
services/session.py. Update totp.py to use shared service instead of
its duplicate _create_full_session (which lacked session cap enforcement).

Also fixes:
- auth/status N+1 query (2 sequential queries -> single JOIN)
- Rename verify_password route to verify_password_endpoint (shadow fix)

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
Kyle 2026-03-17 22:40:46 +08:00
parent c5a309f4a1
commit eebb34aa77
3 changed files with 155 additions and 158 deletions

View File

@ -16,7 +16,6 @@ Security layers:
4. bcryptArgon2id transparent upgrade on first login 4. bcryptArgon2id transparent upgrade on first login
5. Role-based authorization via require_role() dependency factory 5. Role-based authorization via require_role() dependency factory
""" """
import uuid
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import Optional from typing import Optional
@ -49,6 +48,13 @@ from app.services.auth import (
create_mfa_enforce_token, create_mfa_enforce_token,
) )
from app.services.audit import get_client_ip, log_audit_event from app.services.audit import get_client_ip, log_audit_event
from app.services.session import (
set_session_cookie,
check_account_lockout,
record_failed_login,
record_successful_login,
create_db_session,
)
from app.config import settings as app_settings from app.config import settings as app_settings
router = APIRouter() router = APIRouter()
@ -59,22 +65,6 @@ router = APIRouter()
# is indistinguishable from a wrong-password attempt. # is indistinguishable from a wrong-password attempt.
_DUMMY_HASH = hash_password("timing-equalization-dummy") _DUMMY_HASH = hash_password("timing-equalization-dummy")
# ---------------------------------------------------------------------------
# Cookie helper
# ---------------------------------------------------------------------------
def _set_session_cookie(response: Response, token: str) -> None:
response.set_cookie(
key="session",
value=token,
httponly=True,
secure=app_settings.COOKIE_SECURE,
max_age=app_settings.SESSION_MAX_AGE_DAYS * 86400,
samesite="lax",
path="/",
)
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Auth dependencies — export get_current_user and get_current_settings # Auth dependencies — export get_current_user and get_current_settings
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@ -130,7 +120,7 @@ async def get_current_user(
await db.flush() await db.flush()
# Re-issue cookie with fresh signed token to reset browser max_age timer # Re-issue cookie with fresh signed token to reset browser max_age timer
fresh_token = create_session_token(user_id, session_id) fresh_token = create_session_token(user_id, session_id)
_set_session_cookie(response, fresh_token) set_session_cookie(response, fresh_token)
# Stash session on request so lock/unlock endpoints can access it # Stash session on request so lock/unlock endpoints can access it
request.state.db_session = db_session request.state.db_session = db_session
@ -190,82 +180,6 @@ def require_role(*allowed_roles: str):
require_admin = require_role("admin") require_admin = require_role("admin")
# ---------------------------------------------------------------------------
# Account lockout helpers
# ---------------------------------------------------------------------------
async def _check_account_lockout(user: User) -> None:
"""Raise HTTP 423 if the account is currently locked."""
if user.locked_until and datetime.now() < user.locked_until:
remaining = int((user.locked_until - datetime.now()).total_seconds() / 60) + 1
raise HTTPException(
status_code=423,
detail=f"Account locked. Try again in {remaining} minutes.",
)
async def _record_failed_login(db: AsyncSession, user: User) -> None:
"""Increment failure counter; lock account after 10 failures."""
user.failed_login_count += 1
if user.failed_login_count >= 10:
user.locked_until = datetime.now() + timedelta(minutes=30)
await db.commit()
async def _record_successful_login(db: AsyncSession, user: User) -> None:
"""Reset failure counter and update last_login_at."""
user.failed_login_count = 0
user.locked_until = None
user.last_login_at = datetime.now()
await db.commit()
# ---------------------------------------------------------------------------
# Session creation helper
# ---------------------------------------------------------------------------
async def _create_db_session(
db: AsyncSession,
user: User,
ip: str,
user_agent: str | None,
) -> tuple[str, str]:
"""Insert a UserSession row and return (session_id, signed_cookie_token)."""
session_id = uuid.uuid4().hex
expires_at = datetime.now() + timedelta(days=app_settings.SESSION_MAX_AGE_DAYS)
db_session = UserSession(
id=session_id,
user_id=user.id,
expires_at=expires_at,
ip_address=ip[:45] if ip else None,
user_agent=(user_agent or "")[:255] if user_agent else None,
)
db.add(db_session)
await db.flush()
# Enforce concurrent session limit: revoke oldest sessions beyond the cap
active_sessions = (
await db.execute(
select(UserSession)
.where(
UserSession.user_id == user.id,
UserSession.revoked == False, # noqa: E712
UserSession.expires_at > datetime.now(),
)
.order_by(UserSession.created_at.asc())
)
).scalars().all()
max_sessions = app_settings.MAX_SESSIONS_PER_USER
if len(active_sessions) > max_sessions:
for old_session in active_sessions[: len(active_sessions) - max_sessions]:
old_session.revoked = True
await db.flush()
token = create_session_token(user.id, session_id)
return session_id, token
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# User bootstrapping helper (Settings + default calendars) # User bootstrapping helper (Settings + default calendars)
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@ -321,8 +235,8 @@ async def setup(
ip = get_client_ip(request) ip = get_client_ip(request)
user_agent = request.headers.get("user-agent") user_agent = request.headers.get("user-agent")
_, token = await _create_db_session(db, new_user, ip, user_agent) _, token = await create_db_session(db, new_user, ip, user_agent)
_set_session_cookie(response, token) set_session_cookie(response, token)
await log_audit_event( await log_audit_event(
db, action="auth.setup_complete", actor_id=new_user.id, ip=ip, db, action="auth.setup_complete", actor_id=new_user.id, ip=ip,
@ -366,10 +280,10 @@ async def login(
# executes — prevents distinguishing "locked" from "wrong password" via timing. # executes — prevents distinguishing "locked" from "wrong password" via timing.
valid, new_hash = await averify_password_with_upgrade(data.password, user.password_hash) valid, new_hash = await averify_password_with_upgrade(data.password, user.password_hash)
await _check_account_lockout(user) await check_account_lockout(user)
if not valid: if not valid:
await _record_failed_login(db, user) await record_failed_login(db, user)
await log_audit_event( await log_audit_event(
db, action="auth.login_failed", actor_id=user.id, db, action="auth.login_failed", actor_id=user.id,
detail={"reason": "invalid_password"}, ip=client_ip, detail={"reason": "invalid_password"}, ip=client_ip,
@ -378,7 +292,7 @@ async def login(
raise HTTPException(status_code=401, detail="Invalid username or password") raise HTTPException(status_code=401, detail="Invalid username or password")
# Block disabled accounts — checked AFTER password verification to avoid # Block disabled accounts — checked AFTER password verification to avoid
# leaking account-state info, and BEFORE _record_successful_login so # leaking account-state info, and BEFORE record_successful_login so
# last_login_at and lockout counters are not reset for inactive users. # last_login_at and lockout counters are not reset for inactive users.
if not user.is_active: if not user.is_active:
await log_audit_event( await log_audit_event(
@ -391,7 +305,7 @@ async def login(
if new_hash: if new_hash:
user.password_hash = new_hash user.password_hash = new_hash
await _record_successful_login(db, user) await record_successful_login(db, user)
# SEC-03: MFA enforcement — block login entirely until MFA setup completes # SEC-03: MFA enforcement — block login entirely until MFA setup completes
if user.mfa_enforce_pending and not user.totp_enabled: if user.mfa_enforce_pending and not user.totp_enabled:
@ -419,8 +333,8 @@ async def login(
if user.must_change_password: if user.must_change_password:
# Issue a session but flag the frontend to show password change # Issue a session but flag the frontend to show password change
user_agent = request.headers.get("user-agent") user_agent = request.headers.get("user-agent")
_, token = await _create_db_session(db, user, client_ip, user_agent) _, token = await create_db_session(db, user, client_ip, user_agent)
_set_session_cookie(response, token) set_session_cookie(response, token)
await db.commit() await db.commit()
return { return {
"authenticated": True, "authenticated": True,
@ -428,8 +342,8 @@ async def login(
} }
user_agent = request.headers.get("user-agent") user_agent = request.headers.get("user-agent")
_, token = await _create_db_session(db, user, client_ip, user_agent) _, token = await create_db_session(db, user, client_ip, user_agent)
_set_session_cookie(response, token) set_session_cookie(response, token)
await log_audit_event( await log_audit_event(
db, action="auth.login_success", actor_id=user.id, ip=client_ip, db, action="auth.login_success", actor_id=user.id, ip=client_ip,
@ -511,8 +425,8 @@ async def register(
"mfa_token": enforce_token, "mfa_token": enforce_token,
} }
_, token = await _create_db_session(db, new_user, ip, user_agent) _, token = await create_db_session(db, new_user, ip, user_agent)
_set_session_cookie(response, token) set_session_cookie(response, token)
await db.commit() await db.commit()
return {"message": "Registration successful", "authenticated": True} return {"message": "Registration successful", "authenticated": True}
@ -564,32 +478,31 @@ async def auth_status(
is_locked = False is_locked = False
u = None
if not setup_required and session_cookie: if not setup_required and session_cookie:
payload = verify_session_token(session_cookie) payload = verify_session_token(session_cookie)
if payload: if payload:
user_id = payload.get("uid") user_id = payload.get("uid")
session_id = payload.get("sid") session_id = payload.get("sid")
if user_id and session_id: if user_id and session_id:
session_result = await db.execute( # Single JOIN query (was 2 sequential queries — P-01 fix)
select(UserSession).where( result = await db.execute(
select(UserSession, User)
.join(User, UserSession.user_id == User.id)
.where(
UserSession.id == session_id, UserSession.id == session_id,
UserSession.user_id == user_id, UserSession.user_id == user_id,
UserSession.revoked == False, UserSession.revoked == False,
UserSession.expires_at > datetime.now(), UserSession.expires_at > datetime.now(),
User.is_active == True,
) )
) )
db_sess = session_result.scalar_one_or_none() row = result.one_or_none()
if db_sess is not None: if row is not None:
db_sess, u = row.tuple()
authenticated = True authenticated = True
is_locked = db_sess.is_locked is_locked = db_sess.is_locked
user_obj_result = await db.execute( role = u.role
select(User).where(User.id == user_id, User.is_active == True)
)
u = user_obj_result.scalar_one_or_none()
if u:
role = u.role
else:
authenticated = False
# Check registration availability # Check registration availability
registration_open = False registration_open = False
@ -625,7 +538,7 @@ async def lock_session(
@router.post("/verify-password") @router.post("/verify-password")
async def verify_password( async def verify_password_endpoint(
data: VerifyPasswordRequest, data: VerifyPasswordRequest,
request: Request, request: Request,
db: AsyncSession = Depends(get_db), db: AsyncSession = Depends(get_db),
@ -635,11 +548,11 @@ async def verify_password(
Verify the current user's password without changing anything. Verify the current user's password without changing anything.
Used by the frontend lock screen to re-authenticate without a full login. Used by the frontend lock screen to re-authenticate without a full login.
""" """
await _check_account_lockout(current_user) await check_account_lockout(current_user)
valid, new_hash = await averify_password_with_upgrade(data.password, current_user.password_hash) valid, new_hash = await averify_password_with_upgrade(data.password, current_user.password_hash)
if not valid: if not valid:
await _record_failed_login(db, current_user) await record_failed_login(db, current_user)
raise HTTPException(status_code=401, detail="Invalid password") raise HTTPException(status_code=401, detail="Invalid password")
if new_hash: if new_hash:
@ -661,11 +574,11 @@ async def change_password(
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ):
"""Change the current user's password. Requires old password verification.""" """Change the current user's password. Requires old password verification."""
await _check_account_lockout(current_user) await check_account_lockout(current_user)
valid, _ = await averify_password_with_upgrade(data.old_password, current_user.password_hash) valid, _ = await averify_password_with_upgrade(data.old_password, current_user.password_hash)
if not valid: if not valid:
await _record_failed_login(db, current_user) await record_failed_login(db, current_user)
raise HTTPException(status_code=401, detail="Invalid current password") raise HTTPException(status_code=401, detail="Invalid current password")
if data.new_password == data.old_password: if data.new_password == data.old_password:

View File

@ -18,7 +18,6 @@ Security:
- totp-verify uses mfa_token (not session cookie) user is not yet authenticated - totp-verify uses mfa_token (not session cookie) user is not yet authenticated
""" """
import asyncio import asyncio
import uuid
import secrets import secrets
import logging import logging
from datetime import datetime, timedelta from datetime import datetime, timedelta
@ -32,17 +31,16 @@ from sqlalchemy.exc import IntegrityError
from app.database import get_db from app.database import get_db
from app.models.user import User from app.models.user import User
from app.models.session import UserSession
from app.models.totp_usage import TOTPUsage from app.models.totp_usage import TOTPUsage
from app.models.backup_code import BackupCode from app.models.backup_code import BackupCode
from app.routers.auth import get_current_user, _set_session_cookie from app.routers.auth import get_current_user
from app.services.audit import get_client_ip from app.services.audit import get_client_ip
from app.services.auth import ( from app.services.auth import (
averify_password_with_upgrade, averify_password_with_upgrade,
verify_mfa_token, verify_mfa_token,
verify_mfa_enforce_token, verify_mfa_enforce_token,
create_session_token,
) )
from app.services.session import create_db_session, set_session_cookie
from app.services.totp import ( from app.services.totp import (
generate_totp_secret, generate_totp_secret,
encrypt_totp_secret, encrypt_totp_secret,
@ -52,7 +50,7 @@ from app.services.totp import (
generate_qr_base64, generate_qr_base64,
generate_backup_codes, generate_backup_codes,
) )
from app.config import settings as app_settings
# Argon2id for backup code hashing — treat each code like a password # Argon2id for backup code hashing — treat each code like a password
from argon2 import PasswordHasher from argon2 import PasswordHasher
@ -162,29 +160,6 @@ async def _verify_backup_code(
return False return False
async def _create_full_session(
db: AsyncSession,
user: User,
request: Request,
) -> str:
"""Create a UserSession row and return the signed cookie token."""
session_id = uuid.uuid4().hex
expires_at = datetime.now() + timedelta(days=app_settings.SESSION_MAX_AGE_DAYS)
ip = get_client_ip(request)
user_agent = request.headers.get("user-agent")
db_session = UserSession(
id=session_id,
user_id=user.id,
expires_at=expires_at,
ip_address=ip[:45] if ip else None,
user_agent=(user_agent or "")[:255] if user_agent else None,
)
db.add(db_session)
await db.commit()
return create_session_token(user.id, session_id)
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Routes # Routes
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@ -312,8 +287,10 @@ async def totp_verify(
user.last_login_at = datetime.now() user.last_login_at = datetime.now()
await db.commit() await db.commit()
token = await _create_full_session(db, user, request) ip = get_client_ip(request)
_set_session_cookie(response, token) user_agent = request.headers.get("user-agent")
_, token = await create_db_session(db, user, ip, user_agent)
set_session_cookie(response, token)
return {"authenticated": True} return {"authenticated": True}
# --- TOTP code path --- # --- TOTP code path ---
@ -340,8 +317,10 @@ async def totp_verify(
user.last_login_at = datetime.now() user.last_login_at = datetime.now()
await db.commit() await db.commit()
token = await _create_full_session(db, user, request) ip = get_client_ip(request)
_set_session_cookie(response, token) user_agent = request.headers.get("user-agent")
_, token = await create_db_session(db, user, ip, user_agent)
set_session_cookie(response, token)
return {"authenticated": True} return {"authenticated": True}
@ -513,9 +492,11 @@ async def enforce_confirm_totp(
user.last_login_at = datetime.now() user.last_login_at = datetime.now()
await db.commit() await db.commit()
# Issue a full session # Issue a full session (now uses shared session service with cap enforcement)
token = await _create_full_session(db, user, request) ip = get_client_ip(request)
_set_session_cookie(response, token) user_agent = request.headers.get("user-agent")
_, token = await create_db_session(db, user, ip, user_agent)
set_session_cookie(response, token)
return {"authenticated": True} return {"authenticated": True}

View File

@ -0,0 +1,103 @@
"""
Shared session management service.
Consolidates session creation, cookie handling, and account lockout logic
that was previously duplicated between auth.py and totp.py routers.
All auth paths (password, TOTP, passkey) use these functions to ensure
consistent session cap enforcement and lockout behavior.
"""
import uuid
from datetime import datetime, timedelta
from fastapi import HTTPException, Response
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
from app.models.user import User
from app.models.session import UserSession
from app.services.auth import create_session_token
from app.config import settings as app_settings
def set_session_cookie(response: Response, token: str) -> None:
"""Set httpOnly secure signed cookie on response."""
response.set_cookie(
key="session",
value=token,
httponly=True,
secure=app_settings.COOKIE_SECURE,
max_age=app_settings.SESSION_MAX_AGE_DAYS * 86400,
samesite="lax",
path="/",
)
async def check_account_lockout(user: User) -> None:
"""Raise HTTP 423 if the account is currently locked."""
if user.locked_until and datetime.now() < user.locked_until:
remaining = int((user.locked_until - datetime.now()).total_seconds() / 60) + 1
raise HTTPException(
status_code=423,
detail=f"Account locked. Try again in {remaining} minutes.",
)
async def record_failed_login(db: AsyncSession, user: User) -> None:
"""Increment failure counter; lock account after 10 failures."""
user.failed_login_count += 1
if user.failed_login_count >= 10:
user.locked_until = datetime.now() + timedelta(minutes=30)
await db.commit()
async def record_successful_login(db: AsyncSession, user: User) -> None:
"""Reset failure counter and update last_login_at."""
user.failed_login_count = 0
user.locked_until = None
user.last_login_at = datetime.now()
await db.commit()
async def create_db_session(
db: AsyncSession,
user: User,
ip: str,
user_agent: str | None,
) -> tuple[str, str]:
"""Insert a UserSession row and return (session_id, signed_cookie_token).
Enforces MAX_SESSIONS_PER_USER by revoking oldest sessions beyond the cap.
"""
session_id = uuid.uuid4().hex
expires_at = datetime.now() + timedelta(days=app_settings.SESSION_MAX_AGE_DAYS)
db_session = UserSession(
id=session_id,
user_id=user.id,
expires_at=expires_at,
ip_address=ip[:45] if ip else None,
user_agent=(user_agent or "")[:255] if user_agent else None,
)
db.add(db_session)
await db.flush()
# Enforce concurrent session limit: revoke oldest sessions beyond the cap
active_sessions = (
await db.execute(
select(UserSession)
.where(
UserSession.user_id == user.id,
UserSession.revoked == False, # noqa: E712
UserSession.expires_at > datetime.now(),
)
.order_by(UserSession.created_at.asc())
)
).scalars().all()
max_sessions = app_settings.MAX_SESSIONS_PER_USER
if len(active_sessions) > max_sessions:
for old_session in active_sessions[: len(active_sessions) - max_sessions]:
old_session.revoked = True
await db.flush()
token = create_session_token(user.id, session_id)
return session_id, token