def test_requirements_file_log_model(create_requirements_file, sequential_model): requirements_file, content_expected = create_requirements_file with mlflow.start_run(): mlflow.pytorch.log_model( pytorch_model=sequential_model, artifact_path="models", requirements_file=requirements_file, ) model_uri = "runs:/{run_id}/{model_path}".format( run_id=mlflow.active_run().info.run_id, model_path="models") # Verify that explicitly specified requirements file overrides default requirements file conda_env = get_default_conda_env() pip_deps = conda_env["dependencies"][-1]["pip"] assert _mlflow_additional_pip_env(pip_deps) != content_expected with TempDir(remove_on_exit=True) as tmp: model_path = _download_artifact_from_uri(model_uri, tmp.path()) model_config_path = os.path.join(model_path, "MLmodel") model_config = Model.load(model_config_path) flavor_config = model_config.flavors["pytorch"] assert "requirements_file" in flavor_config loaded_requirements_file = flavor_config["requirements_file"] assert "path" in loaded_requirements_file requirements_file_path = loaded_requirements_file["path"] requirements_file_path = os.path.join(model_path, requirements_file_path) with open(requirements_file_path) as fp: assert fp.read() == content_expected
def test_requirements_file_save_model(create_requirements_file, sequential_model): requirements_file, content_expected = create_requirements_file with TempDir(remove_on_exit=True) as tmp: model_path = os.path.join(tmp.path(), "models") mlflow.pytorch.save_model(pytorch_model=sequential_model, path=model_path, requirements_file=requirements_file) # Verify that explicitly specified requirements file overrides default requirements file conda_env = get_default_conda_env() pip_deps = conda_env["dependencies"][-1]["pip"] assert _mlflow_additional_pip_env(pip_deps) != content_expected model_config_path = os.path.join(model_path, "MLmodel") model_config = Model.load(model_config_path) flavor_config = model_config.flavors["pytorch"] assert "requirements_file" in flavor_config loaded_requirements_file = flavor_config["requirements_file"] assert "path" in loaded_requirements_file requirements_file_path = loaded_requirements_file["path"] requirements_file_path = os.path.join(model_path, requirements_file_path) with open(requirements_file_path) as fp: assert fp.read() == content_expected