165 lines
5.4 KiB
Python
165 lines
5.4 KiB
Python
from django.conf import settings
|
|
from django.db import models
|
|
from django.utils import timezone
|
|
from django.utils.translation import gettext_lazy as _
|
|
from oauth2_provider import models as oauth2_models
|
|
from oauth2_provider.models import AbstractApplication
|
|
from rest_framework_simplejwt import utils
|
|
from rest_framework_simplejwt.tokens import RefreshToken, AccessToken
|
|
|
|
from utils.models import PlatformMixin, ProjectBaseMixin
|
|
|
|
|
|
class ApplicationQuerySet(models.QuerySet):
|
|
"""Application queryset"""
|
|
def get_by_natural_key(self, client_id):
|
|
return self.get(client_id=client_id)
|
|
|
|
def by_source(self, source: int):
|
|
"""Filter by source parameter"""
|
|
return self.filter(source=source)
|
|
|
|
|
|
class ApplicationManager(oauth2_models.ApplicationManager):
|
|
"""Application manager"""
|
|
|
|
|
|
class Application(PlatformMixin, AbstractApplication):
|
|
"""Custom oauth2 application model"""
|
|
|
|
objects = ApplicationManager.from_queryset(ApplicationQuerySet)()
|
|
|
|
class Meta(AbstractApplication.Meta):
|
|
swappable = "OAUTH2_PROVIDER_APPLICATION_MODEL"
|
|
|
|
def natural_key(self):
|
|
return self.client_id
|
|
|
|
|
|
class JWTAccessTokenManager(models.Manager):
|
|
"""Manager for AccessToken model."""
|
|
|
|
def make(self, user, access_token: AccessToken, refresh_token: RefreshToken):
|
|
"""Create generated tokens to DB"""
|
|
refresh_token_qs = JWTRefreshToken.objects.filter(user=user,
|
|
jti=refresh_token.payload.get('jti'))
|
|
if refresh_token_qs.exists():
|
|
jti = access_token[settings.SIMPLE_JWT.get('JTI_CLAIM')]
|
|
exp = access_token['exp']
|
|
obj = self.model(
|
|
user=user,
|
|
jti=jti,
|
|
refresh_token=refresh_token_qs.first(),
|
|
created_at=access_token.current_time,
|
|
expires_at=utils.datetime_from_epoch(exp),
|
|
)
|
|
obj.save()
|
|
return obj
|
|
|
|
|
|
class JWTAccessTokenQuerySet(models.QuerySet):
|
|
"""QuerySets for AccessToken model."""
|
|
|
|
def valid(self):
|
|
"""Returns only valid access tokens"""
|
|
return self.filter(expires_at__gte=timezone.now())
|
|
|
|
def by_jti(self, jti: str):
|
|
"""Filter by jti field"""
|
|
return self.filter(jti=jti)
|
|
|
|
def by_refresh_token_jti(self, jti):
|
|
"""Return all tokens by refresh token jti"""
|
|
return self.filter(refresh_token__jti=jti)
|
|
|
|
|
|
class JWTAccessToken(ProjectBaseMixin):
|
|
"""GM access token model."""
|
|
user = models.ForeignKey('account.User',
|
|
related_name='access_tokens',
|
|
on_delete=models.CASCADE)
|
|
refresh_token = models.OneToOneField('JWTRefreshToken',
|
|
related_name='access_token',
|
|
on_delete=models.CASCADE,
|
|
null=True, blank=True, default=None)
|
|
created_at = models.DateTimeField(null=True, blank=True)
|
|
expires_at = models.DateTimeField(verbose_name=_('Expiration datetime'))
|
|
jti = models.CharField(unique=True, max_length=255)
|
|
|
|
objects = JWTAccessTokenManager.from_queryset(JWTAccessTokenQuerySet)()
|
|
|
|
class Meta:
|
|
"""Meta class."""
|
|
unique_together = ('user', 'jti')
|
|
verbose_name = _('Access token')
|
|
verbose_name_plural = _('Access tokens')
|
|
|
|
def __str__(self):
|
|
"""String representation method."""
|
|
return f'Access token JTI: {self.jti}'
|
|
|
|
def expire(self):
|
|
"""Expire access token."""
|
|
self.expires_at = timezone.now()
|
|
self.save()
|
|
|
|
|
|
class JWTRefreshTokenManager(models.Manager):
|
|
"""Manager for model RefreshToken."""
|
|
def make(self, user, token: RefreshToken, source: int):
|
|
"""Make method"""
|
|
jti = token[settings.SIMPLE_JWT.get('JTI_CLAIM')]
|
|
exp = token['exp']
|
|
obj = self.model(
|
|
user=user,
|
|
jti=jti,
|
|
created_at=token.current_time,
|
|
expires_at=utils.datetime_from_epoch(exp))
|
|
if source:
|
|
obj.source = source
|
|
obj.save()
|
|
return obj
|
|
|
|
|
|
class JWTRefreshTokenQuerySet(models.QuerySet):
|
|
"""QuerySets for model RefreshToken."""
|
|
|
|
def valid(self):
|
|
"""Return only valid refresh tokens"""
|
|
return self.filter(expires_at__gte=timezone.now())
|
|
|
|
def by_jti(self, jti: str):
|
|
"""Filter by jti field"""
|
|
return self.filter(jti=jti)
|
|
|
|
def by_source(self, source):
|
|
"""Return access tokens by source"""
|
|
return self.filter(source=source)
|
|
|
|
|
|
class JWTRefreshToken(PlatformMixin, ProjectBaseMixin):
|
|
"""GM refresh token model."""
|
|
user = models.ForeignKey('account.User',
|
|
related_name='refresh_tokens',
|
|
on_delete=models.CASCADE)
|
|
jti = models.CharField(unique=True, max_length=255)
|
|
created_at = models.DateTimeField(null=True, blank=True)
|
|
expires_at = models.DateTimeField()
|
|
|
|
objects = JWTRefreshTokenManager.from_queryset(JWTRefreshTokenQuerySet)()
|
|
|
|
class Meta:
|
|
"""Meta class."""
|
|
unique_together = ('user', 'jti')
|
|
verbose_name = _('Refresh token')
|
|
verbose_name_plural = _('Refresh tokens')
|
|
|
|
def __str__(self):
|
|
"""String representation method."""
|
|
return f'Refresh token JTI: {self.jti}'
|
|
|
|
def expire(self):
|
|
"""Expire refresh token."""
|
|
self.expires_at = timezone.now()
|
|
self.save()
|