Phase 2: Backend critical path optimizations

- AC-1: Merge get_current_user into single JOIN query (session + user in
  one round-trip instead of two sequential queries per request)
- AC-2: Wrap all Argon2id hash/verify calls in run_in_executor to avoid
  blocking the async event loop (~150ms per operation)
- AW-7: Add connection pool config (pool_size=10, pool_pre_ping=True,
  pool_recycle=1800) to prevent connection exhaustion under load
- AC-4: Batch-fetch tasks in reorder_tasks with IN clause instead of
  N sequential queries during Kanban drag operations
- AW-4: Bulk NtfySent inserts with single commit per user instead of
  per-notification commits in the dispatch job

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Kyle 2026-03-13 00:05:54 +08:00
parent dbad9c69b3
commit 1f2083ee61
5 changed files with 60 additions and 34 deletions

View File

@ -2,11 +2,15 @@ from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sess
from sqlalchemy.orm import declarative_base from sqlalchemy.orm import declarative_base
from app.config import settings from app.config import settings
# Create async engine # Create async engine with tuned pool (AW-7)
engine = create_async_engine( engine = create_async_engine(
settings.DATABASE_URL, settings.DATABASE_URL,
echo=False, echo=False,
future=True future=True,
pool_size=10,
max_overflow=5,
pool_pre_ping=True,
pool_recycle=1800,
) )
# Create async session factory # Create async session factory

View File

