Ejemplo n.º 1
0
def test_run_tflite_module__with_profile__valid_input(
        use_vm, tflite_mobilenet_v1_1_quant, tflite_compile_model,
        imagenet_cat):
    # some CI environments wont offer TFLite, so skip in case it is not present
    pytest.importorskip("tflite")

    inputs = np.load(imagenet_cat)
    input_dict = {"input": inputs["input"].astype("uint8")}

    tflite_compiled_model = tflite_compile_model(tflite_mobilenet_v1_1_quant,
                                                 use_vm=use_vm)
    result = tvmc.run(
        tflite_compiled_model,
        inputs=input_dict,
        hostname=None,
        device="cpu",
        profile=True,
    )

    # collect the top 5 results
    top_5_results = get_top_results(result, 5)
    top_5_ids = top_5_results[0]

    # IDs were collected from this reference:
    # https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/
    # java/demo/app/src/main/assets/labels_mobilenet_quant_v1_224.txt
    tiger_cat_mobilenet_id = 283

    assert (
        tiger_cat_mobilenet_id
        in top_5_ids), "tiger cat is expected in the top-5 for mobilenet v1"
    assert type(result.outputs) is dict
    assert type(result.times) is BenchmarkResult
    assert "output_0" in result.outputs.keys()
Ejemplo n.º 2
0
def test_run_tflite_module_with_rpc(tflite_mobilenet_v1_1_quant,
                                    tflite_compile_model, imagenet_cat):
    """
    Test to check that TVMC run is functional when it is being used in
    conjunction with an RPC server.
    """
    pytest.importorskip("tflite")

    inputs = np.load(imagenet_cat)
    input_dict = {"input": inputs["input"].astype("uint8")}

    tflite_compiled_model = tflite_compile_model(tflite_mobilenet_v1_1_quant)

    server = rpc.Server("127.0.0.1", 9099)
    result = tvmc.run(
        tflite_compiled_model,
        inputs=input_dict,
        hostname=server.host,
        port=server.port,
        device="cpu",
    )

    top_5_results = get_top_results(result, 5)
    top_5_ids = top_5_results[0]

    # IDs were collected from this reference:
    # https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/
    # java/demo/app/src/main/assets/labels_mobilenet_quant_v1_224.txt
    tiger_cat_mobilenet_id = 283

    assert (
        tiger_cat_mobilenet_id
        in top_5_ids), "tiger cat is expected in the top-5 for mobilenet v1"
    assert isinstance(result.outputs, dict)
    assert "output_0" in result.outputs.keys()
Ejemplo n.º 3
0
def test_get_top_results_keep_results():
    fake_outputs = {"output_0": np.array([[1, 2, 3, 4], [5, 6, 7, 8]])}
    fake_result = TVMCResult(outputs=fake_outputs, times=None)
    number_of_results_wanted = 3
    sut = get_top_results(fake_result, number_of_results_wanted)

    expected_number_of_lines = 2
    assert len(sut) == expected_number_of_lines

    expected_number_of_results_per_line = 3
    assert len(sut[0]) == expected_number_of_results_per_line
    assert len(sut[1]) == expected_number_of_results_per_line