Example #1
0
class TestJOSETransport3rdParty(unittest.TestCase):
    def setUp(self):
        self._transport = JOSETransport()
        self._transport.get = MagicMock(return_value=MagicMock(
            spec=APIResponse))
        public_key = APIResponse(valid_public_key, transport_request_headers,
                                 200)
        self._transport.get.return_value = public_key
        self._transport._server_time_difference = 0, time()

        self._jwt_patch = patch("launchkey.transports.jose_auth.JWT",
                                return_value=MagicMock(spec=JWT)).start()
        self._jwt_patch.return_value.unpack.return_value.headers = faux_jwt_headers

        self.addCleanup(patch.stopall)

    def test_public_key_parse(self):
        self._transport.get.return_value.data = valid_public_key
        keys = self._transport.api_public_keys
        self.assertEqual(len(keys), 1)
        self.assertIsInstance(keys[0], RSAKey)

    def test_build_jwt_signature_no_key(self):
        with self.assertRaises(NoIssuerKey):
            self._transport._build_jwt_signature(MagicMock(spec=str), ANY, ANY,
                                                 ANY, ANY)

    def test_build_jwt_signature(self):
        self._transport.set_issuer(ANY, uuid4(), valid_private_key)
        jwt = self._transport._build_jwt_signature("PUT", "/test",
                                                   str(uuid4()),
                                                   "test:subject", "ABCDEFG")
        self.assertIn('IOV-JWT', jwt)
        jwt = jwt.strip('IOV-JWT ')
        self.assertEqual(len(jwt.split(".")), 3)

    def test_invalid_jwt_response(self):
        headers = {"X-IOV-JWT": 'invalid'}
        with self.assertRaises(InvalidJWTResponse):
            self._transport.verify_jwt_response(headers, ANY, ANY, ANY)

    def test_jwt_response_error(self):
        headers = {"X-IOV-JWT": 'invalid'}
        with self.assertRaises(InvalidJWTResponse):
            self._transport.verify_jwt_response(headers, ANY, ANY, ANY)

    def test_empty_jwt_response(self):
        headers = {}
        with self.assertRaises(InvalidJWTResponse):
            self._transport.verify_jwt_response(headers, ANY, ANY, ANY)

    @patch.object(JWS, 'verify_compact')
    def test_jwt_error_raises_expected_exception(self, verify_compact_patch):
        verify_compact_patch.side_effect = JWKESTException
        with self.assertRaises(JWTValidationFailure):
            self._transport.verify_jwt_response({}, ANY, ANY, ANY)
class TestJOSETransport3rdParty(unittest.TestCase):

    def setUp(self):
        self._transport = JOSETransport()
        self._transport.get = MagicMock(return_value=MagicMock(spec=APIResponse))
        public_key = APIResponse(valid_private_key,
                                 {"X-IOV-KEY-ID": "59:12:e2:f6:3f:79:d5:1e:18:75:c5:25:ff:b3:b7:f2"}, 200)
        self._transport.get.return_value = public_key
        self._transport._server_time_difference = 0, time()

    def test_public_key_parse(self):
        self._transport.get.return_value.data = valid_public_key
        keys = self._transport.api_public_keys
        self.assertEqual(len(keys), 1)
        self.assertIsInstance(keys[0], RSAKey)

    def test_build_jwt_signature_no_key(self):
        with self.assertRaises(NoIssuerKey):
            self._transport._build_jwt_signature(MagicMock(spec=str), ANY, ANY, ANY, ANY)

    def test_build_jwt_signature(self):
        self._transport.set_issuer(ANY, uuid4(), valid_private_key)
        jwt = self._transport._build_jwt_signature("PUT", "/test", str(uuid4()), "test:subject", "ABCDEFG")
        self.assertIn('IOV-JWT', jwt)
        jwt = jwt.strip('IOV-JWT ')
        self.assertEqual(len(jwt.split(".")), 3)

    def test_invalid_jwt_response(self):
        headers = {"X-IOV-JWT": 'invalid'}
        with self.assertRaises(InvalidJWTResponse):
            self._transport.verify_jwt_response(headers, ANY, ANY, ANY)

    def test_jwt_response_error(self):
        headers = {"X-IOV-JWT": 'invalid'}
        with self.assertRaises(InvalidJWTResponse):
            self._transport.verify_jwt_response(headers, ANY, ANY, ANY)

    def test_empty_jwt_response(self):
        headers = {}
        with self.assertRaises(InvalidJWTResponse):
            self._transport.verify_jwt_response(headers, ANY, ANY, ANY)

    @patch.object(JWS, 'verify_compact')
    def test_jwt_error_raises_expected_exception(self, verify_compact_patch):
        verify_compact_patch.side_effect = JWKESTException
        with self.assertRaises(JWTValidationFailure):
            self._transport.verify_jwt_response({}, ANY, ANY, ANY)
