Ejemplo n.º 1
0
def test_default_conda_env_strips_dev_suffix_from_pyspark_version(
        spark_model_iris, model_path):
    with mock.patch("importlib_metadata.version", return_value="2.4.0"):
        default_conda_env_standard = sparkm.get_default_conda_env()

    for dev_version in [
            "2.4.0.dev0", "2.4.0.dev", "2.4.0.dev1", "2.4.0dev.a", "2.4.0.devb"
    ]:
        with mock.patch("importlib_metadata.version",
                        return_value=dev_version):
            default_conda_env_dev = sparkm.get_default_conda_env()
            assert default_conda_env_dev == default_conda_env_standard

            with mlflow.start_run():
                sparkm.log_model(spark_model=spark_model_iris.model,
                                 artifact_path="model")
                model_uri = "runs:/{run_id}/{artifact_path}".format(
                    run_id=mlflow.active_run().info.run_id,
                    artifact_path="model")

            model_path = _download_artifact_from_uri(artifact_uri=model_uri)
            pyfunc_conf = _get_flavor_configuration(
                model_path=model_path, flavor_name=pyfunc.FLAVOR_NAME)
            conda_env_path = os.path.join(model_path, pyfunc_conf[pyfunc.ENV])
            with open(conda_env_path, "r") as f:
                persisted_conda_env_dev = yaml.safe_load(f)
            assert persisted_conda_env_dev == default_conda_env_standard

    for unaffected_version in ["2.0", "2.3.4", "2"]:
        with mock.patch("importlib_metadata.version",
                        return_value=unaffected_version):
            assert unaffected_version in yaml.safe_dump(
                sparkm.get_default_conda_env())
def test_sparkml_model_save_without_specified_conda_env_uses_default_env_with_expected_dependencies(
        spark_model_iris, model_path):
    sparkm.save_model(spark_model=spark_model_iris.model, path=model_path, conda_env=None)

    pyfunc_conf = _get_flavor_configuration(model_path=model_path, flavor_name=pyfunc.FLAVOR_NAME)
    conda_env_path = os.path.join(model_path, pyfunc_conf[pyfunc.ENV])
    with open(conda_env_path, "r") as f:
        conda_env = yaml.safe_load(f)

    assert conda_env == sparkm.get_default_conda_env()
Ejemplo n.º 3
0
def test_pyspark_version_is_logged_without_dev_suffix(spark_model_iris):
    unsuffixed_version = "2.4.0"
    for dev_suffix in [".dev0", ".dev", ".dev1", "dev.a", ".devb"]:
        with mock.patch("importlib_metadata.version", return_value=unsuffixed_version + dev_suffix):
            with mlflow.start_run():
                sparkm.log_model(spark_model=spark_model_iris.model, artifact_path="model")
                model_uri = mlflow.get_artifact_uri("model")
            _assert_pip_requirements(model_uri, ["mlflow", f"pyspark=={unsuffixed_version}"])

    for unaffected_version in ["2.0", "2.3.4", "2"]:
        with mock.patch("importlib_metadata.version", return_value=unaffected_version):
            pip_deps = _get_pip_deps(sparkm.get_default_conda_env())
            assert any(x == f"pyspark=={unaffected_version}" for x in pip_deps)
Ejemplo n.º 4
0
def test_sparkml_model_log_without_specified_conda_env_uses_default_env_with_expected_dependencies(
    spark_model_iris,
):
    artifact_path = "model"
    with mlflow.start_run():
        sparkm.log_model(spark_model=spark_model_iris.model, artifact_path=artifact_path)
        model_uri = "runs:/{run_id}/{artifact_path}".format(
            run_id=mlflow.active_run().info.run_id, artifact_path=artifact_path
        )

    model_path = _download_artifact_from_uri(artifact_uri=model_uri)
    pyfunc_conf = _get_flavor_configuration(model_path=model_path, flavor_name=pyfunc.FLAVOR_NAME)
    conda_env_path = os.path.join(model_path, pyfunc_conf[pyfunc.ENV])
    with open(conda_env_path, "r") as f:
        conda_env = yaml.safe_load(f)

    assert conda_env == sparkm.get_default_conda_env()