Some checks failed
Tests / Backend Tests (Python) (3.10) (push) Has been cancelled
Tests / Backend Tests (Python) (3.11) (push) Has been cancelled
Tests / Backend Tests (Python) (3.12) (push) Has been cancelled
Tests / Frontend Tests (JS) (push) Has been cancelled
Tests / Integration Tests (push) Has been cancelled
Tests / All Tests Passed (push) Has been cancelled
598 lines
21 KiB
Python
598 lines
21 KiB
Python
"""
|
|
Terminal Service - Manages SSH terminal sessions via ttyd.
|
|
|
|
This service handles:
|
|
- Creating terminal sessions with unique tokens
|
|
- Spawning ttyd processes for SSH connections
|
|
- Managing session lifecycle (creation, expiration, cleanup)
|
|
- Port allocation for ttyd instances
|
|
- Garbage collection of stale/idle sessions
|
|
- Session reuse for same user/host/mode
|
|
"""
|
|
import asyncio
|
|
import hashlib
|
|
import logging
|
|
import os
|
|
import secrets
|
|
import shutil
|
|
import signal
|
|
import socket
|
|
import subprocess
|
|
from datetime import datetime, timezone
|
|
from typing import Dict, Any, Tuple, Optional
|
|
|
|
from app.core.config import settings
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
# Configuration - Session TTL and limits
|
|
TERMINAL_SESSION_TTL_MINUTES = int(os.environ.get("TERMINAL_SESSION_TTL_MINUTES", "30"))
|
|
TERMINAL_SESSION_TTL_SECONDS = TERMINAL_SESSION_TTL_MINUTES * 60
|
|
TERMINAL_MAX_SESSIONS_PER_USER = int(os.environ.get("TERMINAL_MAX_SESSIONS_PER_USER", "3"))
|
|
|
|
# Idle timeout - sessions without heartbeat for this long are considered dead
|
|
TERMINAL_SESSION_IDLE_TIMEOUT_SECONDS = int(os.environ.get("TERMINAL_SESSION_IDLE_TIMEOUT_SECONDS", "120"))
|
|
|
|
# Heartbeat interval - how often client should send heartbeat (slightly less than idle timeout)
|
|
TERMINAL_HEARTBEAT_INTERVAL_SECONDS = int(os.environ.get("TERMINAL_HEARTBEAT_INTERVAL_SECONDS", "15"))
|
|
|
|
# GC interval - how often to run garbage collection
|
|
TERMINAL_GC_INTERVAL_SECONDS = int(os.environ.get("TERMINAL_GC_INTERVAL_SECONDS", "30"))
|
|
|
|
# Port range for ttyd instances
|
|
TERMINAL_PORT_RANGE_START = int(os.environ.get("TERMINAL_PORT_RANGE_START", "7680"))
|
|
TERMINAL_PORT_RANGE_END = int(os.environ.get("TERMINAL_PORT_RANGE_END", "7700"))
|
|
|
|
# SSH configuration
|
|
SSH_USER = os.environ.get("TERMINAL_SSH_USER", "automation")
|
|
TTYD_PATH = os.environ.get("TTYD_PATH", "ttyd")
|
|
|
|
# ttyd bind interface (important for WSL/remote access)
|
|
TERMINAL_TTYD_INTERFACE = os.environ.get("TERMINAL_TTYD_INTERFACE")
|
|
|
|
|
|
class TerminalServiceError(Exception):
|
|
"""Base exception for terminal service errors."""
|
|
pass
|
|
|
|
|
|
class HostNotReadyError(TerminalServiceError):
|
|
"""Host is not ready for terminal connection."""
|
|
pass
|
|
|
|
|
|
class SessionLimitExceededError(TerminalServiceError):
|
|
"""User has too many active sessions."""
|
|
pass
|
|
|
|
|
|
class TtydNotAvailableError(TerminalServiceError):
|
|
"""ttyd binary is not available."""
|
|
pass
|
|
|
|
|
|
class TerminalService:
|
|
"""
|
|
Manages terminal sessions and ttyd processes.
|
|
|
|
This service is responsible for:
|
|
- Generating secure session tokens
|
|
- Allocating ports for ttyd instances
|
|
- Spawning and managing ttyd processes
|
|
- Cleaning up expired/idle sessions via GC
|
|
- Session reuse for same user/host/mode
|
|
|
|
The GC task runs periodically to clean up:
|
|
- Sessions that exceeded their TTL
|
|
- Sessions that haven't received a heartbeat within idle timeout
|
|
"""
|
|
|
|
def __init__(self):
|
|
# Track active ttyd processes: session_id -> subprocess.Popen
|
|
self._processes: Dict[str, subprocess.Popen] = {}
|
|
# Track allocated ports: port -> session_id
|
|
self._allocated_ports: Dict[int, str] = {}
|
|
# Lock for thread-safe operations
|
|
self._lock = asyncio.Lock()
|
|
# Check if ttyd is available
|
|
self._ttyd_available: Optional[bool] = None
|
|
# GC task handle
|
|
self._gc_task: Optional[asyncio.Task] = None
|
|
self._gc_running: bool = False
|
|
# Database session factory for GC (set during app startup)
|
|
self._db_session_factory = None
|
|
# Metrics for observability
|
|
self._metrics = {
|
|
"sessions_created": 0,
|
|
"sessions_reused": 0,
|
|
"sessions_closed_user": 0,
|
|
"sessions_gc_expired": 0,
|
|
"sessions_gc_idle": 0,
|
|
"session_limit_hits": 0,
|
|
}
|
|
|
|
def check_ttyd_available(self) -> bool:
|
|
"""Check if ttyd binary is available."""
|
|
if self._ttyd_available is not None:
|
|
return self._ttyd_available
|
|
|
|
self._ttyd_available = shutil.which(TTYD_PATH) is not None
|
|
if not self._ttyd_available:
|
|
logger.warning(f"ttyd not found at '{TTYD_PATH}'. Terminal feature will be unavailable.")
|
|
else:
|
|
logger.info(f"ttyd found at: {shutil.which(TTYD_PATH)}")
|
|
|
|
return self._ttyd_available
|
|
|
|
def set_db_session_factory(self, factory):
|
|
"""Set the database session factory for GC operations."""
|
|
self._db_session_factory = factory
|
|
|
|
async def start_gc_task(self):
|
|
"""Start the garbage collection background task."""
|
|
if self._gc_task is not None:
|
|
return
|
|
|
|
self._gc_running = True
|
|
self._gc_task = asyncio.create_task(self._gc_loop())
|
|
logger.info(f"Terminal session GC started (interval={TERMINAL_GC_INTERVAL_SECONDS}s, idle_timeout={TERMINAL_SESSION_IDLE_TIMEOUT_SECONDS}s)")
|
|
|
|
async def stop_gc_task(self):
|
|
"""Stop the garbage collection background task."""
|
|
self._gc_running = False
|
|
if self._gc_task:
|
|
self._gc_task.cancel()
|
|
try:
|
|
await self._gc_task
|
|
except asyncio.CancelledError:
|
|
pass
|
|
self._gc_task = None
|
|
logger.info("Terminal session GC stopped")
|
|
|
|
async def _gc_loop(self):
|
|
"""Background loop that periodically cleans up stale sessions."""
|
|
while self._gc_running:
|
|
try:
|
|
await asyncio.sleep(TERMINAL_GC_INTERVAL_SECONDS)
|
|
if self._db_session_factory:
|
|
await self._run_gc_cycle()
|
|
except asyncio.CancelledError:
|
|
break
|
|
except Exception as e:
|
|
logger.exception(f"Error in terminal GC cycle: {e}")
|
|
|
|
async def _run_gc_cycle(self):
|
|
"""Run a single GC cycle to clean up stale sessions."""
|
|
from app.crud.terminal_session import TerminalSessionRepository
|
|
from app.models.terminal_session import CLOSE_REASON_TTL, CLOSE_REASON_IDLE
|
|
|
|
async with self._db_session_factory() as db_session:
|
|
repo = TerminalSessionRepository(db_session)
|
|
|
|
# Get all stale sessions (expired or idle)
|
|
stale_sessions = await repo.list_stale_sessions(
|
|
ttl_seconds=TERMINAL_SESSION_TTL_SECONDS,
|
|
idle_timeout_seconds=TERMINAL_SESSION_IDLE_TIMEOUT_SECONDS
|
|
)
|
|
|
|
if not stale_sessions:
|
|
return
|
|
|
|
now = datetime.now(timezone.utc)
|
|
expired_count = 0
|
|
idle_count = 0
|
|
|
|
for session in stale_sessions:
|
|
# Determine reason: TTL expired or idle
|
|
expires_at = self._to_utc_aware(session.expires_at)
|
|
is_expired = bool(expires_at and expires_at <= now)
|
|
|
|
# Terminate the ttyd process
|
|
await self.terminate_session(session.id)
|
|
# Release the port
|
|
await self.release_port(session.ttyd_port)
|
|
|
|
# Mark in DB with appropriate reason
|
|
if is_expired:
|
|
await repo.mark_expired(session.id)
|
|
expired_count += 1
|
|
self._metrics["sessions_gc_expired"] += 1
|
|
logger.info(f"session_gc_closed session={session.id[:8]}... host={session.host_name} reason=ttl")
|
|
else:
|
|
await repo.mark_idle(session.id)
|
|
idle_count += 1
|
|
self._metrics["sessions_gc_idle"] += 1
|
|
logger.info(f"session_gc_closed session={session.id[:8]}... host={session.host_name} reason=idle")
|
|
|
|
await db_session.commit()
|
|
|
|
if expired_count or idle_count:
|
|
logger.info(f"Terminal GC cycle: cleaned {expired_count} expired, {idle_count} idle sessions")
|
|
|
|
def _to_utc_aware(self, dt: Optional[datetime]) -> Optional[datetime]:
|
|
"""Normalize datetimes from DB (SQLite may return naive) into UTC-aware."""
|
|
if dt is None:
|
|
return None
|
|
if dt.tzinfo is None:
|
|
return dt.replace(tzinfo=timezone.utc)
|
|
return dt.astimezone(timezone.utc)
|
|
|
|
def get_metrics(self) -> Dict[str, Any]:
|
|
"""Get service metrics for observability."""
|
|
return {
|
|
**self._metrics,
|
|
"active_processes": len(self._processes),
|
|
"allocated_ports": len(self._allocated_ports),
|
|
"gc_running": self._gc_running,
|
|
}
|
|
|
|
def generate_session_id(self) -> str:
|
|
"""Generate a unique session ID."""
|
|
return secrets.token_hex(32)
|
|
|
|
def generate_session_token(self) -> Tuple[str, str]:
|
|
"""
|
|
Generate a session token and its hash.
|
|
|
|
Returns:
|
|
Tuple of (plain_token, token_hash)
|
|
"""
|
|
token = secrets.token_urlsafe(48)
|
|
token_hash = hashlib.sha256(token.encode()).hexdigest()
|
|
return token, token_hash
|
|
|
|
def verify_token(self, token: str, token_hash: str) -> bool:
|
|
"""Verify a token against its hash."""
|
|
computed_hash = hashlib.sha256(token.encode()).hexdigest()
|
|
return secrets.compare_digest(computed_hash, token_hash)
|
|
|
|
def _is_port_available(self, port: int) -> bool:
|
|
"""Check if a TCP port is available for binding."""
|
|
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
|
try:
|
|
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 0)
|
|
sock.bind(("0.0.0.0", port))
|
|
return True
|
|
except OSError:
|
|
return False
|
|
finally:
|
|
try:
|
|
sock.close()
|
|
except Exception:
|
|
pass
|
|
|
|
async def is_session_process_alive(self, session_id: str) -> bool:
|
|
"""Check if the ttyd process for a session is still running."""
|
|
async with self._lock:
|
|
process = self._processes.get(session_id)
|
|
if process is None:
|
|
return False
|
|
return process.poll() is None
|
|
|
|
async def allocate_port(self, session_id: str) -> int:
|
|
"""
|
|
Allocate an available port for a ttyd instance.
|
|
|
|
Returns:
|
|
Available port number
|
|
|
|
Raises:
|
|
TerminalServiceError if no ports available
|
|
"""
|
|
async with self._lock:
|
|
for port in range(TERMINAL_PORT_RANGE_START, TERMINAL_PORT_RANGE_END + 1):
|
|
if port in self._allocated_ports:
|
|
continue
|
|
if not self._is_port_available(port):
|
|
continue
|
|
|
|
if port not in self._allocated_ports:
|
|
self._allocated_ports[port] = session_id
|
|
logger.debug(f"Allocated port {port} for session {session_id[:8]}...")
|
|
return port
|
|
|
|
raise TerminalServiceError("No available ports for terminal session")
|
|
|
|
async def release_port(self, port: int) -> None:
|
|
"""Release an allocated port."""
|
|
async with self._lock:
|
|
if port in self._allocated_ports:
|
|
session_id = self._allocated_ports.pop(port)
|
|
logger.debug(f"Released port {port} from session {session_id[:8]}...")
|
|
|
|
async def spawn_ttyd(
|
|
self,
|
|
session_id: str,
|
|
host_ip: str,
|
|
port: int,
|
|
token: str,
|
|
) -> Optional[int]:
|
|
"""
|
|
Spawn a ttyd process for SSH connection.
|
|
|
|
Args:
|
|
session_id: Unique session identifier
|
|
host_ip: Target host IP address
|
|
port: Port to run ttyd on
|
|
token: Session token for authentication
|
|
|
|
Returns:
|
|
Process ID of the spawned ttyd, or None if failed
|
|
"""
|
|
if not self.check_ttyd_available():
|
|
raise TtydNotAvailableError("ttyd is not installed or not in PATH")
|
|
|
|
key_candidates = []
|
|
env_key = os.environ.get("TERMINAL_SSH_KEY_PATH")
|
|
if env_key:
|
|
key_candidates.append(env_key)
|
|
if getattr(settings, "ssh_key_path", None):
|
|
key_candidates.append(settings.ssh_key_path)
|
|
key_candidates.append(str(os.path.expanduser("~/.ssh/id_automation_ansible")))
|
|
key_candidates.append(str(os.path.expanduser("~/.ssh/id_rsa")))
|
|
|
|
ssh_key_path = next((p for p in key_candidates if p and os.path.exists(p)), None)
|
|
if not ssh_key_path:
|
|
raise TerminalServiceError(
|
|
"No SSH key found for terminal sessions. Set SSH_KEY_PATH or TERMINAL_SSH_KEY_PATH (expected ~/.ssh/id_automation_ansible)."
|
|
)
|
|
|
|
# Build ttyd command
|
|
# --once: Exit after single client disconnects
|
|
# --credential: Basic auth with session token
|
|
# --port: Listen port
|
|
# -W: Write-only mode disabled (allow input)
|
|
# The command executed is SSH to the target host
|
|
ssh_cmd = [
|
|
"ssh",
|
|
"-i",
|
|
ssh_key_path,
|
|
"-o",
|
|
"BatchMode=no",
|
|
"-o",
|
|
"PreferredAuthentications=publickey,keyboard-interactive,password",
|
|
"-o",
|
|
"ConnectTimeout=10",
|
|
"-o",
|
|
"ServerAliveInterval=30",
|
|
"-o",
|
|
"ServerAliveCountMax=2",
|
|
"-o",
|
|
"StrictHostKeyChecking=accept-new",
|
|
"-o",
|
|
"UserKnownHostsFile=/dev/null",
|
|
f"{SSH_USER}@{host_ip}",
|
|
]
|
|
|
|
ttyd_interface = TERMINAL_TTYD_INTERFACE
|
|
|
|
# Note: --credential removed to avoid HTTP basic auth popup
|
|
# Security is handled by session token validation in the connect/popout routes
|
|
cmd = [
|
|
TTYD_PATH,
|
|
"--once",
|
|
f"--port={port}",
|
|
"--writable",
|
|
]
|
|
if ttyd_interface:
|
|
cmd.extend(["--interface", ttyd_interface])
|
|
cmd.extend(ssh_cmd)
|
|
|
|
logger.info(
|
|
f"Spawning ttyd for session {session_id[:8]}... on port {port} -> {SSH_USER}@{host_ip} (key={ssh_key_path})"
|
|
)
|
|
|
|
try:
|
|
# Spawn the process
|
|
process = subprocess.Popen(
|
|
cmd,
|
|
stdout=subprocess.PIPE,
|
|
stderr=subprocess.PIPE,
|
|
start_new_session=True, # Detach from parent process group
|
|
)
|
|
|
|
# Store the process
|
|
async with self._lock:
|
|
self._processes[session_id] = process
|
|
|
|
# Wait briefly to check if process started successfully
|
|
await asyncio.sleep(0.5)
|
|
|
|
if process.poll() is not None:
|
|
# Process exited immediately - likely an error
|
|
stdout, stderr = process.communicate(timeout=1)
|
|
error_msg = stderr.decode() if stderr else stdout.decode() if stdout else "Unknown error"
|
|
logger.error(f"ttyd failed to start: {error_msg}")
|
|
await self.release_port(port)
|
|
raise TerminalServiceError(f"ttyd failed to start: {error_msg.strip()}")
|
|
|
|
logger.info(f"ttyd started with PID {process.pid} for session {session_id[:8]}...")
|
|
return process.pid
|
|
|
|
except TerminalServiceError:
|
|
# Preserve detailed error message for API layer
|
|
raise
|
|
except Exception as e:
|
|
logger.exception(f"Failed to spawn ttyd: {e}")
|
|
await self.release_port(port)
|
|
raise TerminalServiceError(str(e))
|
|
|
|
async def terminate_session(self, session_id: str) -> bool:
|
|
"""
|
|
Terminate a terminal session and its ttyd process.
|
|
|
|
Returns:
|
|
True if session was terminated, False if not found
|
|
"""
|
|
async with self._lock:
|
|
process = self._processes.pop(session_id, None)
|
|
|
|
if process is None:
|
|
return False
|
|
|
|
try:
|
|
# Try graceful termination first
|
|
if process.poll() is None:
|
|
process.terminate()
|
|
try:
|
|
process.wait(timeout=5)
|
|
except subprocess.TimeoutExpired:
|
|
# Force kill if graceful termination failed
|
|
process.kill()
|
|
process.wait(timeout=2)
|
|
|
|
logger.info(f"Terminated ttyd process for session {session_id[:8]}...")
|
|
return True
|
|
|
|
except Exception as e:
|
|
logger.exception(f"Error terminating ttyd process: {e}")
|
|
return False
|
|
|
|
async def cleanup_expired_sessions(
|
|
self,
|
|
get_expired_sessions_func,
|
|
mark_expired_func,
|
|
db_session
|
|
) -> int:
|
|
"""
|
|
Clean up expired terminal sessions.
|
|
|
|
Args:
|
|
get_expired_sessions_func: Async function to get expired sessions from DB
|
|
mark_expired_func: Async function to mark session as expired in DB
|
|
db_session: Database session
|
|
|
|
Returns:
|
|
Number of sessions cleaned up
|
|
"""
|
|
from app.crud.terminal_session import TerminalSessionRepository
|
|
|
|
repo = TerminalSessionRepository(db_session)
|
|
expired = await repo.list_expired()
|
|
|
|
count = 0
|
|
for session in expired:
|
|
# Terminate the ttyd process
|
|
await self.terminate_session(session.id)
|
|
# Release the port
|
|
await self.release_port(session.ttyd_port)
|
|
# Mark as expired in DB
|
|
await repo.mark_expired(session.id)
|
|
count += 1
|
|
self._metrics["sessions_gc_expired"] += 1
|
|
logger.info(f"session_gc_closed session={session.id[:8]}... host={session.host_name} reason=ttl")
|
|
|
|
return count
|
|
|
|
async def cleanup_idle_sessions(self, db_session) -> int:
|
|
"""
|
|
Clean up idle terminal sessions (no heartbeat within timeout).
|
|
|
|
Returns:
|
|
Number of sessions cleaned up
|
|
"""
|
|
from app.crud.terminal_session import TerminalSessionRepository
|
|
|
|
repo = TerminalSessionRepository(db_session)
|
|
idle_sessions = await repo.list_idle(TERMINAL_SESSION_IDLE_TIMEOUT_SECONDS)
|
|
|
|
count = 0
|
|
for session in idle_sessions:
|
|
# Terminate the ttyd process
|
|
await self.terminate_session(session.id)
|
|
# Release the port
|
|
await self.release_port(session.ttyd_port)
|
|
# Mark as idle in DB
|
|
await repo.mark_idle(session.id)
|
|
count += 1
|
|
self._metrics["sessions_gc_idle"] += 1
|
|
logger.info(f"session_gc_closed session={session.id[:8]}... host={session.host_name} reason=idle")
|
|
|
|
return count
|
|
|
|
async def close_session_with_cleanup(
|
|
self,
|
|
session_id: str,
|
|
port: int,
|
|
reason: str = "user_close"
|
|
) -> bool:
|
|
"""
|
|
Close a session with full cleanup (process + port).
|
|
|
|
Returns:
|
|
True if session was cleaned up
|
|
"""
|
|
terminated = await self.terminate_session(session_id)
|
|
await self.release_port(port)
|
|
|
|
if reason == "user_close":
|
|
self._metrics["sessions_closed_user"] += 1
|
|
|
|
logger.info(f"session_closed session={session_id[:8]}... reason={reason}")
|
|
return terminated
|
|
|
|
def record_session_created(self, reused: bool = False):
|
|
"""Record a session creation for metrics."""
|
|
if reused:
|
|
self._metrics["sessions_reused"] += 1
|
|
else:
|
|
self._metrics["sessions_created"] += 1
|
|
|
|
def record_session_limit_hit(self):
|
|
"""Record a session limit hit for metrics."""
|
|
self._metrics["session_limit_hits"] += 1
|
|
|
|
def get_active_session_count(self) -> int:
|
|
"""Get the number of active ttyd processes."""
|
|
return len(self._processes)
|
|
|
|
def get_session_url(self, port: int, token: str, base_url: str = "") -> str:
|
|
"""
|
|
Get the URL to access a terminal session.
|
|
|
|
Args:
|
|
port: ttyd port
|
|
token: Session token
|
|
base_url: Optional base URL prefix
|
|
|
|
Returns:
|
|
Full URL to access the terminal
|
|
"""
|
|
# If we have a reverse proxy, use the proxied URL
|
|
# Otherwise, direct access to ttyd port
|
|
if base_url:
|
|
return f"{base_url}/terminal/proxy/{port}?token={token}"
|
|
return f"http://localhost:{port}/"
|
|
|
|
def get_websocket_url(self, port: int) -> str:
|
|
"""Get the WebSocket URL for terminal connection."""
|
|
return f"ws://localhost:{port}/ws"
|
|
|
|
|
|
# Global service instance
|
|
terminal_service = TerminalService()
|
|
|
|
|
|
def get_terminal_service() -> TerminalService:
|
|
"""Get the global terminal service instance."""
|
|
return terminal_service
|
|
|
|
|
|
# Export configuration for use in routes
|
|
__all__ = [
|
|
"terminal_service",
|
|
"get_terminal_service",
|
|
"TerminalService",
|
|
"TerminalServiceError",
|
|
"HostNotReadyError",
|
|
"SessionLimitExceededError",
|
|
"TtydNotAvailableError",
|
|
"TERMINAL_SESSION_TTL_MINUTES",
|
|
"TERMINAL_SESSION_TTL_SECONDS",
|
|
"TERMINAL_MAX_SESSIONS_PER_USER",
|
|
"TERMINAL_SESSION_IDLE_TIMEOUT_SECONDS",
|
|
"TERMINAL_HEARTBEAT_INTERVAL_SECONDS",
|
|
"TERMINAL_GC_INTERVAL_SECONDS",
|
|
]
|