Kyle Pope b134ad9e8b Implement Stage 6 Track B: TOTP MFA (pyotp, Fernet-encrypted secrets, backup codes)
- models/totp_usage.py: replay-prevention table, unique on (user_id, code, window)
- models/backup_code.py: Argon2id-hashed recovery codes with used_at tracking
- services/totp.py: Fernet encrypt/decrypt, verify_totp_code returns actual window, QR base64, backup code generation
- routers/totp.py: setup (idempotent), confirm, totp-verify (mfa_token + TOTP or backup code), disable, regenerate, status
- alembic/024: creates totp_usage and backup_codes tables
- main.py: register totp router, import new models for Alembic discovery
- requirements.txt: add pyotp>=2.9.0, qrcode[pil]>=7.4.0, cryptography>=42.0.0
- jobs/notifications.py: periodic cleanup for totp_usage (5 min) and expired user_sessions

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-25 04:18:05 +08:00

135 lines
4.7 KiB
Python

"""
TOTP service: secret generation/encryption, code verification, QR code generation,
backup code generation.
All TOTP secrets are Fernet-encrypted at rest using a key derived from SECRET_KEY.
Raw secrets are never logged and are only returned to the client once (at setup).
"""
import pyotp
import secrets
import string
import time
import io
import base64
import hashlib
import qrcode
from cryptography.fernet import Fernet
from app.config import settings as app_settings
# ---------------------------------------------------------------------------
# Fernet key derivation
# ---------------------------------------------------------------------------
def _get_fernet() -> Fernet:
"""Derive a 32-byte Fernet key from SECRET_KEY via SHA-256."""
key = hashlib.sha256(app_settings.SECRET_KEY.encode()).digest()
return Fernet(base64.urlsafe_b64encode(key))
# ---------------------------------------------------------------------------
# Secret management
# ---------------------------------------------------------------------------
def generate_totp_secret() -> str:
"""Generate a new random TOTP secret (base32, ~160 bits entropy)."""
return pyotp.random_base32()
def encrypt_totp_secret(raw: str) -> str:
"""Encrypt a raw TOTP secret before storing in the DB."""
return _get_fernet().encrypt(raw.encode()).decode()
def decrypt_totp_secret(encrypted: str) -> str:
"""Decrypt a TOTP secret retrieved from the DB."""
return _get_fernet().decrypt(encrypted.encode()).decode()
# ---------------------------------------------------------------------------
# Provisioning URI and QR code
# ---------------------------------------------------------------------------
def get_totp_uri(encrypted_secret: str, username: str) -> str:
"""Return the otpauth:// provisioning URI for QR code generation."""
raw = decrypt_totp_secret(encrypted_secret)
totp = pyotp.TOTP(raw)
return totp.provisioning_uri(name=username, issuer_name=app_settings.TOTP_ISSUER)
def generate_qr_base64(uri: str) -> str:
"""Return a base64-encoded PNG of the QR code for the provisioning URI."""
qr = qrcode.QRCode(
version=1,
error_correction=qrcode.constants.ERROR_CORRECT_L,
box_size=10,
border=4,
)
qr.add_data(uri)
qr.make(fit=True)
img = qr.make_image(fill_color="black", back_color="white")
buf = io.BytesIO()
img.save(buf, format="PNG")
return base64.b64encode(buf.getvalue()).decode()
# ---------------------------------------------------------------------------
# Code verification
# ---------------------------------------------------------------------------
def verify_totp_code(encrypted_secret: str, code: str, valid_window: int = 1) -> int | None:
"""
Verify a TOTP code and return the matching time window, or None if invalid.
Checks each window individually (T-valid_window ... T+valid_window) so the
caller knows WHICH window matched — required for correct replay-prevention
(the TOTPUsage row must record the actual matching window, not the current one).
Uses secrets.compare_digest for constant-time comparison to prevent timing attacks.
Returns:
int — the floor(unix_time / 30) window value that matched
None — no window matched (invalid code)
"""
raw = decrypt_totp_secret(encrypted_secret)
totp = pyotp.TOTP(raw)
current_window = int(time.time() // 30)
for offset in range(-valid_window, valid_window + 1):
check_window = current_window + offset
# pyotp.at() accepts a unix timestamp; multiply window back to seconds
expected_code = totp.at(check_window * 30)
if secrets.compare_digest(code.strip(), expected_code):
return check_window # Return the ACTUAL window that matched
return None # No window matched
# ---------------------------------------------------------------------------
# Backup codes
# ---------------------------------------------------------------------------
def generate_backup_codes(count: int = 10) -> list[str]:
"""
Generate recovery backup codes in XXXX-XXXX format.
Uses cryptographically secure randomness (secrets module).
"""
alphabet = string.ascii_uppercase + string.digits
return [
"".join(secrets.choice(alphabet) for _ in range(4))
+ "-"
+ "".join(secrets.choice(alphabet) for _ in range(4))
for _ in range(count)
]
# ---------------------------------------------------------------------------
# Utility
# ---------------------------------------------------------------------------
def get_totp_window() -> int:
"""Return the current TOTP time window (floor(unix_time / 30))."""
return int(time.time() // 30)