refactored: login, logout, refresh token, change email, confirm email, reset password endpoints

This commit is contained in:
Anatoly 2019-09-04 12:42:23 +03:00
parent 09cf0f1a06
commit 5403f0d325
15 changed files with 269 additions and 152 deletions

View File

@ -1,5 +1,4 @@
"""Account models""" """Account models"""
from typing import Union
from django.conf import settings from django.conf import settings
from django.contrib.auth.models import AbstractUser, UserManager as BaseUserManager from django.contrib.auth.models import AbstractUser, UserManager as BaseUserManager
@ -14,8 +13,9 @@ from django.utils.translation import ugettext_lazy as _
from rest_framework.authtoken.models import Token from rest_framework.authtoken.models import Token
from authorization.models import Application from authorization.models import Application
from utils.models import ImageMixin, ProjectBaseMixin, PlatformMixin
from utils.models import GMTokenGenerator from utils.models import GMTokenGenerator
from utils.models import ImageMixin, ProjectBaseMixin, PlatformMixin
from utils.tokens import GMRefreshToken
class UserManager(BaseUserManager): class UserManager(BaseUserManager):
@ -93,18 +93,21 @@ class User(ImageMixin, AbstractUser):
self.is_active = switcher self.is_active = switcher
self.save() self.save()
def expire_access_token(self, jti): def create_jwt_tokens(self, source: int):
# todo: add platform to func parameter """Create JWT tokens for user"""
platform = PlatformMixin.WEB token = GMRefreshToken.for_user_by_source(self, source)
access_token_qs = self.access_tokens.by_jti(jti=jti)\ return {
.by_platform(platform=platform) 'access_token': str(token.access_token),
if access_token_qs.exists(): 'refresh_token': str(token),
access_token_qs.first().expire() }
def expire_refresh_token(self, jti): def expire_access_tokens(self):
refresh_token_qs = self.refresh_tokens.by_jti(jti=jti) """Expire all access tokens"""
if refresh_token_qs.exists(): self.access_tokens.update(expires_at=timezone.now())
refresh_token_qs.first().expire()
def expire_refresh_tokens(self):
"""Expire all refresh tokens"""
self.refresh_tokens.update(expires_at=timezone.now())
def confirm_email(self): def confirm_email(self):
"""Method to confirm user email address""" """Method to confirm user email address"""

View File

