def _summarize(v, r, td, suffix): alf.summary.scalar( "explained_variance_of_return_by_value" + suffix, tensor_utils.explained_variance(v, r, mask)) safe_mean_hist_summary('values' + suffix, v, mask) safe_mean_hist_summary('returns' + suffix, r, mask) safe_mean_hist_summary("td_error" + suffix, td, mask)
def _summarize1(pred, tgt, loss, mask, suffix): alf.summary.scalar( "explained_variance" + suffix, tensor_utils.explained_variance(pred, tgt, mask)) safe_mean_hist_summary('predict' + suffix, pred, mask) safe_mean_hist_summary('target' + suffix, tgt, mask) safe_mean_summary("loss" + suffix, loss, mask)
def forward(self, experience, train_info): """Cacluate actor critic loss. The first dimension of all the tensors is time dimension and the second dimesion is the batch dimension. Args: experience (nest): experience used for training. All tensors are time-major. train_info (nest): information collected for training. It is batched from each ``AlgStep.info`` returned by ``rollout_step()`` (on-policy training) or ``train_step()`` (off-policy training). All tensors in ``train_info`` are time-major. Returns: LossInfo: with ``extra`` being ``ActorCriticLossInfo``. """ value = train_info.value returns, advantages = self._calc_returns_and_advantages( experience, value) if self._debug_summaries and alf.summary.should_record_summaries(): with alf.summary.scope(self._name): alf.summary.scalar("values", value.mean()) alf.summary.scalar("returns", returns.mean()) alf.summary.scalar("advantages/mean", advantages.mean()) alf.summary.histogram("advantages/value", advantages) alf.summary.scalar( "explained_variance_of_return_by_value", tensor_utils.explained_variance(value, returns)) if self._normalize_advantages: advantages = _normalize_advantages(advantages) if self._advantage_clip: advantages = torch.clamp(advantages, -self._advantage_clip, self._advantage_clip) pg_loss = self._pg_loss(experience, train_info, advantages.detach()) td_loss = self._td_error_loss_fn(returns.detach(), value) loss = pg_loss + self._td_loss_weight * td_loss entropy_loss = () if self._entropy_regularization is not None: entropy, entropy_for_gradient = dist_utils.entropy_with_fallback( train_info.action_distribution) entropy_loss = -entropy loss -= self._entropy_regularization * entropy_for_gradient return LossInfo(loss=loss, extra=ActorCriticLossInfo(td_loss=td_loss, pg_loss=pg_loss, neg_entropy=entropy_loss))
def calc_loss(self, model_output: ModelOutput, target: ModelTarget) -> LossInfo: """Calculate the loss. The shapes of the tensors in model_output are [B, unroll_steps+1, ...] Returns: LossInfo """ num_unroll_steps = target.value.shape[1] - 1 loss_scale = torch.ones((num_unroll_steps + 1, )) / num_unroll_steps loss_scale[0] = 1.0 value_loss = element_wise_squared_loss(target.value, model_output.value) value_loss = (loss_scale * value_loss).sum(dim=1) loss = value_loss reward_loss = () if self._train_reward_function: reward_loss = element_wise_squared_loss(target.reward, model_output.reward) reward_loss = (loss_scale * reward_loss).sum(dim=1) loss = loss + reward_loss # target_action.shape is [B, unroll_steps+1, num_candidate] # log_prob needs sample shape in the beginning if isinstance(target.action, tuple) and target.action == (): # This condition is only possible for Categorical distribution assert isinstance(model_output.action_distribution, td.Categorical) policy_loss = -(target.action_policy * model_output.action_distribution.logits).sum(dim=2) else: action = target.action.permute(2, 0, 1, *list(range(3, target.action.ndim))) action_log_probs = model_output.action_distribution.log_prob( action) action_log_probs = action_log_probs.permute(1, 2, 0) policy_loss = -(target.action_policy * action_log_probs).sum(dim=2) game_over_loss = () if self._train_game_over_function: game_over_loss = F.binary_cross_entropy_with_logits( input=model_output.game_over_logit, target=target.game_over.to(torch.float), reduction='none') # no need to train policy after game over. policy_loss = policy_loss * (~target.game_over).to(torch.float32) unscaled_game_over_loss = game_over_loss game_over_loss = (loss_scale * game_over_loss).sum(dim=1) loss = loss + game_over_loss policy_loss = (loss_scale * policy_loss).sum(dim=1) loss = loss + policy_loss if self._debug_summaries and alf.summary.should_record_summaries(): with alf.summary.scope(self._name): alf.summary.scalar( "explained_variance_of_value0", tensor_utils.explained_variance(model_output.value[:, 0], target.value[:, 0])) alf.summary.scalar( "explained_variance_of_value1", tensor_utils.explained_variance(model_output.value[:, 1:], target.value[:, 1:])) if self._train_reward_function: alf.summary.scalar( "explained_variance_of_reward0", tensor_utils.explained_variance( model_output.reward[:, 0], target.reward[:, 0])) alf.summary.scalar( "explained_variance_of_reward1", tensor_utils.explained_variance( model_output.reward[:, 1:], target.reward[:, 1:])) if self._train_game_over_function: def _entropy(events): p = events.to(torch.float32).mean() p = torch.tensor([p, 1 - p]) return -(p * (p + 1e-30).log()).sum(), p[0] h0, p0 = _entropy(target.game_over[:, 0]) alf.summary.scalar("game_over0", p0) h1, p1 = _entropy(target.game_over[:, 1:]) alf.summary.scalar("game_over1", p1) alf.summary.scalar( "explained_entropy_of_game_over0", torch.where( h0 == 0, h0, 1. - unscaled_game_over_loss[:, 0].mean() / (h0 + 1e-30))) alf.summary.scalar( "explained_entropy_of_game_over1", torch.where( h1 == 0, h1, 1. - unscaled_game_over_loss[:, 0].mean() / (h1 + 1e-30))) summary_utils.add_mean_hist_summary("target_value", target.value) summary_utils.add_mean_hist_summary("value", model_output.value) summary_utils.add_mean_hist_summary( "td_error", target.value - model_output.value) return LossInfo( loss=loss, extra=dict( value=value_loss, reward=reward_loss, policy=policy_loss, game_over=game_over_loss))