Phase 0: Consolidate session creation into shared service
Extract _create_db_session, _set_session_cookie, _check_account_lockout, _record_failed_login, and _record_successful_login from auth.py into services/session.py. Update totp.py to use shared service instead of its duplicate _create_full_session (which lacked session cap enforcement). Also fixes: - auth/status N+1 query (2 sequential queries -> single JOIN) - Rename verify_password route to verify_password_endpoint (shadow fix) Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
parent
c5a309f4a1
commit
eebb34aa77
@ -16,7 +16,6 @@ Security layers:
|
||||
4. bcrypt→Argon2id transparent upgrade on first login
|
||||
5. Role-based authorization via require_role() dependency factory
|
||||
"""
|
||||
import uuid
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional
|
||||
|
||||
@ -49,6 +48,13 @@ from app.services.auth import (
|
||||
create_mfa_enforce_token,
|
||||
)
|
||||
from app.services.audit import get_client_ip, log_audit_event
|
||||
from app.services.session import (
|
||||
set_session_cookie,
|
||||
check_account_lockout,
|
||||
record_failed_login,
|
||||
record_successful_login,
|
||||
create_db_session,
|
||||
)
|
||||
from app.config import settings as app_settings
|
||||
|
||||
router = APIRouter()
|
||||
@ -59,22 +65,6 @@ router = APIRouter()
|
||||
# is indistinguishable from a wrong-password attempt.
|
||||
_DUMMY_HASH = hash_password("timing-equalization-dummy")
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Cookie helper
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _set_session_cookie(response: Response, token: str) -> None:
|
||||
response.set_cookie(
|
||||
key="session",
|
||||
value=token,
|
||||
httponly=True,
|
||||
secure=app_settings.COOKIE_SECURE,
|
||||
max_age=app_settings.SESSION_MAX_AGE_DAYS * 86400,
|
||||
samesite="lax",
|
||||
path="/",
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Auth dependencies — export get_current_user and get_current_settings
|
||||
# ---------------------------------------------------------------------------
|
||||
@ -130,7 +120,7 @@ async def get_current_user(
|
||||
await db.flush()
|
||||
# Re-issue cookie with fresh signed token to reset browser max_age timer
|
||||
fresh_token = create_session_token(user_id, session_id)
|
||||
_set_session_cookie(response, fresh_token)
|
||||
set_session_cookie(response, fresh_token)
|
||||
|
||||
# Stash session on request so lock/unlock endpoints can access it
|
||||
request.state.db_session = db_session
|
||||
@ -190,82 +180,6 @@ def require_role(*allowed_roles: str):
|
||||
require_admin = require_role("admin")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Account lockout helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def _check_account_lockout(user: User) -> None:
|
||||
"""Raise HTTP 423 if the account is currently locked."""
|
||||
if user.locked_until and datetime.now() < user.locked_until:
|
||||
remaining = int((user.locked_until - datetime.now()).total_seconds() / 60) + 1
|
||||
raise HTTPException(
|
||||
status_code=423,
|
||||
detail=f"Account locked. Try again in {remaining} minutes.",
|
||||
)
|
||||
|
||||
|
||||
async def _record_failed_login(db: AsyncSession, user: User) -> None:
|
||||
"""Increment failure counter; lock account after 10 failures."""
|
||||
user.failed_login_count += 1
|
||||
if user.failed_login_count >= 10:
|
||||
user.locked_until = datetime.now() + timedelta(minutes=30)
|
||||
await db.commit()
|
||||
|
||||
|
||||
async def _record_successful_login(db: AsyncSession, user: User) -> None:
|
||||
"""Reset failure counter and update last_login_at."""
|
||||
user.failed_login_count = 0
|
||||
user.locked_until = None
|
||||
user.last_login_at = datetime.now()
|
||||
await db.commit()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Session creation helper
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def _create_db_session(
|
||||
db: AsyncSession,
|
||||
user: User,
|
||||
ip: str,
|
||||
user_agent: str | None,
|
||||
) -> tuple[str, str]:
|
||||
"""Insert a UserSession row and return (session_id, signed_cookie_token)."""
|
||||
session_id = uuid.uuid4().hex
|
||||
expires_at = datetime.now() + timedelta(days=app_settings.SESSION_MAX_AGE_DAYS)
|
||||
db_session = UserSession(
|
||||
id=session_id,
|
||||
user_id=user.id,
|
||||
expires_at=expires_at,
|
||||
ip_address=ip[:45] if ip else None,
|
||||
user_agent=(user_agent or "")[:255] if user_agent else None,
|
||||
)
|
||||
db.add(db_session)
|
||||
await db.flush()
|
||||
|
||||
# Enforce concurrent session limit: revoke oldest sessions beyond the cap
|
||||
active_sessions = (
|
||||
await db.execute(
|
||||
select(UserSession)
|
||||
.where(
|
||||
UserSession.user_id == user.id,
|
||||
UserSession.revoked == False, # noqa: E712
|
||||
UserSession.expires_at > datetime.now(),
|
||||
)
|
||||
.order_by(UserSession.created_at.asc())
|
||||
)
|
||||
).scalars().all()
|
||||
|
||||
max_sessions = app_settings.MAX_SESSIONS_PER_USER
|
||||
if len(active_sessions) > max_sessions:
|
||||
for old_session in active_sessions[: len(active_sessions) - max_sessions]:
|
||||
old_session.revoked = True
|
||||
await db.flush()
|
||||
|
||||
token = create_session_token(user.id, session_id)
|
||||
return session_id, token
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# User bootstrapping helper (Settings + default calendars)
|
||||
# ---------------------------------------------------------------------------
|
||||
@ -321,8 +235,8 @@ async def setup(
|
||||
|
||||
ip = get_client_ip(request)
|
||||
user_agent = request.headers.get("user-agent")
|
||||
_, token = await _create_db_session(db, new_user, ip, user_agent)
|
||||
_set_session_cookie(response, token)
|
||||
_, token = await create_db_session(db, new_user, ip, user_agent)
|
||||
set_session_cookie(response, token)
|
||||
|
||||
await log_audit_event(
|
||||
db, action="auth.setup_complete", actor_id=new_user.id, ip=ip,
|
||||
@ -366,10 +280,10 @@ async def login(
|
||||
# executes — prevents distinguishing "locked" from "wrong password" via timing.
|
||||
valid, new_hash = await averify_password_with_upgrade(data.password, user.password_hash)
|
||||
|
||||
await _check_account_lockout(user)
|
||||
await check_account_lockout(user)
|
||||
|
||||
if not valid:
|
||||
await _record_failed_login(db, user)
|
||||
await record_failed_login(db, user)
|
||||
await log_audit_event(
|
||||
db, action="auth.login_failed", actor_id=user.id,
|
||||
detail={"reason": "invalid_password"}, ip=client_ip,
|
||||
@ -378,7 +292,7 @@ async def login(
|
||||
raise HTTPException(status_code=401, detail="Invalid username or password")
|
||||
|
||||
# Block disabled accounts — checked AFTER password verification to avoid
|
||||
# leaking account-state info, and BEFORE _record_successful_login so
|
||||
# leaking account-state info, and BEFORE record_successful_login so
|
||||
# last_login_at and lockout counters are not reset for inactive users.
|
||||
if not user.is_active:
|
||||
await log_audit_event(
|
||||
@ -391,7 +305,7 @@ async def login(
|
||||
if new_hash:
|
||||
user.password_hash = new_hash
|
||||
|
||||
await _record_successful_login(db, user)
|
||||
await record_successful_login(db, user)
|
||||
|
||||
# SEC-03: MFA enforcement — block login entirely until MFA setup completes
|
||||
if user.mfa_enforce_pending and not user.totp_enabled:
|
||||
@ -419,8 +333,8 @@ async def login(
|
||||
if user.must_change_password:
|
||||
# Issue a session but flag the frontend to show password change
|
||||
user_agent = request.headers.get("user-agent")
|
||||
_, token = await _create_db_session(db, user, client_ip, user_agent)
|
||||
_set_session_cookie(response, token)
|
||||
_, token = await create_db_session(db, user, client_ip, user_agent)
|
||||
set_session_cookie(response, token)
|
||||
await db.commit()
|
||||
return {
|
||||
"authenticated": True,
|
||||
@ -428,8 +342,8 @@ async def login(
|
||||
}
|
||||
|
||||
user_agent = request.headers.get("user-agent")
|
||||
_, token = await _create_db_session(db, user, client_ip, user_agent)
|
||||
_set_session_cookie(response, token)
|
||||
_, token = await create_db_session(db, user, client_ip, user_agent)
|
||||
set_session_cookie(response, token)
|
||||
|
||||
await log_audit_event(
|
||||
db, action="auth.login_success", actor_id=user.id, ip=client_ip,
|
||||
@ -511,8 +425,8 @@ async def register(
|
||||
"mfa_token": enforce_token,
|
||||
}
|
||||
|
||||
_, token = await _create_db_session(db, new_user, ip, user_agent)
|
||||
_set_session_cookie(response, token)
|
||||
_, token = await create_db_session(db, new_user, ip, user_agent)
|
||||
set_session_cookie(response, token)
|
||||
await db.commit()
|
||||
|
||||
return {"message": "Registration successful", "authenticated": True}
|
||||
@ -564,32 +478,31 @@ async def auth_status(
|
||||
|
||||
is_locked = False
|
||||
|
||||
u = None
|
||||
if not setup_required and session_cookie:
|
||||
payload = verify_session_token(session_cookie)
|
||||
if payload:
|
||||
user_id = payload.get("uid")
|
||||
session_id = payload.get("sid")
|
||||
if user_id and session_id:
|
||||
session_result = await db.execute(
|
||||
select(UserSession).where(
|
||||
# Single JOIN query (was 2 sequential queries — P-01 fix)
|
||||
result = await db.execute(
|
||||
select(UserSession, User)
|
||||
.join(User, UserSession.user_id == User.id)
|
||||
.where(
|
||||
UserSession.id == session_id,
|
||||
UserSession.user_id == user_id,
|
||||
UserSession.revoked == False,
|
||||
UserSession.expires_at > datetime.now(),
|
||||
User.is_active == True,
|
||||
)
|
||||
)
|
||||
db_sess = session_result.scalar_one_or_none()
|
||||
if db_sess is not None:
|
||||
row = result.one_or_none()
|
||||
if row is not None:
|
||||
db_sess, u = row.tuple()
|
||||
authenticated = True
|
||||
is_locked = db_sess.is_locked
|
||||
user_obj_result = await db.execute(
|
||||
select(User).where(User.id == user_id, User.is_active == True)
|
||||
)
|
||||
u = user_obj_result.scalar_one_or_none()
|
||||
if u:
|
||||
role = u.role
|
||||
else:
|
||||
authenticated = False
|
||||
role = u.role
|
||||
|
||||
# Check registration availability
|
||||
registration_open = False
|
||||
@ -625,7 +538,7 @@ async def lock_session(
|
||||
|
||||
|
||||
@router.post("/verify-password")
|
||||
async def verify_password(
|
||||
async def verify_password_endpoint(
|
||||
data: VerifyPasswordRequest,
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
@ -635,11 +548,11 @@ async def verify_password(
|
||||
Verify the current user's password without changing anything.
|
||||
Used by the frontend lock screen to re-authenticate without a full login.
|
||||
"""
|
||||
await _check_account_lockout(current_user)
|
||||
await check_account_lockout(current_user)
|
||||
|
||||
valid, new_hash = await averify_password_with_upgrade(data.password, current_user.password_hash)
|
||||
if not valid:
|
||||
await _record_failed_login(db, current_user)
|
||||
await record_failed_login(db, current_user)
|
||||
raise HTTPException(status_code=401, detail="Invalid password")
|
||||
|
||||
if new_hash:
|
||||
@ -661,11 +574,11 @@ async def change_password(
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""Change the current user's password. Requires old password verification."""
|
||||
await _check_account_lockout(current_user)
|
||||
await check_account_lockout(current_user)
|
||||
|
||||
valid, _ = await averify_password_with_upgrade(data.old_password, current_user.password_hash)
|
||||
if not valid:
|
||||
await _record_failed_login(db, current_user)
|
||||
await record_failed_login(db, current_user)
|
||||
raise HTTPException(status_code=401, detail="Invalid current password")
|
||||
|
||||
if data.new_password == data.old_password:
|
||||
|
||||
@ -18,7 +18,6 @@ Security:
|
||||
- totp-verify uses mfa_token (not session cookie) — user is not yet authenticated
|
||||
"""
|
||||
import asyncio
|
||||
import uuid
|
||||
import secrets
|
||||
import logging
|
||||
from datetime import datetime, timedelta
|
||||
@ -32,17 +31,16 @@ from sqlalchemy.exc import IntegrityError
|
||||
|
||||
from app.database import get_db
|
||||
from app.models.user import User
|
||||
from app.models.session import UserSession
|
||||
from app.models.totp_usage import TOTPUsage
|
||||
from app.models.backup_code import BackupCode
|
||||
from app.routers.auth import get_current_user, _set_session_cookie
|
||||
from app.routers.auth import get_current_user
|
||||
from app.services.audit import get_client_ip
|
||||
from app.services.auth import (
|
||||
averify_password_with_upgrade,
|
||||
verify_mfa_token,
|
||||
verify_mfa_enforce_token,
|
||||
create_session_token,
|
||||
)
|
||||
from app.services.session import create_db_session, set_session_cookie
|
||||
from app.services.totp import (
|
||||
generate_totp_secret,
|
||||
encrypt_totp_secret,
|
||||
@ -52,7 +50,7 @@ from app.services.totp import (
|
||||
generate_qr_base64,
|
||||
generate_backup_codes,
|
||||
)
|
||||
from app.config import settings as app_settings
|
||||
|
||||
|
||||
# Argon2id for backup code hashing — treat each code like a password
|
||||
from argon2 import PasswordHasher
|
||||
@ -162,29 +160,6 @@ async def _verify_backup_code(
|
||||
return False
|
||||
|
||||
|
||||
async def _create_full_session(
|
||||
db: AsyncSession,
|
||||
user: User,
|
||||
request: Request,
|
||||
) -> str:
|
||||
"""Create a UserSession row and return the signed cookie token."""
|
||||
session_id = uuid.uuid4().hex
|
||||
expires_at = datetime.now() + timedelta(days=app_settings.SESSION_MAX_AGE_DAYS)
|
||||
ip = get_client_ip(request)
|
||||
user_agent = request.headers.get("user-agent")
|
||||
|
||||
db_session = UserSession(
|
||||
id=session_id,
|
||||
user_id=user.id,
|
||||
expires_at=expires_at,
|
||||
ip_address=ip[:45] if ip else None,
|
||||
user_agent=(user_agent or "")[:255] if user_agent else None,
|
||||
)
|
||||
db.add(db_session)
|
||||
await db.commit()
|
||||
return create_session_token(user.id, session_id)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Routes
|
||||
# ---------------------------------------------------------------------------
|
||||
@ -312,8 +287,10 @@ async def totp_verify(
|
||||
user.last_login_at = datetime.now()
|
||||
await db.commit()
|
||||
|
||||
token = await _create_full_session(db, user, request)
|
||||
_set_session_cookie(response, token)
|
||||
ip = get_client_ip(request)
|
||||
user_agent = request.headers.get("user-agent")
|
||||
_, token = await create_db_session(db, user, ip, user_agent)
|
||||
set_session_cookie(response, token)
|
||||
return {"authenticated": True}
|
||||
|
||||
# --- TOTP code path ---
|
||||
@ -340,8 +317,10 @@ async def totp_verify(
|
||||
user.last_login_at = datetime.now()
|
||||
await db.commit()
|
||||
|
||||
token = await _create_full_session(db, user, request)
|
||||
_set_session_cookie(response, token)
|
||||
ip = get_client_ip(request)
|
||||
user_agent = request.headers.get("user-agent")
|
||||
_, token = await create_db_session(db, user, ip, user_agent)
|
||||
set_session_cookie(response, token)
|
||||
return {"authenticated": True}
|
||||
|
||||
|
||||
@ -513,9 +492,11 @@ async def enforce_confirm_totp(
|
||||
user.last_login_at = datetime.now()
|
||||
await db.commit()
|
||||
|
||||
# Issue a full session
|
||||
token = await _create_full_session(db, user, request)
|
||||
_set_session_cookie(response, token)
|
||||
# Issue a full session (now uses shared session service with cap enforcement)
|
||||
ip = get_client_ip(request)
|
||||
user_agent = request.headers.get("user-agent")
|
||||
_, token = await create_db_session(db, user, ip, user_agent)
|
||||
set_session_cookie(response, token)
|
||||
|
||||
return {"authenticated": True}
|
||||
|
||||
|
||||
103
backend/app/services/session.py
Normal file
103
backend/app/services/session.py
Normal file
@ -0,0 +1,103 @@
|
||||
"""
|
||||
Shared session management service.
|
||||
|
||||
Consolidates session creation, cookie handling, and account lockout logic
|
||||
that was previously duplicated between auth.py and totp.py routers.
|
||||
All auth paths (password, TOTP, passkey) use these functions to ensure
|
||||
consistent session cap enforcement and lockout behavior.
|
||||
"""
|
||||
import uuid
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from fastapi import HTTPException, Response
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.models.user import User
|
||||
from app.models.session import UserSession
|
||||
from app.services.auth import create_session_token
|
||||
from app.config import settings as app_settings
|
||||
|
||||
|
||||
def set_session_cookie(response: Response, token: str) -> None:
|
||||
"""Set httpOnly secure signed cookie on response."""
|
||||
response.set_cookie(
|
||||
key="session",
|
||||
value=token,
|
||||
httponly=True,
|
||||
secure=app_settings.COOKIE_SECURE,
|
||||
max_age=app_settings.SESSION_MAX_AGE_DAYS * 86400,
|
||||
samesite="lax",
|
||||
path="/",
|
||||
)
|
||||
|
||||
|
||||
async def check_account_lockout(user: User) -> None:
|
||||
"""Raise HTTP 423 if the account is currently locked."""
|
||||
if user.locked_until and datetime.now() < user.locked_until:
|
||||
remaining = int((user.locked_until - datetime.now()).total_seconds() / 60) + 1
|
||||
raise HTTPException(
|
||||
status_code=423,
|
||||
detail=f"Account locked. Try again in {remaining} minutes.",
|
||||
)
|
||||
|
||||
|
||||
async def record_failed_login(db: AsyncSession, user: User) -> None:
|
||||
"""Increment failure counter; lock account after 10 failures."""
|
||||
user.failed_login_count += 1
|
||||
if user.failed_login_count >= 10:
|
||||
user.locked_until = datetime.now() + timedelta(minutes=30)
|
||||
await db.commit()
|
||||
|
||||
|
||||
async def record_successful_login(db: AsyncSession, user: User) -> None:
|
||||
"""Reset failure counter and update last_login_at."""
|
||||
user.failed_login_count = 0
|
||||
user.locked_until = None
|
||||
user.last_login_at = datetime.now()
|
||||
await db.commit()
|
||||
|
||||
|
||||
async def create_db_session(
|
||||
db: AsyncSession,
|
||||
user: User,
|
||||
ip: str,
|
||||
user_agent: str | None,
|
||||
) -> tuple[str, str]:
|
||||
"""Insert a UserSession row and return (session_id, signed_cookie_token).
|
||||
|
||||
Enforces MAX_SESSIONS_PER_USER by revoking oldest sessions beyond the cap.
|
||||
"""
|
||||
session_id = uuid.uuid4().hex
|
||||
expires_at = datetime.now() + timedelta(days=app_settings.SESSION_MAX_AGE_DAYS)
|
||||
db_session = UserSession(
|
||||
id=session_id,
|
||||
user_id=user.id,
|
||||
expires_at=expires_at,
|
||||
ip_address=ip[:45] if ip else None,
|
||||
user_agent=(user_agent or "")[:255] if user_agent else None,
|
||||
)
|
||||
db.add(db_session)
|
||||
await db.flush()
|
||||
|
||||
# Enforce concurrent session limit: revoke oldest sessions beyond the cap
|
||||
active_sessions = (
|
||||
await db.execute(
|
||||
select(UserSession)
|
||||
.where(
|
||||
UserSession.user_id == user.id,
|
||||
UserSession.revoked == False, # noqa: E712
|
||||
UserSession.expires_at > datetime.now(),
|
||||
)
|
||||
.order_by(UserSession.created_at.asc())
|
||||
)
|
||||
).scalars().all()
|
||||
|
||||
max_sessions = app_settings.MAX_SESSIONS_PER_USER
|
||||
if len(active_sessions) > max_sessions:
|
||||
for old_session in active_sessions[: len(active_sessions) - max_sessions]:
|
||||
old_session.revoked = True
|
||||
await db.flush()
|
||||
|
||||
token = create_session_token(user.id, session_id)
|
||||
return session_id, token
|
||||
Loading…
x
Reference in New Issue
Block a user