コード例 #1
0
    def summarize_rollout(self, training_info):
        """Generate summaries for rollout.

        Note that training_info.info is empty here. Should use
        training_info.rollout_info to generate the summaries.

        Args:
            training_info (TrainingInfo): TrainingInfo structure collected from
                rollout.
        Returns:
            None
        """
        if self._debug_summaries:
            summary_utils.add_action_summaries(training_info.action,
                                               self._action_spec,
                                               "rollout_action")
            self.add_reward_summary("rollout_reward/extrinsic",
                                    training_info.reward)

        if self._summarize_action_distributions:
            field = nest_utils.find_field(training_info.rollout_info,
                                          'action_distribution')
            if len(field) == 1:
                summary_utils.summarize_action_dist(
                    action_distributions=field[0],
                    action_specs=self._action_spec,
                    name="rollout_action_dist")
コード例 #2
0
    def summarize_train(self, training_info, loss_info, grads_and_vars):
        """Generate summaries for training & loss info.

        For on-policy algorithms, training_info.info is available.
        For off-policy alogirthms, both training_info.info and training_info.rollout_info
        are available. However, the statistics for these two structure are for
        the data batch sampled from the replay buffer. They do not represent
        the statistics of current on-going rollout.

        Args:
            training_info (TrainingInfo): TrainingInfo structure collected from
                rollout (on-policy training) or train_step (off-policy training).
            loss_info (LossInfo): loss
            grads_and_vars (tuple of (grad, var) pairs): list of gradients and
                their corresponding variables
        Returns:
            None
        """
        if self._summarize_grads_and_vars:
            summary_utils.add_variables_summaries(grads_and_vars)
            summary_utils.add_gradients_summaries(grads_and_vars)
        if self._debug_summaries:
            summary_utils.add_action_summaries(training_info.action,
                                               self._action_spec)
            summary_utils.add_loss_summaries(loss_info)

        if self._summarize_action_distributions:
            field = nest_utils.find_field(training_info.info,
                                          'action_distribution')
            if len(field) == 1:
                summary_utils.summarize_action_dist(field[0],
                                                    self._action_spec)
コード例 #3
0
    def test_find_field(self):
        nest = NTuple(a=1, b=NTuple(a=NTuple(a=2, b=3), b=2))
        ret = nest_utils.find_field(nest, 'a')
        self.assertEqual(len(ret), 2)
        self.assertEqual(ret[0], nest.a)
        self.assertEqual(ret[1], nest.b.a)

        nest = (1, NTuple(a=NTuple(a=2, b=3), b=2))
        ret = nest_utils.find_field(nest, 'a')
        self.assertEqual(len(ret), 1)
        self.assertEqual(ret[0], nest[1].a)

        nest = NTuple(a=1, b=[NTuple(a=2, b=3), 2])
        ret = nest_utils.find_field(nest, 'a')
        self.assertEqual(len(ret), 2)
        self.assertEqual(ret[0], nest.a)
        self.assertEqual(ret[1], nest.b[0].a)