def test_search_registered_models(mock_store): mock_store.search_registered_models.return_value = PagedList( [RegisteredModel("Model 1"), RegisteredModel("Model 2")], "" ) result = newModelRegistryClient().search_registered_models(filter_string="test filter") mock_store.search_registered_models.assert_called_with("test filter", 100, None, None) assert len(result) == 2 assert result.token == "" result = newModelRegistryClient().search_registered_models( filter_string="another filter", max_results=12, order_by=["A", "B DESC"], page_token="next one", ) mock_store.search_registered_models.assert_called_with( "another filter", 12, ["A", "B DESC"], "next one" ) assert len(result) == 2 assert result.token == "" mock_store.search_registered_models.return_value = PagedList( [RegisteredModel("model A"), RegisteredModel("Model zz"), RegisteredModel("Model b")], "page 2 token", ) result = newModelRegistryClient().search_registered_models(max_results=5) mock_store.search_registered_models.assert_called_with(None, 5, None, None) assert [rm.name for rm in result] == ["model A", "Model zz", "Model b"] assert result.token == "page 2 token"
def test_list_run_infos(): experiment_id = mock.Mock() view_type = mock.Mock() run_infos = [mock.Mock(), mock.Mock()] runs = [mock.Mock(info=info) for info in run_infos] token = "adfoiweroh12334kj129318934u" with mock.patch.object( AbstractStoreTestImpl, "search_runs", return_value=PagedList(runs, token) ): store = AbstractStoreTestImpl() result = store.list_run_infos(experiment_id, view_type) for i in range(len(result)): assert result[i] == run_infos[i] assert result.token == token store.search_runs.assert_called_once_with( [experiment_id], None, view_type, SEARCH_MAX_RESULTS_DEFAULT, None, None ) run_infos = [mock.Mock()] runs = [mock.Mock(info=info) for info in run_infos] with mock.patch.object( AbstractStoreTestImpl, "search_runs", return_value=PagedList(runs, None) ): store = AbstractStoreTestImpl() result = store.list_run_infos(experiment_id, view_type, page_token=token) for i in range(len(result)): assert result[i] == run_infos[i] assert result.token is None store.search_runs.assert_called_once_with( [experiment_id], None, view_type, SEARCH_MAX_RESULTS_DEFAULT, None, token )
def test_search_registered_models(mock_get_request_message, mock_model_registry_store): rmds = [ RegisteredModel( name="model_1", creation_timestamp=111, last_updated_timestamp=222, description="Test model", latest_versions=[], ), RegisteredModel( name="model_2", creation_timestamp=111, last_updated_timestamp=333, description="Another model", latest_versions=[], ), ] mock_get_request_message.return_value = SearchRegisteredModels() mock_model_registry_store.search_registered_models.return_value = PagedList(rmds, None) resp = _search_registered_models() _, args = mock_model_registry_store.search_registered_models.call_args assert args == {"filter_string": "", "max_results": 100, "order_by": [], "page_token": ""} assert json.loads(resp.get_data()) == {"registered_models": jsonify(rmds)} mock_get_request_message.return_value = SearchRegisteredModels(filter="hello") mock_model_registry_store.search_registered_models.return_value = PagedList(rmds[:1], "tok") resp = _search_registered_models() _, args = mock_model_registry_store.search_registered_models.call_args assert args == {"filter_string": "hello", "max_results": 100, "order_by": [], "page_token": ""} assert json.loads(resp.get_data()) == { "registered_models": jsonify(rmds[:1]), "next_page_token": "tok", } mock_get_request_message.return_value = SearchRegisteredModels(filter="hi", max_results=5) mock_model_registry_store.search_registered_models.return_value = PagedList([rmds[0]], "tik") resp = _search_registered_models() _, args = mock_model_registry_store.search_registered_models.call_args assert args == {"filter_string": "hi", "max_results": 5, "order_by": [], "page_token": ""} assert json.loads(resp.get_data()) == { "registered_models": jsonify([rmds[0]]), "next_page_token": "tik", } mock_get_request_message.return_value = SearchRegisteredModels( filter="hey", max_results=500, order_by=["a", "B desc"], page_token="prev" ) mock_model_registry_store.search_registered_models.return_value = PagedList(rmds, "DONE") resp = _search_registered_models() _, args = mock_model_registry_store.search_registered_models.call_args assert args == { "filter_string": "hey", "max_results": 500, "order_by": ["a", "B desc"], "page_token": "prev", } assert json.loads(resp.get_data()) == { "registered_models": jsonify(rmds), "next_page_token": "DONE", }
def _list_experiments( self, ids=None, names=None, view_type=ViewType.ACTIVE_ONLY, max_results=None, page_token=None, eager=False, ): """ :param eager: If ``True``, eagerly loads each experiments's tags. If ``False``, these tags are not eagerly loaded and will be loaded if/when their corresponding object properties are accessed from a resulting ``SqlExperiment`` object. """ stages = LifecycleStage.view_type_to_stages(view_type) conditions = [SqlExperiment.lifecycle_stage.in_(stages)] if ids and len(ids) > 0: int_ids = [int(eid) for eid in ids] conditions.append(SqlExperiment.experiment_id.in_(int_ids)) if names and len(names) > 0: conditions.append(SqlExperiment.name.in_(names)) max_results_for_query = None if max_results is not None: max_results_for_query = max_results + 1 def compute_next_token(current_size): next_token = None if max_results_for_query == current_size: final_offset = offset + max_results next_token = SearchUtils.create_page_token(final_offset) return next_token with self.ManagedSessionMaker() as session: query_options = self._get_eager_experiment_query_options( ) if eager else [] if max_results is not None: offset = SearchUtils.parse_start_offset_from_page_token( page_token) queried_experiments = (session.query(SqlExperiment).options( *query_options).order_by( SqlExperiment.experiment_id).filter( *conditions).offset(offset).limit( max_results_for_query).all()) else: queried_experiments = (session.query(SqlExperiment).options( *query_options).filter(*conditions).all()) experiments = [ exp.to_mlflow_entity() for exp in queried_experiments ] if max_results is not None: return PagedList(experiments[:max_results], compute_next_token(len(experiments))) else: return PagedList(experiments, None)
def test_paginate_lt_maxresults_multipage(): """ Number of runs is less than max_results, but multiple pages are necessary to get all runs """ tokenized_runs = PagedList([create_run() for _ in range(10)], "token") no_token_runs = PagedList([create_run()], "") max_results = 50 max_per_page = 10 mocked_lambda = mock.Mock(side_effect=[tokenized_runs, tokenized_runs, no_token_runs]) TOTAL_RUNS = 21 paginated_runs = _paginate(mocked_lambda, max_per_page, max_results) assert len(paginated_runs) == TOTAL_RUNS
def test_get_paginated_runs_lt_maxresults_multipage(): """ Number of runs is less than max_results, but multiple pages are necessary to get all runs """ tokenized_runs = PagedList([create_run() for i in range(10)], "token") no_token_runs = PagedList([create_run()], "") max_results = 50 with mock.patch("mlflow.tracking.fluent.NUM_RUNS_PER_PAGE_PANDAS", 10): with mock.patch.object(MlflowClient, "search_runs"): MlflowClient.search_runs.side_effect = [tokenized_runs, tokenized_runs, no_token_runs] TOTAL_RUNS = 21 paginated_runs = _get_paginated_runs([], "", ViewType.ACTIVE_ONLY, max_results, None) assert len(paginated_runs) == TOTAL_RUNS
def list_experiments( self, view_type=ViewType.ACTIVE_ONLY, max_results=None, page_token=None, ): """ :param view_type: Qualify requested type of experiments. :param max_results: If passed, specifies the maximum number of experiments desired. If not passed, all experiments will be returned. :param page_token: Token specifying the next page of results. It should be obtained from a ``list_experiments`` call. :return: A :py:class:`PagedList <mlflow.store.entities.PagedList>` of :py:class:`Experiment <mlflow.entities.Experiment>` objects. The pagination token for the next page can be obtained via the ``token`` attribute of the object. """ from mlflow.utils.search_utils import SearchUtils from mlflow.store.entities.paged_list import PagedList _validate_list_experiments_max_results(max_results) self._check_root_dir() rsl = [] if view_type == ViewType.ACTIVE_ONLY or view_type == ViewType.ALL: rsl += self._get_active_experiments(full_path=False) if view_type == ViewType.DELETED_ONLY or view_type == ViewType.ALL: rsl += self._get_deleted_experiments(full_path=False) experiments = [] for exp_id in rsl: try: # trap and warn known issues, will raise unexpected exceptions to caller experiment = self._get_experiment(exp_id, view_type) if experiment: experiments.append(experiment) except MissingConfigException as rnfe: # Trap malformed experiments and log warnings. logging.warning( "Malformed experiment '%s'. Detailed error %s", str(exp_id), str(rnfe), exc_info=True, ) if max_results is not None: experiments, next_page_token = SearchUtils.paginate( experiments, page_token, max_results ) return PagedList(experiments, next_page_token) else: return PagedList(experiments, None)
def list_run_infos( self, experiment_id, run_view_type, max_results=SEARCH_MAX_RESULTS_DEFAULT, order_by=None, page_token=None, ): """ Return run information for runs which belong to the experiment_id. :param experiment_id: The experiment id which to search :param run_view_type: ACTIVE_ONLY, DELETED_ONLY, or ALL runs :param max_results: Maximum number of results desired. :param order_by: List of order_by clauses. :param page_token: Token specifying the next page of results. It should be obtained from a ``list_run_infos`` call. :return: A list of :py:class:`mlflow.entities.RunInfo` objects that satisfy the search expressions. The pagination token for the next page can be obtained via the ``token`` attribute of the object; however, some store implementations may not support pagination and thus the returned token would not be meaningful in such cases. """ search_result = self.search_runs([experiment_id], None, run_view_type, max_results, order_by, page_token) return PagedList([run.info for run in search_result], search_result.token)
def search_runs( self, experiment_ids, filter_string, run_view_type, max_results=SEARCH_MAX_RESULTS_DEFAULT, order_by=None, page_token=None, ): """ Return runs that match the given list of search expressions within the experiments. :param experiment_ids: List of experiment ids to scope the search :param filter_string: A search filter string. :param run_view_type: ACTIVE_ONLY, DELETED_ONLY, or ALL runs :param max_results: Maximum number of runs desired. :param order_by: List of order_by clauses. :param page_token: Token specifying the next page of results. It should be obtained from a ``search_runs`` call. :return: A list of :py:class:`mlflow.entities.Run` objects that satisfy the search expressions. The pagination token for the next page can be obtained via the ``token`` attribute of the object; however, some store implementations may not support pagination and thus the returned token would not be meaningful in such cases. """ runs, token = self._search_runs(experiment_ids, filter_string, run_view_type, max_results, order_by, page_token) return PagedList(runs, token)
def test_list_registered_models(mock_get_request_message, mock_model_registry_store): mock_get_request_message.return_value = ListRegisteredModels( max_results=50) rmds = PagedList( [ RegisteredModel( name="model_1", creation_timestamp=111, last_updated_timestamp=222, description="Test model", latest_versions=[], ), RegisteredModel( name="model_2", creation_timestamp=111, last_updated_timestamp=333, description="Another model", latest_versions=[], ), ], "next_pt", ) mock_model_registry_store.list_registered_models.return_value = rmds resp = _list_registered_models() args, _ = mock_model_registry_store.list_registered_models.call_args assert args == (50, "") assert json.loads(resp.get_data()) == { "next_page_token": "next_pt", "registered_models": jsonify(rmds), }
def search_registered_models(self, filter_string=None, max_results=None, order_by=None, page_token=None): """ Search for registered models in backend that satisfy the filter criteria. :param filter_string: Filter query string, defaults to searching all registered models. :param max_results: Maximum number of registered models desired. :param order_by: List of column names with ASC|DESC annotation, to be used for ordering matching search results. :param page_token: Token specifying the next page of results. It should be obtained from a ``search_registered_models`` call. :return: A PagedList of :py:class:`mlflow.entities.model_registry.RegisteredModel` objects that satisfy the search expressions. The pagination token for the next page can be obtained via the ``token`` attribute of the object. """ req_body = message_to_json( SearchRegisteredModels(filter=filter_string, max_results=max_results, order_by=order_by, page_token=page_token)) response_proto = self._call_endpoint(SearchRegisteredModels, req_body) registered_models = [ RegisteredModel.from_proto(registered_model) for registered_model in response_proto.registered_models ] return PagedList(registered_models, response_proto.next_page_token)
def test_list_registered_models(mock_store): mock_store.list_registered_models.return_value = PagedList( [RegisteredModel("Model 1"), RegisteredModel("Model 2")], "" ) result = newModelRegistryClient().list_registered_models() mock_store.list_registered_models.assert_called_once() assert len(result) == 2
def list_experiments( self, view_type=ViewType.ACTIVE_ONLY, max_results=None, page_token=None, ): """ :param view_type: Qualify requested type of experiments. :param max_results: If passed, specifies the maximum number of experiments desired. If not passed, the server will pick a maximum number of results to return. :param page_token: Token specifying the next page of results. It should be obtained from a ``list_experiments`` call. :return: A :py:class:`PagedList <mlflow.store.entities.PagedList>` of :py:class:`Experiment <mlflow.entities.Experiment>` objects. The pagination token for the next page can be obtained via the ``token`` attribute of the object. """ req_body = message_to_json( ListExperiments(view_type=view_type, max_results=max_results, page_token=page_token)) response_proto = self._call_endpoint(ListExperiments, req_body) experiments = [ Experiment.from_proto(x) for x in response_proto.experiments ] # If the response doesn't contain `next_page_token`, `response_proto.next_page_token` # returns an empty string (default value for a string proto field). token = (response_proto.next_page_token if response_proto.HasField("next_page_token") else None) return PagedList(experiments, token)
def test_paginate_gt_maxresults_multipage(): """ Number of runs that fit search criteria is greater than max_results. Multiple pages expected. Expected to only get max_results number of results back. """ # should ask for and return the correct number of max_results full_page_runs = PagedList([create_run() for _ in range(8)], "abc") partial_page = PagedList([create_run() for _ in range(4)], "def") max_results = 20 max_per_page = 8 mocked_lambda = mock.Mock(side_effect=[full_page_runs, full_page_runs, partial_page]) paginated_runs = _paginate(mocked_lambda, max_per_page, max_results) calls = [mock.call(8, None), mock.call(8, "abc"), mock.call(20 % 8, "abc")] mocked_lambda.assert_has_calls(calls) assert len(paginated_runs) == 20
def test_get_paginated_runs_eq_maxresults_token(): """ Runs returned are equal to max_results which are equal to a full number of pages. The server might send a token back, or they might not (depending on if they know if more runs exist). In this example, a toke IS sent back. Expected behavior is to NOT query for more pages. """ runs = [create_run() for i in range(10)] tokenized_runs = PagedList(runs, "abc") blank_runs = PagedList([], "") max_results = 10 with mock.patch("mlflow.tracking.fluent.NUM_RUNS_PER_PAGE_PANDAS", 10): with mock.patch.object(MlflowClient, "search_runs"): MlflowClient.search_runs.side_effect = [tokenized_runs, blank_runs] paginated_runs = _get_paginated_runs([], "", ViewType.ACTIVE_ONLY, max_results, None) MlflowClient.search_runs.assert_called_once() assert len(paginated_runs) == 10
def test_paginate_eq_maxresults_token(): """ Runs returned are equal to max_results which are equal to a full number of pages. The server might send a token back, or they might not (depending on if they know if more runs exist). In this example, a token IS sent back. Expected behavior is to NOT query for more pages. """ runs = [create_run() for _ in range(10)] tokenized_runs = PagedList(runs, "abc") blank_runs = PagedList([], "") max_results = 10 max_per_page = 10 mocked_lambda = mock.Mock(side_effect=[tokenized_runs, blank_runs]) paginated_runs = _paginate(mocked_lambda, max_per_page, max_results) mocked_lambda.assert_called_once() assert len(paginated_runs) == 10
def test_get_paginated_runs_gt_maxresults_multipage(): """ Number of runs that fit search criteria is greater than max_results. Multiple pages expected. Expected to only get max_results number of results back. """ # should ask for and return the correct number of max_results full_page_runs = PagedList([create_run() for i in range(8)], "abc") partial_page = PagedList([create_run() for i in range(4)], "def") max_results = 20 with mock.patch("mlflow.tracking.fluent.NUM_RUNS_PER_PAGE_PANDAS", 8): with mock.patch.object(MlflowClient, "search_runs"): MlflowClient.search_runs.side_effect = [full_page_runs, full_page_runs, partial_page] paginated_runs = _get_paginated_runs([12], "", ViewType.ACTIVE_ONLY, max_results, None) calls = [mock.call([12], "", ViewType.ACTIVE_ONLY, 8, None, None), mock.call([12], "", ViewType.ACTIVE_ONLY, 8, None, "abc"), mock.call([12], "", ViewType.ACTIVE_ONLY, 20 % 8, None, "abc")] MlflowClient.search_runs.assert_has_calls(calls) assert len(paginated_runs) == 20
def test_search_runs_default_view_type(mock_get_request_message, mock_tracking_store): """ Search Runs default view type is filled in as ViewType.ACTIVE_ONLY """ mock_get_request_message.return_value = SearchRuns(experiment_ids=["0"]) mock_tracking_store.search_runs.return_value = PagedList([], None) _search_runs() args, _ = mock_tracking_store.search_runs.call_args assert args[2] == ViewType.ACTIVE_ONLY
def test_get_paginated_runs_lt_maxresults_onepage_nonetoken(): """ Number of runs is less than max_results and fits on one page. The token passed back on the last page is None, not the emptystring """ runs = [create_run() for i in range(5)] tokenized_runs = PagedList(runs, None) max_results = 50 with mock.patch("mlflow.tracking.fluent.NUM_RUNS_PER_PAGE_PANDAS", 10): with mock.patch.object(MlflowClient, "search_runs", return_value=tokenized_runs): paginated_runs = _get_paginated_runs([], "", ViewType.ACTIVE_ONLY, max_results, None) MlflowClient.search_runs.assert_called_once() assert len(paginated_runs) == 5
def test_paginate_gt_maxresults_onepage(): """" Number of runs that fit search criteria is greater than max_results. Only one page expected. Expected to only get max_results number of results back. """ runs = [create_run() for _ in range(10)] tokenized_runs = PagedList(runs, "abc") max_results = 10 max_per_page = 20 mocked_lambda = mock.Mock(return_value=tokenized_runs) paginated_runs = _paginate(mocked_lambda, max_per_page, max_results) mocked_lambda.assert_called_once_with(max_results, None) assert len(paginated_runs) == 10
def test_paginate_lt_maxresults_onepage_nonetoken(): """ Number of runs is less than max_results and fits on one page. The token passed back on the last page is None, not the emptystring """ runs = [create_run() for _ in range(5)] tokenized_runs = PagedList(runs, None) max_results = 50 max_per_page = 10 mocked_lambda = mock.Mock(return_value=tokenized_runs) paginated_runs = _paginate(mocked_lambda, max_per_page, max_results) mocked_lambda.assert_called_once() assert len(paginated_runs) == 5
def test_paginate_lt_maxresults_onepage(): """ Number of runs is less than max_results and fits on one page, so we only need to fetch one page. """ runs = [create_run() for _ in range(5)] tokenized_runs = PagedList(runs, "") max_results = 50 max_per_page = 10 mocked_lambda = mock.Mock(return_value=tokenized_runs) paginated_runs = _paginate(mocked_lambda, max_per_page, max_results) mocked_lambda.assert_called_once() assert len(paginated_runs) == 5
def search_model_versions(self, filter_string): """ Search for model versions in backend that satisfy the filter criteria. :param filter_string: A filter string expression. Currently supports a single filter condition either name of model like ``name = 'model_name'`` or ``run_id = '...'``. :return: PagedList of :py:class:`mlflow.entities.model_registry.ModelVersion` objects. """ parsed_filter = SearchUtils.parse_filter_for_model_versions( filter_string) if len(parsed_filter) == 0: conditions = [] elif len(parsed_filter) == 1: filter_dict = parsed_filter[0] if filter_dict["comparator"] != "=": raise MlflowException( "Model Registry search filter only supports equality(=) " "comparator. Input filter string: %s" % filter_string, error_code=INVALID_PARAMETER_VALUE, ) if filter_dict["key"] == "name": conditions = [SqlModelVersion.name == filter_dict["value"]] elif filter_dict["key"] == "source_path": conditions = [SqlModelVersion.source == filter_dict["value"]] elif filter_dict["key"] == "run_id": conditions = [SqlModelVersion.run_id == filter_dict["value"]] else: raise MlflowException("Invalid filter string: %s" % filter_string, error_code=INVALID_PARAMETER_VALUE) else: raise MlflowException( "Model Registry expects filter to be one of " "\"name = '<model_name>'\" or " "\"source_path = '<source_path>'\" or \"run_id = '<run_id>'." "Input filter string: %s. " % filter_string, error_code=INVALID_PARAMETER_VALUE, ) with self.ManagedSessionMaker() as session: conditions.append( SqlModelVersion.current_stage != STAGE_DELETED_INTERNAL) sql_model_version = session.query(SqlModelVersion).filter( *conditions).all() model_versions = [ mv.to_mlflow_entity() for mv in sql_model_version ] return PagedList(model_versions, None)
def test_get_paginated_runs_gt_maxresults_onepage(): """" Number of runs that fit search criteria is greater than max_results. Only one page expected. Expected to only get max_results number of results back. """ runs = [create_run() for i in range(10)] tokenized_runs = PagedList(runs, "abc") max_results = 10 with mock.patch("mlflow.tracking.fluent.NUM_RUNS_PER_PAGE_PANDAS", 20): with mock.patch.object(MlflowClient, "search_runs", return_value=tokenized_runs): paginated_runs = _get_paginated_runs([123], "", ViewType.ACTIVE_ONLY, max_results, None) MlflowClient.search_runs.assert_called_once_with( [123], "", ViewType.ACTIVE_ONLY, max_results, None, None) assert len(paginated_runs) == 10
def search_model_versions(self, filter_string): """ Search for model versions in backend that satisfy the filter criteria. :param filter_string: A filter string expression. Currently supports a single filter condition either name of model like ``name = 'model_name'`` or ``run_id = '...'``. :return: PagedList of :py:class:`mlflow.entities.model_registry.ModelVersion` objects. """ req_body = message_to_json(SearchModelVersions(filter=filter_string)) response_proto = self._call_endpoint(SearchModelVersions, req_body) model_versions = [ModelVersion.from_proto(mvd) for mvd in response_proto.model_versions] return PagedList(model_versions, response_proto.next_page_token)
def test_get_paginated_runs_lt_maxresults_onepage(): """ Number of runs is less than max_results and fits on one page, so we only need to fetch one page. """ runs = [create_run() for _ in range(5)] tokenized_runs = PagedList(runs, "") max_results = 50 with mock.patch("mlflow.tracking.fluent.NUM_RUNS_PER_PAGE_PANDAS", 10): with mock.patch.object(MlflowClient, "search_runs", return_value=tokenized_runs): paginated_runs = _get_paginated_runs([], "", ViewType.ACTIVE_ONLY, max_results, None) MlflowClient.search_runs.assert_called_once() assert len(paginated_runs) == 5
def list_registered_models(self, max_results, page_token): """ List of all registered models. :param max_results: Maximum number of registered models desired. :param page_token: Token specifying the next page of results. It should be obtained from a ``list_registered_models`` call. :return: PagedList of :py:class:`mlflow.entities.model_registry.RegisteredModel` objects. """ req_body = message_to_json( ListRegisteredModels(page_token=page_token, max_results=max_results)) response_proto = self._call_endpoint(ListRegisteredModels, req_body) return PagedList([ RegisteredModel.from_proto(registered_model) for registered_model in response_proto.registered_models ], response_proto.next_page_token)
def search_registered_models( self, filter_string=None, max_results=SEARCH_REGISTERED_MODEL_MAX_RESULTS_DEFAULT, order_by=None, page_token=None, ): """ Search for registered models in backend that satisfy the filter criteria. :param filter_string: Filter query string, defaults to searching all registered models. :param max_results: Maximum number of registered models desired. :param order_by: List of column names with ASC|DESC annotation, to be used for ordering matching search results. :param page_token: Token specifying the next page of results. It should be obtained from a ``search_registered_models`` call. :return: A PagedList of :py:class:`mlflow.entities.model_registry.RegisteredModel` objects that satisfy the search expressions. The pagination token for the next page can be obtained via the ``token`` attribute of the object. """ if max_results > SEARCH_REGISTERED_MODEL_MAX_RESULTS_THRESHOLD: raise MlflowException( "Invalid value for request parameter max_results. " "It must be at most {}, but got value {}".format( SEARCH_REGISTERED_MODEL_MAX_RESULTS_THRESHOLD, max_results), INVALID_PARAMETER_VALUE, ) parsed_filter = SearchUtils.parse_filter_for_registered_models( filter_string) parsed_orderby = self._parse_search_registered_models_order_by( order_by) offset = SearchUtils.parse_start_offset_from_page_token(page_token) # we query for max_results + 1 items to check whether there is another page to return. # this remediates having to make another query which returns no items. max_results_for_query = max_results + 1 def compute_next_token(current_size): next_token = None if max_results_for_query == current_size: final_offset = offset + max_results next_token = SearchUtils.create_page_token(final_offset) return next_token if len(parsed_filter) == 0: conditions = [] elif len(parsed_filter) == 1: filter_dict = parsed_filter[0] comparator = filter_dict["comparator"].upper() if comparator not in SearchUtils.VALID_REGISTERED_MODEL_SEARCH_COMPARATORS: raise MlflowException( "Search registered models filter expression only " "supports the equality(=) comparator, case-sensitive" "partial match (LIKE), and case-insensitive partial " "match (ILIKE). Input filter string: %s" % filter_string, error_code=INVALID_PARAMETER_VALUE, ) if comparator == SearchUtils.LIKE_OPERATOR: conditions = [ SqlRegisteredModel.name.like(filter_dict["value"]) ] elif comparator == SearchUtils.ILIKE_OPERATOR: conditions = [ SqlRegisteredModel.name.ilike(filter_dict["value"]) ] else: conditions = [SqlRegisteredModel.name == filter_dict["value"]] else: supported_ops = "".join([ "(" + op + ")" for op in SearchUtils.VALID_REGISTERED_MODEL_SEARCH_COMPARATORS ]) sample_query = 'name {} "<model_name>"'.format(supported_ops) raise MlflowException( "Invalid filter string: {}".format(filter_string) + "Search registered models supports filter expressions like:" + sample_query, error_code=INVALID_PARAMETER_VALUE, ) with self.ManagedSessionMaker() as session: query = (session.query(SqlRegisteredModel).filter( *conditions).order_by( *parsed_orderby).limit(max_results_for_query)) if page_token: query = query.offset(offset) sql_registered_models = query.all() next_page_token = compute_next_token(len(sql_registered_models)) rm_entities = [ rm.to_mlflow_entity() for rm in sql_registered_models ][:max_results] return PagedList(rm_entities, next_page_token)
def search_registered_models( self, filter_string=None, max_results=SEARCH_REGISTERED_MODEL_MAX_RESULTS_DEFAULT, order_by=None, page_token=None): """ Search for registered models in backend that satisfy the filter criteria. :param filter_string: Filter query string, defaults to searching all registered models. Currently supports a single filter condition based on the name of the model like ``name = 'model_name'`` :param max_results: Maximum number of registered models desired. :param order_by: List of column names with ASC|DESC annotation, to be used for ordering matching search results. Note:: This field is currently not supported. :param page_token: Token specifying the next page of results. It should be obtained from a ``search_registered_models`` call. :return: A PagedList of :py:class:`mlflow.entities.model_registry.RegisteredModel` objects that satisfy the search expressions. The pagination token for the next page can be obtained via the ``token`` attribute of the object. """ if order_by: raise NotImplementedError( "Order by is not implemented for search registered models.") if max_results > SEARCH_REGISTERED_MODEL_MAX_RESULTS_THRESHOLD: raise MlflowException( "Invalid value for request parameter max_results." "It must be at most {}, but got value {}".format( SEARCH_REGISTERED_MODEL_MAX_RESULTS_THRESHOLD, max_results), INVALID_PARAMETER_VALUE) parsed_filter = SearchUtils.parse_filter_for_registered_models( filter_string) offset = SearchUtils.parse_start_offset_from_page_token(page_token) # we query for max_results + 1 items to check whether there is another page to return. # this remediates having to make another query which returns no items. max_results_for_query = max_results + 1 def compute_next_token(current_size): next_token = None if max_results_for_query == current_size: final_offset = offset + max_results next_token = SearchUtils.create_page_token(final_offset) return next_token if len(parsed_filter) == 0: conditions = [] elif len(parsed_filter) == 1: filter_dict = parsed_filter[0] comparator = filter_dict['comparator'].upper() if comparator not in \ SearchUtils.VALID_REGISTERED_MODEL_SEARCH_COMPARATORS: raise MlflowException( 'Search registered models filter expression only ' 'supports the equality(=) comparator, case-sensitive' 'partial match (LIKE), and case-insensitive partial ' 'match (ILIKE). Input filter string: %s' % filter_string, error_code=INVALID_PARAMETER_VALUE) if comparator == SearchUtils.LIKE_OPERATOR: conditions = [ SqlRegisteredModel.name.like(filter_dict["value"]) ] elif comparator == SearchUtils.ILIKE_OPERATOR: conditions = [ SqlRegisteredModel.name.ilike(filter_dict["value"]) ] else: conditions = [SqlRegisteredModel.name == filter_dict["value"]] else: supported_ops = ''.join([ '(' + op + ')' for op in SearchUtils.VALID_REGISTERED_MODEL_SEARCH_COMPARATORS ]) sample_query = f'name {supported_ops} "<model_name>"' raise MlflowException( f'Invalid filter string: {filter_string}' 'Search registered models supports filter expressions like:' + sample_query, error_code=INVALID_PARAMETER_VALUE) with self.ManagedSessionMaker() as session: if self.db_type == SQLITE: session.execute("PRAGMA case_sensitive_like = true;") query = session\ .query(SqlRegisteredModel)\ .filter(*conditions)\ .order_by(SqlRegisteredModel.name.asc())\ .limit(max_results_for_query) if page_token: query = query.offset(offset) sql_registered_models = query.all() next_page_token = compute_next_token(len(sql_registered_models)) rm_entities = [ rm.to_mlflow_entity() for rm in sql_registered_models ][:max_results] return PagedList(rm_entities, next_page_token)
def search_registered_models(self, filter_string, page_token=None, max_results=SEARCH_REGISTERED_MODEL_MAX_RESULTS_DEFAULT): """ Search for registered models in backend that satisfy the filter criteria. :param filter_string: A filter string expression. Currently supports a single filter condition either name of model like ``name = 'model_name'`` :param page_token: Token specifying the next page of results. It should be obtained from a ``search_registered_models`` call. :param max_results: Maximum number of registered models desired. :return: A PagedList of :py:class:`mlflow.entities.model_registry.RegisteredModel` objects that satisfy the search expressions. The pagination token for the next page can be obtained via the ``token`` attribute of the object. """ if max_results > SEARCH_REGISTERED_MODEL_MAX_RESULTS_THRESHOLD: raise MlflowException("Invalid value for request parameter max_results." "It must be at most {}, but got value {}" .format(SEARCH_REGISTERED_MODEL_MAX_RESULTS_THRESHOLD, max_results), INVALID_PARAMETER_VALUE) parsed_filter = SearchUtils.parse_filter_for_model_registry(filter_string) offset = SearchUtils.parse_start_offset_from_page_token(page_token) def compute_next_token(current_size): next_token = None if max_results == current_size: final_offset = offset + max_results next_token = SearchUtils.create_page_token(final_offset) return next_token if len(parsed_filter) == 0: conditions = [] elif len(parsed_filter) == 1: filter_dict = parsed_filter[0] if filter_dict["comparator"] not in \ SearchUtils.VALID_REGISTERED_MODEL_SEARCH_COMPARATORS: raise MlflowException('Search registered models filter expression only ' 'supports the equality(=) comparator, case-sensitive' 'partial match (LIKE), and case-insensitive partial ' 'match (ILIKE). Input filter string: %s' % filter_string, error_code=INVALID_PARAMETER_VALUE) if filter_dict["key"] == "name": if filter_dict["comparator"] == "LIKE": conditions = [SqlRegisteredModel.name.like(filter_dict["value"])] elif filter_dict["comparator"] == "ILIKE": conditions = [SqlRegisteredModel.name.ilike(filter_dict["value"])] else: conditions = [SqlRegisteredModel.name == filter_dict["value"]] else: raise MlflowException('Invalid filter string: %s' % filter_string, error_code=INVALID_PARAMETER_VALUE) else: supported_ops = ''.join(['(' + op + ')' for op in SearchUtils.VALID_REGISTERED_MODEL_SEARCH_COMPARATORS]) sample_query = f'name {supported_ops} "<model_name>"' raise MlflowException(f'Invalid filter string: {filter_string}' 'Search registered models supports filter expressions like:' + sample_query, error_code=INVALID_PARAMETER_VALUE) with self.ManagedSessionMaker() as session: if self.db_type == SQLITE: session.execute("PRAGMA case_sensitive_like = true;") query = session\ .query(SqlRegisteredModel)\ .filter(*conditions)\ .order_by(SqlRegisteredModel.name.asc())\ .limit(max_results) if page_token: query = query.offset(offset) sql_registered_models = query.all() registered_models = [rm.to_mlflow_entity() for rm in sql_registered_models] next_page_token = compute_next_token(len(registered_models)) return PagedList(registered_models, next_page_token)