def test_api_akv_hsm_key_rollover_single_server():
    model = MATMUL_1
    port = 8001
    key_name = random_akv_key_name()
    try:
        client = Client(f'http://localhost:{port}/', enclave_allow_debug=True)

        with Server(model_path=model['model'],
                    port=port,
                    key_rollover_interval=5,
                    key_sync_interval=5,
                    use_akv=True,
                    akv_app_id=os.environ['CONFONNX_TEST_APP_ID'],
                    akv_app_pwd=os.environ['CONFONNX_TEST_APP_PWD'],
                    akv_service_key_name=key_name,
                    akv_vault_url=os.environ['CONFONNX_TEST_VAULT_HSM_URL'],
                    akv_attestation_url=os.
                    environ['CONFONNX_TEST_ATTESTATION_URL']):
            for i in range(12):
                print(f'Request #{i}')
                client.predict(model['input'])
                print(f'Request #{i} -- done')
                time.sleep(1)

        assert client.key_rollover_count in [1, 2]
        assert client.key_invalid_count == 0
    finally:
        delete_akv_key(os.environ['CONFONNX_TEST_APP_ID'],
                       os.environ['CONFONNX_TEST_APP_PWD'],
                       os.environ['CONFONNX_TEST_VAULT_HSM_URL'],
                       key_name,
                       is_hsm=True)
def test_api_basic():
    model = MATMUL_1
    port = 8001

    client = Client(f'http://localhost:{port}/', enclave_allow_debug=True)

    with Server(model_path=model['model'], port=port):
        output = client.predict(model['input'])

    assert_output_allclose(output, model['ref_output'])
def test_api_local_key_rollover():
    model = MATMUL_1
    port = 8001

    client = Client(f'http://localhost:{port}/', enclave_allow_debug=True)

    with Server(model_path=model['model'],
                port=port,
                key_rollover_interval=5,
                key_sync_interval=1):
        for _ in range(12):
            client.predict(model['input'])
            time.sleep(1)

    assert client.key_rollover_count in [2, 3]
    assert client.key_invalid_count == 0
def test_api_akv_multiple_servers():
    model = MATMUL_1
    num_servers = 3
    start_port = 8001
    ports = list(range(start_port, start_port + num_servers))
    key_name = random_akv_key_name()

    servers = []
    try:
        for port in ports:
            servers.append(
                Server(model_path=model['model'],
                       port=port,
                       use_akv=True,
                       akv_app_id=os.environ['CONFONNX_TEST_APP_ID'],
                       akv_app_pwd=os.environ['CONFONNX_TEST_APP_PWD'],
                       akv_service_key_name=key_name,
                       akv_vault_url=os.environ['CONFONNX_TEST_VAULT_URL']))

        client = Client(f'http://foo/', enclave_allow_debug=True)

        for port in ports:
            client.url = f'http://localhost:{port}/'
            client.predict(model['input'])
            # All servers should have the same key from AKV.
            # This assumes that the key itself exists already in AKV,
            # which will be the case due to running of other AKV tests above.
            assert client.key_rollover_count == 0
            assert client.key_invalid_count == 0
    finally:
        stop_errors = []
        for server in servers:
            try:
                server.stop()
            except Exception as e:
                stop_errors.append(e)
        delete_akv_key(os.environ['CONFONNX_TEST_APP_ID'],
                       os.environ['CONFONNX_TEST_APP_PWD'],
                       os.environ['CONFONNX_TEST_VAULT_URL'],
                       key_name,
                       is_hsm=False)
        for e in stop_errors:
            print(e)
        if stop_errors:
            raise stop_errors[0]
Ejemplo n.º 5
0
def _main_provision_model_key(args) -> None:
    if args.model_key:
        key = args.model_key
    else:
        with open(args.model_key_file) as fp:
            key = fp.read()

    enclave_signing_key = None
    if args.enclave_signing_key_file:
        with open(args.enclave_signing_key_file) as fp:
            enclave_signing_key = fp.read()

    c = Client(url=args.url,
               auth=get_auth(args),
               enclave_signing_key=enclave_signing_key,
               enclave_hash=args.enclave_hash,
               enclave_allow_debug=args.enclave_allow_debug)

    c.provision_model_key(key)
