def test_is_valid_dbfs_uri(uri, result): assert is_valid_dbfs_uri(uri) == result
def save_model( spark_model, path, mlflow_model=None, conda_env=None, dfs_tmpdir=None, sample_input=None, signature: ModelSignature = None, input_example: ModelInputExample = None, ): """ Save a Spark MLlib Model to a local path. By default, this function saves models using the Spark MLlib persistence mechanism. Additionally, if a sample input is specified using the ``sample_input`` parameter, the model is also serialized in MLeap format and the MLeap flavor is added. :param spark_model: Spark model to be saved - MLflow can only save descendants of pyspark.ml.Model which implement MLReadable and MLWritable. :param path: Local path where the model is to be saved. :param mlflow_model: MLflow model config this flavor is being added to. :param conda_env: Either a dictionary representation of a Conda environment or the path to a Conda environment yaml file. If provided, this decsribes the environment this model should be run in. At minimum, it should specify the dependencies contained in :func:`get_default_conda_env()`. If `None`, the default :func:`get_default_conda_env()` environment is added to the model. The following is an *example* dictionary representation of a Conda environment:: { 'name': 'mlflow-env', 'channels': ['defaults'], 'dependencies': [ 'python=3.7.0', 'pyspark=2.3.0' ] } :param dfs_tmpdir: Temporary directory path on Distributed (Hadoop) File System (DFS) or local filesystem if running in local mode. The model is be written in this destination and then copied to the requested local path. This is necessary as Spark ML models read from and write to DFS if running on a cluster. All temporary files created on the DFS are removed if this operation completes successfully. Defaults to ``/tmp/mlflow``. :param sample_input: A sample input that is used to add the MLeap flavor to the model. This must be a PySpark DataFrame that the model can evaluate. If ``sample_input`` is ``None``, the MLeap flavor is not added. :param signature: (Experimental) :py:class:`ModelSignature <mlflow.models.ModelSignature>` describes model input and output :py:class:`Schema <mlflow.types.Schema>`. The model signature can be :py:func:`inferred <mlflow.models.infer_signature>` from datasets with valid model input (e.g. the training dataset with target column omitted) and valid model output (e.g. model predictions generated on the training dataset), for example: .. code-block:: python from mlflow.models.signature import infer_signature train = df.drop_column("target_label") predictions = ... # compute model predictions signature = infer_signature(train, predictions) :param input_example: (Experimental) Input example provides one or several instances of valid model input. The example can be used as a hint of what data to feed the model. The given example will be converted to a Pandas DataFrame and then serialized to json using the Pandas split-oriented format. Bytes are base64-encoded. .. code-block:: python :caption: Example from mlflow import spark from pyspark.ml.pipeline.PipelineModel # your pyspark.ml.pipeline.PipelineModel type model = ... mlflow.spark.save_model(model, "spark-model") """ _validate_model(spark_model) from pyspark.ml import PipelineModel if not isinstance(spark_model, PipelineModel): spark_model = PipelineModel([spark_model]) if mlflow_model is None: mlflow_model = Model() # Spark ML stores the model on DFS if running on a cluster # Save it to a DFS temp dir first and copy it to local path if dfs_tmpdir is None: dfs_tmpdir = DFS_TMP tmp_path = _tmp_path(dfs_tmpdir) spark_model.save(tmp_path) sparkml_data_path = os.path.abspath(os.path.join(path, _SPARK_MODEL_PATH_SUB)) # We're copying the Spark model from DBFS to the local filesystem if (a) the temporary DFS URI # we saved the Spark model to is a DBFS URI ("dbfs:/my-directory"), or (b) if we're running # on a Databricks cluster and the URI is schemeless (e.g. looks like a filesystem absolute path # like "/my-directory") copying_from_dbfs = is_valid_dbfs_uri(tmp_path) or ( databricks_utils.is_in_cluster() and posixpath.abspath(tmp_path) == tmp_path ) if copying_from_dbfs: tmp_path_fuse = dbfs_hdfs_uri_to_fuse_path(tmp_path) shutil.move(src=tmp_path_fuse, dst=sparkml_data_path) else: _HadoopFileSystem.copy_to_local_file(tmp_path, sparkml_data_path, remove_src=True) _save_model_metadata( dst_dir=path, spark_model=spark_model, mlflow_model=mlflow_model, sample_input=sample_input, conda_env=conda_env, signature=signature, input_example=input_example, )