Exemple #1
0
 def calc_gae_advs_v_targets(self, batch, v_preds):
     '''
     Calculate GAE, and advs = GAE, v_targets = advs + v_preds
     See GAE from Schulman et al. https://arxiv.org/pdf/1506.02438.pdf
     '''
     next_states = batch['next_states'][-1]
     if not self.body.env.is_venv:
         next_states = next_states.unsqueeze(dim=0)
     with torch.no_grad():
         next_v_pred = self.calc_v(next_states, use_cache=False)
     v_preds = v_preds.detach()  # adv does not accumulate grad
     if self.body.env.is_venv:
         v_preds = math_util.venv_pack(v_preds, self.body.env.num_envs)
         next_v_pred = next_v_pred.unsqueeze(dim=0)
     v_preds_all = torch.cat((v_preds, next_v_pred), dim=0)
     advs = math_util.calc_gaes(batch['rewards'], batch['dones'],
                                v_preds_all, self.gamma, self.lam)
     v_targets = advs + v_preds
     advs = math_util.standardize(
         advs)  # standardize only for advs, not v_targets
     if self.body.env.is_venv:
         advs = math_util.venv_unpack(advs)
         v_targets = math_util.venv_unpack(v_targets)
     logger.debug(f'advs: {advs}\nv_targets: {v_targets}')
     return advs, v_targets
Exemple #2
0
 def calc_ret_advs_v_targets(self, batch, v_preds):
     '''Calculate plain returns, and advs = rets - v_preds, v_targets = rets'''
     v_preds = v_preds.detach()  # adv does not accumulate grad
     if self.body.env.is_venv:
         v_preds = math_util.venv_pack(v_preds, self.body.env.num_envs)
     rets = math_util.calc_returns(batch['rewards'], batch['dones'], self.gamma)
     advs = rets - v_preds
     v_targets = rets
     if self.body.env.is_venv:
         advs = math_util.venv_unpack(advs)
         v_targets = math_util.venv_unpack(v_targets)
     logger.debug(f'advs: {advs}\nv_targets: {v_targets}')
     return advs, v_targets
Exemple #3
0
 def train(self):
     if util.in_eval_lab_modes():
         return np.nan
     clock = self.body.env.clock
     if self.to_train == 1:
         net_util.copy(self.net, self.old_net)  # update old net
         batch = self.sample()
         clock.set_batch_size(len(batch))
         with torch.no_grad():
             states = batch['states']
             if self.body.env.is_venv:
                 states = math_util.venv_unpack(states)
             # NOTE states is massive with batch_size = time_horizon * num_envs. Chunk up so forward pass can fit into device esp. GPU
             num_chunks = int(len(states) / self.minibatch_size)
             v_preds_chunks = [self.calc_v(states_chunk, use_cache=False) for states_chunk in torch.chunk(states, num_chunks)]
             v_preds = torch.cat(v_preds_chunks)
             advs, v_targets = self.calc_advs_v_targets(batch, v_preds)
         # piggy back on batch, but remember to not pack or unpack
         batch['advs'], batch['v_targets'] = advs, v_targets
         if self.body.env.is_venv:  # unpack if venv for minibatch sampling
             for k, v in batch.items():
                 if k not in ('advs', 'v_targets'):
                     batch[k] = math_util.venv_unpack(v)
         total_loss = torch.tensor(0.0)
         for _ in range(self.training_epoch):
             minibatches = util.split_minibatch(batch, self.minibatch_size)
             for minibatch in minibatches:
                 if self.body.env.is_venv:  # re-pack to restore proper shape
                     for k, v in minibatch.items():
                         if k not in ('advs', 'v_targets'):
                             minibatch[k] = math_util.venv_pack(v, self.body.env.num_envs)
                 advs, v_targets = minibatch['advs'], minibatch['v_targets']
                 pdparams, v_preds = self.calc_pdparam_v(minibatch)
                 policy_loss = self.calc_policy_loss(minibatch, pdparams, advs)  # from actor
                 val_loss = self.calc_val_loss(v_preds, v_targets)  # from critic
                 if self.shared:  # shared network
                     loss = policy_loss + val_loss
                     self.net.train_step(loss, self.optim, self.lr_scheduler, clock=clock, global_net=self.global_net)
                 else:
                     self.net.train_step(policy_loss, self.optim, self.lr_scheduler, clock=clock, global_net=self.global_net)
                     self.critic_net.train_step(val_loss, self.critic_optim, self.critic_lr_scheduler, clock=clock, global_net=self.global_critic_net)
                     loss = policy_loss + val_loss
                 total_loss += loss
         loss = total_loss / self.training_epoch / len(minibatches)
         # reset
         self.to_train = 0
         logger.debug(f'Trained {self.name} at epi: {clock.epi}, frame: {clock.frame}, t: {clock.t}, total_reward so far: {self.body.env.total_reward}, loss: {loss:g}')
         return loss.item()
     else:
         return np.nan
