コード例 #1
0
def test_pagination(page_token, max_results, matching_runs, expected_next_page_token):
    runs = [
        Run(run_info=RunInfo(
            run_uuid="0", run_id="0", experiment_id=0,
            user_id="user-id", status=RunStatus.to_string(RunStatus.FAILED),
            start_time=0, end_time=1, lifecycle_stage=LifecycleStage.ACTIVE),
            run_data=RunData([], [], [])),
        Run(run_info=RunInfo(
            run_uuid="1", run_id="1", experiment_id=0,
            user_id="user-id", status=RunStatus.to_string(RunStatus.FAILED),
            start_time=0, end_time=1, lifecycle_stage=LifecycleStage.ACTIVE),
            run_data=RunData([], [], [])),
        Run(run_info=RunInfo(
            run_uuid="2", run_id="2", experiment_id=0,
            user_id="user-id", status=RunStatus.to_string(RunStatus.FAILED),
            start_time=0, end_time=1, lifecycle_stage=LifecycleStage.ACTIVE),
            run_data=RunData([], [], []))
    ]
    encoded_page_token = None
    if page_token:
        encoded_page_token = base64.b64encode(json.dumps(page_token).encode("utf-8"))
    paginated_runs, next_page_token = SearchUtils.paginate(runs, encoded_page_token, max_results)

    paginated_run_indices = []
    for run in paginated_runs:
        for i, r in enumerate(runs):
            if r == run:
                paginated_run_indices.append(i)
                break
    assert paginated_run_indices == matching_runs

    decoded_next_page_token = None
    if next_page_token:
        decoded_next_page_token = json.loads(base64.b64decode(next_page_token))
    assert decoded_next_page_token == expected_next_page_token
コード例 #2
0
    def _search_runs(self, experiment_ids, filter_string, run_view_type,
                     max_results, order_by, page_token):
        # TODO: push search query into backend database layer
        if max_results > SEARCH_MAX_RESULTS_THRESHOLD:
            raise MlflowException(
                "Invalid value for request parameter max_results. It must be at "
                "most {}, but got value {}".format(
                    SEARCH_MAX_RESULTS_THRESHOLD, max_results),
                INVALID_PARAMETER_VALUE)

        stages = set(LifecycleStage.view_type_to_stages(run_view_type))
        with self.ManagedSessionMaker() as session:
            # Fetch the appropriate runs and eagerly load their summary metrics, params, and
            # tags. These run attributes are referenced during the invocation of
            # ``run.to_mlflow_entity()``, so eager loading helps avoid additional database queries
            # that are otherwise executed at attribute access time under a lazy loading model.
            queried_runs = session \
                .query(SqlRun) \
                .options(*self._get_eager_run_query_options()) \
                .filter(
                    SqlRun.experiment_id.in_(experiment_ids),
                    SqlRun.lifecycle_stage.in_(stages)) \
                .all()
            runs = [run.to_mlflow_entity() for run in queried_runs]

        filtered = SearchUtils.filter(runs, filter_string)
        sorted_runs = SearchUtils.sort(filtered, order_by)
        runs, next_page_token = SearchUtils.paginate(sorted_runs, page_token,
                                                     max_results)
        return runs, next_page_token
コード例 #3
0
    def list_experiments(
        self,
        view_type=ViewType.ACTIVE_ONLY,
        max_results=None,
        page_token=None,
    ):
        """
        :param view_type: Qualify requested type of experiments.
        :param max_results: If passed, specifies the maximum number of experiments desired. If not
                            passed, all experiments will be returned.
        :param page_token: Token specifying the next page of results. It should be obtained from
                           a ``list_experiments`` call.
        :return: A :py:class:`PagedList <mlflow.store.entities.PagedList>` of
                 :py:class:`Experiment <mlflow.entities.Experiment>` objects. The pagination token
                 for the next page can be obtained via the ``token`` attribute of the object.
        """
        from mlflow.utils.search_utils import SearchUtils
        from mlflow.store.entities.paged_list import PagedList

        _validate_list_experiments_max_results(max_results)
        self._check_root_dir()
        rsl = []
        if view_type == ViewType.ACTIVE_ONLY or view_type == ViewType.ALL:
            rsl += self._get_active_experiments(full_path=False)
        if view_type == ViewType.DELETED_ONLY or view_type == ViewType.ALL:
            rsl += self._get_deleted_experiments(full_path=False)

        experiments = []
        for exp_id in rsl:
            try:
                # trap and warn known issues, will raise unexpected exceptions to caller
                experiment = self._get_experiment(exp_id, view_type)
                if experiment:
                    experiments.append(experiment)
            except MissingConfigException as rnfe:
                # Trap malformed experiments and log warnings.
                logging.warning(
                    "Malformed experiment '%s'. Detailed error %s",
                    str(exp_id),
                    str(rnfe),
                    exc_info=True,
                )
        if max_results is not None:
            experiments, next_page_token = SearchUtils.paginate(
                experiments, page_token, max_results
            )
            return PagedList(experiments, next_page_token)
        else:
            return PagedList(experiments, None)
