Exemple #1
0
class TestJOSETransport3rdParty(unittest.TestCase):
    def setUp(self):
        self._transport = JOSETransport()
        self._transport.get = MagicMock(return_value=MagicMock(
            spec=APIResponse))
        public_key = APIResponse(valid_public_key, transport_request_headers,
                                 200)
        self._transport.get.return_value = public_key
        self._transport._server_time_difference = 0, time()

        self._jwt_patch = patch("launchkey.transports.jose_auth.JWT",
                                return_value=MagicMock(spec=JWT)).start()
        self._jwt_patch.return_value.unpack.return_value.headers = faux_jwt_headers

        self.addCleanup(patch.stopall)

    def test_public_key_parse(self):
        self._transport.get.return_value.data = valid_public_key
        keys = self._transport.api_public_keys
        self.assertEqual(len(keys), 1)
        self.assertIsInstance(keys[0], RSAKey)

    def test_build_jwt_signature_no_key(self):
        with self.assertRaises(NoIssuerKey):
            self._transport._build_jwt_signature(MagicMock(spec=str), ANY, ANY,
                                                 ANY, ANY)

    def test_build_jwt_signature(self):
        self._transport.set_issuer(ANY, uuid4(), valid_private_key)
        jwt = self._transport._build_jwt_signature("PUT", "/test",
                                                   str(uuid4()),
                                                   "test:subject", "ABCDEFG")
        self.assertIn('IOV-JWT', jwt)
        jwt = jwt.strip('IOV-JWT ')
        self.assertEqual(len(jwt.split(".")), 3)

    def test_invalid_jwt_response(self):
        headers = {"X-IOV-JWT": 'invalid'}
        with self.assertRaises(InvalidJWTResponse):
            self._transport.verify_jwt_response(headers, ANY, ANY, ANY)

    def test_jwt_response_error(self):
        headers = {"X-IOV-JWT": 'invalid'}
        with self.assertRaises(InvalidJWTResponse):
            self._transport.verify_jwt_response(headers, ANY, ANY, ANY)

    def test_empty_jwt_response(self):
        headers = {}
        with self.assertRaises(InvalidJWTResponse):
            self._transport.verify_jwt_response(headers, ANY, ANY, ANY)

    @patch.object(JWS, 'verify_compact')
    def test_jwt_error_raises_expected_exception(self, verify_compact_patch):
        verify_compact_patch.side_effect = JWKESTException
        with self.assertRaises(JWTValidationFailure):
            self._transport.verify_jwt_response({}, ANY, ANY, ANY)
Exemple #2
0
class TestJOSETransport3rdParty(unittest.TestCase):
    def setUp(self):
        self._transport = JOSETransport()
        self._transport.get = MagicMock(return_value=MagicMock(
            spec=APIResponse))
        public_key = APIResponse(valid_private_key, {
            "X-IOV-KEY-ID":
            "59:12:e2:f6:3f:79:d5:1e:18:75:c5:25:ff:b3:b7:f2"
        }, 200)
        self._transport.get.return_value = public_key
        self._transport._server_time_difference = 0, time()

    def test_public_key_parse(self):
        self._transport.get.return_value.data = valid_public_key
        keys = self._transport.api_public_keys
        self.assertEqual(len(keys), 1)
        self.assertIsInstance(keys[0], RSAKey)

    def test_build_jwt_signature_no_key(self):
        with self.assertRaises(NoIssuerKey):
            self._transport._build_jwt_signature(MagicMock(spec=str), ANY, ANY,
                                                 ANY, ANY)

    def test_build_jwt_signature(self):
        self._transport.set_issuer(ANY, uuid4(), valid_private_key)
        jwt = self._transport._build_jwt_signature("PUT", "/test",
                                                   str(uuid4()),
                                                   "test:subject", "ABCDEFG")
        self.assertIn('IOV-JWT', jwt)
        jwt = jwt.strip('IOV-JWT ')
        self.assertEqual(len(jwt.split(".")), 3)

    def test_invalid_jwt_response(self):
        headers = {"X-IOV-JWT": 'invalid'}
        with self.assertRaises(InvalidJWTResponse):
            self._transport.verify_jwt_response(headers, ANY, ANY, ANY)

    def test_empty_jwt_response(self):
        headers = {}
        with self.assertRaises(InvalidJWTResponse):
            self._transport.verify_jwt_response(headers, ANY, ANY, ANY)
