def actor_loss_fn(dqda, action): if self._dqda_clipping: dqda = torch.clamp(dqda, -self._dqda_clipping, self._dqda_clipping) loss = 0.5 * losses.element_wise_squared_loss( (dqda + action).detach(), action) return loss.sum(list(range(1, loss.ndim)))
def calc_loss(self, training_info: TrainingInfo): info = training_info.info # SarsaInfo critic_loss = losses.element_wise_squared_loss(info.returns, info.critic) not_first_step = tf.not_equal(training_info.step_type, StepType.FIRST) critic_loss *= tf.cast(not_first_step, tf.float32) def _summary(): with self.name_scope: tf.summary.scalar("values", tf.reduce_mean(info.critic)) tf.summary.scalar("returns", tf.reduce_mean(info.returns)) safe_mean_hist_summary("td_error", info.returns - info.critic) tf.summary.scalar( "explained_variance_of_return_by_value", common.explained_variance(info.critic, info.returns)) if self._debug_summaries: common.run_if(common.should_record_summaries(), _summary) return LossInfo( loss=info.actor_loss, # put critic_loss to scalar_loss because loss will be masked by # ~is_last at train_complete(). The critic_loss here should be # masked by ~is_first instead, which is done above. scalar_loss=tf.reduce_mean(critic_loss), extra=SarsaLossInfo(actor=info.actor_loss, critic=critic_loss))
def actor_loss_fn(dqda, action): if self._dqda_clipping: dqda = tf.clip_by_value(dqda, -self._dqda_clipping, self._dqda_clipping) loss = 0.5 * losses.element_wise_squared_loss( tf.stop_gradient(dqda + action), action) loss = tf.reduce_sum(loss, axis=list(range(1, len(loss.shape)))) return loss
def actor_loss_fn(dqda, action): if self._dqda_clipping: dqda = torch.clamp(dqda, -self._dqda_clipping, self._dqda_clipping) loss = 0.5 * losses.element_wise_squared_loss( (dqda + action).detach(), action) if self._action_l2 > 0: assert action.requires_grad loss += self._action_l2 * (action**2) loss = loss.sum(list(range(1, loss.ndim))) return loss
def _predict_skill_loss(self, observation, prev_action, prev_skill, steps, state): # steps -> {1,2,3} if self._skill_type == "action": subtrajectory = (state.first_observation, prev_action) elif self._skill_type == "action_difference": action_difference = prev_action - state.subtrajectory[:, 1, :] subtrajectory = (state.first_observation, action_difference) elif self._skill_type == "state_action": subtrajectory = (observation, prev_action) elif self._skill_type == "state": subtrajectory = observation elif self._skill_type == "state_difference": subtrajectory = observation - state.untrans_observation elif "concatenation" in self._skill_type: subtrajectory = alf.nest.map_structure( lambda traj: traj.reshape(observation.shape[0], -1), state.subtrajectory) if is_action_skill(self._skill_type): subtrajectory = (state.first_observation, subtrajectory.prev_action) else: subtrajectory = subtrajectory.observation if self._skill_encoder is not None: steps = self._num_steps_per_skill - steps if not isinstance(subtrajectory, tuple): subtrajectory = (subtrajectory, ) subtrajectory = subtrajectory + (steps, ) with torch.no_grad(): prev_skill, _ = self._skill_encoder((prev_skill, steps)) if isinstance(self._skill_discriminator, EncodingNetwork): pred_skill, _ = self._skill_discriminator(subtrajectory) loss = torch.sum(losses.element_wise_squared_loss( prev_skill, pred_skill), dim=-1) else: pred_skill_dist, _ = self._skill_discriminator(subtrajectory) loss = -dist_utils.compute_log_probability(pred_skill_dist, prev_skill) return loss
def calc_loss(self, model_output: ModelOutput, target: ModelTarget) -> LossInfo: """Calculate the loss. The shapes of the tensors in model_output are [B, unroll_steps+1, ...] Returns: LossInfo """ num_unroll_steps = target.value.shape[1] - 1 loss_scale = torch.ones((num_unroll_steps + 1, )) / num_unroll_steps loss_scale[0] = 1.0 value_loss = element_wise_squared_loss(target.value, model_output.value) value_loss = (loss_scale * value_loss).sum(dim=1) loss = value_loss reward_loss = () if self._train_reward_function: reward_loss = element_wise_squared_loss(target.reward, model_output.reward) reward_loss = (loss_scale * reward_loss).sum(dim=1) loss = loss + reward_loss # target_action.shape is [B, unroll_steps+1, num_candidate] # log_prob needs sample shape in the beginning if isinstance(target.action, tuple) and target.action == (): # This condition is only possible for Categorical distribution assert isinstance(model_output.action_distribution, td.Categorical) policy_loss = -(target.action_policy * model_output.action_distribution.logits).sum(dim=2) else: action = target.action.permute(2, 0, 1, *list(range(3, target.action.ndim))) action_log_probs = model_output.action_distribution.log_prob( action) action_log_probs = action_log_probs.permute(1, 2, 0) policy_loss = -(target.action_policy * action_log_probs).sum(dim=2) game_over_loss = () if self._train_game_over_function: game_over_loss = F.binary_cross_entropy_with_logits( input=model_output.game_over_logit, target=target.game_over.to(torch.float), reduction='none') # no need to train policy after game over. policy_loss = policy_loss * (~target.game_over).to(torch.float32) unscaled_game_over_loss = game_over_loss game_over_loss = (loss_scale * game_over_loss).sum(dim=1) loss = loss + game_over_loss policy_loss = (loss_scale * policy_loss).sum(dim=1) loss = loss + policy_loss if self._debug_summaries and alf.summary.should_record_summaries(): with alf.summary.scope(self._name): alf.summary.scalar( "explained_variance_of_value0", tensor_utils.explained_variance(model_output.value[:, 0], target.value[:, 0])) alf.summary.scalar( "explained_variance_of_value1", tensor_utils.explained_variance(model_output.value[:, 1:], target.value[:, 1:])) if self._train_reward_function: alf.summary.scalar( "explained_variance_of_reward0", tensor_utils.explained_variance( model_output.reward[:, 0], target.reward[:, 0])) alf.summary.scalar( "explained_variance_of_reward1", tensor_utils.explained_variance( model_output.reward[:, 1:], target.reward[:, 1:])) if self._train_game_over_function: def _entropy(events): p = events.to(torch.float32).mean() p = torch.tensor([p, 1 - p]) return -(p * (p + 1e-30).log()).sum(), p[0] h0, p0 = _entropy(target.game_over[:, 0]) alf.summary.scalar("game_over0", p0) h1, p1 = _entropy(target.game_over[:, 1:]) alf.summary.scalar("game_over1", p1) alf.summary.scalar( "explained_entropy_of_game_over0", torch.where( h0 == 0, h0, 1. - unscaled_game_over_loss[:, 0].mean() / (h0 + 1e-30))) alf.summary.scalar( "explained_entropy_of_game_over1", torch.where( h1 == 0, h1, 1. - unscaled_game_over_loss[:, 0].mean() / (h1 + 1e-30))) summary_utils.add_mean_hist_summary("target_value", target.value) summary_utils.add_mean_hist_summary("value", model_output.value) summary_utils.add_mean_hist_summary( "td_error", target.value - model_output.value) return LossInfo( loss=loss, extra=dict( value=value_loss, reward=reward_loss, policy=policy_loss, game_over=game_over_loss))