Fix QA review findings: race condition, detached session, validation
- C-01: Wrap connection request flush in IntegrityError handler for TOCTOU race on partial unique index - W-02: Extract ntfy config into plain dict before commit to avoid DetachedInstanceError in background tasks - W-04: Add integer range validation (1–2147483647) on notification IDs - W-07: Add typed response models for respond_to_request endpoint - W-09: Document resolved_at requirement for future cancel endpoint - S-02: Use Literal type for ConnectionRequestResponse.status - S-04: Check ntfy master switch in extract_ntfy_config - S-05: Move date import to module level in connection service Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
e27beb7736
commit
0e94b6e1f7
@ -277,7 +277,12 @@ async def _purge_old_notifications(db: AsyncSession) -> None:
|
||||
|
||||
|
||||
async def _purge_resolved_requests(db: AsyncSession) -> None:
|
||||
"""Remove rejected/cancelled connection requests older than 30 days."""
|
||||
"""Remove rejected/cancelled connection requests older than 30 days.
|
||||
|
||||
Note: resolved_at must be set when changing status to rejected/cancelled.
|
||||
Rows with NULL resolved_at are preserved (comparison with NULL yields NULL).
|
||||
Any future cancel endpoint must set resolved_at = now on status change.
|
||||
"""
|
||||
cutoff = datetime.now() - timedelta(days=30)
|
||||
await db.execute(
|
||||
delete(ConnectionRequest).where(
|
||||
|
||||
@ -13,6 +13,7 @@ from datetime import datetime, timedelta, timezone
|
||||
|
||||
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, Path, Query, Request
|
||||
from sqlalchemy import select, func, and_, update
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
@ -27,6 +28,8 @@ from app.routers.auth import get_current_user
|
||||
from app.schemas.connection import (
|
||||
ConnectionRequestResponse,
|
||||
ConnectionResponse,
|
||||
RespondAcceptResponse,
|
||||
RespondRejectResponse,
|
||||
RespondRequest,
|
||||
SendConnectionRequest,
|
||||
SharingOverrideUpdate,
|
||||
@ -38,6 +41,7 @@ from app.services.connection import (
|
||||
SHAREABLE_FIELDS,
|
||||
create_person_from_connection,
|
||||
detach_umbral_contact,
|
||||
extract_ntfy_config,
|
||||
resolve_shared_profile,
|
||||
send_connection_ntfy,
|
||||
)
|
||||
@ -190,13 +194,17 @@ async def send_connection_request(
|
||||
if pending_count >= 5:
|
||||
raise HTTPException(status_code=429, detail="Too many pending requests for this user")
|
||||
|
||||
# Create the request
|
||||
# Create the request (IntegrityError guard for TOCTOU race on partial unique index)
|
||||
conn_request = ConnectionRequest(
|
||||
sender_id=current_user.id,
|
||||
receiver_id=target.id,
|
||||
)
|
||||
db.add(conn_request)
|
||||
await db.flush() # populate conn_request.id for source_id
|
||||
try:
|
||||
await db.flush() # populate conn_request.id for source_id
|
||||
except IntegrityError:
|
||||
await db.rollback()
|
||||
raise HTTPException(status_code=409, detail="A pending request already exists")
|
||||
|
||||
# Create in-app notification for receiver (sender_settings already fetched above)
|
||||
sender_display = (sender_settings.preferred_name if sender_settings else None) or current_user.umbral_name
|
||||
@ -221,13 +229,16 @@ async def send_connection_request(
|
||||
ip=get_client_ip(request),
|
||||
)
|
||||
|
||||
# Extract ntfy config before commit (avoids detached SA object in background task)
|
||||
target_ntfy = extract_ntfy_config(target_settings) if target_settings else None
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(conn_request)
|
||||
|
||||
# ntfy push in background (non-blocking)
|
||||
background_tasks.add_task(
|
||||
send_connection_ntfy,
|
||||
target_settings,
|
||||
target_ntfy,
|
||||
sender_display,
|
||||
"request_received",
|
||||
)
|
||||
@ -319,7 +330,7 @@ async def get_outgoing_requests(
|
||||
|
||||
# ── PUT /requests/{id}/respond ──────────────────────────────────────
|
||||
|
||||
@router.put("/requests/{request_id}/respond")
|
||||
@router.put("/requests/{request_id}/respond", response_model=RespondAcceptResponse | RespondRejectResponse)
|
||||
async def respond_to_request(
|
||||
body: RespondRequest,
|
||||
request: Request,
|
||||
@ -419,16 +430,18 @@ async def respond_to_request(
|
||||
ip=get_client_ip(request),
|
||||
)
|
||||
|
||||
# Extract ntfy config before commit (avoids detached SA object in background task)
|
||||
sender_ntfy = extract_ntfy_config(sender_settings) if sender_settings else None
|
||||
|
||||
await db.commit()
|
||||
|
||||
# ntfy push in background
|
||||
if sender_settings:
|
||||
background_tasks.add_task(
|
||||
send_connection_ntfy,
|
||||
sender_settings,
|
||||
receiver_display,
|
||||
"request_accepted",
|
||||
)
|
||||
background_tasks.add_task(
|
||||
send_connection_ntfy,
|
||||
sender_ntfy,
|
||||
receiver_display,
|
||||
"request_accepted",
|
||||
)
|
||||
|
||||
return {"message": "Connection accepted", "connection_id": conn_a.id}
|
||||
|
||||
|
||||
@ -45,7 +45,7 @@ class ConnectionRequestResponse(BaseModel):
|
||||
sender_preferred_name: Optional[str] = None
|
||||
receiver_umbral_name: str
|
||||
receiver_preferred_name: Optional[str] = None
|
||||
status: str
|
||||
status: Literal["pending", "accepted", "rejected", "cancelled"]
|
||||
created_at: datetime
|
||||
|
||||
|
||||
@ -63,6 +63,15 @@ class ConnectionResponse(BaseModel):
|
||||
created_at: datetime
|
||||
|
||||
|
||||
class RespondAcceptResponse(BaseModel):
|
||||
message: str
|
||||
connection_id: int
|
||||
|
||||
|
||||
class RespondRejectResponse(BaseModel):
|
||||
message: str
|
||||
|
||||
|
||||
class SharingOverrideUpdate(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
preferred_name: Optional[bool] = None
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
@ -27,4 +27,12 @@ class NotificationListResponse(BaseModel):
|
||||
class MarkReadRequest(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
notification_ids: list[int] = Field(..., min_length=1, max_length=100)
|
||||
notification_ids: list[int] = Field(..., min_length=1, max_length=100, json_schema_extra={"items": {"minimum": 1, "maximum": 2147483647}})
|
||||
|
||||
@field_validator('notification_ids')
|
||||
@classmethod
|
||||
def validate_ids(cls, v: list[int]) -> list[int]:
|
||||
for i in v:
|
||||
if i < 1 or i > 2147483647:
|
||||
raise ValueError('Each notification ID must be between 1 and 2147483647')
|
||||
return v
|
||||
|
||||
@ -5,6 +5,8 @@ SHAREABLE_FIELDS is the single source of truth for which fields can be shared.
|
||||
"""
|
||||
import asyncio
|
||||
import logging
|
||||
from datetime import date as date_type
|
||||
from types import SimpleNamespace
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
@ -87,7 +89,6 @@ def create_person_from_connection(
|
||||
job_title = shared_profile.get("job_title")
|
||||
birthday_str = shared_profile.get("birthday")
|
||||
|
||||
from datetime import date as date_type
|
||||
birthday = None
|
||||
if birthday_str:
|
||||
try:
|
||||
@ -125,13 +126,30 @@ async def detach_umbral_contact(person: Person) -> None:
|
||||
person.first_name = person.name or None
|
||||
|
||||
|
||||
def extract_ntfy_config(settings: Settings) -> dict | None:
|
||||
"""Extract ntfy config values into a plain dict safe for use after session close."""
|
||||
if not settings.ntfy_enabled or not settings.ntfy_connections_enabled:
|
||||
return None
|
||||
return {
|
||||
"ntfy_enabled": True,
|
||||
"ntfy_server_url": settings.ntfy_server_url,
|
||||
"ntfy_topic": settings.ntfy_topic,
|
||||
"ntfy_auth_token": settings.ntfy_auth_token,
|
||||
"user_id": settings.user_id,
|
||||
}
|
||||
|
||||
|
||||
async def send_connection_ntfy(
|
||||
settings: Settings,
|
||||
ntfy_config: dict | None,
|
||||
sender_name: str,
|
||||
event_type: str,
|
||||
) -> None:
|
||||
"""Send ntfy push for connection events. Non-blocking with 3s timeout."""
|
||||
if not settings.ntfy_connections_enabled:
|
||||
"""Send ntfy push for connection events. Non-blocking with 3s timeout.
|
||||
|
||||
Accepts a plain dict (from extract_ntfy_config) to avoid accessing
|
||||
detached SQLAlchemy objects after session close.
|
||||
"""
|
||||
if not ntfy_config:
|
||||
return
|
||||
|
||||
title_map = {
|
||||
@ -151,10 +169,13 @@ async def send_connection_ntfy(
|
||||
message = message_map.get(event_type, f"Connection update from {sender_name}")
|
||||
tags = tag_map.get(event_type, ["bell"])
|
||||
|
||||
# Build a settings-like object for send_ntfy_notification (avoids detached SA objects)
|
||||
settings_proxy = SimpleNamespace(**ntfy_config)
|
||||
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
send_ntfy_notification(
|
||||
settings=settings,
|
||||
settings=settings_proxy,
|
||||
title=title,
|
||||
message=message,
|
||||
tags=tags,
|
||||
@ -163,6 +184,6 @@ async def send_connection_ntfy(
|
||||
timeout=3.0,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning("ntfy connection push timed out for user_id=%s", settings.user_id)
|
||||
logger.warning("ntfy connection push timed out for user_id=%s", ntfy_config["user_id"])
|
||||
except Exception:
|
||||
logger.warning("ntfy connection push failed for user_id=%s", settings.user_id)
|
||||
logger.warning("ntfy connection push failed for user_id=%s", ntfy_config["user_id"])
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user