Fix QA review findings: race condition, detached session, validation
- C-01: Wrap connection request flush in IntegrityError handler for TOCTOU race on partial unique index - W-02: Extract ntfy config into plain dict before commit to avoid DetachedInstanceError in background tasks - W-04: Add integer range validation (1–2147483647) on notification IDs - W-07: Add typed response models for respond_to_request endpoint - W-09: Document resolved_at requirement for future cancel endpoint - S-02: Use Literal type for ConnectionRequestResponse.status - S-04: Check ntfy master switch in extract_ntfy_config - S-05: Move date import to module level in connection service Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
e27beb7736
commit
0e94b6e1f7
@ -277,7 +277,12 @@ async def _purge_old_notifications(db: AsyncSession) -> None:
|
|||||||
|
|
||||||
|
|
||||||
async def _purge_resolved_requests(db: AsyncSession) -> None:
|
async def _purge_resolved_requests(db: AsyncSession) -> None:
|
||||||
"""Remove rejected/cancelled connection requests older than 30 days."""
|
"""Remove rejected/cancelled connection requests older than 30 days.
|
||||||
|
|
||||||
|
Note: resolved_at must be set when changing status to rejected/cancelled.
|
||||||
|
Rows with NULL resolved_at are preserved (comparison with NULL yields NULL).
|
||||||
|
Any future cancel endpoint must set resolved_at = now on status change.
|
||||||
|
"""
|
||||||
cutoff = datetime.now() - timedelta(days=30)
|
cutoff = datetime.now() - timedelta(days=30)
|
||||||
await db.execute(
|
await db.execute(
|
||||||
delete(ConnectionRequest).where(
|
delete(ConnectionRequest).where(
|
||||||
|
|||||||
@ -13,6 +13,7 @@ from datetime import datetime, timedelta, timezone
|
|||||||
|
|
||||||
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, Path, Query, Request
|
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, Path, Query, Request
|
||||||
from sqlalchemy import select, func, and_, update
|
from sqlalchemy import select, func, and_, update
|
||||||
|
from sqlalchemy.exc import IntegrityError
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
from sqlalchemy.orm import selectinload
|
from sqlalchemy.orm import selectinload
|
||||||
|
|
||||||
@ -27,6 +28,8 @@ from app.routers.auth import get_current_user
|
|||||||
from app.schemas.connection import (
|
from app.schemas.connection import (
|
||||||
ConnectionRequestResponse,
|
ConnectionRequestResponse,
|
||||||
ConnectionResponse,
|
ConnectionResponse,
|
||||||
|
RespondAcceptResponse,
|
||||||
|
RespondRejectResponse,
|
||||||
RespondRequest,
|
RespondRequest,
|
||||||
SendConnectionRequest,
|
SendConnectionRequest,
|
||||||
SharingOverrideUpdate,
|
SharingOverrideUpdate,
|
||||||
@ -38,6 +41,7 @@ from app.services.connection import (
|
|||||||
SHAREABLE_FIELDS,
|
SHAREABLE_FIELDS,
|
||||||
create_person_from_connection,
|
create_person_from_connection,
|
||||||
detach_umbral_contact,
|
detach_umbral_contact,
|
||||||
|
extract_ntfy_config,
|
||||||
resolve_shared_profile,
|
resolve_shared_profile,
|
||||||
send_connection_ntfy,
|
send_connection_ntfy,
|
||||||
)
|
)
|
||||||
@ -190,13 +194,17 @@ async def send_connection_request(
|
|||||||
if pending_count >= 5:
|
if pending_count >= 5:
|
||||||
raise HTTPException(status_code=429, detail="Too many pending requests for this user")
|
raise HTTPException(status_code=429, detail="Too many pending requests for this user")
|
||||||
|
|
||||||
# Create the request
|
# Create the request (IntegrityError guard for TOCTOU race on partial unique index)
|
||||||
conn_request = ConnectionRequest(
|
conn_request = ConnectionRequest(
|
||||||
sender_id=current_user.id,
|
sender_id=current_user.id,
|
||||||
receiver_id=target.id,
|
receiver_id=target.id,
|
||||||
)
|
)
|
||||||
db.add(conn_request)
|
db.add(conn_request)
|
||||||
await db.flush() # populate conn_request.id for source_id
|
try:
|
||||||
|
await db.flush() # populate conn_request.id for source_id
|
||||||
|
except IntegrityError:
|
||||||
|
await db.rollback()
|
||||||
|
raise HTTPException(status_code=409, detail="A pending request already exists")
|
||||||
|
|
||||||
# Create in-app notification for receiver (sender_settings already fetched above)
|
# Create in-app notification for receiver (sender_settings already fetched above)
|
||||||
sender_display = (sender_settings.preferred_name if sender_settings else None) or current_user.umbral_name
|
sender_display = (sender_settings.preferred_name if sender_settings else None) or current_user.umbral_name
|
||||||
@ -221,13 +229,16 @@ async def send_connection_request(
|
|||||||
ip=get_client_ip(request),
|
ip=get_client_ip(request),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Extract ntfy config before commit (avoids detached SA object in background task)
|
||||||
|
target_ntfy = extract_ntfy_config(target_settings) if target_settings else None
|
||||||
|
|
||||||
await db.commit()
|
await db.commit()
|
||||||
await db.refresh(conn_request)
|
await db.refresh(conn_request)
|
||||||
|
|
||||||
# ntfy push in background (non-blocking)
|
# ntfy push in background (non-blocking)
|
||||||
background_tasks.add_task(
|
background_tasks.add_task(
|
||||||
send_connection_ntfy,
|
send_connection_ntfy,
|
||||||
target_settings,
|
target_ntfy,
|
||||||
sender_display,
|
sender_display,
|
||||||
"request_received",
|
"request_received",
|
||||||
)
|
)
|
||||||
@ -319,7 +330,7 @@ async def get_outgoing_requests(
|
|||||||
|
|
||||||
# ── PUT /requests/{id}/respond ──────────────────────────────────────
|
# ── PUT /requests/{id}/respond ──────────────────────────────────────
|
||||||
|
|
||||||
@router.put("/requests/{request_id}/respond")
|
@router.put("/requests/{request_id}/respond", response_model=RespondAcceptResponse | RespondRejectResponse)
|
||||||
async def respond_to_request(
|
async def respond_to_request(
|
||||||
body: RespondRequest,
|
body: RespondRequest,
|
||||||
request: Request,
|
request: Request,
|
||||||
@ -419,16 +430,18 @@ async def respond_to_request(
|
|||||||
ip=get_client_ip(request),
|
ip=get_client_ip(request),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Extract ntfy config before commit (avoids detached SA object in background task)
|
||||||
|
sender_ntfy = extract_ntfy_config(sender_settings) if sender_settings else None
|
||||||
|
|
||||||
await db.commit()
|
await db.commit()
|
||||||
|
|
||||||
# ntfy push in background
|
# ntfy push in background
|
||||||
if sender_settings:
|
background_tasks.add_task(
|
||||||
background_tasks.add_task(
|
send_connection_ntfy,
|
||||||
send_connection_ntfy,
|
sender_ntfy,
|
||||||
sender_settings,
|
receiver_display,
|
||||||
receiver_display,
|
"request_accepted",
|
||||||
"request_accepted",
|
)
|
||||||
)
|
|
||||||
|
|
||||||
return {"message": "Connection accepted", "connection_id": conn_a.id}
|
return {"message": "Connection accepted", "connection_id": conn_a.id}
|
||||||
|
|
||||||
|
|||||||
@ -45,7 +45,7 @@ class ConnectionRequestResponse(BaseModel):
|
|||||||
sender_preferred_name: Optional[str] = None
|
sender_preferred_name: Optional[str] = None
|
||||||
receiver_umbral_name: str
|
receiver_umbral_name: str
|
||||||
receiver_preferred_name: Optional[str] = None
|
receiver_preferred_name: Optional[str] = None
|
||||||
status: str
|
status: Literal["pending", "accepted", "rejected", "cancelled"]
|
||||||
created_at: datetime
|
created_at: datetime
|
||||||
|
|
||||||
|
|
||||||
@ -63,6 +63,15 @@ class ConnectionResponse(BaseModel):
|
|||||||
created_at: datetime
|
created_at: datetime
|
||||||
|
|
||||||
|
|
||||||
|
class RespondAcceptResponse(BaseModel):
|
||||||
|
message: str
|
||||||
|
connection_id: int
|
||||||
|
|
||||||
|
|
||||||
|
class RespondRejectResponse(BaseModel):
|
||||||
|
message: str
|
||||||
|
|
||||||
|
|
||||||
class SharingOverrideUpdate(BaseModel):
|
class SharingOverrideUpdate(BaseModel):
|
||||||
model_config = ConfigDict(extra="forbid")
|
model_config = ConfigDict(extra="forbid")
|
||||||
preferred_name: Optional[bool] = None
|
preferred_name: Optional[bool] = None
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
from pydantic import BaseModel, ConfigDict, Field
|
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
@ -27,4 +27,12 @@ class NotificationListResponse(BaseModel):
|
|||||||
class MarkReadRequest(BaseModel):
|
class MarkReadRequest(BaseModel):
|
||||||
model_config = ConfigDict(extra="forbid")
|
model_config = ConfigDict(extra="forbid")
|
||||||
|
|
||||||
notification_ids: list[int] = Field(..., min_length=1, max_length=100)
|
notification_ids: list[int] = Field(..., min_length=1, max_length=100, json_schema_extra={"items": {"minimum": 1, "maximum": 2147483647}})
|
||||||
|
|
||||||
|
@field_validator('notification_ids')
|
||||||
|
@classmethod
|
||||||
|
def validate_ids(cls, v: list[int]) -> list[int]:
|
||||||
|
for i in v:
|
||||||
|
if i < 1 or i > 2147483647:
|
||||||
|
raise ValueError('Each notification ID must be between 1 and 2147483647')
|
||||||
|
return v
|
||||||
|
|||||||
@ -5,6 +5,8 @@ SHAREABLE_FIELDS is the single source of truth for which fields can be shared.
|
|||||||
"""
|
"""
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
|
from datetime import date as date_type
|
||||||
|
from types import SimpleNamespace
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
@ -87,7 +89,6 @@ def create_person_from_connection(
|
|||||||
job_title = shared_profile.get("job_title")
|
job_title = shared_profile.get("job_title")
|
||||||
birthday_str = shared_profile.get("birthday")
|
birthday_str = shared_profile.get("birthday")
|
||||||
|
|
||||||
from datetime import date as date_type
|
|
||||||
birthday = None
|
birthday = None
|
||||||
if birthday_str:
|
if birthday_str:
|
||||||
try:
|
try:
|
||||||
@ -125,13 +126,30 @@ async def detach_umbral_contact(person: Person) -> None:
|
|||||||
person.first_name = person.name or None
|
person.first_name = person.name or None
|
||||||
|
|
||||||
|
|
||||||
|
def extract_ntfy_config(settings: Settings) -> dict | None:
|
||||||
|
"""Extract ntfy config values into a plain dict safe for use after session close."""
|
||||||
|
if not settings.ntfy_enabled or not settings.ntfy_connections_enabled:
|
||||||
|
return None
|
||||||
|
return {
|
||||||
|
"ntfy_enabled": True,
|
||||||
|
"ntfy_server_url": settings.ntfy_server_url,
|
||||||
|
"ntfy_topic": settings.ntfy_topic,
|
||||||
|
"ntfy_auth_token": settings.ntfy_auth_token,
|
||||||
|
"user_id": settings.user_id,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
async def send_connection_ntfy(
|
async def send_connection_ntfy(
|
||||||
settings: Settings,
|
ntfy_config: dict | None,
|
||||||
sender_name: str,
|
sender_name: str,
|
||||||
event_type: str,
|
event_type: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Send ntfy push for connection events. Non-blocking with 3s timeout."""
|
"""Send ntfy push for connection events. Non-blocking with 3s timeout.
|
||||||
if not settings.ntfy_connections_enabled:
|
|
||||||
|
Accepts a plain dict (from extract_ntfy_config) to avoid accessing
|
||||||
|
detached SQLAlchemy objects after session close.
|
||||||
|
"""
|
||||||
|
if not ntfy_config:
|
||||||
return
|
return
|
||||||
|
|
||||||
title_map = {
|
title_map = {
|
||||||
@ -151,10 +169,13 @@ async def send_connection_ntfy(
|
|||||||
message = message_map.get(event_type, f"Connection update from {sender_name}")
|
message = message_map.get(event_type, f"Connection update from {sender_name}")
|
||||||
tags = tag_map.get(event_type, ["bell"])
|
tags = tag_map.get(event_type, ["bell"])
|
||||||
|
|
||||||
|
# Build a settings-like object for send_ntfy_notification (avoids detached SA objects)
|
||||||
|
settings_proxy = SimpleNamespace(**ntfy_config)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await asyncio.wait_for(
|
await asyncio.wait_for(
|
||||||
send_ntfy_notification(
|
send_ntfy_notification(
|
||||||
settings=settings,
|
settings=settings_proxy,
|
||||||
title=title,
|
title=title,
|
||||||
message=message,
|
message=message,
|
||||||
tags=tags,
|
tags=tags,
|
||||||
@ -163,6 +184,6 @@ async def send_connection_ntfy(
|
|||||||
timeout=3.0,
|
timeout=3.0,
|
||||||
)
|
)
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
logger.warning("ntfy connection push timed out for user_id=%s", settings.user_id)
|
logger.warning("ntfy connection push timed out for user_id=%s", ntfy_config["user_id"])
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.warning("ntfy connection push failed for user_id=%s", settings.user_id)
|
logger.warning("ntfy connection push failed for user_id=%s", ntfy_config["user_id"])
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user