def check_remote(target_edgetpu=False): tflite_model_path = get_tflite_model_path(target_edgetpu) # inference via tflite interpreter python apis interpreter = init_interpreter(tflite_model_path, target_edgetpu) interpreter.allocate_tensors() input_details = interpreter.get_input_details() output_details = interpreter.get_output_details() input_shape = input_details[0]["shape"] tflite_input = np.array(np.random.random_sample(input_shape), dtype=np.uint8) interpreter.set_tensor(input_details[0]["index"], tflite_input) interpreter.invoke() tflite_output = interpreter.get_tensor(output_details[0]["index"]) # inference via remote tvm tflite runtime server = rpc.Server("localhost") remote = rpc.connect(server.host, server.port) ctx = remote.cpu(0) with open(tflite_model_path, "rb") as model_fin: runtime = tflite_runtime.create(model_fin.read(), ctx) runtime.set_input(0, tvm.nd.array(tflite_input, ctx)) runtime.invoke() out = runtime.get_output(0) np.testing.assert_equal(out.asnumpy(), tflite_output)
def check_remote(): tflite_fname = "model.tflite" tflite_model = create_tflite_model() temp = util.tempdir() tflite_model_path = temp.relpath(tflite_fname) open(tflite_model_path, 'wb').write(tflite_model) # inference via tflite interpreter python apis interpreter = tflite.Interpreter(model_path=tflite_model_path) interpreter.allocate_tensors() input_details = interpreter.get_input_details() output_details = interpreter.get_output_details() input_shape = input_details[0]['shape'] tflite_input = np.array(np.random.random_sample(input_shape), dtype=np.float32) interpreter.set_tensor(input_details[0]['index'], tflite_input) interpreter.invoke() tflite_output = interpreter.get_tensor(output_details[0]['index']) # inference via remote tvm tflite runtime server = rpc.Server("localhost") remote = rpc.connect(server.host, server.port) ctx = remote.cpu(0) a = remote.upload(tflite_model_path) with open(tflite_model_path, 'rb') as model_fin: runtime = tflite_runtime.create(model_fin.read(), remote.cpu(0)) runtime.set_input(0, tvm.nd.array(tflite_input, remote.cpu(0))) runtime.invoke() out = runtime.get_output(0) np.testing.assert_equal(out.asnumpy(), tflite_output)
def check_remote(server): remote = rpc.connect(server.host, server.port) a = remote.upload(tflite_model_path) with open(tflite_model_path, "rb") as model_fin: runtime = tflite_runtime.create(model_fin.read(), remote.cpu(0)) runtime.set_input(0, tvm.nd.array(tflite_input, remote.cpu(0))) runtime.invoke() out = runtime.get_output(0) np.testing.assert_equal(out.numpy(), tflite_output)
def test_remote(): if not tvm.runtime.enabled("tflite"): print("skip because tflite runtime is not enabled...") return if not tvm.get_global_func("tvm.tflite_runtime.create", True): print("skip because tflite runtime is not enabled...") return try: import tensorflow as tf except ImportError: print("skip because tensorflow not installed...") return tflite_fname = "model.tflite" tflite_model = _create_tflite_model() temp = utils.tempdir() tflite_model_path = temp.relpath(tflite_fname) open(tflite_model_path, "wb").write(tflite_model) # inference via tflite interpreter python apis interpreter = tf.lite.Interpreter(model_path=tflite_model_path) interpreter.allocate_tensors() input_details = interpreter.get_input_details() output_details = interpreter.get_output_details() input_shape = input_details[0]["shape"] tflite_input = np.array(np.random.random_sample(input_shape), dtype=np.float32) interpreter.set_tensor(input_details[0]["index"], tflite_input) interpreter.invoke() tflite_output = interpreter.get_tensor(output_details[0]["index"]) # inference via remote tvm tflite runtime server = rpc.Server("localhost") remote = rpc.connect(server.host, server.port) ctx = remote.cpu(0) a = remote.upload(tflite_model_path) with open(tflite_model_path, "rb") as model_fin: runtime = tflite_runtime.create(model_fin.read(), remote.cpu(0)) runtime.set_input(0, tvm.nd.array(tflite_input, remote.cpu(0))) runtime.invoke() out = runtime.get_output(0) np.testing.assert_equal(out.asnumpy(), tflite_output) server.terminate()
def test_local(): if not tvm.runtime.enabled("tflite"): print("skip because tflite runtime is not enabled...") return if not tvm.get_global_func("tvm.tflite_runtime.create", True): print("skip because tflite runtime is not enabled...") return try: import tensorflow as tf except ImportError: print('skip because tensorflow not installed...') return tflite_fname = "model.tflite" tflite_model = _create_tflite_model() temp = util.tempdir() tflite_model_path = temp.relpath(tflite_fname) open(tflite_model_path, 'wb').write(tflite_model) # inference via tflite interpreter python apis interpreter = tf.lite.Interpreter(model_path=tflite_model_path) interpreter.allocate_tensors() input_details = interpreter.get_input_details() output_details = interpreter.get_output_details() input_shape = input_details[0]['shape'] tflite_input = np.array(np.random.random_sample(input_shape), dtype=np.float32) interpreter.set_tensor(input_details[0]['index'], tflite_input) interpreter.invoke() tflite_output = interpreter.get_tensor(output_details[0]['index']) # inference via tvm tflite runtime with open(tflite_model_path, 'rb') as model_fin: runtime = tflite_runtime.create(model_fin.read(), tvm.cpu(0)) runtime.set_input(0, tvm.nd.array(tflite_input)) runtime.invoke() out = runtime.get_output(0) np.testing.assert_equal(out.asnumpy(), tflite_output)