Example #1
0
    def calc_policy_loss(self, batch, pdparams, advs):
        '''
        The PPO loss function (subscript t is omitted)
        L^{CLIP+VF+S} = E[ L^CLIP - c1 * L^VF + c2 * H[pi](s) ]

        Breakdown piecewise,
        1. L^CLIP = E[ min(ratio * A, clip(ratio, 1-eps, 1+eps) * A) ]
        where ratio = pi(a|s) / pi_old(a|s)

        2. L^VF = E[ mse(V(s_t), V^target) ]

        3. H = E[ entropy ]
        '''
        clip_eps = self.body.clip_eps
        action_pd = policy_util.init_action_pd(self.body.ActionPD, pdparams)
        states = batch['states']
        actions = batch['actions']
        if self.body.env.is_venv:
            states = math_util.venv_unpack(states)
            actions = math_util.venv_unpack(actions)

        # L^CLIP
        log_probs = action_pd.log_prob(actions)
        with torch.no_grad():
            old_pdparams = self.calc_pdparam(states, net=self.old_net)
            old_action_pd = policy_util.init_action_pd(self.body.ActionPD,
                                                       old_pdparams)
            old_log_probs = old_action_pd.log_prob(actions)
        assert log_probs.shape == old_log_probs.shape
        ratios = torch.exp(log_probs - old_log_probs)
        logger.debug(f'ratios: {ratios}')
        sur_1 = ratios * advs
        sur_2 = torch.clamp(ratios, 1.0 - clip_eps, 1.0 + clip_eps) * advs
        # flip sign because need to maximize
        clip_loss = -torch.min(sur_1, sur_2).mean()
        logger.debug(f'clip_loss: {clip_loss}')

        # L^VF (inherit from ActorCritic)

        # H entropy regularization
        entropy = action_pd.entropy().mean()
        self.body.mean_entropy = entropy  # update logging variable
        ent_penalty = -self.body.entropy_coef * entropy
        logger.debug(f'ent_penalty: {ent_penalty}')

        policy_loss = clip_loss + ent_penalty
        logger.debug(f'PPO Actor policy loss: {policy_loss:g}')
        return policy_loss
Example #2
0
    def calc_q_targets(self, batch):
        '''Q_tar = r + gamma * (target_Q(s', a') - alpha * log pi(a'|s'))'''
        next_states = batch['next_states']
        with torch.no_grad():
            pdparams = self.calc_pdparam(next_states)
            action_pd = policy_util.init_action_pd(self.body.ActionPD, pdparams)
            next_log_probs, next_actions = self.calc_log_prob_action(action_pd)
            next_actions = self.guard_q_actions(next_actions)  # non-reparam discrete actions need to be converted into one-hot

            next_target_q1_preds = self.calc_q(next_states, next_actions, self.target_q1_net)
            next_target_q2_preds = self.calc_q(next_states, next_actions, self.target_q2_net)
            next_target_q_preds = torch.min(next_target_q1_preds, next_target_q2_preds)
            q_targets = batch['rewards'] + self.gamma * (1 - batch['dones']) * (next_target_q_preds - self.alpha * next_log_probs)
        return q_targets
Example #3
0
 def calc_log_probs(self, batch):
     '''Helper method to calculate log_probs for a randomly sampled batch'''
     states, actions = batch['states'], batch['actions']
     # get ActionPD, don't append to state_buffer
     ActionPD, _pdparam, _body = policy_util.init_action_pd(states[0].cpu().numpy(), self, self.body, append=False)
     # construct log_probs for each state-action
     pdparams = self.calc_pdparam(states, evaluate=False)
     log_probs = []
     for idx, pdparam in enumerate(pdparams):
         _action, action_pd = policy_util.sample_action_pd(ActionPD, pdparam, self.body)
         log_prob = action_pd.log_prob(actions[idx])
         log_probs.append(log_prob)
     log_probs = torch.stack(log_probs)
     return log_probs
Example #4
0
 def calc_policy_loss(self, batch, pdparams, advs):
     '''Calculate the actor's policy loss'''
     action_pd = policy_util.init_action_pd(self.body.ActionPD, pdparams)
     actions = batch['actions']
     if self.body.env.is_venv:
         actions = math_util.venv_unpack(actions)
     log_probs = action_pd.log_prob(actions)
     policy_loss = - self.policy_loss_coef * (log_probs * advs).mean()
     if self.entropy_coef_spec:
         entropy = action_pd.entropy().mean()
         self.body.mean_entropy = entropy  # update logging variable
         policy_loss += (-self.body.entropy_coef * entropy)
     logger.debug(f'Actor policy loss: {policy_loss:g}')
     return policy_loss
