refactored authorization app

This commit is contained in:
Anatoly 2019-09-03 16:48:06 +03:00
parent b20fe5e6fb
commit 06e563b77a
13 changed files with 345 additions and 148 deletions

View File

@ -93,8 +93,10 @@ class User(ImageMixin, AbstractUser):
self.is_active = switcher self.is_active = switcher
self.save() self.save()
def remove_token(self): def expire_access_token(self, jti):
Token.objects.filter(user=self).delete() access_token_qs = self.access_tokens.by_jti(jti=jti)
if access_token_qs.exists():
access_token_qs.first().expire()
def confirm_email(self): def confirm_email(self):
"""Method to confirm user email address""" """Method to confirm user email address"""
@ -106,8 +108,19 @@ class User(ImageMixin, AbstractUser):
self.is_active = True self.is_active = True
self.save() self.save()
def revoke_access_token(self): def get_body_email_message(self, subject: str, message: str):
print('Revoke token') """Prepare the body of the email message"""
return {
'subject': subject,
'message': str(message),
'from_email': settings.EMAIL_HOST_USER,
'recipient_list': [self.email, ]
}
def send_email(self, subject: str, message: str):
"""Send an email to reset user password"""
send_mail(**self.get_body_email_message(subject=subject,
message=message))
@property @property
def confirm_email_token(self): def confirm_email_token(self):
@ -149,20 +162,6 @@ class User(ImageMixin, AbstractUser):
'domain_uri': settings.DOMAIN_URI, 'domain_uri': settings.DOMAIN_URI,
'site_name': settings.SITE_NAME}) 'site_name': settings.SITE_NAME})
def get_body_email_message(self, subject: str, message: str):
"""Prepare the body of the email message"""
return {
'subject': subject,
'message': str(message),
'from_email': settings.EMAIL_HOST_USER,
'recipient_list': [self.email, ]
}
def send_email(self, subject: str, message: str):
"""Send an email to reset user password"""
send_mail(**self.get_body_email_message(subject=subject,
message=message))
class ResetPasswordTokenQuerySet(models.QuerySet): class ResetPasswordTokenQuerySet(models.QuerySet):
"""Reset password token query set""" """Reset password token query set"""

View File

@ -1,13 +1,12 @@
"""Serializers for account web""" """Serializers for account web"""
from django.conf import settings from django.conf import settings
from django.contrib.auth import password_validation as password_validators from django.contrib.auth import password_validation as password_validators
from rest_framework_simplejwt import tokens
from django.db.models import Q from django.db.models import Q
from rest_framework import serializers from rest_framework import serializers
from django.utils import timezone
from account import models from account import models, tasks
from account import tasks
from utils import exceptions as utils_exceptions from utils import exceptions as utils_exceptions
from utils.tokens import GMRefreshToken
class PasswordResetSerializer(serializers.ModelSerializer): class PasswordResetSerializer(serializers.ModelSerializer):
@ -180,27 +179,34 @@ class RefreshTokenSerializer(serializers.Serializer):
def validate(self, attrs): def validate(self, attrs):
"""Override validate method""" """Override validate method"""
refresh_token = self.get_request().COOKIES.get('refresh_token') user = self.get_request().user
if not refresh_token: cookie_refresh_token = self.get_request().COOKIES.get('refresh_token')
# Check if refresh_token in COOKIES
if not cookie_refresh_token:
raise utils_exceptions.NotValidRefreshTokenError() raise utils_exceptions.NotValidRefreshTokenError()
token = tokens.RefreshToken(token=refresh_token) refresh_token = GMRefreshToken(cookie_refresh_token)
refresh_token_qs = user.refresh_tokens.valid()\
.by_jti(refresh_token.payload.get('jti'))
# Check if the user has refresh token
if not refresh_token_qs.exists():
raise utils_exceptions.NotValidRefreshTokenError()
data = {'access_token': str(token.access_token)} # Expire existing refresh token
old_refresh_token = refresh_token_qs.first()
old_refresh_token.expire()
if settings.SIMPLE_JWT.get('ROTATE_REFRESH_TOKENS'): # Expire existing access tokens
if settings.SIMPLE_JWT.get('BLACKLIST_AFTER_ROTATION'): user.access_tokens.by_refresh_token_jti(jti=old_refresh_token.jti)\
try: .valid()\
# Attempt to blacklist the given refresh token .update(expires_at=timezone.now())
token.blacklist()
except AttributeError:
# If blacklist app not installed, `blacklist` method will
# not be present
pass
token.set_jti() # Create new one for user
token.set_exp() refresh_token = GMRefreshToken.for_user(user)
refresh_token['user'] = user.get_user_info()
data['refresh_token'] = str(token) return {
'access_token': str(refresh_token.access_token),
return data 'refresh_token': str(refresh_token),
}

