Ejemplo n.º 1
0
def _upload_artifacts_to_databricks(
    source, run_id, source_host_uri=None, target_databricks_profile_uri=None
):
    """
    Copy the artifacts from ``source`` to the destination Databricks workspace (DBFS) given by
    ``databricks_profile_uri`` or the current tracking URI.
    :param source: Source location for the artifacts to copy.
    :param run_id: Run ID to associate the artifacts with.
    :param source_host_uri: Specifies the source artifact's host URI (e.g. Databricks tracking URI)
        if applicable. If not given, defaults to the current tracking URI.
    :param target_databricks_profile_uri: Specifies the destination Databricks host. If not given,
        defaults to the current tracking URI.
    :return: The DBFS location in the target Databricks workspace the model files have been
        uploaded to.
    """
    from uuid import uuid4

    local_dir = tempfile.mkdtemp()
    try:
        source_with_profile = add_databricks_profile_info_to_artifact_uri(source, source_host_uri)
        _download_artifact_from_uri(source_with_profile, local_dir)
        dest_root = "dbfs:/databricks/mlflow/tmp-external-source/"
        dest_root_with_profile = add_databricks_profile_info_to_artifact_uri(
            dest_root, target_databricks_profile_uri
        )
        dest_repo = DbfsRestArtifactRepository(dest_root_with_profile)
        dest_artifact_path = run_id if run_id else uuid4().hex
        # Allow uploading from the same run id multiple times by randomizing a suffix
        if len(dest_repo.list_artifacts(dest_artifact_path)) > 0:
            dest_artifact_path = dest_artifact_path + "-" + uuid4().hex[0:4]
        dest_repo.log_artifacts(local_dir, artifact_path=dest_artifact_path)
        dirname = pathlib.PurePath(source).name  # innermost directory name
        return posixpath.join(dest_root, dest_artifact_path, dirname)  # new source
    finally:
        shutil.rmtree(local_dir)
Ejemplo n.º 2
0
 def test_init_validation_and_cleaning(self):
     with mock.patch(DBFS_ARTIFACT_REPOSITORY_PACKAGE + '._get_host_creds_from_default_store') \
             as get_creds_mock:
         get_creds_mock.return_value = lambda: MlflowHostCreds('http://host'
                                                               )
         repo = get_artifact_repository('dbfs:/test/')
         assert repo.artifact_uri == 'dbfs:/test'
         with pytest.raises(MlflowException):
             DbfsRestArtifactRepository('s3://test')
         with pytest.raises(MlflowException):
             DbfsRestArtifactRepository(
                 'dbfs://profile@notdatabricks/test/')
Ejemplo n.º 3
0
 def test_init_validation_and_cleaning(self):
     with mock.patch(
             DBFS_ARTIFACT_REPOSITORY_PACKAGE +
             "._get_host_creds_from_default_store") as get_creds_mock:
         get_creds_mock.return_value = lambda: MlflowHostCreds("http://host"
                                                               )
         repo = get_artifact_repository("dbfs:/test/")
         assert repo.artifact_uri == "dbfs:/test"
         match = "DBFS URI must be of the form dbfs:/<path>"
         with pytest.raises(MlflowException, match=match):
             DbfsRestArtifactRepository("s3://test")
         with pytest.raises(MlflowException, match=match):
             DbfsRestArtifactRepository(
                 "dbfs://profile@notdatabricks/test/")
Ejemplo n.º 4
0
 def get_dbfs_artifact_repo(self, artifact_uri):
     try:
         from mlflow.store.artifact.dbfs_artifact_repo import DbfsRestArtifactRepository
     except ImportError:
         logger.warning(
             VERSION_WARNING.format(
                 "DbfsRestArtifactRepository from " +
                 "mlflow.store.artifact.dbfs_artifact_repo"))
         from mlflow.store.dbfs_artifact_repo import DbfsRestArtifactRepository
     # dbfs artifact_uri has to have dbfs as prefix
     # by default it is dbfs:/databricks/mlflow/<exp_id>/<run_id>/artifacts
     artifact_uri = artifact_uri.replace("adbazureml", "dbfs", 1)
     logger.info("DBFS artifact uri is {}".format(artifact_uri))
     return DbfsRestArtifactRepository(artifact_uri)
