""" Shared session management service. Consolidates session creation, cookie handling, and account lockout logic that was previously duplicated between auth.py and totp.py routers. All auth paths (password, TOTP, passkey) use these functions to ensure consistent session cap enforcement and lockout behavior. """ import uuid from datetime import datetime, timedelta from fastapi import HTTPException, Response from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy import select, update from app.models.user import User from app.models.session import UserSession from app.services.auth import create_session_token from app.config import settings as app_settings def set_session_cookie(response: Response, token: str) -> None: """Set httpOnly secure signed cookie on response.""" response.set_cookie( key="session", value=token, httponly=True, secure=app_settings.COOKIE_SECURE, max_age=app_settings.SESSION_MAX_AGE_DAYS * 86400, samesite="lax", path="/", ) async def check_account_lockout(user: User) -> None: """Raise HTTP 401 if the account is currently locked. Uses 401 (same status as wrong-password) so that status-code analysis cannot distinguish a locked account from an invalid credential (F-02). """ if user.locked_until and datetime.now() < user.locked_until: remaining = int((user.locked_until - datetime.now()).total_seconds() / 60) + 1 raise HTTPException( status_code=401, detail=f"Account temporarily locked. Try again in {remaining} minutes.", ) async def record_failed_login(db: AsyncSession, user: User) -> int: """Increment failure counter; lock account after 10 failures. Returns the number of attempts remaining before lockout (0 = just locked). Does NOT commit — caller owns the transaction boundary. """ user.failed_login_count += 1 remaining = max(0, 10 - user.failed_login_count) if user.failed_login_count >= 10: user.locked_until = datetime.now() + timedelta(minutes=30) await db.flush() return remaining async def record_successful_login(db: AsyncSession, user: User) -> None: """Reset failure counter and update last_login_at. Does NOT commit — caller owns the transaction boundary. """ user.failed_login_count = 0 user.locked_until = None user.last_login_at = datetime.now() await db.flush() async def create_db_session( db: AsyncSession, user: User, ip: str, user_agent: str | None, ) -> tuple[str, str]: """Insert a UserSession row and return (session_id, signed_cookie_token). Enforces MAX_SESSIONS_PER_USER by revoking oldest sessions beyond the cap. """ session_id = uuid.uuid4().hex expires_at = datetime.now() + timedelta(days=app_settings.SESSION_MAX_AGE_DAYS) db_session = UserSession( id=session_id, user_id=user.id, expires_at=expires_at, ip_address=ip[:45] if ip else None, user_agent=(user_agent or "")[:255] if user_agent else None, ) db.add(db_session) await db.flush() # Enforce concurrent session limit: revoke oldest sessions beyond the cap. # Perf-2: Query IDs only, bulk-update instead of loading full ORM objects. max_sessions = app_settings.MAX_SESSIONS_PER_USER active_ids = ( await db.execute( select(UserSession.id) .where( UserSession.user_id == user.id, UserSession.revoked == False, # noqa: E712 UserSession.expires_at > datetime.now(), ) .order_by(UserSession.created_at.asc()) ) ).scalars().all() if len(active_ids) > max_sessions: ids_to_revoke = active_ids[: len(active_ids) - max_sessions] await db.execute( update(UserSession) .where(UserSession.id.in_(ids_to_revoke)) .values(revoked=True) ) await db.flush() token = create_session_token(user.id, session_id) return session_id, token