""" Shared session management service. Consolidates session creation, cookie handling, and account lockout logic that was previously duplicated between auth.py and totp.py routers. All auth paths (password, TOTP, passkey) use these functions to ensure consistent session cap enforcement and lockout behavior. """ import uuid from datetime import datetime, timedelta from fastapi import HTTPException, Response from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy import select from app.models.user import User from app.models.session import UserSession from app.services.auth import create_session_token from app.config import settings as app_settings def set_session_cookie(response: Response, token: str) -> None: """Set httpOnly secure signed cookie on response.""" response.set_cookie( key="session", value=token, httponly=True, secure=app_settings.COOKIE_SECURE, max_age=app_settings.SESSION_MAX_AGE_DAYS * 86400, samesite="lax", path="/", ) async def check_account_lockout(user: User) -> None: """Raise HTTP 423 if the account is currently locked.""" if user.locked_until and datetime.now() < user.locked_until: remaining = int((user.locked_until - datetime.now()).total_seconds() / 60) + 1 raise HTTPException( status_code=423, detail=f"Account locked. Try again in {remaining} minutes.", ) async def record_failed_login(db: AsyncSession, user: User) -> None: """Increment failure counter; lock account after 10 failures.""" user.failed_login_count += 1 if user.failed_login_count >= 10: user.locked_until = datetime.now() + timedelta(minutes=30) await db.commit() async def record_successful_login(db: AsyncSession, user: User) -> None: """Reset failure counter and update last_login_at.""" user.failed_login_count = 0 user.locked_until = None user.last_login_at = datetime.now() await db.commit() async def create_db_session( db: AsyncSession, user: User, ip: str, user_agent: str | None, ) -> tuple[str, str]: """Insert a UserSession row and return (session_id, signed_cookie_token). Enforces MAX_SESSIONS_PER_USER by revoking oldest sessions beyond the cap. """ session_id = uuid.uuid4().hex expires_at = datetime.now() + timedelta(days=app_settings.SESSION_MAX_AGE_DAYS) db_session = UserSession( id=session_id, user_id=user.id, expires_at=expires_at, ip_address=ip[:45] if ip else None, user_agent=(user_agent or "")[:255] if user_agent else None, ) db.add(db_session) await db.flush() # Enforce concurrent session limit: revoke oldest sessions beyond the cap active_sessions = ( await db.execute( select(UserSession) .where( UserSession.user_id == user.id, UserSession.revoked == False, # noqa: E712 UserSession.expires_at > datetime.now(), ) .order_by(UserSession.created_at.asc()) ) ).scalars().all() max_sessions = app_settings.MAX_SESSIONS_PER_USER if len(active_sessions) > max_sessions: for old_session in active_sessions[: len(active_sessions) - max_sessions]: old_session.revoked = True await db.flush() token = create_session_token(user.id, session_id) return session_id, token