from __future__ import annotations from datetime import datetime, timezone from typing import Optional from sqlalchemy import select, func, update from sqlalchemy.ext.asyncio import AsyncSession from app.models.alert import Alert class AlertRepository: def __init__(self, session: AsyncSession): self.session = session async def list( self, limit: int = 100, offset: int = 0, unread_only: bool = False, category: Optional[str] = None, user_id: Optional[int] = None, ) -> list[Alert]: stmt = select(Alert).order_by(Alert.created_at.desc()).offset(offset).limit(limit) if user_id is not None: stmt = stmt.where(Alert.user_id == user_id) if unread_only: stmt = stmt.where(Alert.read_at.is_(None)) if category: stmt = stmt.where(Alert.category == category) result = await self.session.execute(stmt) return result.scalars().all() async def get(self, alert_id: str) -> Optional[Alert]: stmt = select(Alert).where(Alert.id == alert_id) result = await self.session.execute(stmt) return result.scalar_one_or_none() async def count_unread(self, user_id: Optional[int] = None) -> int: stmt = select(func.count()).select_from(Alert).where(Alert.read_at.is_(None)) if user_id is not None: stmt = stmt.where(Alert.user_id == user_id) result = await self.session.execute(stmt) return int(result.scalar() or 0) async def create(self, **fields) -> Alert: alert = Alert(**fields) self.session.add(alert) await self.session.flush() return alert async def mark_as_read(self, alert_id: str) -> bool: stmt = select(Alert).where(Alert.id == alert_id) result = await self.session.execute(stmt) alert = result.scalar_one_or_none() if not alert: return False if alert.read_at is None: alert.read_at = datetime.now(timezone.utc) await self.session.flush() return True async def mark_all_as_read(self, user_id: Optional[int] = None) -> int: stmt = update(Alert).where(Alert.read_at.is_(None)).values(read_at=datetime.now(timezone.utc)) if user_id is not None: stmt = stmt.where(Alert.user_id == user_id) result = await self.session.execute(stmt) return int(result.rowcount or 0) async def delete(self, alert_id: str) -> bool: alert = await self.get(alert_id) if alert: await self.session.delete(alert) await self.session.flush() return True return False