示例#1
0
    def evaluate(self, epoch):
        # TODO: AT THIS MOMENT THIS CODE IS THE SAME THAN SUPER
        statistics = OrderedDict()
        statistics.update(self.eval_statistics)
        self.eval_statistics = None

        logger.log("Collecting samples for evaluation")
        test_paths = self.eval_sampler.obtain_samples()

        statistics.update(
            eval_util.get_generic_path_information(test_paths,
                                                   stat_prefix="Test"))
        statistics.update(
            eval_util.get_generic_path_information(
                self._exploration_paths,
                stat_prefix="Exploration",
            ))
        if hasattr(self.explo_env, "log_diagnostics"):
            print('TODO: WE NEED LOG_DIAGNOSTICS IN ENV')
            self.explo_env.log_diagnostics(test_paths)

        average_returns = eval_util.get_average_returns(test_paths)
        statistics['Average Test Return'] = average_returns

        # Record the data
        for key, value in statistics.items():
            logger.record_tabular(key, value)

        if self.plotter is not None:
            self.plotter.draw()
示例#2
0
    def evaluate(self, epoch):
        statistics = OrderedDict()
        statistics.update(self.eval_statistics)
        self.eval_statistics = None

        test_paths = [None for _ in range(self._n_demons)]
        for demon in range(self._n_demons):
            logger.log("[%02d] Collecting samples for evaluation" % demon)
            test_paths[demon] = self.eval_sampler.obtain_samples()

            statistics.update(
                eval_util.get_generic_path_information(
                    test_paths[demon],
                    stat_prefix="[%02d] Test" % demon,
                ))

        statistics.update(
            eval_util.get_generic_path_information(
                self._exploration_paths,
                stat_prefix="Exploration",
            ))
        if hasattr(self.explo_env, "log_diagnostics"):
            print('TODO: WE NEED LOG_DIAGNOSTICS IN ENV')
            self.explo_env.log_diagnostics(test_paths[demon])

        # Record the data
        for key, value in statistics.items():
            logger.record_tabular(key, value)

        if self.plotter is not None:
            self.plotter.draw()
示例#3
0
    def evaluate(self, epoch):
        statistics = OrderedDict()
        self._update_logging_data()
        statistics.update(self.eval_statistics)
        self.eval_statistics = None

        logger.log("Collecting samples for evaluation")
        test_paths = self.eval_sampler.obtain_samples()
        statistics.update(
            eval_util.get_generic_path_information(
                test_paths,
                stat_prefix="[I] Test",
            ))

        if self._exploration_paths:
            statistics.update(
                eval_util.get_generic_path_information(
                    self._exploration_paths,
                    stat_prefix="Exploration",
                ))
        else:
            statistics.update(
                eval_util.get_generic_path_information(
                    test_paths,
                    stat_prefix="Exploration",
                ))

        if self._log_tensorboard:
            self._summary_writer.add_scalar(
                'Evaluation/avg_return', statistics['[I] Test Returns Mean'],
                self._n_epochs)

            self._summary_writer.add_scalar(
                'Evaluation/avg_reward', statistics['[I] Test Rewards Mean'],
                self._n_epochs)

        if hasattr(self.explo_env, "log_diagnostics"):
            pass
            # # TODO: CHECK ENV LOG_DIAGNOSTICS
            # print('TODO: WE NEED LOG_DIAGNOSTICS IN ENV')

        # Record the data
        for key, value in statistics.items():
            logger.record_tabular(key, value)

        # Epoch Plotter
        if self._epoch_plotter is not None:
            self._epoch_plotter.draw()

        # Reset log_data
        for key in self.log_data.keys():
            self.log_data[key].fill(0)
示例#4
0
    def _end_epoch(self, epoch):
        """
        Computations at the end of an epoch.
        Returns:

        """
        logger.log("Epoch Duration: {0}".format(time.time() -
                                                self._epoch_start_time))
        logger.log("Started Training: {0}".format(self._can_train()))
        logger.pop_prefix()

        for post_epoch_func in self.post_epoch_funcs:
            post_epoch_func(self, epoch)
