76 lines
2.6 KiB
Python
76 lines
2.6 KiB
Python
"""Custom tokens based on django-rest-framework-simple-jwt"""
|
|
from rest_framework_simplejwt.settings import api_settings
|
|
from rest_framework_simplejwt.tokens import Token, AccessToken, RefreshToken, BlacklistMixin
|
|
|
|
from authorization.models import JWTRefreshToken, JWTAccessToken
|
|
|
|
|
|
class GMToken(Token):
|
|
"""Custom JWT Token class"""
|
|
|
|
@classmethod
|
|
def for_user(cls, user):
|
|
"""
|
|
Returns an authorization token for the given user that will be provided
|
|
after authenticating the user's credentials.
|
|
"""
|
|
user_id = getattr(user, api_settings.USER_ID_FIELD)
|
|
if not isinstance(user_id, int):
|
|
user_id = str(user_id)
|
|
|
|
token = cls()
|
|
token[api_settings.USER_ID_CLAIM] = user_id
|
|
|
|
return token
|
|
|
|
|
|
class GMBlacklistMixin(BlacklistMixin):
|
|
"""
|
|
If the `rest_framework_simplejwt.token_blacklist` app was configured to be
|
|
used, tokens created from `BlacklistMixin` subclasses will insert
|
|
themselves into an outstanding token list and also check for their
|
|
membership in a token blacklist.
|
|
"""
|
|
|
|
@classmethod
|
|
def for_user(cls, user, source: int = None):
|
|
"""Create a refresh token."""
|
|
token = super().for_user(user)
|
|
token['user'] = user.get_user_info()
|
|
JWTRefreshToken.objects.make(user=user, token=token, source=source)
|
|
return token
|
|
|
|
|
|
class GMRefreshToken(GMBlacklistMixin, GMToken, RefreshToken):
|
|
"""GM refresh token"""
|
|
|
|
@property
|
|
def access_token(self):
|
|
"""
|
|
Returns an access token created from this refresh token. Copies all
|
|
claims present in this refresh token to the new access token except
|
|
those claims listed in the `no_copy_claims` attribute.
|
|
"""
|
|
from account.models import User
|
|
|
|
access_token = AccessToken()
|
|
|
|
# Use instantiation time of refresh token as relative timestamp for
|
|
# access token "exp" claim. This ensures that both a refresh and
|
|
# access token expire relative to the same time if they are created as
|
|
# a pair.
|
|
access_token.set_exp(from_time=self.current_time)
|
|
|
|
no_copy = self.no_copy_claims
|
|
for claim, value in self.payload.items():
|
|
if claim in no_copy:
|
|
continue
|
|
access_token[claim] = value
|
|
|
|
# Create a record in DB
|
|
user = User.objects.get(id=self.payload.get('user_id'))
|
|
JWTAccessToken.objects.make(user=user,
|
|
access_token=access_token,
|
|
refresh_token=self)
|
|
return access_token
|