"""User repository for CRUD operations.""" from __future__ import annotations from datetime import datetime, timezone from typing import Optional from sqlalchemy import func, select, update from sqlalchemy.ext.asyncio import AsyncSession from app.models.user import User class UserRepository: """Repository for User CRUD operations.""" def __init__(self, session: AsyncSession): self.session = session async def count(self, include_deleted: bool = False) -> int: """Count total users.""" stmt = select(func.count(User.id)) if not include_deleted: stmt = stmt.where(User.deleted_at.is_(None)) result = await self.session.execute(stmt) return result.scalar() or 0 async def list( self, limit: int = 100, offset: int = 0, include_deleted: bool = False ) -> list[User]: """List all users with pagination.""" stmt = select(User).order_by(User.created_at.desc()).offset(offset).limit(limit) if not include_deleted: stmt = stmt.where(User.deleted_at.is_(None)) result = await self.session.execute(stmt) return list(result.scalars().all()) async def get(self, user_id: int, include_deleted: bool = False) -> Optional[User]: """Get user by ID.""" stmt = select(User).where(User.id == user_id) if not include_deleted: stmt = stmt.where(User.deleted_at.is_(None)) result = await self.session.execute(stmt) return result.scalar_one_or_none() async def get_by_username( self, username: str, include_deleted: bool = False ) -> Optional[User]: """Get user by username.""" stmt = select(User).where(User.username == username) if not include_deleted: stmt = stmt.where(User.deleted_at.is_(None)) result = await self.session.execute(stmt) return result.scalar_one_or_none() async def get_by_email( self, email: str, include_deleted: bool = False ) -> Optional[User]: """Get user by email.""" stmt = select(User).where(User.email == email) if not include_deleted: stmt = stmt.where(User.deleted_at.is_(None)) result = await self.session.execute(stmt) return result.scalar_one_or_none() async def create( self, *, username: str, hashed_password: str, email: Optional[str] = None, display_name: Optional[str] = None, role: str = "admin", is_active: bool = True, is_superuser: bool = False, ) -> User: """Create a new user.""" user = User( username=username, hashed_password=hashed_password, email=email, display_name=display_name, role=role, is_active=is_active, is_superuser=is_superuser, password_changed_at=datetime.now(timezone.utc), ) self.session.add(user) await self.session.flush() return user async def update(self, user: User, **fields) -> User: """Update user fields.""" for key, value in fields.items(): if value is not None: setattr(user, key, value) await self.session.flush() return user async def update_password(self, user: User, hashed_password: str) -> User: """Update user password and timestamp.""" user.hashed_password = hashed_password user.password_changed_at = datetime.now(timezone.utc) await self.session.flush() return user async def update_last_login(self, user: User) -> User: """Update last login timestamp.""" user.last_login = datetime.now(timezone.utc) await self.session.flush() return user async def soft_delete(self, user_id: int) -> bool: """Soft delete a user.""" stmt = ( update(User) .where(User.id == user_id, User.deleted_at.is_(None)) .values(deleted_at=datetime.now(timezone.utc), is_active=False) ) result = await self.session.execute(stmt) return result.rowcount > 0 async def hard_delete(self, user_id: int) -> bool: """Permanently delete a user (use with caution).""" user = await self.get(user_id, include_deleted=True) if user: await self.session.delete(user) await self.session.flush() return True return False async def exists_any(self) -> bool: """Check if any user exists (for initial setup check).""" stmt = select(func.count(User.id)).where(User.deleted_at.is_(None)) result = await self.session.execute(stmt) count = result.scalar() or 0 return count > 0