示例#5
0
    def evaluate(self, epoch):
        statistics = OrderedDict()
        statistics.update(self.eval_statistics)
        self.eval_statistics = None

        test_paths = [None for _ in range(self._n_unintentional)]
        for demon in range(self._n_unintentional):
            logger.log("[U-%02d] Collecting samples for evaluation" % demon)
            test_paths[demon] = self.eval_samplers[demon].obtain_samples()

            statistics.update(
                eval_util.get_generic_path_information(
                    test_paths[demon],
                    stat_prefix="[U-%02d] Test" % demon,
                ))
            average_returns = eval_util.get_average_returns(test_paths[demon])
            statistics['[U-%02d] AverageReturn' % demon] = average_returns

        logger.log("[I] Collecting samples for evaluation")
        i_test_path = self.eval_sampler.obtain_samples()
        statistics.update(
            eval_util.get_generic_path_information(
                i_test_path,
                stat_prefix="[I] Test",
            ))

        statistics.update(
            eval_util.get_generic_path_information(
                self._exploration_paths,
                stat_prefix="Exploration",
            ))
        if hasattr(self.explo_env, "log_diagnostics"):
            # TODO: CHECK ENV LOG_DIAGNOSTICS
            print('TODO: WE NEED LOG_DIAGNOSTICS IN ENV')
            # self.env.log_diagnostics(test_paths[demon])

        # Record the data
        for key, value in statistics.items():
            logger.record_tabular(key, value)

        for demon in range(self._n_unintentional):
            if self.render_eval_paths:
                # TODO: CHECK ENV RENDER_PATHS
                print('TODO: RENDER_PATHS')
                pass

        if self._epoch_plotter is not None:
            self._epoch_plotter.draw()
示例#6
0
    def _try_to_eval(self, epoch, eval_paths=None):
        """
        Check if the requirements are fulfilled to start or not an evaluation.
        Args:
            epoch (int): Epoch

        Returns:

        """
        logger.save_extra_data(self.get_extra_data_to_save(epoch))
        if self._can_evaluate():
            # Call algorithm-specific evaluate method
            self.evaluate(epoch)

            self._log_data(epoch)
        else:
            logger.log("Skipping eval for now.")
示例#7
0
    def evaluate(self, epoch):
        """
        Evaluate the policy, e.g. save/print progress.
        :param epoch:
        :return:
        """
        if self.eval_statistics is None:
            self.eval_statistics = OrderedDict()

        statistics = OrderedDict()
        statistics.update(self.eval_statistics)
        self.eval_statistics = None

        logger.log("Collecting samples for evaluation")
        test_paths = self.eval_sampler.obtain_samples()

        statistics.update(
            eval_util.get_generic_path_information(
                test_paths,
                stat_prefix="Test",
            ))

        if self._exploration_paths:
            statistics.update(
                eval_util.get_generic_path_information(
                    self._exploration_paths,
                    stat_prefix="Exploration",
                ))
        else:
            statistics.update(
                eval_util.get_generic_path_information(
                    test_paths,
                    stat_prefix="Exploration",
                ))

        if hasattr(self.explo_env, "log_diagnostics"):
            self.explo_env.log_diagnostics(test_paths)

        for key, value in statistics.items():
            logger.record_tabular(key, value)

        if self._epoch_plotter is not None:
            self._epoch_plotter.draw()
            self._epoch_plotter.save_figure(epoch)
