Exemple #1
0
def test_tvmc_import_package_project_dir(tflite_mobilenet_v1_1_quant,
                                         tflite_compile_model):
    pytest.importorskip("tflite")

    # Generate a MLF archive.
    compiled_model_mlf_tvmc_package = tflite_compile_model(
        tflite_mobilenet_v1_1_quant, output_format="mlf")

    # Import the MLF archive setting 'project_dir'. It must succeed.
    mlf_archive_path = compiled_model_mlf_tvmc_package.package_path
    tvmc_package = TVMCPackage(mlf_archive_path, project_dir="/tmp/foobar")
    assert tvmc_package.type == "mlf", "Can't load the MLF archive passing the project directory!"

    # Generate a Classic archive.
    compiled_model_classic_tvmc_package = tflite_compile_model(
        tflite_mobilenet_v1_1_quant)

    # Import the Classic archive setting 'project_dir'.
    # It must fail since setting 'project_dir' is only support when importing a MLF archive.
    classic_archive_path = compiled_model_classic_tvmc_package.package_path
    with pytest.raises(TVMCException) as exp:
        tvmc_package = TVMCPackage(classic_archive_path,
                                   project_dir="/tmp/foobar")

    expected_reason = "Setting 'project_dir' is only allowed when importing a MLF.!"
    on_error = "A TVMCException was caught but its reason is not the expected one."
    assert str(exp.value) == expected_reason, on_error
Exemple #2
0
def test_tvmc_import_package_mlf(tflite_compiled_model_mlf):
    pytest.importorskip("tflite")

    # Compile and export a model to a MLF archive so it can be imported.
    exported_tvmc_package = tflite_compiled_model_mlf
    archive_path = exported_tvmc_package.package_path

    # Import the MLF archive. TVMCPackage constructor will call import_package method.
    tvmc_package = TVMCPackage(archive_path)

    assert tvmc_package.lib_name is None, ".lib_name must not be set in the MLF archive."
    assert tvmc_package.lib_path is None, ".lib_path must not be set in the MLF archive."
    assert tvmc_package.graph is not None, ".graph must be set in the MLF archive."
    assert tvmc_package.params is not None, ".params must be set in the MLF archive."
    assert tvmc_package.type == "mlf", ".type must be set to 'mlf' in the MLF format."
Exemple #3
0
def test_compile_keras__save_module(keras_resnet50, tmpdir_factory):
    # some CI environments wont offer tensorflow/Keras, so skip in case it is not present
    pytest.importorskip("tensorflow")

    expected_temp_dir = tmpdir_factory.mktemp("saved_output")
    expected_file_name = "saved.tar"
    module_file = os.path.join(expected_temp_dir, expected_file_name)

    tvmc_model = tvmc.load(keras_resnet50)
    tvmc.compile(tvmc_model, target="llvm", dump_code="ll", package_path=module_file)

    assert os.path.exists(module_file), "output file {0} should exist".format(module_file)

    # Test that we can load back in a module.
    tvmc_package = TVMCPackage(package_path=module_file)
    assert type(tvmc_package.lib_path) is str
    assert type(tvmc_package.graph) is str
    assert type(tvmc_package.params) is bytearray
Exemple #4
0
def test_tvmc_import_package_mlf_aot(tflite_mobilenet_v1_1_quant,
                                     tflite_compile_model):
    pytest.importorskip("tflite")

    tflite_compiled_model_mlf = tflite_compile_model(
        tflite_mobilenet_v1_1_quant,
        target="c --executor=aot",
        output_format="mlf",
        pass_context_configs=["tir.disable_vectorize=1"],
    )

    # Compile and export a model to a MLF archive so it can be imported.
    exported_tvmc_package = tflite_compiled_model_mlf
    archive_path = exported_tvmc_package.package_path

    # Import the MLF archive. TVMCPackage constructor will call import_package method.
    tvmc_package = TVMCPackage(archive_path)

    assert tvmc_package.lib_name is None, ".lib_name must not be set in the MLF archive."
    assert tvmc_package.lib_path is None, ".lib_path must not be set in the MLF archive."
    assert tvmc_package.graph is None, ".graph must not be set in the MLF archive for AOT executor."
    assert tvmc_package.params is not None, ".params must be set in the MLF archive."
    assert tvmc_package.type == "mlf", ".type must be set to 'mlf' in the MLF format."