Example #5
0
File: sil.py Project: shlpu/SLM-Lab
 def calc_log_probs(self, batch):
     '''Helper method to calculate log_probs for a randomly sampled batch'''
     states, actions = batch['states'], batch['actions']
     # get ActionPD, don't append to state_buffer
     ActionPD, _pdparam, _body = policy_util.init_action_pd(
         states[0].cpu().numpy(), self, self.body, append=False)
     # construct log_probs for each state-action
     pdparams = self.calc_pdparam(states)
     log_probs = []
     for idx, pdparam in enumerate(pdparams):
         _action, action_pd = policy_util.sample_action_pd(
             ActionPD, pdparam, self.body)
         log_prob = action_pd.log_prob(actions[idx])
         log_probs.append(log_prob)
     log_probs = torch.tensor(log_probs)
     return log_probs
Example #6
0
    def train(self):
        '''Train actor critic by computing the loss in batch efficiently'''
        if util.in_eval_lab_modes():
            return np.nan
        clock = self.body.env.clock
        if self.to_train == 1:
            for _ in range(self.training_iter):
                batch = self.sample()
                clock.set_batch_size(len(batch))

                states = batch['states']
                actions = self.guard_q_actions(batch['actions'])
                q_targets = self.calc_q_targets(batch)
                # Q-value loss for both Q nets
                q1_preds = self.calc_q(states, actions, self.q1_net)
                q1_loss = self.calc_reg_loss(q1_preds, q_targets)
                self.q1_net.train_step(q1_loss, self.q1_optim, self.q1_lr_scheduler, clock=clock, global_net=self.global_q1_net)

                q2_preds = self.calc_q(states, actions, self.q2_net)
                q2_loss = self.calc_reg_loss(q2_preds, q_targets)
                self.q2_net.train_step(q2_loss, self.q2_optim, self.q2_lr_scheduler, clock=clock, global_net=self.global_q2_net)

                # policy loss
                action_pd = policy_util.init_action_pd(self.body.ActionPD, self.calc_pdparam(states))
                log_probs, reparam_actions = self.calc_log_prob_action(action_pd, reparam=True)
                policy_loss = self.calc_policy_loss(batch, log_probs, reparam_actions)
                self.net.train_step(policy_loss, self.optim, self.lr_scheduler, clock=clock, global_net=self.global_net)

                # alpha loss
                alpha_loss = self.calc_alpha_loss(log_probs)
                self.train_alpha(alpha_loss)

                loss = q1_loss + q2_loss + policy_loss + alpha_loss
                # update target networks
                self.update_nets()
                # update PER priorities if availalbe
                self.try_update_per(torch.min(q1_preds, q2_preds), q_targets)

            # reset
            self.to_train = 0
            logger.debug(f'Trained {self.name} at epi: {clock.epi}, frame: {clock.frame}, t: {clock.t}, total_reward so far: {self.body.env.total_reward}, loss: {loss:g}')
            return loss.item()
        else:
            return np.nan
Example #7
0
File: sil.py Project: kengz/SLM-Lab
    def calc_sil_policy_val_loss(self, batch, pdparams):
        '''
        Calculate the SIL policy losses for actor and critic
        sil_policy_loss = -log_prob * max(R - v_pred, 0)
        sil_val_loss = (max(R - v_pred, 0)^2) / 2
        This is called on a randomly-sample batch from experience replay
        '''
        v_preds = self.calc_v(batch['states'], use_cache=False)
        rets = math_util.calc_returns(batch['rewards'], batch['dones'], self.gamma)
        clipped_advs = torch.clamp(rets - v_preds, min=0.0)

        action_pd = policy_util.init_action_pd(self.body.ActionPD, pdparams)
        actions = batch['actions']
        if self.body.env.is_venv:
            actions = math_util.venv_unpack(actions)
        log_probs = action_pd.log_prob(actions)

        sil_policy_loss = - self.sil_policy_loss_coef * (log_probs * clipped_advs).mean()
        sil_val_loss = self.sil_val_loss_coef * clipped_advs.pow(2).mean() / 2
        logger.debug(f'SIL actor policy loss: {sil_policy_loss:g}')
        logger.debug(f'SIL critic value loss: {sil_val_loss:g}')
        return sil_policy_loss, sil_val_loss
