update roles list

This commit is contained in:
a.gorbunov 2020-01-29 09:03:45 +00:00
parent f58ac2a1cc
commit 6a6d623ef3
2 changed files with 31 additions and 3 deletions

View File

@ -2,7 +2,7 @@
from rest_framework import serializers from rest_framework import serializers
from account import models from account import models
from account.serializers import RoleBaseSerializer, UserSerializer, subscriptions_handler from account.serializers import RoleBaseSerializer, UserSerializer, subscriptions_handler, roles_handler
from main.models import SiteSettings from main.models import SiteSettings
@ -19,7 +19,7 @@ class _SiteSettingsSerializer(serializers.ModelSerializer):
class BackUserSerializer(UserSerializer): class BackUserSerializer(UserSerializer):
last_country = _SiteSettingsSerializer(read_only=True) last_country = _SiteSettingsSerializer(read_only=True)
roles = RoleBaseSerializer(many=True, read_only=True) roles = RoleBaseSerializer(many=True)
class Meta(UserSerializer.Meta): class Meta(UserSerializer.Meta):
fields = ( fields = (
@ -115,6 +115,7 @@ class BackDetailUserSerializer(BackUserSerializer):
def create(self, validated_data): def create(self, validated_data):
subscriptions_list = [] subscriptions_list = []
if 'subscription_types' in validated_data: if 'subscription_types' in validated_data:
subscriptions_list = validated_data.pop('subscription_types') subscriptions_list = validated_data.pop('subscription_types')
@ -127,11 +128,16 @@ class BackDetailUserSerializer(BackUserSerializer):
def update(self, instance, validated_data): def update(self, instance, validated_data):
subscriptions_list = [] subscriptions_list = []
if 'subscription_types' in validated_data: if 'subscription_types' in validated_data:
subscriptions_list = validated_data.pop('subscription_types') subscriptions_list = validated_data.pop('subscription_types')
if 'roles' in validated_data:
instance = roles_handler(validated_data.pop('roles'), instance)
instance = super().update(instance, validated_data) instance = super().update(instance, validated_data)
subscriptions_handler(subscriptions_list, instance) subscriptions_handler(subscriptions_list, instance)
return instance return instance

View File

@ -8,6 +8,7 @@ from rest_framework import serializers
from rest_framework import validators as rest_validators from rest_framework import validators as rest_validators
from account import models, tasks from account import models, tasks
from account.models import User, Role
from main.serializers.common import NavigationBarPermissionBaseSerializer from main.serializers.common import NavigationBarPermissionBaseSerializer
from notification.models import Subscribe, Subscriber from notification.models import Subscribe, Subscriber
from utils import exceptions as utils_exceptions from utils import exceptions as utils_exceptions
@ -16,6 +17,24 @@ from utils.methods import generate_string_code
from phonenumber_field.serializerfields import PhoneNumberField from phonenumber_field.serializerfields import PhoneNumberField
def roles_handler(roles_list: set, user: User) -> User:
"""
Sync roles for user
:param roles_list: list of user roles
:param user: user
:return: bool
"""
if not roles_list:
user.roles.clear()
return user
ids = list(map(lambda role: role["id"] if "id" in role else None, roles_list))
roles = Role.objects \
.filter(id__in=ids)
user.roles.set(roles)
return user
def subscriptions_handler(subscriptions_list, user): def subscriptions_handler(subscriptions_list, user):
""" """
create or update subscriptions for user create or update subscriptions for user
@ -27,7 +46,7 @@ def subscriptions_handler(subscriptions_list, user):
'user': user, 'user': user,
'email': user.email, 'email': user.email,
'ip_address': user.last_ip, 'ip_address': user.last_ip,
'country_code': user.last_country.country.code if user.last_country else None, 'country_code': user.last_country.country.code if user.last_country and user.last_country.country else None,
'locale': user.locale, 'locale': user.locale,
'update_code': generate_string_code(), 'update_code': generate_string_code(),
} }
@ -42,6 +61,7 @@ def subscriptions_handler(subscriptions_list, user):
class RoleBaseSerializer(serializers.ModelSerializer): class RoleBaseSerializer(serializers.ModelSerializer):
"""Serializer for model Role.""" """Serializer for model Role."""
id = serializers.IntegerField()
role_display = serializers.CharField(source='get_role_display', read_only=True) role_display = serializers.CharField(source='get_role_display', read_only=True)
navigation_bar_permission = NavigationBarPermissionBaseSerializer(read_only=True) navigation_bar_permission = NavigationBarPermissionBaseSerializer(read_only=True)
country_code = serializers.CharField(source='country.code', read_only=True, allow_null=True) country_code = serializers.CharField(source='country.code', read_only=True, allow_null=True)
@ -133,6 +153,7 @@ class UserSerializer(serializers.ModelSerializer):
def update(self, instance, validated_data): def update(self, instance, validated_data):
"""Override update method""" """Override update method"""
subscriptions_list = [] subscriptions_list = []
if 'subscription_types' in validated_data: if 'subscription_types' in validated_data:
subscriptions_list = validated_data.pop('subscription_types') subscriptions_list = validated_data.pop('subscription_types')
@ -164,6 +185,7 @@ class UserSerializer(serializers.ModelSerializer):
emails=[validated_data['email'], ]) emails=[validated_data['email'], ])
subscriptions_handler(subscriptions_list, instance) subscriptions_handler(subscriptions_list, instance)
return instance return instance