Beispiel #1
0
    def __init__(self, artifact_uri):
        super(DatabricksArtifactRepository, self).__init__(artifact_uri)
        if not artifact_uri.startswith('dbfs:/'):
            raise MlflowException(
                message=
                'DatabricksArtifactRepository URI must start with dbfs:/',
                error_code=INVALID_PARAMETER_VALUE)
        if not is_databricks_acled_artifacts_uri(artifact_uri):
            raise MlflowException(
                message=('Artifact URI incorrect. Expected path prefix to be'
                         ' databricks/mlflow-tracking/path/to/artifact/..'),
                error_code=INVALID_PARAMETER_VALUE)
        self.run_id = self._extract_run_id(self.artifact_uri)

        # Fetch the artifact root for the MLflow Run associated with `artifact_uri` and compute
        # the path of `artifact_uri` relative to the MLflow Run's artifact root
        # (the `run_relative_artifact_repo_root_path`). All operations performed on this artifact
        # repository will be performed relative to this computed location
        artifact_repo_root_path = extract_and_normalize_path(artifact_uri)
        run_artifact_root_uri = self._get_run_artifact_root(self.run_id)
        run_artifact_root_path = extract_and_normalize_path(
            run_artifact_root_uri)
        run_relative_root_path = posixpath.relpath(
            path=artifact_repo_root_path, start=run_artifact_root_path)
        # If the paths are equal, then use empty string over "./" for ListArtifact compatibility.
        self.run_relative_artifact_repo_root_path = \
            "" if run_artifact_root_path == artifact_repo_root_path else run_relative_root_path
    def __init__(self, artifact_uri):
        if not is_valid_dbfs_uri(artifact_uri):
            raise MlflowException(message="DBFS URI must be of the form dbfs:/<path> or " +
                                  "dbfs://profile@databricks/<path>",
                                  error_code=INVALID_PARAMETER_VALUE)
        if not is_databricks_acled_artifacts_uri(artifact_uri):
            raise MlflowException(message=('Artifact URI incorrect. Expected path prefix to be'
                                           ' databricks/mlflow-tracking/path/to/artifact/..'),
                                  error_code=INVALID_PARAMETER_VALUE)
        # The dbfs:/ path ultimately used for artifact operations should not contain the
        # Databricks profile info, so strip it before setting ``artifact_uri``.
        super(DatabricksArtifactRepository, self).__init__(
            remove_databricks_profile_info_from_artifact_uri(artifact_uri))

        self.databricks_profile_uri = get_databricks_profile_uri_from_artifact_uri(artifact_uri) \
            or mlflow.tracking.get_tracking_uri()
        self.run_id = self._extract_run_id(self.artifact_uri)
        # Fetch the artifact root for the MLflow Run associated with `artifact_uri` and compute
        # the path of `artifact_uri` relative to the MLflow Run's artifact root
        # (the `run_relative_artifact_repo_root_path`). All operations performed on this artifact
        # repository will be performed relative to this computed location
        artifact_repo_root_path = extract_and_normalize_path(artifact_uri)
        run_artifact_root_uri = self._get_run_artifact_root(self.run_id)
        run_artifact_root_path = extract_and_normalize_path(run_artifact_root_uri)
        run_relative_root_path = posixpath.relpath(
            path=artifact_repo_root_path, start=run_artifact_root_path
        )
        # If the paths are equal, then use empty string over "./" for ListArtifact compatibility.
        self.run_relative_artifact_repo_root_path = \
            "" if run_artifact_root_path == artifact_repo_root_path else run_relative_root_path
