Compare commits
10 Commits
8d7fd0e36c
...
8446d83896
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8446d83896 | ||
|
|
bff53e6b07 | ||
|
|
5fe35e2a77 | ||
|
|
94c7f98ce8 | ||
|
|
fe87afe9d0 | ||
|
|
70ad69a5b0 | ||
|
|
2798e34e4f | ||
|
|
17f5b9d797 | ||
|
|
adadadb92e | ||
|
|
e30b9b3c14 |
|
|
@ -34,20 +34,23 @@ async def run_chain_controller(
|
||||||
try:
|
try:
|
||||||
assert (chain := await chain_repository.get_by_id(run_chain_input.chain_id))
|
assert (chain := await chain_repository.get_by_id(run_chain_input.chain_id))
|
||||||
|
|
||||||
if await running_chain_repository.exists(str(chain.id)):
|
if await running_chain_repository.exists(str(run_chain_input.task_id)):
|
||||||
logger.error(f"Chain {chain.id} is already running")
|
logger.error(f"Chain {chain.id} is already running")
|
||||||
raise HTTPException(status_code=409, detail="Chain is already running")
|
raise HTTPException(status_code=409, detail="Chain is already running")
|
||||||
|
|
||||||
progress_chain = ProgressChain.create_from_chain(
|
progress_chain = ProgressChain.create_from_chain(
|
||||||
chain=chain,
|
chain=chain,
|
||||||
task_id=run_chain_input.task_id,
|
task_id=run_chain_input.task_id,
|
||||||
namespace_id=chain.namespace_id,
|
|
||||||
recipients=run_chain_input.recipients,
|
recipients=run_chain_input.recipients,
|
||||||
variables=run_chain_input.variables,
|
variables=run_chain_input.variables,
|
||||||
)
|
)
|
||||||
|
|
||||||
progress_chain = await progress_chain_repository.upsert(progress_chain)
|
progress_chain = await progress_chain_repository.upsert(progress_chain)
|
||||||
await running_chain_repository.add(str(chain.id))
|
|
||||||
|
await running_chain_repository.add(
|
||||||
|
task_id=str(progress_chain.task_id),
|
||||||
|
progress_chain_id=str(progress_chain.id),
|
||||||
|
)
|
||||||
|
|
||||||
asyncio.create_task(progress_chain_runner_service.process(progress_chain))
|
asyncio.create_task(progress_chain_runner_service.process(progress_chain))
|
||||||
return {}
|
return {}
|
||||||
|
|
@ -56,6 +59,9 @@ async def run_chain_controller(
|
||||||
logger.warning(f"Chain not found {run_chain_input.chain_id}")
|
logger.warning(f"Chain not found {run_chain_input.chain_id}")
|
||||||
raise HTTPException(status_code=404, detail="Chain not found")
|
raise HTTPException(status_code=404, detail="Chain not found")
|
||||||
|
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Error during run chain")
|
logger.exception("Error during run chain")
|
||||||
raise HTTPException(status_code=500, detail="Error during run chain")
|
raise HTTPException(status_code=500, detail="Error during run chain")
|
||||||
|
|
@ -67,8 +73,8 @@ async def abort_chain_controller(
|
||||||
running_chain_repository: RunningChainRepositoryDependency,
|
running_chain_repository: RunningChainRepositoryDependency,
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
assert await running_chain_repository.exists(abort_chain_input.chain_id)
|
assert await running_chain_repository.exists(str(abort_chain_input.task_id))
|
||||||
await running_chain_repository.delete(abort_chain_input.chain_id)
|
await running_chain_repository.delete(str(abort_chain_input.task_id))
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
except AssertionError:
|
except AssertionError:
|
||||||
|
|
@ -81,3 +87,16 @@ async def abort_chain_controller(
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Error during chain abortion")
|
logger.exception("Error during chain abortion")
|
||||||
raise HTTPException(status_code=400, detail="Chain abortion error :)")
|
raise HTTPException(status_code=400, detail="Chain abortion error :)")
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/abort_all_chains")
|
||||||
|
async def abort_all_chains_controller(
|
||||||
|
running_chain_repository: RunningChainRepositoryDependency,
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
await running_chain_repository.delete_all()
|
||||||
|
return {}
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Error during all chains abortion")
|
||||||
|
raise HTTPException(status_code=400, detail="Chains abortion error :)")
|
||||||
|
|
|
||||||
|
|
@ -11,6 +11,7 @@ class ProgressActionStatusEnum(str, Enum):
|
||||||
PENDING = "pending"
|
PENDING = "pending"
|
||||||
DONE = "done"
|
DONE = "done"
|
||||||
FAILED = "failed"
|
FAILED = "failed"
|
||||||
|
ABORTED = "aborted"
|
||||||
|
|
||||||
|
|
||||||
class BaseProgressAction(BaseModel):
|
class BaseProgressAction(BaseModel):
|
||||||
|
|
@ -50,6 +51,7 @@ Action = Annotated[
|
||||||
|
|
||||||
class ProgressChain(BaseMongoModel):
|
class ProgressChain(BaseMongoModel):
|
||||||
task_id: int
|
task_id: int
|
||||||
|
chain_id: str
|
||||||
namespace_id: str
|
namespace_id: str
|
||||||
variables: Annotated[Optional[Dict], Field(default={})]
|
variables: Annotated[Optional[Dict], Field(default={})]
|
||||||
recipients: Annotated[Optional[List[int]], Field(default=[])]
|
recipients: Annotated[Optional[List[int]], Field(default=[])]
|
||||||
|
|
@ -62,7 +64,6 @@ class ProgressChain(BaseMongoModel):
|
||||||
cls,
|
cls,
|
||||||
chain: Chain,
|
chain: Chain,
|
||||||
task_id: int,
|
task_id: int,
|
||||||
namespace_id: str,
|
|
||||||
variables: Optional[Dict] = {},
|
variables: Optional[Dict] = {},
|
||||||
recipients: Optional[List[int]] = [],
|
recipients: Optional[List[int]] = [],
|
||||||
):
|
):
|
||||||
|
|
@ -78,7 +79,8 @@ class ProgressChain(BaseMongoModel):
|
||||||
|
|
||||||
return ProgressChain(
|
return ProgressChain(
|
||||||
task_id=task_id,
|
task_id=task_id,
|
||||||
namespace_id=namespace_id,
|
chain_id=str(chain.id),
|
||||||
|
namespace_id=chain.namespace_id,
|
||||||
variables=variables,
|
variables=variables,
|
||||||
recipients=recipients,
|
recipients=recipients,
|
||||||
name=chain.name,
|
name=chain.name,
|
||||||
|
|
|
||||||
|
|
@ -1,9 +1,12 @@
|
||||||
from .base import BaseConfig
|
from .base import BaseConfig
|
||||||
from pydantic import BaseModel
|
|
||||||
|
from typing import Annotated, Optional
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
class RunningChain(BaseModel):
|
class RunningChain(BaseModel):
|
||||||
chain_id: str
|
task_id: str
|
||||||
|
progress_chain_id: Annotated[Optional[str], Field(default=None)]
|
||||||
|
|
||||||
class Config(BaseConfig):
|
class Config(BaseConfig):
|
||||||
pass
|
pass
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,4 @@
|
||||||
from chain_service.database.database import Database
|
from chain_service.database.database import Database
|
||||||
from chain_service.database.models.running_chain import RunningChain
|
|
||||||
|
|
||||||
from uuid import UUID
|
|
||||||
from loguru import logger
|
|
||||||
|
|
||||||
|
|
||||||
class RunningChainRepository:
|
class RunningChainRepository:
|
||||||
|
|
@ -10,14 +6,26 @@ class RunningChainRepository:
|
||||||
def __init__(self, database: Database):
|
def __init__(self, database: Database):
|
||||||
self.collection = database.get_collection("running_chains")
|
self.collection = database.get_collection("running_chains")
|
||||||
|
|
||||||
async def add(self, chain_id: str):
|
async def add(self, task_id: str, progress_chain_id: str):
|
||||||
query = payload = {"chainId": chain_id}
|
query = {"taskId": task_id}
|
||||||
|
payload = {"taskId": task_id, "progressChainId": progress_chain_id}
|
||||||
await self.collection.replace_one(query, payload, upsert=True)
|
await self.collection.replace_one(query, payload, upsert=True)
|
||||||
|
|
||||||
async def exists(self, chain_id: str) -> bool:
|
async def exists(self, task_id: str, progress_chain_id: str = None) -> bool:
|
||||||
query = {"chainId": chain_id}
|
query = {"taskId": task_id}
|
||||||
|
|
||||||
|
if progress_chain_id:
|
||||||
|
query = {"progressChainId": progress_chain_id}
|
||||||
|
|
||||||
return bool(await self.collection.find_one(query))
|
return bool(await self.collection.find_one(query))
|
||||||
|
|
||||||
async def delete(self, chain_id: str):
|
async def delete(self, task_id: str, progress_chain_id: str = None) -> bool:
|
||||||
query = {"chainId": chain_id}
|
query = {"taskId": task_id}
|
||||||
|
|
||||||
|
if progress_chain_id:
|
||||||
|
query = {"progressChainId": progress_chain_id}
|
||||||
|
|
||||||
await self.collection.delete_one(query)
|
await self.collection.delete_one(query)
|
||||||
|
|
||||||
|
async def delete_all(self):
|
||||||
|
await self.collection.delete_many({})
|
||||||
|
|
|
||||||
|
|
@ -15,4 +15,4 @@ class RunChainInput(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
class AbortChainInput(BaseModel):
|
class AbortChainInput(BaseModel):
|
||||||
chain_id: str
|
task_id: str
|
||||||
|
|
|
||||||
|
|
@ -10,21 +10,21 @@ from io import BytesIO
|
||||||
class AudioConverterService:
|
class AudioConverterService:
|
||||||
|
|
||||||
@sync_to_async
|
@sync_to_async
|
||||||
def __mp3_to_ogg(self, mp3: BytesIO) -> str:
|
def __audio_to_ogg(self, audio: BytesIO) -> str:
|
||||||
filename = f"./audios/{uuid4()}.ogg"
|
filename = f"./audios/{uuid4()}.ogg"
|
||||||
|
|
||||||
process = (
|
process = (
|
||||||
ffmpeg.input("pipe:")
|
ffmpeg.input("pipe:")
|
||||||
.output(filename, loglevel="quiet")
|
.output(filename, codec="libopus", loglevel="quiet")
|
||||||
.overwrite_output()
|
.overwrite_output()
|
||||||
.run_async(pipe_stdin=True)
|
.run_async(pipe_stdin=True)
|
||||||
)
|
)
|
||||||
|
|
||||||
process.communicate(input=mp3.getbuffer())
|
process.communicate(input=audio.getbuffer())
|
||||||
return filename
|
return filename
|
||||||
|
|
||||||
async def mp3_to_ogg(self, mp3: BytesIO) -> BytesIO:
|
async def audio_to_ogg(self, audio: BytesIO) -> BytesIO:
|
||||||
filename = await self.__mp3_to_ogg(mp3)
|
filename = await self.__audio_to_ogg(audio)
|
||||||
|
|
||||||
async with aiofiles.open(filename, "rb") as file:
|
async with aiofiles.open(filename, "rb") as file:
|
||||||
content = BytesIO(initial_bytes=await file.read())
|
content = BytesIO(initial_bytes=await file.read())
|
||||||
|
|
|
||||||
|
|
@ -41,14 +41,13 @@ class FileUploaderService:
|
||||||
if uploaded_file:
|
if uploaded_file:
|
||||||
return
|
return
|
||||||
|
|
||||||
if file_url.endswith(".mp3") or file_url.endswith(".ogg"):
|
if file_url.endswith(".ogg") or file_url.endswith(".mp3"):
|
||||||
|
content = await self.audio_converter_service.audio_to_ogg(
|
||||||
converted_content = await self.audio_converter_service.mp3_to_ogg(
|
audio=BytesIO((await self.client.get(file_url)).read())
|
||||||
mp3=BytesIO((await self.client.get(file_url)).read())
|
|
||||||
)
|
)
|
||||||
|
|
||||||
uploaded_file_id = await self.planfix_client.upload_file(
|
uploaded_file_id = await self.planfix_client.upload_file(
|
||||||
file_content=converted_content
|
file_content=content
|
||||||
)
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
|
|
||||||
|
|
@ -7,6 +7,7 @@ from chain_service.database.models.progress_chain import (
|
||||||
)
|
)
|
||||||
|
|
||||||
from chain_service.repositories.progress_chain import ProgressChainRepository
|
from chain_service.repositories.progress_chain import ProgressChainRepository
|
||||||
|
from chain_service.repositories.running_chain import RunningChainRepository
|
||||||
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
@ -18,9 +19,11 @@ class ProgressChainRunnerService:
|
||||||
self,
|
self,
|
||||||
progress_chain_repository: ProgressChainRepository,
|
progress_chain_repository: ProgressChainRepository,
|
||||||
progress_action_service_factory: ProgressActionServiceFactory,
|
progress_action_service_factory: ProgressActionServiceFactory,
|
||||||
|
running_chain_repository: RunningChainRepository,
|
||||||
):
|
):
|
||||||
self.progress_chain_repository = progress_chain_repository
|
self.progress_chain_repository = progress_chain_repository
|
||||||
self.progress_action_service_factory = progress_action_service_factory
|
self.progress_action_service_factory = progress_action_service_factory
|
||||||
|
self.running_chain_repository = running_chain_repository
|
||||||
|
|
||||||
async def process(self, progress_chain: ProgressChain):
|
async def process(self, progress_chain: ProgressChain):
|
||||||
|
|
||||||
|
|
@ -36,6 +39,11 @@ class ProgressChainRunnerService:
|
||||||
if not await self.process_action(progress_chain, progress_action):
|
if not await self.process_action(progress_chain, progress_action):
|
||||||
break
|
break
|
||||||
|
|
||||||
|
await self.running_chain_repository.delete(
|
||||||
|
task_id=str(progress_chain.task_id),
|
||||||
|
progress_chain_id=str(progress_chain.id),
|
||||||
|
)
|
||||||
|
|
||||||
async def process_action(
|
async def process_action(
|
||||||
self, progress_chain: ProgressChain, progress_action: BaseProgressAction
|
self, progress_chain: ProgressChain, progress_action: BaseProgressAction
|
||||||
):
|
):
|
||||||
|
|
@ -48,9 +56,20 @@ class ProgressChainRunnerService:
|
||||||
)
|
)
|
||||||
|
|
||||||
await progress_action_service.process()
|
await progress_action_service.process()
|
||||||
|
|
||||||
|
assert await self.running_chain_repository.exists(
|
||||||
|
task_id=str(progress_chain.task_id),
|
||||||
|
progress_chain_id=str(progress_chain.id),
|
||||||
|
)
|
||||||
|
|
||||||
progress_action.status = ProgressActionStatusEnum.DONE
|
progress_action.status = ProgressActionStatusEnum.DONE
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
except AssertionError:
|
||||||
|
logger.info(f"Chain was aborted {progress_chain.chain_id}")
|
||||||
|
progress_action.status = ProgressActionStatusEnum.ABORTED
|
||||||
|
return False
|
||||||
|
|
||||||
except Exception as error:
|
except Exception as error:
|
||||||
logger.exception(f"Error during action process for {progress_chain.id}")
|
logger.exception(f"Error during action process for {progress_chain.id}")
|
||||||
progress_action.status = ProgressActionStatusEnum.FAILED
|
progress_action.status = ProgressActionStatusEnum.FAILED
|
||||||
|
|
|
||||||
6
poetry.lock
generated
6
poetry.lock
generated
|
|
@ -348,13 +348,13 @@ zstd = ["pymongo[zstd] (>=4.5,<5)"]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "planfix-client"
|
name = "planfix-client"
|
||||||
version = "0.1.9"
|
version = "0.1.11"
|
||||||
description = ""
|
description = ""
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = ">=3.12,<4.0"
|
python-versions = ">=3.12,<4.0"
|
||||||
files = [
|
files = [
|
||||||
{file = "planfix_client-0.1.9-py3-none-any.whl", hash = "sha256:879cbe091446fb9e808c6f14e9dfeaf737ca3e58ea55b0008743d6b7fc079d3c"},
|
{file = "planfix_client-0.1.11-py3-none-any.whl", hash = "sha256:f127d2b3d44ac65a6e0f8f2eceefd524103d82b35a955a70efc951fccd67b4e8"},
|
||||||
{file = "planfix_client-0.1.9.tar.gz", hash = "sha256:44a5ec39ce0d55e309f50db7fd9aafb3ab25848743b5367ae99ababb025b1a75"},
|
{file = "planfix_client-0.1.11.tar.gz", hash = "sha256:f6f837349d1c64e6fd625139aa393d72466be93284abccf97ab07a513e4bf713"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user