From 040dacf2f59e596771c23ca6402cc3d67b47476e Mon Sep 17 00:00:00 2001 From: Robert Date: Wed, 13 Mar 2024 03:01:13 +0700 Subject: [PATCH] namespace id --- chain_service/controllers/chain.py | 9 +++++++++ chain_service/controllers/run_chain.py | 1 + chain_service/database/models/chain.py | 1 + chain_service/database/models/progress_chain.py | 8 +++++++- 4 files changed, 18 insertions(+), 1 deletion(-) diff --git a/chain_service/controllers/chain.py b/chain_service/controllers/chain.py index 2ac9747..931010b 100644 --- a/chain_service/controllers/chain.py +++ b/chain_service/controllers/chain.py @@ -4,6 +4,9 @@ from fastapi import APIRouter, HTTPException from chain_service.database.models.chain import Chain from chain_service.dependencies.chain_repository import ChainRepositoryDependency +from chain_service.dependencies.namespace_repository import ( + NamespaceRepositoryDependency, +) from chain_service.dependencies.file_uploader_service import ( FileUploaderServiceDependency, ) @@ -15,13 +18,19 @@ router = APIRouter(prefix="/chain") async def chain_upsert_controller( chain: Chain, chain_repository: ChainRepositoryDependency, + namespace_repository: NamespaceRepositoryDependency, file_uploader_service: FileUploaderServiceDependency, ): try: + assert await namespace_repository.get_by_id(namespace_id=chain.namespace_id) upserted_chain = await chain_repository.upsert(chain) await file_uploader_service.upload_from_chain(upserted_chain) return upserted_chain + except AssertionError: + logger.exception(f"Unknown namespace_id {chain.namespace_id}") + return HTTPException(status_code=400, detail="Wrong namespace_id") + except Exception: logger.exception(f"Error during chain upsert {chain.model_dump_json()}") return HTTPException(status_code=500, detail="Error during chain upsert") diff --git a/chain_service/controllers/run_chain.py b/chain_service/controllers/run_chain.py index 4c9684d..bd3a359 100644 --- a/chain_service/controllers/run_chain.py +++ b/chain_service/controllers/run_chain.py @@ -32,6 +32,7 @@ async def run_chain_controller( progress_chain = ProgressChain.create_from_chain( chain=chain, task_id=run_chain_input.task_id, + namespace_id=chain.namespace_id, recipients=run_chain_input.recipients, ) diff --git a/chain_service/database/models/chain.py b/chain_service/database/models/chain.py index 9c6b1bf..9a2560b 100644 --- a/chain_service/database/models/chain.py +++ b/chain_service/database/models/chain.py @@ -27,6 +27,7 @@ Action = Annotated[Union[WaitAction, CommentAction], Field(description="action_t class Chain(BaseMongoModel): + namespace_id: str name: Annotated[Optional[str], Field(default=None)] actions: Annotated[Optional[List[Action]], Field(default=[])] last_modified: Annotated[datetime, Field(default_factory=datetime.utcnow)] diff --git a/chain_service/database/models/progress_chain.py b/chain_service/database/models/progress_chain.py index ee6618e..2813c54 100644 --- a/chain_service/database/models/progress_chain.py +++ b/chain_service/database/models/progress_chain.py @@ -50,6 +50,7 @@ Action = Annotated[ class ProgressChain(BaseMongoModel): task_id: int + namespace_id: str recipients: Annotated[Optional[List[int]], Field(default=[])] name: Annotated[Optional[str], Field(default=None)] actions: Annotated[Optional[List[Action]], Field(default=[])] @@ -57,10 +58,15 @@ class ProgressChain(BaseMongoModel): @classmethod def create_from_chain( - cls, chain: Chain, task_id: int, recipients: Optional[List[int]] = [] + cls, + chain: Chain, + task_id: int, + namespace_id: str, + recipients: Optional[List[int]] = [], ): return ProgressChain( task_id=task_id, + namespace_id=namespace_id, recipients=recipients, name=chain.name, actions=map(Chain.model_dump, chain.actions),