77 lines
2.6 KiB
Python

from __future__ import annotations
from datetime import datetime, timezone
from typing import Optional
from sqlalchemy import select, func, update
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.alert import Alert
class AlertRepository:
def __init__(self, session: AsyncSession):
self.session = session
async def list(
self,
limit: int = 100,
offset: int = 0,
unread_only: bool = False,
category: Optional[str] = None,
user_id: Optional[int] = None,
) -> list[Alert]:
stmt = select(Alert).order_by(Alert.created_at.desc()).offset(offset).limit(limit)
if user_id is not None:
stmt = stmt.where(Alert.user_id == user_id)
if unread_only:
stmt = stmt.where(Alert.read_at.is_(None))
if category:
stmt = stmt.where(Alert.category == category)
result = await self.session.execute(stmt)
return result.scalars().all()
async def get(self, alert_id: str) -> Optional[Alert]:
stmt = select(Alert).where(Alert.id == alert_id)
result = await self.session.execute(stmt)
return result.scalar_one_or_none()
async def count_unread(self, user_id: Optional[int] = None) -> int:
stmt = select(func.count()).select_from(Alert).where(Alert.read_at.is_(None))
if user_id is not None:
stmt = stmt.where(Alert.user_id == user_id)
result = await self.session.execute(stmt)
return int(result.scalar() or 0)
async def create(self, **fields) -> Alert:
alert = Alert(**fields)
self.session.add(alert)
await self.session.flush()
return alert
async def mark_as_read(self, alert_id: str) -> bool:
stmt = select(Alert).where(Alert.id == alert_id)
result = await self.session.execute(stmt)
alert = result.scalar_one_or_none()
if not alert:
return False
if alert.read_at is None:
alert.read_at = datetime.now(timezone.utc)
await self.session.flush()
return True
async def mark_all_as_read(self, user_id: Optional[int] = None) -> int:
stmt = update(Alert).where(Alert.read_at.is_(None)).values(read_at=datetime.now(timezone.utc))
if user_id is not None:
stmt = stmt.where(Alert.user_id == user_id)
result = await self.session.execute(stmt)
return int(result.rowcount or 0)
async def delete(self, alert_id: str) -> bool:
alert = await self.get(alert_id)
if alert:
await self.session.delete(alert)
await self.session.flush()
return True
return False