Beispiel #3
0
    def __init__(self, artifact_uri):
        if not is_valid_dbfs_uri(artifact_uri):
            raise MlflowException(
                message="DBFS URI must be of the form dbfs:/<path> or "
                + "dbfs://profile@databricks/<path>",
                error_code=INVALID_PARAMETER_VALUE,
            )
        if not is_databricks_acled_artifacts_uri(artifact_uri):
            raise MlflowException(
                message=(
                    "Artifact URI incorrect. Expected path prefix to be"
                    " databricks/mlflow-tracking/path/to/artifact/.."
                ),
                error_code=INVALID_PARAMETER_VALUE,
            )
        # The dbfs:/ path ultimately used for artifact operations should not contain the
        # Databricks profile info, so strip it before setting ``artifact_uri``.
        super().__init__(remove_databricks_profile_info_from_artifact_uri(artifact_uri))

        self.databricks_profile_uri = (
            get_databricks_profile_uri_from_artifact_uri(artifact_uri)
            or mlflow.tracking.get_tracking_uri()
        )
        self.run_id = self._extract_run_id(self.artifact_uri)
        # Fetch the artifact root for the MLflow Run associated with `artifact_uri` and compute
        # the path of `artifact_uri` relative to the MLflow Run's artifact root
        # (the `run_relative_artifact_repo_root_path`). All operations performed on this artifact
        # repository will be performed relative to this computed location
        artifact_repo_root_path = extract_and_normalize_path(artifact_uri)
        run_artifact_root_uri = self._get_run_artifact_root(self.run_id)
        run_artifact_root_path = extract_and_normalize_path(run_artifact_root_uri)
        run_relative_root_path = posixpath.relpath(
            path=artifact_repo_root_path, start=run_artifact_root_path
        )
        # If the paths are equal, then use empty string over "./" for ListArtifact compatibility.
        self.run_relative_artifact_repo_root_path = (
            "" if run_artifact_root_path == artifact_repo_root_path else run_relative_root_path
        )
        # Limit the number of threads used for artifact uploads, using at most 8 threads or
        # 2 * the number of CPU cores available on the system (whichever is smaller)
        num_cpus = os.cpu_count() or 4
        num_artifact_workers = min(num_cpus * 2, 8)
        self.thread_pool = ThreadPoolExecutor(max_workers=num_artifact_workers)
Beispiel #4
0
def test_extract_and_normalize_path():
    base_uri = "databricks/mlflow-tracking/EXP_ID/RUN_ID/artifacts"
    assert (
        extract_and_normalize_path("dbfs:databricks/mlflow-tracking/EXP_ID/RUN_ID/artifacts")
        == base_uri
    )
    assert (
        extract_and_normalize_path("dbfs:/databricks/mlflow-tracking/EXP_ID/RUN_ID/artifacts")
        == base_uri
    )
    assert (
        extract_and_normalize_path("dbfs:///databricks/mlflow-tracking/EXP_ID/RUN_ID/artifacts")
        == base_uri
    )
    assert (
        extract_and_normalize_path(
            "dbfs:/databricks///mlflow-tracking///EXP_ID///RUN_ID///artifacts/"
        )
        == base_uri
    )
    assert (
        extract_and_normalize_path(
            "dbfs:///databricks///mlflow-tracking//EXP_ID//RUN_ID///artifacts//"
        )
        == base_uri
    )
    assert (
        extract_and_normalize_path(
            "dbfs:databricks///mlflow-tracking//EXP_ID//RUN_ID///artifacts//"
        )
        == base_uri
    )
Beispiel #5
0
    def _extract_run_id(artifact_uri):
        """
        The artifact_uri is expected to be
        dbfs:/databricks/mlflow-tracking/<EXP_ID>/<RUN_ID>/artifacts/<path>
        Once the path from the input uri is extracted and normalized, it is
        expected to be of the form
        databricks/mlflow-tracking/<EXP_ID>/<RUN_ID>/artifacts/<path>

        Hence the run_id is the 4th element of the normalized path.

        :return: run_id extracted from the artifact_uri
        """
        artifact_path = extract_and_normalize_path(artifact_uri)
        return artifact_path.split('/')[3]
Beispiel #6
0
def test_extract_and_normalize_path():
    base_uri = 'databricks/mlflow-tracking/EXP_ID/RUN_ID/artifacts'
    assert extract_and_normalize_path(
        'dbfs:databricks/mlflow-tracking/EXP_ID/RUN_ID/artifacts') == base_uri
    assert extract_and_normalize_path(
        'dbfs:/databricks/mlflow-tracking/EXP_ID/RUN_ID/artifacts') == base_uri
    assert extract_and_normalize_path(
        'dbfs:///databricks/mlflow-tracking/EXP_ID/RUN_ID/artifacts'
    ) == base_uri
    assert extract_and_normalize_path(
        'dbfs:/databricks///mlflow-tracking///EXP_ID///RUN_ID///artifacts/'
    ) == base_uri
    assert extract_and_normalize_path(
        'dbfs:///databricks///mlflow-tracking//EXP_ID//RUN_ID///artifacts//'
    ) == base_uri
    assert extract_and_normalize_path(
        'dbfs:databricks///mlflow-tracking//EXP_ID//RUN_ID///artifacts//'
    ) == base_uri