def summarize_rollout(self, training_info): """Generate summaries for rollout. Note that training_info.info is empty here. Should use training_info.rollout_info to generate the summaries. Args: training_info (TrainingInfo): TrainingInfo structure collected from rollout. Returns: None """ if self._debug_summaries: summary_utils.add_action_summaries(training_info.action, self._action_spec, "rollout_action") self.add_reward_summary("rollout_reward/extrinsic", training_info.reward) if self._summarize_action_distributions: field = nest_utils.find_field(training_info.rollout_info, 'action_distribution') if len(field) == 1: summary_utils.summarize_action_dist( action_distributions=field[0], action_specs=self._action_spec, name="rollout_action_dist")
def summarize_train(self, training_info, loss_info, grads_and_vars): """Generate summaries for training & loss info. For on-policy algorithms, training_info.info is available. For off-policy alogirthms, both training_info.info and training_info.rollout_info are available. However, the statistics for these two structure are for the data batch sampled from the replay buffer. They do not represent the statistics of current on-going rollout. Args: training_info (TrainingInfo): TrainingInfo structure collected from rollout (on-policy training) or train_step (off-policy training). loss_info (LossInfo): loss grads_and_vars (tuple of (grad, var) pairs): list of gradients and their corresponding variables Returns: None """ if self._summarize_grads_and_vars: summary_utils.add_variables_summaries(grads_and_vars) summary_utils.add_gradients_summaries(grads_and_vars) if self._debug_summaries: summary_utils.add_action_summaries(training_info.action, self._action_spec) summary_utils.add_loss_summaries(loss_info) if self._summarize_action_distributions: field = nest_utils.find_field(training_info.info, 'action_distribution') if len(field) == 1: summary_utils.summarize_action_dist(field[0], self._action_spec)
def test_find_field(self): nest = NTuple(a=1, b=NTuple(a=NTuple(a=2, b=3), b=2)) ret = nest_utils.find_field(nest, 'a') self.assertEqual(len(ret), 2) self.assertEqual(ret[0], nest.a) self.assertEqual(ret[1], nest.b.a) nest = (1, NTuple(a=NTuple(a=2, b=3), b=2)) ret = nest_utils.find_field(nest, 'a') self.assertEqual(len(ret), 1) self.assertEqual(ret[0], nest[1].a) nest = NTuple(a=1, b=[NTuple(a=2, b=3), 2]) ret = nest_utils.find_field(nest, 'a') self.assertEqual(len(ret), 2) self.assertEqual(ret[0], nest.a) self.assertEqual(ret[1], nest.b[0].a)