224 lines
7.3 KiB
Python
224 lines
7.3 KiB
Python
"""Common serializer for application authorization"""
|
|
from django.conf import settings
|
|
from django.contrib.auth import authenticate
|
|
from django.contrib.auth import password_validation as password_validators
|
|
from django.db.models import Q
|
|
from rest_framework import serializers
|
|
from rest_framework import validators as rest_validators
|
|
# JWT
|
|
from rest_framework_simplejwt import tokens
|
|
|
|
from account import models as account_models
|
|
from authorization.models import Application, BlacklistedAccessToken
|
|
from utils import exceptions as utils_exceptions
|
|
from utils import methods as utils_methods
|
|
|
|
JWT_SETTINGS = settings.SIMPLE_JWT
|
|
|
|
|
|
# Mixins
|
|
class BaseAuthSerializerMixin(serializers.Serializer):
|
|
"""Base authorization serializer mixin"""
|
|
source = serializers.ChoiceField(choices=Application.SOURCES)
|
|
|
|
|
|
class JWTBaseSerializerMixin(serializers.Serializer):
|
|
"""
|
|
Mixin for JWT authentication.
|
|
Uses in serializers when need give in response access and refresh token
|
|
"""
|
|
# RESPONSE
|
|
refresh_token = serializers.CharField(read_only=True)
|
|
access_token = serializers.CharField(read_only=True)
|
|
|
|
def get_token(self):
|
|
"""Create JWT token"""
|
|
user = self.instance
|
|
token = tokens.RefreshToken.for_user(user)
|
|
token['user'] = user.get_user_info()
|
|
return token
|
|
|
|
def to_representation(self, instance):
|
|
"""Override to_representation method"""
|
|
token = self.get_token()
|
|
setattr(instance, 'access_token', str(token.access_token))
|
|
setattr(instance, 'refresh_token', str(token))
|
|
return super().to_representation(instance)
|
|
|
|
|
|
class LoginSerializerMixin(BaseAuthSerializerMixin):
|
|
"""Mixin for login serializers"""
|
|
password = serializers.CharField(write_only=True)
|
|
|
|
|
|
class ClassicAuthSerializerMixin(BaseAuthSerializerMixin):
|
|
"""Classic authorization serializer mixin"""
|
|
password = serializers.CharField(write_only=True)
|
|
newsletter = serializers.BooleanField()
|
|
|
|
|
|
# Serializers
|
|
class SignupSerializer(serializers.ModelSerializer):
|
|
"""Signup serializer serializer mixin"""
|
|
# REQUEST
|
|
username = serializers.CharField(
|
|
validators=(rest_validators.UniqueValidator(queryset=account_models.User.objects.all()),),
|
|
write_only=True
|
|
)
|
|
password = serializers.CharField(write_only=True)
|
|
email = serializers.EmailField(
|
|
validators=(rest_validators.UniqueValidator(queryset=account_models.User.objects.all()),),
|
|
write_only=True)
|
|
newsletter = serializers.BooleanField(write_only=True)
|
|
|
|
class Meta:
|
|
model = account_models.User
|
|
fields = (
|
|
'username',
|
|
'password',
|
|
'email',
|
|
'newsletter'
|
|
)
|
|
|
|
def validate_username(self, data):
|
|
"""Custom username validation"""
|
|
valid = utils_methods.username_validator(username=data)
|
|
if not valid:
|
|
raise utils_exceptions.NotValidUsernameError()
|
|
return data
|
|
|
|
def validate_password(self, data):
|
|
"""Custom password validation"""
|
|
try:
|
|
password_validators.validate_password(password=data)
|
|
except serializers.ValidationError as e:
|
|
raise serializers.ValidationError(str(e))
|
|
else:
|
|
return data
|
|
|
|
def create(self, validated_data):
|
|
"""Override create method"""
|
|
obj = account_models.User.objects.make(
|
|
username=validated_data.get('username'),
|
|
password=validated_data.get('password'),
|
|
email=validated_data.get('email'),
|
|
newsletter=validated_data.get('newsletter')
|
|
)
|
|
return obj
|
|
|
|
|
|
class LoginByUsernameOrEmailSerializer(JWTBaseSerializerMixin, serializers.ModelSerializer):
|
|
"""Serializer for login user"""
|
|
username_or_email = serializers.CharField(write_only=True)
|
|
password = serializers.CharField(write_only=True)
|
|
|
|
# for cookie properties (Max-Age)
|
|
remember = serializers.BooleanField(write_only=True)
|
|
|
|
class Meta:
|
|
"""Meta-class"""
|
|
model = account_models.User
|
|
fields = (
|
|
'username_or_email',
|
|
'password',
|
|
'remember',
|
|
'refresh_token',
|
|
'access_token'
|
|
)
|
|
|
|
def validate(self, attrs):
|
|
"""Override validate method"""
|
|
username_or_email = attrs.pop('username_or_email')
|
|
password = attrs.pop('password')
|
|
user_qs = account_models.User.objects.filter(Q(username=username_or_email) |
|
|
Q(email=username_or_email))
|
|
if not user_qs.exists():
|
|
raise utils_exceptions.UserNotFoundError()
|
|
else:
|
|
user = user_qs.first()
|
|
authentication = authenticate(username=user.get_username(),
|
|
password=password)
|
|
if not authentication:
|
|
raise utils_exceptions.WrongAuthCredentials()
|
|
self.instance = user
|
|
return attrs
|
|
|
|
def to_representation(self, instance):
|
|
"""Override to_representation method"""
|
|
token = self.get_token()
|
|
setattr(instance, 'access_token', str(token.access_token))
|
|
setattr(instance, 'refresh_token', str(token))
|
|
# setattr(instance, 'remember', self.validated_data.get('remember'))
|
|
return super().to_representation(instance)
|
|
|
|
|
|
class RefreshTokenSerializer(serializers.Serializer):
|
|
"""Serializer for refresh token view"""
|
|
refresh_token = serializers.CharField()
|
|
access_token = serializers.CharField(read_only=True)
|
|
|
|
def validate(self, attrs):
|
|
"""Override validate method"""
|
|
token = tokens.RefreshToken(attrs['refresh_token'])
|
|
|
|
data = {'access_token': str(token.access_token)}
|
|
|
|
if JWT_SETTINGS.get('ROTATE_REFRESH_TOKENS'):
|
|
if JWT_SETTINGS.get('BLACKLIST_AFTER_ROTATION'):
|
|
try:
|
|
# Attempt to blacklist the given refresh token
|
|
token.blacklist()
|
|
except AttributeError:
|
|
# If blacklist app not installed, `blacklist` method will
|
|
# not be present
|
|
pass
|
|
|
|
token.set_jti()
|
|
token.set_exp()
|
|
|
|
data['refresh_token'] = str(token)
|
|
|
|
return data
|
|
|
|
|
|
class LogoutSerializer(serializers.ModelSerializer):
|
|
"""Serializer class for model Logout"""
|
|
|
|
class Meta:
|
|
model = BlacklistedAccessToken
|
|
fields = (
|
|
'user',
|
|
'token',
|
|
'jti'
|
|
)
|
|
read_only_fields = (
|
|
'user',
|
|
'token',
|
|
'jti'
|
|
)
|
|
|
|
def validate(self, attrs):
|
|
"""Override validated data"""
|
|
request = self.context.get('request')
|
|
# Get token bytes from cookies (result: b'Bearer <token>')
|
|
token_bytes = utils_methods.get_token_from_cookies(request)
|
|
# Get token value from bytes
|
|
token = token_bytes.decode().split(' ')[::-1][0]
|
|
# Get access token obj
|
|
access_token = tokens.AccessToken(token)
|
|
# Prepare validated data
|
|
attrs['user'] = request.user
|
|
attrs['token'] = access_token.token
|
|
attrs['jti'] = access_token.payload.get('jti')
|
|
return attrs
|
|
|
|
|
|
# OAuth
|
|
class OAuth2Serialzier(BaseAuthSerializerMixin):
|
|
"""Serializer OAuth2 authorization"""
|
|
token = serializers.CharField(max_length=255)
|
|
|
|
|
|
class OAuth2LogoutSerializer(BaseAuthSerializerMixin):
|
|
"""Serializer for logout"""
|