""" Notification centre router — in-app notifications. All endpoints scoped by current_user.id to prevent IDOR. """ from fastapi import APIRouter, Depends, HTTPException, Path, Query from sqlalchemy import select, func, update, delete, and_ from sqlalchemy.ext.asyncio import AsyncSession from app.database import get_db from app.models.notification import Notification from app.models.user import User from app.routers.auth import get_current_user from app.schemas.notification import ( NotificationResponse, NotificationListResponse, MarkReadRequest, ) router = APIRouter() @router.get("/", response_model=NotificationListResponse) async def list_notifications( unread_only: bool = Query(False), type: str | None = Query(None, max_length=50), page: int = Query(1, ge=1), per_page: int = Query(20, ge=1, le=100), db: AsyncSession = Depends(get_db), current_user: User = Depends(get_current_user), ): """Paginated notification list with optional filters.""" base = select(Notification).where(Notification.user_id == current_user.id) if unread_only: base = base.where(Notification.is_read == False) # noqa: E712 if type: base = base.where(Notification.type == type) # Total count count_q = select(func.count()).select_from(base.subquery()) total = await db.scalar(count_q) or 0 # Unread count (always full, regardless of filters) unread_count = await db.scalar( select(func.count()) .select_from(Notification) .where( Notification.user_id == current_user.id, Notification.is_read == False, # noqa: E712 ) ) or 0 # Paginated results offset = (page - 1) * per_page result = await db.execute( base.order_by(Notification.created_at.desc()).offset(offset).limit(per_page) ) notifications = result.scalars().all() return NotificationListResponse( notifications=notifications, unread_count=unread_count, total=total, ) @router.get("/unread-count") async def get_unread_count( db: AsyncSession = Depends(get_db), current_user: User = Depends(get_current_user), ): """Lightweight unread count endpoint (uses partial index).""" count = await db.scalar( select(func.count()) .select_from(Notification) .where( Notification.user_id == current_user.id, Notification.is_read == False, # noqa: E712 ) ) or 0 return {"count": count} @router.put("/read") async def mark_read( body: MarkReadRequest, db: AsyncSession = Depends(get_db), current_user: User = Depends(get_current_user), ): """Mark specific notification IDs as read (user_id scoped — IDOR prevention).""" await db.execute( update(Notification) .where( and_( Notification.id.in_(body.notification_ids), Notification.user_id == current_user.id, ) ) .values(is_read=True) ) await db.commit() return {"message": "Notifications marked as read"} @router.put("/read-all") async def mark_all_read( db: AsyncSession = Depends(get_db), current_user: User = Depends(get_current_user), ): """Mark all notifications as read for current user.""" await db.execute( update(Notification) .where( Notification.user_id == current_user.id, Notification.is_read == False, # noqa: E712 ) .values(is_read=True) ) await db.commit() return {"message": "All notifications marked as read"} @router.delete("/{notification_id}", status_code=204) async def delete_notification( notification_id: int = Path(ge=1, le=2147483647), db: AsyncSession = Depends(get_db), current_user: User = Depends(get_current_user), ): """Delete a single notification (user_id scoped).""" result = await db.execute( select(Notification).where( Notification.id == notification_id, Notification.user_id == current_user.id, ) ) notification = result.scalar_one_or_none() if not notification: raise HTTPException(status_code=404, detail="Notification not found") await db.delete(notification) await db.commit() return None