def __init__(self, artifact_uri): from mlflow.store.artifact.artifact_repository_registry import get_artifact_repository super().__init__(artifact_uri) if is_using_databricks_registry(artifact_uri): # Use the DatabricksModelsArtifactRepository if a databricks profile is being used. self.repo = DatabricksModelsArtifactRepository(artifact_uri) else: uri = ModelsArtifactRepository.get_underlying_uri(artifact_uri) self.repo = get_artifact_repository(uri)
def test_init_with_invalid_artifact_uris(self, invalid_artifact_uri): with pytest.raises( MlflowException, match= "A valid databricks profile is required to instantiate this repository", ): DatabricksModelsArtifactRepository(invalid_artifact_uri)
def test_init_with_stage_uri_and_profile_is_inferred( self, stage_uri_without_profile): model_version_detailed = ModelVersion( MOCK_MODEL_NAME, MOCK_MODEL_VERSION, "2345671890", "234567890", "some description", "UserID", "Production", "source", "run12345", ) get_latest_versions_patch = mock.patch.object( MlflowClient, "get_latest_versions", return_value=[model_version_detailed]) with get_latest_versions_patch, mock.patch( "mlflow.store.artifact.utils.models.mlflow.get_registry_uri", return_value=MOCK_PROFILE, ), mock.patch("mlflow.tracking.get_registry_uri", return_value=MOCK_PROFILE): repo = DatabricksModelsArtifactRepository( stage_uri_without_profile) assert repo.artifact_uri == stage_uri_without_profile assert repo.model_name == MOCK_MODEL_NAME assert repo.model_version == MOCK_MODEL_VERSION assert repo.databricks_profile_uri == MOCK_PROFILE
def test_init_with_valid_uri_but_no_profile(self, valid_profileless_artifact_uri): # Mock for `is_using_databricks_registry` fail when calling `get_registry_uri` with mock.patch( "mlflow.store.artifact.utils.models.mlflow.get_registry_uri", return_value=None, ): with pytest.raises(MlflowException): DatabricksModelsArtifactRepository(valid_profileless_artifact_uri)
def test_init_with_version_uri_and_profile_is_inferred(self): # First mock for `is_using_databricks_registry` to pass # Second mock to set `databricks_profile_uri` during instantiation with mock.patch( "mlflow.store.artifact.utils.models.mlflow.get_registry_uri", return_value=MOCK_PROFILE, ), mock.patch("mlflow.tracking.get_registry_uri", return_value=MOCK_PROFILE): repo = DatabricksModelsArtifactRepository(MOCK_MODEL_ROOT_URI_WITHOUT_PROFILE) assert repo.artifact_uri == MOCK_MODEL_ROOT_URI_WITHOUT_PROFILE assert repo.model_name == MOCK_MODEL_NAME assert repo.model_version == MOCK_MODEL_VERSION assert repo.databricks_profile_uri == MOCK_PROFILE
def test_init_with_stage_uri_containing_profile(self, stage_uri_with_profile): model_version_detailed = ModelVersion( MOCK_MODEL_NAME, MOCK_MODEL_VERSION, "2345671890", "234567890", "some description", "UserID", "Production", "source", "run12345", ) get_latest_versions_patch = mock.patch.object( MlflowClient, "get_latest_versions", return_value=[model_version_detailed] ) with get_latest_versions_patch: repo = DatabricksModelsArtifactRepository(stage_uri_with_profile) assert repo.artifact_uri == stage_uri_with_profile assert repo.model_name == MOCK_MODEL_NAME assert repo.model_version == MOCK_MODEL_VERSION assert repo.databricks_profile_uri == MOCK_PROFILE
def test_init_with_invalid_artifact_uris(self, invalid_artifact_uri): with pytest.raises(MlflowException): DatabricksModelsArtifactRepository(invalid_artifact_uri)
def test_init_with_version_uri_containing_profile(self): repo = DatabricksModelsArtifactRepository(MOCK_MODEL_ROOT_URI_WITH_PROFILE) assert repo.artifact_uri == MOCK_MODEL_ROOT_URI_WITH_PROFILE assert repo.model_name == MOCK_MODEL_NAME assert repo.model_version == MOCK_MODEL_VERSION
def databricks_model_artifact_repo(): return DatabricksModelsArtifactRepository(MOCK_MODEL_ROOT_URI_WITH_PROFILE)
class ModelsArtifactRepository(ArtifactRepository): """ Handles artifacts associated with a model version in the model registry via URIs of the form: - `models:/<model_name>/<model_version>` - `models:/<model_name>/<stage>` (refers to the latest model version in the given stage) - `models:/<model_name>/latest` (refers to the latest of all model versions) It is a light wrapper that resolves the artifact path to an absolute URI then instantiates and uses the artifact repository for that URI. """ def __init__(self, artifact_uri): from mlflow.store.artifact.artifact_repository_registry import get_artifact_repository super().__init__(artifact_uri) if is_using_databricks_registry(artifact_uri): # Use the DatabricksModelsArtifactRepository if a databricks profile is being used. self.repo = DatabricksModelsArtifactRepository(artifact_uri) else: uri = ModelsArtifactRepository.get_underlying_uri(artifact_uri) self.repo = get_artifact_repository(uri) # TODO: it may be nice to fall back to the source URI explicitly here if for some reason # we don't get a download URI here, or fail during the download itself. @staticmethod def is_models_uri(uri): return urllib.parse.urlparse(uri).scheme == "models" @staticmethod def get_underlying_uri(uri): # Note: to support a registry URI that is different from the tracking URI here, # we'll need to add setting of registry URIs via environment variables. from mlflow.tracking import MlflowClient databricks_profile_uri = ( get_databricks_profile_uri_from_artifact_uri(uri) or mlflow.get_registry_uri() ) client = MlflowClient(registry_uri=databricks_profile_uri) (name, version) = get_model_name_and_version(client, uri) download_uri = client.get_model_version_download_uri(name, version) return add_databricks_profile_info_to_artifact_uri(download_uri, databricks_profile_uri) def log_artifact(self, local_file, artifact_path=None): """ Log a local file as an artifact, optionally taking an ``artifact_path`` to place it in within the run's artifacts. Run artifacts can be organized into directories, so you can place the artifact in a directory this way. :param local_file: Path to artifact to log :param artifact_path: Directory within the run's artifact directory in which to log the artifact """ raise ValueError( "log_artifact is not supported for models:/ URIs. Use register_model instead." ) def log_artifacts(self, local_dir, artifact_path=None): """ Log the files in the specified local directory as artifacts, optionally taking an ``artifact_path`` to place them in within the run's artifacts. :param local_dir: Directory of local artifacts to log :param artifact_path: Directory within the run's artifact directory in which to log the artifacts """ raise ValueError( "log_artifacts is not supported for models:/ URIs. Use register_model instead." ) def list_artifacts(self, path): """ Return all the artifacts for this run_id directly under path. If path is a file, returns an empty list. Will error if path is neither a file nor directory. :param path: Relative source path that contain desired artifacts :return: List of artifacts as FileInfo listed directly under path. """ return self.repo.list_artifacts(path) def download_artifacts(self, artifact_path, dst_path=None): """ Download an artifact file or directory to a local directory if applicable, and return a local path for it. The caller is responsible for managing the lifecycle of the downloaded artifacts. :param artifact_path: Relative source path to the desired artifacts. :param dst_path: Absolute path of the local filesystem destination directory to which to download the specified artifacts. This directory must already exist. If unspecified, the artifacts will either be downloaded to a new uniquely-named directory on the local filesystem or will be returned directly in the case of the LocalArtifactRepository. :return: Absolute path of the local filesystem location containing the desired artifacts. """ return self.repo.download_artifacts(artifact_path, dst_path) def _download_file(self, remote_file_path, local_path): """ Download the file at the specified relative remote path and saves it at the specified local path. :param remote_file_path: Source path to the remote file, relative to the root directory of the artifact repository. :param local_path: The path to which to save the downloaded file. """ self.repo._download_file(remote_file_path, local_path) def delete_artifacts(self, artifact_path=None): raise MlflowException("Not implemented yet")