class TestJOSETransport3rdParty(unittest.TestCase):

    def setUp(self):
        self._transport = JOSETransport()
        self._transport.get = MagicMock(return_value=MagicMock(spec=APIResponse))
        public_key = APIResponse(valid_private_key,
                                 {"X-IOV-KEY-ID": "59:12:e2:f6:3f:79:d5:1e:18:75:c5:25:ff:b3:b7:f2"}, 200)
        self._transport.get.return_value = public_key
        self._transport._server_time_difference = 0, time()

    def test_public_key_parse(self):
        self._transport.get.return_value.data = valid_public_key
        keys = self._transport.api_public_keys
        self.assertEqual(len(keys), 1)
        self.assertIsInstance(keys[0], RSAKey)

    def test_build_jwt_signature_no_key(self):
        with self.assertRaises(NoIssuerKey):
            self._transport._build_jwt_signature(MagicMock(spec=str), ANY, ANY, ANY, ANY)

    def test_build_jwt_signature(self):
        self._transport.set_issuer(ANY, uuid4(), valid_private_key)
        jwt = self._transport._build_jwt_signature("PUT", "/test", str(uuid4()), "test:subject", "ABCDEFG")
        self.assertIn('IOV-JWT', jwt)
        jwt = jwt.strip('IOV-JWT ')
        self.assertEqual(len(jwt.split(".")), 3)

    def test_invalid_jwt_response(self):
        headers = {"X-IOV-JWT": 'invalid'}
        with self.assertRaises(InvalidJWTResponse):
            self._transport.verify_jwt_response(headers, ANY, ANY, ANY)

    def test_jwt_response_error(self):
        headers = {"X-IOV-JWT": 'invalid'}
        with self.assertRaises(InvalidJWTResponse):
            self._transport.verify_jwt_response(headers, ANY, ANY, ANY)

    def test_empty_jwt_response(self):
        headers = {}
        with self.assertRaises(InvalidJWTResponse):
            self._transport.verify_jwt_response(headers, ANY, ANY, ANY)

    @patch.object(JWS, 'verify_compact')
    def test_jwt_error_raises_expected_exception(self, verify_compact_patch):
        verify_compact_patch.side_effect = JWKESTException
        with self.assertRaises(JWTValidationFailure):
            self._transport.verify_jwt_response({}, ANY, ANY, ANY)
