from chain_service.schema.run_chain import RunChainInput from chain_service.dependencies.chain_repository import ChainRepositoryDependency from chain_service.dependencies.progress_chain_repository import ( ProgressChainRepositoryDependency, ) from chain_service.dependencies.progress_chain_runner_service import ( ProgressChainRunnerServiceDependency, ) from chain_service.database.models.progress_chain import ProgressChain import asyncio from loguru import logger from fastapi import APIRouter, HTTPException router = APIRouter() @router.post("/run_chain") async def run_chain_controller( run_chain_input: RunChainInput, chain_repository: ChainRepositoryDependency, progress_chain_repository: ProgressChainRepositoryDependency, progress_chain_runner_service: ProgressChainRunnerServiceDependency, ): try: assert (chain := await chain_repository.get_by_id(run_chain_input.chain_id)) progress_chain = ProgressChain.create_from_chain( chain=chain, task_id=run_chain_input.task_id, namespace_id=chain.namespace_id, recipients=run_chain_input.recipients, ) progress_chain = await progress_chain_repository.upsert(progress_chain) asyncio.create_task(progress_chain_runner_service.process(progress_chain)) return {} except AssertionError: logger.warning(f"Chain not found {run_chain_input.chain_id}") return HTTPException(status_code=404, detail="Chain not found") except Exception: logger.exception("Error during run chain") return HTTPException(status_code=500, detail="Error during run chain")