def test_update_model_version_stage(self, mock_http): rm = RegisteredModel("model_1") mv = ModelVersion(rm, 5) self.store.update_model_version(model_version=mv, stage="prod") self._verify_requests( mock_http, "model-versions/update", "PATCH", UpdateModelVersion(model_version=mv.to_proto(), stage="prod"))
def test_get_model_version_stages(self, mock_http): rm = RegisteredModel("model_11") mv = ModelVersion(rm, 8) self.store.get_model_version_stages(model_version=mv) self._verify_requests( mock_http, "model-versions/get-stages", "POST", GetModelVersionStages(model_version=mv.to_proto()))
def test_search_model_versions(mock_get_request_message, mock_model_registry_store): mock_get_request_message.return_value = SearchModelVersions(filter="source_path = 'A/B/CD'") mvds = [ ModelVersion(name="model_1", version="5", creation_timestamp=100, last_updated_timestamp=1200, description="v 5", user_id="u1", current_stage="Production", source="A/B/CD", run_id=uuid.uuid4().hex, status="READY", status_message=None), ModelVersion(name="model_1", version="12", creation_timestamp=110, last_updated_timestamp=2000, description="v 12", user_id="u2", current_stage="Production", source="A/B/CD", run_id=uuid.uuid4().hex, status="READY", status_message=None), ModelVersion(name="ads_model", version="8", creation_timestamp=200, last_updated_timestamp=2000, description="v 8", user_id="u1", current_stage="Staging", source="A/B/CD", run_id=uuid.uuid4().hex, status="READY", status_message=None), ModelVersion(name="fraud_detection_model", version="345", creation_timestamp=1000, last_updated_timestamp=1001, description="newest version", user_id="u12", current_stage="None", source="A/B/CD", run_id=uuid.uuid4().hex, status="READY", status_message=None), ] mock_model_registry_store.search_model_versions.return_value = mvds resp = _search_model_versions() args, _ = mock_model_registry_store.search_model_versions.call_args assert args == ("source_path = 'A/B/CD'",) assert json.loads(resp.get_data()) == {"model_versions": jsonify(mvds)}
def test_get_model_name_and_version_with_latest(): with mock.patch.object( MlflowClient, "get_latest_versions", return_value=[ ModelVersion(name="mv1", version="10", creation_timestamp=123, current_stage="Production"), ModelVersion(name="mv3", version="20", creation_timestamp=125, current_stage="None"), ModelVersion(name="mv2", version="15", creation_timestamp=124, current_stage="Staging"), ], ) as mlflow_client_mock: assert get_model_name_and_version(MlflowClient(), "models:/AdsModel1/latest") == ( "AdsModel1", "20", ) mlflow_client_mock.assert_called_once_with("AdsModel1", None)
def test_delete_model_version(mock_get_request_message, mock_model_registry_store): rm = RegisteredModel("model1") mv = ModelVersion(registered_model=rm, version=32) mock_get_request_message.return_value = DeleteModelVersion(model_version=mv.to_proto()) _delete_model_version() args, _ = mock_model_registry_store.delete_model_version.call_args assert args == (mv, )
def test_get_model_version_download_uri(self, mock_http): rm = RegisteredModel("model_11") mv = ModelVersion(rm, 8) self.store.get_model_version_download_uri(model_version=mv) self._verify_requests( mock_http, "model-versions/get-download-uri", "POST", GetModelVersionDownloadUri(model_version=mv.to_proto()))
def test_get_latest_versions(mock_get_request_message, mock_model_registry_store): name = "model1" mock_get_request_message.return_value = GetLatestVersions(name=name) mvds = [ ModelVersion(name=name, version="5", creation_timestamp=1, last_updated_timestamp=12, description="v 5", user_id="u1", current_stage="Production", source="A/B", run_id=uuid.uuid4().hex, status="READY", status_message=None), ModelVersion(name=name, version="1", creation_timestamp=1, last_updated_timestamp=1200, description="v 1", user_id="u1", current_stage="Archived", source="A/B2", run_id=uuid.uuid4().hex, status="READY", status_message=None), ModelVersion(name=name, version="12", creation_timestamp=100, last_updated_timestamp=None, description="v 12", user_id="u2", current_stage="Staging", source="A/B3", run_id=uuid.uuid4().hex, status="READY", status_message=None), ] mock_model_registry_store.get_latest_versions.return_value = mvds resp = _get_latest_versions() _, args = mock_model_registry_store.get_latest_versions.call_args assert args == {"name": name, "stages": []} assert json.loads(resp.get_data()) == {"model_versions": jsonify(mvds)} for stages in [[], ["None"], ["Staging"], ["Staging", "Production"]]: mock_get_request_message.return_value = GetLatestVersions(name=name, stages=stages) _get_latest_versions() _, args = mock_model_registry_store.get_latest_versions.call_args assert args == {"name": name, "stages": stages}
def test_search_model_versions(mock_store): mock_store.search_model_versions.return_value = [ ModelVersion(name="Model 1", version="1", creation_timestamp=123), ModelVersion(name="Model 1", version="2", creation_timestamp=124), ] result = newModelRegistryClient().search_model_versions("name=Model 1") mock_store.search_model_versions.assert_called_once_with("name=Model 1") assert len(result) == 2
def test_search_model_versions(mock_store): mock_store.search_model_versions.return_value = [ ModelVersion(RegisteredModel("Model 1"), 1), ModelVersion(RegisteredModel("Model 1"), 2) ] result = newModelRegistryClient().search_model_versions("name=Model 1") mock_store.search_model_versions.assert_called_once_with("name=Model 1") assert len(result) == 2
def test_update_model_version(mock_get_request_message, mock_model_registry_store): rm = RegisteredModel("model1") mv = ModelVersion(registered_model=rm, version=32) mock_get_request_message.return_value = UpdateModelVersion(model_version=mv.to_proto(), stage="Production", description="Great model!") _update_model_version() args, _ = mock_model_registry_store.update_model_version.call_args assert args == (mv, "Production", "Great model!")
def test_get_model_version_download_uri(mock_get_request_message, mock_model_registry_store): rm = RegisteredModel("model1") mv = ModelVersion(registered_model=rm, version=32) mock_get_request_message.return_value = GetModelVersionDownloadUri(model_version=mv.to_proto()) mock_model_registry_store.get_model_version_download_uri.return_value = "some/download/path" resp = _get_model_version_download_uri() args, _ = mock_model_registry_store.get_model_version_download_uri.call_args assert args == (mv, ) assert json.loads(resp.get_data()) == {"artifact_uri": "some/download/path"}
def test_update_model_version_decription(self, mock_http): rm = RegisteredModel("model_1") mv = ModelVersion(rm, 5) self.store.update_model_version(model_version=mv, description="test model version") self._verify_requests( mock_http, "model-versions/update", "PATCH", UpdateModelVersion(model_version=mv.to_proto(), description="test model version"))
def test_model_version_stages(mock_get_request_message, mock_model_registry_store): rm = RegisteredModel("model1") mv = ModelVersion(registered_model=rm, version=32) mock_get_request_message.return_value = GetModelVersionStages(model_version=mv.to_proto()) stages = ["Stage1", "Production", "0", "5% traffic", "None"] mock_model_registry_store.get_model_version_stages.return_value = stages resp = _get_model_version_stages() args, _ = mock_model_registry_store.get_model_version_stages.call_args assert args == (mv, ) assert json.loads(resp.get_data()) == {"stages": stages}
def test_get_model_version_details(mock_get_request_message, mock_model_registry_store): rm = RegisteredModel("model1") mv = ModelVersion(registered_model=rm, version=32) mock_get_request_message.return_value = GetModelVersionDetails(model_version=mv.to_proto()) mvd = ModelVersionDetailed(registered_model=rm, version=5, creation_timestamp=1, last_updated_timestamp=12, description="v 5", user_id="u1", current_stage="Production", source="A/B", run_id=uuid.uuid4().hex, status="READY", status_message=None) mock_model_registry_store.get_model_version_details.return_value = mvd resp = _get_model_version_details() args, _ = mock_model_registry_store.get_model_version_details.call_args assert args == (mv, ) assert json.loads(resp.get_data()) == {"model_version_detailed": jsonify(mvd)}
def test_create_model_version(mock_get_request_message, mock_model_registry_store): run_id = uuid.uuid4().hex tags = [ ModelVersionTag(key="key", value="value"), ModelVersionTag(key="anotherKey", value="some other value") ] run_link = "localhost:5000/path/to/run" mock_get_request_message.return_value = CreateModelVersion( name="model_1", source="A/B", run_id=run_id, run_link=run_link, tags=[tag.to_proto() for tag in tags]) mv = ModelVersion(name="model_1", version="12", creation_timestamp=123, tags=tags, run_link=run_link) mock_model_registry_store.create_model_version.return_value = mv resp = _create_model_version() _, args = mock_model_registry_store.create_model_version.call_args assert args["name"] == "model_1" assert args["source"] == "A/B" assert args["run_id"] == run_id assert {tag.key: tag.value for tag in args["tags"]} == {tag.key: tag.value for tag in tags} assert args["run_link"] == run_link assert json.loads(resp.get_data()) == {"model_version": jsonify(mv)}
def test_create_model_version_no_run_id(mock_store): name = "Model 1" version = "1" tags_dict = {"key": "value", "another key": "some other value"} tags = [ModelVersionTag(key, value) for key, value in tags_dict.items()] description = "best model ever" mock_store.create_model_version.return_value = ModelVersion( name=name, version=version, creation_timestamp=123, tags=tags, run_link=None, description=description, ) result = newModelRegistryClient().create_model_version( name, "uri:/for/source", tags=tags_dict, run_link=None, description=description ) mock_store.create_model_version.assert_called_once_with( name, "uri:/for/source", None, tags, None, description ) assert result.name == name assert result.version == version assert result.tags == tags_dict assert result.run_id is None
def test_models_artifact_repo_init_with_stage_uri( host_creds_mock, ): # pylint: disable=unused-argument model_uri = "models:/MyModel/Production" artifact_location = "dbfs:/databricks/mlflow-registry/12345/models/keras-model" model_version_detailed = ModelVersion( "MyModel", "10", "2345671890", "234567890", "some description", "UserID", "Production", "source", "run12345", ) get_latest_versions_patch = mock.patch.object( MlflowClient, "get_latest_versions", return_value=[model_version_detailed] ) get_model_version_download_uri_patch = mock.patch.object( MlflowClient, "get_model_version_download_uri", return_value=artifact_location ) with get_latest_versions_patch, get_model_version_download_uri_patch: models_repo = ModelsArtifactRepository(model_uri) assert models_repo.artifact_uri == model_uri assert isinstance(models_repo.repo, DbfsRestArtifactRepository) assert models_repo.repo.artifact_uri == artifact_location
def test_create_model_version_run_link_with_configured_profile( mock_registry_store): experiment_id = "test-exp-id" hostname = "https://workspace.databricks.com/" workspace_id = "10002" run_id = "runid" workspace_url = construct_run_url(hostname, experiment_id, run_id, workspace_id) get_run_mock = mock.MagicMock() get_run_mock.return_value = Run( RunInfo(run_id, experiment_id, "userid", "status", 0, 1, None), None) with mock.patch( "mlflow.tracking.client.is_in_databricks_notebook", return_value=False ), mock.patch( "mlflow.tracking.client.get_workspace_info_from_databricks_secrets", return_value=(hostname, workspace_id), ): client = MlflowClient(tracking_uri="databricks", registry_uri="otherplace") client.get_run = get_run_mock mock_registry_store.create_model_version.return_value = ModelVersion( "name", 1, 0, 1, source="source", run_id=run_id, run_link=workspace_url) model_version = client.create_model_version("name", "source", "runid") assert model_version.run_link == workspace_url # verify that the client generated the right URL mock_registry_store.create_model_version.assert_called_once_with( "name", "source", "runid", [], workspace_url, None)
def test_models_artifact_repo_init_with_stage_uri_and_not_using_databricks_registry( ): model_uri = "models:/MyModel/Staging" artifact_location = "s3://blah_bucket/" model_version_detailed = ModelVersion( "MyModel", "10", "2345671890", "234567890", "some description", "UserID", "Production", "source", "run12345", ) get_latest_versions_patch = mock.patch.object( MlflowClient, "get_latest_versions", return_value=[model_version_detailed]) get_model_version_download_uri_patch = mock.patch.object( MlflowClient, "get_model_version_download_uri", return_value=artifact_location) with get_latest_versions_patch, get_model_version_download_uri_patch, mock.patch( "mlflow.store.artifact.artifact_repository_registry.get_artifact_repository" ) as get_repo_mock: get_repo_mock.return_value = None ModelsArtifactRepository(model_uri) get_repo_mock.assert_called_once_with(artifact_location)
def test_models_artifact_repo_init_with_stage_uri_and_db_profile(): model_uri = "models://profile@databricks/MyModel/Staging" artifact_location = "dbfs:/databricks/mlflow-registry/12345/models/keras-model" final_uri = "dbfs://profile@databricks/databricks/mlflow-registry/12345/models/keras-model" model_version_detailed = ModelVersion( "MyModel", "10", "2345671890", "234567890", "some description", "UserID", "Production", "source", "run12345", ) get_latest_versions_patch = mock.patch.object( MlflowClient, "get_latest_versions", return_value=[model_version_detailed] ) get_model_version_download_uri_patch = mock.patch.object( MlflowClient, "get_model_version_download_uri", return_value=artifact_location ) with get_latest_versions_patch, get_model_version_download_uri_patch, mock.patch( "mlflow.store.artifact.dbfs_artifact_repo.DbfsRestArtifactRepository", autospec=True ) as mock_repo: models_repo = ModelsArtifactRepository(model_uri) assert models_repo.artifact_uri == model_uri assert isinstance(models_repo.repo, DbfsRestArtifactRepository) mock_repo.assert_called_once_with(final_uri)
def create_model_version(self, name, source, run_id, tags=None, run_link=None, description=None): """ Create a new model version from given source and run ID. :param name: Registered model name. :param source: Source path where the MLflow model is stored. :param run_id: Run ID from MLflow tracking server that generated the model. :param tags: A list of :py:class:`mlflow.entities.model_registry.ModelVersionTag` instances associated with this model version. :param run_link: Link to the run from an MLflow tracking server that generated this model. :param description: Description of the version. :return: A single object of :py:class:`mlflow.entities.model_registry.ModelVersion` created in the backend. """ proto_tags = [tag.to_proto() for tag in tags or []] req_body = message_to_json( CreateModelVersion( name=name, source=source, run_id=run_id, run_link=run_link, tags=proto_tags, description=description, )) response_proto = self._call_endpoint(CreateModelVersion, req_body) return ModelVersion.from_proto(response_proto.model_version)
def test_create_model_version_run_link_in_notebook_with_default_profile( mock_registry_store): experiment_id = 'test-exp-id' hostname = 'https://workspace.databricks.com/' workspace_id = '10002' run_id = 'runid' workspace_url = construct_run_url(hostname, experiment_id, run_id, workspace_id) get_run_mock = mock.MagicMock() get_run_mock.return_value = Run( RunInfo(run_id, experiment_id, 'userid', 'status', 0, 1, None), None) with mock.patch('mlflow.tracking.client.is_in_databricks_notebook', return_value=True), \ mock.patch('mlflow.tracking.client.get_workspace_info_from_dbutils', return_value=(hostname, workspace_id)): client = MlflowClient(tracking_uri='databricks', registry_uri='otherplace') client.get_run = get_run_mock mock_registry_store.create_model_version.return_value = \ ModelVersion('name', 1, 0, 1, source='source', run_id=run_id, run_link=workspace_url) model_version = client.create_model_version('name', 'source', 'runid') assert (model_version.run_link == workspace_url) # verify that the client generated the right URL mock_registry_store.create_model_version.assert_called_once_with( "name", 'source', 'runid', [], workspace_url)
def _get_model_version_stages(): request_message = _get_request_message(GetModelVersionStages()) stages = _get_model_registry_store().get_model_version_stages( ModelVersion.from_proto(request_message.model_version)) response_message = GetModelVersionStages.Response() response_message.stages.extend(stages) return _wrap_response(response_message)
def _get_model_version_download_uri(): request_message = _get_request_message(GetModelVersionDownloadUri()) download_uri = _get_model_registry_store().get_model_version_download_uri( ModelVersion.from_proto(request_message.model_version)) response_message = GetModelVersionDownloadUri.Response( artifact_uri=download_uri) return _wrap_response(response_message)
def _model_version(name, version, stage, source="some:/source", run_id="run13579"): return ModelVersion(name, version, "2345671890", "234567890", "some description", "UserID", stage, source, run_id)
def _get_model_version_details(): request_message = _get_request_message(GetModelVersionDetails()) model_version_detailed = _get_model_registry_store().get_model_version_details( ModelVersion.from_proto(request_message.model_version)) response_proto = model_version_detailed.to_proto() response_message = GetModelVersionDetails.Response(model_version_detailed=response_proto) return _wrap_response(response_message)
def test_init_with_stage_uri_and_profile_is_inferred( self, stage_uri_without_profile): model_version_detailed = ModelVersion( MOCK_MODEL_NAME, MOCK_MODEL_VERSION, "2345671890", "234567890", "some description", "UserID", "Production", "source", "run12345", ) get_latest_versions_patch = mock.patch.object( MlflowClient, "get_latest_versions", return_value=[model_version_detailed]) with get_latest_versions_patch, mock.patch( "mlflow.store.artifact.utils.models.mlflow.get_registry_uri", return_value=MOCK_PROFILE, ), mock.patch("mlflow.tracking.get_registry_uri", return_value=MOCK_PROFILE): repo = DatabricksModelsArtifactRepository( stage_uri_without_profile) assert repo.artifact_uri == stage_uri_without_profile assert repo.model_name == MOCK_MODEL_NAME assert repo.model_version == MOCK_MODEL_VERSION assert repo.databricks_profile_uri == MOCK_PROFILE
def get_model_version_details(self, name, version): """ :param name: Name of the containing registered model. :param version: Version number of the model version. :return: A single :py:class:`mlflow.entities.model_registry.ModelVersionDetailed` object. """ return self.store.get_model_version_details( ModelVersion(RegisteredModel(name), version))
def delete_model_version(self, name, version): """ Delete model version in backend. :param name: Name of the containing registered model. :param version: Version number of the model version. """ self.store.delete_model_version( ModelVersion(RegisteredModel(name), version))
def test_create_model_version(mock_store): name = "Model 1" version = "1" mock_store.create_model_version.return_value = ModelVersion(name=name, version=version, creation_timestamp=123) result = newModelRegistryClient().create_model_version(name, "uri:/for/source", "run123") mock_store.create_model_version.assert_called_once_with(name, "uri:/for/source", "run123") assert result.name == name assert result.version == version