from __future__ import annotations from datetime import datetime from uuid import UUID, uuid4 from fastapi import HTTPException from sqlalchemy import select, update from sqlalchemy.orm import Session from fastapi_demo.app.infrastructure.db.models import ( Asset as AssetORM, AssetEvent as AssetEventORM, ) from fastapi_demo.app.domain.status import AssetStatus from fastapi_demo.app.schemas.asset import AssetOut, AssetEventOut class SqlAssetsRepo: def __init__(self, db: Session) -> None: self.db = db def create(self, asset: AssetOut) -> None: row = AssetORM( id=str(asset.id), name=asset.name, serial=asset.serial, status=str(asset.status), revision=asset.revision, # <-- wichtig created_at=asset.updated_at, # MVP: created_at == updated_at updated_at=asset.updated_at, ) self.db.add(row) self.db.commit() def get(self, asset_id: UUID) -> AssetOut | None: row = self.db.get(AssetORM, str(asset_id)) if not row: return None return AssetOut( id=UUID(row.id), name=row.name, serial=row.serial, status=AssetStatus(row.status), revision=row.revision, # <-- wichtig updated_at=row.updated_at, ) def transition_with_revision( self, asset_id: UUID, expected_revision: int, to_status: AssetStatus, at: datetime, note: str | None, ) -> tuple[AssetOut, AssetEventOut]: with self.db.begin(): # <-- begin ganz nach oben current = self.db.get(AssetORM, str(asset_id)) if not current: raise HTTPException(status_code=404, detail="Asset nicht gefunden") from_status = AssetStatus(current.status) stmt = ( update(AssetORM) .where( AssetORM.id == str(asset_id), AssetORM.revision == expected_revision, ) .values( status=str(to_status), updated_at=at, revision=expected_revision + 1, ) ) res = self.db.execute(stmt) if res.rowcount != 1: # aktuelle Revision für saubere Fehlermeldung neu lesen latest = self.db.get(AssetORM, str(asset_id)) raise HTTPException( status_code=409, detail={ "message": "Revision-Konflikt", "expected_revision": expected_revision, "current_revision": latest.revision if latest else None, }, ) event_row = AssetEventORM( id=str(uuid4()), asset_id=str(asset_id), from_status=str(from_status), to_status=str(to_status), at=at, note=note, ) self.db.add(event_row) # nach Commit: updated Asset laden updated = self.db.get(AssetORM, str(asset_id)) assert updated is not None asset_out = AssetOut( id=UUID(updated.id), name=updated.name, serial=updated.serial, status=AssetStatus(updated.status), revision=updated.revision, updated_at=updated.updated_at, ) event_out = AssetEventOut( asset_id=asset_out.id, from_status=from_status, to_status=to_status, at=at, note=note, ) return asset_out, event_out def list_events(self, asset_id: UUID) -> list[AssetEventOut]: stmt = ( select(AssetEventORM) .where(AssetEventORM.asset_id == str(asset_id)) .order_by(AssetEventORM.at.asc()) ) rows = self.db.execute(stmt).scalars().all() return [ AssetEventOut( asset_id=UUID(r.asset_id), from_status=AssetStatus(r.from_status), to_status=AssetStatus(r.to_status), at=r.at, note=r.note, ) for r in rows ]