diff --git a/chain_service/controllers/run_chain.py b/chain_service/controllers/run_chain.py index b75b3e2..732e41e 100644 --- a/chain_service/controllers/run_chain.py +++ b/chain_service/controllers/run_chain.py @@ -34,7 +34,7 @@ async def run_chain_controller( try: assert (chain := await chain_repository.get_by_id(run_chain_input.chain_id)) - if await running_chain_repository.exists(str(chain.id)): + if await running_chain_repository.exists(str(run_chain_input.task_id)): logger.error(f"Chain {chain.id} is already running") raise HTTPException(status_code=409, detail="Chain is already running") @@ -46,7 +46,7 @@ async def run_chain_controller( ) progress_chain = await progress_chain_repository.upsert(progress_chain) - await running_chain_repository.add(str(chain.id)) + await running_chain_repository.add(str(progress_chain.task_id)) asyncio.create_task(progress_chain_runner_service.process(progress_chain)) return {} @@ -55,6 +55,9 @@ async def run_chain_controller( logger.warning(f"Chain not found {run_chain_input.chain_id}") raise HTTPException(status_code=404, detail="Chain not found") + except HTTPException: + raise + except Exception: logger.exception("Error during run chain") raise HTTPException(status_code=500, detail="Error during run chain") @@ -66,8 +69,8 @@ async def abort_chain_controller( running_chain_repository: RunningChainRepositoryDependency, ): try: - assert await running_chain_repository.exists(abort_chain_input.chain_id) - await running_chain_repository.delete(abort_chain_input.chain_id) + assert await running_chain_repository.exists(str(abort_chain_input.task_id)) + await running_chain_repository.delete(str(abort_chain_input.task_id)) return {} except AssertionError: diff --git a/chain_service/database/models/running_chain.py b/chain_service/database/models/running_chain.py index 4fe8057..ebd8c75 100644 --- a/chain_service/database/models/running_chain.py +++ b/chain_service/database/models/running_chain.py @@ -3,7 +3,7 @@ from pydantic import BaseModel class RunningChain(BaseModel): - chain_id: str + task_id: str class Config(BaseConfig): pass diff --git a/chain_service/repositories/running_chain.py b/chain_service/repositories/running_chain.py index 472a237..80a908b 100644 --- a/chain_service/repositories/running_chain.py +++ b/chain_service/repositories/running_chain.py @@ -6,16 +6,16 @@ class RunningChainRepository: def __init__(self, database: Database): self.collection = database.get_collection("running_chains") - async def add(self, chain_id: str): - query = payload = {"chainId": chain_id} + async def add(self, task_id: str): + query = payload = {"taskId": task_id} await self.collection.replace_one(query, payload, upsert=True) - async def exists(self, chain_id: str) -> bool: - query = {"chainId": chain_id} + async def exists(self, task_id: str) -> bool: + query = {"taskId": task_id} return bool(await self.collection.find_one(query)) - async def delete(self, chain_id: str): - query = {"chainId": chain_id} + async def delete(self, task_id: str): + query = {"taskId": task_id} await self.collection.delete_one(query) async def delete_all(self): diff --git a/chain_service/schema/run_chain.py b/chain_service/schema/run_chain.py index 38266ea..e5693f2 100644 --- a/chain_service/schema/run_chain.py +++ b/chain_service/schema/run_chain.py @@ -15,4 +15,4 @@ class RunChainInput(BaseModel): class AbortChainInput(BaseModel): - chain_id: str + task_id: str diff --git a/chain_service/services/progress_chain_runner.py b/chain_service/services/progress_chain_runner.py index dc004e5..9d13df9 100644 --- a/chain_service/services/progress_chain_runner.py +++ b/chain_service/services/progress_chain_runner.py @@ -39,7 +39,7 @@ class ProgressChainRunnerService: if not await self.process_action(progress_chain, progress_action): break - await self.running_chain_repository.delete(progress_chain.chain_id) + await self.running_chain_repository.delete(str(progress_chain.task_id)) async def process_action( self, progress_chain: ProgressChain, progress_action: BaseProgressAction @@ -53,7 +53,11 @@ class ProgressChainRunnerService: ) await progress_action_service.process() - assert await self.running_chain_repository.exists(progress_chain.chain_id) + + assert await self.running_chain_repository.exists( + str(progress_chain.task_id) + ) + progress_action.status = ProgressActionStatusEnum.DONE return True