Example #1
0
def test_requirements_file_log_model(create_requirements_file,
                                     sequential_model):
    requirements_file, content_expected = create_requirements_file
    with mlflow.start_run():
        mlflow.pytorch.log_model(
            pytorch_model=sequential_model,
            artifact_path="models",
            requirements_file=requirements_file,
        )

        model_uri = "runs:/{run_id}/{model_path}".format(
            run_id=mlflow.active_run().info.run_id, model_path="models")

        # Verify that explicitly specified requirements file overrides default requirements file
        conda_env = get_default_conda_env()
        pip_deps = conda_env["dependencies"][-1]["pip"]
        assert _mlflow_additional_pip_env(pip_deps) != content_expected

        with TempDir(remove_on_exit=True) as tmp:
            model_path = _download_artifact_from_uri(model_uri, tmp.path())
            model_config_path = os.path.join(model_path, "MLmodel")
            model_config = Model.load(model_config_path)
            flavor_config = model_config.flavors["pytorch"]

            assert "requirements_file" in flavor_config
            loaded_requirements_file = flavor_config["requirements_file"]

            assert "path" in loaded_requirements_file
            requirements_file_path = loaded_requirements_file["path"]
            requirements_file_path = os.path.join(model_path,
                                                  requirements_file_path)
            with open(requirements_file_path) as fp:
                assert fp.read() == content_expected
Example #2
0
def test_requirements_file_save_model(create_requirements_file,
                                      sequential_model):
    requirements_file, content_expected = create_requirements_file
    with TempDir(remove_on_exit=True) as tmp:
        model_path = os.path.join(tmp.path(), "models")
        mlflow.pytorch.save_model(pytorch_model=sequential_model,
                                  path=model_path,
                                  requirements_file=requirements_file)

        # Verify that explicitly specified requirements file overrides default requirements file
        conda_env = get_default_conda_env()
        pip_deps = conda_env["dependencies"][-1]["pip"]
        assert _mlflow_additional_pip_env(pip_deps) != content_expected

        model_config_path = os.path.join(model_path, "MLmodel")
        model_config = Model.load(model_config_path)
        flavor_config = model_config.flavors["pytorch"]

        assert "requirements_file" in flavor_config
        loaded_requirements_file = flavor_config["requirements_file"]

        assert "path" in loaded_requirements_file
        requirements_file_path = loaded_requirements_file["path"]
        requirements_file_path = os.path.join(model_path,
                                              requirements_file_path)
        with open(requirements_file_path) as fp:
            assert fp.read() == content_expected