def test_pagination(page_token, max_results, matching_runs, expected_next_page_token): runs = [ Run(run_info=RunInfo( run_uuid="0", run_id="0", experiment_id=0, user_id="user-id", status=RunStatus.to_string(RunStatus.FAILED), start_time=0, end_time=1, lifecycle_stage=LifecycleStage.ACTIVE), run_data=RunData([], [], [])), Run(run_info=RunInfo( run_uuid="1", run_id="1", experiment_id=0, user_id="user-id", status=RunStatus.to_string(RunStatus.FAILED), start_time=0, end_time=1, lifecycle_stage=LifecycleStage.ACTIVE), run_data=RunData([], [], [])), Run(run_info=RunInfo( run_uuid="2", run_id="2", experiment_id=0, user_id="user-id", status=RunStatus.to_string(RunStatus.FAILED), start_time=0, end_time=1, lifecycle_stage=LifecycleStage.ACTIVE), run_data=RunData([], [], [])) ] encoded_page_token = None if page_token: encoded_page_token = base64.b64encode(json.dumps(page_token).encode("utf-8")) paginated_runs, next_page_token = SearchUtils.paginate(runs, encoded_page_token, max_results) paginated_run_indices = [] for run in paginated_runs: for i, r in enumerate(runs): if r == run: paginated_run_indices.append(i) break assert paginated_run_indices == matching_runs decoded_next_page_token = None if next_page_token: decoded_next_page_token = json.loads(base64.b64decode(next_page_token)) assert decoded_next_page_token == expected_next_page_token
def _search_runs(self, experiment_ids, filter_string, run_view_type, max_results, order_by, page_token): # TODO: push search query into backend database layer if max_results > SEARCH_MAX_RESULTS_THRESHOLD: raise MlflowException( "Invalid value for request parameter max_results. It must be at " "most {}, but got value {}".format( SEARCH_MAX_RESULTS_THRESHOLD, max_results), INVALID_PARAMETER_VALUE) stages = set(LifecycleStage.view_type_to_stages(run_view_type)) with self.ManagedSessionMaker() as session: # Fetch the appropriate runs and eagerly load their summary metrics, params, and # tags. These run attributes are referenced during the invocation of # ``run.to_mlflow_entity()``, so eager loading helps avoid additional database queries # that are otherwise executed at attribute access time under a lazy loading model. queried_runs = session \ .query(SqlRun) \ .options(*self._get_eager_run_query_options()) \ .filter( SqlRun.experiment_id.in_(experiment_ids), SqlRun.lifecycle_stage.in_(stages)) \ .all() runs = [run.to_mlflow_entity() for run in queried_runs] filtered = SearchUtils.filter(runs, filter_string) sorted_runs = SearchUtils.sort(filtered, order_by) runs, next_page_token = SearchUtils.paginate(sorted_runs, page_token, max_results) return runs, next_page_token
def list_experiments( self, view_type=ViewType.ACTIVE_ONLY, max_results=None, page_token=None, ): """ :param view_type: Qualify requested type of experiments. :param max_results: If passed, specifies the maximum number of experiments desired. If not passed, all experiments will be returned. :param page_token: Token specifying the next page of results. It should be obtained from a ``list_experiments`` call. :return: A :py:class:`PagedList <mlflow.store.entities.PagedList>` of :py:class:`Experiment <mlflow.entities.Experiment>` objects. The pagination token for the next page can be obtained via the ``token`` attribute of the object. """ from mlflow.utils.search_utils import SearchUtils from mlflow.store.entities.paged_list import PagedList _validate_list_experiments_max_results(max_results) self._check_root_dir() rsl = [] if view_type == ViewType.ACTIVE_ONLY or view_type == ViewType.ALL: rsl += self._get_active_experiments(full_path=False) if view_type == ViewType.DELETED_ONLY or view_type == ViewType.ALL: rsl += self._get_deleted_experiments(full_path=False) experiments = [] for exp_id in rsl: try: # trap and warn known issues, will raise unexpected exceptions to caller experiment = self._get_experiment(exp_id, view_type) if experiment: experiments.append(experiment) except MissingConfigException as rnfe: # Trap malformed experiments and log warnings. logging.warning( "Malformed experiment '%s'. Detailed error %s", str(exp_id), str(rnfe), exc_info=True, ) if max_results is not None: experiments, next_page_token = SearchUtils.paginate( experiments, page_token, max_results ) return PagedList(experiments, next_page_token) else: return PagedList(experiments, None)
def _search_runs(self, experiment_ids, filter_string, run_view_type, max_results, order_by, page_token): if max_results > SEARCH_MAX_RESULTS_THRESHOLD: raise MlflowException( "Invalid value for request parameter max_results. It must be at " "most {}, but got value {}".format( SEARCH_MAX_RESULTS_THRESHOLD, max_results), databricks_pb2.INVALID_PARAMETER_VALUE) runs = [] for experiment_id in experiment_ids: run_infos = self._list_run_infos(experiment_id, run_view_type) runs.extend(self.get_run(r.run_id) for r in run_infos) filtered = SearchUtils.filter(runs, filter_string) sorted_runs = SearchUtils.sort(filtered, order_by) runs, next_page_token = SearchUtils.paginate(sorted_runs, page_token, max_results) return runs, next_page_token
def _search_runs(self, experiment_ids, filter_string, run_view_type, max_results, order_by, page_token): # TODO: push search query into backend database layer if max_results > SEARCH_MAX_RESULTS_THRESHOLD: raise MlflowException( "Invalid value for request parameter max_results. It must be at " "most {}, but got value {}".format( SEARCH_MAX_RESULTS_THRESHOLD, max_results), INVALID_PARAMETER_VALUE) with self.ManagedSessionMaker() as session: runs = [ run.to_mlflow_entity() for exp in experiment_ids for run in self._list_runs(session, exp, run_view_type) ] filtered = SearchUtils.filter(runs, filter_string) sorted_runs = SearchUtils.sort(filtered, order_by) runs, next_page_token = SearchUtils.paginate( sorted_runs, page_token, max_results) return runs, next_page_token
def _search_runs( self, experiment_ids, filter_string, run_view_type, max_results, order_by, page_token, ): if max_results > SEARCH_MAX_RESULTS_THRESHOLD: raise MlflowException( "Invalid value for request parameter max_results. It must be at " "most {}, but got value {}".format( SEARCH_MAX_RESULTS_THRESHOLD, max_results), INVALID_PARAMETER_VALUE, ) runs = [] for experiment_id in experiment_ids: run_ids = self._list_runs_ids(experiment_id, run_view_type) run_infos = [ _dict_to_run_info(r) for r in self._get_run_list(run_ids) ] for run_info in run_infos: # Load the metrics, params and tags for the run run_id = run_info.run_id metrics = self.get_all_metrics(run_id) params = self.get_all_params(run_id) tags = self.get_all_tags(run_id) run = Run(run_info, RunData(metrics, params, tags)) runs.append(run) filtered = SearchUtils.filter(runs, filter_string) sorted_runs = SearchUtils.sort(filtered, order_by) runs, next_page_token = SearchUtils.paginate(sorted_runs, page_token, max_results) return runs, next_page_token
def test_invalid_page_tokens(page_token, error_message): with pytest.raises(MlflowException) as e: SearchUtils.paginate([], page_token, 1) assert error_message in e.value.message
def test_invalid_page_tokens(page_token, error_message): with pytest.raises(MlflowException, match=error_message): SearchUtils.paginate([], page_token, 1)