class TestJOSETransportJWT(unittest.TestCase):

    def setUp(self):
        self._transport = JOSETransport()
        self._transport.get = MagicMock(return_value=MagicMock(spec=APIResponse))
        self._transport._get_jwt_signature = MagicMock(return_value='x.x.x')
        public_key = APIResponse(valid_public_key,
                                 {"X-IOV-KEY-ID": "59:12:e2:f6:3f:79:d5:1e:18:75:c5:25:ff:b3:b7:f2"}, 200)
        self._transport.get.return_value = public_key
        self._transport._server_time_difference = 0, time()

    def test_build_jwt_signature(self):
        self._transport.set_issuer(ANY, uuid4(), valid_private_key)
        jwt = self._transport._build_jwt_signature("PUT", "/test", str(uuid4()), "test:subject", "ABCDEFG")
        self.assertIn('IOV-JWT', jwt)
        jwt = jwt.strip('IOV-JWT ')
        self.assertEqual(len(jwt.split(".")), 3)
Example #4
0
class TestJOSETransport3rdParty(unittest.TestCase):
    def setUp(self):
        self._transport = JOSETransport()
        self._transport.get = MagicMock(return_value=MagicMock(
            spec=APIResponse))
        public_key = APIResponse(valid_private_key, {
            "X-IOV-KEY-ID":
            "59:12:e2:f6:3f:79:d5:1e:18:75:c5:25:ff:b3:b7:f2"
        }, 200)
        self._transport.get.return_value = public_key
        self._transport._server_time_difference = 0, time()

    def test_public_key_parse(self):
        self._transport.get.return_value.data = valid_public_key
        keys = self._transport.api_public_keys
        self.assertEqual(len(keys), 1)
        self.assertIsInstance(keys[0], RSAKey)

    def test_build_jwt_signature_no_key(self):
        with self.assertRaises(NoIssuerKey):
            self._transport._build_jwt_signature(MagicMock(spec=str), ANY, ANY,
                                                 ANY, ANY)

    def test_build_jwt_signature(self):
        self._transport.set_issuer(ANY, uuid4(), valid_private_key)
        jwt = self._transport._build_jwt_signature("PUT", "/test",
                                                   str(uuid4()),
                                                   "test:subject", "ABCDEFG")
        self.assertIn('IOV-JWT', jwt)
        jwt = jwt.strip('IOV-JWT ')
        self.assertEqual(len(jwt.split(".")), 3)

    def test_invalid_jwt_response(self):
        headers = {"X-IOV-JWT": 'invalid'}
        with self.assertRaises(InvalidJWTResponse):
            self._transport.verify_jwt_response(headers, ANY, ANY, ANY)

    def test_empty_jwt_response(self):
        headers = {}
        with self.assertRaises(InvalidJWTResponse):
            self._transport.verify_jwt_response(headers, ANY, ANY, ANY)
class TestJOSETransportIssuers(unittest.TestCase):

    def setUp(self):
        self._transport = JOSETransport()

    def test_add_issuer_key(self):
        self.assertEqual(len(self._transport.issuer_private_keys), 0)
        self._transport.add_issuer_key(valid_private_key)
        self.assertEqual(len(self._transport.issuer_private_keys), 1)

    def test_add_duplicate_issuer_key(self):
        self.assertEqual(len(self._transport.issuer_private_keys), 0)
        self._transport.add_issuer_key(valid_private_key)
        self.assertEqual(len(self._transport.issuer_private_keys), 1)
        resp = self._transport.add_issuer_key(valid_private_key)
        self.assertFalse(resp)
        self.assertEqual(len(self._transport.issuer_private_keys), 1)

    def test_generate_key_id(self):
        self._transport.add_issuer_key(valid_private_key)
        self.assertEqual(self._transport.issuer_private_keys[0].kid,
                         '59:12:e2:f6:3f:79:d5:1e:18:75:c5:25:ff:b3:b7:f2')

    def test_set_url(self):
        self._transport._http_client = MagicMock()
        self._transport.set_url(ANY, ANY)
        self._transport._http_client.set_url.assert_called_once()

    def test_set_issuer_invalid_entity_id(self):
        with self.assertRaises(InvalidEntityID):
            self._transport.set_issuer(ANY, ANY, ANY)

    def test_set_issuer_invalid_entity_issuer(self):
        with self.assertRaises(InvalidIssuer):
            self._transport.set_issuer(MagicMock(spec=str), uuid4(), ANY)

    def test_set_issuer_invalid_private_key(self):
        with self.assertRaises(InvalidPrivateKey):
            self._transport.set_issuer(ANY, uuid4(), "InvalidKey")

    @patch("launchkey.transports.jose_auth.RSAKey")
    @patch("launchkey.transports.jose_auth.import_rsa_key")
    def test_issuer_list(self, rsa_key_patch, import_key_patch):
        rsa_key_patch.return_value = MagicMock(spec=RSAKey)
        import_key_patch.return_value = MagicMock()
        self._transport.add_issuer_key = MagicMock()
        for issuer in VALID_JWT_ISSUER_LIST:
            self._transport.set_issuer(issuer, uuid4(), ANY)