View File

@ -2,6 +2,7 @@
from fcm_django.models import FCMDevice from fcm_django.models import FCMDevice
from rest_framework import generics, status from rest_framework import generics, status
from rest_framework import permissions from rest_framework import permissions
from utils.permissions import IsAuthenticatedAndTokenIsValid
from rest_framework.response import Response from rest_framework.response import Response
from account import models from account import models

View File

@ -1,7 +1,2 @@
from django.contrib import admin from django.contrib import admin
from authorization import models from authorization import models
@admin.register(models.BlacklistedAccessToken)
class BlacklistedAccessTokenAdmin(admin.ModelAdmin):
"""Admin for BlackListedAccessToken"""

View File

@ -0,0 +1,53 @@
# Generated by Django 2.2.4 on 2019-09-03 11:58
from django.conf import settings
from django.db import migrations, models
import django.db.models.deletion
import django.utils.timezone
class Migration(migrations.Migration):
dependencies = [
migrations.swappable_dependency(settings.AUTH_USER_MODEL),
('authorization', '0002_blacklistedaccesstoken'),
]
operations = [
migrations.CreateModel(
name='JWTRefreshToken',
fields=[
('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
('created', models.DateTimeField(default=django.utils.timezone.now, editable=False, verbose_name='Date created')),
('modified', models.DateTimeField(auto_now=True, verbose_name='Date updated')),
('jti', models.CharField(max_length=255, unique=True)),
('created_at', models.DateTimeField(blank=True, null=True)),
('expires_at', models.DateTimeField()),
('user', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='refresh_tokens', to=settings.AUTH_USER_MODEL)),
],
options={
'verbose_name': 'Refresh token',
'verbose_name_plural': 'Refresh tokens',
'unique_together': {('user', 'jti')},
},
),
migrations.CreateModel(
name='JWTAccessToken',
fields=[
('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
('created', models.DateTimeField(default=django.utils.timezone.now, editable=False, verbose_name='Date created')),
('modified', models.DateTimeField(auto_now=True, verbose_name='Date updated')),
('source', models.PositiveSmallIntegerField(choices=[(0, 'Mobile'), (1, 'Web')], default=1, verbose_name='Source')),
('created_at', models.DateTimeField(blank=True, null=True)),
('expires_at', models.DateTimeField(verbose_name='Expiration datetime')),
('jti', models.CharField(max_length=255, unique=True)),
('refresh_token', models.ForeignKey(on_delete=django.db.models.deletion.DO_NOTHING, related_name='access_tokens', to='authorization.JWTRefreshToken')),
('user', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='access_tokens', to=settings.AUTH_USER_MODEL)),
],
options={
'verbose_name': 'Access token',
'verbose_name_plural': 'Access tokens',
'unique_together': {('user', 'jti')},
},
),
]

View File

@ -0,0 +1,16 @@
# Generated by Django 2.2.4 on 2019-09-03 11:58
from django.db import migrations
class Migration(migrations.Migration):
dependencies = [
('authorization', '0003_jwtaccesstoken_jwtrefreshtoken'),
]
operations = [
migrations.DeleteModel(
name='BlacklistedAccessToken',
),
]

View File