コード例 #4
0
ファイル: file_store.py プロジェクト: winnerineast/mlflow
 def _search_runs(self, experiment_ids, filter_string, run_view_type,
                  max_results, order_by, page_token):
     if max_results > SEARCH_MAX_RESULTS_THRESHOLD:
         raise MlflowException(
             "Invalid value for request parameter max_results. It must be at "
             "most {}, but got value {}".format(
                 SEARCH_MAX_RESULTS_THRESHOLD, max_results),
             databricks_pb2.INVALID_PARAMETER_VALUE)
     runs = []
     for experiment_id in experiment_ids:
         run_infos = self._list_run_infos(experiment_id, run_view_type)
         runs.extend(self.get_run(r.run_id) for r in run_infos)
     filtered = SearchUtils.filter(runs, filter_string)
     sorted_runs = SearchUtils.sort(filtered, order_by)
     runs, next_page_token = SearchUtils.paginate(sorted_runs, page_token,
                                                  max_results)
     return runs, next_page_token
コード例 #5
0
 def _search_runs(self, experiment_ids, filter_string, run_view_type,
                  max_results, order_by, page_token):
     # TODO: push search query into backend database layer
     if max_results > SEARCH_MAX_RESULTS_THRESHOLD:
         raise MlflowException(
             "Invalid value for request parameter max_results. It must be at "
             "most {}, but got value {}".format(
                 SEARCH_MAX_RESULTS_THRESHOLD, max_results),
             INVALID_PARAMETER_VALUE)
     with self.ManagedSessionMaker() as session:
         runs = [
             run.to_mlflow_entity() for exp in experiment_ids
             for run in self._list_runs(session, exp, run_view_type)
         ]
         filtered = SearchUtils.filter(runs, filter_string)
         sorted_runs = SearchUtils.sort(filtered, order_by)
         runs, next_page_token = SearchUtils.paginate(
             sorted_runs, page_token, max_results)
         return runs, next_page_token
コード例 #6
0
ファイル: dynamodb_store.py プロジェクト: brightsparc/mlflow
    def _search_runs(
        self,
        experiment_ids,
        filter_string,
        run_view_type,
        max_results,
        order_by,
        page_token,
    ):
        if max_results > SEARCH_MAX_RESULTS_THRESHOLD:
            raise MlflowException(
                "Invalid value for request parameter max_results. It must be at "
                "most {}, but got value {}".format(
                    SEARCH_MAX_RESULTS_THRESHOLD, max_results),
                INVALID_PARAMETER_VALUE,
            )
        runs = []
        for experiment_id in experiment_ids:
            run_ids = self._list_runs_ids(experiment_id, run_view_type)
            run_infos = [
                _dict_to_run_info(r) for r in self._get_run_list(run_ids)
            ]
            for run_info in run_infos:
                # Load the metrics, params and tags for the run
                run_id = run_info.run_id
                metrics = self.get_all_metrics(run_id)
                params = self.get_all_params(run_id)
                tags = self.get_all_tags(run_id)
                run = Run(run_info, RunData(metrics, params, tags))
                runs.append(run)

        filtered = SearchUtils.filter(runs, filter_string)
        sorted_runs = SearchUtils.sort(filtered, order_by)
        runs, next_page_token = SearchUtils.paginate(sorted_runs, page_token,
                                                     max_results)
        return runs, next_page_token
コード例 #7
0
def test_invalid_page_tokens(page_token, error_message):
    with pytest.raises(MlflowException) as e:
        SearchUtils.paginate([], page_token, 1)
    assert error_message in e.value.message
コード例 #8
0
def test_invalid_page_tokens(page_token, error_message):
    with pytest.raises(MlflowException, match=error_message):
        SearchUtils.paginate([], page_token, 1)