def test_search_registered_model(self, mock_http): self.store.search_registered_models() self._verify_requests(mock_http, "registered-models/search", "GET", SearchRegisteredModels()) params_list = [ { "filter_string": "model = 'yo'" }, { "max_results": 400 }, { "page_token": "blah" }, { "order_by": ["x", "Y"] }, ] # test all combination of params for sz in [0, 1, 2, 3, 4]: for combination in combinations(params_list, sz): params = {k: v for d in combination for k, v in d.items()} self.store.search_registered_models(**params) if "filter_string" in params: params["filter"] = params.pop("filter_string") self._verify_requests(mock_http, "registered-models/search", "GET", SearchRegisteredModels(**params))
def test_search_registered_models(mock_get_request_message, mock_model_registry_store): rmds = [ RegisteredModel( name="model_1", creation_timestamp=111, last_updated_timestamp=222, description="Test model", latest_versions=[], ), RegisteredModel( name="model_2", creation_timestamp=111, last_updated_timestamp=333, description="Another model", latest_versions=[], ), ] mock_get_request_message.return_value = SearchRegisteredModels() mock_model_registry_store.search_registered_models.return_value = PagedList(rmds, None) resp = _search_registered_models() _, args = mock_model_registry_store.search_registered_models.call_args assert args == {"filter_string": "", "max_results": 100, "order_by": [], "page_token": ""} assert json.loads(resp.get_data()) == {"registered_models": jsonify(rmds)} mock_get_request_message.return_value = SearchRegisteredModels(filter="hello") mock_model_registry_store.search_registered_models.return_value = PagedList(rmds[:1], "tok") resp = _search_registered_models() _, args = mock_model_registry_store.search_registered_models.call_args assert args == {"filter_string": "hello", "max_results": 100, "order_by": [], "page_token": ""} assert json.loads(resp.get_data()) == { "registered_models": jsonify(rmds[:1]), "next_page_token": "tok", } mock_get_request_message.return_value = SearchRegisteredModels(filter="hi", max_results=5) mock_model_registry_store.search_registered_models.return_value = PagedList([rmds[0]], "tik") resp = _search_registered_models() _, args = mock_model_registry_store.search_registered_models.call_args assert args == {"filter_string": "hi", "max_results": 5, "order_by": [], "page_token": ""} assert json.loads(resp.get_data()) == { "registered_models": jsonify([rmds[0]]), "next_page_token": "tik", } mock_get_request_message.return_value = SearchRegisteredModels( filter="hey", max_results=500, order_by=["a", "B desc"], page_token="prev" ) mock_model_registry_store.search_registered_models.return_value = PagedList(rmds, "DONE") resp = _search_registered_models() _, args = mock_model_registry_store.search_registered_models.call_args assert args == { "filter_string": "hey", "max_results": 500, "order_by": ["a", "B desc"], "page_token": "prev", } assert json.loads(resp.get_data()) == { "registered_models": jsonify(rmds), "next_page_token": "DONE", }
def _search_registered_models(): request_message = _get_request_message(SearchRegisteredModels()) store = _get_model_registry_store() registered_models = store.search_registered_models(filter_string=request_message.filter, max_results=request_message.max_results, order_by=request_message.order_by, page_token=request_message.page_token) response_message = SearchRegisteredModels.Response() response_message.registered_models.extend([e.to_proto() for e in registered_models]) if registered_models.token: response_message.next_page_token = registered_models.token return _wrap_response(response_message)
def search_registered_models(self, filter_string=None, max_results=None, order_by=None, page_token=None): """ Search for registered models in backend that satisfy the filter criteria. :param filter_string: Filter query string, defaults to searching all registered models. :param max_results: Maximum number of registered models desired. :param order_by: List of column names with ASC|DESC annotation, to be used for ordering matching search results. :param page_token: Token specifying the next page of results. It should be obtained from a ``search_registered_models`` call. :return: A PagedList of :py:class:`mlflow.entities.model_registry.RegisteredModel` objects that satisfy the search expressions. The pagination token for the next page can be obtained via the ``token`` attribute of the object. """ req_body = message_to_json( SearchRegisteredModels(filter=filter_string, max_results=max_results, order_by=order_by, page_token=page_token)) response_proto = self._call_endpoint(SearchRegisteredModels, req_body) registered_models = [ RegisteredModel.from_proto(registered_model) for registered_model in response_proto.registered_models ] return PagedList(registered_models, response_proto.next_page_token)