Example #6
0
class TestJOSETransportJWTRequest(unittest.TestCase):
    def setUp(self):
        self._transport = JOSETransport()
        self._transport.get = MagicMock(return_value=MagicMock(
            spec=APIResponse))
        public_key = APIResponse(valid_private_key, {}, 200)
        self._transport.get.return_value = public_key
        self._transport._server_time_difference = 0, time()
        self.issuer = "svc"
        self.issuer_id = uuid4()
        self._transport.set_issuer(self.issuer, self.issuer_id,
                                   valid_private_key)
        self._body = str(uuid4())
        self._content_hash = '16793293daadb5a03b7cbbb9d15a1a705b22e762a1a751bc8625dec666101ff2'

        self.jwt_payload = {
            'aud': '%s:%s' % (self.issuer, self.issuer_id),
            'iss': 'lka',
            'cty': 'application/json',
            'nbf': time(),
            'jti': str(uuid4()),
            'exp': time() + 30,
            'iat': time(),
            'request': {
                'meth': 'POST',
                'path': '/',
                'hash': self._content_hash,
                'func': 'S256'
            },
            'sub': '%s:%s' % (self.issuer, self.issuer_id)
        }
        self._transport._get_jwt_payload = MagicMock(
            return_value=self.jwt_payload)
        self._transport.content_hash_function = MagicMock()

        self._transport.content_hash_function(
        ).hexdigest.return_value = self._content_hash

        self._jwt_patch = patch("launchkey.transports.jose_auth.JWT",
                                return_value=MagicMock(spec=JWT)).start()
        self._jwt_patch.return_value.unpack.return_value.headers = faux_jwt_headers

        patcher = patch('launchkey.transports.jose_auth.sha256')
        patched = patcher.start()
        patched.return_value = MagicMock()
        patched.return_value.hexdigest.return_value = self._content_hash
        self.addCleanup(patcher.stop)

    def _verify_jwt_request(self,
                            compact_jwt="compact.jtw.value",
                            subscriber=None,
                            method='POST',
                            path='/',
                            body="body"):
        subscriber = self.jwt_payload[
            'sub'] if subscriber is None else subscriber
        return self._transport.verify_jwt_request(compact_jwt, subscriber,
                                                  method, path, body)

    def test_all_params_returns_payload(self):
        actual = self._verify_jwt_request()
        self.assertEqual(actual, self.jwt_payload)

    def test_none_path_returns_payload(self):
        actual = self._transport.verify_jwt_request(MagicMock(),
                                                    self.jwt_payload['sub'],
                                                    'POST', None, MagicMock())
        self.assertEqual(actual, self.jwt_payload)

    def test_none_path_still_requires_jwt_request_path(self):
        del self.jwt_payload['request']['path']
        with self.assertRaises(JWTValidationFailure):
            self._verify_jwt_request(path=None)

    def test_none_method_returns_payload(self):
        actual = self._verify_jwt_request(method=None)
        self.assertEqual(actual, self.jwt_payload)

    def test_none_method_still_requires_jwt_request_path(self):
        del self.jwt_payload['request']['meth']
        with self.assertRaises(JWTValidationFailure):
            self._verify_jwt_request(method=None)

    def test_invalid_audience_raises_validation_failure(self):
        self.jwt_payload['aud'] = MagicMock()
        with self.assertRaises(JWTValidationFailure):
            self._verify_jwt_request()

    def test_invalid_issuer_raises_validation_failure(self):
        self.jwt_payload['iss'] = MagicMock()
        with self.assertRaises(JWTValidationFailure):
            self._verify_jwt_request()

    @patch("launchkey.transports.jose_auth.time")
    def test_invalid_nbf_raises_validation_failure(self, time_patch):
        time_patch.return_value = self.jwt_payload['nbf'] - JOSE_JWT_LEEWAY - 1
        with self.assertRaises(JWTValidationFailure):
            self._verify_jwt_request()

    @patch("launchkey.transports.jose_auth.time")
    def test_invalid_exp_raises_validation_failure(self, time_patch):
        time_patch.return_value = self.jwt_payload['exp'] + JOSE_JWT_LEEWAY + 1
        with self.assertRaises(JWTValidationFailure):
            self._verify_jwt_request()

    def test_invalid_sub_raises_validation_failure(self):
        with self.assertRaises(JWTValidationFailure):
            self._verify_jwt_request(subscriber='other')

    def test_invalid_iat_raises_validation_failure(self):
        self.jwt_payload['iat'] = time() + JOSE_JWT_LEEWAY + 1
        with self.assertRaises(JWTValidationFailure):
            self._verify_jwt_request()

    def test_invalid_content_body_raises_validation_failure(self):
        self.jwt_payload['request']['hash'] = MagicMock()
        with self.assertRaises(JWTValidationFailure):
            self._verify_jwt_request()

    def test_no_body_but_response_body_hash_algorithm_raises_validation_failure(
            self):
        del self.jwt_payload['request']['func']
        with self.assertRaises(JWTValidationFailure):
            self._verify_jwt_request(body=None)

    def test_no_body_but_response_body_hash_raises_validation_failure(self):
        del self.jwt_payload['request']['hash']
        with self.assertRaises(JWTValidationFailure):
            self._verify_jwt_request(body=None)

    def test_body_but_no_response_body_hash_algorithm_raises_validation_failure(
            self):
        del self.jwt_payload['request']['hash']
        with self.assertRaises(JWTValidationFailure):
            self._verify_jwt_request()

    def test_body_but_no_response_body_hash_raises_validation_failure(self):
        del self.jwt_payload['request']['func']
        with self.assertRaises(JWTValidationFailure):
            self._verify_jwt_request()

    def test_invalid_hash_func_raises_validation_failure(self):
        self.jwt_payload['request']['func'] = 'INVALID'
        with self.assertRaises(JWTValidationFailure):
            self._verify_jwt_request()

    def test_no_request_raises_validation_failure(self):
        del self.jwt_payload['request']
        with self.assertRaises(JWTValidationFailure):
            self._verify_jwt_request(method='INV')

    def test_no_method_raises_validation_failure(self):
        del self.jwt_payload['request']['meth']
        with self.assertRaises(JWTValidationFailure):
            self._verify_jwt_request(method='INV')

    def test_invalid_method_raises_validation_failure(self):
        with self.assertRaises(JWTValidationFailure):
            self._verify_jwt_request(method='INV')

    def test_no_path_raises_validation_failure(self):
        del self.jwt_payload['request']['path']
        with self.assertRaises(JWTValidationFailure):
            self._verify_jwt_request(path='/invalid')

    def test_invalid_path_raises_validation_failure(self):
        with self.assertRaises(JWTValidationFailure):
            self._verify_jwt_request(path='/invalid')

    def test_invalid_jwt_raises_validation_failure(self):
        self._transport._get_jwt_payload.side_effect = JWKESTException
        with self.assertRaises(JWTValidationFailure):
            self._verify_jwt_request()
