def __init__(self, artifact_uri): super(DatabricksArtifactRepository, self).__init__(artifact_uri) if not artifact_uri.startswith('dbfs:/'): raise MlflowException( message= 'DatabricksArtifactRepository URI must start with dbfs:/', error_code=INVALID_PARAMETER_VALUE) if not is_databricks_acled_artifacts_uri(artifact_uri): raise MlflowException( message=('Artifact URI incorrect. Expected path prefix to be' ' databricks/mlflow-tracking/path/to/artifact/..'), error_code=INVALID_PARAMETER_VALUE) self.run_id = self._extract_run_id(self.artifact_uri) # Fetch the artifact root for the MLflow Run associated with `artifact_uri` and compute # the path of `artifact_uri` relative to the MLflow Run's artifact root # (the `run_relative_artifact_repo_root_path`). All operations performed on this artifact # repository will be performed relative to this computed location artifact_repo_root_path = extract_and_normalize_path(artifact_uri) run_artifact_root_uri = self._get_run_artifact_root(self.run_id) run_artifact_root_path = extract_and_normalize_path( run_artifact_root_uri) run_relative_root_path = posixpath.relpath( path=artifact_repo_root_path, start=run_artifact_root_path) # If the paths are equal, then use empty string over "./" for ListArtifact compatibility. self.run_relative_artifact_repo_root_path = \ "" if run_artifact_root_path == artifact_repo_root_path else run_relative_root_path
def __init__(self, artifact_uri): if not is_valid_dbfs_uri(artifact_uri): raise MlflowException(message="DBFS URI must be of the form dbfs:/<path> or " + "dbfs://profile@databricks/<path>", error_code=INVALID_PARAMETER_VALUE) if not is_databricks_acled_artifacts_uri(artifact_uri): raise MlflowException(message=('Artifact URI incorrect. Expected path prefix to be' ' databricks/mlflow-tracking/path/to/artifact/..'), error_code=INVALID_PARAMETER_VALUE) # The dbfs:/ path ultimately used for artifact operations should not contain the # Databricks profile info, so strip it before setting ``artifact_uri``. super(DatabricksArtifactRepository, self).__init__( remove_databricks_profile_info_from_artifact_uri(artifact_uri)) self.databricks_profile_uri = get_databricks_profile_uri_from_artifact_uri(artifact_uri) \ or mlflow.tracking.get_tracking_uri() self.run_id = self._extract_run_id(self.artifact_uri) # Fetch the artifact root for the MLflow Run associated with `artifact_uri` and compute # the path of `artifact_uri` relative to the MLflow Run's artifact root # (the `run_relative_artifact_repo_root_path`). All operations performed on this artifact # repository will be performed relative to this computed location artifact_repo_root_path = extract_and_normalize_path(artifact_uri) run_artifact_root_uri = self._get_run_artifact_root(self.run_id) run_artifact_root_path = extract_and_normalize_path(run_artifact_root_uri) run_relative_root_path = posixpath.relpath( path=artifact_repo_root_path, start=run_artifact_root_path ) # If the paths are equal, then use empty string over "./" for ListArtifact compatibility. self.run_relative_artifact_repo_root_path = \ "" if run_artifact_root_path == artifact_repo_root_path else run_relative_root_path
def __init__(self, artifact_uri): if not is_valid_dbfs_uri(artifact_uri): raise MlflowException( message="DBFS URI must be of the form dbfs:/<path> or " + "dbfs://profile@databricks/<path>", error_code=INVALID_PARAMETER_VALUE, ) if not is_databricks_acled_artifacts_uri(artifact_uri): raise MlflowException( message=( "Artifact URI incorrect. Expected path prefix to be" " databricks/mlflow-tracking/path/to/artifact/.." ), error_code=INVALID_PARAMETER_VALUE, ) # The dbfs:/ path ultimately used for artifact operations should not contain the # Databricks profile info, so strip it before setting ``artifact_uri``. super().__init__(remove_databricks_profile_info_from_artifact_uri(artifact_uri)) self.databricks_profile_uri = ( get_databricks_profile_uri_from_artifact_uri(artifact_uri) or mlflow.tracking.get_tracking_uri() ) self.run_id = self._extract_run_id(self.artifact_uri) # Fetch the artifact root for the MLflow Run associated with `artifact_uri` and compute # the path of `artifact_uri` relative to the MLflow Run's artifact root # (the `run_relative_artifact_repo_root_path`). All operations performed on this artifact # repository will be performed relative to this computed location artifact_repo_root_path = extract_and_normalize_path(artifact_uri) run_artifact_root_uri = self._get_run_artifact_root(self.run_id) run_artifact_root_path = extract_and_normalize_path(run_artifact_root_uri) run_relative_root_path = posixpath.relpath( path=artifact_repo_root_path, start=run_artifact_root_path ) # If the paths are equal, then use empty string over "./" for ListArtifact compatibility. self.run_relative_artifact_repo_root_path = ( "" if run_artifact_root_path == artifact_repo_root_path else run_relative_root_path ) # Limit the number of threads used for artifact uploads, using at most 8 threads or # 2 * the number of CPU cores available on the system (whichever is smaller) num_cpus = os.cpu_count() or 4 num_artifact_workers = min(num_cpus * 2, 8) self.thread_pool = ThreadPoolExecutor(max_workers=num_artifact_workers)
def test_extract_and_normalize_path(): base_uri = "databricks/mlflow-tracking/EXP_ID/RUN_ID/artifacts" assert ( extract_and_normalize_path("dbfs:databricks/mlflow-tracking/EXP_ID/RUN_ID/artifacts") == base_uri ) assert ( extract_and_normalize_path("dbfs:/databricks/mlflow-tracking/EXP_ID/RUN_ID/artifacts") == base_uri ) assert ( extract_and_normalize_path("dbfs:///databricks/mlflow-tracking/EXP_ID/RUN_ID/artifacts") == base_uri ) assert ( extract_and_normalize_path( "dbfs:/databricks///mlflow-tracking///EXP_ID///RUN_ID///artifacts/" ) == base_uri ) assert ( extract_and_normalize_path( "dbfs:///databricks///mlflow-tracking//EXP_ID//RUN_ID///artifacts//" ) == base_uri ) assert ( extract_and_normalize_path( "dbfs:databricks///mlflow-tracking//EXP_ID//RUN_ID///artifacts//" ) == base_uri )
def _extract_run_id(artifact_uri): """ The artifact_uri is expected to be dbfs:/databricks/mlflow-tracking/<EXP_ID>/<RUN_ID>/artifacts/<path> Once the path from the input uri is extracted and normalized, it is expected to be of the form databricks/mlflow-tracking/<EXP_ID>/<RUN_ID>/artifacts/<path> Hence the run_id is the 4th element of the normalized path. :return: run_id extracted from the artifact_uri """ artifact_path = extract_and_normalize_path(artifact_uri) return artifact_path.split('/')[3]
def test_extract_and_normalize_path(): base_uri = 'databricks/mlflow-tracking/EXP_ID/RUN_ID/artifacts' assert extract_and_normalize_path( 'dbfs:databricks/mlflow-tracking/EXP_ID/RUN_ID/artifacts') == base_uri assert extract_and_normalize_path( 'dbfs:/databricks/mlflow-tracking/EXP_ID/RUN_ID/artifacts') == base_uri assert extract_and_normalize_path( 'dbfs:///databricks/mlflow-tracking/EXP_ID/RUN_ID/artifacts' ) == base_uri assert extract_and_normalize_path( 'dbfs:/databricks///mlflow-tracking///EXP_ID///RUN_ID///artifacts/' ) == base_uri assert extract_and_normalize_path( 'dbfs:///databricks///mlflow-tracking//EXP_ID//RUN_ID///artifacts//' ) == base_uri assert extract_and_normalize_path( 'dbfs:databricks///mlflow-tracking//EXP_ID//RUN_ID///artifacts//' ) == base_uri