+ Handle cancellation of async tasks
This commit is contained in:
parent
a737bfd154
commit
095896d847
|
|
@ -1,5 +1,7 @@
|
||||||
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import threading
|
import threading
|
||||||
|
from collections import defaultdict
|
||||||
|
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
|
@ -14,6 +16,24 @@ application = FastAPI()
|
||||||
logging.basicConfig(level=logging.DEBUG)
|
logging.basicConfig(level=logging.DEBUG)
|
||||||
|
|
||||||
|
|
||||||
|
@application.on_event("startup")
|
||||||
|
async def startup_event():
|
||||||
|
application.state.tasks = defaultdict(list)
|
||||||
|
application.state.tasks_lock = asyncio.Lock()
|
||||||
|
|
||||||
|
|
||||||
|
@application.on_event("shutdown")
|
||||||
|
async def shutdown_event():
|
||||||
|
async with application.state.tasks_lock:
|
||||||
|
for task in application.state.tasks.values():
|
||||||
|
try:
|
||||||
|
task.cancel()
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
await asyncio.gather(*application.state.tasks.values(), return_exceptions=True)
|
||||||
|
application.state.tasks.clear()
|
||||||
|
|
||||||
|
|
||||||
# Subclass threading.Thread for logging
|
# Subclass threading.Thread for logging
|
||||||
class DebugThread(threading.Thread):
|
class DebugThread(threading.Thread):
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
|
|
|
||||||
|
|
@ -16,15 +16,17 @@ from chain_service.dependencies.running_chain_repository import (
|
||||||
|
|
||||||
from chain_service.database.models.progress_chain import ProgressChain
|
from chain_service.database.models.progress_chain import ProgressChain
|
||||||
|
|
||||||
import asyncio
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from fastapi import APIRouter, HTTPException
|
from fastapi import APIRouter, HTTPException, Request
|
||||||
|
|
||||||
|
from chain_service.utils.tasks import create_task, cancel_task
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
@router.post("/run_chain")
|
@router.post("/run_chain")
|
||||||
async def run_chain_controller(
|
async def run_chain_controller(
|
||||||
|
request: Request,
|
||||||
run_chain_input: RunChainInput,
|
run_chain_input: RunChainInput,
|
||||||
chain_repository: ChainRepositoryDependency,
|
chain_repository: ChainRepositoryDependency,
|
||||||
progress_chain_repository: ProgressChainRepositoryDependency,
|
progress_chain_repository: ProgressChainRepositoryDependency,
|
||||||
|
|
@ -52,7 +54,7 @@ async def run_chain_controller(
|
||||||
progress_chain_id=str(progress_chain.id),
|
progress_chain_id=str(progress_chain.id),
|
||||||
)
|
)
|
||||||
|
|
||||||
asyncio.create_task(progress_chain_runner_service.process(progress_chain))
|
await create_task(request.app, str(run_chain_input.task_id), progress_chain_runner_service.process, progress_chain)
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
except AssertionError:
|
except AssertionError:
|
||||||
|
|
@ -69,12 +71,14 @@ async def run_chain_controller(
|
||||||
|
|
||||||
@router.post("/abort_chain")
|
@router.post("/abort_chain")
|
||||||
async def abort_chain_controller(
|
async def abort_chain_controller(
|
||||||
|
request: Request,
|
||||||
abort_chain_input: AbortChainInput,
|
abort_chain_input: AbortChainInput,
|
||||||
running_chain_repository: RunningChainRepositoryDependency,
|
running_chain_repository: RunningChainRepositoryDependency,
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
assert await running_chain_repository.exists(str(abort_chain_input.task_id))
|
assert await running_chain_repository.exists(str(abort_chain_input.task_id))
|
||||||
await running_chain_repository.delete(str(abort_chain_input.task_id))
|
await running_chain_repository.delete(str(abort_chain_input.task_id))
|
||||||
|
await cancel_task(request.app, abort_chain_input.task_id)
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
except AssertionError:
|
except AssertionError:
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,11 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
|
from loguru import logger
|
||||||
from .base import BaseProgressActionService
|
from .base import BaseProgressActionService
|
||||||
|
|
||||||
|
|
||||||
class WaitProgressActionService(BaseProgressActionService):
|
class WaitProgressActionService(BaseProgressActionService):
|
||||||
|
|
||||||
async def process(self):
|
async def process(self):
|
||||||
|
logger.info('WaitProgressActionService task started')
|
||||||
await asyncio.sleep(self.progress_action.wait_for)
|
await asyncio.sleep(self.progress_action.wait_for)
|
||||||
|
logger.info('WaitProgressActionService task ended')
|
||||||
|
|
|
||||||
27
chain_service/utils/tasks.py
Normal file
27
chain_service/utils/tasks.py
Normal file
|
|
@ -0,0 +1,27 @@
|
||||||
|
import asyncio
|
||||||
|
from asyncio import Task
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
|
||||||
|
async def create_task(app: FastAPI, identifier, callable, *args, **kwargs):
|
||||||
|
async with app.state.tasks_lock:
|
||||||
|
task = asyncio.create_task(callable(*args, **kwargs))
|
||||||
|
app.state.tasks[identifier].append(task)
|
||||||
|
logger.info(f"Task {identifier} created")
|
||||||
|
|
||||||
|
|
||||||
|
async def cancel_task(app: FastAPI, identifier):
|
||||||
|
async with app.state.tasks_lock:
|
||||||
|
tasks: List[Task] = app.state.tasks.get(identifier)
|
||||||
|
if not tasks:
|
||||||
|
logger.info(f"Task {identifier} not found, can't cancel")
|
||||||
|
for task in tasks:
|
||||||
|
try:
|
||||||
|
task.cancel()
|
||||||
|
await task # Wait for the task to finish cancellation
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
logger.info(f"Task {identifier} cancelled")
|
||||||
|
del app.state.tasks[identifier]
|
||||||
Loading…
Reference in New Issue
Block a user