homelab_automation/app/services/terminal_command_logger.py
Bruno Charest 5bc12d0729
Some checks failed
Tests / Backend Tests (Python) (3.10) (push) Has been cancelled
Tests / Backend Tests (Python) (3.11) (push) Has been cancelled
Tests / Backend Tests (Python) (3.12) (push) Has been cancelled
Tests / Frontend Tests (JS) (push) Has been cancelled
Tests / Integration Tests (push) Has been cancelled
Tests / All Tests Passed (push) Has been cancelled
Add terminal session management with heartbeat monitoring, idle timeout detection, session reuse logic, and command history panel UI with search and filtering capabilities
2025-12-18 13:49:40 -05:00

323 lines
11 KiB
Python

"""
Terminal Command Logger - Captures and logs validated commands from terminal sessions.
This module provides:
- Command buffer management per session
- Enter key detection to capture complete commands
- Integration with CommandPolicy for validation
- Async database logging
Key features:
- Handles backspace, Ctrl+U, and Enter sequences
- Only logs commands that pass security validation
- Thread-safe buffer management
"""
import asyncio
import logging
import os
from dataclasses import dataclass, field
from datetime import datetime, timedelta, timezone
from typing import Callable, Dict, Optional
from app.security.command_policy import CommandPolicy, CommandPolicyResult, get_command_policy
logger = logging.getLogger(__name__)
# Configuration
TERMINAL_COMMAND_RETENTION_DAYS = int(os.environ.get("TERMINAL_COMMAND_RETENTION_DAYS", "30"))
@dataclass
class SessionContext:
"""Context for a terminal session's command buffer."""
session_id: str
host_id: str
host_name: str
user_id: Optional[str]
username: Optional[str]
# Command buffer (characters waiting for Enter)
buffer: str = ""
# Stats
commands_logged: int = 0
commands_blocked: int = 0
created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
class TerminalCommandLogger:
"""
Manages command buffers for terminal sessions and logs validated commands.
Usage:
logger = TerminalCommandLogger()
# Create session context
logger.create_session(session_id, host_id, host_name, user_id, username)
# Process input bytes from terminal
async for result in logger.process_input(session_id, input_bytes):
if result.should_log:
# Command was logged to database
pass
# Cleanup when session ends
logger.remove_session(session_id)
"""
def __init__(self, policy: Optional[CommandPolicy] = None):
self._sessions: Dict[str, SessionContext] = {}
self._lock = asyncio.Lock()
self._policy = policy or get_command_policy()
# Callback for logging to database (injected)
self._log_callback: Optional[Callable] = None
def set_log_callback(self, callback: Callable):
"""Set the callback function for logging commands to database."""
self._log_callback = callback
async def create_session(
self,
session_id: str,
host_id: str,
host_name: str,
user_id: Optional[str] = None,
username: Optional[str] = None,
) -> SessionContext:
"""Create a new session context for command buffering."""
async with self._lock:
ctx = SessionContext(
session_id=session_id,
host_id=host_id,
host_name=host_name,
user_id=user_id,
username=username,
)
self._sessions[session_id] = ctx
logger.debug(f"Created command logger session: {session_id[:8]}... for {host_name}")
return ctx
async def remove_session(self, session_id: str) -> Optional[SessionContext]:
"""Remove a session context."""
async with self._lock:
ctx = self._sessions.pop(session_id, None)
if ctx:
logger.debug(
f"Removed command logger session: {session_id[:8]}... "
f"(logged={ctx.commands_logged}, blocked={ctx.commands_blocked})"
)
return ctx
def get_session(self, session_id: str) -> Optional[SessionContext]:
"""Get a session context."""
return self._sessions.get(session_id)
def _handle_control_char(self, ctx: SessionContext, char: bytes) -> bool:
"""
Handle control characters in the input stream.
Returns True if the character was handled, False otherwise.
"""
if char in (b'\x7f', b'\b'): # Backspace (DEL or BS)
if ctx.buffer:
ctx.buffer = ctx.buffer[:-1]
return True
if char == b'\x15': # Ctrl+U - clear line
ctx.buffer = ""
return True
if char == b'\x17': # Ctrl+W - delete word
# Delete last word (back to space or start)
ctx.buffer = ctx.buffer.rstrip()
if ' ' in ctx.buffer:
ctx.buffer = ctx.buffer[:ctx.buffer.rfind(' ') + 1]
else:
ctx.buffer = ""
return True
if char == b'\x03': # Ctrl+C - interrupt
ctx.buffer = ""
return True
if char == b'\x04': # Ctrl+D - EOF
ctx.buffer = ""
return True
return False
async def process_input(
self,
session_id: str,
data: bytes,
) -> list[CommandPolicyResult]:
"""
Process input bytes from a terminal session.
This method:
1. Adds printable characters to the buffer
2. Handles control characters (backspace, Ctrl+U, etc.)
3. When Enter is detected, evaluates the command and logs if valid
Args:
session_id: The terminal session ID
data: Raw bytes from the terminal input stream
Returns:
List of CommandPolicyResult for any commands that were evaluated
"""
ctx = self._sessions.get(session_id)
if not ctx:
return []
results = []
# Process byte by byte
i = 0
while i < len(data):
char = data[i:i+1]
# Check for Enter (CR or LF)
if char in (b'\r', b'\n'):
# Skip CRLF sequence
if char == b'\r' and i + 1 < len(data) and data[i+1:i+2] == b'\n':
i += 1
# Process the command
command = ctx.buffer.strip()
ctx.buffer = ""
if command:
result = await self._evaluate_and_log(ctx, command)
results.append(result)
i += 1
continue
# Handle control characters
if self._handle_control_char(ctx, char):
i += 1
continue
# Skip other control characters and escape sequences
if char[0] < 32 or char == b'\x7f':
# Check for escape sequence
if char == b'\x1b' and i + 1 < len(data):
# Skip escape sequences (arrow keys, etc.)
i += 1
while i < len(data) and data[i:i+1] not in (b'', b'\r', b'\n'):
next_char = data[i:i+1]
if next_char.isalpha() or next_char == b'~':
i += 1
break
i += 1
continue
i += 1
continue
# Add printable character to buffer
try:
ctx.buffer += char.decode('utf-8', errors='ignore')
except Exception:
pass
i += 1
return results
async def _evaluate_and_log(
self,
ctx: SessionContext,
command: str,
) -> CommandPolicyResult:
"""Evaluate a command against policy and log if valid."""
result = self._policy.evaluate(command)
if result.should_log:
ctx.commands_logged += 1
logger.debug(f"Command logged for {ctx.host_name}: {result.masked_command[:50]}...")
# Call the log callback if set
if self._log_callback:
try:
await self._log_callback(
host_id=ctx.host_id,
host_name=ctx.host_name,
user_id=ctx.user_id,
username=ctx.username,
session_id=ctx.session_id,
command=result.masked_command,
command_hash=result.command_hash,
)
except Exception as e:
logger.error(f"Failed to log command to database: {e}")
elif result.is_blocked:
ctx.commands_blocked += 1
logger.info(f"Command blocked for {ctx.host_name}: {result.reason}")
# Optionally log blocked command metadata (not the command itself)
if self._log_callback:
try:
await self._log_callback(
host_id=ctx.host_id,
host_name=ctx.host_name,
user_id=ctx.user_id,
username=ctx.username,
session_id=ctx.session_id,
command="[BLOCKED]",
command_hash="blocked",
is_blocked=True,
blocked_reason=result.reason,
)
except Exception as e:
logger.error(f"Failed to log blocked command metadata: {e}")
else:
# Command not in allowlist - silently skip
logger.debug(f"Command not in allowlist for {ctx.host_name}: {command[:30]}...")
return result
async def flush_buffer(self, session_id: str) -> Optional[CommandPolicyResult]:
"""
Flush any remaining buffer content as a command.
Useful when session ends to capture any incomplete command.
"""
ctx = self._sessions.get(session_id)
if not ctx or not ctx.buffer.strip():
return None
command = ctx.buffer.strip()
ctx.buffer = ""
return await self._evaluate_and_log(ctx, command)
def get_stats(self) -> dict:
"""Get statistics about command logging."""
total_logged = sum(ctx.commands_logged for ctx in self._sessions.values())
total_blocked = sum(ctx.commands_blocked for ctx in self._sessions.values())
return {
"active_sessions": len(self._sessions),
"total_commands_logged": total_logged,
"total_commands_blocked": total_blocked,
}
# ============================================================================
# Global instance
# ============================================================================
_command_logger_instance: Optional[TerminalCommandLogger] = None
def get_command_logger() -> TerminalCommandLogger:
"""Get or create the global command logger instance."""
global _command_logger_instance
if _command_logger_instance is None:
_command_logger_instance = TerminalCommandLogger()
return _command_logger_instance