Example #7
0
class TestJOSETransportJWTResponse(unittest.TestCase):
    def setUp(self):
        self._transport = JOSETransport()
        self._transport.get = MagicMock(return_value=MagicMock(
            spec=APIResponse))
        public_key = APIResponse(valid_private_key, {}, 200)
        self._transport.get.return_value = public_key
        self._transport._server_time_difference = 0, time()
        self.issuer = "svc"
        self.issuer_id = uuid4()
        self._transport.set_issuer(self.issuer, self.issuer_id,
                                   valid_private_key)
        self._body = str(uuid4())
        self._content_hash = '16793293daadb5a03b7cbbb9d15a1a705b22e762a1a751bc8625dec666101ff2'

        self._jwt_patch = patch("launchkey.transports.jose_auth.JWT",
                                return_value=MagicMock(spec=JWT)).start()
        self._jwt_patch.return_value.unpack.return_value.headers = faux_jwt_headers

        self.jwt_payload = {
            'aud': '%s:%s' % (self.issuer, self.issuer_id),
            'iss': 'lka',
            'cty': 'application/json',
            'nbf': time(),
            'jti': str(uuid4()),
            'exp': time() + 30,
            'iat': time(),
            'response': {
                'status': 200,
                'hash': self._content_hash,
                'func': 'S256',
                'cache': 'expected cache-control',
                'location': 'expected location'
            },
            'sub': '%s:%s' % (self.issuer, self.issuer_id)
        }

        self._headers = {
            'Location': 'expected location',
            'Cache-Control': 'expected cache-control'
        }
        self._transport._get_jwt_payload = MagicMock(
            return_value=self.jwt_payload)
        self._transport.content_hash_function = MagicMock()
        self._transport.content_hash_function.return_value.hexdigest.return_value = self._content_hash

        patcher = patch('launchkey.transports.jose_auth.sha256')
        patched = patcher.start()
        patched.return_value = MagicMock()
        patched.return_value.hexdigest.return_value = self._content_hash
        self.addCleanup(patcher.stop)

    def _verify_jwt_response(self,
                             headers=None,
                             jti=None,
                             body=False,
                             subject=None,
                             status_code=200):
        headers = self._headers if headers is None else headers
        jti = self.jwt_payload['jti'] if jti is None else jti
        body = MagicMock() if body is False else body
        subject = self.jwt_payload['sub'] if subject is None else subject
        return self._transport.verify_jwt_response(headers, jti, body, subject,
                                                   status_code)

    def test_verify_jwt_response_success_returns_payload(self):
        actual = self._verify_jwt_response()
        self.assertEqual(actual, self.jwt_payload)

    def test_verify_jwt_response_invalid_audience(self):
        self.jwt_payload['aud'] = MagicMock()
        with self.assertRaises(JWTValidationFailure):
            self._verify_jwt_response()

    @patch("launchkey.transports.jose_auth.time")
    def test_verify_jwt_response_invalid_nbf(self, time_patch):
        time_patch.return_value = self.jwt_payload['nbf'] - JOSE_JWT_LEEWAY - 1
        with self.assertRaises(JWTValidationFailure):
            self._verify_jwt_response()

    @patch("launchkey.transports.jose_auth.time")
    def test_verify_jwt_response_invalid_exp(self, time_patch):
        time_patch.return_value = self.jwt_payload['exp'] + JOSE_JWT_LEEWAY + 1
        with self.assertRaises(JWTValidationFailure):
            self._verify_jwt_response()

    def test_verify_jwt_response_invalid_sub(self):
        with self.assertRaises(JWTValidationFailure):
            self._transport.verify_jwt_response(MagicMock(),
                                                self.jwt_payload['jti'],
                                                MagicMock(), MagicMock())

    def test_verify_jwt_response_invalid_sub_401(self):
        self.jwt_payload['response']['status'] = 401
        self.jwt_payload['aud'] = "public"
        with self.assertRaises(JWTValidationFailure):
            self._verify_jwt_response(subject="Invalid subject")

    def test_verify_jwt_response_invalid_iat(self):
        self.jwt_payload['iat'] = time() + JOSE_JWT_LEEWAY + 1
        with self.assertRaises(JWTValidationFailure):
            self._verify_jwt_response()

    def test_verify_jwt_response_invalid_content_body(self):
        self.jwt_payload['response']['hash'] = MagicMock()
        with self.assertRaises(JWTValidationFailure):
            self._verify_jwt_response()

    def test_verify_jwt_response_invalid_content_jti(self):
        with self.assertRaises(JWTValidationFailure):
            self._verify_jwt_response(jti='InvalidJTI')

    def test_verify_no_response_raises_validation_failure(self):
        del self.jwt_payload['response']
        with self.assertRaises(JWTValidationFailure):
            self._verify_jwt_response(body=None)

    def test_verify_no_body_but_response_body_hash_algorithm_raises_validation_failure(
            self):
        del self.jwt_payload['response']['hash']
        with self.assertRaises(JWTValidationFailure):
            self._verify_jwt_response(body=None)

    def test_verify_no_body_but_response_body_hash_raises_validation_failure(
            self):
        del self.jwt_payload['response']['func']
        with self.assertRaises(JWTValidationFailure):
            self._verify_jwt_response(body=None)

    def test_verify_invalid_status_code_raises_validation_failure(self):
        self.jwt_payload['response']['status'] = 999
        with self.assertRaises(JWTValidationFailure):
            self._verify_jwt_response()

    def test_verify_any_status_with_no_status_code_still_passes(self):
        self.assertIsNotNone(self._verify_jwt_response(status_code=None))

    def test_verify_no_status_in_jwt_response_and_no_status_code_raises_validation_failure(
            self):
        del self.jwt_payload['response']['status']
        with self.assertRaises(JWTValidationFailure):
            self.assertIsNotNone(self._verify_jwt_response(status_code=None))

    def test_verify_no_cache_control_header_and_jwt_response_cache_raises_validation_failure(
            self):
        del self._headers['Cache-Control']
        with self.assertRaises(JWTValidationFailure):
            self._verify_jwt_response()

    def test_verify_cache_control_header_and_jwt_with_no_response_cache_raises_validation_failure(
            self):
        del self.jwt_payload['response']['cache']
        with self.assertRaises(JWTValidationFailure):
            self._verify_jwt_response()

    def test_verify_invalid_cache_control_header_raises_validation_failure(
            self):
        self.jwt_payload['response']['cache'] = "Unexpected"
        with self.assertRaises(JWTValidationFailure):
            self._verify_jwt_response()

    def test_verify_no_location_header_and_jwt_response_cache_raises_validation_failure(
            self):
        del self._headers['Location']
        with self.assertRaises(JWTValidationFailure):
            self._verify_jwt_response()

    def test_verify_location_header_and_jwt_with_no_response_cache_raises_validation_failure(
            self):
        del self.jwt_payload['response']['location']
        with self.assertRaises(JWTValidationFailure):
            self._verify_jwt_response()

    def test_verify_invalid_location_header_raises_validation_failure(self):
        self.jwt_payload['response']['location'] = "Unexpected"
        with self.assertRaises(JWTValidationFailure):
            self._verify_jwt_response()
