"""Database configuration and session management for Homelab Automation. Uses SQLAlchemy 2.x async engine with SQLite + aiosqlite driver. """ from __future__ import annotations import os from pathlib import Path from typing import AsyncGenerator from urllib.parse import urlparse from sqlalchemy import event, MetaData from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker, create_async_engine from sqlalchemy.orm import declarative_base # Naming convention to keep Alembic happy with constraints NAMING_CONVENTION = { "ix": "ix_%(column_0_label)s", "uq": "uq_%(table_name)s_%(column_0_name)s", "ck": "ck_%(table_name)s_%(constraint_name)s", "fk": "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s", "pk": "pk_%(table_name)s", } metadata_obj = MetaData(naming_convention=NAMING_CONVENTION) Base = declarative_base(metadata=metadata_obj) # Resolve base path (project root) ROOT_DIR = Path(__file__).resolve().parents[2] DEFAULT_DB_PATH = Path(os.environ.get("DB_PATH") or (ROOT_DIR / "data" / "homelab.db")) DATABASE_URL = os.environ.get("DATABASE_URL", f"sqlite+aiosqlite:///{DEFAULT_DB_PATH}") # Ensure SQLite directory exists even if DATABASE_URL overrides DB_PATH def _ensure_sqlite_dir(db_url: str) -> None: if not db_url.startswith("sqlite"): return parsed = urlparse(db_url.replace("sqlite+aiosqlite", "sqlite")) if parsed.scheme != "sqlite": return db_path = Path(parsed.path) if db_path.parent: db_path.parent.mkdir(parents=True, exist_ok=True) DEFAULT_DB_PATH.parent.mkdir(parents=True, exist_ok=True) _ensure_sqlite_dir(DATABASE_URL) def _debug_db_paths() -> None: try: print( "[DB] DATABASE_URL=%s, DEFAULT_DB_PATH=%s, parent_exists=%s, parent=%s" % ( DATABASE_URL, DEFAULT_DB_PATH, DEFAULT_DB_PATH.parent.exists(), DEFAULT_DB_PATH.parent, ) ) except Exception: # Debug logging should never break startup pass _debug_db_paths() engine: AsyncEngine = create_async_engine( DATABASE_URL, echo=False, pool_pre_ping=True, future=True, ) # Ensure SQLite pragmas (WAL + FK) when using SQLite if DATABASE_URL.startswith("sqlite"): @event.listens_for(engine.sync_engine, "connect") def _set_sqlite_pragmas(dbapi_connection, connection_record): # type: ignore[override] cursor = dbapi_connection.cursor() cursor.execute("PRAGMA foreign_keys=ON") cursor.execute("PRAGMA journal_mode=WAL") cursor.close() async_session_maker = async_sessionmaker( bind=engine, autoflush=False, expire_on_commit=False, class_=AsyncSession, ) async def get_db() -> AsyncGenerator[AsyncSession, None]: """FastAPI dependency that yields an AsyncSession with automatic rollback on error.""" async with async_session_maker() as session: # type: AsyncSession try: yield session except Exception: await session.rollback() raise finally: await session.close() async def init_db() -> None: """Create all tables (mostly for dev/tests; migrations should be handled by Alembic).""" from . import host, task, schedule, schedule_run, log # noqa: F401 async with engine.begin() as conn: await conn.run_sync(Base.metadata.create_all)