예제 #1
0
    def _trial_to_result(self, trial: Trial) -> Result:
        checkpoint = trial.checkpoint.to_air_checkpoint()
        best_checkpoints = [(checkpoint.to_air_checkpoint(),
                             checkpoint.metrics)
                            for checkpoint in trial.get_trial_checkpoints()]

        result = Result(
            checkpoint=checkpoint,
            metrics=trial.last_result.copy(),
            error=self._populate_exception(trial),
            log_dir=Path(trial.logdir) if trial.logdir else None,
            metrics_dataframe=self._experiment_analysis.trial_dataframes.get(
                trial.logdir) if self._experiment_analysis else None,
            best_checkpoints=best_checkpoints,
        )
        return result
예제 #2
0
    def get_trial_checkpoints_paths(
            self,
            trial: Trial,
            metric: Optional[str] = None) -> List[Tuple[str, Number]]:
        """Gets paths and metrics of all persistent checkpoints of a trial.

        Args:
            trial: The log directory of a trial, or a trial instance.
            metric: key for trial info to return, e.g. "mean_accuracy".
                "training_iteration" is used by default if no value was
                passed to ``self.default_metric``.

        Returns:
            List of [path, metric] for all persistent checkpoints of the trial.
        """
        metric = metric or self.default_metric or TRAINING_ITERATION

        if isinstance(trial, str):
            trial_dir = os.path.expanduser(trial)
            # Get checkpoints from logdir.
            chkpt_df = TrainableUtil.get_checkpoints_paths(trial_dir)

            # Join with trial dataframe to get metrics.
            trial_df = self.trial_dataframes[trial_dir]
            path_metric_df = chkpt_df.merge(trial_df,
                                            on="training_iteration",
                                            how="inner")
            return path_metric_df[["chkpt_path", metric]].values.tolist()
        elif isinstance(trial, Trial):
            checkpoints = trial.get_trial_checkpoints()
            # Support metrics given as paths, e.g.
            # "info/learner/default_policy/policy_loss".
            return [(c.dir_or_data, unflattened_lookup(metric, c.metrics))
                    for c in checkpoints]
        else:
            raise ValueError("trial should be a string or a Trial instance.")