@ -1,9 +1,13 @@
from django.conf import settings
from django.db import models from django.db import models
from django.utils import timezone
from django.utils.translation import gettext_lazy as _
from oauth2_provider import models as oauth2_models from oauth2_provider import models as oauth2_models
from oauth2_provider.models import AbstractApplication from oauth2_provider.models import AbstractApplication
from django.utils.translation import gettext_lazy as _ from rest_framework_simplejwt import utils
from rest_framework_simplejwt.tokens import RefreshToken, AccessToken
from utils.models import PlatformMixin from utils.models import PlatformMixin, ProjectBaseMixin
# Create your models here. # Create your models here.
@ -33,40 +37,133 @@ class Application(PlatformMixin, AbstractApplication):
return (self.client_id,) return (self.client_id,)
class BlacklistedAccessTokenQuerySet(models.QuerySet): class JWTAccessTokenManager(models.Manager):
"""Queryset for model BlacklistedAccessToken""" """Manager for AccessToken model."""
def add_to_db(self, user, access_token: AccessToken, refresh_token: RefreshToken):
"""Create generated tokens to DB"""
refresh_token_qs = JWTRefreshToken.objects.filter(user=user,
jti=refresh_token.payload.get('jti'))
if refresh_token_qs.exists():
jti = access_token[settings.SIMPLE_JWT.get('JTI_CLAIM')]
exp = access_token['exp']
obj = self.model(
user=user,
jti=jti,
refresh_token=refresh_token_qs.first(),
created_at=access_token.current_time,
expires_at=utils.datetime_from_epoch(exp),
)
obj.save()
return obj
def by_user(self, user):
"""Filter by user"""
return self.filter(user=user)
def by_token(self, token): class JWTAccessTokenQuerySet(models.QuerySet):
"""Filter by token""" """QuerySets for AccessToken model."""
return self.filter(token=token)
def by_jti(self, jti): def valid(self):
"""Filter by unique access_token identifier""" """Returns only valid access tokens"""
return self.filter(expires_at__gte=timezone.now())
def by_jti(self, jti: str):
"""Filter by jti field"""
return self.filter(jti=jti)
def by_refresh_token_jti(self, jti):
"""Return all tokens by refresh token jti"""
return self.filter(refresh_token__jti=jti)
class JWTAccessToken(PlatformMixin, ProjectBaseMixin):
"""GM access token model."""
MOBILE = 0
WEB = 1
SOURCES = (
(MOBILE, _('Mobile')),
(WEB, _('Web')),
)
user = models.ForeignKey('account.User',
related_name='access_tokens',
on_delete=models.CASCADE)
source = models.PositiveSmallIntegerField(choices=SOURCES, default=WEB,
verbose_name=_('Source'))
refresh_token = models.ForeignKey('JWTRefreshToken',
related_name='access_tokens',
on_delete=models.DO_NOTHING)
created_at = models.DateTimeField(null=True, blank=True)
expires_at = models.DateTimeField(verbose_name=_('Expiration datetime'))
jti = models.CharField(unique=True, max_length=255)
objects = JWTAccessTokenManager.from_queryset(JWTAccessTokenQuerySet)()
class Meta:
"""Meta class."""
unique_together = ('user', 'jti')
verbose_name = _('Access token')
verbose_name_plural = _('Access tokens')
def __str__(self):
"""String representation method."""
return f'Access token JTI: {self.jti}'
def expire(self):
"""Expire access token."""
self.expires_at = timezone.now()
self.save()
class JWTRefreshTokenManager(models.Manager):
"""Manager for model RefreshToken."""
def add_to_db(self, user, token: RefreshToken):
"""Added generated refresh token to db"""
jti = token[settings.SIMPLE_JWT.get('JTI_CLAIM')]
exp = token['exp']
obj = self.model(
user=user,
jti=jti,
created_at=token.current_time,
expires_at=utils.datetime_from_epoch(exp),
)
obj.save()
return obj
class JWTRefreshTokenQuerySet(models.QuerySet):
"""QuerySets for model RefreshToken."""
def valid(self):
"""Return only balid refresh tokens"""
return self.filter(expires_at__gte=timezone.now())
def by_jti(self, jti: str):
"""Filter by jti field"""
return self.filter(jti=jti) return self.filter(jti=jti)
class BlacklistedAccessToken(models.Model): class JWTRefreshToken(ProjectBaseMixin):
"""GM refresh token model."""
user = models.ForeignKey('account.User', user = models.ForeignKey('account.User',
on_delete=models.CASCADE, related_name='refresh_tokens',
verbose_name=_('User')) on_delete=models.CASCADE)
jti = models.CharField(unique=True, max_length=255)
created_at = models.DateTimeField(null=True, blank=True)
expires_at = models.DateTimeField()
jti = models.CharField(max_length=255, unique=True, objects = JWTRefreshTokenManager.from_queryset(JWTRefreshTokenQuerySet)()
verbose_name=_('Unique access_token identifier'))
token = models.TextField(verbose_name=_('Access token'))
blacklisted_at = models.DateTimeField(auto_now_add=True,
verbose_name=_('Blacklisted datetime'))
objects = BlacklistedAccessTokenQuerySet.as_manager()
class Meta: class Meta:
"""Meta class""" """Meta class."""
unique_together = ('token', 'user') unique_together = ('user', 'jti')
verbose_name = _('Refresh token')
verbose_name_plural = _('Refresh tokens')
def __str__(self): def __str__(self):
return 'Blacklisted access token for {}'.format(self.user) """String representation method."""
return f'Refresh token JTI: {self.jti}'
def expire(self):
"""Expire refresh token."""
self.expires_at = timezone.now()
self.save()

