refactored authorization app
This commit is contained in:
parent
b20fe5e6fb
commit
06e563b77a
|
|
@ -93,8 +93,10 @@ class User(ImageMixin, AbstractUser):
|
||||||
self.is_active = switcher
|
self.is_active = switcher
|
||||||
self.save()
|
self.save()
|
||||||
|
|
||||||
def remove_token(self):
|
def expire_access_token(self, jti):
|
||||||
Token.objects.filter(user=self).delete()
|
access_token_qs = self.access_tokens.by_jti(jti=jti)
|
||||||
|
if access_token_qs.exists():
|
||||||
|
access_token_qs.first().expire()
|
||||||
|
|
||||||
def confirm_email(self):
|
def confirm_email(self):
|
||||||
"""Method to confirm user email address"""
|
"""Method to confirm user email address"""
|
||||||
|
|
@ -106,8 +108,19 @@ class User(ImageMixin, AbstractUser):
|
||||||
self.is_active = True
|
self.is_active = True
|
||||||
self.save()
|
self.save()
|
||||||
|
|
||||||
def revoke_access_token(self):
|
def get_body_email_message(self, subject: str, message: str):
|
||||||
print('Revoke token')
|
"""Prepare the body of the email message"""
|
||||||
|
return {
|
||||||
|
'subject': subject,
|
||||||
|
'message': str(message),
|
||||||
|
'from_email': settings.EMAIL_HOST_USER,
|
||||||
|
'recipient_list': [self.email, ]
|
||||||
|
}
|
||||||
|
|
||||||
|
def send_email(self, subject: str, message: str):
|
||||||
|
"""Send an email to reset user password"""
|
||||||
|
send_mail(**self.get_body_email_message(subject=subject,
|
||||||
|
message=message))
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def confirm_email_token(self):
|
def confirm_email_token(self):
|
||||||
|
|
@ -149,20 +162,6 @@ class User(ImageMixin, AbstractUser):
|
||||||
'domain_uri': settings.DOMAIN_URI,
|
'domain_uri': settings.DOMAIN_URI,
|
||||||
'site_name': settings.SITE_NAME})
|
'site_name': settings.SITE_NAME})
|
||||||
|
|
||||||
def get_body_email_message(self, subject: str, message: str):
|
|
||||||
"""Prepare the body of the email message"""
|
|
||||||
return {
|
|
||||||
'subject': subject,
|
|
||||||
'message': str(message),
|
|
||||||
'from_email': settings.EMAIL_HOST_USER,
|
|
||||||
'recipient_list': [self.email, ]
|
|
||||||
}
|
|
||||||
|
|
||||||
def send_email(self, subject: str, message: str):
|
|
||||||
"""Send an email to reset user password"""
|
|
||||||
send_mail(**self.get_body_email_message(subject=subject,
|
|
||||||
message=message))
|
|
||||||
|
|
||||||
|
|
||||||
class ResetPasswordTokenQuerySet(models.QuerySet):
|
class ResetPasswordTokenQuerySet(models.QuerySet):
|
||||||
"""Reset password token query set"""
|
"""Reset password token query set"""
|
||||||
|
|
|
||||||
|
|
@ -1,13 +1,12 @@
|
||||||
"""Serializers for account web"""
|
"""Serializers for account web"""
|
||||||
from django.conf import settings
|
from django.conf import settings
|
||||||
from django.contrib.auth import password_validation as password_validators
|
from django.contrib.auth import password_validation as password_validators
|
||||||
from rest_framework_simplejwt import tokens
|
|
||||||
from django.db.models import Q
|
from django.db.models import Q
|
||||||
from rest_framework import serializers
|
from rest_framework import serializers
|
||||||
|
from django.utils import timezone
|
||||||
from account import models
|
from account import models, tasks
|
||||||
from account import tasks
|
|
||||||
from utils import exceptions as utils_exceptions
|
from utils import exceptions as utils_exceptions
|
||||||
|
from utils.tokens import GMRefreshToken
|
||||||
|
|
||||||
|
|
||||||
class PasswordResetSerializer(serializers.ModelSerializer):
|
class PasswordResetSerializer(serializers.ModelSerializer):
|
||||||
|
|
@ -180,27 +179,34 @@ class RefreshTokenSerializer(serializers.Serializer):
|
||||||
|
|
||||||
def validate(self, attrs):
|
def validate(self, attrs):
|
||||||
"""Override validate method"""
|
"""Override validate method"""
|
||||||
refresh_token = self.get_request().COOKIES.get('refresh_token')
|
user = self.get_request().user
|
||||||
if not refresh_token:
|
cookie_refresh_token = self.get_request().COOKIES.get('refresh_token')
|
||||||
|
|
||||||
|
# Check if refresh_token in COOKIES
|
||||||
|
if not cookie_refresh_token:
|
||||||
raise utils_exceptions.NotValidRefreshTokenError()
|
raise utils_exceptions.NotValidRefreshTokenError()
|
||||||
|
|
||||||
token = tokens.RefreshToken(token=refresh_token)
|
refresh_token = GMRefreshToken(cookie_refresh_token)
|
||||||
|
refresh_token_qs = user.refresh_tokens.valid()\
|
||||||
|
.by_jti(refresh_token.payload.get('jti'))
|
||||||
|
# Check if the user has refresh token
|
||||||
|
if not refresh_token_qs.exists():
|
||||||
|
raise utils_exceptions.NotValidRefreshTokenError()
|
||||||
|
|
||||||
data = {'access_token': str(token.access_token)}
|
# Expire existing refresh token
|
||||||
|
old_refresh_token = refresh_token_qs.first()
|
||||||
|
old_refresh_token.expire()
|
||||||
|
|
||||||
if settings.SIMPLE_JWT.get('ROTATE_REFRESH_TOKENS'):
|
# Expire existing access tokens
|
||||||
if settings.SIMPLE_JWT.get('BLACKLIST_AFTER_ROTATION'):
|
user.access_tokens.by_refresh_token_jti(jti=old_refresh_token.jti)\
|
||||||
try:
|
.valid()\
|
||||||
# Attempt to blacklist the given refresh token
|
.update(expires_at=timezone.now())
|
||||||
token.blacklist()
|
|
||||||
except AttributeError:
|
|
||||||
# If blacklist app not installed, `blacklist` method will
|
|
||||||
# not be present
|
|
||||||
pass
|
|
||||||
|
|
||||||
token.set_jti()
|
# Create new one for user
|
||||||
token.set_exp()
|
refresh_token = GMRefreshToken.for_user(user)
|
||||||
|
refresh_token['user'] = user.get_user_info()
|
||||||
|
|
||||||
data['refresh_token'] = str(token)
|
return {
|
||||||
|
'access_token': str(refresh_token.access_token),
|
||||||
return data
|
'refresh_token': str(refresh_token),
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -2,6 +2,7 @@
|
||||||
from fcm_django.models import FCMDevice
|
from fcm_django.models import FCMDevice
|
||||||
from rest_framework import generics, status
|
from rest_framework import generics, status
|
||||||
from rest_framework import permissions
|
from rest_framework import permissions
|
||||||
|
from utils.permissions import IsAuthenticatedAndTokenIsValid
|
||||||
from rest_framework.response import Response
|
from rest_framework.response import Response
|
||||||
|
|
||||||
from account import models
|
from account import models
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,2 @@
|
||||||
from django.contrib import admin
|
from django.contrib import admin
|
||||||
from authorization import models
|
from authorization import models
|
||||||
|
|
||||||
|
|
||||||
@admin.register(models.BlacklistedAccessToken)
|
|
||||||
class BlacklistedAccessTokenAdmin(admin.ModelAdmin):
|
|
||||||
"""Admin for BlackListedAccessToken"""
|
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,53 @@
|
||||||
|
# Generated by Django 2.2.4 on 2019-09-03 11:58
|
||||||
|
|
||||||
|
from django.conf import settings
|
||||||
|
from django.db import migrations, models
|
||||||
|
import django.db.models.deletion
|
||||||
|
import django.utils.timezone
|
||||||
|
|
||||||
|
|
||||||
|
class Migration(migrations.Migration):
|
||||||
|
|
||||||
|
dependencies = [
|
||||||
|
migrations.swappable_dependency(settings.AUTH_USER_MODEL),
|
||||||
|
('authorization', '0002_blacklistedaccesstoken'),
|
||||||
|
]
|
||||||
|
|
||||||
|
operations = [
|
||||||
|
migrations.CreateModel(
|
||||||
|
name='JWTRefreshToken',
|
||||||
|
fields=[
|
||||||
|
('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
|
||||||
|
('created', models.DateTimeField(default=django.utils.timezone.now, editable=False, verbose_name='Date created')),
|
||||||
|
('modified', models.DateTimeField(auto_now=True, verbose_name='Date updated')),
|
||||||
|
('jti', models.CharField(max_length=255, unique=True)),
|
||||||
|
('created_at', models.DateTimeField(blank=True, null=True)),
|
||||||
|
('expires_at', models.DateTimeField()),
|
||||||
|
('user', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='refresh_tokens', to=settings.AUTH_USER_MODEL)),
|
||||||
|
],
|
||||||
|
options={
|
||||||
|
'verbose_name': 'Refresh token',
|
||||||
|
'verbose_name_plural': 'Refresh tokens',
|
||||||
|
'unique_together': {('user', 'jti')},
|
||||||
|
},
|
||||||
|
),
|
||||||
|
migrations.CreateModel(
|
||||||
|
name='JWTAccessToken',
|
||||||
|
fields=[
|
||||||
|
('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
|
||||||
|
('created', models.DateTimeField(default=django.utils.timezone.now, editable=False, verbose_name='Date created')),
|
||||||
|
('modified', models.DateTimeField(auto_now=True, verbose_name='Date updated')),
|
||||||
|
('source', models.PositiveSmallIntegerField(choices=[(0, 'Mobile'), (1, 'Web')], default=1, verbose_name='Source')),
|
||||||
|
('created_at', models.DateTimeField(blank=True, null=True)),
|
||||||
|
('expires_at', models.DateTimeField(verbose_name='Expiration datetime')),
|
||||||
|
('jti', models.CharField(max_length=255, unique=True)),
|
||||||
|
('refresh_token', models.ForeignKey(on_delete=django.db.models.deletion.DO_NOTHING, related_name='access_tokens', to='authorization.JWTRefreshToken')),
|
||||||
|
('user', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='access_tokens', to=settings.AUTH_USER_MODEL)),
|
||||||
|
],
|
||||||
|
options={
|
||||||
|
'verbose_name': 'Access token',
|
||||||
|
'verbose_name_plural': 'Access tokens',
|
||||||
|
'unique_together': {('user', 'jti')},
|
||||||
|
},
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
@ -0,0 +1,16 @@
|
||||||
|
# Generated by Django 2.2.4 on 2019-09-03 11:58
|
||||||
|
|
||||||
|
from django.db import migrations
|
||||||
|
|
||||||
|
|
||||||
|
class Migration(migrations.Migration):
|
||||||
|
|
||||||
|
dependencies = [
|
||||||
|
('authorization', '0003_jwtaccesstoken_jwtrefreshtoken'),
|
||||||
|
]
|
||||||
|
|
||||||
|
operations = [
|
||||||
|
migrations.DeleteModel(
|
||||||
|
name='BlacklistedAccessToken',
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
@ -1,9 +1,13 @@
|
||||||
|
from django.conf import settings
|
||||||
from django.db import models
|
from django.db import models
|
||||||
|
from django.utils import timezone
|
||||||
|
from django.utils.translation import gettext_lazy as _
|
||||||
from oauth2_provider import models as oauth2_models
|
from oauth2_provider import models as oauth2_models
|
||||||
from oauth2_provider.models import AbstractApplication
|
from oauth2_provider.models import AbstractApplication
|
||||||
from django.utils.translation import gettext_lazy as _
|
from rest_framework_simplejwt import utils
|
||||||
|
from rest_framework_simplejwt.tokens import RefreshToken, AccessToken
|
||||||
|
|
||||||
from utils.models import PlatformMixin
|
from utils.models import PlatformMixin, ProjectBaseMixin
|
||||||
|
|
||||||
|
|
||||||
# Create your models here.
|
# Create your models here.
|
||||||
|
|
@ -33,40 +37,133 @@ class Application(PlatformMixin, AbstractApplication):
|
||||||
return (self.client_id,)
|
return (self.client_id,)
|
||||||
|
|
||||||
|
|
||||||
class BlacklistedAccessTokenQuerySet(models.QuerySet):
|
class JWTAccessTokenManager(models.Manager):
|
||||||
"""Queryset for model BlacklistedAccessToken"""
|
"""Manager for AccessToken model."""
|
||||||
|
def add_to_db(self, user, access_token: AccessToken, refresh_token: RefreshToken):
|
||||||
|
"""Create generated tokens to DB"""
|
||||||
|
refresh_token_qs = JWTRefreshToken.objects.filter(user=user,
|
||||||
|
jti=refresh_token.payload.get('jti'))
|
||||||
|
if refresh_token_qs.exists():
|
||||||
|
jti = access_token[settings.SIMPLE_JWT.get('JTI_CLAIM')]
|
||||||
|
exp = access_token['exp']
|
||||||
|
obj = self.model(
|
||||||
|
user=user,
|
||||||
|
jti=jti,
|
||||||
|
refresh_token=refresh_token_qs.first(),
|
||||||
|
created_at=access_token.current_time,
|
||||||
|
expires_at=utils.datetime_from_epoch(exp),
|
||||||
|
)
|
||||||
|
obj.save()
|
||||||
|
return obj
|
||||||
|
|
||||||
def by_user(self, user):
|
|
||||||
"""Filter by user"""
|
|
||||||
return self.filter(user=user)
|
|
||||||
|
|
||||||
def by_token(self, token):
|
class JWTAccessTokenQuerySet(models.QuerySet):
|
||||||
"""Filter by token"""
|
"""QuerySets for AccessToken model."""
|
||||||
return self.filter(token=token)
|
|
||||||
|
|
||||||
def by_jti(self, jti):
|
def valid(self):
|
||||||
"""Filter by unique access_token identifier"""
|
"""Returns only valid access tokens"""
|
||||||
|
return self.filter(expires_at__gte=timezone.now())
|
||||||
|
|
||||||
|
def by_jti(self, jti: str):
|
||||||
|
"""Filter by jti field"""
|
||||||
|
return self.filter(jti=jti)
|
||||||
|
|
||||||
|
def by_refresh_token_jti(self, jti):
|
||||||
|
"""Return all tokens by refresh token jti"""
|
||||||
|
return self.filter(refresh_token__jti=jti)
|
||||||
|
|
||||||
|
|
||||||
|
class JWTAccessToken(PlatformMixin, ProjectBaseMixin):
|
||||||
|
"""GM access token model."""
|
||||||
|
MOBILE = 0
|
||||||
|
WEB = 1
|
||||||
|
|
||||||
|
SOURCES = (
|
||||||
|
(MOBILE, _('Mobile')),
|
||||||
|
(WEB, _('Web')),
|
||||||
|
)
|
||||||
|
|
||||||
|
user = models.ForeignKey('account.User',
|
||||||
|
related_name='access_tokens',
|
||||||
|
on_delete=models.CASCADE)
|
||||||
|
source = models.PositiveSmallIntegerField(choices=SOURCES, default=WEB,
|
||||||
|
verbose_name=_('Source'))
|
||||||
|
refresh_token = models.ForeignKey('JWTRefreshToken',
|
||||||
|
related_name='access_tokens',
|
||||||
|
on_delete=models.DO_NOTHING)
|
||||||
|
created_at = models.DateTimeField(null=True, blank=True)
|
||||||
|
expires_at = models.DateTimeField(verbose_name=_('Expiration datetime'))
|
||||||
|
jti = models.CharField(unique=True, max_length=255)
|
||||||
|
|
||||||
|
objects = JWTAccessTokenManager.from_queryset(JWTAccessTokenQuerySet)()
|
||||||
|
|
||||||
|
class Meta:
|
||||||
|
"""Meta class."""
|
||||||
|
unique_together = ('user', 'jti')
|
||||||
|
verbose_name = _('Access token')
|
||||||
|
verbose_name_plural = _('Access tokens')
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
"""String representation method."""
|
||||||
|
return f'Access token JTI: {self.jti}'
|
||||||
|
|
||||||
|
def expire(self):
|
||||||
|
"""Expire access token."""
|
||||||
|
self.expires_at = timezone.now()
|
||||||
|
self.save()
|
||||||
|
|
||||||
|
|
||||||
|
class JWTRefreshTokenManager(models.Manager):
|
||||||
|
"""Manager for model RefreshToken."""
|
||||||
|
|
||||||
|
def add_to_db(self, user, token: RefreshToken):
|
||||||
|
"""Added generated refresh token to db"""
|
||||||
|
jti = token[settings.SIMPLE_JWT.get('JTI_CLAIM')]
|
||||||
|
exp = token['exp']
|
||||||
|
obj = self.model(
|
||||||
|
user=user,
|
||||||
|
jti=jti,
|
||||||
|
created_at=token.current_time,
|
||||||
|
expires_at=utils.datetime_from_epoch(exp),
|
||||||
|
)
|
||||||
|
obj.save()
|
||||||
|
return obj
|
||||||
|
|
||||||
|
|
||||||
|
class JWTRefreshTokenQuerySet(models.QuerySet):
|
||||||
|
"""QuerySets for model RefreshToken."""
|
||||||
|
|
||||||
|
def valid(self):
|
||||||
|
"""Return only balid refresh tokens"""
|
||||||
|
return self.filter(expires_at__gte=timezone.now())
|
||||||
|
|
||||||
|
def by_jti(self, jti: str):
|
||||||
|
"""Filter by jti field"""
|
||||||
return self.filter(jti=jti)
|
return self.filter(jti=jti)
|
||||||
|
|
||||||
|
|
||||||
class BlacklistedAccessToken(models.Model):
|
class JWTRefreshToken(ProjectBaseMixin):
|
||||||
|
"""GM refresh token model."""
|
||||||
user = models.ForeignKey('account.User',
|
user = models.ForeignKey('account.User',
|
||||||
on_delete=models.CASCADE,
|
related_name='refresh_tokens',
|
||||||
verbose_name=_('User'))
|
on_delete=models.CASCADE)
|
||||||
|
jti = models.CharField(unique=True, max_length=255)
|
||||||
|
created_at = models.DateTimeField(null=True, blank=True)
|
||||||
|
expires_at = models.DateTimeField()
|
||||||
|
|
||||||
jti = models.CharField(max_length=255, unique=True,
|
objects = JWTRefreshTokenManager.from_queryset(JWTRefreshTokenQuerySet)()
|
||||||
verbose_name=_('Unique access_token identifier'))
|
|
||||||
token = models.TextField(verbose_name=_('Access token'))
|
|
||||||
|
|
||||||
blacklisted_at = models.DateTimeField(auto_now_add=True,
|
|
||||||
verbose_name=_('Blacklisted datetime'))
|
|
||||||
|
|
||||||
objects = BlacklistedAccessTokenQuerySet.as_manager()
|
|
||||||
|
|
||||||
class Meta:
|
class Meta:
|
||||||
"""Meta class"""
|
"""Meta class."""
|
||||||
unique_together = ('token', 'user')
|
unique_together = ('user', 'jti')
|
||||||
|
verbose_name = _('Refresh token')
|
||||||
|
verbose_name_plural = _('Refresh tokens')
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
return 'Blacklisted access token for {}'.format(self.user)
|
"""String representation method."""
|
||||||
|
return f'Refresh token JTI: {self.jti}'
|
||||||
|
|
||||||
|
def expire(self):
|
||||||
|
"""Expire refresh token."""
|
||||||
|
self.expires_at = timezone.now()
|
||||||
|
self.save()
|
||||||
|
|
|
||||||
|
|
@ -5,14 +5,13 @@ from django.contrib.auth import password_validation as password_validators
|
||||||
from django.db.models import Q
|
from django.db.models import Q
|
||||||
from rest_framework import serializers
|
from rest_framework import serializers
|
||||||
from rest_framework import validators as rest_validators
|
from rest_framework import validators as rest_validators
|
||||||
from authorization import tasks
|
|
||||||
# JWT
|
|
||||||
from rest_framework_simplejwt import tokens
|
|
||||||
|
|
||||||
from account import models as account_models
|
from account import models as account_models
|
||||||
from authorization.models import Application, BlacklistedAccessToken
|
from authorization import tasks
|
||||||
|
from authorization.models import Application
|
||||||
from utils import exceptions as utils_exceptions
|
from utils import exceptions as utils_exceptions
|
||||||
from utils import methods as utils_methods
|
from utils import methods as utils_methods
|
||||||
|
from utils import tokens
|
||||||
|
|
||||||
|
|
||||||
# Mixins
|
# Mixins
|
||||||
|
|
@ -30,18 +29,21 @@ class JWTBaseSerializerMixin(serializers.Serializer):
|
||||||
refresh_token = serializers.CharField(read_only=True)
|
refresh_token = serializers.CharField(read_only=True)
|
||||||
access_token = serializers.CharField(read_only=True)
|
access_token = serializers.CharField(read_only=True)
|
||||||
|
|
||||||
def get_token(self):
|
def get_tokens(self):
|
||||||
"""Create JWT token"""
|
"""Create JWT token"""
|
||||||
user = self.instance
|
user = self.instance
|
||||||
token = tokens.RefreshToken.for_user(user)
|
token = tokens.GMRefreshToken.for_user(user)
|
||||||
token['user'] = user.get_user_info()
|
token['user'] = user.get_user_info()
|
||||||
return token
|
return {
|
||||||
|
'access_token': str(token.access_token),
|
||||||
|
'refresh_token': str(token),
|
||||||
|
}
|
||||||
|
|
||||||
def to_representation(self, instance):
|
def to_representation(self, instance):
|
||||||
"""Override to_representation method"""
|
"""Override to_representation method"""
|
||||||
token = self.get_token()
|
tokens = self.get_tokens()
|
||||||
setattr(instance, 'access_token', str(token.access_token))
|
setattr(instance, 'access_token', tokens.get('access_token'))
|
||||||
setattr(instance, 'refresh_token', str(token))
|
setattr(instance, 'refresh_token', tokens.get('refresh_token'))
|
||||||
return super().to_representation(instance)
|
return super().to_representation(instance)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -136,46 +138,6 @@ class LoginByUsernameOrEmailSerializer(JWTBaseSerializerMixin, serializers.Model
|
||||||
self.instance = user
|
self.instance = user
|
||||||
return attrs
|
return attrs
|
||||||
|
|
||||||
def to_representation(self, instance):
|
|
||||||
"""Override to_representation method"""
|
|
||||||
token = self.get_token()
|
|
||||||
setattr(instance, 'access_token', str(token.access_token))
|
|
||||||
setattr(instance, 'refresh_token', str(token))
|
|
||||||
# setattr(instance, 'remember', self.validated_data.get('remember'))
|
|
||||||
return super().to_representation(instance)
|
|
||||||
|
|
||||||
|
|
||||||
class LogoutSerializer(serializers.ModelSerializer):
|
|
||||||
"""Serializer class for model Logout"""
|
|
||||||
|
|
||||||
class Meta:
|
|
||||||
model = BlacklistedAccessToken
|
|
||||||
fields = (
|
|
||||||
'user',
|
|
||||||
'token',
|
|
||||||
'jti'
|
|
||||||
)
|
|
||||||
read_only_fields = (
|
|
||||||
'user',
|
|
||||||
'token',
|
|
||||||
'jti'
|
|
||||||
)
|
|
||||||
|
|
||||||
def create(self, validated_data):
|
|
||||||
"""Override create method"""
|
|
||||||
request = self.context.get('request')
|
|
||||||
# Get token bytes from cookies (result: b'Bearer <token>')
|
|
||||||
token_bytes = utils_methods.get_token_from_cookies(request)
|
|
||||||
# Get token value from bytes
|
|
||||||
token = token_bytes.decode().split(' ')[::-1][0]
|
|
||||||
# Get access token obj
|
|
||||||
access_token = tokens.AccessToken(token)
|
|
||||||
# Prepare validated data
|
|
||||||
validated_data['user'] = request.user
|
|
||||||
validated_data['token'] = access_token.token
|
|
||||||
validated_data['jti'] = access_token.payload.get('jti')
|
|
||||||
return super().create(validated_data)
|
|
||||||
|
|
||||||
|
|
||||||
# OAuth
|
# OAuth
|
||||||
class OAuth2Serialzier(BaseAuthSerializerMixin):
|
class OAuth2Serialzier(BaseAuthSerializerMixin):
|
||||||
|
|
|
||||||
|
|
@ -13,15 +13,16 @@ from rest_framework import generics
|
||||||
from rest_framework import permissions
|
from rest_framework import permissions
|
||||||
from rest_framework import status
|
from rest_framework import status
|
||||||
from rest_framework.response import Response
|
from rest_framework.response import Response
|
||||||
from rest_framework_simplejwt import tokens as jwt_tokens
|
from rest_framework_simplejwt.tokens import AccessToken
|
||||||
from rest_framework_social_oauth2.oauth2_backends import KeepRequestCore
|
from rest_framework_social_oauth2.oauth2_backends import KeepRequestCore
|
||||||
from rest_framework_social_oauth2.oauth2_endpoints import SocialTokenServer
|
from rest_framework_social_oauth2.oauth2_endpoints import SocialTokenServer
|
||||||
from utils.models import GMTokenGenerator
|
|
||||||
|
|
||||||
from account.models import User
|
from account.models import User
|
||||||
from authorization.models import Application
|
from authorization.models import Application
|
||||||
from authorization.serializers import common as serializers
|
from authorization.serializers import common as serializers
|
||||||
from utils import exceptions as utils_exceptions
|
from utils import exceptions as utils_exceptions
|
||||||
|
from utils import tokens as jwt_tokens
|
||||||
|
from utils.models import GMTokenGenerator
|
||||||
from utils.views import (JWTGenericViewMixin,
|
from utils.views import (JWTGenericViewMixin,
|
||||||
JWTCreateAPIView)
|
JWTCreateAPIView)
|
||||||
|
|
||||||
|
|
@ -109,7 +110,7 @@ class OAuth2SignUpView(OAuth2ViewMixin, JWTGenericViewMixin):
|
||||||
|
|
||||||
def get_jwt_token(self, user: User):
|
def get_jwt_token(self, user: User):
|
||||||
"""Get JWT token"""
|
"""Get JWT token"""
|
||||||
token = jwt_tokens.RefreshToken.for_user(user)
|
token = jwt_tokens.GMRefreshToken.for_user(user)
|
||||||
# Adding additional information about user to payload
|
# Adding additional information about user to payload
|
||||||
token['user'] = user.get_user_info()
|
token['user'] = user.get_user_info()
|
||||||
return token
|
return token
|
||||||
|
|
@ -218,11 +219,10 @@ class LoginByUsernameOrEmailView(JWTAuthViewMixin):
|
||||||
# Logout
|
# Logout
|
||||||
class LogoutView(JWTGenericViewMixin):
|
class LogoutView(JWTGenericViewMixin):
|
||||||
"""Logout user"""
|
"""Logout user"""
|
||||||
serializer_class = serializers.LogoutSerializer
|
|
||||||
|
|
||||||
def post(self, request, *args, **kwargs):
|
def post(self, request, *args, **kwargs):
|
||||||
"""Override create method"""
|
"""Override create method"""
|
||||||
serializer = self.get_serializer(data=request.data)
|
access_token = request.COOKIES.get('access_token')
|
||||||
serializer.is_valid(raise_exception=True)
|
access_token_obj = AccessToken(access_token)
|
||||||
serializer.save()
|
request.user.expire_access_token(jti=access_token_obj.payload.get('jti'))
|
||||||
return Response(status=status.HTTP_204_NO_CONTENT)
|
return Response(status=status.HTTP_204_NO_CONTENT)
|
||||||
|
|
|
||||||
|
|
@ -87,7 +87,7 @@ class NotValidAccessTokenError(exceptions.APIException):
|
||||||
class NotValidRefreshTokenError(exceptions.APIException):
|
class NotValidRefreshTokenError(exceptions.APIException):
|
||||||
"""The exception should be thrown when refresh token is not valid
|
"""The exception should be thrown when refresh token is not valid
|
||||||
"""
|
"""
|
||||||
status_code = status.HTTP_400_BAD_REQUEST
|
status_code = status.HTTP_401_UNAUTHORIZED
|
||||||
default_detail = _('Not valid refresh token')
|
default_detail = _('Not valid refresh token')
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,10 +1,7 @@
|
||||||
"""Project custom permissions"""
|
"""Project custom permissions"""
|
||||||
from rest_framework.permissions import BasePermission
|
from rest_framework.permissions import BasePermission
|
||||||
from rest_framework_simplejwt.exceptions import TokenBackendError
|
|
||||||
|
|
||||||
from authorization.models import BlacklistedAccessToken
|
from rest_framework_simplejwt.tokens import AccessToken
|
||||||
from utils.exceptions import NotValidAccessTokenError
|
|
||||||
from utils.methods import get_token_from_cookies
|
|
||||||
|
|
||||||
|
|
||||||
class IsAuthenticatedAndTokenIsValid(BasePermission):
|
class IsAuthenticatedAndTokenIsValid(BasePermission):
|
||||||
|
|
@ -15,17 +12,11 @@ class IsAuthenticatedAndTokenIsValid(BasePermission):
|
||||||
def has_permission(self, request, view):
|
def has_permission(self, request, view):
|
||||||
"""Check permissions by access token and default REST permission IsAuthenticated"""
|
"""Check permissions by access token and default REST permission IsAuthenticated"""
|
||||||
user = request.user
|
user = request.user
|
||||||
try:
|
access_token = request.COOKIES.get('access_token')
|
||||||
if user and user.is_authenticated:
|
if user.is_authenticated and access_token:
|
||||||
token_bytes = get_token_from_cookies(request)
|
access_token = AccessToken(access_token)
|
||||||
# Get access token key
|
valid_tokens = user.access_tokens.valid()\
|
||||||
token = token_bytes.decode().split(' ')[1]
|
.by_jti(jti=access_token.payload.get('jti'))
|
||||||
# Check if user access token not expired
|
return valid_tokens.exists()
|
||||||
blacklisted = BlacklistedAccessToken.objects.by_token(token) \
|
|
||||||
.by_user(user) \
|
|
||||||
.exists()
|
|
||||||
return not blacklisted
|
|
||||||
except TokenBackendError:
|
|
||||||
raise NotValidAccessTokenError()
|
|
||||||
else:
|
else:
|
||||||
return False
|
return False
|
||||||
|
|
|
||||||
77
apps/utils/tokens.py
Normal file
77
apps/utils/tokens.py
Normal file
|
|
@ -0,0 +1,77 @@
|
||||||
|
"""Custom tokens based on django-rest-framework-simple-jwt"""
|
||||||
|
from rest_framework_simplejwt.settings import api_settings
|
||||||
|
from rest_framework_simplejwt.tokens import Token, AccessToken, RefreshToken, BlacklistMixin
|
||||||
|
|
||||||
|
from account.models import User
|
||||||
|
from authorization.models import JWTRefreshToken, JWTAccessToken
|
||||||
|
|
||||||
|
|
||||||
|
class GMToken(Token):
|
||||||
|
"""Custom JWT Token class"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def for_user(cls, user):
|
||||||
|
"""
|
||||||
|
Returns an authorization token for the given user that will be provided
|
||||||
|
after authenticating the user's credentials.
|
||||||
|
"""
|
||||||
|
user_id = getattr(user, api_settings.USER_ID_FIELD)
|
||||||
|
if not isinstance(user_id, int):
|
||||||
|
user_id = str(user_id)
|
||||||
|
|
||||||
|
token = cls()
|
||||||
|
token[api_settings.USER_ID_CLAIM] = user_id
|
||||||
|
|
||||||
|
return token
|
||||||
|
|
||||||
|
|
||||||
|
class GMBlacklistMixin(BlacklistMixin):
|
||||||
|
"""
|
||||||
|
If the `rest_framework_simplejwt.token_blacklist` app was configured to be
|
||||||
|
used, tokens created from `BlacklistMixin` subclasses will insert
|
||||||
|
themselves into an outstanding token list and also check for their
|
||||||
|
membership in a token blacklist.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def for_user(cls, user):
|
||||||
|
"""
|
||||||
|
Adds this token to the outstanding token list.
|
||||||
|
"""
|
||||||
|
token = super().for_user(user)
|
||||||
|
# Create a record in DB
|
||||||
|
JWTRefreshToken.objects.add_to_db(user=user, token=token)
|
||||||
|
return token
|
||||||
|
|
||||||
|
|
||||||
|
class GMRefreshToken(GMBlacklistMixin, GMToken, RefreshToken):
|
||||||
|
"""GM refresh token"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def access_token(self):
|
||||||
|
"""
|
||||||
|
Returns an access token created from this refresh token. Copies all
|
||||||
|
claims present in this refresh token to the new access token except
|
||||||
|
those claims listed in the `no_copy_claims` attribute.
|
||||||
|
"""
|
||||||
|
access_token = AccessToken()
|
||||||
|
|
||||||
|
# Use instantiation time of refresh token as relative timestamp for
|
||||||
|
# access token "exp" claim. This ensures that both a refresh and
|
||||||
|
# access token expire relative to the same time if they are created as
|
||||||
|
# a pair.
|
||||||
|
access_token.set_exp(from_time=self.current_time)
|
||||||
|
|
||||||
|
no_copy = self.no_copy_claims
|
||||||
|
for claim, value in self.payload.items():
|
||||||
|
if claim in no_copy:
|
||||||
|
continue
|
||||||
|
access_token[claim] = value
|
||||||
|
|
||||||
|
# Create a record in DB
|
||||||
|
user = User.objects.get(id=self.payload.get('user_id'))
|
||||||
|
JWTAccessToken.objects.add_to_db(user=user,
|
||||||
|
access_token=access_token,
|
||||||
|
refresh_token=self)
|
||||||
|
return access_token
|
||||||
|
|
||||||
|
|
@ -4,7 +4,7 @@ from django.conf import settings
|
||||||
from rest_framework import generics
|
from rest_framework import generics
|
||||||
from rest_framework import status
|
from rest_framework import status
|
||||||
from rest_framework.response import Response
|
from rest_framework.response import Response
|
||||||
from rest_framework_simplejwt import tokens
|
from utils import tokens
|
||||||
|
|
||||||
|
|
||||||
# JWT
|
# JWT
|
||||||
|
|
@ -21,7 +21,7 @@ class JWTGenericViewMixin(generics.GenericAPIView):
|
||||||
|
|
||||||
def _create_jwt_token(self, user) -> dict:
|
def _create_jwt_token(self, user) -> dict:
|
||||||
"""Return dictionary with pairs access and refresh tokens"""
|
"""Return dictionary with pairs access and refresh tokens"""
|
||||||
token = tokens.RefreshToken.for_user(user)
|
token = tokens.GMRefreshToken.for_user(user)
|
||||||
token['user'] = user.get_user_info()
|
token['user'] = user.get_user_info()
|
||||||
return {
|
return {
|
||||||
'access_token': str(token.access_token),
|
'access_token': str(token.access_token),
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user