def test_default_conda_env_strips_dev_suffix_from_pyspark_version( spark_model_iris, model_path): with mock.patch("importlib_metadata.version", return_value="2.4.0"): default_conda_env_standard = sparkm.get_default_conda_env() for dev_version in [ "2.4.0.dev0", "2.4.0.dev", "2.4.0.dev1", "2.4.0dev.a", "2.4.0.devb" ]: with mock.patch("importlib_metadata.version", return_value=dev_version): default_conda_env_dev = sparkm.get_default_conda_env() assert default_conda_env_dev == default_conda_env_standard with mlflow.start_run(): sparkm.log_model(spark_model=spark_model_iris.model, artifact_path="model") model_uri = "runs:/{run_id}/{artifact_path}".format( run_id=mlflow.active_run().info.run_id, artifact_path="model") model_path = _download_artifact_from_uri(artifact_uri=model_uri) pyfunc_conf = _get_flavor_configuration( model_path=model_path, flavor_name=pyfunc.FLAVOR_NAME) conda_env_path = os.path.join(model_path, pyfunc_conf[pyfunc.ENV]) with open(conda_env_path, "r") as f: persisted_conda_env_dev = yaml.safe_load(f) assert persisted_conda_env_dev == default_conda_env_standard for unaffected_version in ["2.0", "2.3.4", "2"]: with mock.patch("importlib_metadata.version", return_value=unaffected_version): assert unaffected_version in yaml.safe_dump( sparkm.get_default_conda_env())
def test_sparkml_model_save_without_specified_conda_env_uses_default_env_with_expected_dependencies( spark_model_iris, model_path): sparkm.save_model(spark_model=spark_model_iris.model, path=model_path, conda_env=None) pyfunc_conf = _get_flavor_configuration(model_path=model_path, flavor_name=pyfunc.FLAVOR_NAME) conda_env_path = os.path.join(model_path, pyfunc_conf[pyfunc.ENV]) with open(conda_env_path, "r") as f: conda_env = yaml.safe_load(f) assert conda_env == sparkm.get_default_conda_env()
def test_pyspark_version_is_logged_without_dev_suffix(spark_model_iris): unsuffixed_version = "2.4.0" for dev_suffix in [".dev0", ".dev", ".dev1", "dev.a", ".devb"]: with mock.patch("importlib_metadata.version", return_value=unsuffixed_version + dev_suffix): with mlflow.start_run(): sparkm.log_model(spark_model=spark_model_iris.model, artifact_path="model") model_uri = mlflow.get_artifact_uri("model") _assert_pip_requirements(model_uri, ["mlflow", f"pyspark=={unsuffixed_version}"]) for unaffected_version in ["2.0", "2.3.4", "2"]: with mock.patch("importlib_metadata.version", return_value=unaffected_version): pip_deps = _get_pip_deps(sparkm.get_default_conda_env()) assert any(x == f"pyspark=={unaffected_version}" for x in pip_deps)
def test_sparkml_model_log_without_specified_conda_env_uses_default_env_with_expected_dependencies( spark_model_iris, ): artifact_path = "model" with mlflow.start_run(): sparkm.log_model(spark_model=spark_model_iris.model, artifact_path=artifact_path) model_uri = "runs:/{run_id}/{artifact_path}".format( run_id=mlflow.active_run().info.run_id, artifact_path=artifact_path ) model_path = _download_artifact_from_uri(artifact_uri=model_uri) pyfunc_conf = _get_flavor_configuration(model_path=model_path, flavor_name=pyfunc.FLAVOR_NAME) conda_env_path = os.path.join(model_path, pyfunc_conf[pyfunc.ENV]) with open(conda_env_path, "r") as f: conda_env = yaml.safe_load(f) assert conda_env == sparkm.get_default_conda_env()