""" CRUD operations for terminal sessions. """ from datetime import datetime, timedelta, timezone from typing import List, Optional from sqlalchemy import select, update, and_, or_ from sqlalchemy.ext.asyncio import AsyncSession from app.models.terminal_session import ( TerminalSession, SESSION_STATUS_ACTIVE, SESSION_STATUS_CLOSED, SESSION_STATUS_EXPIRED, SESSION_STATUS_CLOSING, CLOSE_REASON_USER, CLOSE_REASON_TTL, CLOSE_REASON_IDLE, CLOSE_REASON_CLIENT_LOST, ) class TerminalSessionRepository: """Repository for TerminalSession CRUD operations.""" def __init__(self, session: AsyncSession): self.session = session async def create( self, id: str, host_id: str, host_name: str, host_ip: str, token_hash: str, ttyd_port: int, expires_at: datetime, user_id: Optional[str] = None, username: Optional[str] = None, ttyd_pid: Optional[int] = None, mode: str = "embedded", ) -> TerminalSession: """Create a new terminal session.""" now = datetime.now(timezone.utc) session = TerminalSession( id=id, host_id=host_id, host_name=host_name, host_ip=host_ip, user_id=user_id, username=username, token_hash=token_hash, ttyd_port=ttyd_port, ttyd_pid=ttyd_pid, mode=mode, status=SESSION_STATUS_ACTIVE, expires_at=expires_at, last_seen_at=now, ) self.session.add(session) await self.session.flush() return session async def get(self, session_id: str) -> Optional[TerminalSession]: """Get a terminal session by ID.""" result = await self.session.execute( select(TerminalSession).where(TerminalSession.id == session_id) ) return result.scalar_one_or_none() async def get_active_by_id(self, session_id: str) -> Optional[TerminalSession]: """Get an active terminal session by ID.""" now = datetime.now(timezone.utc) result = await self.session.execute( select(TerminalSession).where( and_( TerminalSession.id == session_id, TerminalSession.status == SESSION_STATUS_ACTIVE, TerminalSession.expires_at > now ) ) ) return result.scalar_one_or_none() async def find_reusable_session( self, user_id: str, host_id: str, mode: str, idle_timeout_seconds: int = 120 ) -> Optional[TerminalSession]: """ Find an existing active session that can be reused. A session is reusable if: - Same user, host, and mode - Status is 'active' - Not expired - Last seen within idle timeout (considered "healthy") """ now = datetime.now(timezone.utc) min_last_seen = now - timedelta(seconds=idle_timeout_seconds) result = await self.session.execute( select(TerminalSession).where( and_( TerminalSession.user_id == user_id, TerminalSession.host_id == host_id, TerminalSession.mode == mode, TerminalSession.status == SESSION_STATUS_ACTIVE, TerminalSession.expires_at > now, TerminalSession.last_seen_at >= min_last_seen ) ).order_by(TerminalSession.created_at.desc()) ) return result.scalar_one_or_none() async def list_active_for_user(self, user_id: str) -> List[TerminalSession]: """List all active sessions for a user.""" now = datetime.now(timezone.utc) result = await self.session.execute( select(TerminalSession).where( and_( TerminalSession.user_id == user_id, TerminalSession.status == SESSION_STATUS_ACTIVE, TerminalSession.expires_at > now ) ).order_by(TerminalSession.created_at.desc()) ) return list(result.scalars().all()) async def list_all_active(self) -> List[TerminalSession]: """List all active sessions (for admin/GC).""" now = datetime.now(timezone.utc) result = await self.session.execute( select(TerminalSession).where( and_( TerminalSession.status == SESSION_STATUS_ACTIVE, TerminalSession.expires_at > now ) ).order_by(TerminalSession.created_at.desc()) ) return list(result.scalars().all()) async def count_active_for_user(self, user_id: str) -> int: """Count active sessions for a user.""" sessions = await self.list_active_for_user(user_id) return len(sessions) async def list_expired(self) -> List[TerminalSession]: """List all expired but not yet closed sessions (TTL exceeded).""" now = datetime.now(timezone.utc) result = await self.session.execute( select(TerminalSession).where( and_( TerminalSession.status == SESSION_STATUS_ACTIVE, TerminalSession.expires_at <= now ) ) ) return list(result.scalars().all()) async def list_idle(self, idle_timeout_seconds: int) -> List[TerminalSession]: """List all idle sessions (no heartbeat within timeout).""" now = datetime.now(timezone.utc) idle_cutoff = now - timedelta(seconds=idle_timeout_seconds) result = await self.session.execute( select(TerminalSession).where( and_( TerminalSession.status == SESSION_STATUS_ACTIVE, TerminalSession.last_seen_at < idle_cutoff ) ) ) return list(result.scalars().all()) async def list_stale_sessions(self, ttl_seconds: int, idle_timeout_seconds: int) -> List[TerminalSession]: """ List all sessions that should be cleaned up. Returns sessions that are either: - Expired (past TTL) - Idle (no heartbeat within idle timeout) """ now = datetime.now(timezone.utc) idle_cutoff = now - timedelta(seconds=idle_timeout_seconds) result = await self.session.execute( select(TerminalSession).where( and_( TerminalSession.status == SESSION_STATUS_ACTIVE, or_( TerminalSession.expires_at <= now, TerminalSession.last_seen_at < idle_cutoff ) ) ) ) return list(result.scalars().all()) async def update_status( self, session_id: str, status: str, closed_at: Optional[datetime] = None, reason_closed: Optional[str] = None ) -> Optional[TerminalSession]: """Update session status.""" session = await self.get(session_id) if session: session.status = status if closed_at: session.closed_at = closed_at if reason_closed: session.reason_closed = reason_closed await self.session.flush() return session async def update_last_seen(self, session_id: str) -> Optional[TerminalSession]: """Update last_seen_at timestamp (heartbeat).""" session = await self.get(session_id) if session and session.status == SESSION_STATUS_ACTIVE: session.last_seen_at = datetime.now(timezone.utc) await self.session.flush() return session async def close_session( self, session_id: str, reason: str = CLOSE_REASON_USER ) -> Optional[TerminalSession]: """Close a terminal session with reason.""" return await self.update_status( session_id, status=SESSION_STATUS_CLOSED, closed_at=datetime.now(timezone.utc), reason_closed=reason ) async def mark_expired(self, session_id: str) -> Optional[TerminalSession]: """Mark a session as expired (TTL).""" return await self.update_status( session_id, status=SESSION_STATUS_EXPIRED, closed_at=datetime.now(timezone.utc), reason_closed=CLOSE_REASON_TTL ) async def mark_idle(self, session_id: str) -> Optional[TerminalSession]: """Mark a session as expired due to idle timeout.""" return await self.update_status( session_id, status=SESSION_STATUS_EXPIRED, closed_at=datetime.now(timezone.utc), reason_closed=CLOSE_REASON_IDLE ) async def mark_client_lost(self, session_id: str) -> Optional[TerminalSession]: """Mark a session as closed due to client disconnect.""" return await self.update_status( session_id, status=SESSION_STATUS_CLOSED, closed_at=datetime.now(timezone.utc), reason_closed=CLOSE_REASON_CLIENT_LOST ) async def cleanup_old_sessions(self, days: int = 7) -> int: """Delete sessions older than specified days.""" cutoff = datetime.now(timezone.utc) - timedelta(days=days) result = await self.session.execute( select(TerminalSession).where( and_( TerminalSession.status.in_([SESSION_STATUS_CLOSED, SESSION_STATUS_EXPIRED, "error"]), TerminalSession.created_at < cutoff ) ) ) sessions = result.scalars().all() count = len(sessions) for session in sessions: await self.session.delete(session) await self.session.flush() return count async def close_all_active(self, reason: str = CLOSE_REASON_USER) -> int: """Close all active sessions (for cleanup/reset).""" now = datetime.now(timezone.utc) result = await self.session.execute( select(TerminalSession).where(TerminalSession.status == SESSION_STATUS_ACTIVE) ) sessions = list(result.scalars().all()) count = len(sessions) for session in sessions: session.status = SESSION_STATUS_CLOSED session.closed_at = now session.reason_closed = reason await self.session.flush() return count async def get_oldest_active_for_user(self, user_id: str) -> Optional[TerminalSession]: """Get the oldest active session for a user (for auto-close).""" now = datetime.now(timezone.utc) result = await self.session.execute( select(TerminalSession).where( and_( TerminalSession.user_id == user_id, TerminalSession.status == SESSION_STATUS_ACTIVE, TerminalSession.expires_at > now ) ).order_by(TerminalSession.created_at.asc()) ) return result.scalars().first()