def split_signed_compact( signed_compact: Text) -> Tuple[bytes, Text, Text, bytes]: """Splits a signed compact into its parts. Args: signed_compact: A signed compact JWT. Returns: A (unsigned_compact, json_header, json_payload, signature_or_mac) tuple. Raises: _jwt_error.JwtInvalidError if it fails. """ try: encoded = signed_compact.encode('utf8') except UnicodeEncodeError: raise _jwt_error.JwtInvalidError('invalid token') try: unsigned_compact, encoded_signature = encoded.rsplit(b'.', 1) except ValueError: raise _jwt_error.JwtInvalidError('invalid token') signature_or_mac = decode_signature(encoded_signature) try: encoded_header, encoded_payload = unsigned_compact.split(b'.') except ValueError: raise _jwt_error.JwtInvalidError('invalid token') json_header = decode_header(encoded_header) json_payload = decode_payload(encoded_payload) return (unsigned_compact, json_header, json_payload, signature_or_mac)
def _validate_timestamp_claim(self, name: str): if name in self._payload: timestamp = self._payload[name] if not isinstance(timestamp, (int, float)): raise _jwt_error.JwtInvalidError('claim %s must be a Number' % name) if timestamp > _MAX_TIMESTAMP_VALUE or timestamp < 0: raise _jwt_error.JwtInvalidError( 'timestamp of claim %s is out of range' % name)
def get_kid(key_id: int, prefix: tink_pb2.OutputPrefixType) -> Optional[str]: """Returns the encoded key_id, or None.""" if prefix == tink_pb2.RAW: return None if prefix == tink_pb2.TINK: if key_id < 0 or key_id > 2**32: raise _jwt_error.JwtInvalidError('invalid key_id') return base64_encode(struct.pack('>L', key_id)).decode('utf8') raise _jwt_error.JwtInvalidError('unexpected output prefix type')
def _validate_audience_claim(self): """The 'aud' claim must either be a string or a list of strings.""" if 'aud' in self._payload: audiences = self._payload['aud'] if isinstance(audiences, str): return if not isinstance(audiences, list) or not audiences: raise _jwt_error.JwtInvalidError('audiences cannot be an empty list') if not all(isinstance(value, str) for value in audiences): raise _jwt_error.JwtInvalidError('audiences must only contain strings')
def validate_header(header: Any, algorithm: Text) -> None: """Parses the header and validates its values.""" _validate_algorithm(algorithm) hdr_algorithm = header.get('alg', '') if hdr_algorithm.upper() != algorithm: raise _jwt_error.JwtInvalidError('Invalid algorithm; expected %s, got %s' % (algorithm, hdr_algorithm)) if 'crit' in header: raise _jwt_error.JwtInvalidError( 'all tokens with crit headers are rejected')
def create(cls, *, type_header: Optional[str] = None, issuer: Optional[str] = None, subject: Optional[str] = None, audience: Optional[str] = None, audiences: Optional[List[str]] = None, jwt_id: Optional[str] = None, expiration: Optional[datetime.datetime] = None, without_expiration: Optional[bool] = None, not_before: Optional[datetime.datetime] = None, issued_at: Optional[datetime.datetime] = None, custom_claims: Optional[Mapping[str, Claim]] = None) -> 'RawJwt': """Create a new RawJwt instance.""" if not expiration and not without_expiration: raise ValueError('either expiration or without_expiration must be set') if expiration and without_expiration: raise ValueError( 'expiration and without_expiration cannot be set at the same time') if audience is not None and audiences is not None: raise _jwt_error.JwtInvalidError( 'audience and audiences cannot be set at the same time') payload = {} if issuer: payload['iss'] = issuer if subject: payload['sub'] = subject if jwt_id is not None: payload['jti'] = jwt_id if audience is not None: payload['aud'] = audience if audiences is not None: payload['aud'] = copy.copy(audiences) if expiration: payload['exp'] = _from_datetime(expiration) if not_before: payload['nbf'] = _from_datetime(not_before) if issued_at: payload['iat'] = _from_datetime(issued_at) if custom_claims: for name, value in custom_claims.items(): _validate_custom_claim_name(name) if not isinstance(name, str): raise _jwt_error.JwtInvalidError('claim name must be Text') if (value is None or isinstance(value, (bool, int, float, str))): payload[name] = value elif isinstance(value, list): payload[name] = json.loads(json.dumps(value)) elif isinstance(value, dict): payload[name] = json.loads(json.dumps(value)) else: raise _jwt_error.JwtInvalidError('claim %s has unknown type' % name) raw_jwt = object.__new__(cls) raw_jwt.__init__(type_header, payload) return raw_jwt
def _validate_audience_claim(self): if 'aud' in self._payload: audiences = self._payload['aud'] if isinstance(audiences, Text): self._payload['aud'] = [audiences] return if not isinstance(audiences, list) or not audiences: raise _jwt_error.JwtInvalidError( 'audiences must be a non-empty list') if not all(isinstance(value, Text) for value in audiences): raise _jwt_error.JwtInvalidError( 'audiences must only contain Text')
def json_loads(json_text: Text) -> Any: """Does the same as json.loads, but with some additinal validation.""" try: json_data = json.loads(json_text) validate_all_strings(json_data) return json_data except json.decoder.JSONDecodeError: raise _jwt_error.JwtInvalidError('Failed to parse JSON string') except RecursionError: raise _jwt_error.JwtInvalidError( 'Failed to parse JSON string, too many recursions') except UnicodeEncodeError: raise _jwt_error.JwtInvalidError('invalid character')
def _base64_decode(encoded_data: bytes) -> bytes: """Does a URL-safe base64 decoding without padding.""" # base64.urlsafe_b64decode ignores all non-base64 chars. We don't want that. for c in encoded_data: if not _is_valid_urlsafe_base64_char(c): raise _jwt_error.JwtInvalidError('invalid token') # base64.urlsafe_b64decode requires padding, but does not mind too much # padding. So we simply add the maximum ammount of padding needed. padded_encoded_data = encoded_data + b'===' try: return base64.urlsafe_b64decode(padded_encoded_data) except binascii.Error: # Throws when the length of encoded_data is (4*i + 1) for some i raise _jwt_error.JwtInvalidError('invalid token')
def validate_header(header: bytes, algorithm: Text) -> None: """Parses the header and validates its values.""" _validate_algorithm(algorithm) json_header = _base64_decode(header).decode('utf8') decoded_header = json_loads(json_header) hdr_algorithm = decoded_header.get('alg', '') if hdr_algorithm.upper() != algorithm: raise _jwt_error.JwtInvalidError('Invalid algorithm; expected %s, got %s' % (algorithm, hdr_algorithm)) header_type = decoded_header.get('typ', None) if 'typ' in decoded_header: if decoded_header['typ'].upper() != 'JWT': raise _jwt_error.JwtInvalidError( 'Invalid header type; expected JWT, got %s' % decoded_header['typ'])
def _rsa_ssa_pss_alg_kid_from_private_key_data( key_data: tink_pb2.KeyData) -> Tuple[Text, Optional[Text]]: if key_data.type_url != _JWT_RSA_SSA_PSS_PRIVATE_KEY_TYPE: raise _jwt_error.JwtInvalidError('Invalid key data key type') key = jwt_rsa_ssa_pss_pb2.JwtRsaSsaPssPrivateKey.FromString(key_data.value) return (_rsa_ssa_pss_algorithm_text(key.public_key.algorithm), _get_custom_kid(key.public_key))
def _ecdsa_alg_kid_from_private_key_data( key_data: tink_pb2.KeyData) -> Tuple[str, Optional[str]]: if key_data.type_url != _JWT_ECDSA_PRIVATE_KEY_TYPE: raise _jwt_error.JwtInvalidError('Invalid key data key type') key = jwt_ecdsa_pb2.JwtEcdsaPrivateKey.FromString(key_data.value) return (_ecdsa_algorithm_text(key.public_key.algorithm), _get_custom_kid(key.public_key))
def primitive(self, key_data: tink_pb2.KeyData) -> _jwt_mac.JwtMac: if key_data.type_url != _JWT_HMAC_KEY_TYPE: raise _jwt_error.JwtInvalidError('Invalid key data key type') jwt_hmac_key = jwt_hmac_pb2.JwtHmacKey.FromString(key_data.value) algorithm = _HASH_TYPE_TO_ALGORITHM[jwt_hmac_key.hash_type] cc_mac = self._cc_key_manager.primitive(key_data.SerializeToString()) return _JwtHmac(cc_mac, algorithm)
def compute_mac_and_encode_with_kid(self, raw_jwt: _raw_jwt.RawJwt, kid: Optional[Text]) -> Text: """Computes a MAC and encodes the token. Args: raw_jwt: The RawJwt token to be MACed and encoded. kid: Optional "kid" header value. It is set by the wrapper for keys with output prefix TINK, and it is None for output prefix RAW. Returns: The MACed token encoded in the JWS compact serialization format. Raises: tink.TinkError if the operation fails. """ if raw_jwt.has_type_header(): type_header = raw_jwt.type_header() else: type_header = None if self._custom_kid is not None: if kid is not None: raise _jwt_error.JwtInvalidError( 'custom_kid must not be set for keys with output prefix type TINK' ) kid = self._custom_kid unsigned = _jwt_format.create_unsigned_compact(self._algorithm, type_header, kid, raw_jwt.json_payload()) return _jwt_format.create_signed_compact(unsigned, self._compute_mac(unsigned))
def _rsa_ssa_pkcs1_alg_kid_from_public_key_data( key_data: tink_pb2.KeyData) -> Tuple[str, Optional[str]]: if key_data.type_url != _JWT_RSA_SSA_PKCS1_PUBLIC_KEY_TYPE: raise _jwt_error.JwtInvalidError('Invalid key data key type') key = jwt_rsa_ssa_pkcs1_pb2.JwtRsaSsaPkcs1PublicKey.FromString( key_data.value) return (_rsa_ssa_pkcs1_algorithm_text(key.algorithm), _get_custom_kid(key))
def validate_header(json_header: Text, algorithm: Text) -> None: """Parses the header and validates its values.""" _validate_algorithm(algorithm) decoded_header = json_loads(json_header) hdr_algorithm = decoded_header.get('alg', '') if hdr_algorithm.upper() != algorithm: raise _jwt_error.JwtInvalidError( 'Invalid algorithm; expected %s, got %s' % (algorithm, hdr_algorithm))
def _base64_decode(encoded_data: bytes) -> bytes: # base64.urlsafe_b64decode ignores all non-base64 chars. We don't want that. for c in encoded_data: if not _is_valid_urlsafe_base64_char(c): raise _jwt_error.JwtInvalidError('invalid token') # base64.urlsafe_b64decode requires padding, but does not mind too much # padding. So we simply add the maximum ammount of padding needed. padded_encoded_data = encoded_data + b'===' return base64.urlsafe_b64decode(padded_encoded_data)
def create(cls, issuer: Optional[Text] = None, subject: Optional[Text] = None, audiences: Optional[List[Text]] = None, jwt_id: Optional[Text] = None, expiration: Optional[datetime.datetime] = None, not_before: Optional[datetime.datetime] = None, issued_at: Optional[datetime.datetime] = None, custom_claims: Mapping[Text, Claim] = None) -> 'RawJwt': """Create a new RawJwt instance.""" payload = {} if issuer: payload['iss'] = issuer if subject: payload['sub'] = subject if jwt_id is not None: payload['jti'] = jwt_id if audiences is not None: payload['aud'] = copy.copy(audiences) if expiration: payload['exp'] = _from_datetime(expiration) if not_before: payload['nbf'] = _from_datetime(not_before) if issued_at: payload['iat'] = _from_datetime(issued_at) if custom_claims: for name, value in custom_claims.items(): _validate_custom_claim_name(name) if not isinstance(name, Text): raise _jwt_error.JwtInvalidError('claim name must be Text') if (value is None or isinstance(value, (bool, int, float, Text))): payload[name] = value elif isinstance(value, list): payload[name] = json.loads(json.dumps(value)) elif isinstance(value, dict): payload[name] = json.loads(json.dumps(value)) else: raise _jwt_error.JwtInvalidError( 'claim %s has unknown type' % name) raw_jwt = object.__new__(cls) raw_jwt.__init__(payload) return raw_jwt
def validate(validator: JwtValidator, raw_jwt: _raw_jwt.RawJwt) -> None: """Validates a jwt.RawJwt and raises JwtInvalidError if it is invalid. This function is called by the JWT primitives and does not need to be called by the user. Args: validator: a jwt.JwtValidator that defines how to validate tokens. raw_jwt: a jwt.RawJwt token to validate. Raises: jwt.JwtInvalidError """ if validator.has_fixed_now(): now = validator.fixed_now() else: now = datetime.datetime.now(tz=datetime.timezone.utc) if (raw_jwt.has_expiration() and raw_jwt.expiration() < now - validator.clock_skew()): raise _jwt_error.JwtInvalidError('token has expired since %s' % raw_jwt.expiration()) if (raw_jwt.has_not_before() and raw_jwt.not_before() > now + validator.clock_skew()): raise _jwt_error.JwtInvalidError('token cannot be used before %s' % raw_jwt.not_before()) if validator.has_issuer(): if not raw_jwt.has_issuer(): raise _jwt_error.JwtInvalidError( 'invalid JWT; missing expected issuer %s.' % validator.issuer()) if validator.issuer() != raw_jwt.issuer(): raise _jwt_error.JwtInvalidError( 'invalid JWT; expected issuer %s, but got %s' % (validator.issuer(), raw_jwt.issuer())) if validator.has_subject(): if not raw_jwt.has_subject(): raise _jwt_error.JwtInvalidError( 'invalid JWT; missing expected subject %s.' % validator.subject()) if validator.subject() != raw_jwt.subject(): raise _jwt_error.JwtInvalidError( 'invalid JWT; expected subject %s, but got %s' % (validator.subject(), raw_jwt.subject())) if validator.has_audience(): if (not raw_jwt.has_audiences() or validator.audience() not in raw_jwt.audiences()): raise _jwt_error.JwtInvalidError( 'invalid JWT; missing expected audience %s.' % validator.audience()) else: if raw_jwt.has_audiences(): raise _jwt_error.JwtInvalidError( 'invalid JWT; token has audience set, but validator not.')
def primitive(self, key_data: tink_pb2.KeyData) -> _jwt_mac.JwtMacInternal: if key_data.type_url != _JWT_HMAC_KEY_TYPE: raise _jwt_error.JwtInvalidError('Invalid key data key type') jwt_hmac_key = jwt_hmac_pb2.JwtHmacKey.FromString(key_data.value) algorithm = _ALGORITHM_STRING[jwt_hmac_key.algorithm] cc_mac = self._cc_key_manager.primitive(key_data.SerializeToString()) if jwt_hmac_key.HasField('custom_kid'): custom_kid = jwt_hmac_key.custom_kid.value else: custom_kid = None return _JwtHmac(cc_mac, algorithm, custom_kid)
def verify_mac_and_decode( self, compact: Text, validator: _jwt_validator.JwtValidator ) -> _verified_jwt.VerifiedJwt: """Verifies, validates and decodes a MACed compact JWT token.""" encoded = compact.encode('utf8') try: unsigned_compact, encoded_signature = encoded.rsplit(b'.', 1) except ValueError: raise _jwt_error.JwtInvalidError('invalid token') signature = _jwt_format.decode_signature(encoded_signature) self._verify_mac(signature, unsigned_compact) try: encoded_header, encoded_payload = unsigned_compact.split(b'.') except ValueError: raise _jwt_error.JwtInvalidError('invalid token') _jwt_format.validate_header(encoded_header, self._algorithm) json_payload = _jwt_format.decode_payload(encoded_payload) raw_jwt = _raw_jwt.RawJwt.from_json_payload(json_payload) _jwt_validator.validate(validator, raw_jwt) return _verified_jwt.VerifiedJwt._create(raw_jwt) # pylint: disable=protected-access
def __init__(self, payload: Dict[Text, Any]) -> None: # No need to copy payload, because only create and from_json_payload # call this method. if not isinstance(payload, Dict): raise _jwt_error.JwtInvalidError('payload must be a dict') self._payload = payload self._validate_string_claim('iss') self._validate_string_claim('sub') self._validate_string_claim('jti') self._validate_number_claim('exp') self._validate_number_claim('nbf') self._validate_number_claim('iat') self._validate_audience_claim()
def validate_header(header: Any, algorithm: str, tink_kid: Optional[str] = None, custom_kid: Optional[str] = None) -> None: """Parses the header and validates its values.""" _validate_algorithm(algorithm) hdr_algorithm = header.get('alg', '') if hdr_algorithm.upper() != algorithm: raise _jwt_error.JwtInvalidError('Invalid algorithm; expected %s, got %s' % (algorithm, hdr_algorithm)) if 'crit' in header: raise _jwt_error.JwtInvalidError( 'all tokens with crit headers are rejected') if tink_kid is not None and custom_kid is not None: raise _jwt_error.JwtInvalidError('custom_kid can only be set for RAW keys') if tink_kid is not None: if 'kid' not in header: # for output prefix type TINK, the kid header is required raise _jwt_error.JwtInvalidError('missing kid in header') _validate_kid_header(header, tink_kid) if custom_kid is not None and 'kid' in header: _validate_kid_header(header, custom_kid)
def __init__(self, type_header: Optional[str], payload: Dict[str, Any]) -> None: # No need to copy payload, because only create and from_json_payload # call this method. if not isinstance(payload, Dict): raise _jwt_error.JwtInvalidError('payload must be a dict') self._type_header = type_header self._payload = payload self._validate_string_claim('iss') self._validate_string_claim('sub') self._validate_string_claim('jti') self._validate_timestamp_claim('exp') self._validate_timestamp_claim('nbf') self._validate_timestamp_claim('iat') self._validate_audience_claim()
def sign_and_encode_with_kid(self, raw_jwt: _raw_jwt.RawJwt, kid: Optional[str]) -> str: """Computes a signature and encodes the token. Args: raw_jwt: The RawJwt token to be MACed and encoded. kid: Optional "kid" header value. It is set by the wrapper for keys with output prefix TINK, and it is None for output prefix RAW. Returns: The MACed token encoded in the JWS compact serialization format. Raises: tink.TinkError if the operation fails. """ if self._custom_kid is not None: if kid is not None: raise _jwt_error.JwtInvalidError( 'custom_kid must not be set for keys with output prefix type TINK' ) kid = self._custom_kid unsigned = _jwt_format.create_unsigned_compact(self._algorithm, kid, raw_jwt) return _jwt_format.create_signed_compact(unsigned, self._sign(unsigned))
def _validate_number_claim(self, name: Text): if name in self._payload: if not isinstance(self._payload[name], (int, float)): raise _jwt_error.JwtInvalidError('claim %s must be a Number' % name)
def _validate_string_claim(self, name: Text): if name in self._payload: if not isinstance(self._payload[name], Text): raise _jwt_error.JwtInvalidError('claim %s must be a String' % name)
def _ecdsa_algorithm_text(algorithm: jwt_ecdsa_pb2.JwtEcdsaAlgorithm) -> Text: if algorithm not in _ECDSA_ALGORITHM_TEXTS: raise _jwt_error.JwtInvalidError('Invalid algorithm') return _ECDSA_ALGORITHM_TEXTS[algorithm]
def _ecdsa_algorithm_from_public_key_data(key_data: tink_pb2.KeyData) -> Text: if key_data.type_url != _JWT_ECDSA_PUBLIC_KEY_TYPE: raise _jwt_error.JwtInvalidError('Invalid key data key type') key = jwt_ecdsa_pb2.JwtEcdsaPublicKey.FromString(key_data.value) return _ecdsa_algorithm_text(key.algorithm)
def _validate_custom_claim_name(name: Text) -> None: if name in _REGISTERED_NAMES: raise _jwt_error.JwtInvalidError( 'registered name %s cannot be custom claim name' % name)