Exemple #4
0
class TestCacheAndRetrieveKeyByKid(unittest.TestCase):
    def setUp(self):
        self.jti = str(uuid4())
        minified_jwt_payload = {"jti": self.jti, "response": {"status": 200}}

        self._requests_transport = MagicMock(spec=RequestsTransport)
        self._requests_transport.get.return_value = APIResponse(
            valid_private_key, {}, 200)
        self._transport = JOSETransport(http_client=self._requests_transport)
        self._import_rsa_key_patch = patch(
            "launchkey.transports.jose_auth.import_rsa_key",
            return_value=MagicMock(spec=RsaKey)).start()
        self._jwt_patch = patch("launchkey.transports.jose_auth.JWT",
                                return_value=MagicMock(spec=JWT)).start()
        self._jwt_patch.return_value.unpack.return_value.headers = faux_jwt_headers
        self._jws_patch = patch("launchkey.transports.jose_auth.JWS",
                                return_value=MagicMock(spec=JWS)).start()
        self._jws_patch.return_value.verify_compact.return_value = minified_jwt_payload

        patch.object(JOSETransport, "_verify_jwt_payload").start()
        patch.object(JOSETransport, "_verify_jwt_response_headers").start()
        patch.object(JOSETransport, "_verify_jwt_content_hash").start()

        self.addCleanup(patch.stopall)

    def test_key_retrieved_by_id_in_jwt_header(self):
        self._transport.verify_jwt_response(MagicMock(), self.jti, ANY, None)
        self._requests_transport.get.assert_called_once_with(
            "/public/v3/public-key/%s" % faux_kid, data={})

    def test_key_is_cached_by_id(self):
        self._transport.verify_jwt_response(MagicMock(), self.jti, ANY, None)
        self._transport.verify_jwt_response(MagicMock(), self.jti, ANY, None)
        self._requests_transport.get.assert_called_once()

    def test_key_is_retrieved_by_id_when_key_changed(self):
        jwt1 = MagicMock()
        jwt1.headers = {"alg": "RS512", "typ": "JWT", "kid": "jwt1keyid"}

        jwt2 = MagicMock()
        jwt2.headers = {"alg": "RS512", "typ": "JWT", "kid": "jwt2keyid"}

        self._jwt_patch.return_value.unpack.side_effect = [
            jwt1, jwt1, jwt2, jwt2
        ]

        self._transport.verify_jwt_response(MagicMock(), self.jti, ANY, None)
        self._transport.verify_jwt_response(MagicMock(), self.jti, ANY, None)
        self._requests_transport.get.assert_has_calls([
            call('/public/v3/public-key/jwt1keyid', data={}),
            call('/public/v3/public-key/jwt2keyid', data={})
        ],
                                                      any_order=True)

    @patch("launchkey.transports.jose_auth.RSAKey")
    def test_key_retrieved_is_used_to_verify_payload(self, rsa_key_patch):
        self._requests_transport.get.return_value = MagicMock(spec=APIResponse)
        self._requests_transport.get.return_value.data = valid_public_key
        self._transport.verify_jwt_response(MagicMock(), self.jti, ANY, None)

        # Verify that verify_compact is called one time with key created by our jwkest key patch
        self._jws_patch.return_value.verify_compact.assert_called_once_with(
            ANY, keys=[rsa_key_patch.return_value])

        # Assert that the jwkest key patch is built using the import_rsa_key patch return value and the key id
        # from the header
        rsa_key_patch.assert_called_with(
            key=self._import_rsa_key_patch.return_value, kid=faux_kid)

        # Verify that we used the correct key to retrieve the key id from the header
        self._requests_transport.get.return_value.headers.get.assert_called_with(
            "X-IOV-JWT")

    def test_raises_when_kid_header_is_missing(self):
        headers_without_kid = {"alg": "RS512", "typ": "JWT"}
        jwt = MagicMock()
        jwt.headers = headers_without_kid
        self._jwt_patch.return_value.unpack.side_effect = jwt
        with self.assertRaises(JWTValidationFailure):
            self._transport.verify_jwt_response(MagicMock(), self.jti, ANY,
                                                None)

    def test_raises_when_kid_header_is_malformed(self):
        headers_with_kid_of_wrong_type = {
            "alg": "RS512",
            "typ": "JWT",
            "kid": 1234
        }
        jwt = MagicMock()
        jwt.headers = headers_with_kid_of_wrong_type
        self._jwt_patch.return_value.unpack.side_effect = jwt
        with self.assertRaises(JWTValidationFailure):
            self._transport.verify_jwt_response(MagicMock(), self.jti, ANY,
                                                None)

    def test_raises_on_api_404(self):
        self._requests_transport.get.return_value = APIResponse(
            "Not Found", {}, 404)
        with self.assertRaises(UnexpectedAPIResponse):
            self._transport.verify_jwt_response(MagicMock(), self.jti, ANY,
                                                None)

    def test_raises_on_malformed_response_object(self):
        self._requests_transport.get.return_value = APIResponse(None, {}, 200)
        with self.assertRaises(UnexpectedAPIResponse):
            self._transport.verify_jwt_response(MagicMock(), self.jti, ANY,
                                                None)

    @patch("launchkey.transports.jose_auth.RSAKey")
    def test_raises_on_rsa_key_parsing_error(self, rsa_key_patch):
        rsa_key_patch.side_effect = TypeError
        with self.assertRaises(UnexpectedAPIResponse):
            self._transport.verify_jwt_response(MagicMock(), self.jti, ANY,
                                                None)
