ObsiGate/backend/auth/jwt_handler.py

134 lines
4.1 KiB
Python

# backend/auth/jwt_handler.py
# JWT token generation, validation, and revocation.
# Secret key auto-generated on first startup and persisted to data/secret.key.
# Revoked token JTIs persisted to data/revoked_tokens.json.
import json
import secrets
import uuid
import time
import logging
from pathlib import Path
from jose import jwt, JWTError
from typing import Optional
logger = logging.getLogger("obsigate.auth.jwt")
# Paths relative to working directory (Docker: /app)
SECRET_KEY_FILE = Path("data/secret.key")
REVOKED_TOKENS_FILE = Path("data/revoked_tokens.json")
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_SECONDS = 3600 # 1 hour
REFRESH_TOKEN_EXPIRE_SECONDS = 604800 # 7 days
# In-memory revoked token set (loaded from disk on startup)
_revoked_jtis: set = set()
_revoked_loaded = False
def get_secret_key() -> str:
"""Read or generate the JWT secret key.
On first call, generates a 512-bit random key and writes it to
data/secret.key with 600 permissions. Subsequent calls read from disk.
"""
if not SECRET_KEY_FILE.exists():
SECRET_KEY_FILE.parent.mkdir(parents=True, exist_ok=True)
key = secrets.token_hex(64) # 512 bits
SECRET_KEY_FILE.write_text(key)
try:
SECRET_KEY_FILE.chmod(0o600)
except OSError:
pass # Windows doesn't support Unix permissions
logger.info("Generated new JWT secret key")
return key
return SECRET_KEY_FILE.read_text().strip()
def create_access_token(user: dict) -> str:
"""Create a JWT access token with user claims."""
now = int(time.time())
payload = {
"sub": user["username"],
"role": user["role"],
"vaults": user["vaults"],
"jti": str(uuid.uuid4()),
"iat": now,
"exp": now + ACCESS_TOKEN_EXPIRE_SECONDS,
"type": "access",
}
return jwt.encode(payload, get_secret_key(), algorithm=ALGORITHM)
def create_refresh_token(username: str) -> tuple:
"""Create a JWT refresh token. Returns (token_string, jti)."""
now = int(time.time())
jti = str(uuid.uuid4())
payload = {
"sub": username,
"jti": jti,
"iat": now,
"exp": now + REFRESH_TOKEN_EXPIRE_SECONDS,
"type": "refresh",
}
return jwt.encode(payload, get_secret_key(), algorithm=ALGORITHM), jti
def decode_token(token: str) -> Optional[dict]:
"""Decode and validate a JWT. Returns None if invalid/expired."""
try:
return jwt.decode(token, get_secret_key(), algorithms=[ALGORITHM])
except JWTError:
return None
# ---------------------------------------------------------------------------
# Token revocation
# ---------------------------------------------------------------------------
def _load_revoked():
"""Load revoked token JTIs from disk into memory (once)."""
global _revoked_loaded, _revoked_jtis
if _revoked_loaded:
return
if REVOKED_TOKENS_FILE.exists():
try:
data = json.loads(REVOKED_TOKENS_FILE.read_text())
# Clean expired entries (older than 7 days)
now = int(time.time())
_revoked_jtis = {
jti for jti, exp in data.items()
if exp > now
}
except Exception as e:
logger.warning(f"Failed to load revoked tokens: {e}")
_revoked_jtis = set()
_revoked_loaded = True
def _save_revoked():
"""Persist revoked JTIs to disk."""
REVOKED_TOKENS_FILE.parent.mkdir(parents=True, exist_ok=True)
# Store with expiry timestamp for cleanup
now = int(time.time())
# Keep entries for 7 days max
data = {jti: now + REFRESH_TOKEN_EXPIRE_SECONDS for jti in _revoked_jtis}
tmp = REVOKED_TOKENS_FILE.with_suffix(".tmp")
tmp.write_text(json.dumps(data))
tmp.replace(REVOKED_TOKENS_FILE)
def revoke_token(jti: str):
"""Add a token JTI to the revocation list."""
_load_revoked()
_revoked_jtis.add(jti)
_save_revoked()
logger.debug(f"Revoked token JTI: {jti[:8]}...")
def is_token_revoked(jti: str) -> bool:
"""Check if a token JTI has been revoked."""
_load_revoked()
return jti in _revoked_jtis