Fix QA review findings: race condition, detached session, validation

- C-01: Wrap connection request flush in IntegrityError handler for
  TOCTOU race on partial unique index
- W-02: Extract ntfy config into plain dict before commit to avoid
  DetachedInstanceError in background tasks
- W-04: Add integer range validation (1–2147483647) on notification IDs
- W-07: Add typed response models for respond_to_request endpoint
- W-09: Document resolved_at requirement for future cancel endpoint
- S-02: Use Literal type for ConnectionRequestResponse.status
- S-04: Check ntfy master switch in extract_ntfy_config
- S-05: Move date import to module level in connection service

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Kyle 2026-03-04 06:36:14 +08:00
parent e27beb7736
commit 0e94b6e1f7
5 changed files with 78 additions and 22 deletions

View File

@ -277,7 +277,12 @@ async def _purge_old_notifications(db: AsyncSession) -> None:
async def _purge_resolved_requests(db: AsyncSession) -> None:
"""Remove rejected/cancelled connection requests older than 30 days."""
"""Remove rejected/cancelled connection requests older than 30 days.
Note: resolved_at must be set when changing status to rejected/cancelled.
Rows with NULL resolved_at are preserved (comparison with NULL yields NULL).
Any future cancel endpoint must set resolved_at = now on status change.
"""
cutoff = datetime.now() - timedelta(days=30)
await db.execute(
delete(ConnectionRequest).where(

View File

@ -13,6 +13,7 @@ from datetime import datetime, timedelta, timezone
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, Path, Query, Request
from sqlalchemy import select, func, and_, update
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
@ -27,6 +28,8 @@ from app.routers.auth import get_current_user
from app.schemas.connection import (
ConnectionRequestResponse,
ConnectionResponse,
RespondAcceptResponse,
RespondRejectResponse,
RespondRequest,
SendConnectionRequest,
SharingOverrideUpdate,
@ -38,6 +41,7 @@ from app.services.connection import (
SHAREABLE_FIELDS,
create_person_from_connection,
detach_umbral_contact,
extract_ntfy_config,
resolve_shared_profile,
send_connection_ntfy,
)
@ -190,13 +194,17 @@ async def send_connection_request(
if pending_count >= 5:
raise HTTPException(status_code=429, detail="Too many pending requests for this user")
# Create the request
# Create the request (IntegrityError guard for TOCTOU race on partial unique index)
conn_request = ConnectionRequest(
sender_id=current_user.id,
receiver_id=target.id,
)
db.add(conn_request)
try:
await db.flush() # populate conn_request.id for source_id
except IntegrityError:
await db.rollback()
raise HTTPException(status_code=409, detail="A pending request already exists")
# Create in-app notification for receiver (sender_settings already fetched above)
sender_display = (sender_settings.preferred_name if sender_settings else None) or current_user.umbral_name
@ -221,13 +229,16 @@ async def send_connection_request(
ip=get_client_ip(request),
)
# Extract ntfy config before commit (avoids detached SA object in background task)
target_ntfy = extract_ntfy_config(target_settings) if target_settings else None
await db.commit()
await db.refresh(conn_request)
# ntfy push in background (non-blocking)
background_tasks.add_task(
send_connection_ntfy,
target_settings,
target_ntfy,
sender_display,
"request_received",
)
@ -319,7 +330,7 @@ async def get_outgoing_requests(
# ── PUT /requests/{id}/respond ──────────────────────────────────────
@router.put("/requests/{request_id}/respond")
@router.put("/requests/{request_id}/respond", response_model=RespondAcceptResponse | RespondRejectResponse)
async def respond_to_request(
body: RespondRequest,
request: Request,
@ -419,13 +430,15 @@ async def respond_to_request(
ip=get_client_ip(request),
)
# Extract ntfy config before commit (avoids detached SA object in background task)
sender_ntfy = extract_ntfy_config(sender_settings) if sender_settings else None
await db.commit()
# ntfy push in background
if sender_settings:
background_tasks.add_task(
send_connection_ntfy,
sender_settings,
sender_ntfy,
receiver_display,
"request_accepted",
)

View File

@ -45,7 +45,7 @@ class ConnectionRequestResponse(BaseModel):
sender_preferred_name: Optional[str] = None
receiver_umbral_name: str
receiver_preferred_name: Optional[str] = None
status: str
status: Literal["pending", "accepted", "rejected", "cancelled"]
created_at: datetime
@ -63,6 +63,15 @@ class ConnectionResponse(BaseModel):
created_at: datetime
class RespondAcceptResponse(BaseModel):
message: str
connection_id: int
class RespondRejectResponse(BaseModel):
message: str
class SharingOverrideUpdate(BaseModel):
model_config = ConfigDict(extra="forbid")
preferred_name: Optional[bool] = None

View File

@ -1,4 +1,4 @@
from pydantic import BaseModel, ConfigDict, Field
from pydantic import BaseModel, ConfigDict, Field, field_validator
from datetime import datetime
from typing import Optional
@ -27,4 +27,12 @@ class NotificationListResponse(BaseModel):
class MarkReadRequest(BaseModel):
model_config = ConfigDict(extra="forbid")
notification_ids: list[int] = Field(..., min_length=1, max_length=100)
notification_ids: list[int] = Field(..., min_length=1, max_length=100, json_schema_extra={"items": {"minimum": 1, "maximum": 2147483647}})
@field_validator('notification_ids')
@classmethod
def validate_ids(cls, v: list[int]) -> list[int]:
for i in v:
if i < 1 or i > 2147483647:
raise ValueError('Each notification ID must be between 1 and 2147483647')
return v

View File

@ -5,6 +5,8 @@ SHAREABLE_FIELDS is the single source of truth for which fields can be shared.
"""
import asyncio
import logging
from datetime import date as date_type
from types import SimpleNamespace
from typing import Optional
from sqlalchemy.ext.asyncio import AsyncSession
@ -87,7 +89,6 @@ def create_person_from_connection(
job_title = shared_profile.get("job_title")
birthday_str = shared_profile.get("birthday")
from datetime import date as date_type
birthday = None
if birthday_str:
try:
@ -125,13 +126,30 @@ async def detach_umbral_contact(person: Person) -> None:
person.first_name = person.name or None
def extract_ntfy_config(settings: Settings) -> dict | None:
"""Extract ntfy config values into a plain dict safe for use after session close."""
if not settings.ntfy_enabled or not settings.ntfy_connections_enabled:
return None
return {
"ntfy_enabled": True,
"ntfy_server_url": settings.ntfy_server_url,
"ntfy_topic": settings.ntfy_topic,
"ntfy_auth_token": settings.ntfy_auth_token,
"user_id": settings.user_id,
}
async def send_connection_ntfy(
settings: Settings,
ntfy_config: dict | None,
sender_name: str,
event_type: str,
) -> None:
"""Send ntfy push for connection events. Non-blocking with 3s timeout."""
if not settings.ntfy_connections_enabled:
"""Send ntfy push for connection events. Non-blocking with 3s timeout.
Accepts a plain dict (from extract_ntfy_config) to avoid accessing
detached SQLAlchemy objects after session close.
"""
if not ntfy_config:
return
title_map = {
@ -151,10 +169,13 @@ async def send_connection_ntfy(
message = message_map.get(event_type, f"Connection update from {sender_name}")
tags = tag_map.get(event_type, ["bell"])
# Build a settings-like object for send_ntfy_notification (avoids detached SA objects)
settings_proxy = SimpleNamespace(**ntfy_config)
try:
await asyncio.wait_for(
send_ntfy_notification(
settings=settings,
settings=settings_proxy,
title=title,
message=message,
tags=tags,
@ -163,6 +184,6 @@ async def send_connection_ntfy(
timeout=3.0,
)
except asyncio.TimeoutError:
logger.warning("ntfy connection push timed out for user_id=%s", settings.user_id)
logger.warning("ntfy connection push timed out for user_id=%s", ntfy_config["user_id"])
except Exception:
logger.warning("ntfy connection push failed for user_id=%s", settings.user_id)
logger.warning("ntfy connection push failed for user_id=%s", ntfy_config["user_id"])