class TestJOSETransportJWTRequest(unittest.TestCase):

    def setUp(self):
        self._transport = JOSETransport()
        self._transport.get = MagicMock(return_value=MagicMock(spec=APIResponse))
        public_key = APIResponse(valid_private_key,
                                 {"X-IOV-KEY-ID": "59:12:e2:f6:3f:79:d5:1e:18:75:c5:25:ff:b3:b7:f2"}, 200)
        self._transport.get.return_value = public_key
        self._transport._server_time_difference = 0, time()
        self.issuer = "svc"
        self.issuer_id = uuid4()
        self._transport.set_issuer(self.issuer, self.issuer_id, valid_private_key)
        self._body = str(uuid4())
        self._content_hash = '16793293daadb5a03b7cbbb9d15a1a705b22e762a1a751bc8625dec666101ff2'

        self.jwt_payload = {
            'aud': '%s:%s' % (self.issuer, self.issuer_id),
            'iss': 'lka',
            'cty': 'application/json',
            'nbf': time(),
            'jti': str(uuid4()),
            'exp': time() + 30,
            'iat': time(),
            'request': {
                'meth': 'POST',
                'path': '/',
                'hash': self._content_hash,
                'func': 'S256'},
            'sub': '%s:%s' % (self.issuer, self.issuer_id)
        }
        self._transport._get_jwt_payload = MagicMock(return_value=self.jwt_payload)
        self._transport.content_hash_function = MagicMock()

        self._transport.content_hash_function().hexdigest.return_value = self._content_hash

        patcher = patch('launchkey.transports.jose_auth.sha256')
        patched = patcher.start()
        patched.return_value = MagicMock()
        patched.return_value.hexdigest.return_value = self._content_hash
        self.addCleanup(patcher.stop)

    def _verify_jwt_request(self, compact_jwt="compact.jtw.value", subscriber=None, method='POST', path='/', body="body"):
        subscriber = self.jwt_payload['sub'] if subscriber is None else subscriber
        return self._transport.verify_jwt_request(compact_jwt, subscriber, method, path, body)

    def test_all_params_returns_payload(self):
        actual = self._verify_jwt_request()
        self.assertEqual(actual, self.jwt_payload)

    def test_none_path_returns_payload(self):
        actual = self._transport.verify_jwt_request(MagicMock(), self.jwt_payload['sub'], 'POST', None, MagicMock())
        self.assertEqual(actual, self.jwt_payload)

    def test_none_path_still_requires_jwt_request_path(self):
        del self.jwt_payload['request']['path']
        with self.assertRaises(JWTValidationFailure):
            self._verify_jwt_request(path=None)

    def test_none_method_returns_payload(self):
        actual = self._verify_jwt_request(method=None)
        self.assertEqual(actual, self.jwt_payload)

    def test_none_method_still_requires_jwt_request_path(self):
        del self.jwt_payload['request']['meth']
        with self.assertRaises(JWTValidationFailure):
            self._verify_jwt_request(method=None)

    def test_invalid_audience_raises_validation_failure(self):
        self.jwt_payload['aud'] = MagicMock()
        with self.assertRaises(JWTValidationFailure):
            self._verify_jwt_request()

    def test_invalid_issuer_raises_validation_failure(self):
        self.jwt_payload['iss'] = MagicMock()
        with self.assertRaises(JWTValidationFailure):
            self._verify_jwt_request()

    @patch("launchkey.transports.jose_auth.time")
    def test_invalid_nbf_raises_validation_failure(self, time_patch):
        time_patch.return_value = self.jwt_payload['nbf'] - JOSE_JWT_LEEWAY - 1
        with self.assertRaises(JWTValidationFailure):
            self._verify_jwt_request()

    @patch("launchkey.transports.jose_auth.time")
    def test_invalid_exp_raises_validation_failure(self, time_patch):
        time_patch.return_value = self.jwt_payload['exp'] + JOSE_JWT_LEEWAY + 1
        with self.assertRaises(JWTValidationFailure):
            self._verify_jwt_request()

    def test_invalid_sub_raises_validation_failure(self):
        with self.assertRaises(JWTValidationFailure):
            self._verify_jwt_request(subscriber='other')

    def test_invalid_iat_raises_validation_failure(self):
        self.jwt_payload['iat'] = time() + JOSE_JWT_LEEWAY + 1
        with self.assertRaises(JWTValidationFailure):
            self._verify_jwt_request()

    def test_invalid_content_body_raises_validation_failure(self):
        self.jwt_payload['request']['hash'] = MagicMock()
        with self.assertRaises(JWTValidationFailure):
            self._verify_jwt_request()

    def test_no_body_but_response_body_hash_algorithm_raises_validation_failure(self):
        del self.jwt_payload['request']['func']
        with self.assertRaises(JWTValidationFailure):
            self._verify_jwt_request(body=None)

    def test_no_body_but_response_body_hash_raises_validation_failure(self):
        del self.jwt_payload['request']['hash']
        with self.assertRaises(JWTValidationFailure):
            self._verify_jwt_request(body=None)

    def test_body_but_no_response_body_hash_algorithm_raises_validation_failure(self):
        del self.jwt_payload['request']['hash']
        with self.assertRaises(JWTValidationFailure):
            self._verify_jwt_request()

    def test_body_but_no_response_body_hash_raises_validation_failure(self):
        del self.jwt_payload['request']['func']
        with self.assertRaises(JWTValidationFailure):
            self._verify_jwt_request()

    def test_invalid_hash_func_raises_validation_failure(self):
        self.jwt_payload['request']['func'] = 'INVALID'
        with self.assertRaises(JWTValidationFailure):
            self._verify_jwt_request()

    def test_no_request_raises_validation_failure(self):
        del self.jwt_payload['request']
        with self.assertRaises(JWTValidationFailure):
            self._verify_jwt_request(method='INV')

    def test_no_method_raises_validation_failure(self):
        del self.jwt_payload['request']['meth']
        with self.assertRaises(JWTValidationFailure):
            self._verify_jwt_request(method='INV')

    def test_invalid_method_raises_validation_failure(self):
        with self.assertRaises(JWTValidationFailure):
            self._verify_jwt_request(method='INV')

    def test_no_path_raises_validation_failure(self):
        del self.jwt_payload['request']['path']
        with self.assertRaises(JWTValidationFailure):
            self._verify_jwt_request(path='/invalid')

    def test_invalid_path_raises_validation_failure(self):
        with self.assertRaises(JWTValidationFailure):
            self._verify_jwt_request(path='/invalid')

    def test_invalid_jwt_raises_validation_failure(self):
        self._transport._get_jwt_payload.side_effect = JWKESTException
        with self.assertRaises(JWTValidationFailure):
            self._verify_jwt_request()
