diff --git a/apps/account/models.py b/apps/account/models.py index 08944f0c..e67fef49 100644 --- a/apps/account/models.py +++ b/apps/account/models.py @@ -13,6 +13,7 @@ from django.utils.http import urlsafe_base64_encode from django.utils.translation import ugettext_lazy as _ from rest_framework.authtoken.models import Token from collections import Counter +from typing import List from authorization.models import Application from establishment.models import Establishment, EstablishmentSubType @@ -435,6 +436,15 @@ class User(AbstractUser): result.append(item.id) return set(result) + def set_roles(self, ids: List(int)): + """ + Set user roles + :param ids: list of role ids + :return: bool + """ + self.roles.set(Role.objects.filter(id__in=ids)) + return self + class UserRoleQueryset(models.QuerySet): """QuerySet for model UserRole.""" diff --git a/apps/account/serializers/back.py b/apps/account/serializers/back.py index 2baaf656..93e33fa1 100644 --- a/apps/account/serializers/back.py +++ b/apps/account/serializers/back.py @@ -19,7 +19,7 @@ class _SiteSettingsSerializer(serializers.ModelSerializer): class BackUserSerializer(UserSerializer): last_country = _SiteSettingsSerializer(read_only=True) - roles = RoleBaseSerializer(many=True, read_only=True) + roles = RoleBaseSerializer(many=True) class Meta(UserSerializer.Meta): fields = ( @@ -115,6 +115,7 @@ class BackDetailUserSerializer(BackUserSerializer): def create(self, validated_data): subscriptions_list = [] + if 'subscription_types' in validated_data: subscriptions_list = validated_data.pop('subscription_types') @@ -127,11 +128,17 @@ class BackDetailUserSerializer(BackUserSerializer): def update(self, instance, validated_data): subscriptions_list = [] + if 'subscription_types' in validated_data: subscriptions_list = validated_data.pop('subscription_types') + if 'roles' in validated_data: + roles_ids = [role['id'] for role in validated_data.pop('roles') if 'id' in role] + instance.set_roles(roles_ids) + instance = super().update(instance, validated_data) subscriptions_handler(subscriptions_list, instance) + return instance diff --git a/apps/account/serializers/common.py b/apps/account/serializers/common.py index 197c3a84..7ae6cb70 100644 --- a/apps/account/serializers/common.py +++ b/apps/account/serializers/common.py @@ -8,6 +8,7 @@ from rest_framework import serializers from rest_framework import validators as rest_validators from account import models, tasks +from account.models import User, Role from main.serializers.common import NavigationBarPermissionBaseSerializer from notification.models import Subscribe, Subscriber from utils import exceptions as utils_exceptions @@ -27,7 +28,7 @@ def subscriptions_handler(subscriptions_list, user): 'user': user, 'email': user.email, 'ip_address': user.last_ip, - 'country_code': user.last_country.country.code if user.last_country else None, + 'country_code': user.last_country.country.code if user.last_country and user.last_country.country else None, 'locale': user.locale, 'update_code': generate_string_code(), } @@ -42,6 +43,7 @@ def subscriptions_handler(subscriptions_list, user): class RoleBaseSerializer(serializers.ModelSerializer): """Serializer for model Role.""" + id = serializers.IntegerField() role_display = serializers.CharField(source='get_role_display', read_only=True) navigation_bar_permission = NavigationBarPermissionBaseSerializer(read_only=True) country_code = serializers.CharField(source='country.code', read_only=True, allow_null=True) @@ -133,6 +135,7 @@ class UserSerializer(serializers.ModelSerializer): def update(self, instance, validated_data): """Override update method""" subscriptions_list = [] + if 'subscription_types' in validated_data: subscriptions_list = validated_data.pop('subscription_types') @@ -164,6 +167,7 @@ class UserSerializer(serializers.ModelSerializer): emails=[validated_data['email'], ]) subscriptions_handler(subscriptions_list, instance) + return instance