Beispiel #1
0
def run_model(project):
    project.flash()

    with project.transport() as transport:
        aot_transport_init_wait(transport)
        transport.write(b"infer%", timeout_sec=5)
        result_line = aot_transport_find_message(transport, "result", timeout_sec=60)

    result_line = result_line.strip("\n")
    result_line = result_line.split(":")
    result = int(result_line[1])
    time = int(result_line[2])
    _LOG.info(f"Result: {result}\ttime: {time} ms")

    return result, time
Beispiel #2
0
def _run_model(temp_dir, board, west_cmd, lowered, build_config, sample,
               output_shape):

    project = _generate_project(temp_dir, board, west_cmd, lowered,
                                build_config, sample, output_shape)

    project.flash()

    with project.transport() as transport:
        aot_transport_init_wait(transport)
        transport.write(b"infer%", timeout_sec=5)
        result_line = aot_transport_find_message(transport,
                                                 "result",
                                                 timeout_sec=60)

    result_line = result_line.strip("\n")
    result_line = result_line.split(":")
    result = int(result_line[1])
    time = int(result_line[2])
    _LOG.info(f"Result: {result}\ttime: {time} ms")

    return result, time
Beispiel #3
0
def test_tflite(temp_dir, board, west_cmd, tvm_debug):
    """Testing a TFLite model."""
    model = test_utils.ZEPHYR_BOARDS[board]
    input_shape = (1, 49, 10, 1)
    output_shape = (1, 12)
    build_config = {"debug": tvm_debug}

    model_url = "https://github.com/tlc-pack/web-data/raw/25fe99fb00329a26bd37d3dca723da94316fd34c/testdata/microTVM/model/keyword_spotting_quant.tflite"
    model_path = download_testdata(model_url, "keyword_spotting_quant.tflite", module="model")

    # Import TFLite model
    tflite_model_buf = open(model_path, "rb").read()
    try:
        import tflite

        tflite_model = tflite.Model.GetRootAsModel(tflite_model_buf, 0)
    except AttributeError:
        import tflite.Model

        tflite_model = tflite.Model.Model.GetRootAsModel(tflite_model_buf, 0)

    # Load TFLite model and convert to Relay
    relay_mod, params = relay.frontend.from_tflite(
        tflite_model, shape_dict={"input_1": input_shape}, dtype_dict={"input_1 ": "int8"}
    )

    target = tvm.target.target.micro(model)
    executor = Executor(
        "aot", {"unpacked-api": True, "interface-api": "c", "workspace-byte-alignment": 4}
    )
    runtime = Runtime("crt")
    with tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True}):
        lowered = relay.build(relay_mod, target, params=params, runtime=runtime, executor=executor)

    sample_url = "https://github.com/tlc-pack/web-data/raw/967fc387dadb272c5a7f8c3461d34c060100dbf1/testdata/microTVM/data/keyword_spotting_int8_6.pyc.npy"
    sample_path = download_testdata(sample_url, "keyword_spotting_int8_6.pyc.npy", module="data")
    sample = np.load(sample_path)

    with tempfile.NamedTemporaryFile() as tar_temp_file:
        with tarfile.open(tar_temp_file.name, "w:gz") as tf:
            with tempfile.TemporaryDirectory() as tar_temp_dir:
                model_files_path = os.path.join(tar_temp_dir, "include")
                os.mkdir(model_files_path)
                header_path = generate_c_interface_header(
                    lowered.libmod_name, ["input_1"], ["output"], [], 0, model_files_path
                )
                tf.add(header_path, arcname=os.path.relpath(header_path, tar_temp_dir))

            test_utils.create_header_file("input_data", sample, "include", tf)
            test_utils.create_header_file(
                "output_data", np.zeros(shape=output_shape, dtype="int8"), "include", tf
            )

        project, _ = test_utils.build_project(
            temp_dir,
            board,
            west_cmd,
            lowered,
            build_config,
            extra_files_tar=tar_temp_file.name,
        )

    project.flash()
    with project.transport() as transport:
        aot_transport_init_wait(transport)
        transport.write(b"infer%", timeout_sec=5)
        result_line = aot_transport_find_message(transport, "result", timeout_sec=60)

    result_line = result_line.strip("\n")
    result_line = result_line.split(":")
    result = int(result_line[1])
    time = int(result_line[2])
    logging.info(f"Result: {result}\ttime: {time} ms")
    assert result == 6