def get_trial_checkpoints_paths(self, trial, metric=TRAINING_ITERATION):
        """Returns a list of [path, metric] lists for all disk checkpoints of
         a trial.

        Arguments:
            trial(Trial): The log directory of a trial, or a trial instance.
            metric (str): key for trial info to return, e.g. "mean_accuracy".
                "training_iteration" is used by default.
        """

        if isinstance(trial, str):
            trial_dir = os.path.expanduser(trial)

            # get checkpoints from logdir
            chkpt_df = TrainableUtil.get_checkpoints_paths(trial_dir)

            # join with trial dataframe to get metrics
            trial_df = self.trial_dataframes[trial_dir]
            path_metric_df = chkpt_df.merge(
                trial_df, on="training_iteration", how="inner")
            return path_metric_df[["chkpt_path", metric]].values.tolist()
        elif isinstance(trial, Trial):
            checkpoints = trial.checkpoint_manager.best_checkpoints()
            # TODO(ujvl): Remove condition once the checkpoint manager is
            #  modified to only track PERSISTENT checkpoints.
            return [[c.value, c.result[metric]] for c in checkpoints
                    if c.storage == Checkpoint.PERSISTENT]
        else:
            raise ValueError("trial should be a string or a Trial instance.")
Esempio n. 2
0
    def get_trial_checkpoints_paths(self, trial, metric=TRAINING_ITERATION):
        """Gets paths and metrics of all persistent checkpoints of a trial.

        Args:
            trial (Trial): The log directory of a trial, or a trial instance.
            metric (str): key for trial info to return, e.g. "mean_accuracy".
                "training_iteration" is used by default.

        Returns:
            List of [path, metric] for all persistent checkpoints of the trial.
        """
        if isinstance(trial, str):
            trial_dir = os.path.expanduser(trial)
            # Get checkpoints from logdir.
            chkpt_df = TrainableUtil.get_checkpoints_paths(trial_dir)

            # Join with trial dataframe to get metrics.
            trial_df = self.trial_dataframes[trial_dir]
            path_metric_df = chkpt_df.merge(
                trial_df, on="training_iteration", how="inner")
            return path_metric_df[["chkpt_path", metric]].values.tolist()
        elif isinstance(trial, Trial):
            checkpoints = trial.checkpoint_manager.best_checkpoints()
            return [[c.value, c.result[metric]] for c in checkpoints]
        else:
            raise ValueError("trial should be a string or a Trial instance.")