def _calc_returns_and_advantages(self, experience, value): returns = value_ops.discounted_return(rewards=experience.reward, values=value, step_types=experience.step_type, discounts=experience.discount * self._gamma) returns = tensor_utils.tensor_extend(returns, value[-1]) if not self._use_gae: advantages = returns - value else: advantages = value_ops.generalized_advantage_estimation( rewards=experience.reward, values=value, step_types=experience.step_type, discounts=experience.discount * self._gamma, td_lambda=self._lambda) advantages = tensor_utils.tensor_extend_zero(advantages) if self._use_td_lambda_return: returns = advantages + value return returns, advantages
def _backup(self, trees: _MCTSTree, search_paths, path_lengths, values): B = trees.B.unsqueeze(0) T = search_paths.shape[0] depth = torch.arange(T).unsqueeze(-1) if trees.reward is not None: reward = trees.reward[B, search_paths] reward[depth > path_lengths] = 0. # [T+1, batch_size] reward = tensor_utils.tensor_extend_zero(reward) reward[path_lengths, B] = values discounts = (self._discount**torch.arange( T + 1, dtype=torch.float32)).unsqueeze(-1) # discounted_return[t] = discount^t * reward[t] discounted_return = reward * discounts # discounted_return[t] = \sum_{s=t}^T discount^s * reward[s] discounted_return = reward.flip(0).cumsum(dim=0).flip(0) # discounted_return[t] = \sum_{s=t}^T discount^(s-t) * reward[s] discounted_return = discounted_return / discounts discounted_return = discounted_return[1:] else: # [T, 1] steps = torch.arange(1, T + 1, dtype=torch.float32).unsqueeze(-1) # [T, B] discounts = self._discount**(path_lengths.unsqueeze(0) - steps) discounted_return = values.unsqueeze(0) * discounts value_sum = trees.value_sum[B, search_paths] + discounted_return valid = depth < path_lengths nodes = (B.expand(T, -1)[valid], search_paths[valid]) trees.visit_count[nodes] += 1 trees.value_sum[nodes] = value_sum[valid] trees.update_value_stats((B, search_paths), valid) self._update_best_child(trees, nodes)
def forward(self, experience, value, target_value): """Cacluate the loss. The first dimension of all the tensors is time dimension and the second dimesion is the batch dimension. Args: experience (Experience): experience collected from ``unroll()`` or a replay buffer. All tensors are time-major. value (torch.Tensor): the time-major tensor for the value at each time step. The loss is between this and the calculated return. target_value (torch.Tensor): the time-major tensor for the value at each time step. This is used to calculate return. ``target_value`` can be same as ``value``. Returns: LossInfo: with the ``extra`` field same as ``loss``. """ if self._lambda == 1.0: returns = value_ops.discounted_return( rewards=experience.reward, values=target_value, step_types=experience.step_type, discounts=experience.discount * self._gamma) elif self._lambda == 0.0: returns = value_ops.one_step_discounted_return( rewards=experience.reward, values=target_value, step_types=experience.step_type, discounts=experience.discount * self._gamma) else: advantages = value_ops.generalized_advantage_estimation( rewards=experience.reward, values=target_value, step_types=experience.step_type, discounts=experience.discount * self._gamma, td_lambda=self._lambda) returns = advantages + target_value[:-1] value = value[:-1] if self._debug_summaries and alf.summary.should_record_summaries(): mask = experience.step_type[:-1] != StepType.LAST with alf.summary.scope(self._name): def _summarize(v, r, td, suffix): alf.summary.scalar( "explained_variance_of_return_by_value" + suffix, tensor_utils.explained_variance(v, r, mask)) safe_mean_hist_summary('values' + suffix, v, mask) safe_mean_hist_summary('returns' + suffix, r, mask) safe_mean_hist_summary("td_error" + suffix, td, mask) if value.ndim == 2: _summarize(value, returns, returns - value, '') else: td = returns - value for i in range(value.shape[2]): suffix = '/' + str(i) _summarize(value[..., i], returns[..., i], td[..., i], suffix) loss = self._td_error_loss_fn(returns.detach(), value) if loss.ndim == 3: # Multidimensional reward. Average over the critic loss for all dimensions loss = loss.mean(dim=2) # The shape of the loss expected by Algorith.update_with_gradient is # [T, B], so we need to augment it with additional zeros. loss = tensor_utils.tensor_extend_zero(loss) return LossInfo(loss=loss, extra=loss)
def calc_loss(self, experience, train_info: td.Distribution): dist: td.Distribution = train_info log_prob = dist.log_prob(experience.action) loss = -log_prob[:-1] * experience.reward[1:] loss = tensor_utils.tensor_extend_zero(loss) return LossInfo(loss=loss)
def calc_loss(self, experience, info: SarsaInfo): loss = info.actor_loss if self._log_alpha is not None: alpha = self._log_alpha.exp().detach() alpha_loss = self._log_alpha * (-info.neg_entropy - self._target_entropy).detach() loss = loss + alpha * info.neg_entropy + alpha_loss else: alpha_loss = () # For sarsa, info.critics is actually the critics for the previous step. # And info.target_critics is the critics for the current step. So we # need to rearrange ``experience``` to match the requirement for # `OneStepTDLoss`. step_type0 = experience.step_type[0] step_type0 = torch.where(step_type0 == StepType.LAST, torch.tensor(StepType.MID), step_type0) step_type0 = torch.where(step_type0 == StepType.FIRST, torch.tensor(StepType.LAST), step_type0) reward = experience.reward if self._use_entropy_reward: reward -= (self._log_alpha.exp() * info.neg_entropy).detach() shifted_experience = experience._replace( discount=tensor_utils.tensor_prepend_zero(experience.discount), reward=tensor_utils.tensor_prepend_zero(reward), step_type=tensor_utils.tensor_prepend(experience.step_type, step_type0)) critic_losses = [] for i in range(self._num_critic_replicas): critic = tensor_utils.tensor_extend_zero(info.critics[..., i]) target_critic = tensor_utils.tensor_prepend_zero( info.target_critics) loss_info = self._critic_losses[i](shifted_experience, critic, target_critic) critic_losses.append(nest_map(lambda l: l[:-1], loss_info.loss)) critic_loss = math_ops.add_n(critic_losses) not_first_step = (experience.step_type != StepType.FIRST).to( torch.float32) critic_loss = critic_loss * not_first_step if (experience.batch_info != () and experience.batch_info.importance_weights != ()): valid_n = torch.clamp(not_first_step.sum(dim=0), min=1.0) priority = (critic_loss.sum(dim=0) / valid_n).sqrt() else: priority = () # put critic_loss to scalar_loss because loss will be masked by # ~is_last at train_complete(). The critic_loss here should be # masked by ~is_first instead, which is done above scalar_loss = critic_loss.mean() if self._debug_summaries and alf.summary.should_record_summaries(): with alf.summary.scope(self._name): if self._log_alpha is not None: alf.summary.scalar("alpha", alpha) return LossInfo(loss=loss, scalar_loss=scalar_loss, priority=priority, extra=SarsaLossInfo(actor=info.actor_loss, critic=critic_loss, alpha=alpha_loss, neg_entropy=info.neg_entropy))