Example #8
0
 def calc_log_probs(self, batch, use_old_net=False):
     '''Helper method to calculate log_probs with the option to swith net'''
     if use_old_net:
         # temporarily swap to do calc
         self.tmp_net = self.net
         self.net = self.old_net
     states, actions = batch['states'], batch['actions']
     # get ActionPD, don't append to state_buffer
     ActionPD, _pdparam, _body = policy_util.init_action_pd(states[0].cpu().numpy(), self, self.body, append=False)
     # construct log_probs for each state-action
     pdparams = self.calc_pdparam(states, evaluate=False)
     log_probs = []
     for idx, pdparam in enumerate(pdparams):
         _action, action_pd = policy_util.sample_action_pd(ActionPD, pdparam, self.body)
         log_prob = action_pd.log_prob(actions[idx])
         log_probs.append(log_prob)
     log_probs = torch.stack(log_probs)
     if use_old_net:
         # swap back
         self.old_net = self.net
         self.net = self.tmp_net
     return log_probs
Example #9
0
 def calc_log_probs(self, batch, use_old_net=False):
     '''Helper method to calculate log_probs with the option to swith net'''
     if use_old_net:
         # temporarily swap to do calc
         self.tmp_net = self.net
         self.net = self.old_net
     states, actions = batch['states'], batch['actions']
     # get ActionPD, don't append to state_buffer
     ActionPD, _pdparam, _body = policy_util.init_action_pd(states[0].cpu().numpy(), self, self.body, append=False)
     # construct log_probs for each state-action
     pdparams = self.calc_pdparam(states)
     log_probs = []
     for idx, pdparam in enumerate(pdparams):
         _action, action_pd = policy_util.sample_action_pd(ActionPD, pdparam, self.body)
         log_prob = action_pd.log_prob(actions[idx])
         log_probs.append(log_prob)
     log_probs = torch.tensor(log_probs)
     if use_old_net:
         # swap back
         self.old_net = self.net
         self.net = self.tmp_net
     return log_probs
Example #10
0
    def train(self):
        '''Train actor critic by computing the loss in batch efficiently'''
        if util.in_eval_lab_modes():
            return np.nan
        clock = self.body.env.clock
        if self.to_train == 1:
            for _ in range(self.training_iter):
                batch = self.sample()
                clock.set_batch_size(len(batch))

                # forward passes for losses
                states = batch['states']
                actions = batch['actions']
                if self.body.is_discrete:
                    # to one-hot discrete action for Q input.
                    # TODO support multi-discrete actions
                    actions = torch.eye(self.body.action_dim)[actions.long()]
                pdparams = self.calc_pdparam(states)
                action_pd = policy_util.init_action_pd(self.body.ActionPD,
                                                       pdparams)

                # V-value loss
                v_preds = self.calc_v(states, net=self.critic_net)
                v_targets = self.calc_v_targets(batch, action_pd)
                val_loss = self.calc_reg_loss(v_preds, v_targets)
                self.critic_net.train_step(val_loss,
                                           self.critic_optim,
                                           self.critic_lr_scheduler,
                                           clock=clock,
                                           global_net=self.global_critic_net)

                # Q-value loss for both Q nets
                q_targets = self.calc_q_targets(batch)
                q1_preds = self.calc_q(states, actions, self.q1_net)
                q1_loss = self.calc_reg_loss(q1_preds, q_targets)
                self.q1_net.train_step(q1_loss,
                                       self.q1_optim,
                                       self.q1_lr_scheduler,
                                       clock=clock,
                                       global_net=self.global_q1_net)
                q2_preds = self.calc_q(states, actions, self.q2_net)
                q2_loss = self.calc_reg_loss(q2_preds, q_targets)
                self.q2_net.train_step(q2_loss,
                                       self.q2_optim,
                                       self.q2_lr_scheduler,
                                       clock=clock,
                                       global_net=self.global_q2_net)

                # policy loss
                policy_loss = self.calc_policy_loss(batch, action_pd)
                self.net.train_step(policy_loss,
                                    self.optim,
                                    self.lr_scheduler,
                                    clock=clock,
                                    global_net=self.global_net)

                loss = policy_loss + val_loss + q1_loss + q2_loss

                # update target_critic_net
                self.update_nets()
                # update PER priorities if availalbe
                self.try_update_per(torch.min(q1_preds, q2_preds), q_targets)

            # reset
            self.to_train = 0
            logger.debug(
                f'Trained {self.name} at epi: {clock.epi}, frame: {clock.frame}, t: {clock.t}, total_reward so far: {self.body.env.total_reward}, loss: {loss:g}'
            )
            return loss.item()
        else:
            return np.nan