Exemple #5
0
class TestJOSETransportJWTResponse(unittest.TestCase):
    def setUp(self):
        self._transport = JOSETransport()
        self._transport.get = MagicMock(return_value=MagicMock(
            spec=APIResponse))
        public_key = APIResponse(valid_private_key, {}, 200)
        self._transport.get.return_value = public_key
        self._transport._server_time_difference = 0, time()
        self.issuer = "svc"
        self.issuer_id = uuid4()
        self._transport.set_issuer(self.issuer, self.issuer_id,
                                   valid_private_key)
        self._body = str(uuid4())
        self._content_hash = '16793293daadb5a03b7cbbb9d15a1a705b22e762a1a751bc8625dec666101ff2'

        self._jwt_patch = patch("launchkey.transports.jose_auth.JWT",
                                return_value=MagicMock(spec=JWT)).start()
        self._jwt_patch.return_value.unpack.return_value.headers = faux_jwt_headers

        self.jwt_payload = {
            'aud': '%s:%s' % (self.issuer, self.issuer_id),
            'iss': 'lka',
            'cty': 'application/json',
            'nbf': time(),
            'jti': str(uuid4()),
            'exp': time() + 30,
            'iat': time(),
            'response': {
                'status': 200,
                'hash': self._content_hash,
                'func': 'S256',
                'cache': 'expected cache-control',
                'location': 'expected location'
            },
            'sub': '%s:%s' % (self.issuer, self.issuer_id)
        }

        self._headers = {
            'Location': 'expected location',
            'Cache-Control': 'expected cache-control'
        }
        self._transport._get_jwt_payload = MagicMock(
            return_value=self.jwt_payload)
        self._transport.content_hash_function = MagicMock()
        self._transport.content_hash_function.return_value.hexdigest.return_value = self._content_hash

        patcher = patch('launchkey.transports.jose_auth.sha256')
        patched = patcher.start()
        patched.return_value = MagicMock()
        patched.return_value.hexdigest.return_value = self._content_hash
        self.addCleanup(patcher.stop)

    def _verify_jwt_response(self,
                             headers=None,
                             jti=None,
                             body=False,
                             subject=None,
                             status_code=200):
        headers = self._headers if headers is None else headers
        jti = self.jwt_payload['jti'] if jti is None else jti
        body = MagicMock() if body is False else body
        subject = self.jwt_payload['sub'] if subject is None else subject
        return self._transport.verify_jwt_response(headers, jti, body, subject,
                                                   status_code)

    def test_verify_jwt_response_success_returns_payload(self):
        actual = self._verify_jwt_response()
        self.assertEqual(actual, self.jwt_payload)

    def test_verify_jwt_response_invalid_audience(self):
        self.jwt_payload['aud'] = MagicMock()
        with self.assertRaises(JWTValidationFailure):
            self._verify_jwt_response()

    @patch("launchkey.transports.jose_auth.time")
    def test_verify_jwt_response_invalid_nbf(self, time_patch):
        time_patch.return_value = self.jwt_payload['nbf'] - JOSE_JWT_LEEWAY - 1
        with self.assertRaises(JWTValidationFailure):
            self._verify_jwt_response()

    @patch("launchkey.transports.jose_auth.time")
    def test_verify_jwt_response_invalid_exp(self, time_patch):
        time_patch.return_value = self.jwt_payload['exp'] + JOSE_JWT_LEEWAY + 1
        with self.assertRaises(JWTValidationFailure):
            self._verify_jwt_response()

    def test_verify_jwt_response_invalid_sub(self):
        with self.assertRaises(JWTValidationFailure):
            self._transport.verify_jwt_response(MagicMock(),
                                                self.jwt_payload['jti'],
                                                MagicMock(), MagicMock())

    def test_verify_jwt_response_invalid_sub_401(self):
        self.jwt_payload['response']['status'] = 401
        self.jwt_payload['aud'] = "public"
        with self.assertRaises(JWTValidationFailure):
            self._verify_jwt_response(subject="Invalid subject")

    def test_verify_jwt_response_invalid_iat(self):
        self.jwt_payload['iat'] = time() + JOSE_JWT_LEEWAY + 1
        with self.assertRaises(JWTValidationFailure):
            self._verify_jwt_response()

    def test_verify_jwt_response_invalid_content_body(self):
        self.jwt_payload['response']['hash'] = MagicMock()
        with self.assertRaises(JWTValidationFailure):
            self._verify_jwt_response()

    def test_verify_jwt_response_invalid_content_jti(self):
        with self.assertRaises(JWTValidationFailure):
            self._verify_jwt_response(jti='InvalidJTI')

    def test_verify_no_response_raises_validation_failure(self):
        del self.jwt_payload['response']
        with self.assertRaises(JWTValidationFailure):
            self._verify_jwt_response(body=None)

    def test_verify_no_body_but_response_body_hash_algorithm_raises_validation_failure(
            self):
        del self.jwt_payload['response']['hash']
        with self.assertRaises(JWTValidationFailure):
            self._verify_jwt_response(body=None)

    def test_verify_no_body_but_response_body_hash_raises_validation_failure(
            self):
        del self.jwt_payload['response']['func']
        with self.assertRaises(JWTValidationFailure):
            self._verify_jwt_response(body=None)

    def test_verify_invalid_status_code_raises_validation_failure(self):
        self.jwt_payload['response']['status'] = 999
        with self.assertRaises(JWTValidationFailure):
            self._verify_jwt_response()

    def test_verify_any_status_with_no_status_code_still_passes(self):
        self.assertIsNotNone(self._verify_jwt_response(status_code=None))

    def test_verify_no_status_in_jwt_response_and_no_status_code_raises_validation_failure(
            self):
        del self.jwt_payload['response']['status']
        with self.assertRaises(JWTValidationFailure):
            self.assertIsNotNone(self._verify_jwt_response(status_code=None))

    def test_verify_no_cache_control_header_and_jwt_response_cache_raises_validation_failure(
            self):
        del self._headers['Cache-Control']
        with self.assertRaises(JWTValidationFailure):
            self._verify_jwt_response()

    def test_verify_cache_control_header_and_jwt_with_no_response_cache_raises_validation_failure(
            self):
        del self.jwt_payload['response']['cache']
        with self.assertRaises(JWTValidationFailure):
            self._verify_jwt_response()

    def test_verify_invalid_cache_control_header_raises_validation_failure(
            self):
        self.jwt_payload['response']['cache'] = "Unexpected"
        with self.assertRaises(JWTValidationFailure):
            self._verify_jwt_response()

    def test_verify_no_location_header_and_jwt_response_cache_raises_validation_failure(
            self):
        del self._headers['Location']
        with self.assertRaises(JWTValidationFailure):
            self._verify_jwt_response()

    def test_verify_location_header_and_jwt_with_no_response_cache_raises_validation_failure(
            self):
        del self.jwt_payload['response']['location']
        with self.assertRaises(JWTValidationFailure):
            self._verify_jwt_response()

    def test_verify_invalid_location_header_raises_validation_failure(self):
        self.jwt_payload['response']['location'] = "Unexpected"
        with self.assertRaises(JWTValidationFailure):
            self._verify_jwt_response()
