""" Terminal Command Logger - Captures and logs validated commands from terminal sessions. This module provides: - Command buffer management per session - Enter key detection to capture complete commands - Integration with CommandPolicy for validation - Async database logging Key features: - Handles backspace, Ctrl+U, and Enter sequences - Only logs commands that pass security validation - Thread-safe buffer management """ import asyncio import logging import os from dataclasses import dataclass, field from datetime import datetime, timedelta, timezone from typing import Callable, Dict, Optional from app.security.command_policy import CommandPolicy, CommandPolicyResult, get_command_policy logger = logging.getLogger(__name__) # Configuration TERMINAL_COMMAND_RETENTION_DAYS = int(os.environ.get("TERMINAL_COMMAND_RETENTION_DAYS", "30")) @dataclass class SessionContext: """Context for a terminal session's command buffer.""" session_id: str host_id: str host_name: str user_id: Optional[str] username: Optional[str] # Command buffer (characters waiting for Enter) buffer: str = "" # Stats commands_logged: int = 0 commands_blocked: int = 0 created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) class TerminalCommandLogger: """ Manages command buffers for terminal sessions and logs validated commands. Usage: logger = TerminalCommandLogger() # Create session context logger.create_session(session_id, host_id, host_name, user_id, username) # Process input bytes from terminal async for result in logger.process_input(session_id, input_bytes): if result.should_log: # Command was logged to database pass # Cleanup when session ends logger.remove_session(session_id) """ def __init__(self, policy: Optional[CommandPolicy] = None): self._sessions: Dict[str, SessionContext] = {} self._lock = asyncio.Lock() self._policy = policy or get_command_policy() # Callback for logging to database (injected) self._log_callback: Optional[Callable] = None def set_log_callback(self, callback: Callable): """Set the callback function for logging commands to database.""" self._log_callback = callback async def create_session( self, session_id: str, host_id: str, host_name: str, user_id: Optional[str] = None, username: Optional[str] = None, ) -> SessionContext: """Create a new session context for command buffering.""" async with self._lock: ctx = SessionContext( session_id=session_id, host_id=host_id, host_name=host_name, user_id=user_id, username=username, ) self._sessions[session_id] = ctx logger.debug(f"Created command logger session: {session_id[:8]}... for {host_name}") return ctx async def remove_session(self, session_id: str) -> Optional[SessionContext]: """Remove a session context.""" async with self._lock: ctx = self._sessions.pop(session_id, None) if ctx: logger.debug( f"Removed command logger session: {session_id[:8]}... " f"(logged={ctx.commands_logged}, blocked={ctx.commands_blocked})" ) return ctx def get_session(self, session_id: str) -> Optional[SessionContext]: """Get a session context.""" return self._sessions.get(session_id) def _handle_control_char(self, ctx: SessionContext, char: bytes) -> bool: """ Handle control characters in the input stream. Returns True if the character was handled, False otherwise. """ if char in (b'\x7f', b'\b'): # Backspace (DEL or BS) if ctx.buffer: ctx.buffer = ctx.buffer[:-1] return True if char == b'\x15': # Ctrl+U - clear line ctx.buffer = "" return True if char == b'\x17': # Ctrl+W - delete word # Delete last word (back to space or start) ctx.buffer = ctx.buffer.rstrip() if ' ' in ctx.buffer: ctx.buffer = ctx.buffer[:ctx.buffer.rfind(' ') + 1] else: ctx.buffer = "" return True if char == b'\x03': # Ctrl+C - interrupt ctx.buffer = "" return True if char == b'\x04': # Ctrl+D - EOF ctx.buffer = "" return True return False async def process_input( self, session_id: str, data: bytes, ) -> list[CommandPolicyResult]: """ Process input bytes from a terminal session. This method: 1. Adds printable characters to the buffer 2. Handles control characters (backspace, Ctrl+U, etc.) 3. When Enter is detected, evaluates the command and logs if valid Args: session_id: The terminal session ID data: Raw bytes from the terminal input stream Returns: List of CommandPolicyResult for any commands that were evaluated """ ctx = self._sessions.get(session_id) if not ctx: return [] results = [] # Process byte by byte i = 0 while i < len(data): char = data[i:i+1] # Check for Enter (CR or LF) if char in (b'\r', b'\n'): # Skip CRLF sequence if char == b'\r' and i + 1 < len(data) and data[i+1:i+2] == b'\n': i += 1 # Process the command command = ctx.buffer.strip() ctx.buffer = "" if command: result = await self._evaluate_and_log(ctx, command) results.append(result) i += 1 continue # Handle control characters if self._handle_control_char(ctx, char): i += 1 continue # Skip other control characters and escape sequences if char[0] < 32 or char == b'\x7f': # Check for escape sequence if char == b'\x1b' and i + 1 < len(data): # Skip escape sequences (arrow keys, etc.) i += 1 while i < len(data) and data[i:i+1] not in (b'', b'\r', b'\n'): next_char = data[i:i+1] if next_char.isalpha() or next_char == b'~': i += 1 break i += 1 continue i += 1 continue # Add printable character to buffer try: ctx.buffer += char.decode('utf-8', errors='ignore') except Exception: pass i += 1 return results async def _evaluate_and_log( self, ctx: SessionContext, command: str, ) -> CommandPolicyResult: """Evaluate a command against policy and log if valid.""" result = self._policy.evaluate(command) if result.should_log: ctx.commands_logged += 1 logger.debug(f"Command logged for {ctx.host_name}: {result.masked_command[:50]}...") # Call the log callback if set if self._log_callback: try: await self._log_callback( host_id=ctx.host_id, host_name=ctx.host_name, user_id=ctx.user_id, username=ctx.username, session_id=ctx.session_id, command=result.masked_command, command_hash=result.command_hash, ) except Exception as e: logger.error(f"Failed to log command to database: {e}") elif result.is_blocked: ctx.commands_blocked += 1 logger.info(f"Command blocked for {ctx.host_name}: {result.reason}") # Optionally log blocked command metadata (not the command itself) if self._log_callback: try: await self._log_callback( host_id=ctx.host_id, host_name=ctx.host_name, user_id=ctx.user_id, username=ctx.username, session_id=ctx.session_id, command="[BLOCKED]", command_hash="blocked", is_blocked=True, blocked_reason=result.reason, ) except Exception as e: logger.error(f"Failed to log blocked command metadata: {e}") else: # Command not in allowlist - silently skip logger.debug(f"Command not in allowlist for {ctx.host_name}: {command[:30]}...") return result async def flush_buffer(self, session_id: str) -> Optional[CommandPolicyResult]: """ Flush any remaining buffer content as a command. Useful when session ends to capture any incomplete command. """ ctx = self._sessions.get(session_id) if not ctx or not ctx.buffer.strip(): return None command = ctx.buffer.strip() ctx.buffer = "" return await self._evaluate_and_log(ctx, command) def get_stats(self) -> dict: """Get statistics about command logging.""" total_logged = sum(ctx.commands_logged for ctx in self._sessions.values()) total_blocked = sum(ctx.commands_blocked for ctx in self._sessions.values()) return { "active_sessions": len(self._sessions), "total_commands_logged": total_logged, "total_commands_blocked": total_blocked, } # ============================================================================ # Global instance # ============================================================================ _command_logger_instance: Optional[TerminalCommandLogger] = None def get_command_logger() -> TerminalCommandLogger: """Get or create the global command logger instance.""" global _command_logger_instance if _command_logger_instance is None: _command_logger_instance = TerminalCommandLogger() return _command_logger_instance