+ Handle cancellation of async tasks

This commit is contained in:
Phil Zhitnikov 2024-08-04 23:48:39 +04:00
parent a737bfd154
commit 095896d847
4 changed files with 57 additions and 3 deletions

View File

@ -1,5 +1,7 @@
import asyncio
import logging
import threading
from collections import defaultdict
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
@ -14,6 +16,24 @@ application = FastAPI()
logging.basicConfig(level=logging.DEBUG)
@application.on_event("startup")
async def startup_event():
application.state.tasks = defaultdict(list)
application.state.tasks_lock = asyncio.Lock()
@application.on_event("shutdown")
async def shutdown_event():
async with application.state.tasks_lock:
for task in application.state.tasks.values():
try:
task.cancel()
except asyncio.CancelledError:
pass
await asyncio.gather(*application.state.tasks.values(), return_exceptions=True)
application.state.tasks.clear()
# Subclass threading.Thread for logging
class DebugThread(threading.Thread):
def __init__(self, *args, **kwargs):

View File

@ -16,15 +16,17 @@ from chain_service.dependencies.running_chain_repository import (
from chain_service.database.models.progress_chain import ProgressChain
import asyncio
from loguru import logger
from fastapi import APIRouter, HTTPException
from fastapi import APIRouter, HTTPException, Request
from chain_service.utils.tasks import create_task, cancel_task
router = APIRouter()
@router.post("/run_chain")
async def run_chain_controller(
request: Request,
run_chain_input: RunChainInput,
chain_repository: ChainRepositoryDependency,
progress_chain_repository: ProgressChainRepositoryDependency,
@ -52,7 +54,7 @@ async def run_chain_controller(
progress_chain_id=str(progress_chain.id),
)
asyncio.create_task(progress_chain_runner_service.process(progress_chain))
await create_task(request.app, str(run_chain_input.task_id), progress_chain_runner_service.process, progress_chain)
return {}
except AssertionError:
@ -69,12 +71,14 @@ async def run_chain_controller(
@router.post("/abort_chain")
async def abort_chain_controller(
request: Request,
abort_chain_input: AbortChainInput,
running_chain_repository: RunningChainRepositoryDependency,
):
try:
assert await running_chain_repository.exists(str(abort_chain_input.task_id))
await running_chain_repository.delete(str(abort_chain_input.task_id))
await cancel_task(request.app, abort_chain_input.task_id)
return {}
except AssertionError:

View File

@ -1,8 +1,11 @@
import asyncio
from loguru import logger
from .base import BaseProgressActionService
class WaitProgressActionService(BaseProgressActionService):
async def process(self):
logger.info('WaitProgressActionService task started')
await asyncio.sleep(self.progress_action.wait_for)
logger.info('WaitProgressActionService task ended')

View File

@ -0,0 +1,27 @@
import asyncio
from asyncio import Task
from typing import List
from fastapi import FastAPI
from loguru import logger
async def create_task(app: FastAPI, identifier, callable, *args, **kwargs):
async with app.state.tasks_lock:
task = asyncio.create_task(callable(*args, **kwargs))
app.state.tasks[identifier].append(task)
logger.info(f"Task {identifier} created")
async def cancel_task(app: FastAPI, identifier):
async with app.state.tasks_lock:
tasks: List[Task] = app.state.tasks.get(identifier)
if not tasks:
logger.info(f"Task {identifier} not found, can't cancel")
for task in tasks:
try:
task.cancel()
await task # Wait for the task to finish cancellation
except asyncio.CancelledError:
logger.info(f"Task {identifier} cancelled")
del app.state.tasks[identifier]