Ejemplo n.º 2
0
class TestJOSETransportJWTRequest(unittest.TestCase):
    def setUp(self):
        self._transport = JOSETransport()
        self._transport.get = MagicMock(return_value=MagicMock(
            spec=APIResponse))
        public_key = APIResponse(valid_private_key, {}, 200)
        self._transport.get.return_value = public_key
        self._transport._server_time_difference = 0, time()
        self.issuer = "svc"
        self.issuer_id = uuid4()
        self._transport.set_issuer(self.issuer, self.issuer_id,
                                   valid_private_key)
        self._body = str(uuid4())
        self._content_hash = '16793293daadb5a03b7cbbb9d15a1a705b22e762a1a751bc8625dec666101ff2'

        self.jwt_payload = {
            'aud': '%s:%s' % (self.issuer, self.issuer_id),
            'iss': 'lka',
            'cty': 'application/json',
            'nbf': time(),
            'jti': str(uuid4()),
            'exp': time() + 30,
            'iat': time(),
            'request': {
                'meth': 'POST',
                'path': '/',
                'hash': self._content_hash,
                'func': 'S256'
            },
            'sub': '%s:%s' % (self.issuer, self.issuer_id)
        }
        self._transport._get_jwt_payload = MagicMock(
            return_value=self.jwt_payload)
        self._transport.content_hash_function = MagicMock()

        self._transport.content_hash_function(
        ).hexdigest.return_value = self._content_hash

        self._jwt_patch = patch("launchkey.transports.jose_auth.JWT",
                                return_value=MagicMock(spec=JWT)).start()
        self._jwt_patch.return_value.unpack.return_value.headers = faux_jwt_headers

        patcher = patch('launchkey.transports.jose_auth.sha256')
        patched = patcher.start()
        patched.return_value = MagicMock()
        patched.return_value.hexdigest.return_value = self._content_hash
        self.addCleanup(patcher.stop)

    def _verify_jwt_request(self,
                            compact_jwt="compact.jtw.value",
                            subscriber=None,
                            method='POST',
                            path='/',
                            body="body"):
        subscriber = self.jwt_payload[
            'sub'] if subscriber is None else subscriber
        return self._transport.verify_jwt_request(compact_jwt, subscriber,
                                                  method, path, body)

    def test_all_params_returns_payload(self):
        actual = self._verify_jwt_request()
        self.assertEqual(actual, self.jwt_payload)

    def test_none_path_returns_payload(self):
        actual = self._transport.verify_jwt_request(MagicMock(),
                                                    self.jwt_payload['sub'],
                                                    'POST', None, MagicMock())
        self.assertEqual(actual, self.jwt_payload)

    def test_none_path_still_requires_jwt_request_path(self):
        del self.jwt_payload['request']['path']
        with self.assertRaises(JWTValidationFailure):
            self._verify_jwt_request(path=None)

    def test_none_method_returns_payload(self):
        actual = self._verify_jwt_request(method=None)
        self.assertEqual(actual, self.jwt_payload)

    def test_none_method_still_requires_jwt_request_path(self):
        del self.jwt_payload['request']['meth']
        with self.assertRaises(JWTValidationFailure):
            self._verify_jwt_request(method=None)

    def test_invalid_audience_raises_validation_failure(self):
        self.jwt_payload['aud'] = MagicMock()
        with self.assertRaises(JWTValidationFailure):
            self._verify_jwt_request()

    def test_invalid_issuer_raises_validation_failure(self):
        self.jwt_payload['iss'] = MagicMock()
        with self.assertRaises(JWTValidationFailure):
            self._verify_jwt_request()

    @patch("launchkey.transports.jose_auth.time")
    def test_invalid_nbf_raises_validation_failure(self, time_patch):
        time_patch.return_value = self.jwt_payload['nbf'] - JOSE_JWT_LEEWAY - 1
        with self.assertRaises(JWTValidationFailure):
            self._verify_jwt_request()

    @patch("launchkey.transports.jose_auth.time")
    def test_invalid_exp_raises_validation_failure(self, time_patch):
        time_patch.return_value = self.jwt_payload['exp'] + JOSE_JWT_LEEWAY + 1
        with self.assertRaises(JWTValidationFailure):
            self._verify_jwt_request()

    def test_invalid_sub_raises_validation_failure(self):
        with self.assertRaises(JWTValidationFailure):
            self._verify_jwt_request(subscriber='other')

    def test_invalid_iat_raises_validation_failure(self):
        self.jwt_payload['iat'] = time() + JOSE_JWT_LEEWAY + 1
        with self.assertRaises(JWTValidationFailure):
            self._verify_jwt_request()

    def test_invalid_content_body_raises_validation_failure(self):
        self.jwt_payload['request']['hash'] = MagicMock()
        with self.assertRaises(JWTValidationFailure):
            self._verify_jwt_request()

    def test_no_body_but_response_body_hash_algorithm_raises_validation_failure(
            self):
        del self.jwt_payload['request']['func']
        with self.assertRaises(JWTValidationFailure):
            self._verify_jwt_request(body=None)

    def test_no_body_but_response_body_hash_raises_validation_failure(self):
        del self.jwt_payload['request']['hash']
        with self.assertRaises(JWTValidationFailure):
            self._verify_jwt_request(body=None)

    def test_body_but_no_response_body_hash_algorithm_raises_validation_failure(
            self):
        del self.jwt_payload['request']['hash']
        with self.assertRaises(JWTValidationFailure):
            self._verify_jwt_request()

    def test_body_but_no_response_body_hash_raises_validation_failure(self):
        del self.jwt_payload['request']['func']
        with self.assertRaises(JWTValidationFailure):
            self._verify_jwt_request()

    def test_invalid_hash_func_raises_validation_failure(self):
        self.jwt_payload['request']['func'] = 'INVALID'
        with self.assertRaises(JWTValidationFailure):
            self._verify_jwt_request()

    def test_no_request_raises_validation_failure(self):
        del self.jwt_payload['request']
        with self.assertRaises(JWTValidationFailure):
            self._verify_jwt_request(method='INV')

    def test_no_method_raises_validation_failure(self):
        del self.jwt_payload['request']['meth']
        with self.assertRaises(JWTValidationFailure):
            self._verify_jwt_request(method='INV')

    def test_invalid_method_raises_validation_failure(self):
        with self.assertRaises(JWTValidationFailure):
            self._verify_jwt_request(method='INV')

    def test_no_path_raises_validation_failure(self):
        del self.jwt_payload['request']['path']
        with self.assertRaises(JWTValidationFailure):
            self._verify_jwt_request(path='/invalid')

    def test_invalid_path_raises_validation_failure(self):
        with self.assertRaises(JWTValidationFailure):
            self._verify_jwt_request(path='/invalid')

    def test_invalid_jwt_raises_validation_failure(self):
        self._transport._get_jwt_payload.side_effect = JWKESTException
        with self.assertRaises(JWTValidationFailure):
            self._verify_jwt_request()