Exemple #4
0
    def calc_policy_loss(self, batch, pdparams, advs):
        '''
        The PPO loss function (subscript t is omitted)
        L^{CLIP+VF+S} = E[ L^CLIP - c1 * L^VF + c2 * H[pi](s) ]

        Breakdown piecewise,
        1. L^CLIP = E[ min(ratio * A, clip(ratio, 1-eps, 1+eps) * A) ]
        where ratio = pi(a|s) / pi_old(a|s)

        2. L^VF = E[ mse(V(s_t), V^target) ]

        3. H = E[ entropy ]
        '''
        clip_eps = self.body.clip_eps
        action_pd = policy_util.init_action_pd(self.body.ActionPD, pdparams)
        states = batch['states']
        actions = batch['actions']
        if self.body.env.is_venv:
            states = math_util.venv_unpack(states)
            actions = math_util.venv_unpack(actions)

        # L^CLIP
        log_probs = action_pd.log_prob(actions)
        with torch.no_grad():
            old_pdparams = self.calc_pdparam(states, net=self.old_net)
            old_action_pd = policy_util.init_action_pd(self.body.ActionPD,
                                                       old_pdparams)
            old_log_probs = old_action_pd.log_prob(actions)
        assert log_probs.shape == old_log_probs.shape
        ratios = torch.exp(log_probs - old_log_probs)
        logger.debug(f'ratios: {ratios}')
        sur_1 = ratios * advs
        sur_2 = torch.clamp(ratios, 1.0 - clip_eps, 1.0 + clip_eps) * advs
        # flip sign because need to maximize
        clip_loss = -torch.min(sur_1, sur_2).mean()
        logger.debug(f'clip_loss: {clip_loss}')

        # L^VF (inherit from ActorCritic)

        # H entropy regularization
        entropy = action_pd.entropy().mean()
        self.body.mean_entropy = entropy  # update logging variable
        ent_penalty = -self.body.entropy_coef * entropy
        logger.debug(f'ent_penalty: {ent_penalty}')

        policy_loss = clip_loss + ent_penalty
        logger.debug(f'PPO Actor policy loss: {policy_loss:g}')
        return policy_loss
Exemple #5
0
 def calc_pdparam_batch(self, batch):
     '''Efficiently forward to get pdparam and by batch for loss computation'''
     states = batch['states']
     if self.body.env.is_venv:
         states = math_util.venv_unpack(states)
     pdparam = self.calc_pdparam(states)
     return pdparam
Exemple #6
0
def test_venv_pack(base_shape):
    batch_size = 5
    num_envs = 4
    batch_arr = torch.zeros([batch_size, num_envs] + base_shape)
    unpacked_arr = math_util.venv_unpack(batch_arr)
    packed_arr = math_util.venv_pack(unpacked_arr, num_envs)
    assert list(packed_arr.shape) == [batch_size, num_envs] + base_shape
Exemple #7
0
 def calc_pdparam_v(self, batch):
     '''Efficiently forward to get pdparam and v by batch for loss computation'''
     states = batch['states']
     if self.body.env.is_venv:
         states = math_util.venv_unpack(states)
     pdparam = self.calc_pdparam(states)
     v_pred = self.calc_v(states)  # uses self.v_pred from calc_pdparam if self.shared
     return pdparam, v_pred
Exemple #8
0
 def calc_ret_advs(self, batch):
     '''Calculate plain returns; which is generalized to advantage in ActorCritic'''
     rets = math_util.calc_returns(batch['rewards'], batch['dones'], self.gamma)
     if self.center_return:
         rets = math_util.center_mean(rets)
     advs = rets
     if self.body.env.is_venv:
         advs = math_util.venv_unpack(advs)
     logger.debug(f'advs: {advs}')
     return advs