class TestJOSETransportJWTRequest(unittest.TestCase):

    def setUp(self):
        self._transport = JOSETransport()
        self._transport.get = MagicMock(return_value=MagicMock(spec=APIResponse))
        public_key = APIResponse(valid_private_key,
                                 {"X-IOV-KEY-ID": "59:12:e2:f6:3f:79:d5:1e:18:75:c5:25:ff:b3:b7:f2"}, 200)
        self._transport.get.return_value = public_key
        self._transport._server_time_difference = 0, time()
        self.issuer = "svc"
        self.issuer_id = uuid4()
        self._transport.set_issuer(self.issuer, self.issuer_id, valid_private_key)
        self._body = str(uuid4())
        self._content_hash = '16793293daadb5a03b7cbbb9d15a1a705b22e762a1a751bc8625dec666101ff2'

        self.jwt_payload = {
            'aud': '%s:%s' % (self.issuer, self.issuer_id),
            'iss': 'lka',
            'cty': 'application/json',
            'nbf': time(),
            'jti': str(uuid4()),
            'exp': time() + 30,
            'iat': time(),
            'request': {
                'meth': 'POST',
                'path': '/',
                'hash': self._content_hash,
                'func': 'S256'},
            'sub': '%s:%s' % (self.issuer, self.issuer_id)
        }
        self._transport._get_jwt_payload = MagicMock(return_value=self.jwt_payload)
        self._transport.content_hash_function = MagicMock()

        self._transport.content_hash_function().hexdigest.return_value = self._content_hash

        patcher = patch('launchkey.transports.jose_auth.sha256')
        patched = patcher.start()
        patched.return_value = MagicMock()
        patched.return_value.hexdigest.return_value = self._content_hash
        self.addCleanup(patcher.stop)

    def _verify_jwt_request(self, compact_jwt="compact.jtw.value", subscriber=None, method='POST', path='/', body="body"):
        subscriber = self.jwt_payload['sub'] if subscriber is None else subscriber
        return self._transport.verify_jwt_request(compact_jwt, subscriber, method, path, body)

    def test_all_params_returns_payload(self):
        actual = self._verify_jwt_request()
        self.assertEqual(actual, self.jwt_payload)

    def test_none_path_returns_payload(self):
        actual = self._transport.verify_jwt_request(MagicMock(), self.jwt_payload['sub'], 'POST', None, MagicMock())
        self.assertEqual(actual, self.jwt_payload)

    def test_none_path_still_requires_jwt_request_path(self):
        del self.jwt_payload['request']['path']
        with self.assertRaises(JWTValidationFailure):
            self._verify_jwt_request(path=None)

    def test_none_method_returns_payload(self):
        actual = self._verify_jwt_request(method=None)
        self.assertEqual(actual, self.jwt_payload)

    def test_none_method_still_requires_jwt_request_path(self):
        del self.jwt_payload['request']['meth']
        with self.assertRaises(JWTValidationFailure):
            self._verify_jwt_request(method=None)

    def test_invalid_audience_raises_validation_failure(self):
        self.jwt_payload['aud'] = MagicMock()
        with self.assertRaises(JWTValidationFailure):
            self._verify_jwt_request()

    def test_invalid_issuer_raises_validation_failure(self):
        self.jwt_payload['iss'] = MagicMock()
        with self.assertRaises(JWTValidationFailure):
            self._verify_jwt_request()

    @patch("launchkey.transports.jose_auth.time")
    def test_invalid_nbf_raises_validation_failure(self, time_patch):
        time_patch.return_value = self.jwt_payload['nbf'] - JOSE_JWT_LEEWAY - 1
        with self.assertRaises(JWTValidationFailure):
            self._verify_jwt_request()

    @patch("launchkey.transports.jose_auth.time")
    def test_invalid_exp_raises_validation_failure(self, time_patch):
        time_patch.return_value = self.jwt_payload['exp'] + JOSE_JWT_LEEWAY + 1
        with self.assertRaises(JWTValidationFailure):
            self._verify_jwt_request()

    def test_invalid_sub_raises_validation_failure(self):
        with self.assertRaises(JWTValidationFailure):
            self._verify_jwt_request(subscriber='other')

    def test_invalid_iat_raises_validation_failure(self):
        self.jwt_payload['iat'] = time() + JOSE_JWT_LEEWAY + 1
        with self.assertRaises(JWTValidationFailure):
            self._verify_jwt_request()

    def test_invalid_content_body_raises_validation_failure(self):
        self.jwt_payload['request']['hash'] = MagicMock()
        with self.assertRaises(JWTValidationFailure):
            self._verify_jwt_request()

    def test_no_body_but_response_body_hash_algorithm_raises_validation_failure(self):
        del self.jwt_payload['request']['func']
        with self.assertRaises(JWTValidationFailure):
            self._verify_jwt_request(body=None)

    def test_no_body_but_response_body_hash_raises_validation_failure(self):
        del self.jwt_payload['request']['hash']
        with self.assertRaises(JWTValidationFailure):
            self._verify_jwt_request(body=None)

    def test_body_but_no_response_body_hash_algorithm_raises_validation_failure(self):
        del self.jwt_payload['request']['hash']
        with self.assertRaises(JWTValidationFailure):
            self._verify_jwt_request()

    def test_body_but_no_response_body_hash_raises_validation_failure(self):
        del self.jwt_payload['request']['func']
        with self.assertRaises(JWTValidationFailure):
            self._verify_jwt_request()

    def test_invalid_hash_func_raises_validation_failure(self):
        self.jwt_payload['request']['func'] = 'INVALID'
        with self.assertRaises(JWTValidationFailure):
            self._verify_jwt_request()

    def test_no_request_raises_validation_failure(self):
        del self.jwt_payload['request']
        with self.assertRaises(JWTValidationFailure):
            self._verify_jwt_request(method='INV')

    def test_no_method_raises_validation_failure(self):
        del self.jwt_payload['request']['meth']
        with self.assertRaises(JWTValidationFailure):
            self._verify_jwt_request(method='INV')

    def test_invalid_method_raises_validation_failure(self):
        with self.assertRaises(JWTValidationFailure):
            self._verify_jwt_request(method='INV')

    def test_no_path_raises_validation_failure(self):
        del self.jwt_payload['request']['path']
        with self.assertRaises(JWTValidationFailure):
            self._verify_jwt_request(path='/invalid')

    def test_invalid_path_raises_validation_failure(self):
        with self.assertRaises(JWTValidationFailure):
            self._verify_jwt_request(path='/invalid')

    def test_invalid_jwt_raises_validation_failure(self):
        self._transport._get_jwt_payload.side_effect = JWKESTException
        with self.assertRaises(JWTValidationFailure):
            self._verify_jwt_request()
