def __init__(self, artifact_uri): if not is_valid_dbfs_uri(artifact_uri): raise MlflowException(message="DBFS URI must be of the form dbfs:/<path> or " + "dbfs://profile@databricks/<path>", error_code=INVALID_PARAMETER_VALUE) if not is_databricks_acled_artifacts_uri(artifact_uri): raise MlflowException(message=('Artifact URI incorrect. Expected path prefix to be' ' databricks/mlflow-tracking/path/to/artifact/..'), error_code=INVALID_PARAMETER_VALUE) # The dbfs:/ path ultimately used for artifact operations should not contain the # Databricks profile info, so strip it before setting ``artifact_uri``. super(DatabricksArtifactRepository, self).__init__( remove_databricks_profile_info_from_artifact_uri(artifact_uri)) self.databricks_profile_uri = get_databricks_profile_uri_from_artifact_uri(artifact_uri) \ or mlflow.tracking.get_tracking_uri() self.run_id = self._extract_run_id(self.artifact_uri) # Fetch the artifact root for the MLflow Run associated with `artifact_uri` and compute # the path of `artifact_uri` relative to the MLflow Run's artifact root # (the `run_relative_artifact_repo_root_path`). All operations performed on this artifact # repository will be performed relative to this computed location artifact_repo_root_path = extract_and_normalize_path(artifact_uri) run_artifact_root_uri = self._get_run_artifact_root(self.run_id) run_artifact_root_path = extract_and_normalize_path(run_artifact_root_uri) run_relative_root_path = posixpath.relpath( path=artifact_repo_root_path, start=run_artifact_root_path ) # If the paths are equal, then use empty string over "./" for ListArtifact compatibility. self.run_relative_artifact_repo_root_path = \ "" if run_artifact_root_path == artifact_repo_root_path else run_relative_root_path
def get_underlying_uri(runs_uri): from mlflow.tracking.artifact_utils import get_artifact_uri (run_id, artifact_path) = RunsArtifactRepository.parse_runs_uri(runs_uri) tracking_uri = get_databricks_profile_uri_from_artifact_uri(runs_uri) uri = get_artifact_uri(run_id, artifact_path, tracking_uri) assert not RunsArtifactRepository.is_runs_uri( uri) # avoid an infinite loop return add_databricks_profile_info_to_artifact_uri(uri, tracking_uri)
def get_underlying_uri(uri): # Note: to support a registry URI that is different from the tracking URI here, # we'll need to add setting of registry URIs via environment variables. from mlflow.tracking import MlflowClient databricks_profile_uri = ( get_databricks_profile_uri_from_artifact_uri(uri) or mlflow.get_registry_uri() ) client = MlflowClient(registry_uri=databricks_profile_uri) (name, version) = get_model_name_and_version(client, uri) download_uri = client.get_model_version_download_uri(name, version) return add_databricks_profile_info_to_artifact_uri(download_uri, databricks_profile_uri)
def __init__(self, artifact_uri): if not is_using_databricks_registry(artifact_uri): raise MlflowException( message= "A valid databricks profile is required to instantiate this repository", error_code=INVALID_PARAMETER_VALUE, ) super().__init__(artifact_uri) from mlflow.tracking import MlflowClient self.databricks_profile_uri = ( get_databricks_profile_uri_from_artifact_uri(artifact_uri) or mlflow.get_registry_uri()) client = MlflowClient(registry_uri=self.databricks_profile_uri) self.model_name, self.model_version = get_model_name_and_version( client, artifact_uri)
def __init__(self, artifact_uri): if not is_valid_dbfs_uri(artifact_uri): raise MlflowException(message="DBFS URI must be of the form dbfs:/<path> or " + "dbfs://profile@databricks/<path>", error_code=INVALID_PARAMETER_VALUE) # The dbfs:/ path ultimately used for artifact operations should not contain the # Databricks profile info, so strip it before setting ``artifact_uri``. super(DbfsRestArtifactRepository, self).__init__( remove_databricks_profile_info_from_artifact_uri(artifact_uri)) databricks_profile_uri = get_databricks_profile_uri_from_artifact_uri(artifact_uri) if databricks_profile_uri: hostcreds_from_uri = get_databricks_host_creds(databricks_profile_uri) self.get_host_creds = lambda: hostcreds_from_uri else: self.get_host_creds = _get_host_creds_from_default_store()
def get_underlying_uri(uri): # Note: to support a registry URI that is different from the tracking URI here, # we'll need to add setting of registry URIs via environment variables. from mlflow.tracking import MlflowClient databricks_profile_uri = get_databricks_profile_uri_from_artifact_uri( uri) client = MlflowClient(registry_uri=databricks_profile_uri) (name, version, stage) = ModelsArtifactRepository._parse_uri(uri) if stage is not None: latest = client.get_latest_versions(name, [stage]) if len(latest) == 0: raise MlflowException( "No versions of model with name '{name}' and " "stage '{stage}' found".format(name=name, stage=stage)) version = latest[0].version download_uri = client.get_model_version_download_uri(name, version) return add_databricks_profile_info_to_artifact_uri( download_uri, databricks_profile_uri)
def __init__(self, artifact_uri): if not is_valid_dbfs_uri(artifact_uri): raise MlflowException( message="DBFS URI must be of the form dbfs:/<path> or " + "dbfs://profile@databricks/<path>", error_code=INVALID_PARAMETER_VALUE, ) if not is_databricks_acled_artifacts_uri(artifact_uri): raise MlflowException( message=( "Artifact URI incorrect. Expected path prefix to be" " databricks/mlflow-tracking/path/to/artifact/.." ), error_code=INVALID_PARAMETER_VALUE, ) # The dbfs:/ path ultimately used for artifact operations should not contain the # Databricks profile info, so strip it before setting ``artifact_uri``. super().__init__(remove_databricks_profile_info_from_artifact_uri(artifact_uri)) self.databricks_profile_uri = ( get_databricks_profile_uri_from_artifact_uri(artifact_uri) or mlflow.tracking.get_tracking_uri() ) self.run_id = self._extract_run_id(self.artifact_uri) # Fetch the artifact root for the MLflow Run associated with `artifact_uri` and compute # the path of `artifact_uri` relative to the MLflow Run's artifact root # (the `run_relative_artifact_repo_root_path`). All operations performed on this artifact # repository will be performed relative to this computed location artifact_repo_root_path = extract_and_normalize_path(artifact_uri) run_artifact_root_uri = self._get_run_artifact_root(self.run_id) run_artifact_root_path = extract_and_normalize_path(run_artifact_root_uri) run_relative_root_path = posixpath.relpath( path=artifact_repo_root_path, start=run_artifact_root_path ) # If the paths are equal, then use empty string over "./" for ListArtifact compatibility. self.run_relative_artifact_repo_root_path = ( "" if run_artifact_root_path == artifact_repo_root_path else run_relative_root_path ) # Limit the number of threads used for artifact uploads, using at most 8 threads or # 2 * the number of CPU cores available on the system (whichever is smaller) num_cpus = os.cpu_count() or 4 num_artifact_workers = min(num_cpus * 2, 8) self.thread_pool = ThreadPoolExecutor(max_workers=num_artifact_workers)
def dbfs_artifact_repo_factory(artifact_uri): """ Returns an ArtifactRepository subclass for storing artifacts on DBFS. This factory method is used with URIs of the form ``dbfs:/<path>``. DBFS-backed artifact storage can only be used together with the RestStore. In the special case where the URI is of the form `dbfs:/databricks/mlflow-tracking/<Exp-ID>/<Run-ID>/<path>', a DatabricksArtifactRepository is returned. This is capable of storing access controlled artifacts. :param artifact_uri: DBFS root artifact URI (string). :return: Subclass of ArtifactRepository capable of storing artifacts on DBFS. """ if not is_valid_dbfs_uri(artifact_uri): raise MlflowException( "DBFS URI must be of the form dbfs:/<path> or " + "dbfs://profile@databricks/<path>, but received " + artifact_uri ) cleaned_artifact_uri = artifact_uri.rstrip("/") db_profile_uri = get_databricks_profile_uri_from_artifact_uri(cleaned_artifact_uri) if is_databricks_acled_artifacts_uri(artifact_uri): return DatabricksArtifactRepository(cleaned_artifact_uri) elif ( mlflow.utils.databricks_utils.is_dbfs_fuse_available() and os.environ.get(USE_FUSE_ENV_VAR, "").lower() != "false" and not is_databricks_model_registry_artifacts_uri(artifact_uri) and (db_profile_uri is None or db_profile_uri == "databricks") ): # If the DBFS FUSE mount is available, write artifacts directly to # /dbfs/... using local filesystem APIs. # Note: it is possible for a named Databricks profile to point to the current workspace, # but we're going to avoid doing a complex check and assume users will use `databricks` # to mean the current workspace. Using `DbfsRestArtifactRepository` to access the current # workspace's DBFS should still work; it just may be slower. final_artifact_uri = remove_databricks_profile_info_from_artifact_uri(cleaned_artifact_uri) file_uri = "file:///dbfs/{}".format(strip_prefix(final_artifact_uri, "dbfs:/")) return LocalArtifactRepository(file_uri) return DbfsRestArtifactRepository(cleaned_artifact_uri)
def test_get_databricks_profile_uri_from_artifact_uri_error_cases(uri): with pytest.raises(MlflowException, match="Unsupported Databricks profile"): get_databricks_profile_uri_from_artifact_uri(uri)
def test_get_databricks_profile_uri_from_artifact_uri(uri, result): assert get_databricks_profile_uri_from_artifact_uri(uri) == result
def is_using_databricks_registry(uri): profile_uri = get_databricks_profile_uri_from_artifact_uri( uri) or mlflow.get_registry_uri() return is_databricks_uri(profile_uri)
def test_get_databricks_profile_uri_from_artifact_uri_error_cases(uri): with pytest.raises(MlflowException): get_databricks_profile_uri_from_artifact_uri(uri)