@ -2,12 +2,13 @@
from django.conf import settings from django.conf import settings
from django.contrib.auth import password_validation as password_validators from django.contrib.auth import password_validation as password_validators
from django.db.models import Q from django.db.models import Q
from django.utils import timezone from django.utils.translation import gettext_lazy as _
from rest_framework import serializers from rest_framework import serializers
from account import models, tasks from account import models, tasks
from authorization.models import JWTRefreshToken from authorization.models import JWTRefreshToken
from utils import exceptions as utils_exceptions from utils import exceptions as utils_exceptions
from utils.serializers import SourceSerializerMixin
from utils.tokens import GMRefreshToken from utils.tokens import GMRefreshToken
@ -26,9 +27,11 @@ class PasswordResetSerializer(serializers.ModelSerializer):
def validate(self, attrs): def validate(self, attrs):
"""Override validate method""" """Override validate method"""
user = self.context.get('request').user user = self.context.get('request').user
username_or_email = attrs.pop('username_or_email')
if user.is_anonymous: if user.is_anonymous:
username_or_email = attrs.get('username_or_email')
if not username_or_email:
raise serializers.ValidationError(_('Username or Email not requested'))
# Check user in DB # Check user in DB
user_qs = models.User.objects.filter(Q(email=username_or_email) | user_qs = models.User.objects.filter(Q(email=username_or_email) |
Q(username=username_or_email)) Q(username=username_or_email))
@ -122,6 +125,10 @@ class ChangePasswordSerializer(serializers.ModelSerializer):
# Update user password from instance # Update user password from instance
instance.set_password(validated_data.get('password')) instance.set_password(validated_data.get('password'))
instance.save() instance.save()
# Expire tokens
instance.expire_access_tokens()
instance.expire_refresh_tokens()
return instance return instance
@ -132,17 +139,12 @@ class ChangeEmailSerializer(serializers.ModelSerializer):
"""Meta class""" """Meta class"""
model = models.User model = models.User
fields = ( fields = (
'id',
'email', 'email',
) )
read_only_fields = (
'id',
)
def validate_email(self, value): def validate_email(self, value):
"""Validate email value""" """Validate email value"""
if value == self.instance.email: if value == self.instance.email:
# todo: add custom exception
raise serializers.ValidationError() raise serializers.ValidationError()
return value return value
@ -150,7 +152,6 @@ class ChangeEmailSerializer(serializers.ModelSerializer):
"""Override validate method""" """Override validate method"""
email_confirmed = self.instance.email_confirmed email_confirmed = self.instance.email_confirmed
if not email_confirmed: if not email_confirmed:
# todo: add custom exception
raise serializers.ValidationError() raise serializers.ValidationError()
return attrs return attrs
@ -166,11 +167,39 @@ class ChangeEmailSerializer(serializers.ModelSerializer):
tasks.confirm_new_email_address.delay(instance.id) tasks.confirm_new_email_address.delay(instance.id)
else: else:
tasks.confirm_new_email_address(instance.id) tasks.confirm_new_email_address(instance.id)
instance.revoke_access_token()
return instance return instance
class RefreshTokenSerializer(serializers.Serializer): class ConfirmEmailSerializer(serializers.ModelSerializer):
"""Confirm user email serializer"""
class Meta:
"""Meta class"""
model = models.User
fields = (
'email',
)
def validate(self, attrs):
"""Override validate method"""
email_confirmed = self.instance.email_confirmed
if email_confirmed:
raise serializers.ValidationError()
return attrs
def update(self, instance, validated_data):
"""
Override update method
"""
# Send verification link on user email for change email address
if settings.USE_CELERY:
tasks.confirm_new_email_address.delay(instance.id)
else:
tasks.confirm_new_email_address(instance.id)
return instance
class RefreshTokenSerializer(SourceSerializerMixin):
"""Serializer for refresh token view""" """Serializer for refresh token view"""
refresh_token = serializers.CharField(read_only=True) refresh_token = serializers.CharField(read_only=True)
access_token = serializers.CharField(read_only=True) access_token = serializers.CharField(read_only=True)
@ -181,34 +210,28 @@ class RefreshTokenSerializer(serializers.Serializer):
def validate(self, attrs): def validate(self, attrs):
"""Override validate method""" """Override validate method"""
cookie_refresh_token = self.get_request().COOKIES.get('refresh_token')
cookie_refresh_token = self.get_request().COOKIES.get('refresh_token')
# Check if refresh_token in COOKIES # Check if refresh_token in COOKIES
if not cookie_refresh_token: if not cookie_refresh_token:
raise utils_exceptions.NotValidRefreshTokenError() raise utils_exceptions.NotValidRefreshTokenError()
refresh_token = GMRefreshToken(cookie_refresh_token) refresh_token = GMRefreshToken(cookie_refresh_token)
refresh_token_qs = JWTRefreshToken.objects.valid()\ refresh_token_qs = JWTRefreshToken.objects.valid() \
.by_jti(jti=refresh_token.payload.get('jti')) .by_jti(jti=refresh_token.payload.get('jti'))
# Check if the user has refresh token # Check if the user has refresh token
if not refresh_token_qs.exists(): if not refresh_token_qs.exists():
raise utils_exceptions.NotValidRefreshTokenError() raise utils_exceptions.NotValidRefreshTokenError()
# Expire existing refresh token
old_refresh_token = refresh_token_qs.first() old_refresh_token = refresh_token_qs.first()
source = old_refresh_token.source
user = old_refresh_token.user user = old_refresh_token.user
# Expire existing tokens # Expire existing tokens
old_refresh_token.expire() old_refresh_token.expire()
user.access_tokens.by_refresh_token_jti(jti=old_refresh_token.jti)\ old_refresh_token.access_token.expire()
.valid()\
.update(expires_at=timezone.now())
# Create new one for user # Create new one for user
refresh_token = GMRefreshToken.for_user(user) response = user.create_jwt_tokens(source=source)
refresh_token['user'] = user.get_user_info()
return { return response
'access_token': str(refresh_token.access_token),
'refresh_token': str(refresh_token),
}

View File

@ -17,6 +17,9 @@ urlpatterns_api = [
path('change-email/', views.ChangeEmailView.as_view(), name='change-email'), path('change-email/', views.ChangeEmailView.as_view(), name='change-email'),
path('change-email/confirm/<uidb64>/<token>/', views.ChangeEmailConfirmView.as_view(), path('change-email/confirm/<uidb64>/<token>/', views.ChangeEmailConfirmView.as_view(),
name='change-email-confirm'), name='change-email-confirm'),
path('confirm-email/', views.ConfirmEmailView.as_view(), name='confirm-email'),
path('confirm-email/<uidb64>/<token>/', views.ConfirmInactiveEmailView.as_view(),
name='inactive-email-confirm'),
] ]
urlpatterns = urlpatterns_api + \ urlpatterns = urlpatterns_api + \

View File