View File

@ -5,14 +5,13 @@ from django.contrib.auth import password_validation as password_validators
from django.db.models import Q from django.db.models import Q
from rest_framework import serializers from rest_framework import serializers
from rest_framework import validators as rest_validators from rest_framework import validators as rest_validators
from authorization import tasks
# JWT
from rest_framework_simplejwt import tokens
from account import models as account_models from account import models as account_models
from authorization.models import Application, BlacklistedAccessToken from authorization import tasks
from authorization.models import Application
from utils import exceptions as utils_exceptions from utils import exceptions as utils_exceptions
from utils import methods as utils_methods from utils import methods as utils_methods
from utils import tokens
# Mixins # Mixins
@ -30,18 +29,21 @@ class JWTBaseSerializerMixin(serializers.Serializer):
refresh_token = serializers.CharField(read_only=True) refresh_token = serializers.CharField(read_only=True)
access_token = serializers.CharField(read_only=True) access_token = serializers.CharField(read_only=True)
def get_token(self): def get_tokens(self):
"""Create JWT token""" """Create JWT token"""
user = self.instance user = self.instance
token = tokens.RefreshToken.for_user(user) token = tokens.GMRefreshToken.for_user(user)
token['user'] = user.get_user_info() token['user'] = user.get_user_info()
return token return {
'access_token': str(token.access_token),
'refresh_token': str(token),
}
def to_representation(self, instance): def to_representation(self, instance):
"""Override to_representation method""" """Override to_representation method"""
token = self.get_token() tokens = self.get_tokens()
setattr(instance, 'access_token', str(token.access_token)) setattr(instance, 'access_token', tokens.get('access_token'))
setattr(instance, 'refresh_token', str(token)) setattr(instance, 'refresh_token', tokens.get('refresh_token'))
return super().to_representation(instance) return super().to_representation(instance)
@ -136,46 +138,6 @@ class LoginByUsernameOrEmailSerializer(JWTBaseSerializerMixin, serializers.Model
self.instance = user self.instance = user
return attrs return attrs
def to_representation(self, instance):
"""Override to_representation method"""
token = self.get_token()
setattr(instance, 'access_token', str(token.access_token))
setattr(instance, 'refresh_token', str(token))
# setattr(instance, 'remember', self.validated_data.get('remember'))
return super().to_representation(instance)
class LogoutSerializer(serializers.ModelSerializer):
"""Serializer class for model Logout"""
class Meta:
model = BlacklistedAccessToken
fields = (
'user',
'token',
'jti'
)
read_only_fields = (
'user',
'token',
'jti'
)
def create(self, validated_data):
"""Override create method"""
request = self.context.get('request')
# Get token bytes from cookies (result: b'Bearer <token>')
token_bytes = utils_methods.get_token_from_cookies(request)
# Get token value from bytes
token = token_bytes.decode().split(' ')[::-1][0]
# Get access token obj
access_token = tokens.AccessToken(token)
# Prepare validated data
validated_data['user'] = request.user
validated_data['token'] = access_token.token
validated_data['jti'] = access_token.payload.get('jti')
return super().create(validated_data)
# OAuth # OAuth
class OAuth2Serialzier(BaseAuthSerializerMixin): class OAuth2Serialzier(BaseAuthSerializerMixin):

View File

@ -13,15 +13,16 @@ from rest_framework import generics
from rest_framework import permissions from rest_framework import permissions
from rest_framework import status from rest_framework import status
from rest_framework.response import Response from rest_framework.response import Response
from rest_framework_simplejwt import tokens as jwt_tokens from rest_framework_simplejwt.tokens import AccessToken
from rest_framework_social_oauth2.oauth2_backends import KeepRequestCore from rest_framework_social_oauth2.oauth2_backends import KeepRequestCore
from rest_framework_social_oauth2.oauth2_endpoints import SocialTokenServer from rest_framework_social_oauth2.oauth2_endpoints import SocialTokenServer
from utils.models import GMTokenGenerator
from account.models import User from account.models import User
from authorization.models import Application from authorization.models import Application
from authorization.serializers import common as serializers from authorization.serializers import common as serializers
from utils import exceptions as utils_exceptions from utils import exceptions as utils_exceptions
from utils import tokens as jwt_tokens
from utils.models import GMTokenGenerator
from utils.views import (JWTGenericViewMixin, from utils.views import (JWTGenericViewMixin,
JWTCreateAPIView) JWTCreateAPIView)
@ -109,7 +110,7 @@ class OAuth2SignUpView(OAuth2ViewMixin, JWTGenericViewMixin):
def get_jwt_token(self, user: User): def get_jwt_token(self, user: User):
"""Get JWT token""" """Get JWT token"""
token = jwt_tokens.RefreshToken.for_user(user) token = jwt_tokens.GMRefreshToken.for_user(user)
# Adding additional information about user to payload # Adding additional information about user to payload
token['user'] = user.get_user_info() token['user'] = user.get_user_info()
return token return token
@ -218,11 +219,10 @@ class LoginByUsernameOrEmailView(JWTAuthViewMixin):
# Logout # Logout
class LogoutView(JWTGenericViewMixin): class LogoutView(JWTGenericViewMixin):
"""Logout user""" """Logout user"""
serializer_class = serializers.LogoutSerializer
def post(self, request, *args, **kwargs): def post(self, request, *args, **kwargs):
"""Override create method""" """Override create method"""
serializer = self.get_serializer(data=request.data) access_token = request.COOKIES.get('access_token')
serializer.is_valid(raise_exception=True) access_token_obj = AccessToken(access_token)
serializer.save() request.user.expire_access_token(jti=access_token_obj.payload.get('jti'))
return Response(status=status.HTTP_204_NO_CONTENT) return Response(status=status.HTTP_204_NO_CONTENT)

View File

@ -87,7 +87,7 @@ class NotValidAccessTokenError(exceptions.APIException):
class NotValidRefreshTokenError(exceptions.APIException): class NotValidRefreshTokenError(exceptions.APIException):
"""The exception should be thrown when refresh token is not valid """The exception should be thrown when refresh token is not valid
""" """
status_code = status.HTTP_400_BAD_REQUEST status_code = status.HTTP_401_UNAUTHORIZED
default_detail = _('Not valid refresh token') default_detail = _('Not valid refresh token')

View File

@ -1,10 +1,7 @@
"""Project custom permissions""" """Project custom permissions"""
from rest_framework.permissions import BasePermission from rest_framework.permissions import BasePermission
from rest_framework_simplejwt.exceptions import TokenBackendError
from authorization.models import BlacklistedAccessToken from rest_framework_simplejwt.tokens import AccessToken
from utils.exceptions import NotValidAccessTokenError
from utils.methods import get_token_from_cookies
class IsAuthenticatedAndTokenIsValid(BasePermission): class IsAuthenticatedAndTokenIsValid(BasePermission):
@ -15,17 +12,11 @@ class IsAuthenticatedAndTokenIsValid(BasePermission):
def has_permission(self, request, view): def has_permission(self, request, view):
"""Check permissions by access token and default REST permission IsAuthenticated""" """Check permissions by access token and default REST permission IsAuthenticated"""
user = request.user user = request.user
try: access_token = request.COOKIES.get('access_token')
if user and user.is_authenticated: if user.is_authenticated and access_token:
token_bytes = get_token_from_cookies(request) access_token = AccessToken(access_token)
# Get access token key valid_tokens = user.access_tokens.valid()\
token = token_bytes.decode().split(' ')[1] .by_jti(jti=access_token.payload.get('jti'))
# Check if user access token not expired return valid_tokens.exists()
blacklisted = BlacklistedAccessToken.objects.by_token(token) \
.by_user(user) \
.exists()
return not blacklisted
except TokenBackendError:
raise NotValidAccessTokenError()
else: else:
return False return False

77
apps/utils/tokens.py Normal file
View File

@ -0,0 +1,77 @@
"""Custom tokens based on django-rest-framework-simple-jwt"""
from rest_framework_simplejwt.settings import api_settings
from rest_framework_simplejwt.tokens import Token, AccessToken, RefreshToken, BlacklistMixin
from account.models import User
from authorization.models import JWTRefreshToken, JWTAccessToken
class GMToken(Token):
"""Custom JWT Token class"""
@classmethod
def for_user(cls, user):
"""
Returns an authorization token for the given user that will be provided
after authenticating the user's credentials.
"""
user_id = getattr(user, api_settings.USER_ID_FIELD)
if not isinstance(user_id, int):
user_id = str(user_id)
token = cls()
token[api_settings.USER_ID_CLAIM] = user_id
return token
class GMBlacklistMixin(BlacklistMixin):
"""
If the `rest_framework_simplejwt.token_blacklist` app was configured to be
used, tokens created from `BlacklistMixin` subclasses will insert
themselves into an outstanding token list and also check for their
membership in a token blacklist.
"""
@classmethod
def for_user(cls, user):
"""
Adds this token to the outstanding token list.
"""
token = super().for_user(user)
# Create a record in DB
JWTRefreshToken.objects.add_to_db(user=user, token=token)
return token
class GMRefreshToken(GMBlacklistMixin, GMToken, RefreshToken):
"""GM refresh token"""
@property
def access_token(self):
"""
Returns an access token created from this refresh token. Copies all
claims present in this refresh token to the new access token except
those claims listed in the `no_copy_claims` attribute.
"""
access_token = AccessToken()
# Use instantiation time of refresh token as relative timestamp for
# access token "exp" claim. This ensures that both a refresh and
# access token expire relative to the same time if they are created as
# a pair.
access_token.set_exp(from_time=self.current_time)
no_copy = self.no_copy_claims
for claim, value in self.payload.items():
if claim in no_copy:
continue
access_token[claim] = value
# Create a record in DB
user = User.objects.get(id=self.payload.get('user_id'))
JWTAccessToken.objects.add_to_db(user=user,
access_token=access_token,
refresh_token=self)
return access_token

View File

@ -4,7 +4,7 @@ from django.conf import settings
from rest_framework import generics from rest_framework import generics
from rest_framework import status from rest_framework import status
from rest_framework.response import Response from rest_framework.response import Response
from rest_framework_simplejwt import tokens from utils import tokens
# JWT # JWT
@ -21,7 +21,7 @@ class JWTGenericViewMixin(generics.GenericAPIView):
def _create_jwt_token(self, user) -> dict: def _create_jwt_token(self, user) -> dict:
"""Return dictionary with pairs access and refresh tokens""" """Return dictionary with pairs access and refresh tokens"""
token = tokens.RefreshToken.for_user(user) token = tokens.GMRefreshToken.for_user(user)
token['user'] = user.get_user_info() token['user'] = user.get_user_info()
return { return {
'access_token': str(token.access_token), 'access_token': str(token.access_token),