gault-millau/apps/utils/tokens.py

76 lines
2.6 KiB
Python

"""Custom tokens based on django-rest-framework-simple-jwt"""
from rest_framework_simplejwt.settings import api_settings
from rest_framework_simplejwt.tokens import Token, AccessToken, RefreshToken, BlacklistMixin
from authorization.models import JWTRefreshToken, JWTAccessToken
class GMToken(Token):
"""Custom JWT Token class"""
@classmethod
def for_user(cls, user):
"""
Returns an authorization token for the given user that will be provided
after authenticating the user's credentials.
"""
user_id = getattr(user, api_settings.USER_ID_FIELD)
if not isinstance(user_id, int):
user_id = str(user_id)
token = cls()
token[api_settings.USER_ID_CLAIM] = user_id
return token
class GMBlacklistMixin(BlacklistMixin):
"""
If the `rest_framework_simplejwt.token_blacklist` app was configured to be
used, tokens created from `BlacklistMixin` subclasses will insert
themselves into an outstanding token list and also check for their
membership in a token blacklist.
"""
@classmethod
def for_user(cls, user, source: int = None):
"""Create a refresh token."""
token = super().for_user(user)
token['user'] = user.get_user_info()
JWTRefreshToken.objects.make(user=user, token=token, source=source)
return token
class GMRefreshToken(GMBlacklistMixin, GMToken, RefreshToken):
"""GM refresh token"""
@property
def access_token(self):
"""
Returns an access token created from this refresh token. Copies all
claims present in this refresh token to the new access token except
those claims listed in the `no_copy_claims` attribute.
"""
from account.models import User
access_token = AccessToken()
# Use instantiation time of refresh token as relative timestamp for
# access token "exp" claim. This ensures that both a refresh and
# access token expire relative to the same time if they are created as
# a pair.
access_token.set_exp(from_time=self.current_time)
no_copy = self.no_copy_claims
for claim, value in self.payload.items():
if claim in no_copy:
continue
access_token[claim] = value
# Create a record in DB
user = User.objects.get(id=self.payload.get('user_id'))
JWTAccessToken.objects.make(user=user,
access_token=access_token,
refresh_token=self)
return access_token