def test_api_akv_hsm_key_rollover_single_server():
    model = MATMUL_1
    port = 8001
    key_name = random_akv_key_name()
    try:
        client = Client(f'http://localhost:{port}/', enclave_allow_debug=True)

        with Server(model_path=model['model'],
                    port=port,
                    key_rollover_interval=5,
                    key_sync_interval=5,
                    use_akv=True,
                    akv_app_id=os.environ['CONFONNX_TEST_APP_ID'],
                    akv_app_pwd=os.environ['CONFONNX_TEST_APP_PWD'],
                    akv_service_key_name=key_name,
                    akv_vault_url=os.environ['CONFONNX_TEST_VAULT_HSM_URL'],
                    akv_attestation_url=os.
                    environ['CONFONNX_TEST_ATTESTATION_URL']):
            for i in range(12):
                print(f'Request #{i}')
                client.predict(model['input'])
                print(f'Request #{i} -- done')
                time.sleep(1)

        assert client.key_rollover_count in [1, 2]
        assert client.key_invalid_count == 0
    finally:
        delete_akv_key(os.environ['CONFONNX_TEST_APP_ID'],
                       os.environ['CONFONNX_TEST_APP_PWD'],
                       os.environ['CONFONNX_TEST_VAULT_HSM_URL'],
                       key_name,
                       is_hsm=True)
def test_api_local_key_rollover():
    model = MATMUL_1
    port = 8001

    client = Client(f'http://localhost:{port}/', enclave_allow_debug=True)

    with Server(model_path=model['model'],
                port=port,
                key_rollover_interval=5,
                key_sync_interval=1):
        for _ in range(12):
            client.predict(model['input'])
            time.sleep(1)

    assert client.key_rollover_count in [2, 3]
    assert client.key_invalid_count == 0
def test_api_akv_multiple_servers():
    model = MATMUL_1
    num_servers = 3
    start_port = 8001
    ports = list(range(start_port, start_port + num_servers))
    key_name = random_akv_key_name()

    servers = []
    try:
        for port in ports:
            servers.append(
                Server(model_path=model['model'],
                       port=port,
                       use_akv=True,
                       akv_app_id=os.environ['CONFONNX_TEST_APP_ID'],
                       akv_app_pwd=os.environ['CONFONNX_TEST_APP_PWD'],
                       akv_service_key_name=key_name,
                       akv_vault_url=os.environ['CONFONNX_TEST_VAULT_URL']))

        client = Client(f'http://foo/', enclave_allow_debug=True)

        for port in ports:
            client.url = f'http://localhost:{port}/'
            client.predict(model['input'])
            # All servers should have the same key from AKV.
            # This assumes that the key itself exists already in AKV,
            # which will be the case due to running of other AKV tests above.
            assert client.key_rollover_count == 0
            assert client.key_invalid_count == 0
    finally:
        stop_errors = []
        for server in servers:
            try:
                server.stop()
            except Exception as e:
                stop_errors.append(e)
        delete_akv_key(os.environ['CONFONNX_TEST_APP_ID'],
                       os.environ['CONFONNX_TEST_APP_PWD'],
                       os.environ['CONFONNX_TEST_VAULT_URL'],
                       key_name,
                       is_hsm=False)
        for e in stop_errors:
            print(e)
        if stop_errors:
            raise stop_errors[0]
def test_api_invalid_key():
    model = MATMUL_1
    port = 8001

    client = Client(f'http://localhost:{port}/', enclave_allow_debug=True)

    with Server(model_path=model['model'], port=port):
        client.predict(model['input'])

    assert client.key_invalid_count == 0

    # Simulate multiple key rollovers by restarting the server (without using AKV).
    # The client's key then becomes invalid.
    with Server(model_path=model['model'], port=port):
        # If the client does not repeat the key exchange,
        # then the following request will fail.
        client.predict(model['input'])

    assert client.key_invalid_count == 1
def test_api_basic():
    model = MATMUL_1
    port = 8001

    client = Client(f'http://localhost:{port}/', enclave_allow_debug=True)

    with Server(model_path=model['model'], port=port):
        output = client.predict(model['input'])

    assert_output_allclose(output, model['ref_output'])
Beispiel #6
0
def _main_predict(args) -> None:
    inputs = {}
    if args.pb_in:
        for path, input_name in zip(args.pb_in, args.pb_in_names):
            with open(path, 'rb') as fp:
                tensor = onnx.TensorProto()
                tensor.ParseFromString(fp.read())
            print('Using {} as "{}" input'.format(path, input_name))
            inputs[input_name] = numpy_helper.to_array(tensor)
    elif args.json_in:
        with open(args.json_in) as fp:
            obj = json.load(fp)
        for input_name, obj in obj.items():
            arr = np.asarray(obj['values'], dtype=obj['type'])
            inputs[input_name] = arr
    else:
        raise NotImplementedError

    if args.verbose:
        print('Inputs:')
        for input_name, arr in inputs.items():
            print(' {}: dtype={} shape={}'.format(input_name, arr.dtype,
                                                  arr.shape))

    enclave_signing_key = None
    if args.enclave_signing_key_file:
        with open(args.enclave_signing_key_file) as fp:
            enclave_signing_key = fp.read()

    if args.enclave_model_hash_file:
        with open(args.enclave_model_hash_file) as f:
            enclave_model_hash = f.read()
    else:
        enclave_model_hash = args.enclave_model_hash

    c = Client(url=args.url,
               auth=get_auth(args),
               enclave_signing_key=enclave_signing_key,
               enclave_hash=args.enclave_hash,
               enclave_model_hash=enclave_model_hash,
               enclave_allow_debug=args.enclave_allow_debug)

    try:
        outputs = c.predict(inputs)
    except Exception as e:
        if args.verbose:
            raise
        else:
            print(f'{C.FAIL}{C.BOLD}ERROR: {e}{C.END}')
            sys.exit(1)

    if args.verbose:
        print('Outputs:')
        for output_name, arr in outputs.items():
            print(' {}: dtype={} shape={}'.format(output_name, arr.dtype,
                                                  arr.shape))

    if args.pb_out:
        os.makedirs(args.pb_out, exist_ok=True)
        for i, (output_name, arr) in enumerate(outputs.items()):
            filename = 'output_{}.pb'.format(i)
            print('Saving "{}" output as {}'.format(output_name, filename))
            path = os.path.join(args.pb_out, filename)
            tensor = numpy_helper.from_array(arr, output_name)
            with open(path, 'wb') as fp:
                fp.write(tensor.SerializeToString())

    if args.json_out:
        print('Saving inference results to {}'.format(args.json_out))
        with open(args.json_out, 'w') as fp:
            json.dump(outputs, fp, cls=NumpyEncoder, sort_keys=True)