class TestJOSETransportJWTResponse(unittest.TestCase):

    def setUp(self):
        self._transport = JOSETransport()
        self._transport.get = MagicMock(return_value=MagicMock(spec=APIResponse))
        public_key = APIResponse(valid_private_key,
                                 {"X-IOV-KEY-ID": "59:12:e2:f6:3f:79:d5:1e:18:75:c5:25:ff:b3:b7:f2"}, 200)
        self._transport.get.return_value = public_key
        self._transport._server_time_difference = 0, time()
        self.issuer = "svc"
        self.issuer_id = uuid4()
        self._transport.set_issuer(self.issuer, self.issuer_id, valid_private_key)
        self._body = str(uuid4())
        self._content_hash = '16793293daadb5a03b7cbbb9d15a1a705b22e762a1a751bc8625dec666101ff2'

        self.jwt_payload = {
            'aud': '%s:%s' % (self.issuer, self.issuer_id),
            'iss': 'lka',
            'cty': 'application/json',
            'nbf': time(),
            'jti': str(uuid4()),
            'exp': time() + 30,
            'iat': time(),
            'response': {
                'status': 200,
                'hash': self._content_hash,
                'func': 'S256',
                'cache': 'expected cache-control',
                'location': 'expected location'},
            'sub': '%s:%s' % (self.issuer, self.issuer_id)
        }

        self._headers = {'Location': 'expected location', 'Cache-Control': 'expected cache-control'}
        self._transport._get_jwt_payload = MagicMock(return_value=self.jwt_payload)
        self._transport.content_hash_function = MagicMock()
        self._transport.content_hash_function.return_value.hexdigest.return_value = self._content_hash

        patcher = patch('launchkey.transports.jose_auth.sha256')
        patched = patcher.start()
        patched.return_value = MagicMock()
        patched.return_value.hexdigest.return_value = self._content_hash
        self.addCleanup(patcher.stop)

    def _verify_jwt_response(self, headers=None, jti=None, body=False, subject=None, status_code=200):
        headers = self._headers if headers is None else headers
        jti = self.jwt_payload['jti'] if jti is None else jti
        body = MagicMock() if body is False else body
        subject = self.jwt_payload['sub'] if subject is None else subject
        return self._transport.verify_jwt_response(headers, jti, body, subject, status_code)

    def test_verify_jwt_response_success_returns_payload(self):
        actual = self._verify_jwt_response()
        self.assertEqual(actual, self.jwt_payload)

    def test_verify_jwt_response_invalid_audience(self):
        self.jwt_payload['aud'] = MagicMock()
        with self.assertRaises(JWTValidationFailure):
            self._verify_jwt_response()

    @patch("launchkey.transports.jose_auth.time")
    def test_verify_jwt_response_invalid_nbf(self, time_patch):
        time_patch.return_value = self.jwt_payload['nbf'] - JOSE_JWT_LEEWAY - 1
        with self.assertRaises(JWTValidationFailure):
            self._verify_jwt_response()

    @patch("launchkey.transports.jose_auth.time")
    def test_verify_jwt_response_invalid_exp(self, time_patch):
        time_patch.return_value = self.jwt_payload['exp'] + JOSE_JWT_LEEWAY + 1
        with self.assertRaises(JWTValidationFailure):
            self._verify_jwt_response()

    def test_verify_jwt_response_invalid_sub(self):
        with self.assertRaises(JWTValidationFailure):
            self._transport.verify_jwt_response(MagicMock(), self.jwt_payload['jti'], MagicMock(), MagicMock())

    def test_verify_jwt_response_invalid_sub_401(self):
        self.jwt_payload['response']['status'] = 401
        self.jwt_payload['aud'] = "public"
        with self.assertRaises(JWTValidationFailure):
            self._verify_jwt_response(subject="Invalid subject")

    def test_verify_jwt_response_invalid_iat(self):
        self.jwt_payload['iat'] = time() + JOSE_JWT_LEEWAY + 1
        with self.assertRaises(JWTValidationFailure):
            self._verify_jwt_response()

    def test_verify_jwt_response_invalid_content_body(self):
        self.jwt_payload['response']['hash'] = MagicMock()
        with self.assertRaises(JWTValidationFailure):
            self._verify_jwt_response()

    def test_verify_jwt_response_invalid_content_jti(self):
        with self.assertRaises(JWTValidationFailure):
            self._verify_jwt_response(jti='InvalidJTI')

    def test_verify_no_response_raises_validation_failure(self):
        del self.jwt_payload['response']
        with self.assertRaises(JWTValidationFailure):
            self._verify_jwt_response(body=None)

    def test_verify_no_body_but_response_body_hash_algorithm_raises_validation_failure(self):
        del self.jwt_payload['response']['hash']
        with self.assertRaises(JWTValidationFailure):
            self._verify_jwt_response(body=None)

    def test_verify_no_body_but_response_body_hash_raises_validation_failure(self):
        del self.jwt_payload['response']['func']
        with self.assertRaises(JWTValidationFailure):
            self._verify_jwt_response(body=None)

    def test_verify_invalid_status_code_raises_validation_failure(self):
        self.jwt_payload['response']['status'] = 999
        with self.assertRaises(JWTValidationFailure):
            self._verify_jwt_response()

    def test_verify_any_status_with_no_status_code_still_passes(self):
        self.assertIsNotNone(self._verify_jwt_response(status_code=None))

    def test_verify_no_status_in_jwt_response_and_no_status_code_raises_validation_failure(self):
        del self.jwt_payload['response']['status']
        with self.assertRaises(JWTValidationFailure):
            self.assertIsNotNone(self._verify_jwt_response(status_code=None))

    def test_verify_no_cache_control_header_and_jwt_response_cache_raises_validation_failure(self):
        del self._headers['Cache-Control']
        with self.assertRaises(JWTValidationFailure):
            self._verify_jwt_response()

    def test_verify_cache_control_header_and_jwt_with_no_response_cache_raises_validation_failure(self):
        del self.jwt_payload['response']['cache']
        with self.assertRaises(JWTValidationFailure):
            self._verify_jwt_response()

    def test_verify_invalid_cache_control_header_raises_validation_failure(self):
        self.jwt_payload['response']['cache'] = "Unexpected"
        with self.assertRaises(JWTValidationFailure):
            self._verify_jwt_response()

    def test_verify_no_location_header_and_jwt_response_cache_raises_validation_failure(self):
        del self._headers['Location']
        with self.assertRaises(JWTValidationFailure):
            self._verify_jwt_response()

    def test_verify_location_header_and_jwt_with_no_response_cache_raises_validation_failure(self):
        del self.jwt_payload['response']['location']
        with self.assertRaises(JWTValidationFailure):
            self._verify_jwt_response()

    def test_verify_invalid_location_header_raises_validation_failure(self):
        self.jwt_payload['response']['location'] = "Unexpected"
        with self.assertRaises(JWTValidationFailure):
            self._verify_jwt_response()
