def calc_policy_loss(self, batch, pdparams, advs): ''' The PPO loss function (subscript t is omitted) L^{CLIP+VF+S} = E[ L^CLIP - c1 * L^VF + c2 * S[pi](s) ] Breakdown piecewise, 1. L^CLIP = E[ min(ratio * A, clip(ratio, 1-eps, 1+eps) * A) ] where ratio = pi(a|s) / pi_old(a|s) 2. L^VF = E[ mse(V(s_t), V^target) ] 3. S = E[ entropy ] ''' clip_eps = self.body.clip_eps action_pd = policy_util.init_action_pd(self.body.ActionPD, pdparams) states = batch['states'] actions = batch['actions'] if self.body.env.is_venv: states = math_util.venv_unpack(states) actions = math_util.venv_unpack(actions) # L^CLIP log_probs = action_pd.log_prob(actions) with torch.no_grad(): old_pdparams = self.calc_pdparam(states, net=self.old_net) old_action_pd = policy_util.init_action_pd(self.body.ActionPD, old_pdparams) old_log_probs = old_action_pd.log_prob(actions) assert log_probs.shape == old_log_probs.shape ratios = torch.exp(log_probs - old_log_probs) # clip to prevent overflow logger.debug(f'ratios: {ratios}') sur_1 = ratios * advs sur_2 = torch.clamp(ratios, 1.0 - clip_eps, 1.0 + clip_eps) * advs # flip sign because need to maximize clip_loss = -torch.min(sur_1, sur_2).mean() logger.debug(f'clip_loss: {clip_loss}') # L^VF (inherit from ActorCritic) # S entropy bonus entropy = action_pd.entropy().mean() self.body.mean_entropy = entropy # update logging variable ent_penalty = -self.body.entropy_coef * entropy logger.debug(f'ent_penalty: {ent_penalty}') policy_loss = clip_loss + ent_penalty logger.debug(f'PPO Actor policy loss: {policy_loss:g}') return policy_loss
def fetch_airl_reward(self, batch): self.disc_training_count += 1 if self.disc_training_count >= self.disc_training_freq: self.policy_training_flag = True self.disc_training_count = 0 self.discriminator.eval() pdparams, _ = self.calc_pdparam_v(batch) action_pd = policy_util.init_action_pd(self.body.ActionPD, pdparams) actions = batch['actions'] log_probs = action_pd.log_prob(actions) weight = self.discriminator.get_reward(batch) assert log_probs.shape == weight.shape reward = (weight - log_probs.view(-1)).detach() print('Disc reward: {}'.format(weight.mean().item())) batch['rewards'] = reward self.experience_buffer.append(copy.deepcopy(batch)) self.discriminator.train() self.airl_train(self.disc_training_times) self.discriminator.eval() # if not self.pretrain_finished: # return batch # self.experience_buffer.append(copy.deepcopy(batch)) # self.discriminator.train() # self.airl_train(5) # self.discriminator.eval() self.batch_count += 1 return batch
def imitate_loop(self): real_state, real_action = self.discriminator.sample_real_batch_id() batch = {'states': real_state, 'actions': real_action} pdparams, _ = self.calc_pdparam_v(batch) action_pd = policy_util.init_action_pd(self.body.ActionPD, pdparams) log_probs = action_pd.log_prob(real_action) imitate_loss = -log_probs.mean() return imitate_loss
def calc_policy_loss(self, batch, pdparams, advs): '''Calculate the actor's policy loss''' action_pd = policy_util.init_action_pd(self.body.ActionPD, pdparams) actions = batch['actions'] if self.body.env.is_venv: actions = math_util.venv_unpack(actions) log_probs = action_pd.log_prob(actions) policy_loss = - self.policy_loss_coef * (log_probs * advs).mean() if self.entropy_coef_spec: entropy = action_pd.entropy().mean() self.body.mean_entropy = entropy # update logging variable policy_loss += (-self.body.entropy_coef * entropy) logger.debug(f'Actor policy loss: {policy_loss:g}') return policy_loss
def calc_sil_policy_val_loss(self, batch, pdparams): ''' Calculate the SIL policy losses for actor and critic sil_policy_loss = -log_prob * max(R - v_pred, 0) sil_val_loss = (max(R - v_pred, 0)^2) / 2 This is called on a randomly-sample batch from experience replay ''' v_preds = self.calc_v(batch['states'], use_cache=False) rets = math_util.calc_returns(batch['rewards'], batch['dones'], self.gamma) clipped_advs = torch.clamp(rets - v_preds, min=0.0) action_pd = policy_util.init_action_pd(self.body.ActionPD, pdparams) actions = batch['actions'] if self.body.env.is_venv: actions = math_util.venv_unpack(actions) log_probs = action_pd.log_prob(actions) sil_policy_loss = -self.sil_policy_loss_coef * (log_probs * clipped_advs).mean() sil_val_loss = self.sil_val_loss_coef * clipped_advs.pow(2).mean() / 2 logger.debug(f'SIL actor policy loss: {sil_policy_loss:g}') logger.debug(f'SIL critic value loss: {sil_val_loss:g}') return sil_policy_loss, sil_val_loss