From 39e833ec1da0359d502dfc89edcdebad88b8fde2 Mon Sep 17 00:00:00 2001 From: Anatoly Date: Thu, 29 Aug 2019 09:43:31 +0300 Subject: [PATCH] added custom JWT authentication --- apps/authorization/serializers/common.py | 7 +++-- apps/utils/authentication.py | 36 ++++++++++++++++++++++++ apps/utils/exceptions.py | 2 +- apps/utils/methods.py | 16 +++++++++-- apps/utils/permissions.py | 24 ++++++++++------ apps/utils/views.py | 4 +-- project/settings/base.py | 5 ++-- project/settings/local.py | 3 -- 8 files changed, 75 insertions(+), 22 deletions(-) create mode 100644 apps/utils/authentication.py diff --git a/apps/authorization/serializers/common.py b/apps/authorization/serializers/common.py index 2a38ca09..36ae8a3e 100644 --- a/apps/authorization/serializers/common.py +++ b/apps/authorization/serializers/common.py @@ -200,8 +200,11 @@ class LogoutSerializer(serializers.ModelSerializer): def validate(self, attrs): """Override validated data""" request = self.context.get('request') - token = request.headers.get('Authorization') \ - .split(' ')[::-1][0] + # Get token bytes from cookies (result: b'Bearer ') + token_bytes = utils_methods.get_token_from_cookies(request) + # Get token value from bytes + token = token_bytes.decode().split(' ')[::-1][0] + # Get access token obj access_token = tokens.AccessToken(token) # Prepare validated data attrs['user'] = request.user diff --git a/apps/utils/authentication.py b/apps/utils/authentication.py new file mode 100644 index 00000000..044d6d75 --- /dev/null +++ b/apps/utils/authentication.py @@ -0,0 +1,36 @@ +"""Custom authentication based on JWTAuthentication class""" +from rest_framework import HTTP_HEADER_ENCODING +from rest_framework_simplejwt.authentication import JWTAuthentication +from rest_framework_simplejwt.settings import api_settings + +from utils.methods import get_token_from_cookies + +AUTH_HEADER_TYPES = api_settings.AUTH_HEADER_TYPES + +if not isinstance(api_settings.AUTH_HEADER_TYPES, (list, tuple)): + AUTH_HEADER_TYPES = (AUTH_HEADER_TYPES,) + +AUTH_HEADER_TYPE_BYTES = set( + h.encode(HTTP_HEADER_ENCODING) + for h in AUTH_HEADER_TYPES +) + + +class GMJWTAuthentication(JWTAuthentication): + """ + An authentication plugin that authenticates requests through a JSON web + token provided in a request cookies. + """ + + def authenticate(self, request): + token = get_token_from_cookies(request) + if token is None: + return None + + raw_token = self.get_raw_token(token) + if raw_token is None: + return None + + validated_token = self.get_validated_token(raw_token) + + return self.get_user(validated_token), None diff --git a/apps/utils/exceptions.py b/apps/utils/exceptions.py index 16eb12fc..90348498 100644 --- a/apps/utils/exceptions.py +++ b/apps/utils/exceptions.py @@ -73,7 +73,7 @@ class NotValidUsernameError(exceptions.APIException): class NotValidTokenError(exceptions.APIException): """The exception should be thrown when token in url is not valid """ - status_code = status.HTTP_400_BAD_REQUEST + status_code = status.HTTP_401_UNAUTHORIZED default_detail = _('Not valid token') diff --git a/apps/utils/methods.py b/apps/utils/methods.py index a26c306c..9dcafa97 100644 --- a/apps/utils/methods.py +++ b/apps/utils/methods.py @@ -16,13 +16,23 @@ def generate_code(digits=6, string_output=True): return str(value) if string_output else value +def get_token_from_cookies(request): + """Get access token from request cookies""" + cookies = request.COOKIES + if cookies.get('access_token'): + token = f'Bearer {cookies.get("access_token")}' + return token.encode() + + def get_token_from_request(request): """Get access token from request""" + token = None if 'Authorization' in request.headers: if isinstance(request, HttpRequest): - return request.headers.get('Authorization').split(' ')[::-1][0] - elif isinstance(request, Request): - return request._request.headers.get('Authorization').split(' ')[::-1][0] + token = request.headers.get('Authorization').split(' ')[::-1][0] + if isinstance(request, Request): + token = request.headers.get('Authorization').split(' ')[::-1][0] + return token def username_validator(username: str) -> bool: diff --git a/apps/utils/permissions.py b/apps/utils/permissions.py index 3e4a1d33..d1f8c430 100644 --- a/apps/utils/permissions.py +++ b/apps/utils/permissions.py @@ -1,7 +1,10 @@ """Project custom permissions""" from rest_framework.permissions import BasePermission +from rest_framework_simplejwt.exceptions import TokenBackendError + from authorization.models import BlacklistedAccessToken -from utils.methods import get_token_from_request +from utils.exceptions import NotValidTokenError +from utils.methods import get_token_from_cookies class IsAuthenticatedAndTokenIsValid(BasePermission): @@ -12,12 +15,17 @@ class IsAuthenticatedAndTokenIsValid(BasePermission): def has_permission(self, request, view): """Check permissions by access token and default REST permission IsAuthenticated""" user = request.user - if user and user.is_authenticated: - token = get_token_from_request(request) - # Check if user access token not expired - expired = BlacklistedAccessToken.objects.by_token(token)\ - .by_user(user)\ - .exists() - return not expired + try: + if user and user.is_authenticated: + token_bytes = get_token_from_cookies(request) + # Get access token key + token = token_bytes.decode().split(' ')[1] + # Check if user access token not expired + blacklisted = BlacklistedAccessToken.objects.by_token(token) \ + .by_user(user) \ + .exists() + return not blacklisted + except TokenBackendError: + raise NotValidTokenError() else: return False diff --git a/apps/utils/views.py b/apps/utils/views.py index f413e3cc..b222a4d7 100644 --- a/apps/utils/views.py +++ b/apps/utils/views.py @@ -67,8 +67,8 @@ class JWTGenericViewMixin(generics.GenericAPIView): """Update COOKIES in response from namedtuple""" for cookie in cookies: # todo: remove config for develop - import os - configuration = os.environ.get('SETTINGS_CONFIGURATION', None) + from os import environ + configuration = environ.get('SETTINGS_CONFIGURATION', None) if configuration == 'development': response.set_cookie(key=cookie.key, value=cookie.value, diff --git a/project/settings/base.py b/project/settings/base.py index 45dbb693..e73dbb65 100644 --- a/project/settings/base.py +++ b/project/settings/base.py @@ -205,10 +205,9 @@ REST_FRAMEWORK = { 'DEFAULT_PAGINATION_CLASS': 'utils.pagination.ProjectMobilePagination', 'COERCE_DECIMAL_TO_STRING': False, 'DEFAULT_AUTHENTICATION_CLASSES': ( - 'rest_framework.authentication.TokenAuthentication', - 'rest_framework.authentication.SessionAuthentication', # JWT - 'rest_framework_simplejwt.authentication.JWTAuthentication', + 'utils.authentication.GMJWTAuthentication', + 'rest_framework.authentication.SessionAuthentication', ), 'DEFAULT_VERSIONING_CLASS': 'rest_framework.versioning.AcceptHeaderVersioning', 'DEFAULT_VERSION': (AVAILABLE_VERSIONS['current'],), diff --git a/project/settings/local.py b/project/settings/local.py index 2858973c..93e6948b 100644 --- a/project/settings/local.py +++ b/project/settings/local.py @@ -21,6 +21,3 @@ API_HOST_URL = 'http://%s' % API_HOST BROKER_URL = 'amqp://rabbitmq:5672' CELERY_RESULT_BACKEND = BROKER_URL CELERY_BROKER_URL = BROKER_URL - -# Increase access token lifetime for local deploy -SIMPLE_JWT['ACCESS_TOKEN_LIFETIME'] = timedelta(days=365)