77 lines
2.6 KiB
Python
77 lines
2.6 KiB
Python
from __future__ import annotations
|
|
|
|
from datetime import datetime, timezone
|
|
from typing import Optional
|
|
|
|
from sqlalchemy import select, func, update
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from app.models.alert import Alert
|
|
|
|
|
|
class AlertRepository:
|
|
def __init__(self, session: AsyncSession):
|
|
self.session = session
|
|
|
|
async def list(
|
|
self,
|
|
limit: int = 100,
|
|
offset: int = 0,
|
|
unread_only: bool = False,
|
|
category: Optional[str] = None,
|
|
user_id: Optional[int] = None,
|
|
) -> list[Alert]:
|
|
stmt = select(Alert).order_by(Alert.created_at.desc()).offset(offset).limit(limit)
|
|
if user_id is not None:
|
|
stmt = stmt.where(Alert.user_id == user_id)
|
|
if unread_only:
|
|
stmt = stmt.where(Alert.read_at.is_(None))
|
|
if category:
|
|
stmt = stmt.where(Alert.category == category)
|
|
result = await self.session.execute(stmt)
|
|
return result.scalars().all()
|
|
|
|
async def get(self, alert_id: str) -> Optional[Alert]:
|
|
stmt = select(Alert).where(Alert.id == alert_id)
|
|
result = await self.session.execute(stmt)
|
|
return result.scalar_one_or_none()
|
|
|
|
async def count_unread(self, user_id: Optional[int] = None) -> int:
|
|
stmt = select(func.count()).select_from(Alert).where(Alert.read_at.is_(None))
|
|
if user_id is not None:
|
|
stmt = stmt.where(Alert.user_id == user_id)
|
|
result = await self.session.execute(stmt)
|
|
return int(result.scalar() or 0)
|
|
|
|
async def create(self, **fields) -> Alert:
|
|
alert = Alert(**fields)
|
|
self.session.add(alert)
|
|
await self.session.flush()
|
|
return alert
|
|
|
|
async def mark_as_read(self, alert_id: str) -> bool:
|
|
stmt = select(Alert).where(Alert.id == alert_id)
|
|
result = await self.session.execute(stmt)
|
|
alert = result.scalar_one_or_none()
|
|
if not alert:
|
|
return False
|
|
if alert.read_at is None:
|
|
alert.read_at = datetime.now(timezone.utc)
|
|
await self.session.flush()
|
|
return True
|
|
|
|
async def mark_all_as_read(self, user_id: Optional[int] = None) -> int:
|
|
stmt = update(Alert).where(Alert.read_at.is_(None)).values(read_at=datetime.now(timezone.utc))
|
|
if user_id is not None:
|
|
stmt = stmt.where(Alert.user_id == user_id)
|
|
result = await self.session.execute(stmt)
|
|
return int(result.rowcount or 0)
|
|
|
|
async def delete(self, alert_id: str) -> bool:
|
|
alert = await self.get(alert_id)
|
|
if alert:
|
|
await self.session.delete(alert)
|
|
await self.session.flush()
|
|
return True
|
|
return False
|