diff --git a/apps/account/serializers/back.py b/apps/account/serializers/back.py index 2baaf656..4fba1c6f 100644 --- a/apps/account/serializers/back.py +++ b/apps/account/serializers/back.py @@ -2,7 +2,7 @@ from rest_framework import serializers from account import models -from account.serializers import RoleBaseSerializer, UserSerializer, subscriptions_handler +from account.serializers import RoleBaseSerializer, UserSerializer, subscriptions_handler, roles_handler from main.models import SiteSettings @@ -19,7 +19,7 @@ class _SiteSettingsSerializer(serializers.ModelSerializer): class BackUserSerializer(UserSerializer): last_country = _SiteSettingsSerializer(read_only=True) - roles = RoleBaseSerializer(many=True, read_only=True) + roles = RoleBaseSerializer(many=True) class Meta(UserSerializer.Meta): fields = ( @@ -115,6 +115,7 @@ class BackDetailUserSerializer(BackUserSerializer): def create(self, validated_data): subscriptions_list = [] + if 'subscription_types' in validated_data: subscriptions_list = validated_data.pop('subscription_types') @@ -127,11 +128,16 @@ class BackDetailUserSerializer(BackUserSerializer): def update(self, instance, validated_data): subscriptions_list = [] + if 'subscription_types' in validated_data: subscriptions_list = validated_data.pop('subscription_types') + if 'roles' in validated_data: + instance = roles_handler(validated_data.pop('roles'), instance) + instance = super().update(instance, validated_data) subscriptions_handler(subscriptions_list, instance) + return instance diff --git a/apps/account/serializers/common.py b/apps/account/serializers/common.py index 197c3a84..790cbfdf 100644 --- a/apps/account/serializers/common.py +++ b/apps/account/serializers/common.py @@ -8,6 +8,7 @@ from rest_framework import serializers from rest_framework import validators as rest_validators from account import models, tasks +from account.models import User, Role from main.serializers.common import NavigationBarPermissionBaseSerializer from notification.models import Subscribe, Subscriber from utils import exceptions as utils_exceptions @@ -16,6 +17,24 @@ from utils.methods import generate_string_code from phonenumber_field.serializerfields import PhoneNumberField +def roles_handler(roles_list: set, user: User) -> User: + """ + Sync roles for user + :param roles_list: list of user roles + :param user: user + :return: bool + """ + if not roles_list: + user.roles.clear() + return user + + ids = list(map(lambda role: role["id"] if "id" in role else None, roles_list)) + roles = Role.objects \ + .filter(id__in=ids) + user.roles.set(roles) + return user + + def subscriptions_handler(subscriptions_list, user): """ create or update subscriptions for user @@ -27,7 +46,7 @@ def subscriptions_handler(subscriptions_list, user): 'user': user, 'email': user.email, 'ip_address': user.last_ip, - 'country_code': user.last_country.country.code if user.last_country else None, + 'country_code': user.last_country.country.code if user.last_country and user.last_country.country else None, 'locale': user.locale, 'update_code': generate_string_code(), } @@ -42,6 +61,7 @@ def subscriptions_handler(subscriptions_list, user): class RoleBaseSerializer(serializers.ModelSerializer): """Serializer for model Role.""" + id = serializers.IntegerField() role_display = serializers.CharField(source='get_role_display', read_only=True) navigation_bar_permission = NavigationBarPermissionBaseSerializer(read_only=True) country_code = serializers.CharField(source='country.code', read_only=True, allow_null=True) @@ -133,6 +153,7 @@ class UserSerializer(serializers.ModelSerializer): def update(self, instance, validated_data): """Override update method""" subscriptions_list = [] + if 'subscription_types' in validated_data: subscriptions_list = validated_data.pop('subscription_types') @@ -164,6 +185,7 @@ class UserSerializer(serializers.ModelSerializer): emails=[validated_data['email'], ]) subscriptions_handler(subscriptions_list, instance) + return instance