abort chain controller

This commit is contained in:
Robert 2024-03-21 02:53:00 +07:00
parent 36c36d2c89
commit b14c8bf20e
No known key found for this signature in database
GPG Key ID: F631C7FD957D5F22
2 changed files with 38 additions and 1 deletions

View File

@ -1,4 +1,4 @@
from chain_service.schema.run_chain import RunChainInput from chain_service.schema.run_chain import RunChainInput, AbortChainInput
from chain_service.dependencies.chain_repository import ChainRepositoryDependency from chain_service.dependencies.chain_repository import ChainRepositoryDependency
@ -10,6 +10,10 @@ from chain_service.dependencies.progress_chain_runner_service import (
ProgressChainRunnerServiceDependency, ProgressChainRunnerServiceDependency,
) )
from chain_service.dependencies.running_chain_repository import (
RunningChainRepositoryDependency,
)
from chain_service.database.models.progress_chain import ProgressChain from chain_service.database.models.progress_chain import ProgressChain
import asyncio import asyncio
@ -25,10 +29,15 @@ async def run_chain_controller(
chain_repository: ChainRepositoryDependency, chain_repository: ChainRepositoryDependency,
progress_chain_repository: ProgressChainRepositoryDependency, progress_chain_repository: ProgressChainRepositoryDependency,
progress_chain_runner_service: ProgressChainRunnerServiceDependency, progress_chain_runner_service: ProgressChainRunnerServiceDependency,
running_chain_repository: RunningChainRepositoryDependency,
): ):
try: try:
assert (chain := await chain_repository.get_by_id(run_chain_input.chain_id)) assert (chain := await chain_repository.get_by_id(run_chain_input.chain_id))
if await running_chain_repository.exists(str(chain.id)):
logger.error(f"Chain {chain.id} is already running")
raise HTTPException(status_code=409, detail="Chain is already running")
progress_chain = ProgressChain.create_from_chain( progress_chain = ProgressChain.create_from_chain(
chain=chain, chain=chain,
task_id=run_chain_input.task_id, task_id=run_chain_input.task_id,
@ -38,6 +47,8 @@ async def run_chain_controller(
) )
progress_chain = await progress_chain_repository.upsert(progress_chain) progress_chain = await progress_chain_repository.upsert(progress_chain)
await running_chain_repository.add(str(chain.id))
asyncio.create_task(progress_chain_runner_service.process(progress_chain)) asyncio.create_task(progress_chain_runner_service.process(progress_chain))
return {} return {}
@ -48,3 +59,25 @@ async def run_chain_controller(
except Exception: except Exception:
logger.exception("Error during run chain") logger.exception("Error during run chain")
raise HTTPException(status_code=500, detail="Error during run chain") raise HTTPException(status_code=500, detail="Error during run chain")
@router.post("/abort_chain")
async def abort_chain_controller(
abort_chain_input: AbortChainInput,
running_chain_repository: RunningChainRepositoryDependency,
):
try:
assert await running_chain_repository.exists(abort_chain_input.chain_id)
await running_chain_repository.delete(abort_chain_input.chain_id)
return {}
except AssertionError:
logger.warning("Cannot abort a chain that isn't running")
raise HTTPException(
status_code=409, detail="Cannot abort a chain that isn't running"
)
except Exception:
logger.exception("Error during chain abortion")
raise HTTPException(status_code=400, detail="Chain abortion error :)")

View File

@ -12,3 +12,7 @@ class RunChainInput(BaseModel):
class Config(BaseConfig): class Config(BaseConfig):
pass pass
class AbortChainInput(BaseModel):
chain_id: str