def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]: """ Performs update on model. :param batch: Batch of experiences. :param num_sequences: Number of sequences to process. :return: Results of update. """ # Get decayed parameters decay_lr = self.decay_learning_rate.get_value( self.policy.get_current_step()) decay_eps = self.decay_epsilon.get_value( self.policy.get_current_step()) decay_bet = self.decay_beta.get_value(self.policy.get_current_step()) returns = {} old_values = {} for name in self.reward_signals: old_values[name] = ModelUtils.list_to_tensor( batch[RewardSignalUtil.value_estimates_key(name)]) returns[name] = ModelUtils.list_to_tensor( batch[RewardSignalUtil.returns_key(name)]) n_obs = len(self.policy.behavior_spec.observation_specs) current_obs = ObsUtil.from_buffer(batch, n_obs) # Convert to tensors current_obs = [ModelUtils.list_to_tensor(obs) for obs in current_obs] act_masks = ModelUtils.list_to_tensor(batch[BufferKey.ACTION_MASK]) actions = AgentAction.from_buffer(batch) memories = [ ModelUtils.list_to_tensor(batch[BufferKey.MEMORY][i]) for i in range(0, len(batch[BufferKey.MEMORY]), self.policy.sequence_length) ] if len(memories) > 0: memories = torch.stack(memories).unsqueeze(0) # Get value memories value_memories = [ ModelUtils.list_to_tensor(batch[BufferKey.CRITIC_MEMORY][i]) for i in range(0, len(batch[BufferKey.CRITIC_MEMORY]), self.policy.sequence_length) ] if len(value_memories) > 0: value_memories = torch.stack(value_memories).unsqueeze(0) log_probs, entropy = self.policy.evaluate_actions( current_obs, masks=act_masks, actions=actions, memories=memories, seq_len=self.policy.sequence_length, ) values, _ = self.critic.critic_pass( current_obs, memories=value_memories, sequence_length=self.policy.sequence_length, ) old_log_probs = ActionLogProbs.from_buffer(batch).flatten() log_probs = log_probs.flatten() loss_masks = ModelUtils.list_to_tensor(batch[BufferKey.MASKS], dtype=torch.bool) value_loss = self.ppo_value_loss(values, old_values, returns, decay_eps, loss_masks) policy_loss = self.ppo_policy_loss( ModelUtils.list_to_tensor(batch[BufferKey.ADVANTAGES]), log_probs, old_log_probs, loss_masks, ) loss = (policy_loss + 0.5 * value_loss - decay_bet * ModelUtils.masked_mean(entropy, loss_masks)) # Set optimizer learning rate ModelUtils.update_learning_rate(self.optimizer, decay_lr) self.optimizer.zero_grad() loss.backward() self.optimizer.step() update_stats = { # NOTE: abs() is not technically correct, but matches the behavior in TensorFlow. # TODO: After PyTorch is default, change to something more correct. "Losses/Policy Loss": torch.abs(policy_loss).item(), "Losses/Value Loss": value_loss.item(), "Policy/Learning Rate": decay_lr, "Policy/Epsilon": decay_eps, "Policy/Beta": decay_bet, } for reward_provider in self.reward_signals.values(): update_stats.update(reward_provider.update(batch)) return update_stats
def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]: """ Performs update on model. :param batch: Batch of experiences. :param num_sequences: Number of sequences to process. :return: Results of update. """ # Get decayed parameters decay_lr = self.decay_learning_rate.get_value( self.policy.get_current_step()) decay_eps = self.decay_epsilon.get_value( self.policy.get_current_step()) decay_bet = self.decay_beta.get_value(self.policy.get_current_step()) returns = {} old_values = {} for name in self.reward_signals: old_values[name] = ModelUtils.list_to_tensor( batch[f"{name}_value_estimates"]) returns[name] = ModelUtils.list_to_tensor(batch[f"{name}_returns"]) vec_obs = [ModelUtils.list_to_tensor(batch["vector_obs"])] act_masks = ModelUtils.list_to_tensor(batch["action_mask"]) if self.policy.use_continuous_act: actions = ModelUtils.list_to_tensor( batch["actions_pre"]).unsqueeze(-1) else: actions = ModelUtils.list_to_tensor(batch["actions"], dtype=torch.long) memories = [ ModelUtils.list_to_tensor(batch["memory"][i]) for i in range( 0, len(batch["memory"]), self.policy.sequence_length) ] if len(memories) > 0: memories = torch.stack(memories).unsqueeze(0) if self.policy.use_vis_obs: vis_obs = [] for idx, _ in enumerate( self.policy.actor_critic.network_body.visual_processors): vis_ob = ModelUtils.list_to_tensor(batch["visual_obs%d" % idx]) vis_obs.append(vis_ob) else: vis_obs = [] log_probs, entropy, values = self.policy.evaluate_actions( vec_obs, vis_obs, masks=act_masks, actions=actions, memories=memories, seq_len=self.policy.sequence_length, ) loss_masks = ModelUtils.list_to_tensor(batch["masks"], dtype=torch.bool) value_loss = self.ppo_value_loss(values, old_values, returns, decay_eps, loss_masks) policy_loss = self.ppo_policy_loss( ModelUtils.list_to_tensor(batch["advantages"]), log_probs, ModelUtils.list_to_tensor(batch["action_probs"]), loss_masks, ) loss = (policy_loss + 0.5 * value_loss - decay_bet * ModelUtils.masked_mean(entropy, loss_masks)) # Set optimizer learning rate ModelUtils.update_learning_rate(self.optimizer, decay_lr) self.optimizer.zero_grad() loss.backward() self.optimizer.step() update_stats = { # NOTE: abs() is not technically correct, but matches the behavior in TensorFlow. # TODO: After PyTorch is default, change to something more correct. "Losses/Policy Loss": torch.abs(policy_loss).item(), "Losses/Value Loss": value_loss.item(), "Policy/Learning Rate": decay_lr, "Policy/Epsilon": decay_eps, "Policy/Beta": decay_bet, } for reward_provider in self.reward_signals.values(): update_stats.update(reward_provider.update(batch)) return update_stats