Beispiel #1
0
 def test_search_registered_model(self, mock_http):
     self.store.search_registered_models()
     self._verify_requests(mock_http, "registered-models/search", "GET",
                           SearchRegisteredModels())
     params_list = [
         {
             "filter_string": "model = 'yo'"
         },
         {
             "max_results": 400
         },
         {
             "page_token": "blah"
         },
         {
             "order_by": ["x", "Y"]
         },
     ]
     # test all combination of params
     for sz in [0, 1, 2, 3, 4]:
         for combination in combinations(params_list, sz):
             params = {k: v for d in combination for k, v in d.items()}
             self.store.search_registered_models(**params)
             if "filter_string" in params:
                 params["filter"] = params.pop("filter_string")
             self._verify_requests(mock_http, "registered-models/search",
                                   "GET", SearchRegisteredModels(**params))
def test_search_registered_models(mock_get_request_message, mock_model_registry_store):
    rmds = [
        RegisteredModel(
            name="model_1",
            creation_timestamp=111,
            last_updated_timestamp=222,
            description="Test model",
            latest_versions=[],
        ),
        RegisteredModel(
            name="model_2",
            creation_timestamp=111,
            last_updated_timestamp=333,
            description="Another model",
            latest_versions=[],
        ),
    ]
    mock_get_request_message.return_value = SearchRegisteredModels()
    mock_model_registry_store.search_registered_models.return_value = PagedList(rmds, None)
    resp = _search_registered_models()
    _, args = mock_model_registry_store.search_registered_models.call_args
    assert args == {"filter_string": "", "max_results": 100, "order_by": [], "page_token": ""}
    assert json.loads(resp.get_data()) == {"registered_models": jsonify(rmds)}

    mock_get_request_message.return_value = SearchRegisteredModels(filter="hello")
    mock_model_registry_store.search_registered_models.return_value = PagedList(rmds[:1], "tok")
    resp = _search_registered_models()
    _, args = mock_model_registry_store.search_registered_models.call_args
    assert args == {"filter_string": "hello", "max_results": 100, "order_by": [], "page_token": ""}
    assert json.loads(resp.get_data()) == {
        "registered_models": jsonify(rmds[:1]),
        "next_page_token": "tok",
    }

    mock_get_request_message.return_value = SearchRegisteredModels(filter="hi", max_results=5)
    mock_model_registry_store.search_registered_models.return_value = PagedList([rmds[0]], "tik")
    resp = _search_registered_models()
    _, args = mock_model_registry_store.search_registered_models.call_args
    assert args == {"filter_string": "hi", "max_results": 5, "order_by": [], "page_token": ""}
    assert json.loads(resp.get_data()) == {
        "registered_models": jsonify([rmds[0]]),
        "next_page_token": "tik",
    }

    mock_get_request_message.return_value = SearchRegisteredModels(
        filter="hey", max_results=500, order_by=["a", "B desc"], page_token="prev"
    )
    mock_model_registry_store.search_registered_models.return_value = PagedList(rmds, "DONE")
    resp = _search_registered_models()
    _, args = mock_model_registry_store.search_registered_models.call_args
    assert args == {
        "filter_string": "hey",
        "max_results": 500,
        "order_by": ["a", "B desc"],
        "page_token": "prev",
    }
    assert json.loads(resp.get_data()) == {
        "registered_models": jsonify(rmds),
        "next_page_token": "DONE",
    }
Beispiel #3
0
def _search_registered_models():
    request_message = _get_request_message(SearchRegisteredModels())
    store = _get_model_registry_store()
    registered_models = store.search_registered_models(filter_string=request_message.filter,
                                                       max_results=request_message.max_results,
                                                       order_by=request_message.order_by,
                                                       page_token=request_message.page_token)
    response_message = SearchRegisteredModels.Response()
    response_message.registered_models.extend([e.to_proto() for e in registered_models])
    if registered_models.token:
        response_message.next_page_token = registered_models.token
    return _wrap_response(response_message)
Beispiel #4
0
    def search_registered_models(self,
                                 filter_string=None,
                                 max_results=None,
                                 order_by=None,
                                 page_token=None):
        """
        Search for registered models in backend that satisfy the filter criteria.

        :param filter_string: Filter query string, defaults to searching all registered models.
        :param max_results: Maximum number of registered models desired.
        :param order_by: List of column names with ASC|DESC annotation, to be used for ordering
                         matching search results.
        :param page_token: Token specifying the next page of results. It should be obtained from
                            a ``search_registered_models`` call.
        :return: A PagedList of :py:class:`mlflow.entities.model_registry.RegisteredModel` objects
                that satisfy the search expressions. The pagination token for the next page can be
                obtained via the ``token`` attribute of the object.
        """
        req_body = message_to_json(
            SearchRegisteredModels(filter=filter_string,
                                   max_results=max_results,
                                   order_by=order_by,
                                   page_token=page_token))
        response_proto = self._call_endpoint(SearchRegisteredModels, req_body)
        registered_models = [
            RegisteredModel.from_proto(registered_model)
            for registered_model in response_proto.registered_models
        ]
        return PagedList(registered_models, response_proto.next_page_token)