Imago/app/services/pipeline.py

206 lines
9.2 KiB
Python

"""
Pipeline de traitement AI — orchestration des 3 étapes
Chaque étape est indépendante : un échec partiel n'arrête pas le pipeline.
Publie des événements Redis (si disponible) pour le suivi en temps réel.
"""
import json
import logging
import time
from datetime import datetime, timezone
from typing import Any
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
from app.models.image import Image, ProcessingStatus
from app.services.exif_service import extract_exif
from app.services.ocr_service import extract_text
from app.services.ai_vision import analyze_image, extract_text_with_ai
import asyncio
logger = logging.getLogger(__name__)
async def _publish_event(
redis: Any, image_id: int, event: str, data: dict | None = None
) -> None:
"""Publie un événement sur le channel Redis pipeline:{image_id}."""
if redis is None:
return
try:
payload = {"event": event, "image_id": image_id, "timestamp": time.time()}
if data:
payload["data"] = data
await redis.publish(f"pipeline:{image_id}", json.dumps(payload))
except Exception:
pass # Pub/Sub non critique — ne doit pas bloquer le pipeline
async def process_image_pipeline(
image_id: int, db: AsyncSession, redis: Any = None
) -> None:
"""
Pipeline complet de traitement d'une image :
1. Extraction EXIF (async)
2. OCR — extraction texte (async)
3. Vision AI — description + tags (async)
4. Sauvegarde finale en BDD
Le statut est mis à jour à chaque étape pour permettre le polling.
Publie des événements Redis sur le channel pipeline:{image_id}.
"""
# ── Chargement de l'image ─────────────────────────────────
result = await db.execute(select(Image).where(Image.id == image_id))
image = result.scalar_one_or_none()
if not image:
logger.warning("pipeline.image_not_found", extra={"image_id": image_id})
return
# ── Démarrage ─────────────────────────────────────────────
image.processing_status = ProcessingStatus.PROCESSING
image.processing_started_at = datetime.now(timezone.utc)
await db.commit()
await db.refresh(image)
await _publish_event(redis, image_id, "pipeline.started")
errors: list[str] = []
file_path = image.file_path
# ════════════════════════════════════════════════════════════
# ÉTAPE 1 — Extraction EXIF
# ════════════════════════════════════════════════════════════
try:
logger.info("pipeline.step.start", extra={"image_id": image_id, "step": "exif", "step_num": "1/3"})
t0 = time.time()
# Maintenant async et utilise le backend
exif = await extract_exif(file_path)
image.exif_raw = exif.get("raw")
image.exif_make = exif.get("make")
image.exif_model = exif.get("model")
image.exif_lens = exif.get("lens")
image.exif_taken_at = exif.get("taken_at")
image.exif_gps_lat = exif.get("gps_lat")
image.exif_gps_lon = exif.get("gps_lon")
image.exif_altitude = exif.get("altitude")
image.exif_iso = exif.get("iso")
image.exif_aperture = exif.get("aperture")
image.exif_shutter = exif.get("shutter")
image.exif_focal = exif.get("focal")
image.exif_flash = exif.get("flash")
image.exif_orientation = exif.get("orientation")
image.exif_software = exif.get("software")
await db.commit()
elapsed = int((time.time() - t0) * 1000)
logger.info("pipeline.step.done", extra={"image_id": image_id, "step": "exif", "duration_ms": elapsed, "camera": image.exif_make})
await _publish_event(redis, image_id, "step.completed", {
"step": "exif", "duration_ms": elapsed, "camera": image.exif_make,
})
except Exception as e:
msg = f"EXIF : {str(e)}"
errors.append(msg)
logger.error("pipeline.step.error", extra={"image_id": image_id, "step": "exif", "error": str(e)})
# ════════════════════════════════════════════════════════════
# ÉTAPE 2 — OCR
# ════════════════════════════════════════════════════════════
try:
logger.info("pipeline.step.start", extra={"image_id": image_id, "step": "ocr", "step_num": "2/3"})
t0 = time.time()
# Maintenant async et utilise le backend
ocr = await extract_text(file_path)
# Fallback AI si OCR classique échoue ou ne trouve rien
if not ocr.get("has_text", False):
logger.info("pipeline.ocr.fallback", extra={"image_id": image_id, "reason": "tesseract_empty"})
ai_ocr = await extract_text_with_ai(file_path)
if ai_ocr.get("has_text"):
ocr = ai_ocr
logger.info("pipeline.ocr.fallback_success", extra={"image_id": image_id, "chars": len(ocr.get("text", ""))})
else:
logger.info("pipeline.ocr.fallback_empty", extra={"image_id": image_id})
image.ocr_text = ocr.get("text")
image.ocr_language = ocr.get("language")
image.ocr_confidence = ocr.get("confidence")
image.ocr_has_text = ocr.get("has_text", False)
await db.commit()
elapsed = int((time.time() - t0) * 1000)
logger.info("pipeline.step.done", extra={"image_id": image_id, "step": "ocr", "duration_ms": elapsed, "has_text": image.ocr_has_text})
await _publish_event(redis, image_id, "step.completed", {
"step": "ocr", "duration_ms": elapsed, "has_text": image.ocr_has_text,
})
except Exception as e:
msg = f"OCR : {str(e)}"
errors.append(msg)
logger.error("pipeline.step.error", extra={"image_id": image_id, "step": "ocr", "error": str(e)})
# ════════════════════════════════════════════════════════════
# ÉTAPE 3 — Vision AI (description + tags)
# ════════════════════════════════════════════════════════════
try:
logger.info("pipeline.step.start", extra={"image_id": image_id, "step": "ai", "step_num": "3/3"})
t0 = time.time()
ai = await analyze_image(
file_path=file_path,
ocr_hint=image.ocr_text,
)
image.ai_description = ai.get("description")
image.ai_tags = ai.get("tags", [])
image.ai_confidence = ai.get("confidence")
image.ai_model_used = ai.get("model")
image.ai_processed_at = datetime.now(timezone.utc)
image.ai_prompt_tokens = ai.get("prompt_tokens")
image.ai_output_tokens = ai.get("output_tokens")
await db.commit()
elapsed = int((time.time() - t0) * 1000)
logger.info("pipeline.step.done", extra={"image_id": image_id, "step": "ai", "duration_ms": elapsed, "tags_count": len(image.ai_tags or [])})
await _publish_event(redis, image_id, "step.completed", {
"step": "ai", "duration_ms": elapsed, "tags_count": len(image.ai_tags or []),
})
except Exception as e:
msg = f"AI Vision : {str(e)}"
errors.append(msg)
logger.error("pipeline.step.error", extra={"image_id": image_id, "step": "ai", "error": str(e)})
# ════════════════════════════════════════════════════════════
# FINALISATION
# ════════════════════════════════════════════════════════════
image.processing_done_at = datetime.now(timezone.utc)
if errors:
if image.ai_description:
image.processing_status = ProcessingStatus.DONE
image.processing_error = f"Avertissements : {'; '.join(errors)}"
else:
image.processing_status = ProcessingStatus.ERROR
image.processing_error = "; ".join(errors)
else:
image.processing_status = ProcessingStatus.DONE
image.processing_error = None
await db.commit()
logger.info("pipeline.completed", extra={
"image_id": image_id,
"status": image.processing_status.value,
"errors": len(errors),
})
if errors:
await _publish_event(redis, image_id, "pipeline.error", {"errors": errors})
else:
await _publish_event(redis, image_id, "pipeline.done")