Fix QA review findings: race condition, detached session, validation

- C-01: Wrap connection request flush in IntegrityError handler for
  TOCTOU race on partial unique index
- W-02: Extract ntfy config into plain dict before commit to avoid
  DetachedInstanceError in background tasks
- W-04: Add integer range validation (1–2147483647) on notification IDs
- W-07: Add typed response models for respond_to_request endpoint
- W-09: Document resolved_at requirement for future cancel endpoint
- S-02: Use Literal type for ConnectionRequestResponse.status
- S-04: Check ntfy master switch in extract_ntfy_config
- S-05: Move date import to module level in connection service

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Kyle 2026-03-04 06:36:14 +08:00
parent e27beb7736
commit 0e94b6e1f7
5 changed files with 78 additions and 22 deletions

View File

@ -277,7 +277,12 @@ async def _purge_old_notifications(db: AsyncSession) -> None:
async def _purge_resolved_requests(db: AsyncSession) -> None: async def _purge_resolved_requests(db: AsyncSession) -> None:
"""Remove rejected/cancelled connection requests older than 30 days.""" """Remove rejected/cancelled connection requests older than 30 days.
Note: resolved_at must be set when changing status to rejected/cancelled.
Rows with NULL resolved_at are preserved (comparison with NULL yields NULL).
Any future cancel endpoint must set resolved_at = now on status change.
"""
cutoff = datetime.now() - timedelta(days=30) cutoff = datetime.now() - timedelta(days=30)
await db.execute( await db.execute(
delete(ConnectionRequest).where( delete(ConnectionRequest).where(

View File

@ -13,6 +13,7 @@ from datetime import datetime, timedelta, timezone
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, Path, Query, Request from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, Path, Query, Request
from sqlalchemy import select, func, and_, update from sqlalchemy import select, func, and_, update
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload from sqlalchemy.orm import selectinload
@ -27,6 +28,8 @@ from app.routers.auth import get_current_user
from app.schemas.connection import ( from app.schemas.connection import (
ConnectionRequestResponse, ConnectionRequestResponse,
ConnectionResponse, ConnectionResponse,
RespondAcceptResponse,
RespondRejectResponse,
RespondRequest, RespondRequest,
SendConnectionRequest, SendConnectionRequest,
SharingOverrideUpdate, SharingOverrideUpdate,
@ -38,6 +41,7 @@ from app.services.connection import (
SHAREABLE_FIELDS, SHAREABLE_FIELDS,
create_person_from_connection, create_person_from_connection,
detach_umbral_contact, detach_umbral_contact,
extract_ntfy_config,
resolve_shared_profile, resolve_shared_profile,
send_connection_ntfy, send_connection_ntfy,
) )
@ -190,13 +194,17 @@ async def send_connection_request(
if pending_count >= 5: if pending_count >= 5:
raise HTTPException(status_code=429, detail="Too many pending requests for this user") raise HTTPException(status_code=429, detail="Too many pending requests for this user")
# Create the request # Create the request (IntegrityError guard for TOCTOU race on partial unique index)
conn_request = ConnectionRequest( conn_request = ConnectionRequest(
sender_id=current_user.id, sender_id=current_user.id,
receiver_id=target.id, receiver_id=target.id,
) )
db.add(conn_request) db.add(conn_request)
await db.flush() # populate conn_request.id for source_id try:
await db.flush() # populate conn_request.id for source_id
except IntegrityError:
await db.rollback()
raise HTTPException(status_code=409, detail="A pending request already exists")
# Create in-app notification for receiver (sender_settings already fetched above) # Create in-app notification for receiver (sender_settings already fetched above)
sender_display = (sender_settings.preferred_name if sender_settings else None) or current_user.umbral_name sender_display = (sender_settings.preferred_name if sender_settings else None) or current_user.umbral_name
@ -221,13 +229,16 @@ async def send_connection_request(
ip=get_client_ip(request), ip=get_client_ip(request),
) )
# Extract ntfy config before commit (avoids detached SA object in background task)
target_ntfy = extract_ntfy_config(target_settings) if target_settings else None
await db.commit() await db.commit()
await db.refresh(conn_request) await db.refresh(conn_request)
# ntfy push in background (non-blocking) # ntfy push in background (non-blocking)
background_tasks.add_task( background_tasks.add_task(
send_connection_ntfy, send_connection_ntfy,
target_settings, target_ntfy,
sender_display, sender_display,
"request_received", "request_received",
) )
@ -319,7 +330,7 @@ async def get_outgoing_requests(
# ── PUT /requests/{id}/respond ────────────────────────────────────── # ── PUT /requests/{id}/respond ──────────────────────────────────────
@router.put("/requests/{request_id}/respond") @router.put("/requests/{request_id}/respond", response_model=RespondAcceptResponse | RespondRejectResponse)
async def respond_to_request( async def respond_to_request(
body: RespondRequest, body: RespondRequest,
request: Request, request: Request,
@ -419,16 +430,18 @@ async def respond_to_request(
ip=get_client_ip(request), ip=get_client_ip(request),
) )
# Extract ntfy config before commit (avoids detached SA object in background task)
sender_ntfy = extract_ntfy_config(sender_settings) if sender_settings else None
await db.commit() await db.commit()
# ntfy push in background # ntfy push in background
if sender_settings: background_tasks.add_task(
background_tasks.add_task( send_connection_ntfy,
send_connection_ntfy, sender_ntfy,
sender_settings, receiver_display,
receiver_display, "request_accepted",
"request_accepted", )
)
return {"message": "Connection accepted", "connection_id": conn_a.id} return {"message": "Connection accepted", "connection_id": conn_a.id}

View File

@ -45,7 +45,7 @@ class ConnectionRequestResponse(BaseModel):
sender_preferred_name: Optional[str] = None sender_preferred_name: Optional[str] = None
receiver_umbral_name: str receiver_umbral_name: str
receiver_preferred_name: Optional[str] = None receiver_preferred_name: Optional[str] = None
status: str status: Literal["pending", "accepted", "rejected", "cancelled"]
created_at: datetime created_at: datetime
@ -63,6 +63,15 @@ class ConnectionResponse(BaseModel):
created_at: datetime created_at: datetime
class RespondAcceptResponse(BaseModel):
message: str
connection_id: int
class RespondRejectResponse(BaseModel):
message: str
class SharingOverrideUpdate(BaseModel): class SharingOverrideUpdate(BaseModel):
model_config = ConfigDict(extra="forbid") model_config = ConfigDict(extra="forbid")
preferred_name: Optional[bool] = None preferred_name: Optional[bool] = None

View File

@ -1,4 +1,4 @@
from pydantic import BaseModel, ConfigDict, Field from pydantic import BaseModel, ConfigDict, Field, field_validator
from datetime import datetime from datetime import datetime
from typing import Optional from typing import Optional
@ -27,4 +27,12 @@ class NotificationListResponse(BaseModel):
class MarkReadRequest(BaseModel): class MarkReadRequest(BaseModel):
model_config = ConfigDict(extra="forbid") model_config = ConfigDict(extra="forbid")
notification_ids: list[int] = Field(..., min_length=1, max_length=100) notification_ids: list[int] = Field(..., min_length=1, max_length=100, json_schema_extra={"items": {"minimum": 1, "maximum": 2147483647}})
@field_validator('notification_ids')
@classmethod
def validate_ids(cls, v: list[int]) -> list[int]:
for i in v:
if i < 1 or i > 2147483647:
raise ValueError('Each notification ID must be between 1 and 2147483647')
return v

View File

@ -5,6 +5,8 @@ SHAREABLE_FIELDS is the single source of truth for which fields can be shared.
""" """
import asyncio import asyncio
import logging import logging
from datetime import date as date_type
from types import SimpleNamespace
from typing import Optional from typing import Optional
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
@ -87,7 +89,6 @@ def create_person_from_connection(
job_title = shared_profile.get("job_title") job_title = shared_profile.get("job_title")
birthday_str = shared_profile.get("birthday") birthday_str = shared_profile.get("birthday")
from datetime import date as date_type
birthday = None birthday = None
if birthday_str: if birthday_str:
try: try:
@ -125,13 +126,30 @@ async def detach_umbral_contact(person: Person) -> None:
person.first_name = person.name or None person.first_name = person.name or None
def extract_ntfy_config(settings: Settings) -> dict | None:
"""Extract ntfy config values into a plain dict safe for use after session close."""
if not settings.ntfy_enabled or not settings.ntfy_connections_enabled:
return None
return {
"ntfy_enabled": True,
"ntfy_server_url": settings.ntfy_server_url,
"ntfy_topic": settings.ntfy_topic,
"ntfy_auth_token": settings.ntfy_auth_token,
"user_id": settings.user_id,
}
async def send_connection_ntfy( async def send_connection_ntfy(
settings: Settings, ntfy_config: dict | None,
sender_name: str, sender_name: str,
event_type: str, event_type: str,
) -> None: ) -> None:
"""Send ntfy push for connection events. Non-blocking with 3s timeout.""" """Send ntfy push for connection events. Non-blocking with 3s timeout.
if not settings.ntfy_connections_enabled:
Accepts a plain dict (from extract_ntfy_config) to avoid accessing
detached SQLAlchemy objects after session close.
"""
if not ntfy_config:
return return
title_map = { title_map = {
@ -151,10 +169,13 @@ async def send_connection_ntfy(
message = message_map.get(event_type, f"Connection update from {sender_name}") message = message_map.get(event_type, f"Connection update from {sender_name}")
tags = tag_map.get(event_type, ["bell"]) tags = tag_map.get(event_type, ["bell"])
# Build a settings-like object for send_ntfy_notification (avoids detached SA objects)
settings_proxy = SimpleNamespace(**ntfy_config)
try: try:
await asyncio.wait_for( await asyncio.wait_for(
send_ntfy_notification( send_ntfy_notification(
settings=settings, settings=settings_proxy,
title=title, title=title,
message=message, message=message,
tags=tags, tags=tags,
@ -163,6 +184,6 @@ async def send_connection_ntfy(
timeout=3.0, timeout=3.0,
) )
except asyncio.TimeoutError: except asyncio.TimeoutError:
logger.warning("ntfy connection push timed out for user_id=%s", settings.user_id) logger.warning("ntfy connection push timed out for user_id=%s", ntfy_config["user_id"])
except Exception: except Exception:
logger.warning("ntfy connection push failed for user_id=%s", settings.user_id) logger.warning("ntfy connection push failed for user_id=%s", ntfy_config["user_id"])