def run_model(project): project.flash() with project.transport() as transport: aot_transport_init_wait(transport) transport.write(b"infer%", timeout_sec=5) result_line = aot_transport_find_message(transport, "result", timeout_sec=60) result_line = result_line.strip("\n") result_line = result_line.split(":") result = int(result_line[1]) time = int(result_line[2]) _LOG.info(f"Result: {result}\ttime: {time} ms") return result, time
def _run_model(temp_dir, board, west_cmd, lowered, build_config, sample, output_shape): project = _generate_project(temp_dir, board, west_cmd, lowered, build_config, sample, output_shape) project.flash() with project.transport() as transport: aot_transport_init_wait(transport) transport.write(b"infer%", timeout_sec=5) result_line = aot_transport_find_message(transport, "result", timeout_sec=60) result_line = result_line.strip("\n") result_line = result_line.split(":") result = int(result_line[1]) time = int(result_line[2]) _LOG.info(f"Result: {result}\ttime: {time} ms") return result, time
def test_tflite(temp_dir, board, west_cmd, tvm_debug): """Testing a TFLite model.""" model = test_utils.ZEPHYR_BOARDS[board] input_shape = (1, 49, 10, 1) output_shape = (1, 12) build_config = {"debug": tvm_debug} model_url = "https://github.com/tlc-pack/web-data/raw/25fe99fb00329a26bd37d3dca723da94316fd34c/testdata/microTVM/model/keyword_spotting_quant.tflite" model_path = download_testdata(model_url, "keyword_spotting_quant.tflite", module="model") # Import TFLite model tflite_model_buf = open(model_path, "rb").read() try: import tflite tflite_model = tflite.Model.GetRootAsModel(tflite_model_buf, 0) except AttributeError: import tflite.Model tflite_model = tflite.Model.Model.GetRootAsModel(tflite_model_buf, 0) # Load TFLite model and convert to Relay relay_mod, params = relay.frontend.from_tflite( tflite_model, shape_dict={"input_1": input_shape}, dtype_dict={"input_1 ": "int8"} ) target = tvm.target.target.micro(model) executor = Executor( "aot", {"unpacked-api": True, "interface-api": "c", "workspace-byte-alignment": 4} ) runtime = Runtime("crt") with tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True}): lowered = relay.build(relay_mod, target, params=params, runtime=runtime, executor=executor) sample_url = "https://github.com/tlc-pack/web-data/raw/967fc387dadb272c5a7f8c3461d34c060100dbf1/testdata/microTVM/data/keyword_spotting_int8_6.pyc.npy" sample_path = download_testdata(sample_url, "keyword_spotting_int8_6.pyc.npy", module="data") sample = np.load(sample_path) with tempfile.NamedTemporaryFile() as tar_temp_file: with tarfile.open(tar_temp_file.name, "w:gz") as tf: with tempfile.TemporaryDirectory() as tar_temp_dir: model_files_path = os.path.join(tar_temp_dir, "include") os.mkdir(model_files_path) header_path = generate_c_interface_header( lowered.libmod_name, ["input_1"], ["output"], [], 0, model_files_path ) tf.add(header_path, arcname=os.path.relpath(header_path, tar_temp_dir)) test_utils.create_header_file("input_data", sample, "include", tf) test_utils.create_header_file( "output_data", np.zeros(shape=output_shape, dtype="int8"), "include", tf ) project, _ = test_utils.build_project( temp_dir, board, west_cmd, lowered, build_config, extra_files_tar=tar_temp_file.name, ) project.flash() with project.transport() as transport: aot_transport_init_wait(transport) transport.write(b"infer%", timeout_sec=5) result_line = aot_transport_find_message(transport, "result", timeout_sec=60) result_line = result_line.strip("\n") result_line = result_line.split(":") result = int(result_line[1]) time = int(result_line[2]) logging.info(f"Result: {result}\ttime: {time} ms") assert result == 6