""" CRUD operations for terminal command logs. Provides repository pattern for storing and retrieving terminal command history. """ import os from datetime import datetime, timedelta, timezone from typing import List, Optional from sqlalchemy import and_, delete, func, or_, select from sqlalchemy.ext.asyncio import AsyncSession from app.models.terminal_command_log import TerminalCommandLog # Configuration TERMINAL_COMMAND_RETENTION_DAYS = int(os.environ.get("TERMINAL_COMMAND_RETENTION_DAYS", "30")) class TerminalCommandLogRepository: """Repository for TerminalCommandLog CRUD operations.""" def __init__(self, session: AsyncSession): self.session = session async def create( self, host_id: str, command: str, command_hash: str, host_name: Optional[str] = None, user_id: Optional[str] = None, username: Optional[str] = None, terminal_session_id: Optional[str] = None, source: str = "terminal", is_blocked: bool = False, blocked_reason: Optional[str] = None, ) -> TerminalCommandLog: """Create a new command log entry.""" log = TerminalCommandLog( host_id=host_id, host_name=host_name, user_id=user_id, username=username, terminal_session_id=terminal_session_id, command=command, command_hash=command_hash, source=source, is_blocked=is_blocked, blocked_reason=blocked_reason, ) self.session.add(log) await self.session.flush() return log async def get(self, log_id: int) -> Optional[TerminalCommandLog]: """Get a command log by ID.""" result = await self.session.execute( select(TerminalCommandLog).where(TerminalCommandLog.id == log_id) ) return result.scalar_one_or_none() async def list_for_host( self, host_id: str, query: Optional[str] = None, limit: int = 50, offset: int = 0, include_blocked: bool = False, ) -> List[TerminalCommandLog]: """ List command logs for a specific host. Args: host_id: The host ID to filter by query: Optional search query (filters command text) limit: Maximum number of results offset: Number of results to skip include_blocked: Whether to include blocked commands """ conditions = [TerminalCommandLog.host_id == host_id] if not include_blocked: conditions.append(TerminalCommandLog.is_blocked == False) if query: conditions.append(TerminalCommandLog.command.ilike(f"%{query}%")) stmt = ( select(TerminalCommandLog) .where(and_(*conditions)) .order_by(TerminalCommandLog.created_at.desc()) .limit(limit) .offset(offset) ) result = await self.session.execute(stmt) return list(result.scalars().all()) async def list_for_user( self, user_id: str, query: Optional[str] = None, host_id: Optional[str] = None, limit: int = 50, offset: int = 0, ) -> List[TerminalCommandLog]: """ List command logs for a specific user. Args: user_id: The user ID to filter by query: Optional search query host_id: Optional host ID to filter by limit: Maximum number of results offset: Number of results to skip """ conditions = [ TerminalCommandLog.user_id == user_id, TerminalCommandLog.is_blocked == False, ] if host_id: conditions.append(TerminalCommandLog.host_id == host_id) if query: conditions.append(TerminalCommandLog.command.ilike(f"%{query}%")) stmt = ( select(TerminalCommandLog) .where(and_(*conditions)) .order_by(TerminalCommandLog.created_at.desc()) .limit(limit) .offset(offset) ) result = await self.session.execute(stmt) return list(result.scalars().all()) async def list_global( self, query: Optional[str] = None, host_id: Optional[str] = None, user_id: Optional[str] = None, limit: int = 50, offset: int = 0, ) -> List[TerminalCommandLog]: """ List command logs globally with optional filters. Args: query: Optional search query host_id: Optional host ID to filter by user_id: Optional user ID to filter by limit: Maximum number of results offset: Number of results to skip """ conditions = [TerminalCommandLog.is_blocked == False] if host_id: conditions.append(TerminalCommandLog.host_id == host_id) if user_id: conditions.append(TerminalCommandLog.user_id == user_id) if query: conditions.append(TerminalCommandLog.command.ilike(f"%{query}%")) stmt = ( select(TerminalCommandLog) .where(and_(*conditions)) .order_by(TerminalCommandLog.created_at.desc()) .limit(limit) .offset(offset) ) result = await self.session.execute(stmt) return list(result.scalars().all()) async def get_unique_commands_for_host( self, host_id: str, query: Optional[str] = None, limit: int = 50, ) -> List[dict]: """ Get unique commands for a host (deduped by command_hash). Returns commands with their most recent execution time. """ conditions = [ TerminalCommandLog.host_id == host_id, TerminalCommandLog.is_blocked == False, ] if query: conditions.append(TerminalCommandLog.command.ilike(f"%{query}%")) # Subquery to get the max created_at for each command_hash subq = ( select( TerminalCommandLog.command_hash, func.max(TerminalCommandLog.created_at).label("max_created"), func.count(TerminalCommandLog.id).label("execution_count"), ) .where(and_(*conditions)) .group_by(TerminalCommandLog.command_hash) .subquery() ) # Join to get the actual command text stmt = ( select( TerminalCommandLog.command, TerminalCommandLog.command_hash, subq.c.max_created, subq.c.execution_count, ) .join(subq, and_( TerminalCommandLog.command_hash == subq.c.command_hash, TerminalCommandLog.created_at == subq.c.max_created, )) .where(TerminalCommandLog.host_id == host_id) .order_by(subq.c.max_created.desc()) .limit(limit) ) result = await self.session.execute(stmt) rows = result.all() return [ { "command": row.command, "command_hash": row.command_hash, "last_used": row.max_created, "execution_count": row.execution_count, } for row in rows ] async def count_for_host(self, host_id: str) -> int: """Count command logs for a host.""" result = await self.session.execute( select(func.count(TerminalCommandLog.id)) .where(and_( TerminalCommandLog.host_id == host_id, TerminalCommandLog.is_blocked == False, )) ) return result.scalar() or 0 async def purge_old_logs(self, days: Optional[int] = None) -> int: """ Purge logs older than the specified retention period. Args: days: Number of days to retain (defaults to TERMINAL_COMMAND_RETENTION_DAYS) Returns: Number of logs deleted """ if days is None: days = TERMINAL_COMMAND_RETENTION_DAYS cutoff = datetime.now(timezone.utc) - timedelta(days=days) # First count how many will be deleted count_result = await self.session.execute( select(func.count(TerminalCommandLog.id)) .where(TerminalCommandLog.created_at < cutoff) ) count = count_result.scalar() or 0 # Delete old logs await self.session.execute( delete(TerminalCommandLog) .where(TerminalCommandLog.created_at < cutoff) ) await self.session.flush() return count async def delete_for_host(self, host_id: str) -> int: """Delete all command logs for a host.""" count_result = await self.session.execute( select(func.count(TerminalCommandLog.id)) .where(TerminalCommandLog.host_id == host_id) ) count = count_result.scalar() or 0 await self.session.execute( delete(TerminalCommandLog) .where(TerminalCommandLog.host_id == host_id) ) await self.session.flush() return count async def delete_for_session(self, session_id: str) -> int: """Delete all command logs for a terminal session.""" count_result = await self.session.execute( select(func.count(TerminalCommandLog.id)) .where(TerminalCommandLog.terminal_session_id == session_id) ) count = count_result.scalar() or 0 await self.session.execute( delete(TerminalCommandLog) .where(TerminalCommandLog.terminal_session_id == session_id) ) await self.session.flush() return count