class TestJOSETransportJWTResponse(unittest.TestCase):

    def setUp(self):
        self._transport = JOSETransport()
        self._transport.get = MagicMock(return_value=MagicMock(spec=APIResponse))
        public_key = APIResponse(valid_private_key,
                                 {"X-IOV-KEY-ID": "59:12:e2:f6:3f:79:d5:1e:18:75:c5:25:ff:b3:b7:f2"}, 200)
        self._transport.get.return_value = public_key
        self._transport._server_time_difference = 0, time()
        self.issuer = "svc"
        self.issuer_id = uuid4()
        self._transport.set_issuer(self.issuer, self.issuer_id, valid_private_key)
        self._body = str(uuid4())
        self._content_hash = '16793293daadb5a03b7cbbb9d15a1a705b22e762a1a751bc8625dec666101ff2'

        self.jwt_payload = {
            'aud': '%s:%s' % (self.issuer, self.issuer_id),
            'iss': 'lka',
            'cty': 'application/json',
            'nbf': time(),
            'jti': str(uuid4()),
            'exp': time() + 30,
            'iat': time(),
            'response': {
                'status': 200,
                'hash': self._content_hash,
                'func': 'S256',
                'cache': 'expected cache-control',
                'location': 'expected location'},
            'sub': '%s:%s' % (self.issuer, self.issuer_id)
        }

        self._headers = {'Location': 'expected location', 'Cache-Control': 'expected cache-control'}
        self._transport._get_jwt_payload = MagicMock(return_value=self.jwt_payload)
        self._transport.content_hash_function = MagicMock()
        self._transport.content_hash_function.return_value.hexdigest.return_value = self._content_hash

        patcher = patch('launchkey.transports.jose_auth.sha256')
        patched = patcher.start()
        patched.return_value = MagicMock()
        patched.return_value.hexdigest.return_value = self._content_hash
        self.addCleanup(patcher.stop)

    def _verify_jwt_response(self, headers=None, jti=None, body=False, subject=None, status_code=200):
        headers = self._headers if headers is None else headers
        jti = self.jwt_payload['jti'] if jti is None else jti
        body = MagicMock() if body is False else body
        subject = self.jwt_payload['sub'] if subject is None else subject
        return self._transport.verify_jwt_response(headers, jti, body, subject, status_code)

    def test_verify_jwt_response_success_returns_payload(self):
        actual = self._verify_jwt_response()
        self.assertEqual(actual, self.jwt_payload)

    def test_verify_jwt_response_invalid_audience(self):
        self.jwt_payload['aud'] = MagicMock()
        with self.assertRaises(JWTValidationFailure):
            self._verify_jwt_response()

    @patch("launchkey.transports.jose_auth.time")
    def test_verify_jwt_response_invalid_nbf(self, time_patch):
        time_patch.return_value = self.jwt_payload['nbf'] - JOSE_JWT_LEEWAY - 1
        with self.assertRaises(JWTValidationFailure):
            self._verify_jwt_response()

    @patch("launchkey.transports.jose_auth.time")
    def test_verify_jwt_response_invalid_exp(self, time_patch):
        time_patch.return_value = self.jwt_payload['exp'] + JOSE_JWT_LEEWAY + 1
        with self.assertRaises(JWTValidationFailure):
            self._verify_jwt_response()

    def test_verify_jwt_response_invalid_sub(self):
        with self.assertRaises(JWTValidationFailure):
            self._transport.verify_jwt_response(MagicMock(), self.jwt_payload['jti'], MagicMock(), MagicMock())

    def test_verify_jwt_response_invalid_sub_401(self):
        self.jwt_payload['response']['status'] = 401
        self.jwt_payload['aud'] = "public"
        with self.assertRaises(JWTValidationFailure):
            self._verify_jwt_response(subject="Invalid subject")

    def test_verify_jwt_response_invalid_iat(self):
        self.jwt_payload['iat'] = time() + JOSE_JWT_LEEWAY + 1
        with self.assertRaises(JWTValidationFailure):
            self._verify_jwt_response()

    def test_verify_jwt_response_invalid_content_body(self):
        self.jwt_payload['response']['hash'] = MagicMock()
        with self.assertRaises(JWTValidationFailure):
            self._verify_jwt_response()

    def test_verify_jwt_response_invalid_content_jti(self):
        with self.assertRaises(JWTValidationFailure):
            self._verify_jwt_response(jti='InvalidJTI')

    def test_verify_no_response_raises_validation_failure(self):
        del self.jwt_payload['response']
        with self.assertRaises(JWTValidationFailure):
            self._verify_jwt_response(body=None)

    def test_verify_no_body_but_response_body_hash_algorithm_raises_validation_failure(self):
        del self.jwt_payload['response']['hash']
        with self.assertRaises(JWTValidationFailure):
            self._verify_jwt_response(body=None)

    def test_verify_no_body_but_response_body_hash_raises_validation_failure(self):
        del self.jwt_payload['response']['func']
        with self.assertRaises(JWTValidationFailure):
            self._verify_jwt_response(body=None)

    def test_verify_invalid_status_code_raises_validation_failure(self):
        self.jwt_payload['response']['status'] = 999
        with self.assertRaises(JWTValidationFailure):
            self._verify_jwt_response()

    def test_verify_any_status_with_no_status_code_still_passes(self):
        self.assertIsNotNone(self._verify_jwt_response(status_code=None))

    def test_verify_no_status_in_jwt_response_and_no_status_code_raises_validation_failure(self):
        del self.jwt_payload['response']['status']
        with self.assertRaises(JWTValidationFailure):
            self.assertIsNotNone(self._verify_jwt_response(status_code=None))

    def test_verify_no_cache_control_header_and_jwt_response_cache_raises_validation_failure(self):
        del self._headers['Cache-Control']
        with self.assertRaises(JWTValidationFailure):
            self._verify_jwt_response()

    def test_verify_cache_control_header_and_jwt_with_no_response_cache_raises_validation_failure(self):
        del self.jwt_payload['response']['cache']
        with self.assertRaises(JWTValidationFailure):
            self._verify_jwt_response()

    def test_verify_invalid_cache_control_header_raises_validation_failure(self):
        self.jwt_payload['response']['cache'] = "Unexpected"
        with self.assertRaises(JWTValidationFailure):
            self._verify_jwt_response()

    def test_verify_no_location_header_and_jwt_response_cache_raises_validation_failure(self):
        del self._headers['Location']
        with self.assertRaises(JWTValidationFailure):
            self._verify_jwt_response()

    def test_verify_location_header_and_jwt_with_no_response_cache_raises_validation_failure(self):
        del self.jwt_payload['response']['location']
        with self.assertRaises(JWTValidationFailure):
            self._verify_jwt_response()

    def test_verify_invalid_location_header_raises_validation_failure(self):
        self.jwt_payload['response']['location'] = "Unexpected"
        with self.assertRaises(JWTValidationFailure):
            self._verify_jwt_response()
