def update(self, rollouts: Sequence[StepSequence], use_empirical_returns: bool = False): """ Adapt the parameters of the advantage function estimator, minimizing the MSE loss for the given samples. :param rollouts: batch of rollouts :param use_empirical_returns: use the return from the rollout (True) or the ones from the V-fcn (False) :return adv: tensor of advantages after V-function updates """ # Turn the batch of rollouts into a list of steps concat_ros = StepSequence.concat(rollouts) concat_ros.torch(data_type=to.get_default_dtype()) if use_empirical_returns: # Compute the value targets (empirical discounted returns) for all samples v_targ = discounted_values(rollouts, self.gamma).view(-1, 1) else: # Use the value function to compute the value targets (also called bootstrapping) v_targ = self.tdlamda_returns(concat_ros=concat_ros) concat_ros.add_data('v_targ', v_targ) # Logging with to.no_grad(): v_pred_old = self.values(concat_ros) loss_old = self.loss_fcn(v_pred_old, v_targ) vfcn_grad_norm = [] # Iterate over all gathered samples num_epoch times for e in range(self.num_epoch): for batch in tqdm(concat_ros.split_shuffled_batches( self.batch_size, complete_rollouts=isinstance(self.vfcn, RecurrentPolicy)), total=num_iter_from_rollouts(None, concat_ros, self.batch_size), desc=f'Epoch {e}', unit='batches', file=sys.stdout, leave=False): # Reset the gradients self.optim.zero_grad() # Make predictions for this mini-batch using values function v_pred = self.values(batch) # Compute estimator loss for this mini-batch and backpropagate vfcn_loss = self.loss_fcn(v_pred, batch.v_targ) vfcn_loss.backward() # Clip the gradients if desired vfcn_grad_norm.append(Algorithm.clip_grad(self.vfcn, self.max_grad_norm)) # Call optimizer self.optim.step() # Update the learning rate if a scheduler has been specified if self._lr_scheduler is not None: self._lr_scheduler.step() # Estimate the advantage after fitting the parameters of the V-fcn adv = self.gae(concat_ros) # is done with to.no_grad() with to.no_grad(): v_pred_new = self.values(concat_ros) loss_new = self.loss_fcn(v_pred_new, v_targ) vfcn_loss_impr = loss_old - loss_new # positive values are desired explvar = explained_var(v_pred_new, v_targ) # values close to 1 are desired # Log metrics computed from the old value function (before the update) self.logger.add_value('explained var critic', explvar, 4) self.logger.add_value('loss improv critic', vfcn_loss_impr, 4) self.logger.add_value('avg grad norm critic', np.mean(vfcn_grad_norm), 4) if self._lr_scheduler is not None: self.logger.add_value('lr critic', self._lr_scheduler.get_last_lr(), 6) return adv
def update(self, rollouts: Sequence[StepSequence]): # Turn the batch of rollouts into a list of steps concat_ros = StepSequence.concat(rollouts) concat_ros.torch(data_type=to.get_default_dtype()) with to.no_grad(): # Compute the action probabilities using the old (before update) policy act_stats = compute_action_statistics(concat_ros, self._expl_strat) log_probs_old = act_stats.log_probs act_distr_old = act_stats.act_distr # Compute value predictions using the old old (before update) value function v_pred_old = self._critic.values(concat_ros) # Attach advantages and old log probs to rollout concat_ros.add_data('log_probs_old', log_probs_old) concat_ros.add_data('v_pred_old', v_pred_old) # For logging the gradient norms policy_grad_norm = [] value_fcn_grad_norm = [] # Compute the value targets (empirical discounted returns) for all samples before fitting the V-fcn parameters adv = self._critic.gae(concat_ros) # done with to.no_grad() v_targ = discounted_values(rollouts, self._critic.gamma).view( -1, 1) # empirical discounted returns concat_ros.add_data('adv', adv) concat_ros.add_data('v_targ', v_targ) # Iterations over the whole data set for e in range(self.num_epoch): for batch in tqdm(concat_ros.split_shuffled_batches( self.batch_size, complete_rollouts=self._policy.is_recurrent or isinstance(self._critic.value_fcn, RecurrentPolicy)), total=num_iter_from_rollouts( None, concat_ros, self.batch_size), desc=f'Epoch {e}', unit='batches', file=sys.stdout, leave=False): # Reset the gradients self.optim.zero_grad() # Compute log of the action probabilities for the mini-batch log_probs = compute_action_statistics( batch, self._expl_strat).log_probs.to(self.policy.device) # Compute value predictions for the mini-batch v_pred = self._critic.values(batch) # Compute combined loss and backpropagate loss = self.loss_fcn(log_probs, batch.log_probs_old, batch.adv, v_pred, batch.v_pred_old, batch.v_targ) loss.backward() # Clip the gradients if desired policy_grad_norm.append( self.clip_grad(self._expl_strat.policy, self.max_grad_norm)) value_fcn_grad_norm.append( self.clip_grad(self._critic.value_fcn, self.max_grad_norm)) # Call optimizer self.optim.step() if to.isnan(self._expl_strat.noise.std).any(): raise RuntimeError( f'At least one exploration parameter became NaN! The exploration parameters are' f'\n{self._expl_strat.std.detach().numpy()}') # Update the learning rate if a scheduler has been specified if self._lr_scheduler is not None: self._lr_scheduler.step() # Additional logging if self.log_loss: with to.no_grad(): # Compute value predictions using the new (after the updates) value function approximator v_pred = self._critic.values(concat_ros).to(self.policy.device) v_loss_old = self._critic.loss_fcn( v_pred_old.to(self.policy.device), v_targ.to(self.policy.device)).to(self.policy.device) v_loss_new = self._critic.loss_fcn(v_pred, v_targ).to( self.policy.device) value_fcn_loss_impr = v_loss_old - v_loss_new # positive values are desired # Compute the action probabilities using the new (after the updates) policy act_stats = compute_action_statistics(concat_ros, self._expl_strat) log_probs_new = act_stats.log_probs act_distr_new = act_stats.act_distr loss_after = self.loss_fcn(log_probs_new, log_probs_old, adv, v_pred, v_pred_old, v_targ) kl_avg = to.mean(kl_divergence( act_distr_old, act_distr_new)) # mean seeking a.k.a. inclusive KL # Compute explained variance (after the updates) explvar = explained_var(v_pred, v_targ) self.logger.add_value('explained var', explvar.detach().numpy()) self.logger.add_value('V-fcn loss improvement', value_fcn_loss_impr.detach().numpy()) self.logger.add_value('loss after', loss_after.detach().numpy()) self.logger.add_value('KL(old_new)', kl_avg.item()) # Logging self.logger.add_value( 'avg expl strat std', to.mean(self._expl_strat.noise.std.data).detach().numpy()) self.logger.add_value('expl strat entropy', self._expl_strat.noise.get_entropy().item()) self.logger.add_value('avg policy grad norm', np.mean(policy_grad_norm)) self.logger.add_value('avg V-fcn grad norm', np.mean(value_fcn_grad_norm)) if self._lr_scheduler is not None: self.logger.add_value('learning rate', self._lr_scheduler.get_lr())
def update(self, rollouts: Sequence[StepSequence]): # Turn the batch of rollouts into a list of steps concat_ros = StepSequence.concat(rollouts) concat_ros.torch(data_type=to.get_default_dtype()) # Compute the value targets (empirical discounted returns) for all samples before fitting the V-fcn parameters adv = self._critic.gae(concat_ros) # done with to.no_grad() v_targ = discounted_values(rollouts, self._critic.gamma).view(-1, 1).to(self.policy.device) # empirical discounted returns with to.no_grad(): # Compute value predictions and the GAE using the old (before the updates) value function approximator v_pred = self._critic.values(concat_ros) # Compute the action probabilities using the old (before update) policy act_stats = compute_action_statistics(concat_ros, self._expl_strat) log_probs_old = act_stats.log_probs act_distr_old = act_stats.act_distr loss_before = self.loss_fcn(log_probs_old, adv, v_pred, v_targ) self.logger.add_value('loss before', loss_before, 4) concat_ros.add_data('adv', adv) concat_ros.add_data('v_targ', v_targ) # For logging the gradients' norms policy_grad_norm = [] for batch in tqdm(concat_ros.split_shuffled_batches( self.batch_size, complete_rollouts=self._policy.is_recurrent or isinstance(self._critic.vfcn, RecurrentPolicy)), total=num_iter_from_rollouts(None, concat_ros, self.batch_size), desc='Updating', unit='batches', file=sys.stdout, leave=False): # Reset the gradients self.optim.zero_grad() # Compute log of the action probabilities for the mini-batch log_probs = compute_action_statistics(batch, self._expl_strat).log_probs # Compute value predictions for the mini-batch v_pred = self._critic.values(batch) # Compute combined loss and backpropagate loss = self.loss_fcn(log_probs, batch.adv, v_pred, batch.v_targ) loss.backward() # Clip the gradients if desired policy_grad_norm.append(self.clip_grad(self.expl_strat.policy, self.max_grad_norm)) # Call optimizer self.optim.step() # Update the learning rate if a scheduler has been specified if self._lr_scheduler is not None: self._lr_scheduler.step() if to.isnan(self.expl_strat.noise.std).any(): raise RuntimeError(f'At least one exploration parameter became NaN! The exploration parameters are' f'\n{self.expl_strat.std.item()}') # Logging with to.no_grad(): # Compute value predictions and the GAE using the new (after the updates) value function approximator v_pred = self._critic.values(concat_ros).to(self.policy.device) adv = self._critic.gae(concat_ros) # done with to.no_grad() # Compute the action probabilities using the new (after the updates) policy act_stats = compute_action_statistics(concat_ros, self._expl_strat) log_probs_new = act_stats.log_probs act_distr_new = act_stats.act_distr loss_after = self.loss_fcn(log_probs_new, adv, v_pred, v_targ) kl_avg = to.mean( kl_divergence(act_distr_old, act_distr_new)) # mean seeking a.k.a. inclusive KL explvar = explained_var(v_pred, v_targ) # values close to 1 are desired self.logger.add_value('loss after', loss_after, 4) self.logger.add_value('KL(old_new)', kl_avg, 4) self.logger.add_value('explained var', explvar, 4) ent = self.expl_strat.noise.get_entropy() self.logger.add_value('avg expl strat std', to.mean(self.expl_strat.noise.std), 4) self.logger.add_value('expl strat entropy', to.mean(ent), 4) self.logger.add_value('avg grad norm policy', np.mean(policy_grad_norm), 4) if self._lr_scheduler is not None: self.logger.add_value('avg lr', np.mean(self._lr_scheduler.get_last_lr()), 6)