示例#1
0
    def decrypt_jwt_token(self, token):
        try:
            if token:
                logging.debug("Decrypting signed JWT " + strings.to_str(token))
                tokens = token.split('.')
                if len(tokens) != 5:
                    raise InvalidTokenException("Incorrect size")
                jwe_protected_header = tokens[0]
                self.__check_jwe_protected_header(jwe_protected_header)
                encrypted_key = tokens[1]
                encoded_iv = tokens[2]

                decrypted_key = self._decrypt_key(encrypted_key)
                iv = self._base64_decode(encoded_iv)
                if not self._check_iv_length(iv):
                    raise InvalidTokenException("IV incorrect length")
                if not self._check_cek_length(decrypted_key):
                    raise InvalidTokenException("CEK incorrect length")

                signed_token = super().decrypt(token)
                return self.decode_signed_jwt_token(signed_token)
            else:
                raise NoTokenException("JWT Missing")
        except (jwt.DecodeError, InvalidTag, InternalError, ValueError,
                AssertionError) as e:
            raise InvalidTokenException(repr(e))
示例#2
0
 def _check_headers(self, header_data):
     try:
         headers = self._base64_decode(header_data)
         if not headers:
             raise InvalidTokenException("Missing Headers")
         self._check_for_duplicates(headers)
     except UnicodeDecodeError:
         raise InvalidTokenException("Corrupted Header")
     except ValueError as e:
         raise InvalidTokenException(repr(e))
示例#3
0
 def __check_jwe_protected_header(self, header):
     header = self._base64_decode(header).decode()
     header_data = json.loads(header)
     if not header_data.get("alg"):
         raise InvalidTokenException("Missing Algorithm")
     if header_data.get("alg") != "RSA-OAEP":
         raise InvalidTokenException("Invalid Algorithm")
     if not header_data.get("enc"):
         raise InvalidTokenException("Missing Encoding")
     if header_data.get("enc") != "A256GCM":
         raise InvalidTokenException("Invalid Encoding")
示例#4
0
 def _raise_exception_on_duplicates(ordered_pairs):
     store = {}
     for key, value in ordered_pairs:
         if key in store:
             raise InvalidTokenException("Multiple " + key + " Headers")
         else:
             store[key] = value
     return store
示例#5
0
 def decode_signed_jwt_token(self, signed_token, leeway):
     try:
         if signed_token:
             logger.debug("decoding signed jwt", jwt_token=strings.to_str(signed_token))
             self._check_token(signed_token)
             token = jwt.decode(signed_token, self.rrm_public_key, algorithms=['RS256'],
                                leeway=leeway)
             if not token:
                 raise InvalidTokenException("Missing Payload")
             return token
         else:
             raise NoTokenException("JWT Missing")
     except (jwt.DecodeError,
             jwt.exceptions.InvalidAlgorithmError,
             jwt.exceptions.ExpiredSignatureError,
             jwt.exceptions.InvalidIssuedAtError) as e:
         raise InvalidTokenException(repr(e))
示例#6
0
    def _check_token(self, token):
        token_as_str = strings.to_str(token)
        if token_as_str.count(".") != 2:
            raise InvalidTokenException("Invalid Token")

        # header_data, payload_data, signature_data
        header_data, payload_data, _ = token_as_str.split('.', maxsplit=2)

        self._check_headers(header_data)
        self._check_header_values(token_as_str)
        self._check_payload(payload_data)
示例#7
0
    def _get_token_data(token):
        ru_ref = token.get("ru_ref")
        collection_exercise_sid = token.get("collection_exercise_sid")
        eq_id = token.get("eq_id")
        form_type = token.get("form_type")

        if ru_ref and collection_exercise_sid and eq_id and form_type:
            return collection_exercise_sid, eq_id, form_type, ru_ref
        else:
            logger.error(
                "Missing values for ru_ref, collection_exercise_sid, form_type or eq_id in token %s",
                token)
            raise InvalidTokenException("Missing values in JWT token")
