def calc_policy_loss(self, batch, advs): ''' The PPO loss function (subscript t is omitted) L^{CLIP+VF+S} = E[ L^CLIP - c1 * L^VF + c2 * S[pi](s) ] Breakdown piecewise, 1. L^CLIP = E[ min(ratio * A, clip(ratio, 1-eps, 1+eps) * A) ] where ratio = pi(a|s) / pi_old(a|s) 2. L^VF = E[ mse(V(s_t), V^target) ] 3. S = E[ entropy ] ''' # decay clip_eps by episode clip_eps = policy_util._linear_decay(self.clip_eps, 0.1 * self.clip_eps, self.clip_eps_anneal_epi, self.body.env.clock.get('epi')) # L^CLIP log_probs = policy_util.calc_log_probs(self, self.net, self.body, batch) old_log_probs = policy_util.calc_log_probs(self, self.old_net, self.body, batch) assert log_probs.shape == old_log_probs.shape assert advs.shape[0] == log_probs.shape[0] # batch size ratios = torch.exp(log_probs - old_log_probs) logger.debug(f'ratios: {ratios}') sur_1 = ratios * advs sur_2 = torch.clamp(ratios, 1.0 - clip_eps, 1.0 + clip_eps) * advs # flip sign because need to maximize clip_loss = -torch.mean(torch.min(sur_1, sur_2)) logger.debug(f'clip_loss: {clip_loss}') # L^VF (inherit from ActorCritic) # S entropy bonus entropies = torch.stack(self.body.entropies) ent_penalty = torch.mean(-self.entropy_coef * entropies) logger.debug(f'ent_penalty: {ent_penalty}') policy_loss = clip_loss + ent_penalty if torch.cuda.is_available() and self.net.gpu: policy_loss = policy_loss.cuda() logger.debug(f'Actor policy loss: {policy_loss:.4f}') return policy_loss
def calc_policy_loss(self, batch, advs): ''' The PPO loss function (subscript t is omitted) L^{CLIP+VF+S} = E[ L^CLIP - c1 * L^VF + c2 * S[pi](s) ] Breakdown piecewise, 1. L^CLIP = E[ min(ratio * A, clip(ratio, 1-eps, 1+eps) * A) ] where ratio = pi(a|s) / pi_old(a|s) 2. L^VF = E[ mse(V(s_t), V^target) ] 3. S = E[ entropy ] ''' clip_eps = self.body.clip_eps # L^CLIP log_probs = policy_util.calc_log_probs(self, self.net, self.body, batch) old_log_probs = policy_util.calc_log_probs(self, self.old_net, self.body, batch).detach() assert log_probs.shape == old_log_probs.shape assert advs.shape[0] == log_probs.shape[0] # batch size ratios = torch.exp(log_probs - old_log_probs) # clip to prevent overflow logger.debug(f'ratios: {ratios}') sur_1 = ratios * advs sur_2 = torch.clamp(ratios, 1.0 - clip_eps, 1.0 + clip_eps) * advs # flip sign because need to maximize clip_loss = -torch.mean(torch.min(sur_1, sur_2)) logger.debug(f'clip_loss: {clip_loss}') # L^VF (inherit from ActorCritic) # S entropy bonus entropies = torch.stack(self.body.entropies) ent_penalty = torch.mean(-self.body.entropy_coef * entropies) logger.debug(f'ent_penalty: {ent_penalty}') # Store mean entropy for debug logging self.body.mean_entropy = torch.mean(torch.tensor( self.body.entropies)).item() policy_loss = clip_loss + ent_penalty logger.debug(f'PPO Actor policy loss: {policy_loss:.4f}') return policy_loss
def calc_sil_policy_val_loss(self, batch): ''' Calculate the SIL policy losses for actor and critic sil_policy_loss = -log_prob * max(R - v_pred, 0) sil_val_loss = (max(R - v_pred, 0)^2) / 2 This is called on a randomly-sample batch from experience replay ''' returns = batch['rets'] v_preds = self.calc_v(batch['states'], evaluate=False) clipped_advs = torch.clamp(returns - v_preds, min=0.0) log_probs = policy_util.calc_log_probs(self, self.net, self.body, batch) sil_policy_loss = self.sil_policy_loss_coef * torch.mean(- log_probs * clipped_advs) sil_val_loss = self.sil_val_loss_coef * torch.pow(clipped_advs, 2) / 2 sil_val_loss = torch.mean(sil_val_loss) logger.debug(f'SIL actor policy loss: {sil_policy_loss:.4f}') logger.debug(f'SIL critic value loss: {sil_val_loss:.4f}') return sil_policy_loss, sil_val_loss