68 lines
2.5 KiB
Python

from __future__ import annotations
from datetime import datetime, timezone
from typing import Iterable, Optional
from sqlalchemy import select, update
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from models.host import Host
class HostRepository:
def __init__(self, session: AsyncSession):
self.session = session
async def list(self, limit: int = 100, offset: int = 0, include_deleted: bool = False) -> list[Host]:
stmt = select(Host).order_by(Host.created_at.desc()).offset(offset).limit(limit)
if not include_deleted:
stmt = stmt.where(Host.deleted_at.is_(None))
result = await self.session.execute(stmt)
return result.scalars().all()
async def get(self, host_id: str, include_deleted: bool = False) -> Optional[Host]:
stmt = select(Host).where(Host.id == host_id).options(selectinload(Host.bootstrap_statuses))
if not include_deleted:
stmt = stmt.where(Host.deleted_at.is_(None))
result = await self.session.execute(stmt)
return result.scalar_one_or_none()
async def get_by_ip(self, ip_address: str, include_deleted: bool = False) -> Optional[Host]:
stmt = select(Host).where(Host.ip_address == ip_address)
if not include_deleted:
stmt = stmt.where(Host.deleted_at.is_(None))
result = await self.session.execute(stmt)
return result.scalar_one_or_none()
async def create(self, *, id: str, name: str, ip_address: str, ansible_group: Optional[str] = None,
status: str = "unknown", reachable: bool = False, last_seen: Optional[datetime] = None) -> Host:
host = Host(
id=id,
name=name,
ip_address=ip_address,
ansible_group=ansible_group,
status=status,
reachable=reachable,
last_seen=last_seen,
)
self.session.add(host)
await self.session.flush()
return host
async def update(self, host: Host, **fields) -> Host:
for key, value in fields.items():
if value is not None:
setattr(host, key, value)
await self.session.flush()
return host
async def soft_delete(self, host_id: str) -> bool:
stmt = (
update(Host)
.where(Host.id == host_id, Host.deleted_at.is_(None))
.values(deleted_at=datetime.now(timezone.utc))
)
result = await self.session.execute(stmt)
return result.rowcount > 0