@ -12,9 +12,11 @@ from django.utils.translation import gettext_lazy as _
from django.views.decorators.cache import never_cache from django.views.decorators.cache import never_cache
from django.views.decorators.debug import sensitive_post_parameters from django.views.decorators.debug import sensitive_post_parameters
from django.views.generic.edit import FormView from django.views.generic.edit import FormView
from rest_framework import generics
from rest_framework import permissions from rest_framework import permissions
from rest_framework import status from rest_framework import status
from rest_framework import views from rest_framework import views
from rest_framework.permissions import AllowAny
from rest_framework.response import Response from rest_framework.response import Response
from account import models from account import models
@ -22,7 +24,6 @@ from account.forms import SetPasswordForm
from account.serializers import web as serializers from account.serializers import web as serializers
from utils import exceptions as utils_exceptions from utils import exceptions as utils_exceptions
from utils.models import GMTokenGenerator from utils.models import GMTokenGenerator
from utils.permissions import IsRefreshTokenValid
from utils.views import (JWTCreateAPIView, from utils.views import (JWTCreateAPIView,
JWTUpdateAPIView, JWTUpdateAPIView,
JWTGenericViewMixin) JWTGenericViewMixin)
@ -90,7 +91,7 @@ class ChangePasswordView(JWTUpdateAPIView):
class ChangeEmailView(JWTGenericViewMixin): class ChangeEmailView(JWTGenericViewMixin):
"""Change user email view""" """Change user email view."""
serializer_class = serializers.ChangeEmailSerializer serializer_class = serializers.ChangeEmailSerializer
queryset = models.User.objects.all() queryset = models.User.objects.all()
@ -105,11 +106,43 @@ class ChangeEmailView(JWTGenericViewMixin):
return Response(status=status.HTTP_200_OK) return Response(status=status.HTTP_200_OK)
class ConfirmEmailView(ChangeEmailView):
"""Confirm email view."""
serializer_class = serializers.ConfirmEmailSerializer
class ChangeEmailConfirmView(JWTGenericViewMixin): class ChangeEmailConfirmView(JWTGenericViewMixin):
"""View for confirm changing email""" """View for confirm changing email"""
permission_classes = (permissions.AllowAny,) permission_classes = (permissions.AllowAny,)
def get(self, request, *args, **kwargs):
"""Implement GET-method"""
uidb64 = kwargs.get('uidb64')
token = kwargs.get('token')
uid = force_text(urlsafe_base64_decode(uidb64))
user_qs = models.User.objects.filter(pk=uid)
if user_qs.exists():
user = user_qs.first()
if not GMTokenGenerator(GMTokenGenerator.CHANGE_EMAIL).check_token(
user, token):
raise utils_exceptions.NotValidTokenError()
# Approve email status
user.confirm_email()
# Expire user tokens
user.expire_access_tokens()
user.expire_refresh_tokens()
return Response(status=status.HTTP_200_OK)
else:
raise utils_exceptions.UserNotFoundError()
class ConfirmInactiveEmailView(generics.GenericAPIView):
"""View for confirm inactive email"""
permission_classes = (permissions.AllowAny,)
def get(self, request, *args, **kwargs): def get(self, request, *args, **kwargs):
"""Implement GET-method""" """Implement GET-method"""
uidb64 = kwargs.get('uidb64') uidb64 = kwargs.get('uidb64')
@ -130,7 +163,7 @@ class ChangeEmailConfirmView(JWTGenericViewMixin):
class RefreshTokenView(JWTGenericViewMixin): class RefreshTokenView(JWTGenericViewMixin):
"""Refresh access_token""" """Refresh access_token"""
permission_classes = (IsRefreshTokenValid, ) permission_classes = (AllowAny, )
serializer_class = serializers.RefreshTokenSerializer serializer_class = serializers.RefreshTokenSerializer
def post(self, request, *args, **kwargs): def post(self, request, *args, **kwargs):
@ -227,6 +260,12 @@ class FormPasswordResetConfirmView(PasswordContextMixin, FormView):
def form_valid(self, form): def form_valid(self, form):
# Saving form # Saving form
form.save() form.save()
user = form.user
# Expire user tokens
user.expire_access_tokens()
user.expire_refresh_tokens()
# Pop session token # Pop session token
del self.request.session[self.INTERNAL_RESET_SESSION_TOKEN] del self.request.session[self.INTERNAL_RESET_SESSION_TOKEN]
return super().form_valid(form) return super().form_valid(form)

View File

@ -0,0 +1,22 @@
# Generated by Django 2.2.4 on 2019-09-04 08:21
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('authorization', '0004_delete_blacklistedaccesstoken'),
]
operations = [
migrations.RemoveField(
model_name='jwtaccesstoken',
name='source',
),
migrations.AddField(
model_name='jwtrefreshtoken',
name='source',
field=models.PositiveSmallIntegerField(choices=[(0, 'Mobile'), (1, 'Web')], default=0, verbose_name='Source'),
),
]

View File

@ -0,0 +1,17 @@
# Generated by Django 2.2.4 on 2019-09-04 08:22
from django.db import migrations
class Migration(migrations.Migration):
dependencies = [
('authorization', '0005_auto_20190904_0821'),
]
operations = [
migrations.RemoveField(
model_name='jwtaccesstoken',
name='refresh_token',
),
]

