""" Event invitation service — send, respond, override, dismiss invitations. All functions accept an AsyncSession and do NOT commit — callers manage transactions. """ import logging from datetime import datetime from fastapi import HTTPException from sqlalchemy import delete, select, update from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import selectinload from app.models.calendar import Calendar from app.models.calendar_event import CalendarEvent from app.models.event_invitation import EventInvitation, EventInvitationOverride from app.models.user_connection import UserConnection from app.models.settings import Settings from app.models.user import User from app.services.notification import create_notification logger = logging.getLogger(__name__) async def validate_connections( db: AsyncSession, inviter_id: int, user_ids: list[int] ) -> None: """Verify bidirectional connections exist for all invitees. Raises 404 on failure.""" if not user_ids: return result = await db.execute( select(UserConnection.connected_user_id).where( UserConnection.user_id == inviter_id, UserConnection.connected_user_id.in_(user_ids), ) ) connected_ids = {r[0] for r in result.all()} missing = set(user_ids) - connected_ids if missing: raise HTTPException(status_code=404, detail="One or more users not found in your connections") async def send_event_invitations( db: AsyncSession, event_id: int, user_ids: list[int], invited_by: int, ) -> list[EventInvitation]: """ Bulk-insert invitations for an event. Skips self-invites and existing invitations. Creates in-app notifications for each invitee. """ # Remove self from list user_ids = [uid for uid in user_ids if uid != invited_by] if not user_ids: raise HTTPException(status_code=400, detail="Cannot invite yourself") # Validate connections await validate_connections(db, invited_by, user_ids) # Check existing invitations to skip duplicates existing_result = await db.execute( select(EventInvitation.user_id).where( EventInvitation.event_id == event_id, EventInvitation.user_id.in_(user_ids), ) ) existing_ids = {r[0] for r in existing_result.all()} # Cap: max 20 pending invitations per event count_result = await db.execute( select(EventInvitation.id).where(EventInvitation.event_id == event_id) ) current_count = len(count_result.all()) new_ids = [uid for uid in user_ids if uid not in existing_ids] if current_count + len(new_ids) > 20: raise HTTPException(status_code=400, detail="Maximum 20 invitations per event") if not new_ids: return [] # Fetch event title for notifications event_result = await db.execute( select(CalendarEvent.title, CalendarEvent.start_datetime).where( CalendarEvent.id == event_id ) ) event_row = event_result.one_or_none() event_title = event_row[0] if event_row else "an event" event_start = event_row[1] if event_row else None # Fetch inviter's name inviter_settings = await db.execute( select(Settings.preferred_name).where(Settings.user_id == invited_by) ) inviter_name_row = inviter_settings.one_or_none() inviter_name = inviter_name_row[0] if inviter_name_row and inviter_name_row[0] else "Someone" invitations = [] for uid in new_ids: inv = EventInvitation( event_id=event_id, user_id=uid, invited_by=invited_by, status="pending", ) db.add(inv) invitations.append(inv) # Flush to populate invitation IDs before creating notifications await db.flush() for inv in invitations: start_str = event_start.strftime("%b %d, %I:%M %p") if event_start else "" await create_notification( db=db, user_id=inv.user_id, type="event_invite", title="Event Invitation", message=f"{inviter_name} invited you to {event_title}" + (f" · {start_str}" if start_str else ""), data={"event_id": event_id, "event_title": event_title, "invitation_id": inv.id}, source_type="event_invitation", source_id=event_id, ) return invitations async def respond_to_invitation( db: AsyncSession, invitation_id: int, user_id: int, status: str, ) -> EventInvitation: """Update invitation status. Returns the updated invitation.""" result = await db.execute( select(EventInvitation) .options(selectinload(EventInvitation.event)) .where( EventInvitation.id == invitation_id, EventInvitation.user_id == user_id, ) ) invitation = result.scalar_one_or_none() if not invitation: raise HTTPException(status_code=404, detail="Invitation not found") # Build response data before modifying event_title = invitation.event.title old_status = invitation.status invitation.status = status invitation.responded_at = datetime.now() # Auto-assign display calendar on accept/tentative (atomic: only if not already set) if status in ("accepted", "tentative"): default_cal = await db.execute( select(Calendar.id).where( Calendar.user_id == user_id, Calendar.is_default == True, ).limit(1) ) default_cal_id = default_cal.scalar_one_or_none() if default_cal_id and invitation.display_calendar_id is None: # Atomic: only set if still NULL (race-safe) await db.execute( update(EventInvitation) .where( EventInvitation.id == invitation_id, EventInvitation.display_calendar_id == None, ) .values(display_calendar_id=default_cal_id) ) invitation.display_calendar_id = default_cal_id # Notify the inviter only if status actually changed (prevents duplicate notifications) if invitation.invited_by and old_status != status: status_label = {"accepted": "Going", "tentative": "Tentative", "declined": "Declined"} # Fetch responder name responder_settings = await db.execute( select(Settings.preferred_name).where(Settings.user_id == user_id) ) responder_row = responder_settings.one_or_none() responder_name = responder_row[0] if responder_row and responder_row[0] else "Someone" await create_notification( db=db, user_id=invitation.invited_by, type="event_invite_response", title="Event RSVP", message=f"{responder_name} is {status_label.get(status, status)} for {event_title}", data={"event_id": invitation.event_id, "status": status}, source_type="event_invitation", source_id=invitation.event_id, ) return invitation async def override_occurrence_status( db: AsyncSession, invitation_id: int, occurrence_id: int, user_id: int, status: str, ) -> EventInvitationOverride: """Create or update a per-occurrence status override.""" # Verify invitation belongs to user inv_result = await db.execute( select(EventInvitation).where( EventInvitation.id == invitation_id, EventInvitation.user_id == user_id, ) ) invitation = inv_result.scalar_one_or_none() if not invitation: raise HTTPException(status_code=404, detail="Invitation not found") # Verify occurrence belongs to the invited event's series occ_result = await db.execute( select(CalendarEvent).where(CalendarEvent.id == occurrence_id) ) occurrence = occ_result.scalar_one_or_none() if not occurrence: raise HTTPException(status_code=404, detail="Occurrence not found") # Occurrence must be the event itself OR a child of the invited event if occurrence.id != invitation.event_id and occurrence.parent_event_id != invitation.event_id: raise HTTPException(status_code=400, detail="Occurrence does not belong to this event series") # Upsert override existing = await db.execute( select(EventInvitationOverride).where( EventInvitationOverride.invitation_id == invitation_id, EventInvitationOverride.occurrence_id == occurrence_id, ) ) override = existing.scalar_one_or_none() if override: override.status = status override.responded_at = datetime.now() else: override = EventInvitationOverride( invitation_id=invitation_id, occurrence_id=occurrence_id, status=status, responded_at=datetime.now(), ) db.add(override) return override async def dismiss_invitation( db: AsyncSession, invitation_id: int, user_id: int, ) -> None: """Delete an invitation (invitee leaving or owner revoking).""" result = await db.execute( delete(EventInvitation).where( EventInvitation.id == invitation_id, EventInvitation.user_id == user_id, ) ) if result.rowcount == 0: raise HTTPException(status_code=404, detail="Invitation not found") async def dismiss_invitation_by_owner( db: AsyncSession, invitation_id: int, ) -> None: """Delete an invitation by the event owner (revoking).""" result = await db.execute( delete(EventInvitation).where(EventInvitation.id == invitation_id) ) if result.rowcount == 0: raise HTTPException(status_code=404, detail="Invitation not found") async def get_event_invitations( db: AsyncSession, event_id: int, ) -> list[dict]: """Get all invitations for an event with invitee names.""" result = await db.execute( select( EventInvitation, Settings.preferred_name, User.umbral_name, ) .join(User, EventInvitation.user_id == User.id) .outerjoin(Settings, Settings.user_id == User.id) .where(EventInvitation.event_id == event_id) .order_by(EventInvitation.invited_at.asc()) ) rows = result.all() return [ { "id": inv.id, "event_id": inv.event_id, "user_id": inv.user_id, "invited_by": inv.invited_by, "status": inv.status, "invited_at": inv.invited_at, "responded_at": inv.responded_at, "invitee_name": preferred_name or umbral_name or "Unknown", "invitee_umbral_name": umbral_name or "Unknown", } for inv, preferred_name, umbral_name in rows ] async def get_invited_event_ids( db: AsyncSession, user_id: int, ) -> list[int]: """Return event IDs where user has a non-declined invitation.""" result = await db.execute( select(EventInvitation.event_id).where( EventInvitation.user_id == user_id, EventInvitation.status != "declined", ) ) return [r[0] for r in result.all()] async def get_pending_invitations( db: AsyncSession, user_id: int, ) -> list[dict]: """Return pending invitations for the current user.""" result = await db.execute( select( EventInvitation, CalendarEvent.title, CalendarEvent.start_datetime, Settings.preferred_name, ) .join(CalendarEvent, EventInvitation.event_id == CalendarEvent.id) .outerjoin( User, EventInvitation.invited_by == User.id ) .outerjoin( Settings, Settings.user_id == User.id ) .where( EventInvitation.user_id == user_id, EventInvitation.status == "pending", ) .order_by(EventInvitation.invited_at.desc()) ) rows = result.all() return [ { "id": inv.id, "event_id": inv.event_id, "event_title": title, "event_start": start_dt, "invited_by_name": inviter_name or "Someone", "invited_at": inv.invited_at, "status": inv.status, } for inv, title, start_dt, inviter_name in rows ] async def get_invitation_overrides_for_user( db: AsyncSession, user_id: int, event_ids: list[int], ) -> dict[int, str]: """ For a list of occurrence event IDs, return a map of occurrence_id -> override status. Used to annotate event listings with per-occurrence invitation status. """ if not event_ids: return {} result = await db.execute( select( EventInvitationOverride.occurrence_id, EventInvitationOverride.status, ) .join(EventInvitation, EventInvitationOverride.invitation_id == EventInvitation.id) .where( EventInvitation.user_id == user_id, EventInvitationOverride.occurrence_id.in_(event_ids), ) ) return {r[0]: r[1] for r in result.all()} async def cascade_event_invitations_on_disconnect( db: AsyncSession, user_a_id: int, user_b_id: int, ) -> None: """Delete event invitations between two users when connection is severed.""" # Delete invitations where A invited B await db.execute( delete(EventInvitation).where( EventInvitation.invited_by == user_a_id, EventInvitation.user_id == user_b_id, ) ) # Delete invitations where B invited A await db.execute( delete(EventInvitation).where( EventInvitation.invited_by == user_b_id, EventInvitation.user_id == user_a_id, ) )