Compare commits

..

No commits in common. "8446d83896e17efe8518bb91a0ec43a1e90b5f8a" and "8d7fd0e36cfe87849a4e4762a8e0cbb69fbae6b6" have entirely different histories.

9 changed files with 33 additions and 83 deletions

View File

@ -34,23 +34,20 @@ async def run_chain_controller(
try: try:
assert (chain := await chain_repository.get_by_id(run_chain_input.chain_id)) assert (chain := await chain_repository.get_by_id(run_chain_input.chain_id))
if await running_chain_repository.exists(str(run_chain_input.task_id)): if await running_chain_repository.exists(str(chain.id)):
logger.error(f"Chain {chain.id} is already running") logger.error(f"Chain {chain.id} is already running")
raise HTTPException(status_code=409, detail="Chain is already running") raise HTTPException(status_code=409, detail="Chain is already running")
progress_chain = ProgressChain.create_from_chain( progress_chain = ProgressChain.create_from_chain(
chain=chain, chain=chain,
task_id=run_chain_input.task_id, task_id=run_chain_input.task_id,
namespace_id=chain.namespace_id,
recipients=run_chain_input.recipients, recipients=run_chain_input.recipients,
variables=run_chain_input.variables, variables=run_chain_input.variables,
) )
progress_chain = await progress_chain_repository.upsert(progress_chain) progress_chain = await progress_chain_repository.upsert(progress_chain)
await running_chain_repository.add(str(chain.id))
await running_chain_repository.add(
task_id=str(progress_chain.task_id),
progress_chain_id=str(progress_chain.id),
)
asyncio.create_task(progress_chain_runner_service.process(progress_chain)) asyncio.create_task(progress_chain_runner_service.process(progress_chain))
return {} return {}
@ -59,9 +56,6 @@ async def run_chain_controller(
logger.warning(f"Chain not found {run_chain_input.chain_id}") logger.warning(f"Chain not found {run_chain_input.chain_id}")
raise HTTPException(status_code=404, detail="Chain not found") raise HTTPException(status_code=404, detail="Chain not found")
except HTTPException:
raise
except Exception: except Exception:
logger.exception("Error during run chain") logger.exception("Error during run chain")
raise HTTPException(status_code=500, detail="Error during run chain") raise HTTPException(status_code=500, detail="Error during run chain")
@ -73,8 +67,8 @@ async def abort_chain_controller(
running_chain_repository: RunningChainRepositoryDependency, running_chain_repository: RunningChainRepositoryDependency,
): ):
try: try:
assert await running_chain_repository.exists(str(abort_chain_input.task_id)) assert await running_chain_repository.exists(abort_chain_input.chain_id)
await running_chain_repository.delete(str(abort_chain_input.task_id)) await running_chain_repository.delete(abort_chain_input.chain_id)
return {} return {}
except AssertionError: except AssertionError:
@ -87,16 +81,3 @@ async def abort_chain_controller(
except Exception: except Exception:
logger.exception("Error during chain abortion") logger.exception("Error during chain abortion")
raise HTTPException(status_code=400, detail="Chain abortion error :)") raise HTTPException(status_code=400, detail="Chain abortion error :)")
@router.post("/abort_all_chains")
async def abort_all_chains_controller(
running_chain_repository: RunningChainRepositoryDependency,
):
try:
await running_chain_repository.delete_all()
return {}
except Exception:
logger.exception("Error during all chains abortion")
raise HTTPException(status_code=400, detail="Chains abortion error :)")

View File

@ -11,7 +11,6 @@ class ProgressActionStatusEnum(str, Enum):
PENDING = "pending" PENDING = "pending"
DONE = "done" DONE = "done"
FAILED = "failed" FAILED = "failed"
ABORTED = "aborted"
class BaseProgressAction(BaseModel): class BaseProgressAction(BaseModel):
@ -51,7 +50,6 @@ Action = Annotated[
class ProgressChain(BaseMongoModel): class ProgressChain(BaseMongoModel):
task_id: int task_id: int
chain_id: str
namespace_id: str namespace_id: str
variables: Annotated[Optional[Dict], Field(default={})] variables: Annotated[Optional[Dict], Field(default={})]
recipients: Annotated[Optional[List[int]], Field(default=[])] recipients: Annotated[Optional[List[int]], Field(default=[])]
@ -64,6 +62,7 @@ class ProgressChain(BaseMongoModel):
cls, cls,
chain: Chain, chain: Chain,
task_id: int, task_id: int,
namespace_id: str,
variables: Optional[Dict] = {}, variables: Optional[Dict] = {},
recipients: Optional[List[int]] = [], recipients: Optional[List[int]] = [],
): ):
@ -79,8 +78,7 @@ class ProgressChain(BaseMongoModel):
return ProgressChain( return ProgressChain(
task_id=task_id, task_id=task_id,
chain_id=str(chain.id), namespace_id=namespace_id,
namespace_id=chain.namespace_id,
variables=variables, variables=variables,
recipients=recipients, recipients=recipients,
name=chain.name, name=chain.name,

View File

@ -1,12 +1,9 @@
from .base import BaseConfig from .base import BaseConfig
from pydantic import BaseModel
from typing import Annotated, Optional
from pydantic import BaseModel, Field
class RunningChain(BaseModel): class RunningChain(BaseModel):
task_id: str chain_id: str
progress_chain_id: Annotated[Optional[str], Field(default=None)]
class Config(BaseConfig): class Config(BaseConfig):
pass pass

View File

@ -1,4 +1,8 @@
from chain_service.database.database import Database from chain_service.database.database import Database
from chain_service.database.models.running_chain import RunningChain
from uuid import UUID
from loguru import logger
class RunningChainRepository: class RunningChainRepository:
@ -6,26 +10,14 @@ class RunningChainRepository:
def __init__(self, database: Database): def __init__(self, database: Database):
self.collection = database.get_collection("running_chains") self.collection = database.get_collection("running_chains")
async def add(self, task_id: str, progress_chain_id: str): async def add(self, chain_id: str):
query = {"taskId": task_id} query = payload = {"chainId": chain_id}
payload = {"taskId": task_id, "progressChainId": progress_chain_id}
await self.collection.replace_one(query, payload, upsert=True) await self.collection.replace_one(query, payload, upsert=True)
async def exists(self, task_id: str, progress_chain_id: str = None) -> bool: async def exists(self, chain_id: str) -> bool:
query = {"taskId": task_id} query = {"chainId": chain_id}
if progress_chain_id:
query = {"progressChainId": progress_chain_id}
return bool(await self.collection.find_one(query)) return bool(await self.collection.find_one(query))
async def delete(self, task_id: str, progress_chain_id: str = None) -> bool: async def delete(self, chain_id: str):
query = {"taskId": task_id} query = {"chainId": chain_id}
if progress_chain_id:
query = {"progressChainId": progress_chain_id}
await self.collection.delete_one(query) await self.collection.delete_one(query)
async def delete_all(self):
await self.collection.delete_many({})

View File

@ -15,4 +15,4 @@ class RunChainInput(BaseModel):
class AbortChainInput(BaseModel): class AbortChainInput(BaseModel):
task_id: str chain_id: str

View File

@ -10,21 +10,21 @@ from io import BytesIO
class AudioConverterService: class AudioConverterService:
@sync_to_async @sync_to_async
def __audio_to_ogg(self, audio: BytesIO) -> str: def __mp3_to_ogg(self, mp3: BytesIO) -> str:
filename = f"./audios/{uuid4()}.ogg" filename = f"./audios/{uuid4()}.ogg"
process = ( process = (
ffmpeg.input("pipe:") ffmpeg.input("pipe:")
.output(filename, codec="libopus", loglevel="quiet") .output(filename, loglevel="quiet")
.overwrite_output() .overwrite_output()
.run_async(pipe_stdin=True) .run_async(pipe_stdin=True)
) )
process.communicate(input=audio.getbuffer()) process.communicate(input=mp3.getbuffer())
return filename return filename
async def audio_to_ogg(self, audio: BytesIO) -> BytesIO: async def mp3_to_ogg(self, mp3: BytesIO) -> BytesIO:
filename = await self.__audio_to_ogg(audio) filename = await self.__mp3_to_ogg(mp3)
async with aiofiles.open(filename, "rb") as file: async with aiofiles.open(filename, "rb") as file:
content = BytesIO(initial_bytes=await file.read()) content = BytesIO(initial_bytes=await file.read())

View File

@ -41,13 +41,14 @@ class FileUploaderService:
if uploaded_file: if uploaded_file:
return return
if file_url.endswith(".ogg") or file_url.endswith(".mp3"): if file_url.endswith(".mp3") or file_url.endswith(".ogg"):
content = await self.audio_converter_service.audio_to_ogg(
audio=BytesIO((await self.client.get(file_url)).read()) converted_content = await self.audio_converter_service.mp3_to_ogg(
mp3=BytesIO((await self.client.get(file_url)).read())
) )
uploaded_file_id = await self.planfix_client.upload_file( uploaded_file_id = await self.planfix_client.upload_file(
file_content=content file_content=converted_content
) )
else: else:

View File

@ -7,7 +7,6 @@ from chain_service.database.models.progress_chain import (
) )
from chain_service.repositories.progress_chain import ProgressChainRepository from chain_service.repositories.progress_chain import ProgressChainRepository
from chain_service.repositories.running_chain import RunningChainRepository
from loguru import logger from loguru import logger
from datetime import datetime from datetime import datetime
@ -19,11 +18,9 @@ class ProgressChainRunnerService:
self, self,
progress_chain_repository: ProgressChainRepository, progress_chain_repository: ProgressChainRepository,
progress_action_service_factory: ProgressActionServiceFactory, progress_action_service_factory: ProgressActionServiceFactory,
running_chain_repository: RunningChainRepository,
): ):
self.progress_chain_repository = progress_chain_repository self.progress_chain_repository = progress_chain_repository
self.progress_action_service_factory = progress_action_service_factory self.progress_action_service_factory = progress_action_service_factory
self.running_chain_repository = running_chain_repository
async def process(self, progress_chain: ProgressChain): async def process(self, progress_chain: ProgressChain):
@ -39,11 +36,6 @@ class ProgressChainRunnerService:
if not await self.process_action(progress_chain, progress_action): if not await self.process_action(progress_chain, progress_action):
break break
await self.running_chain_repository.delete(
task_id=str(progress_chain.task_id),
progress_chain_id=str(progress_chain.id),
)
async def process_action( async def process_action(
self, progress_chain: ProgressChain, progress_action: BaseProgressAction self, progress_chain: ProgressChain, progress_action: BaseProgressAction
): ):
@ -56,20 +48,9 @@ class ProgressChainRunnerService:
) )
await progress_action_service.process() await progress_action_service.process()
assert await self.running_chain_repository.exists(
task_id=str(progress_chain.task_id),
progress_chain_id=str(progress_chain.id),
)
progress_action.status = ProgressActionStatusEnum.DONE progress_action.status = ProgressActionStatusEnum.DONE
return True return True
except AssertionError:
logger.info(f"Chain was aborted {progress_chain.chain_id}")
progress_action.status = ProgressActionStatusEnum.ABORTED
return False
except Exception as error: except Exception as error:
logger.exception(f"Error during action process for {progress_chain.id}") logger.exception(f"Error during action process for {progress_chain.id}")
progress_action.status = ProgressActionStatusEnum.FAILED progress_action.status = ProgressActionStatusEnum.FAILED

6
poetry.lock generated
View File

@ -348,13 +348,13 @@ zstd = ["pymongo[zstd] (>=4.5,<5)"]
[[package]] [[package]]
name = "planfix-client" name = "planfix-client"
version = "0.1.11" version = "0.1.9"
description = "" description = ""
optional = false optional = false
python-versions = ">=3.12,<4.0" python-versions = ">=3.12,<4.0"
files = [ files = [
{file = "planfix_client-0.1.11-py3-none-any.whl", hash = "sha256:f127d2b3d44ac65a6e0f8f2eceefd524103d82b35a955a70efc951fccd67b4e8"}, {file = "planfix_client-0.1.9-py3-none-any.whl", hash = "sha256:879cbe091446fb9e808c6f14e9dfeaf737ca3e58ea55b0008743d6b7fc079d3c"},
{file = "planfix_client-0.1.11.tar.gz", hash = "sha256:f6f837349d1c64e6fd625139aa393d72466be93284abccf97ab07a513e4bf713"}, {file = "planfix_client-0.1.9.tar.gz", hash = "sha256:44a5ec39ce0d55e309f50db7fd9aafb3ab25848743b5367ae99ababb025b1a75"},
] ]
[package.dependencies] [package.dependencies]