From 095896d8472e3868e7a65ff38f617113aad00025 Mon Sep 17 00:00:00 2001 From: phzhik Date: Sun, 4 Aug 2024 23:48:39 +0400 Subject: [PATCH] + Handle cancellation of async tasks --- chain_service/app.py | 20 ++++++++++++++ chain_service/controllers/run_chain.py | 10 ++++--- .../progress_action/wait_progress_action.py | 3 +++ chain_service/utils/tasks.py | 27 +++++++++++++++++++ 4 files changed, 57 insertions(+), 3 deletions(-) create mode 100644 chain_service/utils/tasks.py diff --git a/chain_service/app.py b/chain_service/app.py index 5085b9a..6f18f33 100644 --- a/chain_service/app.py +++ b/chain_service/app.py @@ -1,5 +1,7 @@ +import asyncio import logging import threading +from collections import defaultdict from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware @@ -14,6 +16,24 @@ application = FastAPI() logging.basicConfig(level=logging.DEBUG) +@application.on_event("startup") +async def startup_event(): + application.state.tasks = defaultdict(list) + application.state.tasks_lock = asyncio.Lock() + + +@application.on_event("shutdown") +async def shutdown_event(): + async with application.state.tasks_lock: + for task in application.state.tasks.values(): + try: + task.cancel() + except asyncio.CancelledError: + pass + await asyncio.gather(*application.state.tasks.values(), return_exceptions=True) + application.state.tasks.clear() + + # Subclass threading.Thread for logging class DebugThread(threading.Thread): def __init__(self, *args, **kwargs): diff --git a/chain_service/controllers/run_chain.py b/chain_service/controllers/run_chain.py index c1691f0..600befa 100644 --- a/chain_service/controllers/run_chain.py +++ b/chain_service/controllers/run_chain.py @@ -16,15 +16,17 @@ from chain_service.dependencies.running_chain_repository import ( from chain_service.database.models.progress_chain import ProgressChain -import asyncio from loguru import logger -from fastapi import APIRouter, HTTPException +from fastapi import APIRouter, HTTPException, Request + +from chain_service.utils.tasks import create_task, cancel_task router = APIRouter() @router.post("/run_chain") async def run_chain_controller( + request: Request, run_chain_input: RunChainInput, chain_repository: ChainRepositoryDependency, progress_chain_repository: ProgressChainRepositoryDependency, @@ -52,7 +54,7 @@ async def run_chain_controller( progress_chain_id=str(progress_chain.id), ) - asyncio.create_task(progress_chain_runner_service.process(progress_chain)) + await create_task(request.app, str(run_chain_input.task_id), progress_chain_runner_service.process, progress_chain) return {} except AssertionError: @@ -69,12 +71,14 @@ async def run_chain_controller( @router.post("/abort_chain") async def abort_chain_controller( + request: Request, abort_chain_input: AbortChainInput, running_chain_repository: RunningChainRepositoryDependency, ): try: assert await running_chain_repository.exists(str(abort_chain_input.task_id)) await running_chain_repository.delete(str(abort_chain_input.task_id)) + await cancel_task(request.app, abort_chain_input.task_id) return {} except AssertionError: diff --git a/chain_service/services/progress_action/wait_progress_action.py b/chain_service/services/progress_action/wait_progress_action.py index 4651e2c..0b4970e 100644 --- a/chain_service/services/progress_action/wait_progress_action.py +++ b/chain_service/services/progress_action/wait_progress_action.py @@ -1,8 +1,11 @@ import asyncio +from loguru import logger from .base import BaseProgressActionService class WaitProgressActionService(BaseProgressActionService): async def process(self): + logger.info('WaitProgressActionService task started') await asyncio.sleep(self.progress_action.wait_for) + logger.info('WaitProgressActionService task ended') diff --git a/chain_service/utils/tasks.py b/chain_service/utils/tasks.py new file mode 100644 index 0000000..9194c0d --- /dev/null +++ b/chain_service/utils/tasks.py @@ -0,0 +1,27 @@ +import asyncio +from asyncio import Task +from typing import List + +from fastapi import FastAPI +from loguru import logger + + +async def create_task(app: FastAPI, identifier, callable, *args, **kwargs): + async with app.state.tasks_lock: + task = asyncio.create_task(callable(*args, **kwargs)) + app.state.tasks[identifier].append(task) + logger.info(f"Task {identifier} created") + + +async def cancel_task(app: FastAPI, identifier): + async with app.state.tasks_lock: + tasks: List[Task] = app.state.tasks.get(identifier) + if not tasks: + logger.info(f"Task {identifier} not found, can't cancel") + for task in tasks: + try: + task.cancel() + await task # Wait for the task to finish cancellation + except asyncio.CancelledError: + logger.info(f"Task {identifier} cancelled") + del app.state.tasks[identifier] \ No newline at end of file