Exemple #1
0
def _get_flavor_configuration_from_uri(model_uri, flavor_name):
    """
    Obtains the configuration for the specified flavor from the specified
    MLflow model uri. If the model does not contain the specified flavor,
    an exception will be thrown.

    :param model_uri: The path to the root directory of the MLflow model for which to load
                       the specified flavor configuration.
    :param flavor_name: The name of the flavor configuration to load.
    :return: The flavor configuration as a dictionary.
    """
    try:
        ml_model_file = _download_artifact_from_uri(
            artifact_uri=append_to_uri_path(model_uri, MLMODEL_FILE_NAME))
    except Exception as ex:
        raise MlflowException(
            "Failed to download an \"{model_file}\" model file from \"{model_uri}\": {ex}"
            .format(model_file=MLMODEL_FILE_NAME, model_uri=model_uri,
                    ex=ex), RESOURCE_DOES_NOT_EXIST)
    model_conf = Model.load(ml_model_file)
    if flavor_name not in model_conf.flavors:
        raise MlflowException(
            "Model does not have the \"{flavor_name}\" flavor".format(
                flavor_name=flavor_name), RESOURCE_DOES_NOT_EXIST)
    return model_conf.flavors[flavor_name]
Exemple #2
0
    def create_run(self, experiment_id, user_id, start_time, tags):
        with self.ManagedSessionMaker() as session:
            experiment = self.get_experiment(experiment_id)
            self._check_experiment_is_active(experiment)

            run_id = uuid.uuid4().hex
            artifact_location = append_to_uri_path(
                experiment.artifact_location, run_id,
                SqlAlchemyStore.ARTIFACTS_FOLDER_NAME)
            run = SqlRun(name="",
                         artifact_uri=artifact_location,
                         run_uuid=run_id,
                         experiment_id=experiment_id,
                         source_type=SourceType.to_string(SourceType.UNKNOWN),
                         source_name="",
                         entry_point_name="",
                         user_id=user_id,
                         status=RunStatus.to_string(RunStatus.RUNNING),
                         start_time=start_time,
                         end_time=None,
                         source_version="",
                         lifecycle_stage=LifecycleStage.ACTIVE)

            tags_dict = {}
            for tag in tags:
                tags_dict[tag.key] = tag.value
            run.tags = [
                SqlTag(key=key, value=value)
                for key, value in tags_dict.items()
            ]
            self._save_to_db(objs=run, session=session)

            return run.to_mlflow_entity()
Exemple #3
0
 def _create_experiment_with_id(self, name, experiment_id, artifact_uri):
     artifact_uri = artifact_uri or append_to_uri_path(
         self.artifact_root_uri, str(experiment_id))
     self._check_root_dir()
     meta_dir = mkdir(self.root_directory, str(experiment_id))
     experiment = Experiment(experiment_id, name, artifact_uri,
                             LifecycleStage.ACTIVE)
     experiment_dict = dict(experiment)
     # tags are added to the file system and are not written to this dict on write
     # As such, we should not include them in the meta file.
     del experiment_dict['tags']
     write_yaml(meta_dir, FileStore.META_DATA_FILE_NAME, experiment_dict)
     return experiment_id
Exemple #4
0
def _get_flavor_backend(model_uri, **kwargs):
    with TempDir() as tmp:
        if ModelsArtifactRepository.is_models_uri(model_uri):
            underlying_model_uri = ModelsArtifactRepository.get_underlying_uri(
                model_uri)
        else:
            underlying_model_uri = model_uri
        local_path = _download_artifact_from_uri(append_to_uri_path(
            underlying_model_uri, MLMODEL_FILE_NAME),
                                                 output_path=tmp.path())
        model = Model.load(local_path)
    flavor_name, flavor_backend = get_flavor_backend(model, **kwargs)
    if flavor_backend is None:
        raise Exception("No suitable flavor backend was found for the model.")
    _logger.info("Selected backend for flavor '%s'", flavor_name)
    return flavor_backend
Exemple #5
0
def load_model(model_uri, dfs_tmpdir=None):
    """
    Load the Spark MLlib model from the path.

    :param model_uri: The location, in URI format, of the MLflow model, for example:

                      - ``/Users/me/path/to/local/model``
                      - ``relative/path/to/local/model``
                      - ``s3://my_bucket/path/to/model``
                      - ``runs:/<mlflow_run_id>/run-relative/path/to/model``
                      - ``models:/<model_name>/<model_version>``
                      - ``models:/<model_name>/<stage>``

                      For more information about supported URI schemes, see
                      `Referencing Artifacts <https://www.mlflow.org/docs/latest/concepts.html#
                      artifact-locations>`_.
    :param dfs_tmpdir: Temporary directory path on Distributed (Hadoop) File System (DFS) or local
                       filesystem if running in local mode. The model is loaded from this
                       destination. Defaults to ``/tmp/mlflow``.
    :return: pyspark.ml.pipeline.PipelineModel

    .. code-block:: python
        :caption: Example

        from mlflow import spark
        model = mlflow.spark.load_model("spark-model")
        # Prepare test documents, which are unlabeled (id, text) tuples.
        test = spark.createDataFrame([
            (4, "spark i j k"),
            (5, "l m n"),
            (6, "spark hadoop spark"),
            (7, "apache hadoop")], ["id", "text"])
        # Make predictions on test documents
        prediction = model.transform(test)
    """
    if RunsArtifactRepository.is_runs_uri(model_uri):
        runs_uri = model_uri
        model_uri = RunsArtifactRepository.get_underlying_uri(model_uri)
        _logger.info("'%s' resolved as '%s'", runs_uri, model_uri)
    elif ModelsArtifactRepository.is_models_uri(model_uri):
        runs_uri = model_uri
        model_uri = ModelsArtifactRepository.get_underlying_uri(model_uri)
        _logger.info("'%s' resolved as '%s'", runs_uri, model_uri)
    flavor_conf = _get_flavor_configuration_from_uri(model_uri, FLAVOR_NAME)
    model_uri = append_to_uri_path(model_uri, flavor_conf["model_data"])
    return _load_model(model_uri=model_uri, dfs_tmpdir=dfs_tmpdir)
Exemple #6
0
 def _get_artifact_dir(self, experiment_id, run_uuid):
     _validate_run_id(run_uuid)
     return append_to_uri_path(
         self.get_experiment(experiment_id).artifact_location, run_uuid,
         FileStore.ARTIFACTS_FOLDER_NAME)
Exemple #7
0
def validate_append_to_uri_path_test_cases(cases):
    for input_uri, input_path, expected_output_uri in cases:
        assert append_to_uri_path(input_uri, input_path) == expected_output_uri
        assert append_to_uri_path(
            input_uri, *posixpath.split(input_path)) == expected_output_uri
Exemple #8
0
 def _get_artifact_location(self, experiment_id):
     return append_to_uri_path(self.artifact_root_uri, str(experiment_id))