Some checks failed
Tests / Backend Tests (Python) (3.10) (push) Has been cancelled
Tests / Backend Tests (Python) (3.11) (push) Has been cancelled
Tests / Backend Tests (Python) (3.12) (push) Has been cancelled
Tests / Frontend Tests (JS) (push) Has been cancelled
Tests / Integration Tests (push) Has been cancelled
Tests / All Tests Passed (push) Has been cancelled
323 lines
11 KiB
Python
323 lines
11 KiB
Python
"""
|
|
Terminal Command Logger - Captures and logs validated commands from terminal sessions.
|
|
|
|
This module provides:
|
|
- Command buffer management per session
|
|
- Enter key detection to capture complete commands
|
|
- Integration with CommandPolicy for validation
|
|
- Async database logging
|
|
|
|
Key features:
|
|
- Handles backspace, Ctrl+U, and Enter sequences
|
|
- Only logs commands that pass security validation
|
|
- Thread-safe buffer management
|
|
"""
|
|
import asyncio
|
|
import logging
|
|
import os
|
|
from dataclasses import dataclass, field
|
|
from datetime import datetime, timedelta, timezone
|
|
from typing import Callable, Dict, Optional
|
|
|
|
from app.security.command_policy import CommandPolicy, CommandPolicyResult, get_command_policy
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Configuration
|
|
TERMINAL_COMMAND_RETENTION_DAYS = int(os.environ.get("TERMINAL_COMMAND_RETENTION_DAYS", "30"))
|
|
|
|
|
|
@dataclass
|
|
class SessionContext:
|
|
"""Context for a terminal session's command buffer."""
|
|
session_id: str
|
|
host_id: str
|
|
host_name: str
|
|
user_id: Optional[str]
|
|
username: Optional[str]
|
|
|
|
# Command buffer (characters waiting for Enter)
|
|
buffer: str = ""
|
|
|
|
# Stats
|
|
commands_logged: int = 0
|
|
commands_blocked: int = 0
|
|
created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
|
|
|
|
|
class TerminalCommandLogger:
|
|
"""
|
|
Manages command buffers for terminal sessions and logs validated commands.
|
|
|
|
Usage:
|
|
logger = TerminalCommandLogger()
|
|
|
|
# Create session context
|
|
logger.create_session(session_id, host_id, host_name, user_id, username)
|
|
|
|
# Process input bytes from terminal
|
|
async for result in logger.process_input(session_id, input_bytes):
|
|
if result.should_log:
|
|
# Command was logged to database
|
|
pass
|
|
|
|
# Cleanup when session ends
|
|
logger.remove_session(session_id)
|
|
"""
|
|
|
|
def __init__(self, policy: Optional[CommandPolicy] = None):
|
|
self._sessions: Dict[str, SessionContext] = {}
|
|
self._lock = asyncio.Lock()
|
|
self._policy = policy or get_command_policy()
|
|
|
|
# Callback for logging to database (injected)
|
|
self._log_callback: Optional[Callable] = None
|
|
|
|
def set_log_callback(self, callback: Callable):
|
|
"""Set the callback function for logging commands to database."""
|
|
self._log_callback = callback
|
|
|
|
async def create_session(
|
|
self,
|
|
session_id: str,
|
|
host_id: str,
|
|
host_name: str,
|
|
user_id: Optional[str] = None,
|
|
username: Optional[str] = None,
|
|
) -> SessionContext:
|
|
"""Create a new session context for command buffering."""
|
|
async with self._lock:
|
|
ctx = SessionContext(
|
|
session_id=session_id,
|
|
host_id=host_id,
|
|
host_name=host_name,
|
|
user_id=user_id,
|
|
username=username,
|
|
)
|
|
self._sessions[session_id] = ctx
|
|
logger.debug(f"Created command logger session: {session_id[:8]}... for {host_name}")
|
|
return ctx
|
|
|
|
async def remove_session(self, session_id: str) -> Optional[SessionContext]:
|
|
"""Remove a session context."""
|
|
async with self._lock:
|
|
ctx = self._sessions.pop(session_id, None)
|
|
if ctx:
|
|
logger.debug(
|
|
f"Removed command logger session: {session_id[:8]}... "
|
|
f"(logged={ctx.commands_logged}, blocked={ctx.commands_blocked})"
|
|
)
|
|
return ctx
|
|
|
|
def get_session(self, session_id: str) -> Optional[SessionContext]:
|
|
"""Get a session context."""
|
|
return self._sessions.get(session_id)
|
|
|
|
def _handle_control_char(self, ctx: SessionContext, char: bytes) -> bool:
|
|
"""
|
|
Handle control characters in the input stream.
|
|
|
|
Returns True if the character was handled, False otherwise.
|
|
"""
|
|
if char in (b'\x7f', b'\b'): # Backspace (DEL or BS)
|
|
if ctx.buffer:
|
|
ctx.buffer = ctx.buffer[:-1]
|
|
return True
|
|
|
|
if char == b'\x15': # Ctrl+U - clear line
|
|
ctx.buffer = ""
|
|
return True
|
|
|
|
if char == b'\x17': # Ctrl+W - delete word
|
|
# Delete last word (back to space or start)
|
|
ctx.buffer = ctx.buffer.rstrip()
|
|
if ' ' in ctx.buffer:
|
|
ctx.buffer = ctx.buffer[:ctx.buffer.rfind(' ') + 1]
|
|
else:
|
|
ctx.buffer = ""
|
|
return True
|
|
|
|
if char == b'\x03': # Ctrl+C - interrupt
|
|
ctx.buffer = ""
|
|
return True
|
|
|
|
if char == b'\x04': # Ctrl+D - EOF
|
|
ctx.buffer = ""
|
|
return True
|
|
|
|
return False
|
|
|
|
async def process_input(
|
|
self,
|
|
session_id: str,
|
|
data: bytes,
|
|
) -> list[CommandPolicyResult]:
|
|
"""
|
|
Process input bytes from a terminal session.
|
|
|
|
This method:
|
|
1. Adds printable characters to the buffer
|
|
2. Handles control characters (backspace, Ctrl+U, etc.)
|
|
3. When Enter is detected, evaluates the command and logs if valid
|
|
|
|
Args:
|
|
session_id: The terminal session ID
|
|
data: Raw bytes from the terminal input stream
|
|
|
|
Returns:
|
|
List of CommandPolicyResult for any commands that were evaluated
|
|
"""
|
|
ctx = self._sessions.get(session_id)
|
|
if not ctx:
|
|
return []
|
|
|
|
results = []
|
|
|
|
# Process byte by byte
|
|
i = 0
|
|
while i < len(data):
|
|
char = data[i:i+1]
|
|
|
|
# Check for Enter (CR or LF)
|
|
if char in (b'\r', b'\n'):
|
|
# Skip CRLF sequence
|
|
if char == b'\r' and i + 1 < len(data) and data[i+1:i+2] == b'\n':
|
|
i += 1
|
|
|
|
# Process the command
|
|
command = ctx.buffer.strip()
|
|
ctx.buffer = ""
|
|
|
|
if command:
|
|
result = await self._evaluate_and_log(ctx, command)
|
|
results.append(result)
|
|
|
|
i += 1
|
|
continue
|
|
|
|
# Handle control characters
|
|
if self._handle_control_char(ctx, char):
|
|
i += 1
|
|
continue
|
|
|
|
# Skip other control characters and escape sequences
|
|
if char[0] < 32 or char == b'\x7f':
|
|
# Check for escape sequence
|
|
if char == b'\x1b' and i + 1 < len(data):
|
|
# Skip escape sequences (arrow keys, etc.)
|
|
i += 1
|
|
while i < len(data) and data[i:i+1] not in (b'', b'\r', b'\n'):
|
|
next_char = data[i:i+1]
|
|
if next_char.isalpha() or next_char == b'~':
|
|
i += 1
|
|
break
|
|
i += 1
|
|
continue
|
|
i += 1
|
|
continue
|
|
|
|
# Add printable character to buffer
|
|
try:
|
|
ctx.buffer += char.decode('utf-8', errors='ignore')
|
|
except Exception:
|
|
pass
|
|
|
|
i += 1
|
|
|
|
return results
|
|
|
|
async def _evaluate_and_log(
|
|
self,
|
|
ctx: SessionContext,
|
|
command: str,
|
|
) -> CommandPolicyResult:
|
|
"""Evaluate a command against policy and log if valid."""
|
|
result = self._policy.evaluate(command)
|
|
|
|
if result.should_log:
|
|
ctx.commands_logged += 1
|
|
logger.debug(f"Command logged for {ctx.host_name}: {result.masked_command[:50]}...")
|
|
|
|
# Call the log callback if set
|
|
if self._log_callback:
|
|
try:
|
|
await self._log_callback(
|
|
host_id=ctx.host_id,
|
|
host_name=ctx.host_name,
|
|
user_id=ctx.user_id,
|
|
username=ctx.username,
|
|
session_id=ctx.session_id,
|
|
command=result.masked_command,
|
|
command_hash=result.command_hash,
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"Failed to log command to database: {e}")
|
|
|
|
elif result.is_blocked:
|
|
ctx.commands_blocked += 1
|
|
logger.info(f"Command blocked for {ctx.host_name}: {result.reason}")
|
|
|
|
# Optionally log blocked command metadata (not the command itself)
|
|
if self._log_callback:
|
|
try:
|
|
await self._log_callback(
|
|
host_id=ctx.host_id,
|
|
host_name=ctx.host_name,
|
|
user_id=ctx.user_id,
|
|
username=ctx.username,
|
|
session_id=ctx.session_id,
|
|
command="[BLOCKED]",
|
|
command_hash="blocked",
|
|
is_blocked=True,
|
|
blocked_reason=result.reason,
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"Failed to log blocked command metadata: {e}")
|
|
|
|
else:
|
|
# Command not in allowlist - silently skip
|
|
logger.debug(f"Command not in allowlist for {ctx.host_name}: {command[:30]}...")
|
|
|
|
return result
|
|
|
|
async def flush_buffer(self, session_id: str) -> Optional[CommandPolicyResult]:
|
|
"""
|
|
Flush any remaining buffer content as a command.
|
|
|
|
Useful when session ends to capture any incomplete command.
|
|
"""
|
|
ctx = self._sessions.get(session_id)
|
|
if not ctx or not ctx.buffer.strip():
|
|
return None
|
|
|
|
command = ctx.buffer.strip()
|
|
ctx.buffer = ""
|
|
return await self._evaluate_and_log(ctx, command)
|
|
|
|
def get_stats(self) -> dict:
|
|
"""Get statistics about command logging."""
|
|
total_logged = sum(ctx.commands_logged for ctx in self._sessions.values())
|
|
total_blocked = sum(ctx.commands_blocked for ctx in self._sessions.values())
|
|
|
|
return {
|
|
"active_sessions": len(self._sessions),
|
|
"total_commands_logged": total_logged,
|
|
"total_commands_blocked": total_blocked,
|
|
}
|
|
|
|
|
|
# ============================================================================
|
|
# Global instance
|
|
# ============================================================================
|
|
_command_logger_instance: Optional[TerminalCommandLogger] = None
|
|
|
|
|
|
def get_command_logger() -> TerminalCommandLogger:
|
|
"""Get or create the global command logger instance."""
|
|
global _command_logger_instance
|
|
|
|
if _command_logger_instance is None:
|
|
_command_logger_instance = TerminalCommandLogger()
|
|
|
|
return _command_logger_instance
|