def _after_eval_callback(): n = 100 packed_rollout_data = rollouts.rollout_n(n, env, training_policy) avg_rew = torch.sum(packed_rollout_data['rew'].data) / n trainer.print('Avg rollout reward: ', avg_rew) summaries.add_scalar('_performance/avg_rollout_reward', avg_rew, trainer.global_step)
def _inference_and_loss(self, sample_data): # Compute bc_loss. pred_act = self.policy(sample_data['obs']) expert_act = sample_data['act'] bc_loss = self._loss_fn(pred_act, expert_act) summaries.add_scalar('_performance/bc_loss', bc_loss, self.global_step) # Compute reflex_loss. reflex_outputs = self.policy.reflex_outputs(sample_data['obs']) unweighted_reflexes_loss = self._loss_fn( reflex_outputs, torch.unsqueeze(expert_act, dim=-1)) reflex_softmax_weights = self.policy.reflex_softmax_weights( sample_data['obs']) weighted_reflexes_loss = unweighted_reflexes_loss * reflex_softmax_weights reflexes_loss = torch.mean(torch.sum(weighted_reflexes_loss, dim=-1)) summaries.add_scalar('_performance/reflexes_loss', reflexes_loss, self.global_step) summaries.add_histogram('reflexes/softmax_weights', reflex_softmax_weights, self.global_step) reflex_conditional_entropy = -torch.sum( reflex_softmax_weights * torch.log(reflex_softmax_weights), dim=-1) summaries.add_histogram('reflexes/reflex_conditional_entropy', reflex_conditional_entropy, self.global_step) reflex_marginals = torch.mean(reflex_softmax_weights, dim=0) reflex_marginal_entropy = -torch.sum( reflex_marginals * torch.log(reflex_marginals), dim=-1) summaries.add_histogram('reflexes/reflex_marginal_entropy', reflex_marginal_entropy, self.global_step) supervisor_loss = bc_loss return supervisor_loss, reflexes_loss
def _inference_and_loss(self, sample_data): embeddings = self.policy.get_embedding(sample_data['obs']) pred_act = self.policy(sample_data['obs']) expert_act = sample_data['act'] loss = self._loss_fn(pred_act, expert_act) summaries.add_scalar('_performance/loss', loss, self.global_step) return loss
def _eval(self): with torch.no_grad(): self._model.eval() # Put model in eval mode. start_time = time.time() sample_data = self._dataset.sample(batch_size=float('inf'), eval=True) losses = self._inference_and_loss(sample_data) total_time = time.time() - start_time # Summarize timing. steps_per_sec = 1 / total_time summaries.add_scalar('misc/eval_steps_per_sec', steps_per_sec, self.global_step) if hasattr(self._model, "reset"): self._model.reset() return losses
def _after_eval_callback(): n = 100 packed_rollout_data = rollouts.rollout_n(n, env, training_policy) avg_rew = torch.sum(packed_rollout_data['rew'].data) / n trainer.print('Avg rollout reward: ', avg_rew) summaries.add_scalar('_performance/avg_rollout_reward', avg_rew, trainer.global_step) # Add oracle data to dataset. packed_oracle_actions = PackedSequence( data=torch.tensor(oracle_policy(packed_rollout_data['obs'].data)), batch_sizes=packed_rollout_data['obs'].batch_sizes, ) oracle_data = { 'obs': packed_rollout_data['obs'], 'act': packed_oracle_actions, } dataset.add_data(oracle_data)
def _train(self, sample_data=None): start_time = time.time() self._model.train() # Put model in train mode. if sample_data is None: sample_data = self._dataset.sample() losses = self._inference_and_loss(sample_data) for i, (opt, loss) in enumerate(zip(self._optimizers, losses)): loss = torch.mean(loss) summaries.add_scalar('_performance/loss', loss, self.global_step) opt.zero_grad() loss.backward(retain_graph=(i + 1 != len(losses))) opt.step() # Summarize timing. total_time = time.time() - start_time steps_per_sec = 1 / total_time summaries.add_scalar('misc/train_steps_per_sec', steps_per_sec, self.global_step) if hasattr(self._model, "reset"): self._model.reset() return losses