abort chain controller
This commit is contained in:
parent
36c36d2c89
commit
b14c8bf20e
|
|
@ -1,4 +1,4 @@
|
||||||
from chain_service.schema.run_chain import RunChainInput
|
from chain_service.schema.run_chain import RunChainInput, AbortChainInput
|
||||||
|
|
||||||
from chain_service.dependencies.chain_repository import ChainRepositoryDependency
|
from chain_service.dependencies.chain_repository import ChainRepositoryDependency
|
||||||
|
|
||||||
|
|
@ -10,6 +10,10 @@ from chain_service.dependencies.progress_chain_runner_service import (
|
||||||
ProgressChainRunnerServiceDependency,
|
ProgressChainRunnerServiceDependency,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from chain_service.dependencies.running_chain_repository import (
|
||||||
|
RunningChainRepositoryDependency,
|
||||||
|
)
|
||||||
|
|
||||||
from chain_service.database.models.progress_chain import ProgressChain
|
from chain_service.database.models.progress_chain import ProgressChain
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
@ -25,10 +29,15 @@ async def run_chain_controller(
|
||||||
chain_repository: ChainRepositoryDependency,
|
chain_repository: ChainRepositoryDependency,
|
||||||
progress_chain_repository: ProgressChainRepositoryDependency,
|
progress_chain_repository: ProgressChainRepositoryDependency,
|
||||||
progress_chain_runner_service: ProgressChainRunnerServiceDependency,
|
progress_chain_runner_service: ProgressChainRunnerServiceDependency,
|
||||||
|
running_chain_repository: RunningChainRepositoryDependency,
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
assert (chain := await chain_repository.get_by_id(run_chain_input.chain_id))
|
assert (chain := await chain_repository.get_by_id(run_chain_input.chain_id))
|
||||||
|
|
||||||
|
if await running_chain_repository.exists(str(chain.id)):
|
||||||
|
logger.error(f"Chain {chain.id} is already running")
|
||||||
|
raise HTTPException(status_code=409, detail="Chain is already running")
|
||||||
|
|
||||||
progress_chain = ProgressChain.create_from_chain(
|
progress_chain = ProgressChain.create_from_chain(
|
||||||
chain=chain,
|
chain=chain,
|
||||||
task_id=run_chain_input.task_id,
|
task_id=run_chain_input.task_id,
|
||||||
|
|
@ -38,6 +47,8 @@ async def run_chain_controller(
|
||||||
)
|
)
|
||||||
|
|
||||||
progress_chain = await progress_chain_repository.upsert(progress_chain)
|
progress_chain = await progress_chain_repository.upsert(progress_chain)
|
||||||
|
await running_chain_repository.add(str(chain.id))
|
||||||
|
|
||||||
asyncio.create_task(progress_chain_runner_service.process(progress_chain))
|
asyncio.create_task(progress_chain_runner_service.process(progress_chain))
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
|
@ -48,3 +59,25 @@ async def run_chain_controller(
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Error during run chain")
|
logger.exception("Error during run chain")
|
||||||
raise HTTPException(status_code=500, detail="Error during run chain")
|
raise HTTPException(status_code=500, detail="Error during run chain")
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/abort_chain")
|
||||||
|
async def abort_chain_controller(
|
||||||
|
abort_chain_input: AbortChainInput,
|
||||||
|
running_chain_repository: RunningChainRepositoryDependency,
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
assert await running_chain_repository.exists(abort_chain_input.chain_id)
|
||||||
|
await running_chain_repository.delete(abort_chain_input.chain_id)
|
||||||
|
return {}
|
||||||
|
|
||||||
|
except AssertionError:
|
||||||
|
logger.warning("Cannot abort a chain that isn't running")
|
||||||
|
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=409, detail="Cannot abort a chain that isn't running"
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Error during chain abortion")
|
||||||
|
raise HTTPException(status_code=400, detail="Chain abortion error :)")
|
||||||
|
|
|
||||||
|
|
@ -12,3 +12,7 @@ class RunChainInput(BaseModel):
|
||||||
|
|
||||||
class Config(BaseConfig):
|
class Config(BaseConfig):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class AbortChainInput(BaseModel):
|
||||||
|
chain_id: str
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user