134 lines
4.1 KiB
Python
134 lines
4.1 KiB
Python
# backend/auth/jwt_handler.py
|
|
# JWT token generation, validation, and revocation.
|
|
# Secret key auto-generated on first startup and persisted to data/secret.key.
|
|
# Revoked token JTIs persisted to data/revoked_tokens.json.
|
|
|
|
import json
|
|
import secrets
|
|
import uuid
|
|
import time
|
|
import logging
|
|
from pathlib import Path
|
|
from jose import jwt, JWTError
|
|
from typing import Optional
|
|
|
|
logger = logging.getLogger("obsigate.auth.jwt")
|
|
|
|
# Paths relative to working directory (Docker: /app)
|
|
SECRET_KEY_FILE = Path("data/secret.key")
|
|
REVOKED_TOKENS_FILE = Path("data/revoked_tokens.json")
|
|
|
|
ALGORITHM = "HS256"
|
|
ACCESS_TOKEN_EXPIRE_SECONDS = 3600 # 1 hour
|
|
REFRESH_TOKEN_EXPIRE_SECONDS = 604800 # 7 days
|
|
|
|
# In-memory revoked token set (loaded from disk on startup)
|
|
_revoked_jtis: set = set()
|
|
_revoked_loaded = False
|
|
|
|
|
|
def get_secret_key() -> str:
|
|
"""Read or generate the JWT secret key.
|
|
|
|
On first call, generates a 512-bit random key and writes it to
|
|
data/secret.key with 600 permissions. Subsequent calls read from disk.
|
|
"""
|
|
if not SECRET_KEY_FILE.exists():
|
|
SECRET_KEY_FILE.parent.mkdir(parents=True, exist_ok=True)
|
|
key = secrets.token_hex(64) # 512 bits
|
|
SECRET_KEY_FILE.write_text(key)
|
|
try:
|
|
SECRET_KEY_FILE.chmod(0o600)
|
|
except OSError:
|
|
pass # Windows doesn't support Unix permissions
|
|
logger.info("Generated new JWT secret key")
|
|
return key
|
|
return SECRET_KEY_FILE.read_text().strip()
|
|
|
|
|
|
def create_access_token(user: dict) -> str:
|
|
"""Create a JWT access token with user claims."""
|
|
now = int(time.time())
|
|
payload = {
|
|
"sub": user["username"],
|
|
"role": user["role"],
|
|
"vaults": user["vaults"],
|
|
"jti": str(uuid.uuid4()),
|
|
"iat": now,
|
|
"exp": now + ACCESS_TOKEN_EXPIRE_SECONDS,
|
|
"type": "access",
|
|
}
|
|
return jwt.encode(payload, get_secret_key(), algorithm=ALGORITHM)
|
|
|
|
|
|
def create_refresh_token(username: str) -> tuple:
|
|
"""Create a JWT refresh token. Returns (token_string, jti)."""
|
|
now = int(time.time())
|
|
jti = str(uuid.uuid4())
|
|
payload = {
|
|
"sub": username,
|
|
"jti": jti,
|
|
"iat": now,
|
|
"exp": now + REFRESH_TOKEN_EXPIRE_SECONDS,
|
|
"type": "refresh",
|
|
}
|
|
return jwt.encode(payload, get_secret_key(), algorithm=ALGORITHM), jti
|
|
|
|
|
|
def decode_token(token: str) -> Optional[dict]:
|
|
"""Decode and validate a JWT. Returns None if invalid/expired."""
|
|
try:
|
|
return jwt.decode(token, get_secret_key(), algorithms=[ALGORITHM])
|
|
except JWTError:
|
|
return None
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Token revocation
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def _load_revoked():
|
|
"""Load revoked token JTIs from disk into memory (once)."""
|
|
global _revoked_loaded, _revoked_jtis
|
|
if _revoked_loaded:
|
|
return
|
|
if REVOKED_TOKENS_FILE.exists():
|
|
try:
|
|
data = json.loads(REVOKED_TOKENS_FILE.read_text())
|
|
# Clean expired entries (older than 7 days)
|
|
now = int(time.time())
|
|
_revoked_jtis = {
|
|
jti for jti, exp in data.items()
|
|
if exp > now
|
|
}
|
|
except Exception as e:
|
|
logger.warning(f"Failed to load revoked tokens: {e}")
|
|
_revoked_jtis = set()
|
|
_revoked_loaded = True
|
|
|
|
|
|
def _save_revoked():
|
|
"""Persist revoked JTIs to disk."""
|
|
REVOKED_TOKENS_FILE.parent.mkdir(parents=True, exist_ok=True)
|
|
# Store with expiry timestamp for cleanup
|
|
now = int(time.time())
|
|
# Keep entries for 7 days max
|
|
data = {jti: now + REFRESH_TOKEN_EXPIRE_SECONDS for jti in _revoked_jtis}
|
|
tmp = REVOKED_TOKENS_FILE.with_suffix(".tmp")
|
|
tmp.write_text(json.dumps(data))
|
|
tmp.replace(REVOKED_TOKENS_FILE)
|
|
|
|
|
|
def revoke_token(jti: str):
|
|
"""Add a token JTI to the revocation list."""
|
|
_load_revoked()
|
|
_revoked_jtis.add(jti)
|
|
_save_revoked()
|
|
logger.debug(f"Revoked token JTI: {jti[:8]}...")
|
|
|
|
|
|
def is_token_revoked(jti: str) -> bool:
|
|
"""Check if a token JTI has been revoked."""
|
|
_load_revoked()
|
|
return jti in _revoked_jtis
|