Пример #1
0
    def _training_summary(self, training_info, loss_info, grads_and_vars):
        if self._summarize_grads_and_vars:
            summary_utils.add_variables_summaries(grads_and_vars,
                                                  self._train_step_counter)
            summary_utils.add_gradients_summaries(grads_and_vars,
                                                  self._train_step_counter)
        if self._debug_summaries:
            common.add_action_summaries(training_info.action,
                                        self.env.action_spec())
            common.add_loss_summaries(loss_info)

        if self._summarize_action_distributions:
            summary_utils.summarize_action_dist(
                training_info.action_distribution, self.env.action_spec())
            if training_info.collect_action_distribution:
                summary_utils.summarize_action_dist(
                    action_distributions=training_info.
                    collect_action_distribution,
                    action_specs=self.env.action_spec(),
                    name="collect_action_dist")

        for metric in self.get_metrics():
            metric.tf_summaries(
                train_step=self._train_step_counter,
                step_metrics=self.get_metrics()[:2])

        mem = tf.py_function(
            lambda: self._proc.memory_info().rss // 1e6, [],
            tf.float32,
            name='memory_usage')
        if not tf.executing_eagerly():
            mem.set_shape(())
        tf.summary.scalar(name='memory_usage', data=mem)
Пример #2
0
    def summarize_train(self, training_info, loss_info, grads_and_vars):
        """Generate summaries for training & loss info.

        For on-policy algorithms, training_info.info is available.
        For off-policy alogirthms, both training_info.info and training_info.rollout_info
        are available. However, the statistics for these two structure are for
        the data batch sampled from the replay buffer. They do not represent
        the statistics of current on-going rollout.

        Args:
            training_info (TrainingInfo): TrainingInfo structure collected from
                rollout (on-policy training) or train_step (off-policy training).
            loss_info (LossInfo): loss
            grads_and_vars (tuple of (grad, var) pairs): list of gradients and
                their corresponding variables
        Returns:
            None
        """
        if self._summarize_grads_and_vars:
            summary_utils.add_variables_summaries(grads_and_vars)
            summary_utils.add_gradients_summaries(grads_and_vars)
        if self._debug_summaries:
            summary_utils.add_action_summaries(training_info.action,
                                               self._action_spec)
            summary_utils.add_loss_summaries(loss_info)

        if self._summarize_action_distributions:
            field = nest_utils.find_field(training_info.info,
                                          'action_distribution')
            if len(field) == 1:
                summary_utils.summarize_action_dist(field[0],
                                                    self._action_spec)