def calc_ret_advs(self, batch): '''Calculate plain returns; which is generalized to advantage in ActorCritic''' rets = math_util.calc_returns(batch['rewards'], batch['dones'], self.gamma) if self.center_return: rets = math_util.center_mean(rets) advs = rets if self.body.env.is_venv: advs = math_util.venv_unpack(advs) logger.debug(f'advs: {advs}') return advs
def sample(self): '''Modify the onpolicy sample to also append to replay''' batch = self.body.memory.sample() batch = {k: np.concatenate(v) for k, v in batch.items()} # concat episodic memory batch['rets'] = math_util.calc_returns(batch, self.gamma) for idx in range(len(batch['dones'])): tuples = [batch[k][idx] for k in self.body.replay_memory.data_keys] self.body.replay_memory.add_experience(*tuples) if self.normalize_state: batch = policy_util.normalize_states_and_next_states(self.body, batch) batch = util.to_torch_batch(batch, self.net.device, self.body.replay_memory.is_episodic) return batch
def calc_ret_advs_v_targets(self, batch, v_preds): '''Calculate plain returns, and advs = rets - v_preds, v_targets = rets''' v_preds = v_preds.detach() # adv does not accumulate grad if self.body.env.is_venv: v_preds = math_util.venv_pack(v_preds, self.body.env.num_envs) rets = math_util.calc_returns(batch['rewards'], batch['dones'], self.gamma) advs = rets - v_preds v_targets = rets if self.body.env.is_venv: advs = math_util.venv_unpack(advs) v_targets = math_util.venv_unpack(v_targets) logger.debug(f'advs: {advs}\nv_targets: {v_targets}') return advs, v_targets
def calc_policy_loss(self, batch): '''Calculate the policy loss for a batch of data.''' # use simple returns as advs advs = math_util.calc_returns(batch, self.gamma) advs = math_util.standardize(advs) logger.debug(f'advs: {advs}') assert len(self.body.log_probs) == len( advs ), f'batch_size of log_probs {len(self.body.log_probs)} vs advs: {len(advs)}' log_probs = torch.stack(self.body.log_probs) policy_loss = -log_probs * advs if self.entropy_coef_spec is not None: entropies = torch.stack(self.body.entropies) policy_loss += (-self.body.entropy_coef * entropies) policy_loss = torch.sum(policy_loss) logger.debug(f'Actor policy loss: {policy_loss:g}') return policy_loss
def calc_sil_policy_val_loss(self, batch, pdparams): ''' Calculate the SIL policy losses for actor and critic sil_policy_loss = -log_prob * max(R - v_pred, 0) sil_val_loss = (max(R - v_pred, 0)^2) / 2 This is called on a randomly-sample batch from experience replay ''' v_preds = self.calc_v(batch['states'], use_cache=False) rets = math_util.calc_returns(batch['rewards'], batch['dones'], self.gamma) clipped_advs = torch.clamp(rets - v_preds, min=0.0) action_pd = policy_util.init_action_pd(self.body.ActionPD, pdparams) actions = batch['actions'] if self.body.env.is_venv: actions = math_util.venv_unpack(actions) log_probs = action_pd.log_prob(actions) sil_policy_loss = - self.sil_policy_loss_coef * (log_probs * clipped_advs).mean() sil_val_loss = self.sil_val_loss_coef * clipped_advs.pow(2).mean() / 2 logger.debug(f'SIL actor policy loss: {sil_policy_loss:g}') logger.debug(f'SIL critic value loss: {sil_val_loss:g}') return sil_policy_loss, sil_val_loss