View File

@ -0,0 +1,19 @@
# Generated by Django 2.2.4 on 2019-09-04 08:24
from django.db import migrations, models
import django.db.models.deletion
class Migration(migrations.Migration):
dependencies = [
('authorization', '0006_remove_jwtaccesstoken_refresh_token'),
]
operations = [
migrations.AddField(
model_name='jwtaccesstoken',
name='refresh_token',
field=models.OneToOneField(blank=True, default=None, null=True, on_delete=django.db.models.deletion.CASCADE, related_name='access_token', to='authorization.JWTRefreshToken'),
),
]

View File

@ -38,7 +38,8 @@ class Application(PlatformMixin, AbstractApplication):
class JWTAccessTokenManager(models.Manager): class JWTAccessTokenManager(models.Manager):
"""Manager for AccessToken model.""" """Manager for AccessToken model."""
def add_to_db(self, user, access_token: AccessToken, refresh_token: RefreshToken): def add_to_db(self, user, access_token: AccessToken,
refresh_token: RefreshToken):
"""Create generated tokens to DB""" """Create generated tokens to DB"""
refresh_token_qs = JWTRefreshToken.objects.filter(user=user, refresh_token_qs = JWTRefreshToken.objects.filter(user=user,
jti=refresh_token.payload.get('jti')) jti=refresh_token.payload.get('jti'))
@ -72,14 +73,15 @@ class JWTAccessTokenQuerySet(models.QuerySet):
return self.filter(refresh_token__jti=jti) return self.filter(refresh_token__jti=jti)
class JWTAccessToken(PlatformMixin, ProjectBaseMixin): class JWTAccessToken(ProjectBaseMixin):
"""GM access token model.""" """GM access token model."""
user = models.ForeignKey('account.User', user = models.ForeignKey('account.User',
related_name='access_tokens', related_name='access_tokens',
on_delete=models.CASCADE) on_delete=models.CASCADE)
refresh_token = models.ForeignKey('JWTRefreshToken', refresh_token = models.OneToOneField('JWTRefreshToken',
related_name='access_tokens', related_name='access_token',
on_delete=models.DO_NOTHING) on_delete=models.CASCADE,
null=True, blank=True, default=None)
created_at = models.DateTimeField(null=True, blank=True) created_at = models.DateTimeField(null=True, blank=True)
expires_at = models.DateTimeField(verbose_name=_('Expiration datetime')) expires_at = models.DateTimeField(verbose_name=_('Expiration datetime'))
jti = models.CharField(unique=True, max_length=255) jti = models.CharField(unique=True, max_length=255)
@ -105,13 +107,14 @@ class JWTAccessToken(PlatformMixin, ProjectBaseMixin):
class JWTRefreshTokenManager(models.Manager): class JWTRefreshTokenManager(models.Manager):
"""Manager for model RefreshToken.""" """Manager for model RefreshToken."""
def add_to_db(self, user, token: RefreshToken): def add_to_db(self, user, token: RefreshToken, source: int):
"""Added generated refresh token to db""" """Added generated refresh token to db"""
jti = token[settings.SIMPLE_JWT.get('JTI_CLAIM')] jti = token[settings.SIMPLE_JWT.get('JTI_CLAIM')]
exp = token['exp'] exp = token['exp']
obj = self.model( obj = self.model(
user=user, user=user,
jti=jti, jti=jti,
source=source,
created_at=token.current_time, created_at=token.current_time,
expires_at=utils.datetime_from_epoch(exp), expires_at=utils.datetime_from_epoch(exp),
) )
@ -123,15 +126,19 @@ class JWTRefreshTokenQuerySet(models.QuerySet):
"""QuerySets for model RefreshToken.""" """QuerySets for model RefreshToken."""
def valid(self): def valid(self):
"""Return only balid refresh tokens""" """Return only valid refresh tokens"""
return self.filter(expires_at__gte=timezone.now()) return self.filter(expires_at__gte=timezone.now())
def by_jti(self, jti: str): def by_jti(self, jti: str):
"""Filter by jti field""" """Filter by jti field"""
return self.filter(jti=jti) return self.filter(jti=jti)
def by_source(self, source):
"""Return access tokens by source"""
return self.filter(source=source)
class JWTRefreshToken(ProjectBaseMixin):
class JWTRefreshToken(PlatformMixin, ProjectBaseMixin):
"""GM refresh token model.""" """GM refresh token model."""
user = models.ForeignKey('account.User', user = models.ForeignKey('account.User',
related_name='refresh_tokens', related_name='refresh_tokens',

View File

@ -8,43 +8,9 @@ from rest_framework import validators as rest_validators
from account import models as account_models from account import models as account_models
from authorization import tasks from authorization import tasks
from authorization.models import Application
from utils import exceptions as utils_exceptions from utils import exceptions as utils_exceptions
from utils import methods as utils_methods from utils import methods as utils_methods
from utils import tokens from utils.serializers import SourceSerializerMixin
# Mixins
class BaseAuthSerializerMixin(serializers.Serializer):
"""Base authorization serializer mixin"""
source = serializers.ChoiceField(choices=Application.SOURCES)
class JWTBaseSerializerMixin(serializers.Serializer):
"""
Mixin for JWT authentication.
Uses in serializers when need give in response access and refresh token
"""
# RESPONSE
refresh_token = serializers.CharField(read_only=True)
access_token = serializers.CharField(read_only=True)
def get_tokens(self):
"""Create JWT token"""
user = self.instance
token = tokens.GMRefreshToken.for_user(user)
token['user'] = user.get_user_info()
return {
'access_token': str(token.access_token),
'refresh_token': str(token),
}
def to_representation(self, instance):
"""Override to_representation method"""
tokens = self.get_tokens()
setattr(instance, 'access_token', tokens.get('access_token'))
setattr(instance, 'refresh_token', tokens.get('refresh_token'))
return super().to_representation(instance)
# Serializers # Serializers
@ -101,14 +67,20 @@ class SignupSerializer(serializers.ModelSerializer):
return obj return obj
class LoginByUsernameOrEmailSerializer(JWTBaseSerializerMixin, serializers.ModelSerializer): class LoginByUsernameOrEmailSerializer(SourceSerializerMixin,
serializers.ModelSerializer):
"""Serializer for login user""" """Serializer for login user"""
# REQUEST
username_or_email = serializers.CharField(write_only=True) username_or_email = serializers.CharField(write_only=True)
password = serializers.CharField(write_only=True) password = serializers.CharField(write_only=True)
# for cookie properties (Max-Age) # For cookie properties (Max-Age)
remember = serializers.BooleanField(write_only=True) remember = serializers.BooleanField(write_only=True)
# RESPONSE
refresh_token = serializers.CharField(read_only=True)
access_token = serializers.CharField(read_only=True)
class Meta: class Meta:
"""Meta-class""" """Meta-class"""
model = account_models.User model = account_models.User
@ -116,8 +88,9 @@ class LoginByUsernameOrEmailSerializer(JWTBaseSerializerMixin, serializers.Model
'username_or_email', 'username_or_email',
'password', 'password',
'remember', 'remember',
'source',
'refresh_token', 'refresh_token',
'access_token' 'access_token',
) )
def validate(self, attrs): def validate(self, attrs):
@ -138,12 +111,23 @@ class LoginByUsernameOrEmailSerializer(JWTBaseSerializerMixin, serializers.Model
self.instance = user self.instance = user
return attrs return attrs
def to_representation(self, instance):
"""Override to_representation method"""
tokens = instance.create_jwt_tokens(source=self.validated_data.get('source'))
setattr(instance, 'access_token', tokens.get('access_token'))
setattr(instance, 'refresh_token', tokens.get('refresh_token'))
return super().to_representation(instance)
class LogoutSerializer(SourceSerializerMixin):
"""Serializer for Logout endpoint."""
# OAuth # OAuth
class OAuth2Serialzier(BaseAuthSerializerMixin): class OAuth2Serialzier(SourceSerializerMixin):
"""Serializer OAuth2 authorization""" """Serializer OAuth2 authorization"""
token = serializers.CharField(max_length=255) token = serializers.CharField(max_length=255)
class OAuth2LogoutSerializer(BaseAuthSerializerMixin): class OAuth2LogoutSerializer(SourceSerializerMixin):
"""Serializer for logout""" """Serializer for logout"""

View File

@ -3,7 +3,6 @@ import json
from braces.views import CsrfExemptMixin from braces.views import CsrfExemptMixin
from django.conf import settings from django.conf import settings
from django.utils import timezone
from django.utils.encoding import force_text from django.utils.encoding import force_text
from django.utils.http import urlsafe_base64_decode from django.utils.http import urlsafe_base64_decode
from django.utils.translation import gettext_lazy as _ from django.utils.translation import gettext_lazy as _
@ -14,16 +13,17 @@ from rest_framework import generics
from rest_framework import permissions from rest_framework import permissions
from rest_framework import status from rest_framework import status
from rest_framework.response import Response from rest_framework.response import Response
from rest_framework_simplejwt.tokens import AccessToken
from rest_framework_social_oauth2.oauth2_backends import KeepRequestCore from rest_framework_social_oauth2.oauth2_backends import KeepRequestCore
from rest_framework_social_oauth2.oauth2_endpoints import SocialTokenServer from rest_framework_social_oauth2.oauth2_endpoints import SocialTokenServer
from account.models import User from account.models import User
from authorization.models import Application from authorization.models import Application
from authorization.models import JWTRefreshToken from authorization.models import JWTAccessToken
from authorization.serializers import common as serializers from authorization.serializers import common as serializers
from utils import exceptions as utils_exceptions from utils import exceptions as utils_exceptions
from utils import tokens as jwt_tokens
from utils.models import GMTokenGenerator from utils.models import GMTokenGenerator
from utils.permissions import IsAuthenticatedAndTokenIsValid
from utils.views import (JWTGenericViewMixin, from utils.views import (JWTGenericViewMixin,
JWTCreateAPIView) JWTCreateAPIView)
@ -109,18 +109,12 @@ class OAuth2SignUpView(OAuth2ViewMixin, JWTGenericViewMixin):
permission_classes = (permissions.AllowAny, ) permission_classes = (permissions.AllowAny, )
serializer_class = serializers.OAuth2Serialzier serializer_class = serializers.OAuth2Serialzier
def get_jwt_token(self, user: User):
"""Get JWT token"""
token = jwt_tokens.GMRefreshToken.for_user(user)
# Adding additional information about user to payload
token['user'] = user.get_user_info()
return token
def post(self, request, *args, **kwargs): def post(self, request, *args, **kwargs):
# Preparing request data # Preparing request data
serializer = self.get_serializer(data=request.data) serializer = self.get_serializer(data=request.data)
serializer.is_valid(raise_exception=True) serializer.is_valid(raise_exception=True)
request_data = self.prepare_request_data(serializer.validated_data) request_data = self.prepare_request_data(serializer.validated_data)
source = serializer.validated_data.get('source')
request_data.update({ request_data.update({
'grant_type': settings.OAUTH2_SOCIAL_AUTH_GRANT_TYPE, 'grant_type': settings.OAUTH2_SOCIAL_AUTH_GRANT_TYPE,
'backend': settings.OAUTH2_SOCIAL_AUTH_BACKEND_NAME 'backend': settings.OAUTH2_SOCIAL_AUTH_BACKEND_NAME
@ -135,18 +129,17 @@ class OAuth2SignUpView(OAuth2ViewMixin, JWTGenericViewMixin):
url, headers, body, oauth2_status = self.create_token_response(request._request) url, headers, body, oauth2_status = self.create_token_response(request._request)
body = json.loads(body) body = json.loads(body)
# Get JWT token # Check OAuth2 response
if oauth2_status != status.HTTP_200_OK: if oauth2_status != status.HTTP_200_OK:
raise ValueError('status isn\'t 200') raise utils_exceptions.OAuth2Error()
# Get authenticated user # Get authenticated user
user = User.objects.by_oauth2_access_token(token=body.get('access_token'))\ user = User.objects.by_oauth2_access_token(token=body.get('access_token'))\
.first() .first()
# Create JWT token and put oauth2 token (access, refresh tokens) in payload # Create JWT token
token = self.get_jwt_token(user=user) tokens = user.create_jwt_tokens(source)
access_token = str(token.access_token) access_token, refresh_token = tokens.get('access_token'), tokens.get('refresh_token')
refresh_token = str(token)
response = Response(data={'access_token': access_token, response = Response(data={'access_token': access_token,
'refresh_token': refresh_token}, 'refresh_token': refresh_token},
status=status.HTTP_200_OK) status=status.HTTP_200_OK)
@ -220,22 +213,16 @@ class LoginByUsernameOrEmailView(JWTAuthViewMixin):
# Logout # Logout
class LogoutView(JWTGenericViewMixin): class LogoutView(JWTGenericViewMixin):
"""Logout user""" """Logout user"""
permission_classes = (IsAuthenticatedAndTokenIsValid, )
def post(self, request, *args, **kwargs): def post(self, request, *args, **kwargs):
"""Override create method""" """Override create method"""
# Get token objs by JTI # Get access token objs by JTI
refresh_token_key = request.COOKIES.get('refresh_token') access_token = AccessToken(request.COOKIES.get('access_token'))
refresh_token = jwt_tokens.GMRefreshToken(refresh_token_key) access_token_obj = JWTAccessToken.objects.get(jti=access_token.payload.get('jti'))
refresh_token_qs = JWTRefreshToken.objects.valid()\
.by_jti(jti=refresh_token.payload.get('jti'))
if not refresh_token_qs.exists():
raise utils_exceptions.NotValidRefreshTokenError()
refresh_token_obj = refresh_token_qs.first()
access_token_qs = refresh_token_obj.access_tokens
# Expire tokens # Expire tokens
refresh_token_obj.expire() access_token_obj.expire()
access_token_qs.update(expires_at=timezone.now()) access_token_obj.refresh_token.expire()
return Response(status=status.HTTP_204_NO_CONTENT) return Response(status=status.HTTP_204_NO_CONTENT)

View File

@ -16,20 +16,25 @@ class ProjectBaseException(exceptions.APIException):
super().__init__() super().__init__()
class AuthErrorMixin(exceptions.APIException):
"""Authentication exception error mixin."""
status_code = status.HTTP_401_UNAUTHORIZED
class ServiceError(ProjectBaseException): class ServiceError(ProjectBaseException):
"""Service error.""" """Service error."""
status_code = status.HTTP_503_SERVICE_UNAVAILABLE status_code = status.HTTP_503_SERVICE_UNAVAILABLE
default_detail = _('Service is temporarily unavailable') default_detail = _('Service is temporarily unavailable')
class UserNotFoundError(ProjectBaseException): class UserNotFoundError(AuthErrorMixin, ProjectBaseException):
"""The exception should be thrown when the user cannot get""" """The exception should be thrown when the user cannot get"""
status_code = status.HTTP_401_UNAUTHORIZED
default_detail = _('User not found') default_detail = _('User not found')
class PasswordRequestResetExists(ProjectBaseException): class PasswordRequestResetExists(ProjectBaseException):
"""The exception should be thrown when request for reset password """
The exception should be thrown when request for reset password
is already exists and valid is already exists and valid
""" """
status_code = status.HTTP_400_BAD_REQUEST status_code = status.HTTP_400_BAD_REQUEST
@ -50,7 +55,8 @@ class EmailSendingError(exceptions.APIException):
class LocaleNotExisted(exceptions.APIException): class LocaleNotExisted(exceptions.APIException):
"""The exception should be thrown when passed locale isn't in model Language """
The exception should be thrown when passed locale isn't in model Language
""" """
status_code = status.HTTP_400_BAD_REQUEST status_code = status.HTTP_400_BAD_REQUEST
default_detail = _('Locale not found in database (%s)') default_detail = _('Locale not found in database (%s)')
@ -64,49 +70,58 @@ class LocaleNotExisted(exceptions.APIException):
class NotValidUsernameError(exceptions.APIException): class NotValidUsernameError(exceptions.APIException):
"""The exception should be thrown when passed username has @ symbol """
The exception should be thrown when passed username has @ symbol
""" """
status_code = status.HTTP_400_BAD_REQUEST status_code = status.HTTP_400_BAD_REQUEST
default_detail = _('Wrong username') default_detail = _('Wrong username')
class NotValidTokenError(exceptions.APIException): class NotValidTokenError(exceptions.APIException):
"""The exception should be thrown when token in url is not valid """
The exception should be thrown when token in url is not valid
""" """
status_code = status.HTTP_400_BAD_REQUEST status_code = status.HTTP_400_BAD_REQUEST
default_detail = _('Not valid token') default_detail = _('Not valid token')
class NotValidAccessTokenError(exceptions.APIException): class NotValidAccessTokenError(AuthErrorMixin):
"""The exception should be thrown when access token in url is not valid """
The exception should be thrown when access token in url is not valid
""" """
status_code = status.HTTP_401_UNAUTHORIZED
default_detail = _('Not valid access token') default_detail = _('Not valid access token')
class NotValidRefreshTokenError(exceptions.APIException): class NotValidRefreshTokenError(AuthErrorMixin):
"""The exception should be thrown when refresh token is not valid """
The exception should be thrown when refresh token is not valid
""" """
status_code = status.HTTP_401_UNAUTHORIZED
default_detail = _('Not valid refresh token') default_detail = _('Not valid refresh token')
class OAuth2Error(AuthErrorMixin):
"""OAuth2 error"""
default_detail = _('OAuth2 Error')
class PasswordsAreEqual(exceptions.APIException): class PasswordsAreEqual(exceptions.APIException):
"""The exception should be raised when passed password is the same as old ones """
The exception should be raised when passed password is the same as old ones
""" """
status_code = status.HTTP_400_BAD_REQUEST status_code = status.HTTP_400_BAD_REQUEST
default_detail = _('Password is already in use') default_detail = _('Password is already in use')
class EmailConfirmedError(exceptions.APIException): class EmailConfirmedError(exceptions.APIException):
"""The exception should be raised when user email status is already confirmed """
The exception should be raised when user email status is already confirmed
""" """
status_code = status.HTTP_400_BAD_REQUEST status_code = status.HTTP_400_BAD_REQUEST
default_detail = _('Email address is already confirmed') default_detail = _('Email address is already confirmed')
class WrongAuthCredentials(exceptions.APIException): class WrongAuthCredentials(AuthErrorMixin):
"""The exception should be raised when credentials is not valid for this user """
The exception should be raised when credentials is not valid for this user
""" """
status_code = status.HTTP_401_UNAUTHORIZED
default_detail = _('Wrong authorization credentials') default_detail = _('Wrong authorization credentials')

View File

@ -1,9 +1,9 @@
"""Project custom permissions""" """Project custom permissions"""
from rest_framework.permissions import BasePermission from rest_framework.permissions import BasePermission
from authorization.models import JWTRefreshToken
from rest_framework_simplejwt.tokens import AccessToken from rest_framework_simplejwt.tokens import AccessToken
from utils.tokens import RefreshToken
from authorization.models import JWTRefreshToken
from utils.tokens import GMRefreshToken
class IsAuthenticatedAndTokenIsValid(BasePermission): class IsAuthenticatedAndTokenIsValid(BasePermission):
@ -32,7 +32,7 @@ class IsRefreshTokenValid(BasePermission):
"""Check permissions by refresh token and default REST permission IsAuthenticated""" """Check permissions by refresh token and default REST permission IsAuthenticated"""
refresh_token = request.COOKIES.get('refresh_token') refresh_token = request.COOKIES.get('refresh_token')
if refresh_token: if refresh_token:
refresh_token = RefreshToken(refresh_token) refresh_token = GMRefreshToken(refresh_token)
refresh_token_qs = JWTRefreshToken.objects.valid()\ refresh_token_qs = JWTRefreshToken.objects.valid()\
.by_jti(jti=refresh_token.payload.get('jti')) .by_jti(jti=refresh_token.payload.get('jti'))
return refresh_token_qs.exists() return refresh_token_qs.exists()

View File

@ -1,6 +1,15 @@
"""Utils app serializer.""" """Utils app serializer."""
from rest_framework import serializers from rest_framework import serializers
from utils.models import PlatformMixin
class EmptySerializer(serializers.Serializer): class EmptySerializer(serializers.Serializer):
"""Empty Serializer""" """Empty Serializer"""
class SourceSerializerMixin(serializers.Serializer):
"""Base authorization serializer mixin"""
source = serializers.ChoiceField(choices=PlatformMixin.SOURCES,
default=PlatformMixin.WEB,
write_only=True)

View File

@ -2,7 +2,6 @@
from rest_framework_simplejwt.settings import api_settings from rest_framework_simplejwt.settings import api_settings
from rest_framework_simplejwt.tokens import Token, AccessToken, RefreshToken, BlacklistMixin from rest_framework_simplejwt.tokens import Token, AccessToken, RefreshToken, BlacklistMixin
from account.models import User
from authorization.models import JWTRefreshToken, JWTAccessToken from authorization.models import JWTRefreshToken, JWTAccessToken
@ -34,13 +33,12 @@ class GMBlacklistMixin(BlacklistMixin):
""" """
@classmethod @classmethod
def for_user(cls, user): def for_user_by_source(cls, user, source: int):
""" """Create a refresh token."""
Adds this token to the outstanding token list.
"""
token = super().for_user(user) token = super().for_user(user)
token['user'] = user.get_user_info()
# Create a record in DB # Create a record in DB
JWTRefreshToken.objects.add_to_db(user=user, token=token) JWTRefreshToken.objects.add_to_db(user=user, token=token, source=source)
return token return token
@ -54,6 +52,8 @@ class GMRefreshToken(GMBlacklistMixin, GMToken, RefreshToken):
claims present in this refresh token to the new access token except claims present in this refresh token to the new access token except
those claims listed in the `no_copy_claims` attribute. those claims listed in the `no_copy_claims` attribute.
""" """
from account.models import User
access_token = AccessToken() access_token = AccessToken()
# Use instantiation time of refresh token as relative timestamp for # Use instantiation time of refresh token as relative timestamp for
@ -74,4 +74,3 @@ class GMRefreshToken(GMBlacklistMixin, GMToken, RefreshToken):
access_token=access_token, access_token=access_token,
refresh_token=self) refresh_token=self)
return access_token return access_token

View File

@ -4,7 +4,6 @@ from django.conf import settings
from rest_framework import generics from rest_framework import generics
from rest_framework import status from rest_framework import status
from rest_framework.response import Response from rest_framework.response import Response
from utils import tokens
# JWT # JWT
@ -19,15 +18,6 @@ class JWTGenericViewMixin(generics.GenericAPIView):
REFRESH_TOKEN_SECURE = False REFRESH_TOKEN_SECURE = False
COOKIE = namedtuple('COOKIE', ['key', 'value', 'http_only', 'secure', 'max_age']) COOKIE = namedtuple('COOKIE', ['key', 'value', 'http_only', 'secure', 'max_age'])
def _create_jwt_token(self, user) -> dict:
"""Return dictionary with pairs access and refresh tokens"""
token = tokens.GMRefreshToken.for_user(user)
token['user'] = user.get_user_info()
return {
'access_token': str(token.access_token),
'refresh_token': str(token),
}
def _put_data_in_cookies(self, def _put_data_in_cookies(self,
access_token: str = None, access_token: str = None,
refresh_token: str = None, refresh_token: str = None,