Ejemplo n.º 3
0
class TestJOSETransportJWTResponse(unittest.TestCase):
    def setUp(self):
        self._transport = JOSETransport()
        self._transport.get = MagicMock(return_value=MagicMock(
            spec=APIResponse))
        public_key = APIResponse(valid_private_key, {
            "X-IOV-KEY-ID":
            "59:12:e2:f6:3f:79:d5:1e:18:75:c5:25:ff:b3:b7:f2"
        }, 200)
        self._transport.get.return_value = public_key
        self._transport._server_time_difference = 0, time()
        self.issuer = ANY
        self.issuer_id = uuid4()
        self._transport.set_issuer(self.issuer, self.issuer_id,
                                   valid_private_key)
        self._body = str(uuid4())
        self._content_hash = '16793293daadb5a03b7cbbb9d15a1a705b22e762a1a751bc8625dec666101ff2'

        self.jwt_response = {
            'aud': '%s:%s' % (self.issuer, self.issuer_id),
            'iss': 'lka',
            'cty': 'application/json',
            'nbf': time(),
            'jti': str(uuid4()),
            'exp': time() + 30,
            'iat': time(),
            'response': {
                'status': 200,
                'hash': self._content_hash,
                'func': 'S256'
            },
            'sub': '%s:%s' % (self.issuer, self.issuer_id)
        }
        self._transport._get_jwt_payload = MagicMock(
            return_value=self.jwt_response)
        self._transport.content_hash_function = MagicMock()
        self._transport.content_hash_function(
        ).hexdigest.return_value = self._content_hash

    def test_verify_jwt_response_success(self):
        self._transport.verify_jwt_response(MagicMock(),
                                            self.jwt_response['jti'],
                                            MagicMock(),
                                            self.jwt_response['sub'])

    def test_verify_jwt_response_invalid_audience(self):
        self.jwt_response['aud'] = MagicMock()
        with self.assertRaises(JWTValidationFailure):
            self._transport.verify_jwt_response(MagicMock(),
                                                self.jwt_response['jti'],
                                                MagicMock(),
                                                self.jwt_response['sub'])

    @patch("launchkey.transports.jose_auth.time")
    def test_verify_jwt_response_invalid_nbf(self, time_patch):
        time_patch.return_value = self.jwt_response['nbf'] - JOSE_JWT_LEEWAY - 1
        with self.assertRaises(JWTValidationFailure):
            self._transport.verify_jwt_response(MagicMock(),
                                                self.jwt_response['jti'],
                                                MagicMock(),
                                                self.jwt_response['sub'])

    @patch("launchkey.transports.jose_auth.time")
    def test_verify_jwt_response_invalid_exp(self, time_patch):
        time_patch.return_value = self.jwt_response['exp'] + JOSE_JWT_LEEWAY + 1
        with self.assertRaises(JWTValidationFailure):
            self._transport.verify_jwt_response(MagicMock(),
                                                self.jwt_response['jti'],
                                                MagicMock(),
                                                self.jwt_response['sub'])

    def test_verify_jwt_response_invalid_sub(self):
        with self.assertRaises(JWTValidationFailure):
            self._transport.verify_jwt_response(MagicMock(),
                                                self.jwt_response['jti'],
                                                MagicMock(), MagicMock())

    def test_verify_jwt_response_invalid_iat(self):
        self.jwt_response['iat'] = time() + JOSE_JWT_LEEWAY + 1
        with self.assertRaises(JWTValidationFailure):
            self._transport.verify_jwt_response(MagicMock(),
                                                self.jwt_response['jti'],
                                                MagicMock(),
                                                self.jwt_response['sub'])

    def test_verify_jwt_response_invalid_content_body(self):
        self.jwt_response['response']['hash'] = MagicMock()
        with self.assertRaises(JWTValidationFailure):
            self._transport.verify_jwt_response(MagicMock(),
                                                self.jwt_response['jti'],
                                                MagicMock(),
                                                self.jwt_response['sub'])

    def test_verify_jwt_response_invalid_content_jti(self):
        with self.assertRaises(JWTValidationFailure):
            self._transport.verify_jwt_response(MagicMock(), 'InvalidJTI',
                                                MagicMock(),
                                                self.jwt_response['sub'])