Phase 0: Consolidate session creation into shared service
Extract _create_db_session, _set_session_cookie, _check_account_lockout, _record_failed_login, and _record_successful_login from auth.py into services/session.py. Update totp.py to use shared service instead of its duplicate _create_full_session (which lacked session cap enforcement). Also fixes: - auth/status N+1 query (2 sequential queries -> single JOIN) - Rename verify_password route to verify_password_endpoint (shadow fix) Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
parent
c5a309f4a1
commit
eebb34aa77
@ -16,7 +16,6 @@ Security layers:
|
|||||||
4. bcrypt→Argon2id transparent upgrade on first login
|
4. bcrypt→Argon2id transparent upgrade on first login
|
||||||
5. Role-based authorization via require_role() dependency factory
|
5. Role-based authorization via require_role() dependency factory
|
||||||
"""
|
"""
|
||||||
import uuid
|
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
@ -49,6 +48,13 @@ from app.services.auth import (
|
|||||||
create_mfa_enforce_token,
|
create_mfa_enforce_token,
|
||||||
)
|
)
|
||||||
from app.services.audit import get_client_ip, log_audit_event
|
from app.services.audit import get_client_ip, log_audit_event
|
||||||
|
from app.services.session import (
|
||||||
|
set_session_cookie,
|
||||||
|
check_account_lockout,
|
||||||
|
record_failed_login,
|
||||||
|
record_successful_login,
|
||||||
|
create_db_session,
|
||||||
|
)
|
||||||
from app.config import settings as app_settings
|
from app.config import settings as app_settings
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
@ -59,22 +65,6 @@ router = APIRouter()
|
|||||||
# is indistinguishable from a wrong-password attempt.
|
# is indistinguishable from a wrong-password attempt.
|
||||||
_DUMMY_HASH = hash_password("timing-equalization-dummy")
|
_DUMMY_HASH = hash_password("timing-equalization-dummy")
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Cookie helper
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
def _set_session_cookie(response: Response, token: str) -> None:
|
|
||||||
response.set_cookie(
|
|
||||||
key="session",
|
|
||||||
value=token,
|
|
||||||
httponly=True,
|
|
||||||
secure=app_settings.COOKIE_SECURE,
|
|
||||||
max_age=app_settings.SESSION_MAX_AGE_DAYS * 86400,
|
|
||||||
samesite="lax",
|
|
||||||
path="/",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Auth dependencies — export get_current_user and get_current_settings
|
# Auth dependencies — export get_current_user and get_current_settings
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@ -130,7 +120,7 @@ async def get_current_user(
|
|||||||
await db.flush()
|
await db.flush()
|
||||||
# Re-issue cookie with fresh signed token to reset browser max_age timer
|
# Re-issue cookie with fresh signed token to reset browser max_age timer
|
||||||
fresh_token = create_session_token(user_id, session_id)
|
fresh_token = create_session_token(user_id, session_id)
|
||||||
_set_session_cookie(response, fresh_token)
|
set_session_cookie(response, fresh_token)
|
||||||
|
|
||||||
# Stash session on request so lock/unlock endpoints can access it
|
# Stash session on request so lock/unlock endpoints can access it
|
||||||
request.state.db_session = db_session
|
request.state.db_session = db_session
|
||||||
@ -190,82 +180,6 @@ def require_role(*allowed_roles: str):
|
|||||||
require_admin = require_role("admin")
|
require_admin = require_role("admin")
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Account lockout helpers
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
async def _check_account_lockout(user: User) -> None:
|
|
||||||
"""Raise HTTP 423 if the account is currently locked."""
|
|
||||||
if user.locked_until and datetime.now() < user.locked_until:
|
|
||||||
remaining = int((user.locked_until - datetime.now()).total_seconds() / 60) + 1
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=423,
|
|
||||||
detail=f"Account locked. Try again in {remaining} minutes.",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def _record_failed_login(db: AsyncSession, user: User) -> None:
|
|
||||||
"""Increment failure counter; lock account after 10 failures."""
|
|
||||||
user.failed_login_count += 1
|
|
||||||
if user.failed_login_count >= 10:
|
|
||||||
user.locked_until = datetime.now() + timedelta(minutes=30)
|
|
||||||
await db.commit()
|
|
||||||
|
|
||||||
|
|
||||||
async def _record_successful_login(db: AsyncSession, user: User) -> None:
|
|
||||||
"""Reset failure counter and update last_login_at."""
|
|
||||||
user.failed_login_count = 0
|
|
||||||
user.locked_until = None
|
|
||||||
user.last_login_at = datetime.now()
|
|
||||||
await db.commit()
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Session creation helper
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
async def _create_db_session(
|
|
||||||
db: AsyncSession,
|
|
||||||
user: User,
|
|
||||||
ip: str,
|
|
||||||
user_agent: str | None,
|
|
||||||
) -> tuple[str, str]:
|
|
||||||
"""Insert a UserSession row and return (session_id, signed_cookie_token)."""
|
|
||||||
session_id = uuid.uuid4().hex
|
|
||||||
expires_at = datetime.now() + timedelta(days=app_settings.SESSION_MAX_AGE_DAYS)
|
|
||||||
db_session = UserSession(
|
|
||||||
id=session_id,
|
|
||||||
user_id=user.id,
|
|
||||||
expires_at=expires_at,
|
|
||||||
ip_address=ip[:45] if ip else None,
|
|
||||||
user_agent=(user_agent or "")[:255] if user_agent else None,
|
|
||||||
)
|
|
||||||
db.add(db_session)
|
|
||||||
await db.flush()
|
|
||||||
|
|
||||||
# Enforce concurrent session limit: revoke oldest sessions beyond the cap
|
|
||||||
active_sessions = (
|
|
||||||
await db.execute(
|
|
||||||
select(UserSession)
|
|
||||||
.where(
|
|
||||||
UserSession.user_id == user.id,
|
|
||||||
UserSession.revoked == False, # noqa: E712
|
|
||||||
UserSession.expires_at > datetime.now(),
|
|
||||||
)
|
|
||||||
.order_by(UserSession.created_at.asc())
|
|
||||||
)
|
|
||||||
).scalars().all()
|
|
||||||
|
|
||||||
max_sessions = app_settings.MAX_SESSIONS_PER_USER
|
|
||||||
if len(active_sessions) > max_sessions:
|
|
||||||
for old_session in active_sessions[: len(active_sessions) - max_sessions]:
|
|
||||||
old_session.revoked = True
|
|
||||||
await db.flush()
|
|
||||||
|
|
||||||
token = create_session_token(user.id, session_id)
|
|
||||||
return session_id, token
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# User bootstrapping helper (Settings + default calendars)
|
# User bootstrapping helper (Settings + default calendars)
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@ -321,8 +235,8 @@ async def setup(
|
|||||||
|
|
||||||
ip = get_client_ip(request)
|
ip = get_client_ip(request)
|
||||||
user_agent = request.headers.get("user-agent")
|
user_agent = request.headers.get("user-agent")
|
||||||
_, token = await _create_db_session(db, new_user, ip, user_agent)
|
_, token = await create_db_session(db, new_user, ip, user_agent)
|
||||||
_set_session_cookie(response, token)
|
set_session_cookie(response, token)
|
||||||
|
|
||||||
await log_audit_event(
|
await log_audit_event(
|
||||||
db, action="auth.setup_complete", actor_id=new_user.id, ip=ip,
|
db, action="auth.setup_complete", actor_id=new_user.id, ip=ip,
|
||||||
@ -366,10 +280,10 @@ async def login(
|
|||||||
# executes — prevents distinguishing "locked" from "wrong password" via timing.
|
# executes — prevents distinguishing "locked" from "wrong password" via timing.
|
||||||
valid, new_hash = await averify_password_with_upgrade(data.password, user.password_hash)
|
valid, new_hash = await averify_password_with_upgrade(data.password, user.password_hash)
|
||||||
|
|
||||||
await _check_account_lockout(user)
|
await check_account_lockout(user)
|
||||||
|
|
||||||
if not valid:
|
if not valid:
|
||||||
await _record_failed_login(db, user)
|
await record_failed_login(db, user)
|
||||||
await log_audit_event(
|
await log_audit_event(
|
||||||
db, action="auth.login_failed", actor_id=user.id,
|
db, action="auth.login_failed", actor_id=user.id,
|
||||||
detail={"reason": "invalid_password"}, ip=client_ip,
|
detail={"reason": "invalid_password"}, ip=client_ip,
|
||||||
@ -378,7 +292,7 @@ async def login(
|
|||||||
raise HTTPException(status_code=401, detail="Invalid username or password")
|
raise HTTPException(status_code=401, detail="Invalid username or password")
|
||||||
|
|
||||||
# Block disabled accounts — checked AFTER password verification to avoid
|
# Block disabled accounts — checked AFTER password verification to avoid
|
||||||
# leaking account-state info, and BEFORE _record_successful_login so
|
# leaking account-state info, and BEFORE record_successful_login so
|
||||||
# last_login_at and lockout counters are not reset for inactive users.
|
# last_login_at and lockout counters are not reset for inactive users.
|
||||||
if not user.is_active:
|
if not user.is_active:
|
||||||
await log_audit_event(
|
await log_audit_event(
|
||||||
@ -391,7 +305,7 @@ async def login(
|
|||||||
if new_hash:
|
if new_hash:
|
||||||
user.password_hash = new_hash
|
user.password_hash = new_hash
|
||||||
|
|
||||||
await _record_successful_login(db, user)
|
await record_successful_login(db, user)
|
||||||
|
|
||||||
# SEC-03: MFA enforcement — block login entirely until MFA setup completes
|
# SEC-03: MFA enforcement — block login entirely until MFA setup completes
|
||||||
if user.mfa_enforce_pending and not user.totp_enabled:
|
if user.mfa_enforce_pending and not user.totp_enabled:
|
||||||
@ -419,8 +333,8 @@ async def login(
|
|||||||
if user.must_change_password:
|
if user.must_change_password:
|
||||||
# Issue a session but flag the frontend to show password change
|
# Issue a session but flag the frontend to show password change
|
||||||
user_agent = request.headers.get("user-agent")
|
user_agent = request.headers.get("user-agent")
|
||||||
_, token = await _create_db_session(db, user, client_ip, user_agent)
|
_, token = await create_db_session(db, user, client_ip, user_agent)
|
||||||
_set_session_cookie(response, token)
|
set_session_cookie(response, token)
|
||||||
await db.commit()
|
await db.commit()
|
||||||
return {
|
return {
|
||||||
"authenticated": True,
|
"authenticated": True,
|
||||||
@ -428,8 +342,8 @@ async def login(
|
|||||||
}
|
}
|
||||||
|
|
||||||
user_agent = request.headers.get("user-agent")
|
user_agent = request.headers.get("user-agent")
|
||||||
_, token = await _create_db_session(db, user, client_ip, user_agent)
|
_, token = await create_db_session(db, user, client_ip, user_agent)
|
||||||
_set_session_cookie(response, token)
|
set_session_cookie(response, token)
|
||||||
|
|
||||||
await log_audit_event(
|
await log_audit_event(
|
||||||
db, action="auth.login_success", actor_id=user.id, ip=client_ip,
|
db, action="auth.login_success", actor_id=user.id, ip=client_ip,
|
||||||
@ -511,8 +425,8 @@ async def register(
|
|||||||
"mfa_token": enforce_token,
|
"mfa_token": enforce_token,
|
||||||
}
|
}
|
||||||
|
|
||||||
_, token = await _create_db_session(db, new_user, ip, user_agent)
|
_, token = await create_db_session(db, new_user, ip, user_agent)
|
||||||
_set_session_cookie(response, token)
|
set_session_cookie(response, token)
|
||||||
await db.commit()
|
await db.commit()
|
||||||
|
|
||||||
return {"message": "Registration successful", "authenticated": True}
|
return {"message": "Registration successful", "authenticated": True}
|
||||||
@ -564,32 +478,31 @@ async def auth_status(
|
|||||||
|
|
||||||
is_locked = False
|
is_locked = False
|
||||||
|
|
||||||
|
u = None
|
||||||
if not setup_required and session_cookie:
|
if not setup_required and session_cookie:
|
||||||
payload = verify_session_token(session_cookie)
|
payload = verify_session_token(session_cookie)
|
||||||
if payload:
|
if payload:
|
||||||
user_id = payload.get("uid")
|
user_id = payload.get("uid")
|
||||||
session_id = payload.get("sid")
|
session_id = payload.get("sid")
|
||||||
if user_id and session_id:
|
if user_id and session_id:
|
||||||
session_result = await db.execute(
|
# Single JOIN query (was 2 sequential queries — P-01 fix)
|
||||||
select(UserSession).where(
|
result = await db.execute(
|
||||||
|
select(UserSession, User)
|
||||||
|
.join(User, UserSession.user_id == User.id)
|
||||||
|
.where(
|
||||||
UserSession.id == session_id,
|
UserSession.id == session_id,
|
||||||
UserSession.user_id == user_id,
|
UserSession.user_id == user_id,
|
||||||
UserSession.revoked == False,
|
UserSession.revoked == False,
|
||||||
UserSession.expires_at > datetime.now(),
|
UserSession.expires_at > datetime.now(),
|
||||||
|
User.is_active == True,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
db_sess = session_result.scalar_one_or_none()
|
row = result.one_or_none()
|
||||||
if db_sess is not None:
|
if row is not None:
|
||||||
|
db_sess, u = row.tuple()
|
||||||
authenticated = True
|
authenticated = True
|
||||||
is_locked = db_sess.is_locked
|
is_locked = db_sess.is_locked
|
||||||
user_obj_result = await db.execute(
|
|
||||||
select(User).where(User.id == user_id, User.is_active == True)
|
|
||||||
)
|
|
||||||
u = user_obj_result.scalar_one_or_none()
|
|
||||||
if u:
|
|
||||||
role = u.role
|
role = u.role
|
||||||
else:
|
|
||||||
authenticated = False
|
|
||||||
|
|
||||||
# Check registration availability
|
# Check registration availability
|
||||||
registration_open = False
|
registration_open = False
|
||||||
@ -625,7 +538,7 @@ async def lock_session(
|
|||||||
|
|
||||||
|
|
||||||
@router.post("/verify-password")
|
@router.post("/verify-password")
|
||||||
async def verify_password(
|
async def verify_password_endpoint(
|
||||||
data: VerifyPasswordRequest,
|
data: VerifyPasswordRequest,
|
||||||
request: Request,
|
request: Request,
|
||||||
db: AsyncSession = Depends(get_db),
|
db: AsyncSession = Depends(get_db),
|
||||||
@ -635,11 +548,11 @@ async def verify_password(
|
|||||||
Verify the current user's password without changing anything.
|
Verify the current user's password without changing anything.
|
||||||
Used by the frontend lock screen to re-authenticate without a full login.
|
Used by the frontend lock screen to re-authenticate without a full login.
|
||||||
"""
|
"""
|
||||||
await _check_account_lockout(current_user)
|
await check_account_lockout(current_user)
|
||||||
|
|
||||||
valid, new_hash = await averify_password_with_upgrade(data.password, current_user.password_hash)
|
valid, new_hash = await averify_password_with_upgrade(data.password, current_user.password_hash)
|
||||||
if not valid:
|
if not valid:
|
||||||
await _record_failed_login(db, current_user)
|
await record_failed_login(db, current_user)
|
||||||
raise HTTPException(status_code=401, detail="Invalid password")
|
raise HTTPException(status_code=401, detail="Invalid password")
|
||||||
|
|
||||||
if new_hash:
|
if new_hash:
|
||||||
@ -661,11 +574,11 @@ async def change_password(
|
|||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
):
|
):
|
||||||
"""Change the current user's password. Requires old password verification."""
|
"""Change the current user's password. Requires old password verification."""
|
||||||
await _check_account_lockout(current_user)
|
await check_account_lockout(current_user)
|
||||||
|
|
||||||
valid, _ = await averify_password_with_upgrade(data.old_password, current_user.password_hash)
|
valid, _ = await averify_password_with_upgrade(data.old_password, current_user.password_hash)
|
||||||
if not valid:
|
if not valid:
|
||||||
await _record_failed_login(db, current_user)
|
await record_failed_login(db, current_user)
|
||||||
raise HTTPException(status_code=401, detail="Invalid current password")
|
raise HTTPException(status_code=401, detail="Invalid current password")
|
||||||
|
|
||||||
if data.new_password == data.old_password:
|
if data.new_password == data.old_password:
|
||||||
|
|||||||
@ -18,7 +18,6 @@ Security:
|
|||||||
- totp-verify uses mfa_token (not session cookie) — user is not yet authenticated
|
- totp-verify uses mfa_token (not session cookie) — user is not yet authenticated
|
||||||
"""
|
"""
|
||||||
import asyncio
|
import asyncio
|
||||||
import uuid
|
|
||||||
import secrets
|
import secrets
|
||||||
import logging
|
import logging
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
@ -32,17 +31,16 @@ from sqlalchemy.exc import IntegrityError
|
|||||||
|
|
||||||
from app.database import get_db
|
from app.database import get_db
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
from app.models.session import UserSession
|
|
||||||
from app.models.totp_usage import TOTPUsage
|
from app.models.totp_usage import TOTPUsage
|
||||||
from app.models.backup_code import BackupCode
|
from app.models.backup_code import BackupCode
|
||||||
from app.routers.auth import get_current_user, _set_session_cookie
|
from app.routers.auth import get_current_user
|
||||||
from app.services.audit import get_client_ip
|
from app.services.audit import get_client_ip
|
||||||
from app.services.auth import (
|
from app.services.auth import (
|
||||||
averify_password_with_upgrade,
|
averify_password_with_upgrade,
|
||||||
verify_mfa_token,
|
verify_mfa_token,
|
||||||
verify_mfa_enforce_token,
|
verify_mfa_enforce_token,
|
||||||
create_session_token,
|
|
||||||
)
|
)
|
||||||
|
from app.services.session import create_db_session, set_session_cookie
|
||||||
from app.services.totp import (
|
from app.services.totp import (
|
||||||
generate_totp_secret,
|
generate_totp_secret,
|
||||||
encrypt_totp_secret,
|
encrypt_totp_secret,
|
||||||
@ -52,7 +50,7 @@ from app.services.totp import (
|
|||||||
generate_qr_base64,
|
generate_qr_base64,
|
||||||
generate_backup_codes,
|
generate_backup_codes,
|
||||||
)
|
)
|
||||||
from app.config import settings as app_settings
|
|
||||||
|
|
||||||
# Argon2id for backup code hashing — treat each code like a password
|
# Argon2id for backup code hashing — treat each code like a password
|
||||||
from argon2 import PasswordHasher
|
from argon2 import PasswordHasher
|
||||||
@ -162,29 +160,6 @@ async def _verify_backup_code(
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
async def _create_full_session(
|
|
||||||
db: AsyncSession,
|
|
||||||
user: User,
|
|
||||||
request: Request,
|
|
||||||
) -> str:
|
|
||||||
"""Create a UserSession row and return the signed cookie token."""
|
|
||||||
session_id = uuid.uuid4().hex
|
|
||||||
expires_at = datetime.now() + timedelta(days=app_settings.SESSION_MAX_AGE_DAYS)
|
|
||||||
ip = get_client_ip(request)
|
|
||||||
user_agent = request.headers.get("user-agent")
|
|
||||||
|
|
||||||
db_session = UserSession(
|
|
||||||
id=session_id,
|
|
||||||
user_id=user.id,
|
|
||||||
expires_at=expires_at,
|
|
||||||
ip_address=ip[:45] if ip else None,
|
|
||||||
user_agent=(user_agent or "")[:255] if user_agent else None,
|
|
||||||
)
|
|
||||||
db.add(db_session)
|
|
||||||
await db.commit()
|
|
||||||
return create_session_token(user.id, session_id)
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Routes
|
# Routes
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@ -312,8 +287,10 @@ async def totp_verify(
|
|||||||
user.last_login_at = datetime.now()
|
user.last_login_at = datetime.now()
|
||||||
await db.commit()
|
await db.commit()
|
||||||
|
|
||||||
token = await _create_full_session(db, user, request)
|
ip = get_client_ip(request)
|
||||||
_set_session_cookie(response, token)
|
user_agent = request.headers.get("user-agent")
|
||||||
|
_, token = await create_db_session(db, user, ip, user_agent)
|
||||||
|
set_session_cookie(response, token)
|
||||||
return {"authenticated": True}
|
return {"authenticated": True}
|
||||||
|
|
||||||
# --- TOTP code path ---
|
# --- TOTP code path ---
|
||||||
@ -340,8 +317,10 @@ async def totp_verify(
|
|||||||
user.last_login_at = datetime.now()
|
user.last_login_at = datetime.now()
|
||||||
await db.commit()
|
await db.commit()
|
||||||
|
|
||||||
token = await _create_full_session(db, user, request)
|
ip = get_client_ip(request)
|
||||||
_set_session_cookie(response, token)
|
user_agent = request.headers.get("user-agent")
|
||||||
|
_, token = await create_db_session(db, user, ip, user_agent)
|
||||||
|
set_session_cookie(response, token)
|
||||||
return {"authenticated": True}
|
return {"authenticated": True}
|
||||||
|
|
||||||
|
|
||||||
@ -513,9 +492,11 @@ async def enforce_confirm_totp(
|
|||||||
user.last_login_at = datetime.now()
|
user.last_login_at = datetime.now()
|
||||||
await db.commit()
|
await db.commit()
|
||||||
|
|
||||||
# Issue a full session
|
# Issue a full session (now uses shared session service with cap enforcement)
|
||||||
token = await _create_full_session(db, user, request)
|
ip = get_client_ip(request)
|
||||||
_set_session_cookie(response, token)
|
user_agent = request.headers.get("user-agent")
|
||||||
|
_, token = await create_db_session(db, user, ip, user_agent)
|
||||||
|
set_session_cookie(response, token)
|
||||||
|
|
||||||
return {"authenticated": True}
|
return {"authenticated": True}
|
||||||
|
|
||||||
|
|||||||
103
backend/app/services/session.py
Normal file
103
backend/app/services/session.py
Normal file
@ -0,0 +1,103 @@
|
|||||||
|
"""
|
||||||
|
Shared session management service.
|
||||||
|
|
||||||
|
Consolidates session creation, cookie handling, and account lockout logic
|
||||||
|
that was previously duplicated between auth.py and totp.py routers.
|
||||||
|
All auth paths (password, TOTP, passkey) use these functions to ensure
|
||||||
|
consistent session cap enforcement and lockout behavior.
|
||||||
|
"""
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
|
||||||
|
from fastapi import HTTPException, Response
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
from sqlalchemy import select
|
||||||
|
|
||||||
|
from app.models.user import User
|
||||||
|
from app.models.session import UserSession
|
||||||
|
from app.services.auth import create_session_token
|
||||||
|
from app.config import settings as app_settings
|
||||||
|
|
||||||
|
|
||||||
|
def set_session_cookie(response: Response, token: str) -> None:
|
||||||
|
"""Set httpOnly secure signed cookie on response."""
|
||||||
|
response.set_cookie(
|
||||||
|
key="session",
|
||||||
|
value=token,
|
||||||
|
httponly=True,
|
||||||
|
secure=app_settings.COOKIE_SECURE,
|
||||||
|
max_age=app_settings.SESSION_MAX_AGE_DAYS * 86400,
|
||||||
|
samesite="lax",
|
||||||
|
path="/",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def check_account_lockout(user: User) -> None:
|
||||||
|
"""Raise HTTP 423 if the account is currently locked."""
|
||||||
|
if user.locked_until and datetime.now() < user.locked_until:
|
||||||
|
remaining = int((user.locked_until - datetime.now()).total_seconds() / 60) + 1
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=423,
|
||||||
|
detail=f"Account locked. Try again in {remaining} minutes.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def record_failed_login(db: AsyncSession, user: User) -> None:
|
||||||
|
"""Increment failure counter; lock account after 10 failures."""
|
||||||
|
user.failed_login_count += 1
|
||||||
|
if user.failed_login_count >= 10:
|
||||||
|
user.locked_until = datetime.now() + timedelta(minutes=30)
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
|
||||||
|
async def record_successful_login(db: AsyncSession, user: User) -> None:
|
||||||
|
"""Reset failure counter and update last_login_at."""
|
||||||
|
user.failed_login_count = 0
|
||||||
|
user.locked_until = None
|
||||||
|
user.last_login_at = datetime.now()
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
|
||||||
|
async def create_db_session(
|
||||||
|
db: AsyncSession,
|
||||||
|
user: User,
|
||||||
|
ip: str,
|
||||||
|
user_agent: str | None,
|
||||||
|
) -> tuple[str, str]:
|
||||||
|
"""Insert a UserSession row and return (session_id, signed_cookie_token).
|
||||||
|
|
||||||
|
Enforces MAX_SESSIONS_PER_USER by revoking oldest sessions beyond the cap.
|
||||||
|
"""
|
||||||
|
session_id = uuid.uuid4().hex
|
||||||
|
expires_at = datetime.now() + timedelta(days=app_settings.SESSION_MAX_AGE_DAYS)
|
||||||
|
db_session = UserSession(
|
||||||
|
id=session_id,
|
||||||
|
user_id=user.id,
|
||||||
|
expires_at=expires_at,
|
||||||
|
ip_address=ip[:45] if ip else None,
|
||||||
|
user_agent=(user_agent or "")[:255] if user_agent else None,
|
||||||
|
)
|
||||||
|
db.add(db_session)
|
||||||
|
await db.flush()
|
||||||
|
|
||||||
|
# Enforce concurrent session limit: revoke oldest sessions beyond the cap
|
||||||
|
active_sessions = (
|
||||||
|
await db.execute(
|
||||||
|
select(UserSession)
|
||||||
|
.where(
|
||||||
|
UserSession.user_id == user.id,
|
||||||
|
UserSession.revoked == False, # noqa: E712
|
||||||
|
UserSession.expires_at > datetime.now(),
|
||||||
|
)
|
||||||
|
.order_by(UserSession.created_at.asc())
|
||||||
|
)
|
||||||
|
).scalars().all()
|
||||||
|
|
||||||
|
max_sessions = app_settings.MAX_SESSIONS_PER_USER
|
||||||
|
if len(active_sessions) > max_sessions:
|
||||||
|
for old_session in active_sessions[: len(active_sessions) - max_sessions]:
|
||||||
|
old_session.revoked = True
|
||||||
|
await db.flush()
|
||||||
|
|
||||||
|
token = create_session_token(user.id, session_id)
|
||||||
|
return session_id, token
|
||||||
Loading…
x
Reference in New Issue
Block a user