def __init__(self, artifact_uri):
        if not is_valid_dbfs_uri(artifact_uri):
            raise MlflowException(message="DBFS URI must be of the form dbfs:/<path> or " +
                                  "dbfs://profile@databricks/<path>",
                                  error_code=INVALID_PARAMETER_VALUE)
        if not is_databricks_acled_artifacts_uri(artifact_uri):
            raise MlflowException(message=('Artifact URI incorrect. Expected path prefix to be'
                                           ' databricks/mlflow-tracking/path/to/artifact/..'),
                                  error_code=INVALID_PARAMETER_VALUE)
        # The dbfs:/ path ultimately used for artifact operations should not contain the
        # Databricks profile info, so strip it before setting ``artifact_uri``.
        super(DatabricksArtifactRepository, self).__init__(
            remove_databricks_profile_info_from_artifact_uri(artifact_uri))

        self.databricks_profile_uri = get_databricks_profile_uri_from_artifact_uri(artifact_uri) \
            or mlflow.tracking.get_tracking_uri()
        self.run_id = self._extract_run_id(self.artifact_uri)
        # Fetch the artifact root for the MLflow Run associated with `artifact_uri` and compute
        # the path of `artifact_uri` relative to the MLflow Run's artifact root
        # (the `run_relative_artifact_repo_root_path`). All operations performed on this artifact
        # repository will be performed relative to this computed location
        artifact_repo_root_path = extract_and_normalize_path(artifact_uri)
        run_artifact_root_uri = self._get_run_artifact_root(self.run_id)
        run_artifact_root_path = extract_and_normalize_path(run_artifact_root_uri)
        run_relative_root_path = posixpath.relpath(
            path=artifact_repo_root_path, start=run_artifact_root_path
        )
        # If the paths are equal, then use empty string over "./" for ListArtifact compatibility.
        self.run_relative_artifact_repo_root_path = \
            "" if run_artifact_root_path == artifact_repo_root_path else run_relative_root_path
 def get_underlying_uri(runs_uri):
     from mlflow.tracking.artifact_utils import get_artifact_uri
     (run_id,
      artifact_path) = RunsArtifactRepository.parse_runs_uri(runs_uri)
     tracking_uri = get_databricks_profile_uri_from_artifact_uri(runs_uri)
     uri = get_artifact_uri(run_id, artifact_path, tracking_uri)
     assert not RunsArtifactRepository.is_runs_uri(
         uri)  # avoid an infinite loop
     return add_databricks_profile_info_to_artifact_uri(uri, tracking_uri)
    def get_underlying_uri(uri):
        # Note: to support a registry URI that is different from the tracking URI here,
        # we'll need to add setting of registry URIs via environment variables.
        from mlflow.tracking import MlflowClient

        databricks_profile_uri = (
            get_databricks_profile_uri_from_artifact_uri(uri) or mlflow.get_registry_uri()
        )
        client = MlflowClient(registry_uri=databricks_profile_uri)
        (name, version) = get_model_name_and_version(client, uri)
        download_uri = client.get_model_version_download_uri(name, version)
        return add_databricks_profile_info_to_artifact_uri(download_uri, databricks_profile_uri)
示例#4
0
    def __init__(self, artifact_uri):
        if not is_using_databricks_registry(artifact_uri):
            raise MlflowException(
                message=
                "A valid databricks profile is required to instantiate this repository",
                error_code=INVALID_PARAMETER_VALUE,
            )
        super().__init__(artifact_uri)
        from mlflow.tracking import MlflowClient

        self.databricks_profile_uri = (
            get_databricks_profile_uri_from_artifact_uri(artifact_uri)
            or mlflow.get_registry_uri())
        client = MlflowClient(registry_uri=self.databricks_profile_uri)
        self.model_name, self.model_version = get_model_name_and_version(
            client, artifact_uri)
    def __init__(self, artifact_uri):
        if not is_valid_dbfs_uri(artifact_uri):
            raise MlflowException(message="DBFS URI must be of the form dbfs:/<path> or " +
                                          "dbfs://profile@databricks/<path>",
                                  error_code=INVALID_PARAMETER_VALUE)

        # The dbfs:/ path ultimately used for artifact operations should not contain the
        # Databricks profile info, so strip it before setting ``artifact_uri``.
        super(DbfsRestArtifactRepository, self).__init__(
            remove_databricks_profile_info_from_artifact_uri(artifact_uri))

        databricks_profile_uri = get_databricks_profile_uri_from_artifact_uri(artifact_uri)
        if databricks_profile_uri:
            hostcreds_from_uri = get_databricks_host_creds(databricks_profile_uri)
            self.get_host_creds = lambda: hostcreds_from_uri
        else:
            self.get_host_creds = _get_host_creds_from_default_store()
 def get_underlying_uri(uri):
     # Note: to support a registry URI that is different from the tracking URI here,
     # we'll need to add setting of registry URIs via environment variables.
     from mlflow.tracking import MlflowClient
     databricks_profile_uri = get_databricks_profile_uri_from_artifact_uri(
         uri)
     client = MlflowClient(registry_uri=databricks_profile_uri)
     (name, version, stage) = ModelsArtifactRepository._parse_uri(uri)
     if stage is not None:
         latest = client.get_latest_versions(name, [stage])
         if len(latest) == 0:
             raise MlflowException(
                 "No versions of model with name '{name}' and "
                 "stage '{stage}' found".format(name=name, stage=stage))
         version = latest[0].version
     download_uri = client.get_model_version_download_uri(name, version)
     return add_databricks_profile_info_to_artifact_uri(
         download_uri, databricks_profile_uri)