示例#8
0
    def evaluate(self, epoch):
        statistics = OrderedDict()
        self._update_logging_data()
        statistics.update(self.eval_statistics)
        self.eval_statistics = None

        # Interaction Paths for each unintentional policy
        test_paths = [None for _ in range(self._n_unintentional)]
        for unint_idx in range(self._n_unintentional):
            logger.log("[U-%02d] Collecting samples for evaluation" %
                       unint_idx)
            test_paths[unint_idx] = \
                self.eval_u_samplers[unint_idx].obtain_samples()

            statistics.update(
                eval_util.get_generic_path_information(
                    test_paths[unint_idx],
                    stat_prefix="[U-%02d] Test" % unint_idx,
                ))

            if self._log_tensorboard:
                self._summary_writer.add_scalar(
                    'EvaluationU%02d/avg_return' % unint_idx,
                    statistics['[U-%02d] Test Returns Mean' % unint_idx],
                    self._n_epochs)

                self._summary_writer.add_scalar(
                    'EvaluationU%02d/avg_reward' % unint_idx,
                    statistics['[U-%02d] Test Rewards Mean' % unint_idx],
                    self._n_epochs)

        # Interaction Paths for the intentional policy
        logger.log("[I] Collecting samples for evaluation")
        i_test_paths = self.eval_sampler.obtain_samples()
        statistics.update(
            eval_util.get_generic_path_information(
                i_test_paths,
                stat_prefix="[I] Test",
            ))

        if self._exploration_paths:
            statistics.update(
                eval_util.get_generic_path_information(
                    self._exploration_paths,
                    stat_prefix="Exploration",
                ))
        else:
            statistics.update(
                eval_util.get_generic_path_information(
                    i_test_paths,
                    stat_prefix="Exploration",
                ))

        if self._log_tensorboard:
            self._summary_writer.add_scalar(
                'EvaluationI/avg_return', statistics['[I] Test Returns Mean'],
                self._n_epochs)

            self._summary_writer.add_scalar(
                'EvaluationI/avg_reward',
                statistics['[I] Test Rewards Mean'] * self.reward_scale,
                self._n_epochs)

        if hasattr(self.explo_env, "log_diagnostics"):
            pass
            # # TODO: CHECK ENV LOG_DIAGNOSTICS
            # print('TODO: WE NEED LOG_DIAGNOSTICS IN ENV')

        # Record the data
        for key, value in statistics.items():
            logger.record_tabular(key, value)

        # Epoch Plotter
        if self._epoch_plotter is not None:
            self._epoch_plotter.draw()

        # Reset log_data
        for key in self.log_data.keys():
            self.log_data[key].fill(0)
示例#9
0
def setup_logger(
    exp_prefix="default",
    exp_id=0,
    seed=0,
    variant=None,
    base_log_dir=None,
    text_log_file="debug.log",
    variant_log_file="variant.json",
    tabular_log_file="progress.csv",
    snapshot_mode="last",
    snapshot_gap=1,
    log_tabular_only=False,
    log_stdout=True,
    log_dir=None,
    git_info=None,
    script_name=None,
):
    """
    Set up logger to have some reasonable default settings.

    Will save log output to

        based_log_dir/exp_prefix/exp_name.

    exp_name will be auto-generated to be unique.

    If log_dir is specified, then that directory is used as the output dir.

    :param exp_prefix: The sub-directory for this specific experiment.
    :param exp_id: The number of the specific experiment run within this
    experiment.
    :param variant:
    :param base_log_dir: The directory where all log should be saved.
    :param text_log_file:
    :param variant_log_file:
    :param tabular_log_file:
    :param snapshot_mode:
    :param log_tabular_only:
    :param log_stdout:
    :param snapshot_gap:
    :param log_dir:
    :param git_info:
    :param script_name: If set, save the script name to this.
    :return:
    """
    first_time = log_dir is None
    if first_time:
        log_dir = create_log_dir(exp_prefix,
                                 exp_id=exp_id,
                                 seed=seed,
                                 base_log_dir=base_log_dir)

    if variant is not None:
        logger.log("Variant:")
        logger.log(json.dumps(dict_to_safe_json(variant), indent=2))
        variant_log_path = osp.join(log_dir, variant_log_file)
        logger.log_variant(variant_log_path, variant)

    tabular_log_path = osp.join(log_dir, tabular_log_file)
    text_log_path = osp.join(log_dir, text_log_file)

    logger.add_text_output(text_log_path)
    if first_time:
        logger.add_tabular_output(tabular_log_path)
    else:
        logger._add_output(tabular_log_path,
                           logger._tabular_outputs,
                           logger._tabular_fds,
                           mode='a')
        for tabular_fd in logger._tabular_fds:
            logger._tabular_header_written.add(tabular_fd)
    logger.set_snapshot_dir(log_dir)
    logger.set_snapshot_mode(snapshot_mode)
    logger.set_snapshot_gap(snapshot_gap)
    logger.set_log_tabular_only(log_tabular_only)
    logger.set_log_stdout(log_stdout)
    exp_name = log_dir.split("/")[-1]
    logger.push_prefix("[%s] " % exp_name)

    if git_info is not None:
        code_diff, commit_hash, branch_name = git_info
        if code_diff is not None:
            with open(osp.join(log_dir, "code.diff"), "w") as f:
                f.write(code_diff)
        with open(osp.join(log_dir, "git_info.txt"), "w") as f:
            f.write("git hash: {}".format(commit_hash))
            f.write('\n')
            f.write("git branch name: {}".format(branch_name))
    if script_name is not None:
        with open(osp.join(log_dir, "script_name.txt"), "w") as f:
            f.write(script_name)
    return log_dir