Exemple #7
0
class TestJOSETransportJWTResponse(unittest.TestCase):
    def setUp(self):
        self._transport = JOSETransport()
        self._transport.get = MagicMock(return_value=MagicMock(
            spec=APIResponse))
        public_key = APIResponse(valid_private_key, {
            "X-IOV-KEY-ID":
            "59:12:e2:f6:3f:79:d5:1e:18:75:c5:25:ff:b3:b7:f2"
        }, 200)
        self._transport.get.return_value = public_key
        self._transport._server_time_difference = 0, time()
        self.issuer = ANY
        self.issuer_id = uuid4()
        self._transport.set_issuer(self.issuer, self.issuer_id,
                                   valid_private_key)
        self._body = str(uuid4())
        self._content_hash = '16793293daadb5a03b7cbbb9d15a1a705b22e762a1a751bc8625dec666101ff2'

        self.jwt_response = {
            'aud': '%s:%s' % (self.issuer, self.issuer_id),
            'iss': 'lka',
            'cty': 'application/json',
            'nbf': time(),
            'jti': str(uuid4()),
            'exp': time() + 30,
            'iat': time(),
            'response': {
                'status': 200,
                'hash': self._content_hash,
                'func': 'S256'
            },
            'sub': '%s:%s' % (self.issuer, self.issuer_id)
        }
        self._transport._get_jwt_payload = MagicMock(
            return_value=self.jwt_response)
        self._transport.content_hash_function = MagicMock()
        self._transport.content_hash_function(
        ).hexdigest.return_value = self._content_hash

    def test_verify_jwt_response_success(self):
        self._transport.verify_jwt_response(MagicMock(),
                                            self.jwt_response['jti'],
                                            MagicMock(),
                                            self.jwt_response['sub'])

    def test_verify_jwt_response_invalid_audience(self):
        self.jwt_response['aud'] = MagicMock()
        with self.assertRaises(JWTValidationFailure):
            self._transport.verify_jwt_response(MagicMock(),
                                                self.jwt_response['jti'],
                                                MagicMock(),
                                                self.jwt_response['sub'])

    @patch("launchkey.transports.jose_auth.time")
    def test_verify_jwt_response_invalid_nbf(self, time_patch):
        time_patch.return_value = self.jwt_response['nbf'] - JOSE_JWT_LEEWAY - 1
        with self.assertRaises(JWTValidationFailure):
            self._transport.verify_jwt_response(MagicMock(),
                                                self.jwt_response['jti'],
                                                MagicMock(),
                                                self.jwt_response['sub'])

    @patch("launchkey.transports.jose_auth.time")
    def test_verify_jwt_response_invalid_exp(self, time_patch):
        time_patch.return_value = self.jwt_response['exp'] + JOSE_JWT_LEEWAY + 1
        with self.assertRaises(JWTValidationFailure):
            self._transport.verify_jwt_response(MagicMock(),
                                                self.jwt_response['jti'],
                                                MagicMock(),
                                                self.jwt_response['sub'])

    def test_verify_jwt_response_invalid_sub(self):
        with self.assertRaises(JWTValidationFailure):
            self._transport.verify_jwt_response(MagicMock(),
                                                self.jwt_response['jti'],
                                                MagicMock(), MagicMock())

    def test_verify_jwt_response_invalid_iat(self):
        self.jwt_response['iat'] = time() + JOSE_JWT_LEEWAY + 1
        with self.assertRaises(JWTValidationFailure):
            self._transport.verify_jwt_response(MagicMock(),
                                                self.jwt_response['jti'],
                                                MagicMock(),
                                                self.jwt_response['sub'])

    def test_verify_jwt_response_invalid_content_body(self):
        self.jwt_response['response']['hash'] = MagicMock()
        with self.assertRaises(JWTValidationFailure):
            self._transport.verify_jwt_response(MagicMock(),
                                                self.jwt_response['jti'],
                                                MagicMock(),
                                                self.jwt_response['sub'])

    def test_verify_jwt_response_invalid_content_jti(self):
        with self.assertRaises(JWTValidationFailure):
            self._transport.verify_jwt_response(MagicMock(), 'InvalidJTI',
                                                MagicMock(),
                                                self.jwt_response['sub'])