示例#7
0
    def __init__(self, artifact_uri):
        if not is_valid_dbfs_uri(artifact_uri):
            raise MlflowException(
                message="DBFS URI must be of the form dbfs:/<path> or "
                + "dbfs://profile@databricks/<path>",
                error_code=INVALID_PARAMETER_VALUE,
            )
        if not is_databricks_acled_artifacts_uri(artifact_uri):
            raise MlflowException(
                message=(
                    "Artifact URI incorrect. Expected path prefix to be"
                    " databricks/mlflow-tracking/path/to/artifact/.."
                ),
                error_code=INVALID_PARAMETER_VALUE,
            )
        # The dbfs:/ path ultimately used for artifact operations should not contain the
        # Databricks profile info, so strip it before setting ``artifact_uri``.
        super().__init__(remove_databricks_profile_info_from_artifact_uri(artifact_uri))

        self.databricks_profile_uri = (
            get_databricks_profile_uri_from_artifact_uri(artifact_uri)
            or mlflow.tracking.get_tracking_uri()
        )
        self.run_id = self._extract_run_id(self.artifact_uri)
        # Fetch the artifact root for the MLflow Run associated with `artifact_uri` and compute
        # the path of `artifact_uri` relative to the MLflow Run's artifact root
        # (the `run_relative_artifact_repo_root_path`). All operations performed on this artifact
        # repository will be performed relative to this computed location
        artifact_repo_root_path = extract_and_normalize_path(artifact_uri)
        run_artifact_root_uri = self._get_run_artifact_root(self.run_id)
        run_artifact_root_path = extract_and_normalize_path(run_artifact_root_uri)
        run_relative_root_path = posixpath.relpath(
            path=artifact_repo_root_path, start=run_artifact_root_path
        )
        # If the paths are equal, then use empty string over "./" for ListArtifact compatibility.
        self.run_relative_artifact_repo_root_path = (
            "" if run_artifact_root_path == artifact_repo_root_path else run_relative_root_path
        )
        # Limit the number of threads used for artifact uploads, using at most 8 threads or
        # 2 * the number of CPU cores available on the system (whichever is smaller)
        num_cpus = os.cpu_count() or 4
        num_artifact_workers = min(num_cpus * 2, 8)
        self.thread_pool = ThreadPoolExecutor(max_workers=num_artifact_workers)
示例#8
0
def dbfs_artifact_repo_factory(artifact_uri):
    """
    Returns an ArtifactRepository subclass for storing artifacts on DBFS.

    This factory method is used with URIs of the form ``dbfs:/<path>``. DBFS-backed artifact
    storage can only be used together with the RestStore.

    In the special case where the URI is of the form
    `dbfs:/databricks/mlflow-tracking/<Exp-ID>/<Run-ID>/<path>',
    a DatabricksArtifactRepository is returned. This is capable of storing access controlled
    artifacts.

    :param artifact_uri: DBFS root artifact URI (string).
    :return: Subclass of ArtifactRepository capable of storing artifacts on DBFS.
    """
    if not is_valid_dbfs_uri(artifact_uri):
        raise MlflowException(
            "DBFS URI must be of the form dbfs:/<path> or "
            + "dbfs://profile@databricks/<path>, but received "
            + artifact_uri
        )

    cleaned_artifact_uri = artifact_uri.rstrip("/")
    db_profile_uri = get_databricks_profile_uri_from_artifact_uri(cleaned_artifact_uri)
    if is_databricks_acled_artifacts_uri(artifact_uri):
        return DatabricksArtifactRepository(cleaned_artifact_uri)
    elif (
        mlflow.utils.databricks_utils.is_dbfs_fuse_available()
        and os.environ.get(USE_FUSE_ENV_VAR, "").lower() != "false"
        and not is_databricks_model_registry_artifacts_uri(artifact_uri)
        and (db_profile_uri is None or db_profile_uri == "databricks")
    ):
        # If the DBFS FUSE mount is available, write artifacts directly to
        # /dbfs/... using local filesystem APIs.
        # Note: it is possible for a named Databricks profile to point to the current workspace,
        # but we're going to avoid doing a complex check and assume users will use `databricks`
        # to mean the current workspace. Using `DbfsRestArtifactRepository` to access the current
        # workspace's DBFS should still work; it just may be slower.
        final_artifact_uri = remove_databricks_profile_info_from_artifact_uri(cleaned_artifact_uri)
        file_uri = "file:///dbfs/{}".format(strip_prefix(final_artifact_uri, "dbfs:/"))
        return LocalArtifactRepository(file_uri)
    return DbfsRestArtifactRepository(cleaned_artifact_uri)
示例#9
0
def test_get_databricks_profile_uri_from_artifact_uri_error_cases(uri):
    with pytest.raises(MlflowException,
                       match="Unsupported Databricks profile"):
        get_databricks_profile_uri_from_artifact_uri(uri)
示例#10
0
def test_get_databricks_profile_uri_from_artifact_uri(uri, result):
    assert get_databricks_profile_uri_from_artifact_uri(uri) == result
示例#11
0
def is_using_databricks_registry(uri):
    profile_uri = get_databricks_profile_uri_from_artifact_uri(
        uri) or mlflow.get_registry_uri()
    return is_databricks_uri(profile_uri)
示例#12
0
def test_get_databricks_profile_uri_from_artifact_uri_error_cases(uri):
    with pytest.raises(MlflowException):
        get_databricks_profile_uri_from_artifact_uri(uri)