""" CRUD operations for terminal sessions. """ from datetime import datetime, timezone from typing import List, Optional from sqlalchemy import select, update, and_ from sqlalchemy.ext.asyncio import AsyncSession from app.models.terminal_session import TerminalSession class TerminalSessionRepository: """Repository for TerminalSession CRUD operations.""" def __init__(self, session: AsyncSession): self.session = session async def create( self, id: str, host_id: str, host_name: str, host_ip: str, token_hash: str, ttyd_port: int, expires_at: datetime, user_id: Optional[str] = None, username: Optional[str] = None, ttyd_pid: Optional[int] = None, mode: str = "embedded", ) -> TerminalSession: """Create a new terminal session.""" session = TerminalSession( id=id, host_id=host_id, host_name=host_name, host_ip=host_ip, user_id=user_id, username=username, token_hash=token_hash, ttyd_port=ttyd_port, ttyd_pid=ttyd_pid, mode=mode, status="active", expires_at=expires_at, ) self.session.add(session) await self.session.flush() return session async def get(self, session_id: str) -> Optional[TerminalSession]: """Get a terminal session by ID.""" result = await self.session.execute( select(TerminalSession).where(TerminalSession.id == session_id) ) return result.scalar_one_or_none() async def get_active_by_id(self, session_id: str) -> Optional[TerminalSession]: """Get an active terminal session by ID.""" now = datetime.now(timezone.utc) result = await self.session.execute( select(TerminalSession).where( and_( TerminalSession.id == session_id, TerminalSession.status == "active", TerminalSession.expires_at > now ) ) ) return result.scalar_one_or_none() async def list_active_for_user(self, user_id: str) -> List[TerminalSession]: """List all active sessions for a user.""" now = datetime.now(timezone.utc) result = await self.session.execute( select(TerminalSession).where( and_( TerminalSession.user_id == user_id, TerminalSession.status == "active", TerminalSession.expires_at > now ) ).order_by(TerminalSession.created_at.desc()) ) return list(result.scalars().all()) async def count_active_for_user(self, user_id: str) -> int: """Count active sessions for a user.""" sessions = await self.list_active_for_user(user_id) return len(sessions) async def list_expired(self) -> List[TerminalSession]: """List all expired but not yet closed sessions.""" now = datetime.now(timezone.utc) result = await self.session.execute( select(TerminalSession).where( and_( TerminalSession.status == "active", TerminalSession.expires_at <= now ) ) ) return list(result.scalars().all()) async def update_status( self, session_id: str, status: str, closed_at: Optional[datetime] = None ) -> Optional[TerminalSession]: """Update session status.""" session = await self.get(session_id) if session: session.status = status if closed_at: session.closed_at = closed_at await self.session.flush() return session async def close_session(self, session_id: str) -> Optional[TerminalSession]: """Close a terminal session.""" return await self.update_status( session_id, status="closed", closed_at=datetime.now(timezone.utc) ) async def mark_expired(self, session_id: str) -> Optional[TerminalSession]: """Mark a session as expired.""" return await self.update_status( session_id, status="expired", closed_at=datetime.now(timezone.utc) ) async def cleanup_old_sessions(self, days: int = 7) -> int: """Delete sessions older than specified days.""" from datetime import timedelta cutoff = datetime.now(timezone.utc) - timedelta(days=days) result = await self.session.execute( select(TerminalSession).where( and_( TerminalSession.status.in_(["closed", "expired", "error"]), TerminalSession.created_at < cutoff ) ) ) sessions = result.scalars().all() count = len(sessions) for session in sessions: await self.session.delete(session) await self.session.flush() return count async def close_all_active(self) -> int: """Close all active sessions (for cleanup/reset).""" now = datetime.now(timezone.utc) result = await self.session.execute( select(TerminalSession).where(TerminalSession.status == "active") ) sessions = list(result.scalars().all()) count = len(sessions) for session in sessions: session.status = "closed" session.closed_at = now await self.session.flush() return count