Exemple #1
0
 def calc_ret_advs(self, batch):
     '''Calculate plain returns; which is generalized to advantage in ActorCritic'''
     rets = math_util.calc_returns(batch['rewards'], batch['dones'], self.gamma)
     if self.center_return:
         rets = math_util.center_mean(rets)
     advs = rets
     if self.body.env.is_venv:
         advs = math_util.venv_unpack(advs)
     logger.debug(f'advs: {advs}')
     return advs
Exemple #2
0
 def sample(self):
     '''Modify the onpolicy sample to also append to replay'''
     batch = self.body.memory.sample()
     batch = {k: np.concatenate(v) for k, v in batch.items()}  # concat episodic memory
     batch['rets'] = math_util.calc_returns(batch, self.gamma)
     for idx in range(len(batch['dones'])):
         tuples = [batch[k][idx] for k in self.body.replay_memory.data_keys]
         self.body.replay_memory.add_experience(*tuples)
     if self.normalize_state:
         batch = policy_util.normalize_states_and_next_states(self.body, batch)
     batch = util.to_torch_batch(batch, self.net.device, self.body.replay_memory.is_episodic)
     return batch
Exemple #3
0
 def calc_ret_advs_v_targets(self, batch, v_preds):
     '''Calculate plain returns, and advs = rets - v_preds, v_targets = rets'''
     v_preds = v_preds.detach()  # adv does not accumulate grad
     if self.body.env.is_venv:
         v_preds = math_util.venv_pack(v_preds, self.body.env.num_envs)
     rets = math_util.calc_returns(batch['rewards'], batch['dones'], self.gamma)
     advs = rets - v_preds
     v_targets = rets
     if self.body.env.is_venv:
         advs = math_util.venv_unpack(advs)
         v_targets = math_util.venv_unpack(v_targets)
     logger.debug(f'advs: {advs}\nv_targets: {v_targets}')
     return advs, v_targets
Exemple #4
0
 def calc_policy_loss(self, batch):
     '''Calculate the policy loss for a batch of data.'''
     # use simple returns as advs
     advs = math_util.calc_returns(batch, self.gamma)
     advs = math_util.standardize(advs)
     logger.debug(f'advs: {advs}')
     assert len(self.body.log_probs) == len(
         advs
     ), f'batch_size of log_probs {len(self.body.log_probs)} vs advs: {len(advs)}'
     log_probs = torch.stack(self.body.log_probs)
     policy_loss = -log_probs * advs
     if self.entropy_coef_spec is not None:
         entropies = torch.stack(self.body.entropies)
         policy_loss += (-self.body.entropy_coef * entropies)
     policy_loss = torch.sum(policy_loss)
     logger.debug(f'Actor policy loss: {policy_loss:g}')
     return policy_loss
Exemple #5
0
    def calc_sil_policy_val_loss(self, batch, pdparams):
        '''
        Calculate the SIL policy losses for actor and critic
        sil_policy_loss = -log_prob * max(R - v_pred, 0)
        sil_val_loss = (max(R - v_pred, 0)^2) / 2
        This is called on a randomly-sample batch from experience replay
        '''
        v_preds = self.calc_v(batch['states'], use_cache=False)
        rets = math_util.calc_returns(batch['rewards'], batch['dones'], self.gamma)
        clipped_advs = torch.clamp(rets - v_preds, min=0.0)

        action_pd = policy_util.init_action_pd(self.body.ActionPD, pdparams)
        actions = batch['actions']
        if self.body.env.is_venv:
            actions = math_util.venv_unpack(actions)
        log_probs = action_pd.log_prob(actions)

        sil_policy_loss = - self.sil_policy_loss_coef * (log_probs * clipped_advs).mean()
        sil_val_loss = self.sil_val_loss_coef * clipped_advs.pow(2).mean() / 2
        logger.debug(f'SIL actor policy loss: {sil_policy_loss:g}')
        logger.debug(f'SIL critic value loss: {sil_val_loss:g}')
        return sil_policy_loss, sil_val_loss