""" IP-based rate limiter for authentication endpoints. Tracks failed login attempts per IP address with automatic cleanup of expired entries. Complements the per-account lockout in user_store.py. Configuration via environment variables: OBSIGATE_LOGIN_MAX_ATTEMPTS Max failures per IP (default: 10) OBSIGATE_LOGIN_WINDOW_SECONDS Lockout window in seconds (default: 900 = 15min) """ import os import time import logging from collections import defaultdict from typing import Dict, Tuple logger = logging.getLogger("obsigate.ratelimit") # --- Configuration --- MAX_ATTEMPTS = int(os.environ.get("OBSIGATE_LOGIN_MAX_ATTEMPTS", "10")) WINDOW_SECONDS = int(os.environ.get("OBSIGATE_LOGIN_WINDOW_SECONDS", "900")) # 15 min # --- In-memory store: {ip: [(timestamp, success_bool), ...]} --- _ip_attempts: Dict[str, list] = defaultdict(list) _last_cleanup = time.time() CLEANUP_INTERVAL = 60 # seconds def _cleanup_expired(): """Remove entries older than the window.""" global _last_cleanup now = time.time() if now - _last_cleanup < CLEANUP_INTERVAL: return _last_cleanup = now cutoff = now - WINDOW_SECONDS expired_ips = [] for ip, attempts in _ip_attempts.items(): _ip_attempts[ip] = [a for a in attempts if a[0] > cutoff] if not _ip_attempts[ip]: expired_ips.append(ip) for ip in expired_ips: del _ip_attempts[ip] def record_failure(ip: str) -> Tuple[int, int]: """Record a failed login attempt from an IP. Returns: (current_failure_count, remaining_attempts) """ _cleanup_expired() _ip_attempts[ip].append((time.time(), False)) failures = sum(1 for _, success in _ip_attempts[ip] if not success) remaining = max(0, MAX_ATTEMPTS - failures) if failures >= MAX_ATTEMPTS: logger.warning(f"IP {ip} rate-limited after {failures} failed logins") return failures, remaining def record_success(ip: str): """Clear rate limit state for an IP after successful login.""" _cleanup_expired() _ip_attempts[ip] = [(time.time(), True)] def is_rate_limited(ip: str) -> bool: """Check if an IP has exceeded the rate limit.""" _cleanup_expired() failures = sum(1 for _, success in _ip_attempts.get(ip, []) if not success) return failures >= MAX_ATTEMPTS def get_status(ip: str | None = None) -> dict: """Get rate limit status for an IP (for diagnostics).""" _cleanup_expired() if ip: attempts = _ip_attempts.get(ip, []) failures = sum(1 for _, s in attempts if not s) return { "ip": ip, "failures": failures, "max": MAX_ATTEMPTS, "limited": failures >= MAX_ATTEMPTS, "window_seconds": WINDOW_SECONDS, } return { "tracked_ips": len(_ip_attempts), "max_attempts": MAX_ATTEMPTS, "window_seconds": WINDOW_SECONDS, "limited_ips": sum( 1 for ip_addr in _ip_attempts if sum(1 for _, s in _ip_attempts[ip_addr] if not s) >= MAX_ATTEMPTS ), }