def _upload_artifacts_to_databricks( source, run_id, source_host_uri=None, target_databricks_profile_uri=None ): """ Copy the artifacts from ``source`` to the destination Databricks workspace (DBFS) given by ``databricks_profile_uri`` or the current tracking URI. :param source: Source location for the artifacts to copy. :param run_id: Run ID to associate the artifacts with. :param source_host_uri: Specifies the source artifact's host URI (e.g. Databricks tracking URI) if applicable. If not given, defaults to the current tracking URI. :param target_databricks_profile_uri: Specifies the destination Databricks host. If not given, defaults to the current tracking URI. :return: The DBFS location in the target Databricks workspace the model files have been uploaded to. """ from uuid import uuid4 local_dir = tempfile.mkdtemp() try: source_with_profile = add_databricks_profile_info_to_artifact_uri(source, source_host_uri) _download_artifact_from_uri(source_with_profile, local_dir) dest_root = "dbfs:/databricks/mlflow/tmp-external-source/" dest_root_with_profile = add_databricks_profile_info_to_artifact_uri( dest_root, target_databricks_profile_uri ) dest_repo = DbfsRestArtifactRepository(dest_root_with_profile) dest_artifact_path = run_id if run_id else uuid4().hex # Allow uploading from the same run id multiple times by randomizing a suffix if len(dest_repo.list_artifacts(dest_artifact_path)) > 0: dest_artifact_path = dest_artifact_path + "-" + uuid4().hex[0:4] dest_repo.log_artifacts(local_dir, artifact_path=dest_artifact_path) dirname = pathlib.PurePath(source).name # innermost directory name return posixpath.join(dest_root, dest_artifact_path, dirname) # new source finally: shutil.rmtree(local_dir)
def get_underlying_uri(runs_uri): from mlflow.tracking.artifact_utils import get_artifact_uri (run_id, artifact_path) = RunsArtifactRepository.parse_runs_uri(runs_uri) tracking_uri = get_databricks_profile_uri_from_artifact_uri(runs_uri) uri = get_artifact_uri(run_id, artifact_path, tracking_uri) assert not RunsArtifactRepository.is_runs_uri( uri) # avoid an infinite loop return add_databricks_profile_info_to_artifact_uri(uri, tracking_uri)
def get_underlying_uri(uri): # Note: to support a registry URI that is different from the tracking URI here, # we'll need to add setting of registry URIs via environment variables. from mlflow.tracking import MlflowClient databricks_profile_uri = ( get_databricks_profile_uri_from_artifact_uri(uri) or mlflow.get_registry_uri() ) client = MlflowClient(registry_uri=databricks_profile_uri) (name, version) = get_model_name_and_version(client, uri) download_uri = client.get_model_version_download_uri(name, version) return add_databricks_profile_info_to_artifact_uri(download_uri, databricks_profile_uri)
def _get_artifact_repo(self, run_id): # Attempt to fetch the artifact repo from a local cache cached_repo = TrackingServiceClient._artifact_repos_cache.get(run_id) if cached_repo is not None: return cached_repo else: run = self.get_run(run_id) artifact_uri = add_databricks_profile_info_to_artifact_uri( run.info.artifact_uri, self.tracking_uri) artifact_repo = get_artifact_repository(artifact_uri) # Cache the artifact repo to avoid a future network call, removing the oldest # entry in the cache if there are too many elements if len(TrackingServiceClient._artifact_repos_cache) > 1024: TrackingServiceClient._artifact_repos_cache.popitem(last=False) TrackingServiceClient._artifact_repos_cache[run_id] = artifact_repo return artifact_repo
def get_underlying_uri(uri): # Note: to support a registry URI that is different from the tracking URI here, # we'll need to add setting of registry URIs via environment variables. from mlflow.tracking import MlflowClient databricks_profile_uri = get_databricks_profile_uri_from_artifact_uri( uri) client = MlflowClient(registry_uri=databricks_profile_uri) (name, version, stage) = ModelsArtifactRepository._parse_uri(uri) if stage is not None: latest = client.get_latest_versions(name, [stage]) if len(latest) == 0: raise MlflowException( "No versions of model with name '{name}' and " "stage '{stage}' found".format(name=name, stage=stage)) version = latest[0].version download_uri = client.get_model_version_download_uri(name, version) return add_databricks_profile_info_to_artifact_uri( download_uri, databricks_profile_uri)
def test_add_databricks_profile_info_to_artifact_uri_errors( artifact_uri, profile_uri): with pytest.raises(MlflowException, match="Unsupported Databricks profile"): add_databricks_profile_info_to_artifact_uri(artifact_uri, profile_uri)
def test_add_databricks_profile_info_to_artifact_uri(artifact_uri, profile_uri, result): assert add_databricks_profile_info_to_artifact_uri(artifact_uri, profile_uri) == result
def _get_artifact_repo(self, run_id): run = self.get_run(run_id) artifact_uri = add_databricks_profile_info_to_artifact_uri( run.info.artifact_uri, self.tracking_uri) return get_artifact_repository(artifact_uri)
def test_add_databricks_profile_info_to_artifact_uri_errors( artifact_uri, profile_uri): with pytest.raises(MlflowException): add_databricks_profile_info_to_artifact_uri(artifact_uri, profile_uri)