예제 #1
0
    def __init__(self, artifact_uri):
        from mlflow.store.artifact.artifact_repository_registry import get_artifact_repository

        super().__init__(artifact_uri)
        if is_using_databricks_registry(artifact_uri):
            # Use the DatabricksModelsArtifactRepository if a databricks profile is being used.
            self.repo = DatabricksModelsArtifactRepository(artifact_uri)
        else:
            uri = ModelsArtifactRepository.get_underlying_uri(artifact_uri)
            self.repo = get_artifact_repository(uri)
 def test_init_with_invalid_artifact_uris(self, invalid_artifact_uri):
     with pytest.raises(
             MlflowException,
             match=
             "A valid databricks profile is required to instantiate this repository",
     ):
         DatabricksModelsArtifactRepository(invalid_artifact_uri)
 def test_init_with_stage_uri_and_profile_is_inferred(
         self, stage_uri_without_profile):
     model_version_detailed = ModelVersion(
         MOCK_MODEL_NAME,
         MOCK_MODEL_VERSION,
         "2345671890",
         "234567890",
         "some description",
         "UserID",
         "Production",
         "source",
         "run12345",
     )
     get_latest_versions_patch = mock.patch.object(
         MlflowClient,
         "get_latest_versions",
         return_value=[model_version_detailed])
     with get_latest_versions_patch, mock.patch(
             "mlflow.store.artifact.utils.models.mlflow.get_registry_uri",
             return_value=MOCK_PROFILE,
     ), mock.patch("mlflow.tracking.get_registry_uri",
                   return_value=MOCK_PROFILE):
         repo = DatabricksModelsArtifactRepository(
             stage_uri_without_profile)
         assert repo.artifact_uri == stage_uri_without_profile
         assert repo.model_name == MOCK_MODEL_NAME
         assert repo.model_version == MOCK_MODEL_VERSION
         assert repo.databricks_profile_uri == MOCK_PROFILE
예제 #4
0
 def test_init_with_valid_uri_but_no_profile(self, valid_profileless_artifact_uri):
     # Mock for `is_using_databricks_registry` fail when calling `get_registry_uri`
     with mock.patch(
         "mlflow.store.artifact.utils.models.mlflow.get_registry_uri", return_value=None,
     ):
         with pytest.raises(MlflowException):
             DatabricksModelsArtifactRepository(valid_profileless_artifact_uri)
예제 #5
0
 def test_init_with_version_uri_and_profile_is_inferred(self):
     # First mock for `is_using_databricks_registry` to pass
     # Second mock to set `databricks_profile_uri` during instantiation
     with mock.patch(
         "mlflow.store.artifact.utils.models.mlflow.get_registry_uri", return_value=MOCK_PROFILE,
     ), mock.patch("mlflow.tracking.get_registry_uri", return_value=MOCK_PROFILE):
         repo = DatabricksModelsArtifactRepository(MOCK_MODEL_ROOT_URI_WITHOUT_PROFILE)
         assert repo.artifact_uri == MOCK_MODEL_ROOT_URI_WITHOUT_PROFILE
         assert repo.model_name == MOCK_MODEL_NAME
         assert repo.model_version == MOCK_MODEL_VERSION
         assert repo.databricks_profile_uri == MOCK_PROFILE
예제 #6
0
 def test_init_with_stage_uri_containing_profile(self, stage_uri_with_profile):
     model_version_detailed = ModelVersion(
         MOCK_MODEL_NAME,
         MOCK_MODEL_VERSION,
         "2345671890",
         "234567890",
         "some description",
         "UserID",
         "Production",
         "source",
         "run12345",
     )
     get_latest_versions_patch = mock.patch.object(
         MlflowClient, "get_latest_versions", return_value=[model_version_detailed]
     )
     with get_latest_versions_patch:
         repo = DatabricksModelsArtifactRepository(stage_uri_with_profile)
         assert repo.artifact_uri == stage_uri_with_profile
         assert repo.model_name == MOCK_MODEL_NAME
         assert repo.model_version == MOCK_MODEL_VERSION
         assert repo.databricks_profile_uri == MOCK_PROFILE
예제 #7
0
 def test_init_with_invalid_artifact_uris(self, invalid_artifact_uri):
     with pytest.raises(MlflowException):
         DatabricksModelsArtifactRepository(invalid_artifact_uri)
예제 #8
0
 def test_init_with_version_uri_containing_profile(self):
     repo = DatabricksModelsArtifactRepository(MOCK_MODEL_ROOT_URI_WITH_PROFILE)
     assert repo.artifact_uri == MOCK_MODEL_ROOT_URI_WITH_PROFILE
     assert repo.model_name == MOCK_MODEL_NAME
     assert repo.model_version == MOCK_MODEL_VERSION
