From eebb34aa778c235b776504818fe412275caa10f9 Mon Sep 17 00:00:00 2001 From: Kyle Pope Date: Tue, 17 Mar 2026 22:40:46 +0800 Subject: [PATCH] Phase 0: Consolidate session creation into shared service Extract _create_db_session, _set_session_cookie, _check_account_lockout, _record_failed_login, and _record_successful_login from auth.py into services/session.py. Update totp.py to use shared service instead of its duplicate _create_full_session (which lacked session cap enforcement). Also fixes: - auth/status N+1 query (2 sequential queries -> single JOIN) - Rename verify_password route to verify_password_endpoint (shadow fix) Co-Authored-By: Claude Opus 4.6 (1M context) --- backend/app/routers/auth.py | 159 ++++++++------------------------ backend/app/routers/totp.py | 51 ++++------ backend/app/services/session.py | 103 +++++++++++++++++++++ 3 files changed, 155 insertions(+), 158 deletions(-) create mode 100644 backend/app/services/session.py diff --git a/backend/app/routers/auth.py b/backend/app/routers/auth.py index 6a85c6b..fe97a1f 100644 --- a/backend/app/routers/auth.py +++ b/backend/app/routers/auth.py @@ -16,7 +16,6 @@ Security layers: 4. bcrypt→Argon2id transparent upgrade on first login 5. Role-based authorization via require_role() dependency factory """ -import uuid from datetime import datetime, timedelta from typing import Optional @@ -49,6 +48,13 @@ from app.services.auth import ( create_mfa_enforce_token, ) from app.services.audit import get_client_ip, log_audit_event +from app.services.session import ( + set_session_cookie, + check_account_lockout, + record_failed_login, + record_successful_login, + create_db_session, +) from app.config import settings as app_settings router = APIRouter() @@ -59,22 +65,6 @@ router = APIRouter() # is indistinguishable from a wrong-password attempt. _DUMMY_HASH = hash_password("timing-equalization-dummy") -# --------------------------------------------------------------------------- -# Cookie helper -# --------------------------------------------------------------------------- - -def _set_session_cookie(response: Response, token: str) -> None: - response.set_cookie( - key="session", - value=token, - httponly=True, - secure=app_settings.COOKIE_SECURE, - max_age=app_settings.SESSION_MAX_AGE_DAYS * 86400, - samesite="lax", - path="/", - ) - - # --------------------------------------------------------------------------- # Auth dependencies — export get_current_user and get_current_settings # --------------------------------------------------------------------------- @@ -130,7 +120,7 @@ async def get_current_user( await db.flush() # Re-issue cookie with fresh signed token to reset browser max_age timer fresh_token = create_session_token(user_id, session_id) - _set_session_cookie(response, fresh_token) + set_session_cookie(response, fresh_token) # Stash session on request so lock/unlock endpoints can access it request.state.db_session = db_session @@ -190,82 +180,6 @@ def require_role(*allowed_roles: str): require_admin = require_role("admin") -# --------------------------------------------------------------------------- -# Account lockout helpers -# --------------------------------------------------------------------------- - -async def _check_account_lockout(user: User) -> None: - """Raise HTTP 423 if the account is currently locked.""" - if user.locked_until and datetime.now() < user.locked_until: - remaining = int((user.locked_until - datetime.now()).total_seconds() / 60) + 1 - raise HTTPException( - status_code=423, - detail=f"Account locked. Try again in {remaining} minutes.", - ) - - -async def _record_failed_login(db: AsyncSession, user: User) -> None: - """Increment failure counter; lock account after 10 failures.""" - user.failed_login_count += 1 - if user.failed_login_count >= 10: - user.locked_until = datetime.now() + timedelta(minutes=30) - await db.commit() - - -async def _record_successful_login(db: AsyncSession, user: User) -> None: - """Reset failure counter and update last_login_at.""" - user.failed_login_count = 0 - user.locked_until = None - user.last_login_at = datetime.now() - await db.commit() - - -# --------------------------------------------------------------------------- -# Session creation helper -# --------------------------------------------------------------------------- - -async def _create_db_session( - db: AsyncSession, - user: User, - ip: str, - user_agent: str | None, -) -> tuple[str, str]: - """Insert a UserSession row and return (session_id, signed_cookie_token).""" - session_id = uuid.uuid4().hex - expires_at = datetime.now() + timedelta(days=app_settings.SESSION_MAX_AGE_DAYS) - db_session = UserSession( - id=session_id, - user_id=user.id, - expires_at=expires_at, - ip_address=ip[:45] if ip else None, - user_agent=(user_agent or "")[:255] if user_agent else None, - ) - db.add(db_session) - await db.flush() - - # Enforce concurrent session limit: revoke oldest sessions beyond the cap - active_sessions = ( - await db.execute( - select(UserSession) - .where( - UserSession.user_id == user.id, - UserSession.revoked == False, # noqa: E712 - UserSession.expires_at > datetime.now(), - ) - .order_by(UserSession.created_at.asc()) - ) - ).scalars().all() - - max_sessions = app_settings.MAX_SESSIONS_PER_USER - if len(active_sessions) > max_sessions: - for old_session in active_sessions[: len(active_sessions) - max_sessions]: - old_session.revoked = True - await db.flush() - - token = create_session_token(user.id, session_id) - return session_id, token - - # --------------------------------------------------------------------------- # User bootstrapping helper (Settings + default calendars) # --------------------------------------------------------------------------- @@ -321,8 +235,8 @@ async def setup( ip = get_client_ip(request) user_agent = request.headers.get("user-agent") - _, token = await _create_db_session(db, new_user, ip, user_agent) - _set_session_cookie(response, token) + _, token = await create_db_session(db, new_user, ip, user_agent) + set_session_cookie(response, token) await log_audit_event( db, action="auth.setup_complete", actor_id=new_user.id, ip=ip, @@ -366,10 +280,10 @@ async def login( # executes — prevents distinguishing "locked" from "wrong password" via timing. valid, new_hash = await averify_password_with_upgrade(data.password, user.password_hash) - await _check_account_lockout(user) + await check_account_lockout(user) if not valid: - await _record_failed_login(db, user) + await record_failed_login(db, user) await log_audit_event( db, action="auth.login_failed", actor_id=user.id, detail={"reason": "invalid_password"}, ip=client_ip, @@ -378,7 +292,7 @@ async def login( raise HTTPException(status_code=401, detail="Invalid username or password") # Block disabled accounts — checked AFTER password verification to avoid - # leaking account-state info, and BEFORE _record_successful_login so + # leaking account-state info, and BEFORE record_successful_login so # last_login_at and lockout counters are not reset for inactive users. if not user.is_active: await log_audit_event( @@ -391,7 +305,7 @@ async def login( if new_hash: user.password_hash = new_hash - await _record_successful_login(db, user) + await record_successful_login(db, user) # SEC-03: MFA enforcement — block login entirely until MFA setup completes if user.mfa_enforce_pending and not user.totp_enabled: @@ -419,8 +333,8 @@ async def login( if user.must_change_password: # Issue a session but flag the frontend to show password change user_agent = request.headers.get("user-agent") - _, token = await _create_db_session(db, user, client_ip, user_agent) - _set_session_cookie(response, token) + _, token = await create_db_session(db, user, client_ip, user_agent) + set_session_cookie(response, token) await db.commit() return { "authenticated": True, @@ -428,8 +342,8 @@ async def login( } user_agent = request.headers.get("user-agent") - _, token = await _create_db_session(db, user, client_ip, user_agent) - _set_session_cookie(response, token) + _, token = await create_db_session(db, user, client_ip, user_agent) + set_session_cookie(response, token) await log_audit_event( db, action="auth.login_success", actor_id=user.id, ip=client_ip, @@ -511,8 +425,8 @@ async def register( "mfa_token": enforce_token, } - _, token = await _create_db_session(db, new_user, ip, user_agent) - _set_session_cookie(response, token) + _, token = await create_db_session(db, new_user, ip, user_agent) + set_session_cookie(response, token) await db.commit() return {"message": "Registration successful", "authenticated": True} @@ -564,32 +478,31 @@ async def auth_status( is_locked = False + u = None if not setup_required and session_cookie: payload = verify_session_token(session_cookie) if payload: user_id = payload.get("uid") session_id = payload.get("sid") if user_id and session_id: - session_result = await db.execute( - select(UserSession).where( + # Single JOIN query (was 2 sequential queries — P-01 fix) + result = await db.execute( + select(UserSession, User) + .join(User, UserSession.user_id == User.id) + .where( UserSession.id == session_id, UserSession.user_id == user_id, UserSession.revoked == False, UserSession.expires_at > datetime.now(), + User.is_active == True, ) ) - db_sess = session_result.scalar_one_or_none() - if db_sess is not None: + row = result.one_or_none() + if row is not None: + db_sess, u = row.tuple() authenticated = True is_locked = db_sess.is_locked - user_obj_result = await db.execute( - select(User).where(User.id == user_id, User.is_active == True) - ) - u = user_obj_result.scalar_one_or_none() - if u: - role = u.role - else: - authenticated = False + role = u.role # Check registration availability registration_open = False @@ -625,7 +538,7 @@ async def lock_session( @router.post("/verify-password") -async def verify_password( +async def verify_password_endpoint( data: VerifyPasswordRequest, request: Request, db: AsyncSession = Depends(get_db), @@ -635,11 +548,11 @@ async def verify_password( Verify the current user's password without changing anything. Used by the frontend lock screen to re-authenticate without a full login. """ - await _check_account_lockout(current_user) + await check_account_lockout(current_user) valid, new_hash = await averify_password_with_upgrade(data.password, current_user.password_hash) if not valid: - await _record_failed_login(db, current_user) + await record_failed_login(db, current_user) raise HTTPException(status_code=401, detail="Invalid password") if new_hash: @@ -661,11 +574,11 @@ async def change_password( current_user: User = Depends(get_current_user), ): """Change the current user's password. Requires old password verification.""" - await _check_account_lockout(current_user) + await check_account_lockout(current_user) valid, _ = await averify_password_with_upgrade(data.old_password, current_user.password_hash) if not valid: - await _record_failed_login(db, current_user) + await record_failed_login(db, current_user) raise HTTPException(status_code=401, detail="Invalid current password") if data.new_password == data.old_password: diff --git a/backend/app/routers/totp.py b/backend/app/routers/totp.py index 9839b18..39d84cf 100644 --- a/backend/app/routers/totp.py +++ b/backend/app/routers/totp.py @@ -18,7 +18,6 @@ Security: - totp-verify uses mfa_token (not session cookie) — user is not yet authenticated """ import asyncio -import uuid import secrets import logging from datetime import datetime, timedelta @@ -32,17 +31,16 @@ from sqlalchemy.exc import IntegrityError from app.database import get_db from app.models.user import User -from app.models.session import UserSession from app.models.totp_usage import TOTPUsage from app.models.backup_code import BackupCode -from app.routers.auth import get_current_user, _set_session_cookie +from app.routers.auth import get_current_user from app.services.audit import get_client_ip from app.services.auth import ( averify_password_with_upgrade, verify_mfa_token, verify_mfa_enforce_token, - create_session_token, ) +from app.services.session import create_db_session, set_session_cookie from app.services.totp import ( generate_totp_secret, encrypt_totp_secret, @@ -52,7 +50,7 @@ from app.services.totp import ( generate_qr_base64, generate_backup_codes, ) -from app.config import settings as app_settings + # Argon2id for backup code hashing — treat each code like a password from argon2 import PasswordHasher @@ -162,29 +160,6 @@ async def _verify_backup_code( return False -async def _create_full_session( - db: AsyncSession, - user: User, - request: Request, -) -> str: - """Create a UserSession row and return the signed cookie token.""" - session_id = uuid.uuid4().hex - expires_at = datetime.now() + timedelta(days=app_settings.SESSION_MAX_AGE_DAYS) - ip = get_client_ip(request) - user_agent = request.headers.get("user-agent") - - db_session = UserSession( - id=session_id, - user_id=user.id, - expires_at=expires_at, - ip_address=ip[:45] if ip else None, - user_agent=(user_agent or "")[:255] if user_agent else None, - ) - db.add(db_session) - await db.commit() - return create_session_token(user.id, session_id) - - # --------------------------------------------------------------------------- # Routes # --------------------------------------------------------------------------- @@ -312,8 +287,10 @@ async def totp_verify( user.last_login_at = datetime.now() await db.commit() - token = await _create_full_session(db, user, request) - _set_session_cookie(response, token) + ip = get_client_ip(request) + user_agent = request.headers.get("user-agent") + _, token = await create_db_session(db, user, ip, user_agent) + set_session_cookie(response, token) return {"authenticated": True} # --- TOTP code path --- @@ -340,8 +317,10 @@ async def totp_verify( user.last_login_at = datetime.now() await db.commit() - token = await _create_full_session(db, user, request) - _set_session_cookie(response, token) + ip = get_client_ip(request) + user_agent = request.headers.get("user-agent") + _, token = await create_db_session(db, user, ip, user_agent) + set_session_cookie(response, token) return {"authenticated": True} @@ -513,9 +492,11 @@ async def enforce_confirm_totp( user.last_login_at = datetime.now() await db.commit() - # Issue a full session - token = await _create_full_session(db, user, request) - _set_session_cookie(response, token) + # Issue a full session (now uses shared session service with cap enforcement) + ip = get_client_ip(request) + user_agent = request.headers.get("user-agent") + _, token = await create_db_session(db, user, ip, user_agent) + set_session_cookie(response, token) return {"authenticated": True} diff --git a/backend/app/services/session.py b/backend/app/services/session.py new file mode 100644 index 0000000..d687fb9 --- /dev/null +++ b/backend/app/services/session.py @@ -0,0 +1,103 @@ +""" +Shared session management service. + +Consolidates session creation, cookie handling, and account lockout logic +that was previously duplicated between auth.py and totp.py routers. +All auth paths (password, TOTP, passkey) use these functions to ensure +consistent session cap enforcement and lockout behavior. +""" +import uuid +from datetime import datetime, timedelta + +from fastapi import HTTPException, Response +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy import select + +from app.models.user import User +from app.models.session import UserSession +from app.services.auth import create_session_token +from app.config import settings as app_settings + + +def set_session_cookie(response: Response, token: str) -> None: + """Set httpOnly secure signed cookie on response.""" + response.set_cookie( + key="session", + value=token, + httponly=True, + secure=app_settings.COOKIE_SECURE, + max_age=app_settings.SESSION_MAX_AGE_DAYS * 86400, + samesite="lax", + path="/", + ) + + +async def check_account_lockout(user: User) -> None: + """Raise HTTP 423 if the account is currently locked.""" + if user.locked_until and datetime.now() < user.locked_until: + remaining = int((user.locked_until - datetime.now()).total_seconds() / 60) + 1 + raise HTTPException( + status_code=423, + detail=f"Account locked. Try again in {remaining} minutes.", + ) + + +async def record_failed_login(db: AsyncSession, user: User) -> None: + """Increment failure counter; lock account after 10 failures.""" + user.failed_login_count += 1 + if user.failed_login_count >= 10: + user.locked_until = datetime.now() + timedelta(minutes=30) + await db.commit() + + +async def record_successful_login(db: AsyncSession, user: User) -> None: + """Reset failure counter and update last_login_at.""" + user.failed_login_count = 0 + user.locked_until = None + user.last_login_at = datetime.now() + await db.commit() + + +async def create_db_session( + db: AsyncSession, + user: User, + ip: str, + user_agent: str | None, +) -> tuple[str, str]: + """Insert a UserSession row and return (session_id, signed_cookie_token). + + Enforces MAX_SESSIONS_PER_USER by revoking oldest sessions beyond the cap. + """ + session_id = uuid.uuid4().hex + expires_at = datetime.now() + timedelta(days=app_settings.SESSION_MAX_AGE_DAYS) + db_session = UserSession( + id=session_id, + user_id=user.id, + expires_at=expires_at, + ip_address=ip[:45] if ip else None, + user_agent=(user_agent or "")[:255] if user_agent else None, + ) + db.add(db_session) + await db.flush() + + # Enforce concurrent session limit: revoke oldest sessions beyond the cap + active_sessions = ( + await db.execute( + select(UserSession) + .where( + UserSession.user_id == user.id, + UserSession.revoked == False, # noqa: E712 + UserSession.expires_at > datetime.now(), + ) + .order_by(UserSession.created_at.asc()) + ) + ).scalars().all() + + max_sessions = app_settings.MAX_SESSIONS_PER_USER + if len(active_sessions) > max_sessions: + for old_session in active_sessions[: len(active_sessions) - max_sessions]: + old_session.revoked = True + await db.flush() + + token = create_session_token(user.id, session_id) + return session_id, token