示例#8
0
def parse_metadata(metadata_to_check):
    parsed = {}
    try:
        for key, field in metadata_fields.items():
            if key in metadata_to_check:
                attr_value = metadata_to_check[key]
                field.validate(attr_value)
                logger.debug("parsing metadata", key=key, value=attr_value)
            else:
                logger.debug("generating metadata value", key=key)
                attr_value = field.generate()

            parsed[key] = attr_value
    except (RuntimeError, ValueError, TypeError) as e:
        logger.error("unable to parse metadata", exc_info=e)
        raise InvalidTokenException("incorrect data in token")
    return parsed
示例#9
0
 def _check_payload(self, payload_data):
     try:
         payload = self._base64_decode(payload_data)
         if not payload:
             raise InvalidTokenException("Missing Payload")
         payload_decoded = payload.decode()
         if payload_decoded == "{}":
             raise InvalidTokenException("Missing Payload")
         if payload_decoded.count("iat") == 0:
             raise InvalidTokenException("Missing iat claim")
         if payload_decoded.count("exp") == 0:
             raise InvalidTokenException("Missing exp claim")
         if payload_decoded.count("iat") > 1:
             raise InvalidTokenException("Multiple iat claims")
         if payload_decoded.count("exp") > 1:
             raise InvalidTokenException("Multiple exp claims")
     except (UnicodeDecodeError, IndexError):
         raise InvalidTokenException("Corrupted Payload")
     except ValueError as e:
         raise InvalidTokenException(repr(e))
示例#10
0
    def _check_header_values(token):
        header = jwt.get_unverified_header(token)

        if not header:
            raise InvalidTokenException("Missing Headers")
        if not header.get('typ'):
            raise InvalidTokenException("Missing Type")
        if not header.get('alg'):
            raise InvalidTokenException("Missing Algorithm")
        if not header.get('kid'):
            raise InvalidTokenException("Missing kid")
        if header.get('typ').upper() != 'JWT':
            raise InvalidTokenException("Invalid Type")
        if header.get('alg').upper() != 'RS256':
            raise InvalidTokenException("Invalid Algorithm")
        if header.get('kid').upper() != 'EDCRRM':
            raise InvalidTokenException("Invalid Key Identifier")
示例#11
0
def parse_metadata(metadata_to_check):
    parsed = {}
    try:
        for key, field in metadata_fields.items():
            logger.debug("parse_metadata: Adding attr %s", key)
            if key in metadata_to_check:
                attr_value = metadata_to_check[key]
                field.validate(attr_value)
                logger.debug("with value %s", attr_value)
            else:
                logger.debug("Generating value for %s", key)
                attr_value = field.generate()

            parsed[key] = attr_value
    except (RuntimeError, ValueError, TypeError) as e:
        logger.error("parse_metadata: Unable to parse")
        logger.exception(e)
        raise InvalidTokenException("Incorrect data in token")
    return parsed
示例#12
0
def decrypt_token(encrypted_token):
    logger.debug("decrypting token")
    if encrypted_token is None:
        raise NoTokenException("Please provide a token")

    decoder = JWTDecryptor(
        current_app.config['EQ_USER_AUTHENTICATION_SR_PRIVATE_KEY'],
        current_app.config['EQ_USER_AUTHENTICATION_SR_PRIVATE_KEY_PASSWORD'],
        current_app.config['EQ_USER_AUTHENTICATION_RRM_PUBLIC_KEY'],
    )

    decrypted_token = decoder.decrypt_jwt_token(
        encrypted_token,
        current_app.config['EQ_JWT_LEEWAY_IN_SECONDS'],
    )

    valid, field = is_valid_metadata(decrypted_token)
    if not valid:
        raise InvalidTokenException("Missing value {}".format(field))

    logger.debug("token decrypted")
    return decrypted_token
 def test(self):
     invalid_token = InvalidTokenException("test")
     self.assertEqual("test", str(invalid_token))
示例#14
0
def _check_user_data(token):
    valid, field = is_valid_metadata(token)
    if not valid:
        raise InvalidTokenException("Missing value {}".format(field))