예제 #9
0
def databricks_model_artifact_repo():
    return DatabricksModelsArtifactRepository(MOCK_MODEL_ROOT_URI_WITH_PROFILE)
예제 #10
0
class ModelsArtifactRepository(ArtifactRepository):
    """
    Handles artifacts associated with a model version in the model registry via URIs of the form:
      - `models:/<model_name>/<model_version>`
      - `models:/<model_name>/<stage>`  (refers to the latest model version in the given stage)
      - `models:/<model_name>/latest` (refers to the latest of all model versions)
    It is a light wrapper that resolves the artifact path to an absolute URI then instantiates
    and uses the artifact repository for that URI.
    """

    def __init__(self, artifact_uri):
        from mlflow.store.artifact.artifact_repository_registry import get_artifact_repository

        super().__init__(artifact_uri)
        if is_using_databricks_registry(artifact_uri):
            # Use the DatabricksModelsArtifactRepository if a databricks profile is being used.
            self.repo = DatabricksModelsArtifactRepository(artifact_uri)
        else:
            uri = ModelsArtifactRepository.get_underlying_uri(artifact_uri)
            self.repo = get_artifact_repository(uri)
            # TODO: it may be nice to fall back to the source URI explicitly here if for some reason
            #  we don't get a download URI here, or fail during the download itself.

    @staticmethod
    def is_models_uri(uri):
        return urllib.parse.urlparse(uri).scheme == "models"

    @staticmethod
    def get_underlying_uri(uri):
        # Note: to support a registry URI that is different from the tracking URI here,
        # we'll need to add setting of registry URIs via environment variables.
        from mlflow.tracking import MlflowClient

        databricks_profile_uri = (
            get_databricks_profile_uri_from_artifact_uri(uri) or mlflow.get_registry_uri()
        )
        client = MlflowClient(registry_uri=databricks_profile_uri)
        (name, version) = get_model_name_and_version(client, uri)
        download_uri = client.get_model_version_download_uri(name, version)
        return add_databricks_profile_info_to_artifact_uri(download_uri, databricks_profile_uri)

    def log_artifact(self, local_file, artifact_path=None):
        """
        Log a local file as an artifact, optionally taking an ``artifact_path`` to place it in
        within the run's artifacts. Run artifacts can be organized into directories, so you can
        place the artifact in a directory this way.

        :param local_file: Path to artifact to log
        :param artifact_path: Directory within the run's artifact directory in which to log the
                              artifact
        """
        raise ValueError(
            "log_artifact is not supported for models:/ URIs. Use register_model instead."
        )

    def log_artifacts(self, local_dir, artifact_path=None):
        """
        Log the files in the specified local directory as artifacts, optionally taking
        an ``artifact_path`` to place them in within the run's artifacts.

        :param local_dir: Directory of local artifacts to log
        :param artifact_path: Directory within the run's artifact directory in which to log the
                              artifacts
        """
        raise ValueError(
            "log_artifacts is not supported for models:/ URIs. Use register_model instead."
        )

    def list_artifacts(self, path):
        """
        Return all the artifacts for this run_id directly under path. If path is a file, returns
        an empty list. Will error if path is neither a file nor directory.

        :param path: Relative source path that contain desired artifacts

        :return: List of artifacts as FileInfo listed directly under path.
        """
        return self.repo.list_artifacts(path)

    def download_artifacts(self, artifact_path, dst_path=None):
        """
        Download an artifact file or directory to a local directory if applicable, and return a
        local path for it.
        The caller is responsible for managing the lifecycle of the downloaded artifacts.

        :param artifact_path: Relative source path to the desired artifacts.
        :param dst_path: Absolute path of the local filesystem destination directory to which to
                         download the specified artifacts. This directory must already exist.
                         If unspecified, the artifacts will either be downloaded to a new
                         uniquely-named directory on the local filesystem or will be returned
                         directly in the case of the LocalArtifactRepository.

        :return: Absolute path of the local filesystem location containing the desired artifacts.
        """
        return self.repo.download_artifacts(artifact_path, dst_path)

    def _download_file(self, remote_file_path, local_path):
        """
        Download the file at the specified relative remote path and saves
        it at the specified local path.

        :param remote_file_path: Source path to the remote file, relative to the root
                                 directory of the artifact repository.
        :param local_path: The path to which to save the downloaded file.
        """
        self.repo._download_file(remote_file_path, local_path)

    def delete_artifacts(self, artifact_path=None):
        raise MlflowException("Not implemented yet")