+ Handle cancellation of async tasks

This commit is contained in:
Phil Zhitnikov 2024-08-04 23:48:39 +04:00
parent a737bfd154
commit 095896d847
4 changed files with 57 additions and 3 deletions

View File

@ -1,5 +1,7 @@
import asyncio
import logging import logging
import threading import threading
from collections import defaultdict
from fastapi import FastAPI from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
@ -14,6 +16,24 @@ application = FastAPI()
logging.basicConfig(level=logging.DEBUG) logging.basicConfig(level=logging.DEBUG)
@application.on_event("startup")
async def startup_event():
application.state.tasks = defaultdict(list)
application.state.tasks_lock = asyncio.Lock()
@application.on_event("shutdown")
async def shutdown_event():
async with application.state.tasks_lock:
for task in application.state.tasks.values():
try:
task.cancel()
except asyncio.CancelledError:
pass
await asyncio.gather(*application.state.tasks.values(), return_exceptions=True)
application.state.tasks.clear()
# Subclass threading.Thread for logging # Subclass threading.Thread for logging
class DebugThread(threading.Thread): class DebugThread(threading.Thread):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):

View File

@ -16,15 +16,17 @@ from chain_service.dependencies.running_chain_repository import (
from chain_service.database.models.progress_chain import ProgressChain from chain_service.database.models.progress_chain import ProgressChain
import asyncio
from loguru import logger from loguru import logger
from fastapi import APIRouter, HTTPException from fastapi import APIRouter, HTTPException, Request
from chain_service.utils.tasks import create_task, cancel_task
router = APIRouter() router = APIRouter()
@router.post("/run_chain") @router.post("/run_chain")
async def run_chain_controller( async def run_chain_controller(
request: Request,
run_chain_input: RunChainInput, run_chain_input: RunChainInput,
chain_repository: ChainRepositoryDependency, chain_repository: ChainRepositoryDependency,
progress_chain_repository: ProgressChainRepositoryDependency, progress_chain_repository: ProgressChainRepositoryDependency,
@ -52,7 +54,7 @@ async def run_chain_controller(
progress_chain_id=str(progress_chain.id), progress_chain_id=str(progress_chain.id),
) )
asyncio.create_task(progress_chain_runner_service.process(progress_chain)) await create_task(request.app, str(run_chain_input.task_id), progress_chain_runner_service.process, progress_chain)
return {} return {}
except AssertionError: except AssertionError:
@ -69,12 +71,14 @@ async def run_chain_controller(
@router.post("/abort_chain") @router.post("/abort_chain")
async def abort_chain_controller( async def abort_chain_controller(
request: Request,
abort_chain_input: AbortChainInput, abort_chain_input: AbortChainInput,
running_chain_repository: RunningChainRepositoryDependency, running_chain_repository: RunningChainRepositoryDependency,
): ):
try: try:
assert await running_chain_repository.exists(str(abort_chain_input.task_id)) assert await running_chain_repository.exists(str(abort_chain_input.task_id))
await running_chain_repository.delete(str(abort_chain_input.task_id)) await running_chain_repository.delete(str(abort_chain_input.task_id))
await cancel_task(request.app, abort_chain_input.task_id)
return {} return {}
except AssertionError: except AssertionError:

View File

@ -1,8 +1,11 @@
import asyncio import asyncio
from loguru import logger
from .base import BaseProgressActionService from .base import BaseProgressActionService
class WaitProgressActionService(BaseProgressActionService): class WaitProgressActionService(BaseProgressActionService):
async def process(self): async def process(self):
logger.info('WaitProgressActionService task started')
await asyncio.sleep(self.progress_action.wait_for) await asyncio.sleep(self.progress_action.wait_for)
logger.info('WaitProgressActionService task ended')

View File

@ -0,0 +1,27 @@
import asyncio
from asyncio import Task
from typing import List
from fastapi import FastAPI
from loguru import logger
async def create_task(app: FastAPI, identifier, callable, *args, **kwargs):
async with app.state.tasks_lock:
task = asyncio.create_task(callable(*args, **kwargs))
app.state.tasks[identifier].append(task)
logger.info(f"Task {identifier} created")
async def cancel_task(app: FastAPI, identifier):
async with app.state.tasks_lock:
tasks: List[Task] = app.state.tasks.get(identifier)
if not tasks:
logger.info(f"Task {identifier} not found, can't cancel")
for task in tasks:
try:
task.cancel()
await task # Wait for the task to finish cancellation
except asyncio.CancelledError:
logger.info(f"Task {identifier} cancelled")
del app.state.tasks[identifier]