Exemple #1
0
    def calc_policy_loss(self, batch, advs):
        '''
        The PPO loss function (subscript t is omitted)
        L^{CLIP+VF+S} = E[ L^CLIP - c1 * L^VF + c2 * S[pi](s) ]

        Breakdown piecewise,
        1. L^CLIP = E[ min(ratio * A, clip(ratio, 1-eps, 1+eps) * A) ]
        where ratio = pi(a|s) / pi_old(a|s)

        2. L^VF = E[ mse(V(s_t), V^target) ]

        3. S = E[ entropy ]
        '''
        # decay clip_eps by episode
        clip_eps = policy_util._linear_decay(self.clip_eps,
                                             0.1 * self.clip_eps,
                                             self.clip_eps_anneal_epi,
                                             self.body.env.clock.get('epi'))

        # L^CLIP
        log_probs = policy_util.calc_log_probs(self, self.net, self.body,
                                               batch)
        old_log_probs = policy_util.calc_log_probs(self, self.old_net,
                                                   self.body, batch)
        assert log_probs.shape == old_log_probs.shape
        assert advs.shape[0] == log_probs.shape[0]  # batch size
        ratios = torch.exp(log_probs - old_log_probs)
        logger.debug(f'ratios: {ratios}')
        sur_1 = ratios * advs
        sur_2 = torch.clamp(ratios, 1.0 - clip_eps, 1.0 + clip_eps) * advs
        # flip sign because need to maximize
        clip_loss = -torch.mean(torch.min(sur_1, sur_2))
        logger.debug(f'clip_loss: {clip_loss}')

        # L^VF (inherit from ActorCritic)

        # S entropy bonus
        entropies = torch.stack(self.body.entropies)
        ent_penalty = torch.mean(-self.entropy_coef * entropies)
        logger.debug(f'ent_penalty: {ent_penalty}')

        policy_loss = clip_loss + ent_penalty
        if torch.cuda.is_available() and self.net.gpu:
            policy_loss = policy_loss.cuda()
        logger.debug(f'Actor policy loss: {policy_loss:.4f}')
        return policy_loss
Exemple #2
0
    def calc_loss(self, batch):
        '''
        The PPO loss function (subscript t is omitted)
        L^{CLIP+VF+S} = E[ L^CLIP - c1 * L^VF + c2 * S[pi](s) ]

        Breakdown piecewise,
        1. L^CLIP = E[ min(ratio * A, clip(ratio, 1-eps, 1+eps) * A) ]
        where ratio = pi(a|s) / pi_old(a|s)

        2. L^VF = E[ (V(s_t) - V^target)^2 ]

        3. S = E[ entropy ]
        '''
        # decay clip_eps by episode
        clip_eps = policy_util._linear_decay(self.clip_eps,
                                             0.1 * self.clip_eps,
                                             self.clip_eps_anneal_epi,
                                             self.body.env.clock.get('epi'))

        with torch.no_grad():
            adv_targets, v_targets = self.calc_gae_advs_v_targets(batch)

        # L^CLIP
        log_probs = self.calc_log_probs(batch, use_old_net=False)
        old_log_probs = self.calc_log_probs(batch, use_old_net=True)
        assert log_probs.shape == old_log_probs.shape
        assert adv_targets.shape == log_probs.shape
        ratios = torch.exp(log_probs - old_log_probs)
        sur_1 = ratios * adv_targets
        sur_2 = torch.clamp(ratios, 1.0 - clip_eps,
                            1.0 + clip_eps) * adv_targets
        # flip sign because need to maximize
        clip_loss = -torch.mean(torch.min(sur_1, sur_2))

        # L^VF
        val_loss = self.calc_val_loss(batch, v_targets)  # from critic

        # S entropy bonus
        ent_mean = torch.mean(torch.tensor(self.body.entropies))
        ent_penalty = -self.entropy_coef * ent_mean
        loss = clip_loss + val_loss + ent_penalty
        return loss
Exemple #3
0
    def calc_policy_loss(self, batch, advs):
        '''
        The PPO loss function (subscript t is omitted)
        L^{CLIP+VF+S} = E[ L^CLIP - c1 * L^VF + c2 * S[pi](s) ]

        Breakdown piecewise,
        1. L^CLIP = E[ min(ratio * A, clip(ratio, 1-eps, 1+eps) * A) ]
        where ratio = pi(a|s) / pi_old(a|s)

        2. L^VF = E[ mse(V(s_t), V^target) ]

        3. S = E[ entropy ]
        '''
        # decay clip_eps by episode
        clip_eps = policy_util._linear_decay(self.clip_eps,
                                             0.1 * self.clip_eps,
                                             self.clip_eps_anneal_epi,
                                             self.body.env.clock.get('epi'))

        # L^CLIP
        log_probs = self.calc_log_probs(batch, use_old_net=False)
        old_log_probs = self.calc_log_probs(batch, use_old_net=True)
        assert log_probs.shape == old_log_probs.shape
        assert advs.shape == log_probs.shape
        ratios = torch.exp(log_probs - old_log_probs)
        sur_1 = ratios * advs
        sur_2 = torch.clamp(ratios, 1.0 - clip_eps, 1.0 + clip_eps) * advs
        # flip sign because need to maximize
        clip_loss = -torch.mean(torch.min(sur_1, sur_2))

        # L^VF (inherit from ActorCritic)

        # S entropy bonus
        ent_penalty = 0
        for e in self.body.entropies:
            ent_penalty += (-self.entropy_coef * e)
        ent_penalty /= len(self.body.entropies)

        policy_loss = clip_loss + ent_penalty
        return policy_loss
Exemple #4
0
    def calc_policy_loss(self, batch, advs):
        '''
        The PPO loss function (subscript t is omitted)
        L^{CLIP+VF+S} = E[ L^CLIP - c1 * L^VF + c2 * S[pi](s) ]

        Breakdown piecewise,
        1. L^CLIP = E[ min(ratio * A, clip(ratio, 1-eps, 1+eps) * A) ]
        where ratio = pi(a|s) / pi_old(a|s)

        2. L^VF = E[ mse(V(s_t), V^target) ]

        3. S = E[ entropy ]
        '''
        # decay clip_eps by episode
        clip_eps = policy_util._linear_decay(self.clip_eps, 0.1 * self.clip_eps, self.clip_eps_anneal_epi, self.body.env.clock.get('epi'))

        # L^CLIP
        log_probs = self.calc_log_probs(batch, use_old_net=False)
        old_log_probs = self.calc_log_probs(batch, use_old_net=True)
        assert log_probs.shape == old_log_probs.shape
        assert advs.shape == log_probs.shape
        ratios = torch.exp(log_probs - old_log_probs)
        sur_1 = ratios * advs
        sur_2 = torch.clamp(ratios, 1.0 - clip_eps, 1.0 + clip_eps) * advs
        # flip sign because need to maximize
        clip_loss = -torch.mean(torch.min(sur_1, sur_2))

        # L^VF (inherit from ActorCritic)

        # S entropy bonus
        ent_penalty = 0
        for e in self.body.entropies:
            ent_penalty += (-self.entropy_coef * e)
        ent_penalty /= len(self.body.entropies)

        policy_loss = clip_loss + ent_penalty
        return policy_loss