Exemple #9
0
 def calc_q_loss(self, batch):
     '''Compute the Q value loss using predicted and target Q values from the appropriate networks'''
     states = batch['states']
     next_states = batch['next_states']
     if self.body.env.is_venv:
         states = math_util.venv_unpack(states)
         next_states = math_util.venv_unpack(next_states)
     q_preds = self.net(states)
     with torch.no_grad():
         next_q_preds = self.net(next_states)
     if self.body.env.is_venv:
         q_preds = math_util.venv_pack(q_preds, self.body.env.num_envs)
         next_q_preds = math_util.venv_pack(next_q_preds, self.body.env.num_envs)
     act_q_preds = q_preds.gather(-1, batch['actions'].long().unsqueeze(-1)).squeeze(-1)
     act_next_q_preds = next_q_preds.gather(-1, batch['next_actions'].long().unsqueeze(-1)).squeeze(-1)
     act_q_targets = batch['rewards'] + self.gamma * (1 - batch['dones']) * act_next_q_preds
     logger.debug(f'act_q_preds: {act_q_preds}\nact_q_targets: {act_q_targets}')
     q_loss = self.net.loss_fn(act_q_preds, act_q_targets)
     return q_loss
Exemple #10
0
 def calc_nstep_advs_v_targets(self, batch, v_preds):
     '''
     Calculate N-step returns, and advs = nstep_rets - v_preds, v_targets = nstep_rets
     See n-step advantage under http://rail.eecs.berkeley.edu/deeprlcourse-fa17/f17docs/lecture_5_actor_critic_pdf.pdf
     '''
     next_states = batch['next_states'][-1]
     if not self.body.env.is_venv:
         next_states = next_states.unsqueeze(dim=0)
     with torch.no_grad():
         next_v_pred = self.calc_v(next_states, use_cache=False)
     v_preds = v_preds.detach()  # adv does not accumulate grad
     if self.body.env.is_venv:
         v_preds = math_util.venv_pack(v_preds, self.body.env.num_envs)
     nstep_rets = math_util.calc_nstep_returns(batch['rewards'], batch['dones'], next_v_pred, self.gamma, self.num_step_returns)
     advs = nstep_rets - v_preds
     v_targets = nstep_rets
     if self.body.env.is_venv:
         advs = math_util.venv_unpack(advs)
         v_targets = math_util.venv_unpack(v_targets)
     logger.debug(f'advs: {advs}\nv_targets: {v_targets}')
     return advs, v_targets
Exemple #11
0
 def calc_policy_loss(self, batch, pdparams, advs):
     '''Calculate the actor's policy loss'''
     action_pd = policy_util.init_action_pd(self.body.ActionPD, pdparams)
     actions = batch['actions']
     if self.body.env.is_venv:
         actions = math_util.venv_unpack(actions)
     log_probs = action_pd.log_prob(actions)
     policy_loss = - self.policy_loss_coef * (log_probs * advs).mean()
     if self.entropy_coef_spec:
         entropy = action_pd.entropy().mean()
         self.body.mean_entropy = entropy  # update logging variable
         policy_loss += (-self.body.entropy_coef * entropy)
     logger.debug(f'Actor policy loss: {policy_loss:g}')
     return policy_loss
Exemple #12
0
    def calc_sil_policy_val_loss(self, batch, pdparams):
        '''
        Calculate the SIL policy losses for actor and critic
        sil_policy_loss = -log_prob * max(R - v_pred, 0)
        sil_val_loss = (max(R - v_pred, 0)^2) / 2
        This is called on a randomly-sample batch from experience replay
        '''
        v_preds = self.calc_v(batch['states'], use_cache=False)
        rets = math_util.calc_returns(batch['rewards'], batch['dones'], self.gamma)
        clipped_advs = torch.clamp(rets - v_preds, min=0.0)

        action_pd = policy_util.init_action_pd(self.body.ActionPD, pdparams)
        actions = batch['actions']
        if self.body.env.is_venv:
            actions = math_util.venv_unpack(actions)
        log_probs = action_pd.log_prob(actions)

        sil_policy_loss = - self.sil_policy_loss_coef * (log_probs * clipped_advs).mean()
        sil_val_loss = self.sil_val_loss_coef * clipped_advs.pow(2).mean() / 2
        logger.debug(f'SIL actor policy loss: {sil_policy_loss:g}')
        logger.debug(f'SIL critic value loss: {sil_val_loss:g}')
        return sil_policy_loss, sil_val_loss