Some checks failed
Tests / Backend Tests (Python) (3.10) (push) Has been cancelled
Tests / Backend Tests (Python) (3.11) (push) Has been cancelled
Tests / Backend Tests (Python) (3.12) (push) Has been cancelled
Tests / Frontend Tests (JS) (push) Has been cancelled
Tests / Integration Tests (push) Has been cancelled
Tests / All Tests Passed (push) Has been cancelled
319 lines
11 KiB
Python
319 lines
11 KiB
Python
"""
|
|
CRUD operations for terminal sessions.
|
|
"""
|
|
from datetime import datetime, timedelta, timezone
|
|
from typing import List, Optional
|
|
|
|
from sqlalchemy import select, update, and_, or_
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from app.models.terminal_session import (
|
|
TerminalSession,
|
|
SESSION_STATUS_ACTIVE,
|
|
SESSION_STATUS_CLOSED,
|
|
SESSION_STATUS_EXPIRED,
|
|
SESSION_STATUS_CLOSING,
|
|
CLOSE_REASON_USER,
|
|
CLOSE_REASON_TTL,
|
|
CLOSE_REASON_IDLE,
|
|
CLOSE_REASON_CLIENT_LOST,
|
|
)
|
|
|
|
|
|
class TerminalSessionRepository:
|
|
"""Repository for TerminalSession CRUD operations."""
|
|
|
|
def __init__(self, session: AsyncSession):
|
|
self.session = session
|
|
|
|
async def create(
|
|
self,
|
|
id: str,
|
|
host_id: str,
|
|
host_name: str,
|
|
host_ip: str,
|
|
token_hash: str,
|
|
ttyd_port: int,
|
|
expires_at: datetime,
|
|
user_id: Optional[str] = None,
|
|
username: Optional[str] = None,
|
|
ttyd_pid: Optional[int] = None,
|
|
mode: str = "embedded",
|
|
) -> TerminalSession:
|
|
"""Create a new terminal session."""
|
|
now = datetime.now(timezone.utc)
|
|
session = TerminalSession(
|
|
id=id,
|
|
host_id=host_id,
|
|
host_name=host_name,
|
|
host_ip=host_ip,
|
|
user_id=user_id,
|
|
username=username,
|
|
token_hash=token_hash,
|
|
ttyd_port=ttyd_port,
|
|
ttyd_pid=ttyd_pid,
|
|
mode=mode,
|
|
status=SESSION_STATUS_ACTIVE,
|
|
expires_at=expires_at,
|
|
last_seen_at=now,
|
|
)
|
|
self.session.add(session)
|
|
await self.session.flush()
|
|
return session
|
|
|
|
async def get(self, session_id: str) -> Optional[TerminalSession]:
|
|
"""Get a terminal session by ID."""
|
|
result = await self.session.execute(
|
|
select(TerminalSession).where(TerminalSession.id == session_id)
|
|
)
|
|
return result.scalar_one_or_none()
|
|
|
|
async def get_active_by_id(self, session_id: str) -> Optional[TerminalSession]:
|
|
"""Get an active terminal session by ID."""
|
|
now = datetime.now(timezone.utc)
|
|
result = await self.session.execute(
|
|
select(TerminalSession).where(
|
|
and_(
|
|
TerminalSession.id == session_id,
|
|
TerminalSession.status == SESSION_STATUS_ACTIVE,
|
|
TerminalSession.expires_at > now
|
|
)
|
|
)
|
|
)
|
|
return result.scalar_one_or_none()
|
|
|
|
async def find_reusable_session(
|
|
self,
|
|
user_id: str,
|
|
host_id: str,
|
|
mode: str,
|
|
idle_timeout_seconds: int = 120
|
|
) -> Optional[TerminalSession]:
|
|
"""
|
|
Find an existing active session that can be reused.
|
|
|
|
A session is reusable if:
|
|
- Same user, host, and mode
|
|
- Status is 'active'
|
|
- Not expired
|
|
- Last seen within idle timeout (considered "healthy")
|
|
"""
|
|
now = datetime.now(timezone.utc)
|
|
min_last_seen = now - timedelta(seconds=idle_timeout_seconds)
|
|
|
|
result = await self.session.execute(
|
|
select(TerminalSession).where(
|
|
and_(
|
|
TerminalSession.user_id == user_id,
|
|
TerminalSession.host_id == host_id,
|
|
TerminalSession.mode == mode,
|
|
TerminalSession.status == SESSION_STATUS_ACTIVE,
|
|
TerminalSession.expires_at > now,
|
|
TerminalSession.last_seen_at >= min_last_seen
|
|
)
|
|
).order_by(TerminalSession.created_at.desc())
|
|
)
|
|
return result.scalar_one_or_none()
|
|
|
|
async def list_active_for_user(self, user_id: str) -> List[TerminalSession]:
|
|
"""List all active sessions for a user."""
|
|
now = datetime.now(timezone.utc)
|
|
result = await self.session.execute(
|
|
select(TerminalSession).where(
|
|
and_(
|
|
TerminalSession.user_id == user_id,
|
|
TerminalSession.status == SESSION_STATUS_ACTIVE,
|
|
TerminalSession.expires_at > now
|
|
)
|
|
).order_by(TerminalSession.created_at.desc())
|
|
)
|
|
return list(result.scalars().all())
|
|
|
|
async def list_all_active(self) -> List[TerminalSession]:
|
|
"""List all active sessions (for admin/GC)."""
|
|
now = datetime.now(timezone.utc)
|
|
result = await self.session.execute(
|
|
select(TerminalSession).where(
|
|
and_(
|
|
TerminalSession.status == SESSION_STATUS_ACTIVE,
|
|
TerminalSession.expires_at > now
|
|
)
|
|
).order_by(TerminalSession.created_at.desc())
|
|
)
|
|
return list(result.scalars().all())
|
|
|
|
async def count_active_for_user(self, user_id: str) -> int:
|
|
"""Count active sessions for a user."""
|
|
sessions = await self.list_active_for_user(user_id)
|
|
return len(sessions)
|
|
|
|
async def list_expired(self) -> List[TerminalSession]:
|
|
"""List all expired but not yet closed sessions (TTL exceeded)."""
|
|
now = datetime.now(timezone.utc)
|
|
result = await self.session.execute(
|
|
select(TerminalSession).where(
|
|
and_(
|
|
TerminalSession.status == SESSION_STATUS_ACTIVE,
|
|
TerminalSession.expires_at <= now
|
|
)
|
|
)
|
|
)
|
|
return list(result.scalars().all())
|
|
|
|
async def list_idle(self, idle_timeout_seconds: int) -> List[TerminalSession]:
|
|
"""List all idle sessions (no heartbeat within timeout)."""
|
|
now = datetime.now(timezone.utc)
|
|
idle_cutoff = now - timedelta(seconds=idle_timeout_seconds)
|
|
|
|
result = await self.session.execute(
|
|
select(TerminalSession).where(
|
|
and_(
|
|
TerminalSession.status == SESSION_STATUS_ACTIVE,
|
|
TerminalSession.last_seen_at < idle_cutoff
|
|
)
|
|
)
|
|
)
|
|
return list(result.scalars().all())
|
|
|
|
async def list_stale_sessions(self, ttl_seconds: int, idle_timeout_seconds: int) -> List[TerminalSession]:
|
|
"""
|
|
List all sessions that should be cleaned up.
|
|
|
|
Returns sessions that are either:
|
|
- Expired (past TTL)
|
|
- Idle (no heartbeat within idle timeout)
|
|
"""
|
|
now = datetime.now(timezone.utc)
|
|
idle_cutoff = now - timedelta(seconds=idle_timeout_seconds)
|
|
|
|
result = await self.session.execute(
|
|
select(TerminalSession).where(
|
|
and_(
|
|
TerminalSession.status == SESSION_STATUS_ACTIVE,
|
|
or_(
|
|
TerminalSession.expires_at <= now,
|
|
TerminalSession.last_seen_at < idle_cutoff
|
|
)
|
|
)
|
|
)
|
|
)
|
|
return list(result.scalars().all())
|
|
|
|
async def update_status(
|
|
self,
|
|
session_id: str,
|
|
status: str,
|
|
closed_at: Optional[datetime] = None,
|
|
reason_closed: Optional[str] = None
|
|
) -> Optional[TerminalSession]:
|
|
"""Update session status."""
|
|
session = await self.get(session_id)
|
|
if session:
|
|
session.status = status
|
|
if closed_at:
|
|
session.closed_at = closed_at
|
|
if reason_closed:
|
|
session.reason_closed = reason_closed
|
|
await self.session.flush()
|
|
return session
|
|
|
|
async def update_last_seen(self, session_id: str) -> Optional[TerminalSession]:
|
|
"""Update last_seen_at timestamp (heartbeat)."""
|
|
session = await self.get(session_id)
|
|
if session and session.status == SESSION_STATUS_ACTIVE:
|
|
session.last_seen_at = datetime.now(timezone.utc)
|
|
await self.session.flush()
|
|
return session
|
|
|
|
async def close_session(
|
|
self,
|
|
session_id: str,
|
|
reason: str = CLOSE_REASON_USER
|
|
) -> Optional[TerminalSession]:
|
|
"""Close a terminal session with reason."""
|
|
return await self.update_status(
|
|
session_id,
|
|
status=SESSION_STATUS_CLOSED,
|
|
closed_at=datetime.now(timezone.utc),
|
|
reason_closed=reason
|
|
)
|
|
|
|
async def mark_expired(self, session_id: str) -> Optional[TerminalSession]:
|
|
"""Mark a session as expired (TTL)."""
|
|
return await self.update_status(
|
|
session_id,
|
|
status=SESSION_STATUS_EXPIRED,
|
|
closed_at=datetime.now(timezone.utc),
|
|
reason_closed=CLOSE_REASON_TTL
|
|
)
|
|
|
|
async def mark_idle(self, session_id: str) -> Optional[TerminalSession]:
|
|
"""Mark a session as expired due to idle timeout."""
|
|
return await self.update_status(
|
|
session_id,
|
|
status=SESSION_STATUS_EXPIRED,
|
|
closed_at=datetime.now(timezone.utc),
|
|
reason_closed=CLOSE_REASON_IDLE
|
|
)
|
|
|
|
async def mark_client_lost(self, session_id: str) -> Optional[TerminalSession]:
|
|
"""Mark a session as closed due to client disconnect."""
|
|
return await self.update_status(
|
|
session_id,
|
|
status=SESSION_STATUS_CLOSED,
|
|
closed_at=datetime.now(timezone.utc),
|
|
reason_closed=CLOSE_REASON_CLIENT_LOST
|
|
)
|
|
|
|
async def cleanup_old_sessions(self, days: int = 7) -> int:
|
|
"""Delete sessions older than specified days."""
|
|
cutoff = datetime.now(timezone.utc) - timedelta(days=days)
|
|
|
|
result = await self.session.execute(
|
|
select(TerminalSession).where(
|
|
and_(
|
|
TerminalSession.status.in_([SESSION_STATUS_CLOSED, SESSION_STATUS_EXPIRED, "error"]),
|
|
TerminalSession.created_at < cutoff
|
|
)
|
|
)
|
|
)
|
|
sessions = result.scalars().all()
|
|
count = len(sessions)
|
|
|
|
for session in sessions:
|
|
await self.session.delete(session)
|
|
|
|
await self.session.flush()
|
|
return count
|
|
|
|
async def close_all_active(self, reason: str = CLOSE_REASON_USER) -> int:
|
|
"""Close all active sessions (for cleanup/reset)."""
|
|
now = datetime.now(timezone.utc)
|
|
result = await self.session.execute(
|
|
select(TerminalSession).where(TerminalSession.status == SESSION_STATUS_ACTIVE)
|
|
)
|
|
sessions = list(result.scalars().all())
|
|
count = len(sessions)
|
|
|
|
for session in sessions:
|
|
session.status = SESSION_STATUS_CLOSED
|
|
session.closed_at = now
|
|
session.reason_closed = reason
|
|
|
|
await self.session.flush()
|
|
return count
|
|
|
|
async def get_oldest_active_for_user(self, user_id: str) -> Optional[TerminalSession]:
|
|
"""Get the oldest active session for a user (for auto-close)."""
|
|
now = datetime.now(timezone.utc)
|
|
result = await self.session.execute(
|
|
select(TerminalSession).where(
|
|
and_(
|
|
TerminalSession.user_id == user_id,
|
|
TerminalSession.status == SESSION_STATUS_ACTIVE,
|
|
TerminalSession.expires_at > now
|
|
)
|
|
).order_by(TerminalSession.created_at.asc())
|
|
)
|
|
return result.scalars().first()
|