abort chain by task id

This commit is contained in:
Robert 2024-03-24 01:16:50 +07:00
parent 94c7f98ce8
commit 5fe35e2a77
No known key found for this signature in database
GPG Key ID: F631C7FD957D5F22
5 changed files with 21 additions and 14 deletions

View File

@ -34,7 +34,7 @@ async def run_chain_controller(
try: try:
assert (chain := await chain_repository.get_by_id(run_chain_input.chain_id)) assert (chain := await chain_repository.get_by_id(run_chain_input.chain_id))
if await running_chain_repository.exists(str(chain.id)): if await running_chain_repository.exists(str(run_chain_input.task_id)):
logger.error(f"Chain {chain.id} is already running") logger.error(f"Chain {chain.id} is already running")
raise HTTPException(status_code=409, detail="Chain is already running") raise HTTPException(status_code=409, detail="Chain is already running")
@ -46,7 +46,7 @@ async def run_chain_controller(
) )
progress_chain = await progress_chain_repository.upsert(progress_chain) progress_chain = await progress_chain_repository.upsert(progress_chain)
await running_chain_repository.add(str(chain.id)) await running_chain_repository.add(str(progress_chain.task_id))
asyncio.create_task(progress_chain_runner_service.process(progress_chain)) asyncio.create_task(progress_chain_runner_service.process(progress_chain))
return {} return {}
@ -55,6 +55,9 @@ async def run_chain_controller(
logger.warning(f"Chain not found {run_chain_input.chain_id}") logger.warning(f"Chain not found {run_chain_input.chain_id}")
raise HTTPException(status_code=404, detail="Chain not found") raise HTTPException(status_code=404, detail="Chain not found")
except HTTPException:
raise
except Exception: except Exception:
logger.exception("Error during run chain") logger.exception("Error during run chain")
raise HTTPException(status_code=500, detail="Error during run chain") raise HTTPException(status_code=500, detail="Error during run chain")
@ -66,8 +69,8 @@ async def abort_chain_controller(
running_chain_repository: RunningChainRepositoryDependency, running_chain_repository: RunningChainRepositoryDependency,
): ):
try: try:
assert await running_chain_repository.exists(abort_chain_input.chain_id) assert await running_chain_repository.exists(str(abort_chain_input.task_id))
await running_chain_repository.delete(abort_chain_input.chain_id) await running_chain_repository.delete(str(abort_chain_input.task_id))
return {} return {}
except AssertionError: except AssertionError:

View File

@ -3,7 +3,7 @@ from pydantic import BaseModel
class RunningChain(BaseModel): class RunningChain(BaseModel):
chain_id: str task_id: str
class Config(BaseConfig): class Config(BaseConfig):
pass pass

View File

@ -6,16 +6,16 @@ class RunningChainRepository:
def __init__(self, database: Database): def __init__(self, database: Database):
self.collection = database.get_collection("running_chains") self.collection = database.get_collection("running_chains")
async def add(self, chain_id: str): async def add(self, task_id: str):
query = payload = {"chainId": chain_id} query = payload = {"taskId": task_id}
await self.collection.replace_one(query, payload, upsert=True) await self.collection.replace_one(query, payload, upsert=True)
async def exists(self, chain_id: str) -> bool: async def exists(self, task_id: str) -> bool:
query = {"chainId": chain_id} query = {"taskId": task_id}
return bool(await self.collection.find_one(query)) return bool(await self.collection.find_one(query))
async def delete(self, chain_id: str): async def delete(self, task_id: str):
query = {"chainId": chain_id} query = {"taskId": task_id}
await self.collection.delete_one(query) await self.collection.delete_one(query)
async def delete_all(self): async def delete_all(self):

View File

@ -15,4 +15,4 @@ class RunChainInput(BaseModel):
class AbortChainInput(BaseModel): class AbortChainInput(BaseModel):
chain_id: str task_id: str

View File

@ -39,7 +39,7 @@ class ProgressChainRunnerService:
if not await self.process_action(progress_chain, progress_action): if not await self.process_action(progress_chain, progress_action):
break break
await self.running_chain_repository.delete(progress_chain.chain_id) await self.running_chain_repository.delete(str(progress_chain.task_id))
async def process_action( async def process_action(
self, progress_chain: ProgressChain, progress_action: BaseProgressAction self, progress_chain: ProgressChain, progress_action: BaseProgressAction
@ -53,7 +53,11 @@ class ProgressChainRunnerService:
) )
await progress_action_service.process() await progress_action_service.process()
assert await self.running_chain_repository.exists(progress_chain.chain_id)
assert await self.running_chain_repository.exists(
str(progress_chain.task_id)
)
progress_action.status = ProgressActionStatusEnum.DONE progress_action.status = ProgressActionStatusEnum.DONE
return True return True