""" Event invitation service — send, respond, override, dismiss invitations. All functions accept an AsyncSession and do NOT commit — callers manage transactions. """ import logging from datetime import datetime from fastapi import HTTPException from sqlalchemy import delete, select, update from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import selectinload from app.models.calendar_event import CalendarEvent from app.models.event_invitation import EventInvitation, EventInvitationOverride from app.models.user_connection import UserConnection from app.models.settings import Settings from app.models.user import User from app.services.notification import create_notification logger = logging.getLogger(__name__) async def validate_connections( db: AsyncSession, inviter_id: int, user_ids: list[int] ) -> None: """Verify bidirectional connections exist for all invitees. Raises 404 on failure.""" if not user_ids: return result = await db.execute( select(UserConnection.connected_user_id).where( UserConnection.user_id == inviter_id, UserConnection.connected_user_id.in_(user_ids), ) ) connected_ids = {r[0] for r in result.all()} missing = set(user_ids) - connected_ids if missing: raise HTTPException(status_code=404, detail="One or more users not found in your connections") async def send_event_invitations( db: AsyncSession, event_id: int, user_ids: list[int], invited_by: int, ) -> list[EventInvitation]: """ Bulk-insert invitations for an event. Skips self-invites and existing invitations. Creates in-app notifications for each invitee. """ # Remove self from list user_ids = [uid for uid in user_ids if uid != invited_by] if not user_ids: raise HTTPException(status_code=400, detail="Cannot invite yourself") # Validate connections await validate_connections(db, invited_by, user_ids) # Check existing invitations to skip duplicates existing_result = await db.execute( select(EventInvitation.user_id).where( EventInvitation.event_id == event_id, EventInvitation.user_id.in_(user_ids), ) ) existing_ids = {r[0] for r in existing_result.all()} # Cap: max 20 pending invitations per event count_result = await db.execute( select(EventInvitation.id).where(EventInvitation.event_id == event_id) ) current_count = len(count_result.all()) new_ids = [uid for uid in user_ids if uid not in existing_ids] if current_count + len(new_ids) > 20: raise HTTPException(status_code=400, detail="Maximum 20 invitations per event") if not new_ids: return [] # Fetch event title for notifications event_result = await db.execute( select(CalendarEvent.title, CalendarEvent.start_datetime).where( CalendarEvent.id == event_id ) ) event_row = event_result.one_or_none() event_title = event_row[0] if event_row else "an event" event_start = event_row[1] if event_row else None # Fetch inviter's name inviter_settings = await db.execute( select(Settings.preferred_name).where(Settings.user_id == invited_by) ) inviter_name_row = inviter_settings.one_or_none() inviter_name = inviter_name_row[0] if inviter_name_row and inviter_name_row[0] else "Someone" invitations = [] for uid in new_ids: inv = EventInvitation( event_id=event_id, user_id=uid, invited_by=invited_by, status="pending", ) db.add(inv) invitations.append(inv) # Create notification start_str = event_start.strftime("%b %d, %I:%M %p") if event_start else "" await create_notification( db=db, user_id=uid, type="event_invite", title="Event Invitation", message=f"{inviter_name} invited you to {event_title}" + (f" · {start_str}" if start_str else ""), data={"event_id": event_id, "event_title": event_title}, source_type="event_invitation", source_id=event_id, ) return invitations async def respond_to_invitation( db: AsyncSession, invitation_id: int, user_id: int, status: str, ) -> EventInvitation: """Update invitation status. Returns the updated invitation.""" result = await db.execute( select(EventInvitation) .options(selectinload(EventInvitation.event)) .where( EventInvitation.id == invitation_id, EventInvitation.user_id == user_id, ) ) invitation = result.scalar_one_or_none() if not invitation: raise HTTPException(status_code=404, detail="Invitation not found") # Build response data before modifying event_title = invitation.event.title old_status = invitation.status invitation.status = status invitation.responded_at = datetime.now() # Notify the inviter if invitation.invited_by: status_label = {"accepted": "Going", "tentative": "Tentative", "declined": "Declined"} # Fetch responder name responder_settings = await db.execute( select(Settings.preferred_name).where(Settings.user_id == user_id) ) responder_row = responder_settings.one_or_none() responder_name = responder_row[0] if responder_row and responder_row[0] else "Someone" await create_notification( db=db, user_id=invitation.invited_by, type="event_invite_response", title="Event RSVP", message=f"{responder_name} is {status_label.get(status, status)} for {event_title}", data={"event_id": invitation.event_id, "status": status}, source_type="event_invitation", source_id=invitation.event_id, ) return invitation async def override_occurrence_status( db: AsyncSession, invitation_id: int, occurrence_id: int, user_id: int, status: str, ) -> EventInvitationOverride: """Create or update a per-occurrence status override.""" # Verify invitation belongs to user inv_result = await db.execute( select(EventInvitation).where( EventInvitation.id == invitation_id, EventInvitation.user_id == user_id, ) ) invitation = inv_result.scalar_one_or_none() if not invitation: raise HTTPException(status_code=404, detail="Invitation not found") # Verify occurrence belongs to the invited event's series occ_result = await db.execute( select(CalendarEvent).where(CalendarEvent.id == occurrence_id) ) occurrence = occ_result.scalar_one_or_none() if not occurrence: raise HTTPException(status_code=404, detail="Occurrence not found") # Occurrence must be the event itself OR a child of the invited event if occurrence.id != invitation.event_id and occurrence.parent_event_id != invitation.event_id: raise HTTPException(status_code=400, detail="Occurrence does not belong to this event series") # Upsert override existing = await db.execute( select(EventInvitationOverride).where( EventInvitationOverride.invitation_id == invitation_id, EventInvitationOverride.occurrence_id == occurrence_id, ) ) override = existing.scalar_one_or_none() if override: override.status = status override.responded_at = datetime.now() else: override = EventInvitationOverride( invitation_id=invitation_id, occurrence_id=occurrence_id, status=status, responded_at=datetime.now(), ) db.add(override) return override async def dismiss_invitation( db: AsyncSession, invitation_id: int, user_id: int, ) -> None: """Delete an invitation (invitee leaving or owner revoking).""" result = await db.execute( delete(EventInvitation).where( EventInvitation.id == invitation_id, EventInvitation.user_id == user_id, ) ) if result.rowcount == 0: raise HTTPException(status_code=404, detail="Invitation not found") async def dismiss_invitation_by_owner( db: AsyncSession, invitation_id: int, ) -> None: """Delete an invitation by the event owner (revoking).""" result = await db.execute( delete(EventInvitation).where(EventInvitation.id == invitation_id) ) if result.rowcount == 0: raise HTTPException(status_code=404, detail="Invitation not found") async def get_event_invitations( db: AsyncSession, event_id: int, ) -> list[dict]: """Get all invitations for an event with invitee names.""" result = await db.execute( select( EventInvitation, Settings.preferred_name, User.umbral_name, ) .join(User, EventInvitation.user_id == User.id) .outerjoin(Settings, Settings.user_id == User.id) .where(EventInvitation.event_id == event_id) .order_by(EventInvitation.invited_at.asc()) ) rows = result.all() return [ { "id": inv.id, "event_id": inv.event_id, "user_id": inv.user_id, "invited_by": inv.invited_by, "status": inv.status, "invited_at": inv.invited_at, "responded_at": inv.responded_at, "invitee_name": preferred_name or umbral_name or "Unknown", "invitee_umbral_name": umbral_name or "Unknown", } for inv, preferred_name, umbral_name in rows ] async def get_invited_event_ids( db: AsyncSession, user_id: int, ) -> list[int]: """Return event IDs where user has a non-declined invitation.""" result = await db.execute( select(EventInvitation.event_id).where( EventInvitation.user_id == user_id, EventInvitation.status != "declined", ) ) return [r[0] for r in result.all()] async def get_pending_invitations( db: AsyncSession, user_id: int, ) -> list[dict]: """Return pending invitations for the current user.""" result = await db.execute( select( EventInvitation, CalendarEvent.title, CalendarEvent.start_datetime, Settings.preferred_name, ) .join(CalendarEvent, EventInvitation.event_id == CalendarEvent.id) .outerjoin( User, EventInvitation.invited_by == User.id ) .outerjoin( Settings, Settings.user_id == User.id ) .where( EventInvitation.user_id == user_id, EventInvitation.status == "pending", ) .order_by(EventInvitation.invited_at.desc()) ) rows = result.all() return [ { "id": inv.id, "event_id": inv.event_id, "event_title": title, "event_start": start_dt, "invited_by_name": inviter_name or "Someone", "invited_at": inv.invited_at, "status": inv.status, } for inv, title, start_dt, inviter_name in rows ] async def get_invitation_overrides_for_user( db: AsyncSession, user_id: int, event_ids: list[int], ) -> dict[int, str]: """ For a list of occurrence event IDs, return a map of occurrence_id -> override status. Used to annotate event listings with per-occurrence invitation status. """ if not event_ids: return {} result = await db.execute( select( EventInvitationOverride.occurrence_id, EventInvitationOverride.status, ) .join(EventInvitation, EventInvitationOverride.invitation_id == EventInvitation.id) .where( EventInvitation.user_id == user_id, EventInvitationOverride.occurrence_id.in_(event_ids), ) ) return {r[0]: r[1] for r in result.all()} async def cascade_event_invitations_on_disconnect( db: AsyncSession, user_a_id: int, user_b_id: int, ) -> None: """Delete event invitations between two users when connection is severed.""" # Delete invitations where A invited B await db.execute( delete(EventInvitation).where( EventInvitation.invited_by == user_a_id, EventInvitation.user_id == user_b_id, ) ) # Delete invitations where B invited A await db.execute( delete(EventInvitation).where( EventInvitation.invited_by == user_b_id, EventInvitation.user_id == user_a_id, ) )