from django.conf import settings from django.db import models from django.utils import timezone from django.utils.translation import gettext_lazy as _ from oauth2_provider import models as oauth2_models from oauth2_provider.models import AbstractApplication from rest_framework_simplejwt import utils from rest_framework_simplejwt.tokens import RefreshToken, AccessToken from utils.models import PlatformMixin, ProjectBaseMixin class ApplicationQuerySet(models.QuerySet): """Application queryset""" def get_by_natural_key(self, client_id): return self.get(client_id=client_id) def by_source(self, source: int): """Filter by source parameter""" return self.filter(source=source) class ApplicationManager(oauth2_models.ApplicationManager): """Application manager""" class Application(PlatformMixin, AbstractApplication): """Custom oauth2 application model""" objects = ApplicationManager.from_queryset(ApplicationQuerySet)() class Meta(AbstractApplication.Meta): swappable = "OAUTH2_PROVIDER_APPLICATION_MODEL" def natural_key(self): return (self.client_id,) class JWTAccessTokenManager(models.Manager): """Manager for AccessToken model.""" def make(self, user, access_token: AccessToken, refresh_token: RefreshToken): """Create generated tokens to DB""" refresh_token_qs = JWTRefreshToken.objects.filter(user=user, jti=refresh_token.payload.get('jti')) if refresh_token_qs.exists(): jti = access_token[settings.SIMPLE_JWT.get('JTI_CLAIM')] exp = access_token['exp'] obj = self.model( user=user, jti=jti, refresh_token=refresh_token_qs.first(), created_at=access_token.current_time, expires_at=utils.datetime_from_epoch(exp), ) obj.save() return obj class JWTAccessTokenQuerySet(models.QuerySet): """QuerySets for AccessToken model.""" def valid(self): """Returns only valid access tokens""" return self.filter(expires_at__gte=timezone.now()) def by_jti(self, jti: str): """Filter by jti field""" return self.filter(jti=jti) def by_refresh_token_jti(self, jti): """Return all tokens by refresh token jti""" return self.filter(refresh_token__jti=jti) class JWTAccessToken(ProjectBaseMixin): """GM access token model.""" user = models.ForeignKey('account.User', related_name='access_tokens', on_delete=models.CASCADE) refresh_token = models.OneToOneField('JWTRefreshToken', related_name='access_token', on_delete=models.CASCADE, null=True, blank=True, default=None) created_at = models.DateTimeField(null=True, blank=True) expires_at = models.DateTimeField(verbose_name=_('Expiration datetime')) jti = models.CharField(unique=True, max_length=255) objects = JWTAccessTokenManager.from_queryset(JWTAccessTokenQuerySet)() class Meta: """Meta class.""" unique_together = ('user', 'jti') verbose_name = _('Access token') verbose_name_plural = _('Access tokens') def __str__(self): """String representation method.""" return f'Access token JTI: {self.jti}' def expire(self): """Expire access token.""" self.expires_at = timezone.now() self.save() class JWTRefreshTokenManager(models.Manager): """Manager for model RefreshToken.""" def make(self, user, token: RefreshToken, source: int): """Make method""" jti = token[settings.SIMPLE_JWT.get('JTI_CLAIM')] exp = token['exp'] obj = self.model( user=user, jti=jti, created_at=token.current_time, expires_at=utils.datetime_from_epoch(exp)) if source: obj.source = source obj.save() return obj class JWTRefreshTokenQuerySet(models.QuerySet): """QuerySets for model RefreshToken.""" def valid(self): """Return only valid refresh tokens""" return self.filter(expires_at__gte=timezone.now()) def by_jti(self, jti: str): """Filter by jti field""" return self.filter(jti=jti) def by_source(self, source): """Return access tokens by source""" return self.filter(source=source) class JWTRefreshToken(PlatformMixin, ProjectBaseMixin): """GM refresh token model.""" user = models.ForeignKey('account.User', related_name='refresh_tokens', on_delete=models.CASCADE) jti = models.CharField(unique=True, max_length=255) created_at = models.DateTimeField(null=True, blank=True) expires_at = models.DateTimeField() objects = JWTRefreshTokenManager.from_queryset(JWTRefreshTokenQuerySet)() class Meta: """Meta class.""" unique_together = ('user', 'jti') verbose_name = _('Refresh token') verbose_name_plural = _('Refresh tokens') def __str__(self): """String representation method.""" return f'Refresh token JTI: {self.jti}' def expire(self): """Expire refresh token.""" self.expires_at = timezone.now() self.save()