119 lines
4.0 KiB
Python

"""Database configuration and session management for Homelab Automation.
Uses SQLAlchemy 2.x async engine with SQLite + aiosqlite driver.
"""
from __future__ import annotations
import os
from pathlib import Path
from typing import AsyncGenerator
from urllib.parse import urlparse
from sqlalchemy import event, MetaData
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker, create_async_engine
from sqlalchemy.orm import declarative_base
# Naming convention to keep Alembic happy with constraints
NAMING_CONVENTION = {
"ix": "ix_%(column_0_label)s",
"uq": "uq_%(table_name)s_%(column_0_name)s",
"ck": "ck_%(table_name)s_%(constraint_name)s",
"fk": "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s",
"pk": "pk_%(table_name)s",
}
metadata_obj = MetaData(naming_convention=NAMING_CONVENTION)
Base = declarative_base(metadata=metadata_obj)
# Resolve base path (project root)
ROOT_DIR = Path(__file__).resolve().parents[2]
DEFAULT_DB_PATH = Path(os.environ.get("DB_PATH") or (ROOT_DIR / "data" / "homelab.db"))
DATABASE_URL = os.environ.get("DATABASE_URL", f"sqlite+aiosqlite:///{DEFAULT_DB_PATH}")
# Ensure SQLite directory exists even if DATABASE_URL overrides DB_PATH
def _ensure_sqlite_dir(db_url: str) -> None:
if not db_url.startswith("sqlite"):
return
# Extraire le chemin après sqlite+aiosqlite:///
# Sur Windows, le chemin peut être C:\... donc on ne peut pas utiliser urlparse
prefix = "sqlite+aiosqlite:///"
if db_url.startswith(prefix):
path_str = db_url[len(prefix):]
# Sur Windows, le chemin peut commencer par une lettre de lecteur (C:)
db_path = Path(path_str)
if db_path.parent and str(db_path.parent) != ".":
db_path.parent.mkdir(parents=True, exist_ok=True)
DEFAULT_DB_PATH.parent.mkdir(parents=True, exist_ok=True)
_ensure_sqlite_dir(DATABASE_URL)
def _debug_db_paths() -> None:
try:
print(
"[DB] DATABASE_URL=%s, DEFAULT_DB_PATH=%s, parent_exists=%s, parent=%s"
% (
DATABASE_URL,
DEFAULT_DB_PATH,
DEFAULT_DB_PATH.parent.exists(),
DEFAULT_DB_PATH.parent,
)
)
except Exception:
# Debug logging should never break startup
pass
_debug_db_paths()
engine: AsyncEngine = create_async_engine(
DATABASE_URL,
echo=False,
pool_pre_ping=True,
future=True,
)
# Ensure SQLite pragmas (WAL + FK) when using SQLite
if DATABASE_URL.startswith("sqlite"):
@event.listens_for(engine.sync_engine, "connect")
def _set_sqlite_pragmas(dbapi_connection, connection_record): # type: ignore[override]
cursor = dbapi_connection.cursor()
cursor.execute("PRAGMA foreign_keys=ON")
# WAL mode can fail on some Docker volume mounts (e.g., NFS, CIFS, overlay issues)
# Fall back to DELETE mode if WAL fails
try:
cursor.execute("PRAGMA journal_mode=WAL")
except Exception:
# WAL not supported, use DELETE mode instead
try:
cursor.execute("PRAGMA journal_mode=DELETE")
except Exception:
pass # Ignore if this also fails
cursor.close()
async_session_maker = async_sessionmaker(
bind=engine,
autoflush=False,
expire_on_commit=False,
class_=AsyncSession,
)
async def get_db() -> AsyncGenerator[AsyncSession, None]:
"""FastAPI dependency that yields an AsyncSession with automatic rollback on error."""
async with async_session_maker() as session: # type: AsyncSession
try:
yield session
except Exception:
await session.rollback()
raise
finally:
await session.close()
async def init_db() -> None:
"""Create all tables (mostly for dev/tests; migrations should be handled by Alembic)."""
from . import host, task, schedule, schedule_run, log # noqa: F401
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)