Exemple #1
0
 def calc_gae_advs_v_targets(self, batch, v_preds):
     '''
     Calculate GAE, and advs = GAE, v_targets = advs + v_preds
     See GAE from Schulman et al. https://arxiv.org/pdf/1506.02438.pdf
     '''
     next_states = batch['next_states'][-1]
     if not self.body.env.is_venv:
         next_states = next_states.unsqueeze(dim=0)
     with torch.no_grad():
         next_v_pred = self.calc_v(next_states, use_cache=False)
     v_preds = v_preds.detach()  # adv does not accumulate grad
     if self.body.env.is_venv:
         v_preds = math_util.venv_pack(v_preds, self.body.env.num_envs)
         next_v_pred = next_v_pred.unsqueeze(dim=0)
     v_preds_all = torch.cat((v_preds, next_v_pred), dim=0)
     advs = math_util.calc_gaes(batch['rewards'], batch['dones'],
                                v_preds_all, self.gamma, self.lam)
     v_targets = advs + v_preds
     advs = math_util.standardize(
         advs)  # standardize only for advs, not v_targets
     if self.body.env.is_venv:
         advs = math_util.venv_unpack(advs)
         v_targets = math_util.venv_unpack(v_targets)
     logger.debug(f'advs: {advs}\nv_targets: {v_targets}')
     return advs, v_targets
Exemple #2
0
def test_calc_gaes():
    rewards = torch.tensor([1., 0., 1., 1., 0., 1., 1., 1.])
    dones = torch.tensor([0., 0., 1., 1., 0., 0., 0., 0.])
    v_preds = torch.tensor([1.1, 0.1, 1.1, 1.1, 0.1, 1.1, 1.1, 1.1, 1.1])
    assert len(v_preds) == len(rewards) + 1  # includes last state
    gamma = 0.99
    lam = 0.95
    gaes = math_util.calc_gaes(rewards, dones, v_preds, gamma, lam)
    res = torch.tensor([
        0.84070045, 0.89495, -0.1, -0.1, 3.616724, 2.7939649, 1.9191545, 0.989
    ])
    # use allclose instead of equal to account for atol
    assert torch.allclose(gaes, res)
Exemple #3
0
 def calc_gae_advs_v_targets(self, batch):
     '''
     Calculate the GAE advantages and value targets for training actor and critic respectively
     adv_targets = GAE (see math_util method)
     v_targets = adv_targets + v_preds
     before output, adv_targets is standardized (so v_targets used the unstandardized version)
     Used for training with GAE
     '''
     states = torch.cat((batch['states'], batch['next_states'][-1:]),
                        dim=0)  # prevent double-pass
     v_preds = self.calc_v(states)
     next_v_preds = v_preds[1:]  # shift for only the next states
     # v_target = r_t + gamma * V(s_(t+1)), i.e. 1-step return
     v_targets = math_util.calc_nstep_returns(batch['rewards'],
                                              batch['dones'], self.gamma, 1,
                                              next_v_preds)
     adv_targets = math_util.calc_gaes(batch['rewards'], batch['dones'],
                                       v_preds, self.gamma, self.lam)
     adv_targets = math_util.standardize(adv_targets)
     logger.debug(f'adv_targets: {adv_targets}\nv_targets: {v_targets}')
     return adv_targets, v_targets
Exemple #4
0
 def calc_gae_advs_v_targets(self, batch):
     '''
     Calculate the GAE advantages and value targets for training actor and critic respectively
     adv_targets = GAE (see math_util method)
     v_targets = adv_targets + v_preds
     before output, adv_targets is standardized (so v_targets used the unstandardized version)
     Used for training with GAE
     '''
     v_preds = self.calc_v(batch['states'])
     # calc next_state boundary value and concat with above for efficiency
     next_v_pred_tail = self.calc_v(batch['next_states'][-1:])
     next_v_preds = torch.cat([v_preds[1:], next_v_pred_tail], dim=0)
     # v targets = r_t + gamma * V(s_(t+1))
     v_targets = math_util.calc_nstep_returns(batch, self.gamma, 1,
                                              next_v_preds)
     # ensure val for next_state is 0 at done
     next_v_preds = next_v_preds * (1 - batch['dones'])
     adv_targets = math_util.calc_gaes(batch['rewards'], v_preds,
                                       next_v_preds, self.gamma, self.lam)
     adv_targets = math_util.standardize(adv_targets)
     logger.debug(f'adv_targets: {adv_targets}\nv_targets: {v_targets}')
     return adv_targets, v_targets