+ Handle cancellation of async tasks
This commit is contained in:
parent
a737bfd154
commit
095896d847
|
|
@ -1,5 +1,7 @@
|
|||
import asyncio
|
||||
import logging
|
||||
import threading
|
||||
from collections import defaultdict
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
|
@ -14,6 +16,24 @@ application = FastAPI()
|
|||
logging.basicConfig(level=logging.DEBUG)
|
||||
|
||||
|
||||
@application.on_event("startup")
|
||||
async def startup_event():
|
||||
application.state.tasks = defaultdict(list)
|
||||
application.state.tasks_lock = asyncio.Lock()
|
||||
|
||||
|
||||
@application.on_event("shutdown")
|
||||
async def shutdown_event():
|
||||
async with application.state.tasks_lock:
|
||||
for task in application.state.tasks.values():
|
||||
try:
|
||||
task.cancel()
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
await asyncio.gather(*application.state.tasks.values(), return_exceptions=True)
|
||||
application.state.tasks.clear()
|
||||
|
||||
|
||||
# Subclass threading.Thread for logging
|
||||
class DebugThread(threading.Thread):
|
||||
def __init__(self, *args, **kwargs):
|
||||
|
|
|
|||
|
|
@ -16,15 +16,17 @@ from chain_service.dependencies.running_chain_repository import (
|
|||
|
||||
from chain_service.database.models.progress_chain import ProgressChain
|
||||
|
||||
import asyncio
|
||||
from loguru import logger
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from fastapi import APIRouter, HTTPException, Request
|
||||
|
||||
from chain_service.utils.tasks import create_task, cancel_task
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post("/run_chain")
|
||||
async def run_chain_controller(
|
||||
request: Request,
|
||||
run_chain_input: RunChainInput,
|
||||
chain_repository: ChainRepositoryDependency,
|
||||
progress_chain_repository: ProgressChainRepositoryDependency,
|
||||
|
|
@ -52,7 +54,7 @@ async def run_chain_controller(
|
|||
progress_chain_id=str(progress_chain.id),
|
||||
)
|
||||
|
||||
asyncio.create_task(progress_chain_runner_service.process(progress_chain))
|
||||
await create_task(request.app, str(run_chain_input.task_id), progress_chain_runner_service.process, progress_chain)
|
||||
return {}
|
||||
|
||||
except AssertionError:
|
||||
|
|
@ -69,12 +71,14 @@ async def run_chain_controller(
|
|||
|
||||
@router.post("/abort_chain")
|
||||
async def abort_chain_controller(
|
||||
request: Request,
|
||||
abort_chain_input: AbortChainInput,
|
||||
running_chain_repository: RunningChainRepositoryDependency,
|
||||
):
|
||||
try:
|
||||
assert await running_chain_repository.exists(str(abort_chain_input.task_id))
|
||||
await running_chain_repository.delete(str(abort_chain_input.task_id))
|
||||
await cancel_task(request.app, abort_chain_input.task_id)
|
||||
return {}
|
||||
|
||||
except AssertionError:
|
||||
|
|
|
|||
|
|
@ -1,8 +1,11 @@
|
|||
import asyncio
|
||||
from loguru import logger
|
||||
from .base import BaseProgressActionService
|
||||
|
||||
|
||||
class WaitProgressActionService(BaseProgressActionService):
|
||||
|
||||
async def process(self):
|
||||
logger.info('WaitProgressActionService task started')
|
||||
await asyncio.sleep(self.progress_action.wait_for)
|
||||
logger.info('WaitProgressActionService task ended')
|
||||
|
|
|
|||
27
chain_service/utils/tasks.py
Normal file
27
chain_service/utils/tasks.py
Normal file
|
|
@ -0,0 +1,27 @@
|
|||
import asyncio
|
||||
from asyncio import Task
|
||||
from typing import List
|
||||
|
||||
from fastapi import FastAPI
|
||||
from loguru import logger
|
||||
|
||||
|
||||
async def create_task(app: FastAPI, identifier, callable, *args, **kwargs):
|
||||
async with app.state.tasks_lock:
|
||||
task = asyncio.create_task(callable(*args, **kwargs))
|
||||
app.state.tasks[identifier].append(task)
|
||||
logger.info(f"Task {identifier} created")
|
||||
|
||||
|
||||
async def cancel_task(app: FastAPI, identifier):
|
||||
async with app.state.tasks_lock:
|
||||
tasks: List[Task] = app.state.tasks.get(identifier)
|
||||
if not tasks:
|
||||
logger.info(f"Task {identifier} not found, can't cancel")
|
||||
for task in tasks:
|
||||
try:
|
||||
task.cancel()
|
||||
await task # Wait for the task to finish cancellation
|
||||
except asyncio.CancelledError:
|
||||
logger.info(f"Task {identifier} cancelled")
|
||||
del app.state.tasks[identifier]
|
||||
Loading…
Reference in New Issue
Block a user