Ejemplo n.º 1
0
def dbfs_artifact_repo_factory(artifact_uri):
    """
    Returns an ArtifactRepository subclass for storing artifacts on DBFS.

    This factory method is used with URIs of the form ``dbfs:/<path>``. DBFS-backed artifact
    storage can only be used together with the RestStore.

    In the special case where the URI is of the form
    `dbfs:/databricks/mlflow-tracking/<Exp-ID>/<Run-ID>/<path>',
    a DatabricksArtifactRepository is returned. This is capable of storing access controlled
    artifacts.

    :param artifact_uri: DBFS root artifact URI (string).
    :return: Subclass of ArtifactRepository capable of storing artifacts on DBFS.
    """
    cleaned_artifact_uri = artifact_uri.rstrip('/')
    uri_scheme = get_uri_scheme(artifact_uri)
    if uri_scheme != 'dbfs':
        raise MlflowException(
            "DBFS URI must be of the form "
            "dbfs:/<path>, but received {uri}".format(uri=artifact_uri))
    if is_databricks_acled_artifacts_uri(artifact_uri):
        return DatabricksArtifactRepository(cleaned_artifact_uri)
    elif mlflow.utils.databricks_utils.is_dbfs_fuse_available() \
            and os.environ.get(USE_FUSE_ENV_VAR, "").lower() != "false" \
            and not artifact_uri.startswith("dbfs:/databricks/mlflow-registry"):
        # If the DBFS FUSE mount is available, write artifacts directly to /dbfs/... using
        # local filesystem APIs
        file_uri = "file:///dbfs/{}".format(
            strip_prefix(cleaned_artifact_uri, "dbfs:/"))
        return LocalArtifactRepository(file_uri)
    return DbfsRestArtifactRepository(cleaned_artifact_uri)
Ejemplo n.º 2
0
 def test_init_artifact_uri(self, artifact_uri, expected_uri, expected_db_uri):
     with mock.patch(
         DATABRICKS_ARTIFACT_REPOSITORY_PACKAGE + ".get_databricks_host_creds", return_value=None
     ), mock.patch(
         DATABRICKS_ARTIFACT_REPOSITORY + "._get_run_artifact_root", return_value="whatever"
     ), mock.patch(
         "mlflow.tracking.get_tracking_uri", return_value="databricks://getTrackingUriDefault"
     ):
         repo = DatabricksArtifactRepository(artifact_uri)
         assert repo.artifact_uri == expected_uri
         assert repo.databricks_profile_uri == expected_db_uri
Ejemplo n.º 3
0
    def test_init_validation_and_cleaning(self):
        with mock.patch(
            DATABRICKS_ARTIFACT_REPOSITORY + "._get_run_artifact_root"
        ) as get_run_artifact_root_mock:
            get_run_artifact_root_mock.return_value = MOCK_RUN_ROOT_URI
            # Basic artifact uri
            repo = get_artifact_repository(
                "dbfs:/databricks/mlflow-tracking/MOCK-EXP/MOCK-RUN-ID/artifacts"
            )
            assert (
                repo.artifact_uri == "dbfs:/databricks/mlflow-tracking/"
                "MOCK-EXP/MOCK-RUN-ID/artifacts"
            )
            assert repo.run_id == MOCK_RUN_ID
            assert repo.run_relative_artifact_repo_root_path == ""

            with pytest.raises(MlflowException):
                DatabricksArtifactRepository("s3://test")
            with pytest.raises(MlflowException):
                DatabricksArtifactRepository("dbfs:/databricks/mlflow/EXP/RUN/artifact")
            with pytest.raises(MlflowException):
                DatabricksArtifactRepository(
                    "dbfs://*****:*****@notdatabricks/databricks/mlflow-tracking/experiment/1/run/2"
                )
Ejemplo n.º 4
0
def dbfs_artifact_repo_factory(artifact_uri):
    """
    Returns an ArtifactRepository subclass for storing artifacts on DBFS.

    This factory method is used with URIs of the form ``dbfs:/<path>``. DBFS-backed artifact
    storage can only be used together with the RestStore.

    In the special case where the URI is of the form
    `dbfs:/databricks/mlflow-tracking/<Exp-ID>/<Run-ID>/<path>',
    a DatabricksArtifactRepository is returned. This is capable of storing access controlled
    artifacts.

    :param artifact_uri: DBFS root artifact URI (string).
    :return: Subclass of ArtifactRepository capable of storing artifacts on DBFS.
    """
    if not is_valid_dbfs_uri(artifact_uri):
        raise MlflowException(
            "DBFS URI must be of the form dbfs:/<path> or "
            + "dbfs://profile@databricks/<path>, but received "
            + artifact_uri
        )

    cleaned_artifact_uri = artifact_uri.rstrip("/")
    db_profile_uri = get_databricks_profile_uri_from_artifact_uri(cleaned_artifact_uri)
    if is_databricks_acled_artifacts_uri(artifact_uri):
        return DatabricksArtifactRepository(cleaned_artifact_uri)
    elif (
        mlflow.utils.databricks_utils.is_dbfs_fuse_available()
        and os.environ.get(USE_FUSE_ENV_VAR, "").lower() != "false"
        and not is_databricks_model_registry_artifacts_uri(artifact_uri)
        and (db_profile_uri is None or db_profile_uri == "databricks")
    ):
        # If the DBFS FUSE mount is available, write artifacts directly to
        # /dbfs/... using local filesystem APIs.
        # Note: it is possible for a named Databricks profile to point to the current workspace,
        # but we're going to avoid doing a complex check and assume users will use `databricks`
        # to mean the current workspace. Using `DbfsRestArtifactRepository` to access the current
        # workspace's DBFS should still work; it just may be slower.
        final_artifact_uri = remove_databricks_profile_info_from_artifact_uri(cleaned_artifact_uri)
        file_uri = "file:///dbfs/{}".format(strip_prefix(final_artifact_uri, "dbfs:/"))
        return LocalArtifactRepository(file_uri)
    return DbfsRestArtifactRepository(cleaned_artifact_uri)