(In progress) refactored logout and refresh token endpoints
This commit is contained in:
parent
06e563b77a
commit
09cf0f1a06
|
|
@ -94,10 +94,18 @@ class User(ImageMixin, AbstractUser):
|
||||||
self.save()
|
self.save()
|
||||||
|
|
||||||
def expire_access_token(self, jti):
|
def expire_access_token(self, jti):
|
||||||
access_token_qs = self.access_tokens.by_jti(jti=jti)
|
# todo: add platform to func parameter
|
||||||
|
platform = PlatformMixin.WEB
|
||||||
|
access_token_qs = self.access_tokens.by_jti(jti=jti)\
|
||||||
|
.by_platform(platform=platform)
|
||||||
if access_token_qs.exists():
|
if access_token_qs.exists():
|
||||||
access_token_qs.first().expire()
|
access_token_qs.first().expire()
|
||||||
|
|
||||||
|
def expire_refresh_token(self, jti):
|
||||||
|
refresh_token_qs = self.refresh_tokens.by_jti(jti=jti)
|
||||||
|
if refresh_token_qs.exists():
|
||||||
|
refresh_token_qs.first().expire()
|
||||||
|
|
||||||
def confirm_email(self):
|
def confirm_email(self):
|
||||||
"""Method to confirm user email address"""
|
"""Method to confirm user email address"""
|
||||||
self.email_confirmed = True
|
self.email_confirmed = True
|
||||||
|
|
|
||||||
|
|
@ -2,9 +2,11 @@
|
||||||
from django.conf import settings
|
from django.conf import settings
|
||||||
from django.contrib.auth import password_validation as password_validators
|
from django.contrib.auth import password_validation as password_validators
|
||||||
from django.db.models import Q
|
from django.db.models import Q
|
||||||
from rest_framework import serializers
|
|
||||||
from django.utils import timezone
|
from django.utils import timezone
|
||||||
|
from rest_framework import serializers
|
||||||
|
|
||||||
from account import models, tasks
|
from account import models, tasks
|
||||||
|
from authorization.models import JWTRefreshToken
|
||||||
from utils import exceptions as utils_exceptions
|
from utils import exceptions as utils_exceptions
|
||||||
from utils.tokens import GMRefreshToken
|
from utils.tokens import GMRefreshToken
|
||||||
|
|
||||||
|
|
@ -179,7 +181,6 @@ class RefreshTokenSerializer(serializers.Serializer):
|
||||||
|
|
||||||
def validate(self, attrs):
|
def validate(self, attrs):
|
||||||
"""Override validate method"""
|
"""Override validate method"""
|
||||||
user = self.get_request().user
|
|
||||||
cookie_refresh_token = self.get_request().COOKIES.get('refresh_token')
|
cookie_refresh_token = self.get_request().COOKIES.get('refresh_token')
|
||||||
|
|
||||||
# Check if refresh_token in COOKIES
|
# Check if refresh_token in COOKIES
|
||||||
|
|
@ -187,17 +188,18 @@ class RefreshTokenSerializer(serializers.Serializer):
|
||||||
raise utils_exceptions.NotValidRefreshTokenError()
|
raise utils_exceptions.NotValidRefreshTokenError()
|
||||||
|
|
||||||
refresh_token = GMRefreshToken(cookie_refresh_token)
|
refresh_token = GMRefreshToken(cookie_refresh_token)
|
||||||
refresh_token_qs = user.refresh_tokens.valid()\
|
refresh_token_qs = JWTRefreshToken.objects.valid()\
|
||||||
.by_jti(refresh_token.payload.get('jti'))
|
.by_jti(jti=refresh_token.payload.get('jti'))
|
||||||
# Check if the user has refresh token
|
# Check if the user has refresh token
|
||||||
if not refresh_token_qs.exists():
|
if not refresh_token_qs.exists():
|
||||||
raise utils_exceptions.NotValidRefreshTokenError()
|
raise utils_exceptions.NotValidRefreshTokenError()
|
||||||
|
|
||||||
# Expire existing refresh token
|
# Expire existing refresh token
|
||||||
old_refresh_token = refresh_token_qs.first()
|
old_refresh_token = refresh_token_qs.first()
|
||||||
old_refresh_token.expire()
|
user = old_refresh_token.user
|
||||||
|
|
||||||
# Expire existing access tokens
|
# Expire existing tokens
|
||||||
|
old_refresh_token.expire()
|
||||||
user.access_tokens.by_refresh_token_jti(jti=old_refresh_token.jti)\
|
user.access_tokens.by_refresh_token_jti(jti=old_refresh_token.jti)\
|
||||||
.valid()\
|
.valid()\
|
||||||
.update(expires_at=timezone.now())
|
.update(expires_at=timezone.now())
|
||||||
|
|
|
||||||
|
|
@ -22,6 +22,7 @@ from account.forms import SetPasswordForm
|
||||||
from account.serializers import web as serializers
|
from account.serializers import web as serializers
|
||||||
from utils import exceptions as utils_exceptions
|
from utils import exceptions as utils_exceptions
|
||||||
from utils.models import GMTokenGenerator
|
from utils.models import GMTokenGenerator
|
||||||
|
from utils.permissions import IsRefreshTokenValid
|
||||||
from utils.views import (JWTCreateAPIView,
|
from utils.views import (JWTCreateAPIView,
|
||||||
JWTUpdateAPIView,
|
JWTUpdateAPIView,
|
||||||
JWTGenericViewMixin)
|
JWTGenericViewMixin)
|
||||||
|
|
@ -129,7 +130,7 @@ class ChangeEmailConfirmView(JWTGenericViewMixin):
|
||||||
|
|
||||||
class RefreshTokenView(JWTGenericViewMixin):
|
class RefreshTokenView(JWTGenericViewMixin):
|
||||||
"""Refresh access_token"""
|
"""Refresh access_token"""
|
||||||
permission_classes = (permissions.IsAuthenticated,)
|
permission_classes = (IsRefreshTokenValid, )
|
||||||
serializer_class = serializers.RefreshTokenSerializer
|
serializer_class = serializers.RefreshTokenSerializer
|
||||||
|
|
||||||
def post(self, request, *args, **kwargs):
|
def post(self, request, *args, **kwargs):
|
||||||
|
|
|
||||||
|
|
@ -10,7 +10,6 @@ from rest_framework_simplejwt.tokens import RefreshToken, AccessToken
|
||||||
from utils.models import PlatformMixin, ProjectBaseMixin
|
from utils.models import PlatformMixin, ProjectBaseMixin
|
||||||
|
|
||||||
|
|
||||||
# Create your models here.
|
|
||||||
class ApplicationQuerySet(models.QuerySet):
|
class ApplicationQuerySet(models.QuerySet):
|
||||||
"""Application queryset"""
|
"""Application queryset"""
|
||||||
def get_by_natural_key(self, client_id):
|
def get_by_natural_key(self, client_id):
|
||||||
|
|
@ -75,19 +74,9 @@ class JWTAccessTokenQuerySet(models.QuerySet):
|
||||||
|
|
||||||
class JWTAccessToken(PlatformMixin, ProjectBaseMixin):
|
class JWTAccessToken(PlatformMixin, ProjectBaseMixin):
|
||||||
"""GM access token model."""
|
"""GM access token model."""
|
||||||
MOBILE = 0
|
|
||||||
WEB = 1
|
|
||||||
|
|
||||||
SOURCES = (
|
|
||||||
(MOBILE, _('Mobile')),
|
|
||||||
(WEB, _('Web')),
|
|
||||||
)
|
|
||||||
|
|
||||||
user = models.ForeignKey('account.User',
|
user = models.ForeignKey('account.User',
|
||||||
related_name='access_tokens',
|
related_name='access_tokens',
|
||||||
on_delete=models.CASCADE)
|
on_delete=models.CASCADE)
|
||||||
source = models.PositiveSmallIntegerField(choices=SOURCES, default=WEB,
|
|
||||||
verbose_name=_('Source'))
|
|
||||||
refresh_token = models.ForeignKey('JWTRefreshToken',
|
refresh_token = models.ForeignKey('JWTRefreshToken',
|
||||||
related_name='access_tokens',
|
related_name='access_tokens',
|
||||||
on_delete=models.DO_NOTHING)
|
on_delete=models.DO_NOTHING)
|
||||||
|
|
|
||||||
|
|
@ -3,6 +3,7 @@ import json
|
||||||
|
|
||||||
from braces.views import CsrfExemptMixin
|
from braces.views import CsrfExemptMixin
|
||||||
from django.conf import settings
|
from django.conf import settings
|
||||||
|
from django.utils import timezone
|
||||||
from django.utils.encoding import force_text
|
from django.utils.encoding import force_text
|
||||||
from django.utils.http import urlsafe_base64_decode
|
from django.utils.http import urlsafe_base64_decode
|
||||||
from django.utils.translation import gettext_lazy as _
|
from django.utils.translation import gettext_lazy as _
|
||||||
|
|
@ -13,12 +14,12 @@ from rest_framework import generics
|
||||||
from rest_framework import permissions
|
from rest_framework import permissions
|
||||||
from rest_framework import status
|
from rest_framework import status
|
||||||
from rest_framework.response import Response
|
from rest_framework.response import Response
|
||||||
from rest_framework_simplejwt.tokens import AccessToken
|
|
||||||
from rest_framework_social_oauth2.oauth2_backends import KeepRequestCore
|
from rest_framework_social_oauth2.oauth2_backends import KeepRequestCore
|
||||||
from rest_framework_social_oauth2.oauth2_endpoints import SocialTokenServer
|
from rest_framework_social_oauth2.oauth2_endpoints import SocialTokenServer
|
||||||
|
|
||||||
from account.models import User
|
from account.models import User
|
||||||
from authorization.models import Application
|
from authorization.models import Application
|
||||||
|
from authorization.models import JWTRefreshToken
|
||||||
from authorization.serializers import common as serializers
|
from authorization.serializers import common as serializers
|
||||||
from utils import exceptions as utils_exceptions
|
from utils import exceptions as utils_exceptions
|
||||||
from utils import tokens as jwt_tokens
|
from utils import tokens as jwt_tokens
|
||||||
|
|
@ -222,7 +223,19 @@ class LogoutView(JWTGenericViewMixin):
|
||||||
|
|
||||||
def post(self, request, *args, **kwargs):
|
def post(self, request, *args, **kwargs):
|
||||||
"""Override create method"""
|
"""Override create method"""
|
||||||
access_token = request.COOKIES.get('access_token')
|
# Get token objs by JTI
|
||||||
access_token_obj = AccessToken(access_token)
|
refresh_token_key = request.COOKIES.get('refresh_token')
|
||||||
request.user.expire_access_token(jti=access_token_obj.payload.get('jti'))
|
refresh_token = jwt_tokens.GMRefreshToken(refresh_token_key)
|
||||||
|
refresh_token_qs = JWTRefreshToken.objects.valid()\
|
||||||
|
.by_jti(jti=refresh_token.payload.get('jti'))
|
||||||
|
if not refresh_token_qs.exists():
|
||||||
|
raise utils_exceptions.NotValidRefreshTokenError()
|
||||||
|
|
||||||
|
refresh_token_obj = refresh_token_qs.first()
|
||||||
|
access_token_qs = refresh_token_obj.access_tokens
|
||||||
|
|
||||||
|
# Expire tokens
|
||||||
|
refresh_token_obj.expire()
|
||||||
|
access_token_qs.update(expires_at=timezone.now())
|
||||||
|
|
||||||
return Response(status=status.HTTP_204_NO_CONTENT)
|
return Response(status=status.HTTP_204_NO_CONTENT)
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,9 @@
|
||||||
"""Project custom permissions"""
|
"""Project custom permissions"""
|
||||||
from rest_framework.permissions import BasePermission
|
from rest_framework.permissions import BasePermission
|
||||||
|
from authorization.models import JWTRefreshToken
|
||||||
|
|
||||||
from rest_framework_simplejwt.tokens import AccessToken
|
from rest_framework_simplejwt.tokens import AccessToken
|
||||||
|
from utils.tokens import RefreshToken
|
||||||
|
|
||||||
|
|
||||||
class IsAuthenticatedAndTokenIsValid(BasePermission):
|
class IsAuthenticatedAndTokenIsValid(BasePermission):
|
||||||
|
|
@ -20,3 +22,19 @@ class IsAuthenticatedAndTokenIsValid(BasePermission):
|
||||||
return valid_tokens.exists()
|
return valid_tokens.exists()
|
||||||
else:
|
else:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
class IsRefreshTokenValid(BasePermission):
|
||||||
|
"""
|
||||||
|
Check if user has a valid refresh token and authenticated
|
||||||
|
"""
|
||||||
|
def has_permission(self, request, view):
|
||||||
|
"""Check permissions by refresh token and default REST permission IsAuthenticated"""
|
||||||
|
refresh_token = request.COOKIES.get('refresh_token')
|
||||||
|
if refresh_token:
|
||||||
|
refresh_token = RefreshToken(refresh_token)
|
||||||
|
refresh_token_qs = JWTRefreshToken.objects.valid()\
|
||||||
|
.by_jti(jti=refresh_token.payload.get('jti'))
|
||||||
|
return refresh_token_qs.exists()
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user