diff --git a/apps/account/models.py b/apps/account/models.py index 5299773c..e7531a43 100644 --- a/apps/account/models.py +++ b/apps/account/models.py @@ -94,10 +94,18 @@ class User(ImageMixin, AbstractUser): self.save() def expire_access_token(self, jti): - access_token_qs = self.access_tokens.by_jti(jti=jti) + # todo: add platform to func parameter + platform = PlatformMixin.WEB + access_token_qs = self.access_tokens.by_jti(jti=jti)\ + .by_platform(platform=platform) if access_token_qs.exists(): access_token_qs.first().expire() + def expire_refresh_token(self, jti): + refresh_token_qs = self.refresh_tokens.by_jti(jti=jti) + if refresh_token_qs.exists(): + refresh_token_qs.first().expire() + def confirm_email(self): """Method to confirm user email address""" self.email_confirmed = True diff --git a/apps/account/serializers/web.py b/apps/account/serializers/web.py index d4e7d1a8..7af38b9a 100644 --- a/apps/account/serializers/web.py +++ b/apps/account/serializers/web.py @@ -2,9 +2,11 @@ from django.conf import settings from django.contrib.auth import password_validation as password_validators from django.db.models import Q -from rest_framework import serializers from django.utils import timezone +from rest_framework import serializers + from account import models, tasks +from authorization.models import JWTRefreshToken from utils import exceptions as utils_exceptions from utils.tokens import GMRefreshToken @@ -179,7 +181,6 @@ class RefreshTokenSerializer(serializers.Serializer): def validate(self, attrs): """Override validate method""" - user = self.get_request().user cookie_refresh_token = self.get_request().COOKIES.get('refresh_token') # Check if refresh_token in COOKIES @@ -187,17 +188,18 @@ class RefreshTokenSerializer(serializers.Serializer): raise utils_exceptions.NotValidRefreshTokenError() refresh_token = GMRefreshToken(cookie_refresh_token) - refresh_token_qs = user.refresh_tokens.valid()\ - .by_jti(refresh_token.payload.get('jti')) + refresh_token_qs = JWTRefreshToken.objects.valid()\ + .by_jti(jti=refresh_token.payload.get('jti')) # Check if the user has refresh token if not refresh_token_qs.exists(): raise utils_exceptions.NotValidRefreshTokenError() # Expire existing refresh token old_refresh_token = refresh_token_qs.first() - old_refresh_token.expire() + user = old_refresh_token.user - # Expire existing access tokens + # Expire existing tokens + old_refresh_token.expire() user.access_tokens.by_refresh_token_jti(jti=old_refresh_token.jti)\ .valid()\ .update(expires_at=timezone.now()) diff --git a/apps/account/views/web.py b/apps/account/views/web.py index 10afe18d..3465b1f1 100644 --- a/apps/account/views/web.py +++ b/apps/account/views/web.py @@ -22,6 +22,7 @@ from account.forms import SetPasswordForm from account.serializers import web as serializers from utils import exceptions as utils_exceptions from utils.models import GMTokenGenerator +from utils.permissions import IsRefreshTokenValid from utils.views import (JWTCreateAPIView, JWTUpdateAPIView, JWTGenericViewMixin) @@ -129,7 +130,7 @@ class ChangeEmailConfirmView(JWTGenericViewMixin): class RefreshTokenView(JWTGenericViewMixin): """Refresh access_token""" - permission_classes = (permissions.IsAuthenticated,) + permission_classes = (IsRefreshTokenValid, ) serializer_class = serializers.RefreshTokenSerializer def post(self, request, *args, **kwargs): diff --git a/apps/authorization/models.py b/apps/authorization/models.py index dd525df2..73290063 100644 --- a/apps/authorization/models.py +++ b/apps/authorization/models.py @@ -10,7 +10,6 @@ from rest_framework_simplejwt.tokens import RefreshToken, AccessToken from utils.models import PlatformMixin, ProjectBaseMixin -# Create your models here. class ApplicationQuerySet(models.QuerySet): """Application queryset""" def get_by_natural_key(self, client_id): @@ -75,19 +74,9 @@ class JWTAccessTokenQuerySet(models.QuerySet): class JWTAccessToken(PlatformMixin, ProjectBaseMixin): """GM access token model.""" - MOBILE = 0 - WEB = 1 - - SOURCES = ( - (MOBILE, _('Mobile')), - (WEB, _('Web')), - ) - user = models.ForeignKey('account.User', related_name='access_tokens', on_delete=models.CASCADE) - source = models.PositiveSmallIntegerField(choices=SOURCES, default=WEB, - verbose_name=_('Source')) refresh_token = models.ForeignKey('JWTRefreshToken', related_name='access_tokens', on_delete=models.DO_NOTHING) diff --git a/apps/authorization/views/common.py b/apps/authorization/views/common.py index a7384b30..6591a643 100644 --- a/apps/authorization/views/common.py +++ b/apps/authorization/views/common.py @@ -3,6 +3,7 @@ import json from braces.views import CsrfExemptMixin from django.conf import settings +from django.utils import timezone from django.utils.encoding import force_text from django.utils.http import urlsafe_base64_decode from django.utils.translation import gettext_lazy as _ @@ -13,12 +14,12 @@ from rest_framework import generics from rest_framework import permissions from rest_framework import status from rest_framework.response import Response -from rest_framework_simplejwt.tokens import AccessToken from rest_framework_social_oauth2.oauth2_backends import KeepRequestCore from rest_framework_social_oauth2.oauth2_endpoints import SocialTokenServer from account.models import User from authorization.models import Application +from authorization.models import JWTRefreshToken from authorization.serializers import common as serializers from utils import exceptions as utils_exceptions from utils import tokens as jwt_tokens @@ -222,7 +223,19 @@ class LogoutView(JWTGenericViewMixin): def post(self, request, *args, **kwargs): """Override create method""" - access_token = request.COOKIES.get('access_token') - access_token_obj = AccessToken(access_token) - request.user.expire_access_token(jti=access_token_obj.payload.get('jti')) + # Get token objs by JTI + refresh_token_key = request.COOKIES.get('refresh_token') + refresh_token = jwt_tokens.GMRefreshToken(refresh_token_key) + refresh_token_qs = JWTRefreshToken.objects.valid()\ + .by_jti(jti=refresh_token.payload.get('jti')) + if not refresh_token_qs.exists(): + raise utils_exceptions.NotValidRefreshTokenError() + + refresh_token_obj = refresh_token_qs.first() + access_token_qs = refresh_token_obj.access_tokens + + # Expire tokens + refresh_token_obj.expire() + access_token_qs.update(expires_at=timezone.now()) + return Response(status=status.HTTP_204_NO_CONTENT) diff --git a/apps/utils/permissions.py b/apps/utils/permissions.py index 9731f1b6..9c3e814c 100644 --- a/apps/utils/permissions.py +++ b/apps/utils/permissions.py @@ -1,7 +1,9 @@ """Project custom permissions""" from rest_framework.permissions import BasePermission +from authorization.models import JWTRefreshToken from rest_framework_simplejwt.tokens import AccessToken +from utils.tokens import RefreshToken class IsAuthenticatedAndTokenIsValid(BasePermission): @@ -20,3 +22,19 @@ class IsAuthenticatedAndTokenIsValid(BasePermission): return valid_tokens.exists() else: return False + + +class IsRefreshTokenValid(BasePermission): + """ + Check if user has a valid refresh token and authenticated + """ + def has_permission(self, request, view): + """Check permissions by refresh token and default REST permission IsAuthenticated""" + refresh_token = request.COOKIES.get('refresh_token') + if refresh_token: + refresh_token = RefreshToken(refresh_token) + refresh_token_qs = JWTRefreshToken.objects.valid()\ + .by_jti(jti=refresh_token.payload.get('jti')) + return refresh_token_qs.exists() + else: + return False