@ -56,8 +56,8 @@ async def _get_sent_keys(db: AsyncSession, user_id: int) -> set[str]:
async def _mark_sent(db: AsyncSession, key: str, user_id: int) -> None: async def _mark_sent(db: AsyncSession, key: str, user_id: int) -> None:
"""Stage a sent record — caller must commit (AW-4: bulk commit per user)."""
db.add(NtfySent(notification_key=key, user_id=user_id)) db.add(NtfySent(notification_key=key, user_id=user_id))
await db.commit()
# ── Dispatch functions ──────────────────────────────────────────────────────── # ── Dispatch functions ────────────────────────────────────────────────────────
@ -248,6 +248,9 @@ async def _dispatch_for_user(db: AsyncSession, settings: Settings, now: datetime
if settings.ntfy_projects_enabled: if settings.ntfy_projects_enabled:
await _dispatch_projects(db, settings, now.date(), sent_keys) await _dispatch_projects(db, settings, now.date(), sent_keys)
# AW-4: Single commit per user instead of per-notification
await db.commit()
async def _purge_old_sent_records(db: AsyncSession) -> None: async def _purge_old_sent_records(db: AsyncSession) -> None:
"""Remove ntfy_sent entries older than 7 days to keep the table lean.""" """Remove ntfy_sent entries older than 7 days to keep the table lean."""

View File

@ -10,6 +10,7 @@ Security measures implemented:
All routes require the `require_admin` dependency (which chains through All routes require the `require_admin` dependency (which chains through
get_current_user, so the session cookie is always validated). get_current_user, so the session cookie is always validated).
""" """
import asyncio
import secrets import secrets
from datetime import datetime from datetime import datetime
from typing import Optional from typing import Optional
@ -222,10 +223,11 @@ async def create_user(
if email_exists.scalar_one_or_none(): if email_exists.scalar_one_or_none():
raise HTTPException(status_code=409, detail="Email already in use") raise HTTPException(status_code=409, detail="Email already in use")
loop = asyncio.get_running_loop()
new_user = User( new_user = User(
username=data.username, username=data.username,
umbral_name=data.username, umbral_name=data.username,
password_hash=hash_password(data.password), password_hash=await loop.run_in_executor(None, hash_password, data.password),
role=data.role, role=data.role,
email=email, email=email,
first_name=data.first_name, first_name=data.first_name,
@ -341,7 +343,8 @@ async def reset_user_password(
raise HTTPException(status_code=404, detail="User not found") raise HTTPException(status_code=404, detail="User not found")
temp_password = secrets.token_urlsafe(16) temp_password = secrets.token_urlsafe(16)
user.password_hash = hash_password(temp_password) loop = asyncio.get_running_loop()
user.password_hash = await loop.run_in_executor(None, hash_password, temp_password)
user.must_change_password = True user.must_change_password = True
user.last_password_change_at = datetime.now() user.last_password_change_at = datetime.now()

View File

@ -16,6 +16,7 @@ Security layers:
4. bcryptArgon2id transparent upgrade on first login 4. bcryptArgon2id transparent upgrade on first login
5. Role-based authorization via require_role() dependency factory 5. Role-based authorization via require_role() dependency factory
""" """
import asyncio
import uuid import uuid
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import Optional from typing import Optional
@ -101,25 +102,22 @@ async def get_current_user(
if user_id is None or session_id is None: if user_id is None or session_id is None:
raise HTTPException(status_code=401, detail="Malformed session token") raise HTTPException(status_code=401, detail="Malformed session token")
# Verify session is active in DB (covers revocation + expiry) # AC-1: Single JOIN query for session + user (was 2 sequential queries)
session_result = await db.execute( result = await db.execute(
select(UserSession).where( select(UserSession, User)
.join(User, UserSession.user_id == User.id)
.where(
UserSession.id == session_id, UserSession.id == session_id,
UserSession.user_id == user_id, UserSession.user_id == user_id,
UserSession.revoked == False, UserSession.revoked == False,
UserSession.expires_at > datetime.now(), UserSession.expires_at > datetime.now(),
User.is_active == True,
) )
) )
db_session = session_result.scalar_one_or_none() row = result.one_or_none()
if not db_session: if not row:
raise HTTPException(status_code=401, detail="Session has been revoked or expired") raise HTTPException(status_code=401, detail="Session expired or user inactive")
db_session, user = row.tuple()
user_result = await db.execute(
select(User).where(User.id == user_id, User.is_active == True)
)
user = user_result.scalar_one_or_none()
if not user:
raise HTTPException(status_code=401, detail="User not found or inactive")
# L-03: Sliding window renewal — extend session if >1 day has elapsed since # L-03: Sliding window renewal — extend session if >1 day has elapsed since
# last renewal (i.e. remaining time < SESSION_MAX_AGE_DAYS - 1 day). # last renewal (i.e. remaining time < SESSION_MAX_AGE_DAYS - 1 day).
@ -299,7 +297,8 @@ async def setup(
if user_count.scalar_one() > 0: if user_count.scalar_one() > 0:
raise HTTPException(status_code=400, detail="Setup already completed") raise HTTPException(status_code=400, detail="Setup already completed")
password_hash = hash_password(data.password) loop = asyncio.get_running_loop()
password_hash = await loop.run_in_executor(None, hash_password, data.password)
new_user = User( new_user = User(
username=data.username, username=data.username,
umbral_name=data.username, umbral_name=data.username,
@ -352,12 +351,18 @@ async def login(
if not user: if not user:
# M-02: Run Argon2id against a dummy hash so the response time is # M-02: Run Argon2id against a dummy hash so the response time is
# indistinguishable from a wrong-password attempt (prevents username enumeration). # indistinguishable from a wrong-password attempt (prevents username enumeration).
verify_password("x", _DUMMY_HASH) # AC-2: run_in_executor to avoid blocking the event loop (~150ms)
loop = asyncio.get_running_loop()
await loop.run_in_executor(None, verify_password, "x", _DUMMY_HASH)
raise HTTPException(status_code=401, detail="Invalid username or password") raise HTTPException(status_code=401, detail="Invalid username or password")
# M-02: Run password verification BEFORE lockout check so Argon2id always # M-02: Run password verification BEFORE lockout check so Argon2id always
# executes — prevents distinguishing "locked" from "wrong password" via timing. # executes — prevents distinguishing "locked" from "wrong password" via timing.
valid, new_hash = verify_password_with_upgrade(data.password, user.password_hash) # AC-2: run_in_executor to avoid blocking the event loop
loop = asyncio.get_running_loop()
valid, new_hash = await loop.run_in_executor(
None, verify_password_with_upgrade, data.password, user.password_hash
)
await _check_account_lockout(user) await _check_account_lockout(user)
@ -465,7 +470,8 @@ async def register(
if existing_email.scalar_one_or_none(): if existing_email.scalar_one_or_none():
raise HTTPException(status_code=400, detail="Registration could not be completed. Please check your details and try again.") raise HTTPException(status_code=400, detail="Registration could not be completed. Please check your details and try again.")
password_hash = hash_password(data.password) loop = asyncio.get_running_loop()
password_hash = await loop.run_in_executor(None, hash_password, data.password)
# SEC-01: Explicit field assignment — never **data.model_dump() # SEC-01: Explicit field assignment — never **data.model_dump()
new_user = User( new_user = User(
username=data.username, username=data.username,
@ -630,7 +636,10 @@ async def verify_password(
""" """
await _check_account_lockout(current_user) await _check_account_lockout(current_user)
valid, new_hash = verify_password_with_upgrade(data.password, current_user.password_hash) loop = asyncio.get_running_loop()
valid, new_hash = await loop.run_in_executor(
None, verify_password_with_upgrade, data.password, current_user.password_hash
)
if not valid: if not valid:
await _record_failed_login(db, current_user) await _record_failed_login(db, current_user)
raise HTTPException(status_code=401, detail="Invalid password") raise HTTPException(status_code=401, detail="Invalid password")
@ -656,7 +665,10 @@ async def change_password(
"""Change the current user's password. Requires old password verification.""" """Change the current user's password. Requires old password verification."""
await _check_account_lockout(current_user) await _check_account_lockout(current_user)
valid, _ = verify_password_with_upgrade(data.old_password, current_user.password_hash) loop = asyncio.get_running_loop()
valid, _ = await loop.run_in_executor(
None, verify_password_with_upgrade, data.old_password, current_user.password_hash
)
if not valid: if not valid:
await _record_failed_login(db, current_user) await _record_failed_login(db, current_user)
raise HTTPException(status_code=401, detail="Invalid current password") raise HTTPException(status_code=401, detail="Invalid current password")
@ -664,7 +676,7 @@ async def change_password(
if data.new_password == data.old_password: if data.new_password == data.old_password:
raise HTTPException(status_code=400, detail="New password must be different from your current password") raise HTTPException(status_code=400, detail="New password must be different from your current password")
current_user.password_hash = hash_password(data.new_password) current_user.password_hash = await loop.run_in_executor(None, hash_password, data.new_password)
current_user.last_password_change_at = datetime.now() current_user.last_password_change_at = datetime.now()
# Clear forced password change flag if set (SEC-12) # Clear forced password change flag if set (SEC-12)

View File

@ -294,16 +294,20 @@ async def reorder_tasks(
if not project: if not project:
raise HTTPException(status_code=404, detail="Project not found") raise HTTPException(status_code=404, detail="Project not found")
for item in items: # AC-4: Batch-fetch all tasks in one query instead of N sequential queries
task_result = await db.execute( task_ids = [item.id for item in items]
select(ProjectTask).where( task_result = await db.execute(
ProjectTask.id == item.id, select(ProjectTask).where(
ProjectTask.project_id == project_id ProjectTask.id.in_(task_ids),
) ProjectTask.project_id == project_id,
) )
task = task_result.scalar_one_or_none() )
if task: tasks_by_id = {t.id: t for t in task_result.scalars().all()}
task.sort_order = item.sort_order
order_map = {item.id: item.sort_order for item in items}
for task_id, task in tasks_by_id.items():
if task_id in order_map:
task.sort_order = order_map[task_id]
await db.commit() await db.commit()