fixed double aborted chain

This commit is contained in:
Robert 2024-04-02 20:33:59 +07:00
parent 5fe35e2a77
commit bff53e6b07
No known key found for this signature in database
GPG Key ID: F631C7FD957D5F22
4 changed files with 28 additions and 8 deletions

View File

@ -46,7 +46,11 @@ async def run_chain_controller(
) )
progress_chain = await progress_chain_repository.upsert(progress_chain) progress_chain = await progress_chain_repository.upsert(progress_chain)
await running_chain_repository.add(str(progress_chain.task_id))
await running_chain_repository.add(
task_id=str(progress_chain.task_id),
progress_chain_id=str(progress_chain.id),
)
asyncio.create_task(progress_chain_runner_service.process(progress_chain)) asyncio.create_task(progress_chain_runner_service.process(progress_chain))
return {} return {}

View File

@ -1,9 +1,12 @@
from .base import BaseConfig from .base import BaseConfig
from pydantic import BaseModel
from typing import Annotated, Optional
from pydantic import BaseModel, Field
class RunningChain(BaseModel): class RunningChain(BaseModel):
task_id: str task_id: str
progress_chain_id: Annotated[Optional[str], Field(default=None)]
class Config(BaseConfig): class Config(BaseConfig):
pass pass

View File

@ -6,16 +6,25 @@ class RunningChainRepository:
def __init__(self, database: Database): def __init__(self, database: Database):
self.collection = database.get_collection("running_chains") self.collection = database.get_collection("running_chains")
async def add(self, task_id: str): async def add(self, task_id: str, progress_chain_id: str):
query = payload = {"taskId": task_id} query = {"taskId": task_id}
payload = {"taskId": task_id, "progressChainId": progress_chain_id}
await self.collection.replace_one(query, payload, upsert=True) await self.collection.replace_one(query, payload, upsert=True)
async def exists(self, task_id: str) -> bool: async def exists(self, task_id: str, progress_chain_id: str = None) -> bool:
query = {"taskId": task_id} query = {"taskId": task_id}
if progress_chain_id:
query = {"progressChainId": progress_chain_id}
return bool(await self.collection.find_one(query)) return bool(await self.collection.find_one(query))
async def delete(self, task_id: str): async def delete(self, task_id: str, progress_chain_id: str = None) -> bool:
query = {"taskId": task_id} query = {"taskId": task_id}
if progress_chain_id:
query = {"progressChainId": progress_chain_id}
await self.collection.delete_one(query) await self.collection.delete_one(query)
async def delete_all(self): async def delete_all(self):

View File

@ -39,7 +39,10 @@ class ProgressChainRunnerService:
if not await self.process_action(progress_chain, progress_action): if not await self.process_action(progress_chain, progress_action):
break break
await self.running_chain_repository.delete(str(progress_chain.task_id)) await self.running_chain_repository.delete(
task_id=str(progress_chain.task_id),
progress_chain_id=str(progress_chain.id),
)
async def process_action( async def process_action(
self, progress_chain: ProgressChain, progress_action: BaseProgressAction self, progress_chain: ProgressChain, progress_action: BaseProgressAction
@ -55,7 +58,8 @@ class ProgressChainRunnerService:
await progress_action_service.process() await progress_action_service.process()
assert await self.running_chain_repository.exists( assert await self.running_chain_repository.exists(
str(progress_chain.task_id) task_id=str(progress_chain.task_id),
progress_chain_id=str(progress_chain.id),
) )
progress_action.status = ProgressActionStatusEnum.DONE progress_action.status = ProgressActionStatusEnum.DONE