Beispiel #1
0
    def calc_policy_loss(self, batch, pdparams, advs):
        '''
        The PPO loss function (subscript t is omitted)
        L^{CLIP+VF+S} = E[ L^CLIP - c1 * L^VF + c2 * S[pi](s) ]

        Breakdown piecewise,
        1. L^CLIP = E[ min(ratio * A, clip(ratio, 1-eps, 1+eps) * A) ]
        where ratio = pi(a|s) / pi_old(a|s)

        2. L^VF = E[ mse(V(s_t), V^target) ]

        3. S = E[ entropy ]
        '''
        clip_eps = self.body.clip_eps
        action_pd = policy_util.init_action_pd(self.body.ActionPD, pdparams)
        states = batch['states']
        actions = batch['actions']
        if self.body.env.is_venv:
            states = math_util.venv_unpack(states)
            actions = math_util.venv_unpack(actions)

        # L^CLIP
        log_probs = action_pd.log_prob(actions)
        with torch.no_grad():
            old_pdparams = self.calc_pdparam(states, net=self.old_net)
            old_action_pd = policy_util.init_action_pd(self.body.ActionPD,
                                                       old_pdparams)
            old_log_probs = old_action_pd.log_prob(actions)
        assert log_probs.shape == old_log_probs.shape
        ratios = torch.exp(log_probs -
                           old_log_probs)  # clip to prevent overflow
        logger.debug(f'ratios: {ratios}')
        sur_1 = ratios * advs
        sur_2 = torch.clamp(ratios, 1.0 - clip_eps, 1.0 + clip_eps) * advs
        # flip sign because need to maximize
        clip_loss = -torch.min(sur_1, sur_2).mean()
        logger.debug(f'clip_loss: {clip_loss}')

        # L^VF (inherit from ActorCritic)

        # S entropy bonus
        entropy = action_pd.entropy().mean()
        self.body.mean_entropy = entropy  # update logging variable
        ent_penalty = -self.body.entropy_coef * entropy
        logger.debug(f'ent_penalty: {ent_penalty}')

        policy_loss = clip_loss + ent_penalty
        logger.debug(f'PPO Actor policy loss: {policy_loss:g}')
        return policy_loss
Beispiel #2
0
    def fetch_airl_reward(self, batch):
        self.disc_training_count += 1
        if self.disc_training_count >= self.disc_training_freq:
            self.policy_training_flag = True
            self.disc_training_count = 0

        self.discriminator.eval()
        pdparams, _ = self.calc_pdparam_v(batch)
        action_pd = policy_util.init_action_pd(self.body.ActionPD, pdparams)
        actions = batch['actions']
        log_probs = action_pd.log_prob(actions)
        weight = self.discriminator.get_reward(batch)
        assert log_probs.shape == weight.shape
        reward = (weight - log_probs.view(-1)).detach()

        print('Disc reward: {}'.format(weight.mean().item()))

        batch['rewards'] = reward

        self.experience_buffer.append(copy.deepcopy(batch))
        self.discriminator.train()
        self.airl_train(self.disc_training_times)
        self.discriminator.eval()

        # if not self.pretrain_finished:
        # return batch
        # self.experience_buffer.append(copy.deepcopy(batch))
        # self.discriminator.train()
        # self.airl_train(5)
        # self.discriminator.eval()

        self.batch_count += 1
        return batch
Beispiel #3
0
    def imitate_loop(self):
        real_state, real_action = self.discriminator.sample_real_batch_id()
        batch = {'states': real_state, 'actions': real_action}
        pdparams, _ = self.calc_pdparam_v(batch)
        action_pd = policy_util.init_action_pd(self.body.ActionPD, pdparams)
        log_probs = action_pd.log_prob(real_action)
        imitate_loss = -log_probs.mean()

        return imitate_loss
Beispiel #4
0
 def calc_policy_loss(self, batch, pdparams, advs):
     '''Calculate the actor's policy loss'''
     action_pd = policy_util.init_action_pd(self.body.ActionPD, pdparams)
     actions = batch['actions']
     if self.body.env.is_venv:
         actions = math_util.venv_unpack(actions)
     log_probs = action_pd.log_prob(actions)
     policy_loss = - self.policy_loss_coef * (log_probs * advs).mean()
     if self.entropy_coef_spec:
         entropy = action_pd.entropy().mean()
         self.body.mean_entropy = entropy  # update logging variable
         policy_loss += (-self.body.entropy_coef * entropy)
     logger.debug(f'Actor policy loss: {policy_loss:g}')
     return policy_loss
Beispiel #5
0
    def calc_sil_policy_val_loss(self, batch, pdparams):
        '''
        Calculate the SIL policy losses for actor and critic
        sil_policy_loss = -log_prob * max(R - v_pred, 0)
        sil_val_loss = (max(R - v_pred, 0)^2) / 2
        This is called on a randomly-sample batch from experience replay
        '''
        v_preds = self.calc_v(batch['states'], use_cache=False)
        rets = math_util.calc_returns(batch['rewards'], batch['dones'],
                                      self.gamma)
        clipped_advs = torch.clamp(rets - v_preds, min=0.0)

        action_pd = policy_util.init_action_pd(self.body.ActionPD, pdparams)
        actions = batch['actions']
        if self.body.env.is_venv:
            actions = math_util.venv_unpack(actions)
        log_probs = action_pd.log_prob(actions)

        sil_policy_loss = -self.sil_policy_loss_coef * (log_probs *
                                                        clipped_advs).mean()
        sil_val_loss = self.sil_val_loss_coef * clipped_advs.pow(2).mean() / 2
        logger.debug(f'SIL actor policy loss: {sil_policy_loss:g}')
        logger.debug(f'SIL critic value loss: {sil_val_loss:g}')
        return sil_policy_loss, sil_val_loss