Ejemplo n.º 5
0
 def test_init_get_host_creds_with_databricks_profile_uri(self):
     databricks_host = 'https://something.databricks.com'
     default_host = 'http://host'
     with mock.patch(DBFS_ARTIFACT_REPOSITORY_PACKAGE + '._get_host_creds_from_default_store',
                     return_value=lambda: MlflowHostCreds(default_host)), \
             mock.patch(DBFS_ARTIFACT_REPOSITORY_PACKAGE + '.get_databricks_host_creds',
                        return_value=MlflowHostCreds(databricks_host)):
         repo = DbfsRestArtifactRepository(
             'dbfs://profile@databricks/test/')
         assert repo.artifact_uri == 'dbfs:/test/'
         creds = repo.get_host_creds()
         assert creds.host == databricks_host
         # no databricks_profile_uri given
         repo = DbfsRestArtifactRepository('dbfs:/test/')
         creds = repo.get_host_creds()
         assert creds.host == default_host
Ejemplo n.º 6
0
 def test_init_get_host_creds_with_databricks_profile_uri(self):
     databricks_host = "https://something.databricks.com"
     default_host = "http://host"
     with mock.patch(
             DBFS_ARTIFACT_REPOSITORY_PACKAGE +
             "._get_host_creds_from_default_store",
             return_value=lambda: MlflowHostCreds(default_host),
     ), mock.patch(
             DBFS_ARTIFACT_REPOSITORY_PACKAGE +
             ".get_databricks_host_creds",
             return_value=MlflowHostCreds(databricks_host),
     ):
         repo = DbfsRestArtifactRepository(
             "dbfs://profile@databricks/test/")
         assert repo.artifact_uri == "dbfs:/test/"
         creds = repo.get_host_creds()
         assert creds.host == databricks_host
         # no databricks_profile_uri given
         repo = DbfsRestArtifactRepository("dbfs:/test/")
         creds = repo.get_host_creds()
         assert creds.host == default_host
def dbfs_artifact_repo_factory(artifact_uri):
    """
    Returns an ArtifactRepository subclass for storing artifacts on DBFS.

    This factory method is used with URIs of the form ``dbfs:/<path>``. DBFS-backed artifact
    storage can only be used together with the RestStore.

    In the special case where the URI is of the form
    `dbfs:/databricks/mlflow-tracking/<Exp-ID>/<Run-ID>/<path>',
    a DatabricksArtifactRepository is returned. This is capable of storing access controlled
    artifacts.

    :param artifact_uri: DBFS root artifact URI (string).
    :return: Subclass of ArtifactRepository capable of storing artifacts on DBFS.
    """
    try:
        if supports_acled_artifacts(mlflow.__version__):
            from mlflow.store.artifact.dbfs_artifact_repo import dbfs_artifact_repo_factory
            return dbfs_artifact_repo_factory(artifact_uri)
    except Exception:
        pass

    # For some reason, we must import modules specific to this package within the
    # entrypoint function rather than the top-level module. Otherwise, entrypoint
    # registration fails with import errors
    from mlflow_databricks_artifacts.store.artifact_repo import DatabricksArtifactRepository
    from mlflow_databricks_artifacts.utils.databricks_utils import is_dbfs_fuse_available
    from mlflow_databricks_artifacts.utils.string_utils import strip_prefix
    from mlflow_databricks_artifacts.utils.uri import (
        get_databricks_profile_uri_from_artifact_uri,
        is_databricks_acled_artifacts_uri,
        is_databricks_model_registry_artifacts_uri,
        is_valid_dbfs_uri,
        remove_databricks_profile_info_from_artifact_uri,
    )

    if not is_valid_dbfs_uri(artifact_uri):
        raise MlflowException(
            "DBFS URI must be of the form dbfs:/<path> or "
            + "dbfs://profile@databricks/<path>, but received "
            + artifact_uri
        )

    cleaned_artifact_uri = artifact_uri.rstrip("/")
    db_profile_uri = get_databricks_profile_uri_from_artifact_uri(cleaned_artifact_uri)
    if is_databricks_acled_artifacts_uri(artifact_uri):
        return DatabricksArtifactRepository(cleaned_artifact_uri)
    elif (
        is_dbfs_fuse_available()
        and os.environ.get(USE_FUSE_ENV_VAR, "").lower() != "false"
        and not is_databricks_model_registry_artifacts_uri(artifact_uri)
        and (db_profile_uri is None or db_profile_uri == "databricks")
    ):
        # If the DBFS FUSE mount is available, write artifacts directly to
        # /dbfs/... using local filesystem APIs.
        # Note: it is possible for a named Databricks profile to point to the current workspace,
        # but we're going to avoid doing a complex check and assume users will use `databricks`
        # to mean the current workspace. Using `DbfsRestArtifactRepository` to access the current
        # workspace's DBFS should still work; it just may be slower.
        final_artifact_uri = remove_databricks_profile_info_from_artifact_uri(cleaned_artifact_uri)
        file_uri = "file:///dbfs/{}".format(strip_prefix(final_artifact_uri, "dbfs:/"))
        return LocalArtifactRepository(file_uri)
    return DbfsRestArtifactRepository(cleaned_artifact_uri)