예제 #1
0
    def calc_policy_loss(self, batch, advs):
        '''
        The PPO loss function (subscript t is omitted)
        L^{CLIP+VF+S} = E[ L^CLIP - c1 * L^VF + c2 * S[pi](s) ]

        Breakdown piecewise,
        1. L^CLIP = E[ min(ratio * A, clip(ratio, 1-eps, 1+eps) * A) ]
        where ratio = pi(a|s) / pi_old(a|s)

        2. L^VF = E[ mse(V(s_t), V^target) ]

        3. S = E[ entropy ]
        '''
        # decay clip_eps by episode
        clip_eps = policy_util._linear_decay(self.clip_eps,
                                             0.1 * self.clip_eps,
                                             self.clip_eps_anneal_epi,
                                             self.body.env.clock.get('epi'))

        # L^CLIP
        log_probs = policy_util.calc_log_probs(self, self.net, self.body,
                                               batch)
        old_log_probs = policy_util.calc_log_probs(self, self.old_net,
                                                   self.body, batch)
        assert log_probs.shape == old_log_probs.shape
        assert advs.shape[0] == log_probs.shape[0]  # batch size
        ratios = torch.exp(log_probs - old_log_probs)
        logger.debug(f'ratios: {ratios}')
        sur_1 = ratios * advs
        sur_2 = torch.clamp(ratios, 1.0 - clip_eps, 1.0 + clip_eps) * advs
        # flip sign because need to maximize
        clip_loss = -torch.mean(torch.min(sur_1, sur_2))
        logger.debug(f'clip_loss: {clip_loss}')

        # L^VF (inherit from ActorCritic)

        # S entropy bonus
        entropies = torch.stack(self.body.entropies)
        ent_penalty = torch.mean(-self.entropy_coef * entropies)
        logger.debug(f'ent_penalty: {ent_penalty}')

        policy_loss = clip_loss + ent_penalty
        if torch.cuda.is_available() and self.net.gpu:
            policy_loss = policy_loss.cuda()
        logger.debug(f'Actor policy loss: {policy_loss:.4f}')
        return policy_loss
예제 #2
0
파일: ppo.py 프로젝트: wilson1yan/SLM-Lab
    def calc_policy_loss(self, batch, advs):
        '''
        The PPO loss function (subscript t is omitted)
        L^{CLIP+VF+S} = E[ L^CLIP - c1 * L^VF + c2 * S[pi](s) ]

        Breakdown piecewise,
        1. L^CLIP = E[ min(ratio * A, clip(ratio, 1-eps, 1+eps) * A) ]
        where ratio = pi(a|s) / pi_old(a|s)

        2. L^VF = E[ mse(V(s_t), V^target) ]

        3. S = E[ entropy ]
        '''
        clip_eps = self.body.clip_eps

        # L^CLIP
        log_probs = policy_util.calc_log_probs(self, self.net, self.body,
                                               batch)
        old_log_probs = policy_util.calc_log_probs(self, self.old_net,
                                                   self.body, batch).detach()
        assert log_probs.shape == old_log_probs.shape
        assert advs.shape[0] == log_probs.shape[0]  # batch size
        ratios = torch.exp(log_probs -
                           old_log_probs)  # clip to prevent overflow
        logger.debug(f'ratios: {ratios}')
        sur_1 = ratios * advs
        sur_2 = torch.clamp(ratios, 1.0 - clip_eps, 1.0 + clip_eps) * advs
        # flip sign because need to maximize
        clip_loss = -torch.mean(torch.min(sur_1, sur_2))
        logger.debug(f'clip_loss: {clip_loss}')

        # L^VF (inherit from ActorCritic)

        # S entropy bonus
        entropies = torch.stack(self.body.entropies)
        ent_penalty = torch.mean(-self.body.entropy_coef * entropies)
        logger.debug(f'ent_penalty: {ent_penalty}')
        # Store mean entropy for debug logging
        self.body.mean_entropy = torch.mean(torch.tensor(
            self.body.entropies)).item()

        policy_loss = clip_loss + ent_penalty
        logger.debug(f'PPO Actor policy loss: {policy_loss:.4f}')
        return policy_loss
예제 #3
0
파일: sil.py 프로젝트: wilson1yan/SLM-Lab
    def calc_sil_policy_val_loss(self, batch):
        '''
        Calculate the SIL policy losses for actor and critic
        sil_policy_loss = -log_prob * max(R - v_pred, 0)
        sil_val_loss = (max(R - v_pred, 0)^2) / 2
        This is called on a randomly-sample batch from experience replay
        '''
        returns = batch['rets']
        v_preds = self.calc_v(batch['states'], evaluate=False)
        clipped_advs = torch.clamp(returns - v_preds, min=0.0)
        log_probs = policy_util.calc_log_probs(self, self.net, self.body, batch)

        sil_policy_loss = self.sil_policy_loss_coef * torch.mean(- log_probs * clipped_advs)
        sil_val_loss = self.sil_val_loss_coef * torch.pow(clipped_advs, 2) / 2
        sil_val_loss = torch.mean(sil_val_loss)
        logger.debug(f'SIL actor policy loss: {sil_policy_loss:.4f}')
        logger.debug(f'SIL critic value loss: {sil_val_loss:.4f}')
        return sil_policy_loss, sil_val_loss