# backend/auth/jwt_handler.py # JWT token generation, validation, and revocation. # Secret key auto-generated on first startup and persisted to data/secret.key. # Revoked token JTIs persisted to data/revoked_tokens.json. import json import secrets import uuid import time import logging from pathlib import Path from jose import jwt, JWTError from typing import Optional logger = logging.getLogger("obsigate.auth.jwt") # Paths relative to working directory (Docker: /app) SECRET_KEY_FILE = Path("data/secret.key") REVOKED_TOKENS_FILE = Path("data/revoked_tokens.json") ALGORITHM = "HS256" ACCESS_TOKEN_EXPIRE_SECONDS = 3600 # 1 hour REFRESH_TOKEN_EXPIRE_SECONDS = 604800 # 7 days # In-memory revoked token set (loaded from disk on startup) _revoked_jtis: set = set() _revoked_loaded = False def get_secret_key() -> str: """Read or generate the JWT secret key. On first call, generates a 512-bit random key and writes it to data/secret.key with 600 permissions. Subsequent calls read from disk. """ if not SECRET_KEY_FILE.exists(): SECRET_KEY_FILE.parent.mkdir(parents=True, exist_ok=True) key = secrets.token_hex(64) # 512 bits SECRET_KEY_FILE.write_text(key) try: SECRET_KEY_FILE.chmod(0o600) except OSError: pass # Windows doesn't support Unix permissions logger.info("Generated new JWT secret key") return key return SECRET_KEY_FILE.read_text().strip() def create_access_token(user: dict) -> str: """Create a JWT access token with user claims.""" now = int(time.time()) payload = { "sub": user["username"], "role": user["role"], "vaults": user["vaults"], "jti": str(uuid.uuid4()), "iat": now, "exp": now + ACCESS_TOKEN_EXPIRE_SECONDS, "type": "access", } return jwt.encode(payload, get_secret_key(), algorithm=ALGORITHM) def create_refresh_token(username: str) -> tuple: """Create a JWT refresh token. Returns (token_string, jti).""" now = int(time.time()) jti = str(uuid.uuid4()) payload = { "sub": username, "jti": jti, "iat": now, "exp": now + REFRESH_TOKEN_EXPIRE_SECONDS, "type": "refresh", } return jwt.encode(payload, get_secret_key(), algorithm=ALGORITHM), jti def decode_token(token: str) -> Optional[dict]: """Decode and validate a JWT. Returns None if invalid/expired.""" try: return jwt.decode(token, get_secret_key(), algorithms=[ALGORITHM]) except JWTError: return None # --------------------------------------------------------------------------- # Token revocation # --------------------------------------------------------------------------- def _load_revoked(): """Load revoked token JTIs from disk into memory (once).""" global _revoked_loaded, _revoked_jtis if _revoked_loaded: return if REVOKED_TOKENS_FILE.exists(): try: data = json.loads(REVOKED_TOKENS_FILE.read_text()) # Clean expired entries (older than 7 days) now = int(time.time()) _revoked_jtis = { jti for jti, exp in data.items() if exp > now } except Exception as e: logger.warning(f"Failed to load revoked tokens: {e}") _revoked_jtis = set() _revoked_loaded = True def _save_revoked(): """Persist revoked JTIs to disk.""" REVOKED_TOKENS_FILE.parent.mkdir(parents=True, exist_ok=True) # Store with expiry timestamp for cleanup now = int(time.time()) # Keep entries for 7 days max data = {jti: now + REFRESH_TOKEN_EXPIRE_SECONDS for jti in _revoked_jtis} tmp = REVOKED_TOKENS_FILE.with_suffix(".tmp") tmp.write_text(json.dumps(data)) tmp.replace(REVOKED_TOKENS_FILE) def revoke_token(jti: str): """Add a token JTI to the revocation list.""" _load_revoked() _revoked_jtis.add(jti) _save_revoked() logger.debug(f"Revoked token JTI: {jti[:8]}...") def is_token_revoked(jti: str) -> bool: """Check if a token JTI has been revoked.""" _load_revoked() return jti in _revoked_jtis