def test_api_invalid_key():
    model = MATMUL_1
    port = 8001

    client = Client(f'http://localhost:{port}/', enclave_allow_debug=True)

    with Server(model_path=model['model'], port=port):
        client.predict(model['input'])

    assert client.key_invalid_count == 0

    # Simulate multiple key rollovers by restarting the server (without using AKV).
    # The client's key then becomes invalid.
    with Server(model_path=model['model'], port=port):
        # If the client does not repeat the key exchange,
        # then the following request will fail.
        client.predict(model['input'])

    assert client.key_invalid_count == 1
Ejemplo n.º 7
0
def _main_predict(args) -> None:
    inputs = {}
    if args.pb_in:
        for path, input_name in zip(args.pb_in, args.pb_in_names):
            with open(path, 'rb') as fp:
                tensor = onnx.TensorProto()
                tensor.ParseFromString(fp.read())
            print('Using {} as "{}" input'.format(path, input_name))
            inputs[input_name] = numpy_helper.to_array(tensor)
    elif args.json_in:
        with open(args.json_in) as fp:
            obj = json.load(fp)
        for input_name, obj in obj.items():
            arr = np.asarray(obj['values'], dtype=obj['type'])
            inputs[input_name] = arr
    else:
        raise NotImplementedError

    if args.verbose:
        print('Inputs:')
        for input_name, arr in inputs.items():
            print(' {}: dtype={} shape={}'.format(input_name, arr.dtype,
                                                  arr.shape))

    enclave_signing_key = None
    if args.enclave_signing_key_file:
        with open(args.enclave_signing_key_file) as fp:
            enclave_signing_key = fp.read()

    if args.enclave_model_hash_file:
        with open(args.enclave_model_hash_file) as f:
            enclave_model_hash = f.read()
    else:
        enclave_model_hash = args.enclave_model_hash

    c = Client(url=args.url,
               auth=get_auth(args),
               enclave_signing_key=enclave_signing_key,
               enclave_hash=args.enclave_hash,
               enclave_model_hash=enclave_model_hash,
               enclave_allow_debug=args.enclave_allow_debug)

    try:
        outputs = c.predict(inputs)
    except Exception as e:
        if args.verbose:
            raise
        else:
            print(f'{C.FAIL}{C.BOLD}ERROR: {e}{C.END}')
            sys.exit(1)

    if args.verbose:
        print('Outputs:')
        for output_name, arr in outputs.items():
            print(' {}: dtype={} shape={}'.format(output_name, arr.dtype,
                                                  arr.shape))

    if args.pb_out:
        os.makedirs(args.pb_out, exist_ok=True)
        for i, (output_name, arr) in enumerate(outputs.items()):
            filename = 'output_{}.pb'.format(i)
            print('Saving "{}" output as {}'.format(output_name, filename))
            path = os.path.join(args.pb_out, filename)
            tensor = numpy_helper.from_array(arr, output_name)
            with open(path, 'wb') as fp:
                fp.write(tensor.SerializeToString())

    if args.json_out:
        print('Saving inference results to {}'.format(args.json_out))
        with open(args.json_out, 'w') as fp:
            json.dump(outputs, fp, cls=NumpyEncoder, sort_keys=True)
    return jwt


sgx_report = get_sgx_report(public_key_path)

# Inference
assert 'CONFONNX_DUMP_QUOTE' in os.environ

CONFONNX_URL = os.environ.get("CONFONNX_URL", "http://127.0.0.1:8888")
CONFONNX_API_KEY = os.environ.get("CONFONNX_API_KEY")
auth = None
if CONFONNX_API_KEY:
    auth = {'user': '******', 'pass': CONFONNX_API_KEY}
# In a production scenario, the keyword arguments enclave_signing_key
# or enclave_hash should be specified to verify the enclave identity.
confonnx_client = Client(url=CONFONNX_URL, auth=auth)


def gray2rgb(image):
    w, h = image.shape
    image += np.abs(np.min(image))
    image_max = np.abs(np.max(image))
    if image_max > 0:
        image /= image_max
    ret = np.empty((w, h, 3), dtype=np.uint8)
    ret[:, :, 2] = ret[:, :, 1] = ret[:, :, 0] = image * 255
    return ret


def outline(image, mask, color):
    mask = np.round(mask)