Esempio n. 1
0
    def check_remote(target_edgetpu=False):
        tflite_model_path = get_tflite_model_path(target_edgetpu)

        # inference via tflite interpreter python apis
        interpreter = init_interpreter(tflite_model_path, target_edgetpu)
        interpreter.allocate_tensors()
        input_details = interpreter.get_input_details()
        output_details = interpreter.get_output_details()

        input_shape = input_details[0]["shape"]
        tflite_input = np.array(np.random.random_sample(input_shape),
                                dtype=np.uint8)
        interpreter.set_tensor(input_details[0]["index"], tflite_input)
        interpreter.invoke()
        tflite_output = interpreter.get_tensor(output_details[0]["index"])

        # inference via remote tvm tflite runtime
        server = rpc.Server("localhost")
        remote = rpc.connect(server.host, server.port)
        ctx = remote.cpu(0)

        with open(tflite_model_path, "rb") as model_fin:
            runtime = tflite_runtime.create(model_fin.read(), ctx)
            runtime.set_input(0, tvm.nd.array(tflite_input, ctx))
            runtime.invoke()
            out = runtime.get_output(0)
            np.testing.assert_equal(out.asnumpy(), tflite_output)
Esempio n. 2
0
    def check_remote():
        tflite_fname = "model.tflite"
        tflite_model = create_tflite_model()
        temp = util.tempdir()
        tflite_model_path = temp.relpath(tflite_fname)
        open(tflite_model_path, 'wb').write(tflite_model)

        # inference via tflite interpreter python apis
        interpreter = tflite.Interpreter(model_path=tflite_model_path)
        interpreter.allocate_tensors()
        input_details = interpreter.get_input_details()
        output_details = interpreter.get_output_details()

        input_shape = input_details[0]['shape']
        tflite_input = np.array(np.random.random_sample(input_shape),
                                dtype=np.float32)
        interpreter.set_tensor(input_details[0]['index'], tflite_input)
        interpreter.invoke()
        tflite_output = interpreter.get_tensor(output_details[0]['index'])

        # inference via remote tvm tflite runtime
        server = rpc.Server("localhost")
        remote = rpc.connect(server.host, server.port)
        ctx = remote.cpu(0)
        a = remote.upload(tflite_model_path)

        with open(tflite_model_path, 'rb') as model_fin:
            runtime = tflite_runtime.create(model_fin.read(), remote.cpu(0))
            runtime.set_input(0, tvm.nd.array(tflite_input, remote.cpu(0)))
            runtime.invoke()
            out = runtime.get_output(0)
            np.testing.assert_equal(out.asnumpy(), tflite_output)
    def check_remote(server):
        remote = rpc.connect(server.host, server.port)
        a = remote.upload(tflite_model_path)

        with open(tflite_model_path, "rb") as model_fin:
            runtime = tflite_runtime.create(model_fin.read(), remote.cpu(0))
            runtime.set_input(0, tvm.nd.array(tflite_input, remote.cpu(0)))
            runtime.invoke()
            out = runtime.get_output(0)
            np.testing.assert_equal(out.numpy(), tflite_output)
Esempio n. 4
0
def test_remote():
    if not tvm.runtime.enabled("tflite"):
        print("skip because tflite runtime is not enabled...")
        return
    if not tvm.get_global_func("tvm.tflite_runtime.create", True):
        print("skip because tflite runtime is not enabled...")
        return

    try:
        import tensorflow as tf
    except ImportError:
        print("skip because tensorflow not installed...")
        return

    tflite_fname = "model.tflite"
    tflite_model = _create_tflite_model()
    temp = utils.tempdir()
    tflite_model_path = temp.relpath(tflite_fname)
    open(tflite_model_path, "wb").write(tflite_model)

    # inference via tflite interpreter python apis
    interpreter = tf.lite.Interpreter(model_path=tflite_model_path)
    interpreter.allocate_tensors()
    input_details = interpreter.get_input_details()
    output_details = interpreter.get_output_details()

    input_shape = input_details[0]["shape"]
    tflite_input = np.array(np.random.random_sample(input_shape),
                            dtype=np.float32)
    interpreter.set_tensor(input_details[0]["index"], tflite_input)
    interpreter.invoke()
    tflite_output = interpreter.get_tensor(output_details[0]["index"])

    # inference via remote tvm tflite runtime
    server = rpc.Server("localhost")
    remote = rpc.connect(server.host, server.port)
    ctx = remote.cpu(0)
    a = remote.upload(tflite_model_path)

    with open(tflite_model_path, "rb") as model_fin:
        runtime = tflite_runtime.create(model_fin.read(), remote.cpu(0))
        runtime.set_input(0, tvm.nd.array(tflite_input, remote.cpu(0)))
        runtime.invoke()
        out = runtime.get_output(0)
        np.testing.assert_equal(out.asnumpy(), tflite_output)

    server.terminate()
Esempio n. 5
0
def test_local():
    if not tvm.runtime.enabled("tflite"):
        print("skip because tflite runtime is not enabled...")
        return
    if not tvm.get_global_func("tvm.tflite_runtime.create", True):
        print("skip because tflite runtime is not enabled...")
        return

    try:
        import tensorflow as tf
    except ImportError:
        print('skip because tensorflow not installed...')
        return

    tflite_fname = "model.tflite"
    tflite_model = _create_tflite_model()
    temp = util.tempdir()
    tflite_model_path = temp.relpath(tflite_fname)
    open(tflite_model_path, 'wb').write(tflite_model)

    # inference via tflite interpreter python apis
    interpreter = tf.lite.Interpreter(model_path=tflite_model_path)
    interpreter.allocate_tensors()
    input_details = interpreter.get_input_details()
    output_details = interpreter.get_output_details()

    input_shape = input_details[0]['shape']
    tflite_input = np.array(np.random.random_sample(input_shape), dtype=np.float32)
    interpreter.set_tensor(input_details[0]['index'], tflite_input)
    interpreter.invoke()
    tflite_output = interpreter.get_tensor(output_details[0]['index'])

    # inference via tvm tflite runtime
    with open(tflite_model_path, 'rb') as model_fin:
        runtime = tflite_runtime.create(model_fin.read(), tvm.cpu(0))
        runtime.set_input(0, tvm.nd.array(tflite_input))
        runtime.invoke()
        out = runtime.get_output(0)
        np.testing.assert_equal(out.asnumpy(), tflite_output)