""" Terminal Service - Manages SSH terminal sessions via ttyd. This service handles: - Creating terminal sessions with unique tokens - Spawning ttyd processes for SSH connections - Managing session lifecycle (creation, expiration, cleanup) - Port allocation for ttyd instances - Garbage collection of stale/idle sessions - Session reuse for same user/host/mode """ import asyncio import hashlib import logging import os import secrets import shutil import signal import socket import subprocess from datetime import datetime, timezone from typing import Dict, Any, Tuple, Optional from app.core.config import settings logger = logging.getLogger(__name__) # Configuration - Session TTL and limits TERMINAL_SESSION_TTL_MINUTES = int(os.environ.get("TERMINAL_SESSION_TTL_MINUTES", "30")) TERMINAL_SESSION_TTL_SECONDS = TERMINAL_SESSION_TTL_MINUTES * 60 TERMINAL_MAX_SESSIONS_PER_USER = int(os.environ.get("TERMINAL_MAX_SESSIONS_PER_USER", "3")) # Idle timeout - sessions without heartbeat for this long are considered dead TERMINAL_SESSION_IDLE_TIMEOUT_SECONDS = int(os.environ.get("TERMINAL_SESSION_IDLE_TIMEOUT_SECONDS", "120")) # Heartbeat interval - how often client should send heartbeat (slightly less than idle timeout) TERMINAL_HEARTBEAT_INTERVAL_SECONDS = int(os.environ.get("TERMINAL_HEARTBEAT_INTERVAL_SECONDS", "15")) # GC interval - how often to run garbage collection TERMINAL_GC_INTERVAL_SECONDS = int(os.environ.get("TERMINAL_GC_INTERVAL_SECONDS", "30")) # Port range for ttyd instances TERMINAL_PORT_RANGE_START = int(os.environ.get("TERMINAL_PORT_RANGE_START", "7680")) TERMINAL_PORT_RANGE_END = int(os.environ.get("TERMINAL_PORT_RANGE_END", "7700")) # SSH configuration SSH_USER = os.environ.get("TERMINAL_SSH_USER", "automation") TTYD_PATH = os.environ.get("TTYD_PATH", "ttyd") # ttyd bind interface (important for WSL/remote access) TERMINAL_TTYD_INTERFACE = os.environ.get("TERMINAL_TTYD_INTERFACE") class TerminalServiceError(Exception): """Base exception for terminal service errors.""" pass class HostNotReadyError(TerminalServiceError): """Host is not ready for terminal connection.""" pass class SessionLimitExceededError(TerminalServiceError): """User has too many active sessions.""" pass class TtydNotAvailableError(TerminalServiceError): """ttyd binary is not available.""" pass class TerminalService: """ Manages terminal sessions and ttyd processes. This service is responsible for: - Generating secure session tokens - Allocating ports for ttyd instances - Spawning and managing ttyd processes - Cleaning up expired/idle sessions via GC - Session reuse for same user/host/mode The GC task runs periodically to clean up: - Sessions that exceeded their TTL - Sessions that haven't received a heartbeat within idle timeout """ def __init__(self): # Track active ttyd processes: session_id -> subprocess.Popen self._processes: Dict[str, subprocess.Popen] = {} # Track allocated ports: port -> session_id self._allocated_ports: Dict[int, str] = {} # Lock for thread-safe operations self._lock = asyncio.Lock() # Check if ttyd is available self._ttyd_available: Optional[bool] = None # GC task handle self._gc_task: Optional[asyncio.Task] = None self._gc_running: bool = False # Database session factory for GC (set during app startup) self._db_session_factory = None # Metrics for observability self._metrics = { "sessions_created": 0, "sessions_reused": 0, "sessions_closed_user": 0, "sessions_gc_expired": 0, "sessions_gc_idle": 0, "session_limit_hits": 0, } def check_ttyd_available(self) -> bool: """Check if ttyd binary is available.""" if self._ttyd_available is not None: return self._ttyd_available self._ttyd_available = shutil.which(TTYD_PATH) is not None if not self._ttyd_available: logger.warning(f"ttyd not found at '{TTYD_PATH}'. Terminal feature will be unavailable.") else: logger.info(f"ttyd found at: {shutil.which(TTYD_PATH)}") return self._ttyd_available def set_db_session_factory(self, factory): """Set the database session factory for GC operations.""" self._db_session_factory = factory async def start_gc_task(self): """Start the garbage collection background task.""" if self._gc_task is not None: return self._gc_running = True self._gc_task = asyncio.create_task(self._gc_loop()) logger.info(f"Terminal session GC started (interval={TERMINAL_GC_INTERVAL_SECONDS}s, idle_timeout={TERMINAL_SESSION_IDLE_TIMEOUT_SECONDS}s)") async def stop_gc_task(self): """Stop the garbage collection background task.""" self._gc_running = False if self._gc_task: self._gc_task.cancel() try: await self._gc_task except asyncio.CancelledError: pass self._gc_task = None logger.info("Terminal session GC stopped") async def _gc_loop(self): """Background loop that periodically cleans up stale sessions.""" while self._gc_running: try: await asyncio.sleep(TERMINAL_GC_INTERVAL_SECONDS) if self._db_session_factory: await self._run_gc_cycle() except asyncio.CancelledError: break except Exception as e: logger.exception(f"Error in terminal GC cycle: {e}") async def _run_gc_cycle(self): """Run a single GC cycle to clean up stale sessions.""" from app.crud.terminal_session import TerminalSessionRepository from app.models.terminal_session import CLOSE_REASON_TTL, CLOSE_REASON_IDLE async with self._db_session_factory() as db_session: repo = TerminalSessionRepository(db_session) # Get all stale sessions (expired or idle) stale_sessions = await repo.list_stale_sessions( ttl_seconds=TERMINAL_SESSION_TTL_SECONDS, idle_timeout_seconds=TERMINAL_SESSION_IDLE_TIMEOUT_SECONDS ) if not stale_sessions: return now = datetime.now(timezone.utc) expired_count = 0 idle_count = 0 for session in stale_sessions: # Determine reason: TTL expired or idle expires_at = self._to_utc_aware(session.expires_at) is_expired = bool(expires_at and expires_at <= now) # Terminate the ttyd process await self.terminate_session(session.id) # Release the port await self.release_port(session.ttyd_port) # Mark in DB with appropriate reason if is_expired: await repo.mark_expired(session.id) expired_count += 1 self._metrics["sessions_gc_expired"] += 1 logger.info(f"session_gc_closed session={session.id[:8]}... host={session.host_name} reason=ttl") else: await repo.mark_idle(session.id) idle_count += 1 self._metrics["sessions_gc_idle"] += 1 logger.info(f"session_gc_closed session={session.id[:8]}... host={session.host_name} reason=idle") await db_session.commit() if expired_count or idle_count: logger.info(f"Terminal GC cycle: cleaned {expired_count} expired, {idle_count} idle sessions") def _to_utc_aware(self, dt: Optional[datetime]) -> Optional[datetime]: """Normalize datetimes from DB (SQLite may return naive) into UTC-aware.""" if dt is None: return None if dt.tzinfo is None: return dt.replace(tzinfo=timezone.utc) return dt.astimezone(timezone.utc) def get_metrics(self) -> Dict[str, Any]: """Get service metrics for observability.""" return { **self._metrics, "active_processes": len(self._processes), "allocated_ports": len(self._allocated_ports), "gc_running": self._gc_running, } def generate_session_id(self) -> str: """Generate a unique session ID.""" return secrets.token_hex(32) def generate_session_token(self) -> Tuple[str, str]: """ Generate a session token and its hash. Returns: Tuple of (plain_token, token_hash) """ token = secrets.token_urlsafe(48) token_hash = hashlib.sha256(token.encode()).hexdigest() return token, token_hash def verify_token(self, token: str, token_hash: str) -> bool: """Verify a token against its hash.""" computed_hash = hashlib.sha256(token.encode()).hexdigest() return secrets.compare_digest(computed_hash, token_hash) def _is_port_available(self, port: int) -> bool: """Check if a TCP port is available for binding.""" sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) try: sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 0) sock.bind(("0.0.0.0", port)) return True except OSError: return False finally: try: sock.close() except Exception: pass async def is_session_process_alive(self, session_id: str) -> bool: """Check if the ttyd process for a session is still running.""" async with self._lock: process = self._processes.get(session_id) if process is None: return False return process.poll() is None async def allocate_port(self, session_id: str) -> int: """ Allocate an available port for a ttyd instance. Returns: Available port number Raises: TerminalServiceError if no ports available """ async with self._lock: for port in range(TERMINAL_PORT_RANGE_START, TERMINAL_PORT_RANGE_END + 1): if port in self._allocated_ports: continue if not self._is_port_available(port): continue if port not in self._allocated_ports: self._allocated_ports[port] = session_id logger.debug(f"Allocated port {port} for session {session_id[:8]}...") return port raise TerminalServiceError("No available ports for terminal session") async def release_port(self, port: int) -> None: """Release an allocated port.""" async with self._lock: if port in self._allocated_ports: session_id = self._allocated_ports.pop(port) logger.debug(f"Released port {port} from session {session_id[:8]}...") async def spawn_ttyd( self, session_id: str, host_ip: str, port: int, token: str, ) -> Optional[int]: """ Spawn a ttyd process for SSH connection. Args: session_id: Unique session identifier host_ip: Target host IP address port: Port to run ttyd on token: Session token for authentication Returns: Process ID of the spawned ttyd, or None if failed """ if not self.check_ttyd_available(): raise TtydNotAvailableError("ttyd is not installed or not in PATH") key_candidates = [] env_key = os.environ.get("TERMINAL_SSH_KEY_PATH") if env_key: key_candidates.append(env_key) if getattr(settings, "ssh_key_path", None): key_candidates.append(settings.ssh_key_path) key_candidates.append(str(os.path.expanduser("~/.ssh/id_automation_ansible"))) key_candidates.append(str(os.path.expanduser("~/.ssh/id_rsa"))) ssh_key_path = next((p for p in key_candidates if p and os.path.exists(p)), None) if not ssh_key_path: raise TerminalServiceError( "No SSH key found for terminal sessions. Set SSH_KEY_PATH or TERMINAL_SSH_KEY_PATH (expected ~/.ssh/id_automation_ansible)." ) # Build ttyd command # NOTE: Do not use --once because it prevents browser reconnection. # Use --max-clients=1 to keep a single active client while allowing reconnect. # --credential: Basic auth with session token # --port: Listen port # -W: Write-only mode disabled (allow input) # The command executed is SSH to the target host ssh_cmd = [ "ssh", "-i", ssh_key_path, "-o", "BatchMode=no", "-o", "PreferredAuthentications=publickey,keyboard-interactive,password", "-o", "ConnectTimeout=10", "-o", "ServerAliveInterval=30", "-o", "ServerAliveCountMax=2", "-o", "StrictHostKeyChecking=no", "-o", "UserKnownHostsFile=/dev/null", f"{SSH_USER}@{host_ip}", ] ttyd_interface = TERMINAL_TTYD_INTERFACE if ttyd_interface and ttyd_interface.strip().lower() in {"127.0.0.1", "localhost", "lo"}: logger.warning( "Ignoring TERMINAL_TTYD_INTERFACE=%s because binding ttyd to loopback breaks remote access/reconnect", ttyd_interface, ) ttyd_interface = None # Note: --credential removed to avoid HTTP basic auth popup # Security is handled by session token validation in the connect/popout routes cmd = [ TTYD_PATH, "--max-clients=1", f"--port={port}", "--writable", ] # if ttyd_interface: # cmd.extend(["--interface", ttyd_interface]) # Always bind to 0.0.0.0 (default) to ensure local proxy can connect via 127.0.0.1 # regardless of which interface is primary (eth0, etc). cmd.extend(ssh_cmd) logger.info( f"Spawning ttyd for session {session_id[:8]}... on port {port} -> {SSH_USER}@{host_ip} (key={ssh_key_path})" ) try: # Spawn the process process = subprocess.Popen( cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, start_new_session=True, # Detach from parent process group ) # Store the process async with self._lock: self._processes[session_id] = process # Wait briefly to check if process started successfully await asyncio.sleep(0.5) if process.poll() is not None: # Process exited immediately - likely an error stdout, stderr = process.communicate(timeout=1) error_msg = stderr.decode() if stderr else stdout.decode() if stdout else "Unknown error" logger.error(f"ttyd failed to start: {error_msg}") await self.release_port(port) raise TerminalServiceError(f"ttyd failed to start: {error_msg.strip()}") logger.info(f"ttyd started with PID {process.pid} for session {session_id[:8]}...") return process.pid except TerminalServiceError: # Preserve detailed error message for API layer raise except Exception as e: logger.exception(f"Failed to spawn ttyd: {e}") await self.release_port(port) raise TerminalServiceError(str(e)) async def terminate_session(self, session_id: str) -> bool: """ Terminate a terminal session and its ttyd process. Returns: True if session was terminated, False if not found """ async with self._lock: process = self._processes.pop(session_id, None) if process is None: return False try: # Try graceful termination first if process.poll() is None: process.terminate() try: process.wait(timeout=5) except subprocess.TimeoutExpired: # Force kill if graceful termination failed process.kill() process.wait(timeout=2) logger.info(f"Terminated ttyd process for session {session_id[:8]}...") return True except Exception as e: logger.exception(f"Error terminating ttyd process: {e}") return False async def cleanup_expired_sessions( self, get_expired_sessions_func, mark_expired_func, db_session ) -> int: """ Clean up expired terminal sessions. Args: get_expired_sessions_func: Async function to get expired sessions from DB mark_expired_func: Async function to mark session as expired in DB db_session: Database session Returns: Number of sessions cleaned up """ from app.crud.terminal_session import TerminalSessionRepository repo = TerminalSessionRepository(db_session) expired = await repo.list_expired() count = 0 for session in expired: # Terminate the ttyd process await self.terminate_session(session.id) # Release the port await self.release_port(session.ttyd_port) # Mark as expired in DB await repo.mark_expired(session.id) count += 1 self._metrics["sessions_gc_expired"] += 1 logger.info(f"session_gc_closed session={session.id[:8]}... host={session.host_name} reason=ttl") return count async def cleanup_idle_sessions(self, db_session) -> int: """ Clean up idle terminal sessions (no heartbeat within timeout). Returns: Number of sessions cleaned up """ from app.crud.terminal_session import TerminalSessionRepository repo = TerminalSessionRepository(db_session) idle_sessions = await repo.list_idle(TERMINAL_SESSION_IDLE_TIMEOUT_SECONDS) count = 0 for session in idle_sessions: # Terminate the ttyd process await self.terminate_session(session.id) # Release the port await self.release_port(session.ttyd_port) # Mark as idle in DB await repo.mark_idle(session.id) count += 1 self._metrics["sessions_gc_idle"] += 1 logger.info(f"session_gc_closed session={session.id[:8]}... host={session.host_name} reason=idle") return count async def close_session_with_cleanup( self, session_id: str, port: int, reason: str = "user_close" ) -> bool: """ Close a session with full cleanup (process + port). Returns: True if session was cleaned up """ terminated = await self.terminate_session(session_id) await self.release_port(port) if reason == "user_close": self._metrics["sessions_closed_user"] += 1 logger.info(f"session_closed session={session_id[:8]}... reason={reason}") return terminated def record_session_created(self, reused: bool = False): """Record a session creation for metrics.""" if reused: self._metrics["sessions_reused"] += 1 else: self._metrics["sessions_created"] += 1 def record_session_limit_hit(self): """Record a session limit hit for metrics.""" self._metrics["session_limit_hits"] += 1 def get_active_session_count(self) -> int: """Get the number of active ttyd processes.""" return len(self._processes) def get_session_url(self, port: int, token: str, base_url: str = "") -> str: """ Get the URL to access a terminal session. Args: port: ttyd port token: Session token base_url: Optional base URL prefix Returns: Full URL to access the terminal """ # If we have a reverse proxy, use the proxied URL # Otherwise, direct access to ttyd port if base_url: return f"{base_url}/terminal/proxy/{port}?token={token}" return f"http://localhost:{port}/" def get_websocket_url(self, port: int) -> str: """Get the WebSocket URL for terminal connection.""" return f"ws://localhost:{port}/ws" # Global service instance terminal_service = TerminalService() def get_terminal_service() -> TerminalService: """Get the global terminal service instance.""" return terminal_service # Export configuration for use in routes __all__ = [ "terminal_service", "get_terminal_service", "TerminalService", "TerminalServiceError", "HostNotReadyError", "SessionLimitExceededError", "TtydNotAvailableError", "TERMINAL_SESSION_TTL_MINUTES", "TERMINAL_SESSION_TTL_SECONDS", "TERMINAL_MAX_SESSIONS_PER_USER", "TERMINAL_SESSION_IDLE_TIMEOUT_SECONDS", "TERMINAL_HEARTBEAT_INTERVAL_SECONDS", "TERMINAL_GC_INTERVAL_SECONDS", ]