def test_get_latest_versions_with_stages(self, mock_http): rm = RegisteredModel("model_1") self.store.get_latest_versions(registered_model=rm, stages=["blaah"]) self._verify_requests( mock_http, "registered-models/get-latest-versions", "POST", GetLatestVersions(registered_model=rm.to_proto(), stages=["blaah"]))
def test_list_registered_model(self): self._rm_maker("A") registered_models = self.store.list_registered_models() self.assertEqual(len(registered_models), 1) self.assertEqual(registered_models[0].name, "A") self.assertIsInstance(registered_models[0], RegisteredModelDetailed) self._rm_maker("B") self.assertEqual( set([rm.name for rm in self.store.list_registered_models()]), set(["A", "B"])) self._rm_maker("BB") self._rm_maker("BA") self._rm_maker("AB") self._rm_maker("BBC") self.assertEqual( set([rm.name for rm in self.store.list_registered_models()]), set(["A", "B", "BB", "BA", "AB", "BBC"])) # list should not return deleted models self.store.delete_registered_model(RegisteredModel("BA")) self.store.delete_registered_model(RegisteredModel("B")) self.assertEqual( set([rm.name for rm in self.store.list_registered_models()]), set(["A", "BB", "AB", "BBC"]))
def test_search_model_versions(mock_get_request_message, mock_model_registry_store): mock_get_request_message.return_value = SearchModelVersions(filter="source_path = 'A/B/CD'") mvds = [ ModelVersionDetailed(RegisteredModel(name="model_1"), version=5, creation_timestamp=100, last_updated_timestamp=1200, description="v 5", user_id="u1", current_stage="Production", source="A/B/CD", run_id=uuid.uuid4().hex, status="READY", status_message=None), ModelVersionDetailed(RegisteredModel(name="model_1"), version=12, creation_timestamp=110, last_updated_timestamp=2000, description="v 12", user_id="u2", current_stage="Production", source="A/B/CD", run_id=uuid.uuid4().hex, status="READY", status_message=None), ModelVersionDetailed(RegisteredModel(name="ads_model"), version=8, creation_timestamp=200, last_updated_timestamp=2000, description="v 8", user_id="u1", current_stage="Staging", source="A/B/CD", run_id=uuid.uuid4().hex, status="READY", status_message=None), ModelVersionDetailed(RegisteredModel(name="fraud_detection_model"), version=345, creation_timestamp=1000, last_updated_timestamp=1001, description="newest version", user_id="u12", current_stage="None", source="A/B/CD", run_id=uuid.uuid4().hex, status="READY", status_message=None), ] mock_model_registry_store.search_model_versions.return_value = mvds resp = _search_model_versions() args, _ = mock_model_registry_store.search_model_versions.call_args assert args == ("source_path = 'A/B/CD'", ) assert json.loads(resp.get_data()) == {"model_versions_detailed": jsonify(mvds)}
def test_list_registered_models(mock_store): mock_store.list_registered_models.return_value = PagedList( [RegisteredModel("Model 1"), RegisteredModel("Model 2")], "" ) result = newModelRegistryClient().list_registered_models() mock_store.list_registered_models.assert_called_once() assert len(result) == 2
def test_search_registered_models(mock_store): mock_store.search_registered_models.return_value = PagedList( [RegisteredModel("Model 1"), RegisteredModel("Model 2")], "" ) result = newModelRegistryClient().search_registered_models(filter_string="test filter") mock_store.search_registered_models.assert_called_with("test filter", 100, None, None) assert len(result) == 2 assert result.token == "" result = newModelRegistryClient().search_registered_models( filter_string="another filter", max_results=12, order_by=["A", "B DESC"], page_token="next one", ) mock_store.search_registered_models.assert_called_with( "another filter", 12, ["A", "B DESC"], "next one" ) assert len(result) == 2 assert result.token == "" mock_store.search_registered_models.return_value = PagedList( [RegisteredModel("model A"), RegisteredModel("Model zz"), RegisteredModel("Model b")], "page 2 token", ) result = newModelRegistryClient().search_registered_models(max_results=5) mock_store.search_registered_models.assert_called_with(None, 5, None, None) assert [rm.name for rm in result] == ["model A", "Model zz", "Model b"] assert result.token == "page 2 token"
def test_list_registered_models(mock_get_request_message, mock_model_registry_store): mock_get_request_message.return_value = ListRegisteredModels( max_results=50) rmds = PagedList( [ RegisteredModel( name="model_1", creation_timestamp=111, last_updated_timestamp=222, description="Test model", latest_versions=[], ), RegisteredModel( name="model_2", creation_timestamp=111, last_updated_timestamp=333, description="Another model", latest_versions=[], ), ], "next_pt", ) mock_model_registry_store.list_registered_models.return_value = rmds resp = _list_registered_models() args, _ = mock_model_registry_store.list_registered_models.call_args assert args == (50, "") assert json.loads(resp.get_data()) == { "next_page_token": "next_pt", "registered_models": jsonify(rmds), }
def test_search_registered_models(mock_get_request_message, mock_model_registry_store): rmds = [ RegisteredModel( name="model_1", creation_timestamp=111, last_updated_timestamp=222, description="Test model", latest_versions=[], ), RegisteredModel( name="model_2", creation_timestamp=111, last_updated_timestamp=333, description="Another model", latest_versions=[], ), ] mock_get_request_message.return_value = SearchRegisteredModels() mock_model_registry_store.search_registered_models.return_value = PagedList(rmds, None) resp = _search_registered_models() _, args = mock_model_registry_store.search_registered_models.call_args assert args == {"filter_string": "", "max_results": 100, "order_by": [], "page_token": ""} assert json.loads(resp.get_data()) == {"registered_models": jsonify(rmds)} mock_get_request_message.return_value = SearchRegisteredModels(filter="hello") mock_model_registry_store.search_registered_models.return_value = PagedList(rmds[:1], "tok") resp = _search_registered_models() _, args = mock_model_registry_store.search_registered_models.call_args assert args == {"filter_string": "hello", "max_results": 100, "order_by": [], "page_token": ""} assert json.loads(resp.get_data()) == { "registered_models": jsonify(rmds[:1]), "next_page_token": "tok", } mock_get_request_message.return_value = SearchRegisteredModels(filter="hi", max_results=5) mock_model_registry_store.search_registered_models.return_value = PagedList([rmds[0]], "tik") resp = _search_registered_models() _, args = mock_model_registry_store.search_registered_models.call_args assert args == {"filter_string": "hi", "max_results": 5, "order_by": [], "page_token": ""} assert json.loads(resp.get_data()) == { "registered_models": jsonify([rmds[0]]), "next_page_token": "tik", } mock_get_request_message.return_value = SearchRegisteredModels( filter="hey", max_results=500, order_by=["a", "B desc"], page_token="prev" ) mock_model_registry_store.search_registered_models.return_value = PagedList(rmds, "DONE") resp = _search_registered_models() _, args = mock_model_registry_store.search_registered_models.call_args assert args == { "filter_string": "hey", "max_results": 500, "order_by": ["a", "B desc"], "page_token": "prev", } assert json.loads(resp.get_data()) == { "registered_models": jsonify(rmds), "next_page_token": "DONE", }
def test_delete_registered_model(mock_get_request_message, mock_model_registry_store): rm = RegisteredModel("model_1") mock_get_request_message.return_value = DeleteRegisteredModel( registered_model=rm.to_proto()) _delete_registered_model() args, _ = mock_model_registry_store.delete_registered_model.call_args assert args == (rm, )
def test_update_registered_model_name(self, mock_http): rm = RegisteredModel("model_1") self.store.update_registered_model(registered_model=rm, new_name="model_2") self._verify_requests( mock_http, "registered-models/update", "PATCH", UpdateRegisteredModel(registered_model=rm.to_proto(), name="model_2"))
def test_update_registered_model_description(self, mock_http): rm = RegisteredModel("model_1") self.store.update_registered_model(registered_model=rm, description="test model") self._verify_requests( mock_http, "registered-models/update", "PATCH", UpdateRegisteredModel(registered_model=rm.to_proto(), description="test model"))
def test_search_model_versions(mock_store): mock_store.search_model_versions.return_value = [ ModelVersion(RegisteredModel("Model 1"), 1), ModelVersion(RegisteredModel("Model 1"), 2) ] result = newModelRegistryClient().search_model_versions("name=Model 1") mock_store.search_model_versions.assert_called_once_with("name=Model 1") assert len(result) == 2
def test_update_registered_model_all(self, mock_http): rm = RegisteredModel("model_1") self.store.update_registered_model(registered_model=rm, new_name="model_3", description="rename and describe") self._verify_requests( mock_http, "registered-models/update", "PATCH", UpdateRegisteredModel(registered_model=rm.to_proto(), name="model_3", description="rename and describe"))
def test_get_latest_versions(mock_get_request_message, mock_model_registry_store): rm = RegisteredModel("model1") mock_get_request_message.return_value = GetLatestVersions( registered_model=rm.to_proto()) mvds = [ ModelVersionDetailed(registered_model=rm, version=5, creation_timestamp=1, last_updated_timestamp=12, description="v 5", user_id="u1", current_stage="Production", source="A/B", run_id=uuid.uuid4().hex, status="READY", status_message=None), ModelVersionDetailed(registered_model=rm, version=1, creation_timestamp=1, last_updated_timestamp=1200, description="v 1", user_id="u1", current_stage="Archived", source="A/B2", run_id=uuid.uuid4().hex, status="READY", status_message=None), ModelVersionDetailed(registered_model=rm, version=12, creation_timestamp=100, last_updated_timestamp=None, description="v 12", user_id="u2", current_stage="Staging", source="A/B3", run_id=uuid.uuid4().hex, status="READY", status_message=None), ] mock_model_registry_store.get_latest_versions.return_value = mvds resp = _get_latest_versions() args, _ = mock_model_registry_store.get_latest_versions.call_args assert args == (rm, []) assert json.loads(resp.get_data()) == { "model_versions_detailed": jsonify(mvds) } for stages in [[], ["None"], ["Staging"], ["Staging", "Production"]]: mock_get_request_message.return_value = GetLatestVersions( registered_model=rm.to_proto(), stages=stages) _get_latest_versions() args, _ = mock_model_registry_store.get_latest_versions.call_args assert args == (rm, stages)
def test_register_model_with_non_runs_uri(): create_model_patch = mock.patch.object(MlflowClient, "create_registered_model", return_value=RegisteredModel("Model 1")) create_version_patch = mock.patch.object( MlflowClient, "create_model_version", return_value=ModelVersion(RegisteredModel("Model 1"), 1)) with create_model_patch, create_version_patch: register_model("s3:/some/path/to/model", "Model 1") MlflowClient.create_registered_model.assert_called_once_with("Model 1") MlflowClient.create_model_version.assert_called_once_with("Model 1", run_id=None, source="s3:/some/path/to/model")
def test_update_registered_model(mock_get_request_message, mock_model_registry_store): rm1 = RegisteredModel("model_1") mock_get_request_message.return_value = UpdateRegisteredModel(registered_model=rm1.to_proto(), name="model_2", description="Test model") rm2 = RegisteredModel("model_2") mock_model_registry_store.update_registered_model.return_value = rm2 resp = _update_registered_model() args, _ = mock_model_registry_store.update_registered_model.call_args assert args == (rm1, u"model_2", u"Test model") assert json.loads(resp.get_data()) == {"registered_model": jsonify(rm2)}
def test_rename_registered_model(mock_store): name = "Model 1" new_name = "New Name" mock_store.rename_registered_model.return_value = RegisteredModel(new_name) result = newModelRegistryClient().rename_registered_model(name=name, new_name=new_name) mock_store.rename_registered_model.assert_called_with(name=name, new_name=new_name) assert result.name == "New Name" mock_store.rename_registered_model.return_value = RegisteredModel("New Name 2") result = newModelRegistryClient().rename_registered_model(name=name, new_name="New Name 2") mock_store.rename_registered_model.assert_called_with(name=name, new_name="New Name 2") assert result.name == "New Name 2"
def test_get_registered_model_details(mock_get_request_message, mock_model_registry_store): rm = RegisteredModel("model1") mock_get_request_message.return_value = GetRegisteredModelDetails( registered_model=rm.to_proto()) rmd = RegisteredModelDetailed(name="model_1", creation_timestamp=111, last_updated_timestamp=222, description="Test model", latest_versions=[]) mock_model_registry_store.get_registered_model_details.return_value = rmd resp = _get_registered_model_details() args, _ = mock_model_registry_store.get_registered_model_details.call_args assert args == (rm, ) assert json.loads(resp.get_data()) == {"registered_model_detailed": jsonify(rmd)}
def test_register_model_with_runs_uri(): create_model_patch = mock.patch.object(MlflowClient, "create_registered_model", return_value=RegisteredModel("Model 1")) get_uri_patch = mock.patch( "mlflow.store.artifact.runs_artifact_repo.RunsArtifactRepository.get_underlying_uri", return_value="s3:/path/to/source") create_version_patch = mock.patch.object( MlflowClient, "create_model_version", return_value=ModelVersion(RegisteredModel("Model 1"), 1)) with get_uri_patch, create_model_patch, create_version_patch: register_model("runs:/run12345/path/to/model", "Model 1") MlflowClient.create_registered_model.assert_called_once_with("Model 1") MlflowClient.create_model_version.assert_called_once_with("Model 1", "s3:/path/to/source", "run12345")
def test_list_registered_models(mock_get_request_message, mock_model_registry_store): mock_get_request_message.return_value = ListRegisteredModels() rmds = [ RegisteredModel(name="model_1", creation_timestamp=111, last_updated_timestamp=222, description="Test model", latest_versions=[]), RegisteredModel(name="model_2", creation_timestamp=111, last_updated_timestamp=333, description="Another model", latest_versions=[]), ] mock_model_registry_store.list_registered_models.return_value = rmds resp = _list_registered_models() args, _ = mock_model_registry_store.list_registered_models.call_args assert args == () assert json.loads(resp.get_data()) == {"registered_models": jsonify(rmds)}
def _get_registered_model_details(): request_message = _get_request_message(GetRegisteredModelDetails()) registered_model_detailed = _get_model_registry_store().get_registered_model_details( RegisteredModel.from_proto(request_message.registered_model)) response_message = GetRegisteredModelDetails.Response( registered_model_detailed=registered_model_detailed.to_proto()) return _wrap_response(response_message)
def test_get_model_version_download_uri(self, mock_http): rm = RegisteredModel("model_11") mv = ModelVersion(rm, 8) self.store.get_model_version_download_uri(model_version=mv) self._verify_requests( mock_http, "model-versions/get-download-uri", "POST", GetModelVersionDownloadUri(model_version=mv.to_proto()))
def _get_latest_versions(): request_message = _get_request_message(GetLatestVersions()) latest_versions = _get_model_registry_store().get_latest_versions( RegisteredModel.from_proto(request_message.registered_model), request_message.stages) response_message = GetLatestVersions.Response() response_message.model_versions_detailed.extend([e.to_proto() for e in latest_versions]) return _wrap_response(response_message)
def search_registered_models(self, filter_string=None, max_results=None, order_by=None, page_token=None): """ Search for registered models in backend that satisfy the filter criteria. :param filter_string: Filter query string, defaults to searching all registered models. :param max_results: Maximum number of registered models desired. :param order_by: List of column names with ASC|DESC annotation, to be used for ordering matching search results. :param page_token: Token specifying the next page of results. It should be obtained from a ``search_registered_models`` call. :return: A PagedList of :py:class:`mlflow.entities.model_registry.RegisteredModel` objects that satisfy the search expressions. The pagination token for the next page can be obtained via the ``token`` attribute of the object. """ req_body = message_to_json( SearchRegisteredModels(filter=filter_string, max_results=max_results, order_by=order_by, page_token=page_token)) response_proto = self._call_endpoint(SearchRegisteredModels, req_body) registered_models = [ RegisteredModel.from_proto(registered_model) for registered_model in response_proto.registered_models ] return PagedList(registered_models, response_proto.next_page_token)
def test_delete_model_version(mock_get_request_message, mock_model_registry_store): rm = RegisteredModel("model1") mv = ModelVersion(registered_model=rm, version=32) mock_get_request_message.return_value = DeleteModelVersion(model_version=mv.to_proto()) _delete_model_version() args, _ = mock_model_registry_store.delete_model_version.call_args assert args == (mv, )
def test_get_model_version_stages(self, mock_http): rm = RegisteredModel("model_11") mv = ModelVersion(rm, 8) self.store.get_model_version_stages(model_version=mv) self._verify_requests( mock_http, "model-versions/get-stages", "POST", GetModelVersionStages(model_version=mv.to_proto()))
def test_update_model_version_stage(self, mock_http): rm = RegisteredModel("model_1") mv = ModelVersion(rm, 5) self.store.update_model_version(model_version=mv, stage="prod") self._verify_requests( mock_http, "model-versions/update", "PATCH", UpdateModelVersion(model_version=mv.to_proto(), stage="prod"))
def test_create_registered_model(mock_get_request_message, mock_model_registry_store): mock_get_request_message.return_value = CreateRegisteredModel(name="model_1") rm = RegisteredModel("model_1") mock_model_registry_store.create_registered_model.return_value = rm resp = _create_registered_model() _, args = mock_model_registry_store.create_registered_model.call_args assert args == {"name": "model_1"} assert json.loads(resp.get_data()) == {"registered_model": jsonify(rm)}
def test_create_registered_model(mock_store): tags_dict = {"key": "value", "another key": "some other value"} tags = [RegisteredModelTag(key, value) for key, value in tags_dict.items()] mock_store.create_registered_model.return_value = RegisteredModel("Model 1", tags=tags) result = newModelRegistryClient().create_registered_model("Model 1", tags_dict) mock_store.create_registered_model.assert_called_once_with("Model 1", tags) assert result.name == "Model 1" assert result.tags == tags_dict
def delete_registered_model(self, name): """ Delete registered model. Backend raises exception if a registered model with given name does not exist. :param name: Name of the registered model to update. """ self.store.delete_registered_model(RegisteredModel(name))
def get_model_version_details(self, name, version): """ :param name: Name of the containing registered model. :param version: Version number of the model version. :return: A single :py:class:`mlflow.entities.model_registry.ModelVersionDetailed` object. """ return self.store.get_model_version_details( ModelVersion(RegisteredModel(name), version))