diff --git a/chain_service/controllers/run_chain.py b/chain_service/controllers/run_chain.py index 37b1fd5..69c4abc 100644 --- a/chain_service/controllers/run_chain.py +++ b/chain_service/controllers/run_chain.py @@ -1,4 +1,4 @@ -from chain_service.schema.run_chain import RunChainInput +from chain_service.schema.run_chain import RunChainInput, AbortChainInput from chain_service.dependencies.chain_repository import ChainRepositoryDependency @@ -10,6 +10,10 @@ from chain_service.dependencies.progress_chain_runner_service import ( ProgressChainRunnerServiceDependency, ) +from chain_service.dependencies.running_chain_repository import ( + RunningChainRepositoryDependency, +) + from chain_service.database.models.progress_chain import ProgressChain import asyncio @@ -25,10 +29,15 @@ async def run_chain_controller( chain_repository: ChainRepositoryDependency, progress_chain_repository: ProgressChainRepositoryDependency, progress_chain_runner_service: ProgressChainRunnerServiceDependency, + running_chain_repository: RunningChainRepositoryDependency, ): try: assert (chain := await chain_repository.get_by_id(run_chain_input.chain_id)) + if await running_chain_repository.exists(str(chain.id)): + logger.error(f"Chain {chain.id} is already running") + raise HTTPException(status_code=409, detail="Chain is already running") + progress_chain = ProgressChain.create_from_chain( chain=chain, task_id=run_chain_input.task_id, @@ -38,6 +47,8 @@ async def run_chain_controller( ) progress_chain = await progress_chain_repository.upsert(progress_chain) + await running_chain_repository.add(str(chain.id)) + asyncio.create_task(progress_chain_runner_service.process(progress_chain)) return {} @@ -48,3 +59,25 @@ async def run_chain_controller( except Exception: logger.exception("Error during run chain") raise HTTPException(status_code=500, detail="Error during run chain") + + +@router.post("/abort_chain") +async def abort_chain_controller( + abort_chain_input: AbortChainInput, + running_chain_repository: RunningChainRepositoryDependency, +): + try: + assert await running_chain_repository.exists(abort_chain_input.chain_id) + await running_chain_repository.delete(abort_chain_input.chain_id) + return {} + + except AssertionError: + logger.warning("Cannot abort a chain that isn't running") + + raise HTTPException( + status_code=409, detail="Cannot abort a chain that isn't running" + ) + + except Exception: + logger.exception("Error during chain abortion") + raise HTTPException(status_code=400, detail="Chain abortion error :)") diff --git a/chain_service/schema/run_chain.py b/chain_service/schema/run_chain.py index 91ba262..38266ea 100644 --- a/chain_service/schema/run_chain.py +++ b/chain_service/schema/run_chain.py @@ -12,3 +12,7 @@ class RunChainInput(BaseModel): class Config(BaseConfig): pass + + +class AbortChainInput(BaseModel): + chain_id: str