Example #10
0
class TestJOSETransportJWTResponse(unittest.TestCase):
    def setUp(self):
        self._transport = JOSETransport()
        self._transport.get = MagicMock(return_value=MagicMock(
            spec=APIResponse))
        public_key = APIResponse(valid_private_key, {
            "X-IOV-KEY-ID":
            "59:12:e2:f6:3f:79:d5:1e:18:75:c5:25:ff:b3:b7:f2"
        }, 200)
        self._transport.get.return_value = public_key
        self._transport._server_time_difference = 0, time()
        self.issuer = ANY
        self.issuer_id = uuid4()
        self._transport.set_issuer(self.issuer, self.issuer_id,
                                   valid_private_key)
        self._body = str(uuid4())
        self._content_hash = '16793293daadb5a03b7cbbb9d15a1a705b22e762a1a751bc8625dec666101ff2'

        self.jwt_response = {
            'aud': '%s:%s' % (self.issuer, self.issuer_id),
            'iss': 'lka',
            'cty': 'application/json',
            'nbf': time(),
            'jti': str(uuid4()),
            'exp': time() + 30,
            'iat': time(),
            'response': {
                'status': 200,
                'hash': self._content_hash,
                'func': 'S256'
            },
            'sub': '%s:%s' % (self.issuer, self.issuer_id)
        }
        self._transport._get_jwt_payload = MagicMock(
            return_value=self.jwt_response)
        self._transport.content_hash_function = MagicMock()
        self._transport.content_hash_function(
        ).hexdigest.return_value = self._content_hash

    def test_verify_jwt_response_success(self):
        self._transport.verify_jwt_response(MagicMock(),
                                            self.jwt_response['jti'],
                                            MagicMock(),
                                            self.jwt_response['sub'])

    def test_verify_jwt_response_invalid_audience(self):
        self.jwt_response['aud'] = MagicMock()
        with self.assertRaises(JWTValidationFailure):
            self._transport.verify_jwt_response(MagicMock(),
                                                self.jwt_response['jti'],
                                                MagicMock(),
                                                self.jwt_response['sub'])

    @patch("launchkey.transports.jose_auth.time")
    def test_verify_jwt_response_invalid_nbf(self, time_patch):
        time_patch.return_value = self.jwt_response['nbf'] - JOSE_JWT_LEEWAY - 1
        with self.assertRaises(JWTValidationFailure):
            self._transport.verify_jwt_response(MagicMock(),
                                                self.jwt_response['jti'],
                                                MagicMock(),
                                                self.jwt_response['sub'])

    @patch("launchkey.transports.jose_auth.time")
    def test_verify_jwt_response_invalid_exp(self, time_patch):
        time_patch.return_value = self.jwt_response['exp'] + JOSE_JWT_LEEWAY + 1
        with self.assertRaises(JWTValidationFailure):
            self._transport.verify_jwt_response(MagicMock(),
                                                self.jwt_response['jti'],
                                                MagicMock(),
                                                self.jwt_response['sub'])

    def test_verify_jwt_response_invalid_sub(self):
        with self.assertRaises(JWTValidationFailure):
            self._transport.verify_jwt_response(MagicMock(),
                                                self.jwt_response['jti'],
                                                MagicMock(), MagicMock())

    def test_verify_jwt_response_invalid_iat(self):
        self.jwt_response['iat'] = time() + JOSE_JWT_LEEWAY + 1
        with self.assertRaises(JWTValidationFailure):
            self._transport.verify_jwt_response(MagicMock(),
                                                self.jwt_response['jti'],
                                                MagicMock(),
                                                self.jwt_response['sub'])

    def test_verify_jwt_response_invalid_content_body(self):
        self.jwt_response['response']['hash'] = MagicMock()
        with self.assertRaises(JWTValidationFailure):
            self._transport.verify_jwt_response(MagicMock(),
                                                self.jwt_response['jti'],
                                                MagicMock(),
                                                self.jwt_response['sub'])

    def test_verify_jwt_response_invalid_content_jti(self):
        with self.assertRaises(JWTValidationFailure):
            self._transport.verify_jwt_response(MagicMock(), 'InvalidJTI',
                                                MagicMock(),
                                                self.jwt_response['sub'])