homelab_automation/app/services/terminal_service.py
Bruno Charest 5bc12d0729
Some checks failed
Tests / Backend Tests (Python) (3.10) (push) Has been cancelled
Tests / Backend Tests (Python) (3.11) (push) Has been cancelled
Tests / Backend Tests (Python) (3.12) (push) Has been cancelled
Tests / Frontend Tests (JS) (push) Has been cancelled
Tests / Integration Tests (push) Has been cancelled
Tests / All Tests Passed (push) Has been cancelled
Add terminal session management with heartbeat monitoring, idle timeout detection, session reuse logic, and command history panel UI with search and filtering capabilities
2025-12-18 13:49:40 -05:00

598 lines
21 KiB
Python

"""
Terminal Service - Manages SSH terminal sessions via ttyd.
This service handles:
- Creating terminal sessions with unique tokens
- Spawning ttyd processes for SSH connections
- Managing session lifecycle (creation, expiration, cleanup)
- Port allocation for ttyd instances
- Garbage collection of stale/idle sessions
- Session reuse for same user/host/mode
"""
import asyncio
import hashlib
import logging
import os
import secrets
import shutil
import signal
import socket
import subprocess
from datetime import datetime, timezone
from typing import Dict, Any, Tuple, Optional
from app.core.config import settings
logger = logging.getLogger(__name__)
# Configuration - Session TTL and limits
TERMINAL_SESSION_TTL_MINUTES = int(os.environ.get("TERMINAL_SESSION_TTL_MINUTES", "30"))
TERMINAL_SESSION_TTL_SECONDS = TERMINAL_SESSION_TTL_MINUTES * 60
TERMINAL_MAX_SESSIONS_PER_USER = int(os.environ.get("TERMINAL_MAX_SESSIONS_PER_USER", "3"))
# Idle timeout - sessions without heartbeat for this long are considered dead
TERMINAL_SESSION_IDLE_TIMEOUT_SECONDS = int(os.environ.get("TERMINAL_SESSION_IDLE_TIMEOUT_SECONDS", "120"))
# Heartbeat interval - how often client should send heartbeat (slightly less than idle timeout)
TERMINAL_HEARTBEAT_INTERVAL_SECONDS = int(os.environ.get("TERMINAL_HEARTBEAT_INTERVAL_SECONDS", "15"))
# GC interval - how often to run garbage collection
TERMINAL_GC_INTERVAL_SECONDS = int(os.environ.get("TERMINAL_GC_INTERVAL_SECONDS", "30"))
# Port range for ttyd instances
TERMINAL_PORT_RANGE_START = int(os.environ.get("TERMINAL_PORT_RANGE_START", "7680"))
TERMINAL_PORT_RANGE_END = int(os.environ.get("TERMINAL_PORT_RANGE_END", "7700"))
# SSH configuration
SSH_USER = os.environ.get("TERMINAL_SSH_USER", "automation")
TTYD_PATH = os.environ.get("TTYD_PATH", "ttyd")
# ttyd bind interface (important for WSL/remote access)
TERMINAL_TTYD_INTERFACE = os.environ.get("TERMINAL_TTYD_INTERFACE")
class TerminalServiceError(Exception):
"""Base exception for terminal service errors."""
pass
class HostNotReadyError(TerminalServiceError):
"""Host is not ready for terminal connection."""
pass
class SessionLimitExceededError(TerminalServiceError):
"""User has too many active sessions."""
pass
class TtydNotAvailableError(TerminalServiceError):
"""ttyd binary is not available."""
pass
class TerminalService:
"""
Manages terminal sessions and ttyd processes.
This service is responsible for:
- Generating secure session tokens
- Allocating ports for ttyd instances
- Spawning and managing ttyd processes
- Cleaning up expired/idle sessions via GC
- Session reuse for same user/host/mode
The GC task runs periodically to clean up:
- Sessions that exceeded their TTL
- Sessions that haven't received a heartbeat within idle timeout
"""
def __init__(self):
# Track active ttyd processes: session_id -> subprocess.Popen
self._processes: Dict[str, subprocess.Popen] = {}
# Track allocated ports: port -> session_id
self._allocated_ports: Dict[int, str] = {}
# Lock for thread-safe operations
self._lock = asyncio.Lock()
# Check if ttyd is available
self._ttyd_available: Optional[bool] = None
# GC task handle
self._gc_task: Optional[asyncio.Task] = None
self._gc_running: bool = False
# Database session factory for GC (set during app startup)
self._db_session_factory = None
# Metrics for observability
self._metrics = {
"sessions_created": 0,
"sessions_reused": 0,
"sessions_closed_user": 0,
"sessions_gc_expired": 0,
"sessions_gc_idle": 0,
"session_limit_hits": 0,
}
def check_ttyd_available(self) -> bool:
"""Check if ttyd binary is available."""
if self._ttyd_available is not None:
return self._ttyd_available
self._ttyd_available = shutil.which(TTYD_PATH) is not None
if not self._ttyd_available:
logger.warning(f"ttyd not found at '{TTYD_PATH}'. Terminal feature will be unavailable.")
else:
logger.info(f"ttyd found at: {shutil.which(TTYD_PATH)}")
return self._ttyd_available
def set_db_session_factory(self, factory):
"""Set the database session factory for GC operations."""
self._db_session_factory = factory
async def start_gc_task(self):
"""Start the garbage collection background task."""
if self._gc_task is not None:
return
self._gc_running = True
self._gc_task = asyncio.create_task(self._gc_loop())
logger.info(f"Terminal session GC started (interval={TERMINAL_GC_INTERVAL_SECONDS}s, idle_timeout={TERMINAL_SESSION_IDLE_TIMEOUT_SECONDS}s)")
async def stop_gc_task(self):
"""Stop the garbage collection background task."""
self._gc_running = False
if self._gc_task:
self._gc_task.cancel()
try:
await self._gc_task
except asyncio.CancelledError:
pass
self._gc_task = None
logger.info("Terminal session GC stopped")
async def _gc_loop(self):
"""Background loop that periodically cleans up stale sessions."""
while self._gc_running:
try:
await asyncio.sleep(TERMINAL_GC_INTERVAL_SECONDS)
if self._db_session_factory:
await self._run_gc_cycle()
except asyncio.CancelledError:
break
except Exception as e:
logger.exception(f"Error in terminal GC cycle: {e}")
async def _run_gc_cycle(self):
"""Run a single GC cycle to clean up stale sessions."""
from app.crud.terminal_session import TerminalSessionRepository
from app.models.terminal_session import CLOSE_REASON_TTL, CLOSE_REASON_IDLE
async with self._db_session_factory() as db_session:
repo = TerminalSessionRepository(db_session)
# Get all stale sessions (expired or idle)
stale_sessions = await repo.list_stale_sessions(
ttl_seconds=TERMINAL_SESSION_TTL_SECONDS,
idle_timeout_seconds=TERMINAL_SESSION_IDLE_TIMEOUT_SECONDS
)
if not stale_sessions:
return
now = datetime.now(timezone.utc)
expired_count = 0
idle_count = 0
for session in stale_sessions:
# Determine reason: TTL expired or idle
expires_at = self._to_utc_aware(session.expires_at)
is_expired = bool(expires_at and expires_at <= now)
# Terminate the ttyd process
await self.terminate_session(session.id)
# Release the port
await self.release_port(session.ttyd_port)
# Mark in DB with appropriate reason
if is_expired:
await repo.mark_expired(session.id)
expired_count += 1
self._metrics["sessions_gc_expired"] += 1
logger.info(f"session_gc_closed session={session.id[:8]}... host={session.host_name} reason=ttl")
else:
await repo.mark_idle(session.id)
idle_count += 1
self._metrics["sessions_gc_idle"] += 1
logger.info(f"session_gc_closed session={session.id[:8]}... host={session.host_name} reason=idle")
await db_session.commit()
if expired_count or idle_count:
logger.info(f"Terminal GC cycle: cleaned {expired_count} expired, {idle_count} idle sessions")
def _to_utc_aware(self, dt: Optional[datetime]) -> Optional[datetime]:
"""Normalize datetimes from DB (SQLite may return naive) into UTC-aware."""
if dt is None:
return None
if dt.tzinfo is None:
return dt.replace(tzinfo=timezone.utc)
return dt.astimezone(timezone.utc)
def get_metrics(self) -> Dict[str, Any]:
"""Get service metrics for observability."""
return {
**self._metrics,
"active_processes": len(self._processes),
"allocated_ports": len(self._allocated_ports),
"gc_running": self._gc_running,
}
def generate_session_id(self) -> str:
"""Generate a unique session ID."""
return secrets.token_hex(32)
def generate_session_token(self) -> Tuple[str, str]:
"""
Generate a session token and its hash.
Returns:
Tuple of (plain_token, token_hash)
"""
token = secrets.token_urlsafe(48)
token_hash = hashlib.sha256(token.encode()).hexdigest()
return token, token_hash
def verify_token(self, token: str, token_hash: str) -> bool:
"""Verify a token against its hash."""
computed_hash = hashlib.sha256(token.encode()).hexdigest()
return secrets.compare_digest(computed_hash, token_hash)
def _is_port_available(self, port: int) -> bool:
"""Check if a TCP port is available for binding."""
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
try:
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 0)
sock.bind(("0.0.0.0", port))
return True
except OSError:
return False
finally:
try:
sock.close()
except Exception:
pass
async def is_session_process_alive(self, session_id: str) -> bool:
"""Check if the ttyd process for a session is still running."""
async with self._lock:
process = self._processes.get(session_id)
if process is None:
return False
return process.poll() is None
async def allocate_port(self, session_id: str) -> int:
"""
Allocate an available port for a ttyd instance.
Returns:
Available port number
Raises:
TerminalServiceError if no ports available
"""
async with self._lock:
for port in range(TERMINAL_PORT_RANGE_START, TERMINAL_PORT_RANGE_END + 1):
if port in self._allocated_ports:
continue
if not self._is_port_available(port):
continue
if port not in self._allocated_ports:
self._allocated_ports[port] = session_id
logger.debug(f"Allocated port {port} for session {session_id[:8]}...")
return port
raise TerminalServiceError("No available ports for terminal session")
async def release_port(self, port: int) -> None:
"""Release an allocated port."""
async with self._lock:
if port in self._allocated_ports:
session_id = self._allocated_ports.pop(port)
logger.debug(f"Released port {port} from session {session_id[:8]}...")
async def spawn_ttyd(
self,
session_id: str,
host_ip: str,
port: int,
token: str,
) -> Optional[int]:
"""
Spawn a ttyd process for SSH connection.
Args:
session_id: Unique session identifier
host_ip: Target host IP address
port: Port to run ttyd on
token: Session token for authentication
Returns:
Process ID of the spawned ttyd, or None if failed
"""
if not self.check_ttyd_available():
raise TtydNotAvailableError("ttyd is not installed or not in PATH")
key_candidates = []
env_key = os.environ.get("TERMINAL_SSH_KEY_PATH")
if env_key:
key_candidates.append(env_key)
if getattr(settings, "ssh_key_path", None):
key_candidates.append(settings.ssh_key_path)
key_candidates.append(str(os.path.expanduser("~/.ssh/id_automation_ansible")))
key_candidates.append(str(os.path.expanduser("~/.ssh/id_rsa")))
ssh_key_path = next((p for p in key_candidates if p and os.path.exists(p)), None)
if not ssh_key_path:
raise TerminalServiceError(
"No SSH key found for terminal sessions. Set SSH_KEY_PATH or TERMINAL_SSH_KEY_PATH (expected ~/.ssh/id_automation_ansible)."
)
# Build ttyd command
# --once: Exit after single client disconnects
# --credential: Basic auth with session token
# --port: Listen port
# -W: Write-only mode disabled (allow input)
# The command executed is SSH to the target host
ssh_cmd = [
"ssh",
"-i",
ssh_key_path,
"-o",
"BatchMode=no",
"-o",
"PreferredAuthentications=publickey,keyboard-interactive,password",
"-o",
"ConnectTimeout=10",
"-o",
"ServerAliveInterval=30",
"-o",
"ServerAliveCountMax=2",
"-o",
"StrictHostKeyChecking=accept-new",
"-o",
"UserKnownHostsFile=/dev/null",
f"{SSH_USER}@{host_ip}",
]
ttyd_interface = TERMINAL_TTYD_INTERFACE
# Note: --credential removed to avoid HTTP basic auth popup
# Security is handled by session token validation in the connect/popout routes
cmd = [
TTYD_PATH,
"--once",
f"--port={port}",
"--writable",
]
if ttyd_interface:
cmd.extend(["--interface", ttyd_interface])
cmd.extend(ssh_cmd)
logger.info(
f"Spawning ttyd for session {session_id[:8]}... on port {port} -> {SSH_USER}@{host_ip} (key={ssh_key_path})"
)
try:
# Spawn the process
process = subprocess.Popen(
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
start_new_session=True, # Detach from parent process group
)
# Store the process
async with self._lock:
self._processes[session_id] = process
# Wait briefly to check if process started successfully
await asyncio.sleep(0.5)
if process.poll() is not None:
# Process exited immediately - likely an error
stdout, stderr = process.communicate(timeout=1)
error_msg = stderr.decode() if stderr else stdout.decode() if stdout else "Unknown error"
logger.error(f"ttyd failed to start: {error_msg}")
await self.release_port(port)
raise TerminalServiceError(f"ttyd failed to start: {error_msg.strip()}")
logger.info(f"ttyd started with PID {process.pid} for session {session_id[:8]}...")
return process.pid
except TerminalServiceError:
# Preserve detailed error message for API layer
raise
except Exception as e:
logger.exception(f"Failed to spawn ttyd: {e}")
await self.release_port(port)
raise TerminalServiceError(str(e))
async def terminate_session(self, session_id: str) -> bool:
"""
Terminate a terminal session and its ttyd process.
Returns:
True if session was terminated, False if not found
"""
async with self._lock:
process = self._processes.pop(session_id, None)
if process is None:
return False
try:
# Try graceful termination first
if process.poll() is None:
process.terminate()
try:
process.wait(timeout=5)
except subprocess.TimeoutExpired:
# Force kill if graceful termination failed
process.kill()
process.wait(timeout=2)
logger.info(f"Terminated ttyd process for session {session_id[:8]}...")
return True
except Exception as e:
logger.exception(f"Error terminating ttyd process: {e}")
return False
async def cleanup_expired_sessions(
self,
get_expired_sessions_func,
mark_expired_func,
db_session
) -> int:
"""
Clean up expired terminal sessions.
Args:
get_expired_sessions_func: Async function to get expired sessions from DB
mark_expired_func: Async function to mark session as expired in DB
db_session: Database session
Returns:
Number of sessions cleaned up
"""
from app.crud.terminal_session import TerminalSessionRepository
repo = TerminalSessionRepository(db_session)
expired = await repo.list_expired()
count = 0
for session in expired:
# Terminate the ttyd process
await self.terminate_session(session.id)
# Release the port
await self.release_port(session.ttyd_port)
# Mark as expired in DB
await repo.mark_expired(session.id)
count += 1
self._metrics["sessions_gc_expired"] += 1
logger.info(f"session_gc_closed session={session.id[:8]}... host={session.host_name} reason=ttl")
return count
async def cleanup_idle_sessions(self, db_session) -> int:
"""
Clean up idle terminal sessions (no heartbeat within timeout).
Returns:
Number of sessions cleaned up
"""
from app.crud.terminal_session import TerminalSessionRepository
repo = TerminalSessionRepository(db_session)
idle_sessions = await repo.list_idle(TERMINAL_SESSION_IDLE_TIMEOUT_SECONDS)
count = 0
for session in idle_sessions:
# Terminate the ttyd process
await self.terminate_session(session.id)
# Release the port
await self.release_port(session.ttyd_port)
# Mark as idle in DB
await repo.mark_idle(session.id)
count += 1
self._metrics["sessions_gc_idle"] += 1
logger.info(f"session_gc_closed session={session.id[:8]}... host={session.host_name} reason=idle")
return count
async def close_session_with_cleanup(
self,
session_id: str,
port: int,
reason: str = "user_close"
) -> bool:
"""
Close a session with full cleanup (process + port).
Returns:
True if session was cleaned up
"""
terminated = await self.terminate_session(session_id)
await self.release_port(port)
if reason == "user_close":
self._metrics["sessions_closed_user"] += 1
logger.info(f"session_closed session={session_id[:8]}... reason={reason}")
return terminated
def record_session_created(self, reused: bool = False):
"""Record a session creation for metrics."""
if reused:
self._metrics["sessions_reused"] += 1
else:
self._metrics["sessions_created"] += 1
def record_session_limit_hit(self):
"""Record a session limit hit for metrics."""
self._metrics["session_limit_hits"] += 1
def get_active_session_count(self) -> int:
"""Get the number of active ttyd processes."""
return len(self._processes)
def get_session_url(self, port: int, token: str, base_url: str = "") -> str:
"""
Get the URL to access a terminal session.
Args:
port: ttyd port
token: Session token
base_url: Optional base URL prefix
Returns:
Full URL to access the terminal
"""
# If we have a reverse proxy, use the proxied URL
# Otherwise, direct access to ttyd port
if base_url:
return f"{base_url}/terminal/proxy/{port}?token={token}"
return f"http://localhost:{port}/"
def get_websocket_url(self, port: int) -> str:
"""Get the WebSocket URL for terminal connection."""
return f"ws://localhost:{port}/ws"
# Global service instance
terminal_service = TerminalService()
def get_terminal_service() -> TerminalService:
"""Get the global terminal service instance."""
return terminal_service
# Export configuration for use in routes
__all__ = [
"terminal_service",
"get_terminal_service",
"TerminalService",
"TerminalServiceError",
"HostNotReadyError",
"SessionLimitExceededError",
"TtydNotAvailableError",
"TERMINAL_SESSION_TTL_MINUTES",
"TERMINAL_SESSION_TTL_SECONDS",
"TERMINAL_MAX_SESSIONS_PER_USER",
"TERMINAL_SESSION_IDLE_TIMEOUT_SECONDS",
"TERMINAL_HEARTBEAT_INTERVAL_SECONDS",
"TERMINAL_GC_INTERVAL_SECONDS",
]