diff --git a/chain_service/controllers/run_chain.py b/chain_service/controllers/run_chain.py index 732e41e..c1691f0 100644 --- a/chain_service/controllers/run_chain.py +++ b/chain_service/controllers/run_chain.py @@ -46,7 +46,11 @@ async def run_chain_controller( ) progress_chain = await progress_chain_repository.upsert(progress_chain) - await running_chain_repository.add(str(progress_chain.task_id)) + + await running_chain_repository.add( + task_id=str(progress_chain.task_id), + progress_chain_id=str(progress_chain.id), + ) asyncio.create_task(progress_chain_runner_service.process(progress_chain)) return {} diff --git a/chain_service/database/models/running_chain.py b/chain_service/database/models/running_chain.py index ebd8c75..7ac14fb 100644 --- a/chain_service/database/models/running_chain.py +++ b/chain_service/database/models/running_chain.py @@ -1,9 +1,12 @@ from .base import BaseConfig -from pydantic import BaseModel + +from typing import Annotated, Optional +from pydantic import BaseModel, Field class RunningChain(BaseModel): task_id: str + progress_chain_id: Annotated[Optional[str], Field(default=None)] class Config(BaseConfig): pass diff --git a/chain_service/repositories/running_chain.py b/chain_service/repositories/running_chain.py index 80a908b..d147e94 100644 --- a/chain_service/repositories/running_chain.py +++ b/chain_service/repositories/running_chain.py @@ -6,16 +6,25 @@ class RunningChainRepository: def __init__(self, database: Database): self.collection = database.get_collection("running_chains") - async def add(self, task_id: str): - query = payload = {"taskId": task_id} + async def add(self, task_id: str, progress_chain_id: str): + query = {"taskId": task_id} + payload = {"taskId": task_id, "progressChainId": progress_chain_id} await self.collection.replace_one(query, payload, upsert=True) - async def exists(self, task_id: str) -> bool: + async def exists(self, task_id: str, progress_chain_id: str = None) -> bool: query = {"taskId": task_id} + + if progress_chain_id: + query = {"progressChainId": progress_chain_id} + return bool(await self.collection.find_one(query)) - async def delete(self, task_id: str): + async def delete(self, task_id: str, progress_chain_id: str = None) -> bool: query = {"taskId": task_id} + + if progress_chain_id: + query = {"progressChainId": progress_chain_id} + await self.collection.delete_one(query) async def delete_all(self): diff --git a/chain_service/services/progress_chain_runner.py b/chain_service/services/progress_chain_runner.py index 9d13df9..2a17e45 100644 --- a/chain_service/services/progress_chain_runner.py +++ b/chain_service/services/progress_chain_runner.py @@ -39,7 +39,10 @@ class ProgressChainRunnerService: if not await self.process_action(progress_chain, progress_action): break - await self.running_chain_repository.delete(str(progress_chain.task_id)) + await self.running_chain_repository.delete( + task_id=str(progress_chain.task_id), + progress_chain_id=str(progress_chain.id), + ) async def process_action( self, progress_chain: ProgressChain, progress_action: BaseProgressAction @@ -55,7 +58,8 @@ class ProgressChainRunnerService: await progress_action_service.process() assert await self.running_chain_repository.exists( - str(progress_chain.task_id) + task_id=str(progress_chain.task_id), + progress_chain_id=str(progress_chain.id), ) progress_action.status = ProgressActionStatusEnum.DONE