Exemplo n.º 1
0
 def calc_gae_advs_v_targets(self, batch, v_preds):
     '''
     Calculate GAE, and advs = GAE, v_targets = advs + v_preds
     See GAE from Schulman et al. https://arxiv.org/pdf/1506.02438.pdf
     '''
     next_states = batch['next_states'][-1]
     if not self.body.env.is_venv:
         next_states = next_states.unsqueeze(dim=0)
     with torch.no_grad():
         next_v_pred = self.calc_v(next_states, use_cache=False)
     v_preds = v_preds.detach()  # adv does not accumulate grad
     if self.body.env.is_venv:
         v_preds = math_util.venv_pack(v_preds, self.body.env.num_envs)
         next_v_pred = next_v_pred.unsqueeze(dim=0)
     v_preds_all = torch.cat((v_preds, next_v_pred), dim=0)
     advs = math_util.calc_gaes(batch['rewards'], batch['dones'],
                                v_preds_all, self.gamma, self.lam)
     v_targets = advs + v_preds
     advs = math_util.standardize(
         advs)  # standardize only for advs, not v_targets
     if self.body.env.is_venv:
         advs = math_util.venv_unpack(advs)
         v_targets = math_util.venv_unpack(v_targets)
     logger.debug(f'advs: {advs}\nv_targets: {v_targets}')
     return advs, v_targets
Exemplo n.º 2
0
 def calc_policy_loss(self, batch):
     '''Calculate the policy loss for a batch of data.'''
     # use simple returns as advs
     advs = math_util.calc_returns(batch, self.gamma)
     advs = math_util.standardize(advs)
     logger.debug(f'advs: {advs}')
     assert len(self.body.log_probs) == len(
         advs
     ), f'batch_size of log_probs {len(self.body.log_probs)} vs advs: {len(advs)}'
     log_probs = torch.stack(self.body.log_probs)
     policy_loss = -log_probs * advs
     if self.entropy_coef_spec is not None:
         entropies = torch.stack(self.body.entropies)
         policy_loss += (-self.body.entropy_coef * entropies)
     policy_loss = torch.sum(policy_loss)
     logger.debug(f'Actor policy loss: {policy_loss:g}')
     return policy_loss
Exemplo n.º 3
0
 def calc_gae_advs_v_targets(self, batch):
     '''
     Calculate the GAE advantages and value targets for training actor and critic respectively
     adv_targets = GAE (see math_util method)
     v_targets = adv_targets + v_preds
     before output, adv_targets is standardized (so v_targets used the unstandardized version)
     Used for training with GAE
     '''
     states = torch.cat((batch['states'], batch['next_states'][-1:]),
                        dim=0)  # prevent double-pass
     v_preds = self.calc_v(states)
     next_v_preds = v_preds[1:]  # shift for only the next states
     # v_target = r_t + gamma * V(s_(t+1)), i.e. 1-step return
     v_targets = math_util.calc_nstep_returns(batch['rewards'],
                                              batch['dones'], self.gamma, 1,
                                              next_v_preds)
     adv_targets = math_util.calc_gaes(batch['rewards'], batch['dones'],
                                       v_preds, self.gamma, self.lam)
     adv_targets = math_util.standardize(adv_targets)
     logger.debug(f'adv_targets: {adv_targets}\nv_targets: {v_targets}')
     return adv_targets, v_targets
Exemplo n.º 4
0
 def calc_gae_advs_v_targets(self, batch):
     '''
     Calculate the GAE advantages and value targets for training actor and critic respectively
     adv_targets = GAE (see math_util method)
     v_targets = adv_targets + v_preds
     before output, adv_targets is standardized (so v_targets used the unstandardized version)
     Used for training with GAE
     '''
     v_preds = self.calc_v(batch['states'])
     # calc next_state boundary value and concat with above for efficiency
     next_v_pred_tail = self.calc_v(batch['next_states'][-1:])
     next_v_preds = torch.cat([v_preds[1:], next_v_pred_tail], dim=0)
     # v targets = r_t + gamma * V(s_(t+1))
     v_targets = math_util.calc_nstep_returns(batch, self.gamma, 1,
                                              next_v_preds)
     # ensure val for next_state is 0 at done
     next_v_preds = next_v_preds * (1 - batch['dones'])
     adv_targets = math_util.calc_gaes(batch['rewards'], v_preds,
                                       next_v_preds, self.gamma, self.lam)
     adv_targets = math_util.standardize(adv_targets)
     logger.debug(f'adv_targets: {adv_targets}\nv_targets: {v_targets}')
     return adv_targets, v_targets