Пример #1
0
def _upload_artifacts_to_databricks(
    source, run_id, source_host_uri=None, target_databricks_profile_uri=None
):
    """
    Copy the artifacts from ``source`` to the destination Databricks workspace (DBFS) given by
    ``databricks_profile_uri`` or the current tracking URI.
    :param source: Source location for the artifacts to copy.
    :param run_id: Run ID to associate the artifacts with.
    :param source_host_uri: Specifies the source artifact's host URI (e.g. Databricks tracking URI)
        if applicable. If not given, defaults to the current tracking URI.
    :param target_databricks_profile_uri: Specifies the destination Databricks host. If not given,
        defaults to the current tracking URI.
    :return: The DBFS location in the target Databricks workspace the model files have been
        uploaded to.
    """
    from uuid import uuid4

    local_dir = tempfile.mkdtemp()
    try:
        source_with_profile = add_databricks_profile_info_to_artifact_uri(source, source_host_uri)
        _download_artifact_from_uri(source_with_profile, local_dir)
        dest_root = "dbfs:/databricks/mlflow/tmp-external-source/"
        dest_root_with_profile = add_databricks_profile_info_to_artifact_uri(
            dest_root, target_databricks_profile_uri
        )
        dest_repo = DbfsRestArtifactRepository(dest_root_with_profile)
        dest_artifact_path = run_id if run_id else uuid4().hex
        # Allow uploading from the same run id multiple times by randomizing a suffix
        if len(dest_repo.list_artifacts(dest_artifact_path)) > 0:
            dest_artifact_path = dest_artifact_path + "-" + uuid4().hex[0:4]
        dest_repo.log_artifacts(local_dir, artifact_path=dest_artifact_path)
        dirname = pathlib.PurePath(source).name  # innermost directory name
        return posixpath.join(dest_root, dest_artifact_path, dirname)  # new source
    finally:
        shutil.rmtree(local_dir)
Пример #2
0
 def get_underlying_uri(runs_uri):
     from mlflow.tracking.artifact_utils import get_artifact_uri
     (run_id,
      artifact_path) = RunsArtifactRepository.parse_runs_uri(runs_uri)
     tracking_uri = get_databricks_profile_uri_from_artifact_uri(runs_uri)
     uri = get_artifact_uri(run_id, artifact_path, tracking_uri)
     assert not RunsArtifactRepository.is_runs_uri(
         uri)  # avoid an infinite loop
     return add_databricks_profile_info_to_artifact_uri(uri, tracking_uri)
Пример #3
0
    def get_underlying_uri(uri):
        # Note: to support a registry URI that is different from the tracking URI here,
        # we'll need to add setting of registry URIs via environment variables.
        from mlflow.tracking import MlflowClient

        databricks_profile_uri = (
            get_databricks_profile_uri_from_artifact_uri(uri) or mlflow.get_registry_uri()
        )
        client = MlflowClient(registry_uri=databricks_profile_uri)
        (name, version) = get_model_name_and_version(client, uri)
        download_uri = client.get_model_version_download_uri(name, version)
        return add_databricks_profile_info_to_artifact_uri(download_uri, databricks_profile_uri)
Пример #4
0
 def _get_artifact_repo(self, run_id):
     # Attempt to fetch the artifact repo from a local cache
     cached_repo = TrackingServiceClient._artifact_repos_cache.get(run_id)
     if cached_repo is not None:
         return cached_repo
     else:
         run = self.get_run(run_id)
         artifact_uri = add_databricks_profile_info_to_artifact_uri(
             run.info.artifact_uri, self.tracking_uri)
         artifact_repo = get_artifact_repository(artifact_uri)
         # Cache the artifact repo to avoid a future network call, removing the oldest
         # entry in the cache if there are too many elements
         if len(TrackingServiceClient._artifact_repos_cache) > 1024:
             TrackingServiceClient._artifact_repos_cache.popitem(last=False)
         TrackingServiceClient._artifact_repos_cache[run_id] = artifact_repo
         return artifact_repo
Пример #5
0
 def get_underlying_uri(uri):
     # Note: to support a registry URI that is different from the tracking URI here,
     # we'll need to add setting of registry URIs via environment variables.
     from mlflow.tracking import MlflowClient
     databricks_profile_uri = get_databricks_profile_uri_from_artifact_uri(
         uri)
     client = MlflowClient(registry_uri=databricks_profile_uri)
     (name, version, stage) = ModelsArtifactRepository._parse_uri(uri)
     if stage is not None:
         latest = client.get_latest_versions(name, [stage])
         if len(latest) == 0:
             raise MlflowException(
                 "No versions of model with name '{name}' and "
                 "stage '{stage}' found".format(name=name, stage=stage))
         version = latest[0].version
     download_uri = client.get_model_version_download_uri(name, version)
     return add_databricks_profile_info_to_artifact_uri(
         download_uri, databricks_profile_uri)
Пример #6
0
def test_add_databricks_profile_info_to_artifact_uri_errors(
        artifact_uri, profile_uri):
    with pytest.raises(MlflowException,
                       match="Unsupported Databricks profile"):
        add_databricks_profile_info_to_artifact_uri(artifact_uri, profile_uri)
Пример #7
0
def test_add_databricks_profile_info_to_artifact_uri(artifact_uri, profile_uri,
                                                     result):
    assert add_databricks_profile_info_to_artifact_uri(artifact_uri,
                                                       profile_uri) == result
Пример #8
0
 def _get_artifact_repo(self, run_id):
     run = self.get_run(run_id)
     artifact_uri = add_databricks_profile_info_to_artifact_uri(
         run.info.artifact_uri, self.tracking_uri)
     return get_artifact_repository(artifact_uri)
Пример #9
0
def test_add_databricks_profile_info_to_artifact_uri_errors(
        artifact_uri, profile_uri):
    with pytest.raises(MlflowException):
        add_databricks_profile_info_to_artifact_uri(artifact_uri, profile_uri)