Пример #1
0
 def actor_loss_fn(dqda, action):
     if self._dqda_clipping:
         dqda = torch.clamp(dqda, -self._dqda_clipping,
                            self._dqda_clipping)
     loss = 0.5 * losses.element_wise_squared_loss(
         (dqda + action).detach(), action)
     return loss.sum(list(range(1, loss.ndim)))
Пример #2
0
    def calc_loss(self, training_info: TrainingInfo):
        info = training_info.info  # SarsaInfo
        critic_loss = losses.element_wise_squared_loss(info.returns,
                                                       info.critic)
        not_first_step = tf.not_equal(training_info.step_type, StepType.FIRST)
        critic_loss *= tf.cast(not_first_step, tf.float32)

        def _summary():
            with self.name_scope:
                tf.summary.scalar("values", tf.reduce_mean(info.critic))
                tf.summary.scalar("returns", tf.reduce_mean(info.returns))
                safe_mean_hist_summary("td_error", info.returns - info.critic)
                tf.summary.scalar(
                    "explained_variance_of_return_by_value",
                    common.explained_variance(info.critic, info.returns))

        if self._debug_summaries:
            common.run_if(common.should_record_summaries(), _summary)

        return LossInfo(
            loss=info.actor_loss,
            # put critic_loss to scalar_loss because loss will be masked by
            # ~is_last at train_complete(). The critic_loss here should be
            # masked by ~is_first instead, which is done above.
            scalar_loss=tf.reduce_mean(critic_loss),
            extra=SarsaLossInfo(actor=info.actor_loss, critic=critic_loss))
Пример #3
0
 def actor_loss_fn(dqda, action):
     if self._dqda_clipping:
         dqda = tf.clip_by_value(dqda, -self._dqda_clipping,
                                 self._dqda_clipping)
     loss = 0.5 * losses.element_wise_squared_loss(
         tf.stop_gradient(dqda + action), action)
     loss = tf.reduce_sum(loss, axis=list(range(1, len(loss.shape))))
     return loss
Пример #4
0
 def actor_loss_fn(dqda, action):
     if self._dqda_clipping:
         dqda = torch.clamp(dqda, -self._dqda_clipping,
                            self._dqda_clipping)
     loss = 0.5 * losses.element_wise_squared_loss(
         (dqda + action).detach(), action)
     if self._action_l2 > 0:
         assert action.requires_grad
         loss += self._action_l2 * (action**2)
     loss = loss.sum(list(range(1, loss.ndim)))
     return loss
Пример #5
0
    def _predict_skill_loss(self, observation, prev_action, prev_skill, steps,
                            state):
        # steps -> {1,2,3}
        if self._skill_type == "action":
            subtrajectory = (state.first_observation, prev_action)
        elif self._skill_type == "action_difference":
            action_difference = prev_action - state.subtrajectory[:, 1, :]
            subtrajectory = (state.first_observation, action_difference)
        elif self._skill_type == "state_action":
            subtrajectory = (observation, prev_action)
        elif self._skill_type == "state":
            subtrajectory = observation
        elif self._skill_type == "state_difference":
            subtrajectory = observation - state.untrans_observation
        elif "concatenation" in self._skill_type:
            subtrajectory = alf.nest.map_structure(
                lambda traj: traj.reshape(observation.shape[0], -1),
                state.subtrajectory)
            if is_action_skill(self._skill_type):
                subtrajectory = (state.first_observation,
                                 subtrajectory.prev_action)
            else:
                subtrajectory = subtrajectory.observation

        if self._skill_encoder is not None:
            steps = self._num_steps_per_skill - steps
            if not isinstance(subtrajectory, tuple):
                subtrajectory = (subtrajectory, )
            subtrajectory = subtrajectory + (steps, )
            with torch.no_grad():
                prev_skill, _ = self._skill_encoder((prev_skill, steps))

        if isinstance(self._skill_discriminator, EncodingNetwork):
            pred_skill, _ = self._skill_discriminator(subtrajectory)
            loss = torch.sum(losses.element_wise_squared_loss(
                prev_skill, pred_skill),
                             dim=-1)
        else:
            pred_skill_dist, _ = self._skill_discriminator(subtrajectory)
            loss = -dist_utils.compute_log_probability(pred_skill_dist,
                                                       prev_skill)
        return loss
Пример #6
0
    def calc_loss(self, model_output: ModelOutput,
                  target: ModelTarget) -> LossInfo:
        """Calculate the loss.

        The shapes of the tensors in model_output are [B, unroll_steps+1, ...]
        Returns:
            LossInfo
        """
        num_unroll_steps = target.value.shape[1] - 1
        loss_scale = torch.ones((num_unroll_steps + 1, )) / num_unroll_steps
        loss_scale[0] = 1.0

        value_loss = element_wise_squared_loss(target.value,
                                               model_output.value)
        value_loss = (loss_scale * value_loss).sum(dim=1)
        loss = value_loss

        reward_loss = ()
        if self._train_reward_function:
            reward_loss = element_wise_squared_loss(target.reward,
                                                    model_output.reward)
            reward_loss = (loss_scale * reward_loss).sum(dim=1)
            loss = loss + reward_loss

        # target_action.shape is [B, unroll_steps+1, num_candidate]
        # log_prob needs sample shape in the beginning
        if isinstance(target.action, tuple) and target.action == ():
            # This condition is only possible for Categorical distribution
            assert isinstance(model_output.action_distribution, td.Categorical)
            policy_loss = -(target.action_policy *
                            model_output.action_distribution.logits).sum(dim=2)
        else:
            action = target.action.permute(2, 0, 1,
                                           *list(range(3, target.action.ndim)))
            action_log_probs = model_output.action_distribution.log_prob(
                action)
            action_log_probs = action_log_probs.permute(1, 2, 0)
            policy_loss = -(target.action_policy * action_log_probs).sum(dim=2)

        game_over_loss = ()
        if self._train_game_over_function:
            game_over_loss = F.binary_cross_entropy_with_logits(
                input=model_output.game_over_logit,
                target=target.game_over.to(torch.float),
                reduction='none')
            # no need to train policy after game over.
            policy_loss = policy_loss * (~target.game_over).to(torch.float32)
            unscaled_game_over_loss = game_over_loss
            game_over_loss = (loss_scale * game_over_loss).sum(dim=1)
            loss = loss + game_over_loss

        policy_loss = (loss_scale * policy_loss).sum(dim=1)
        loss = loss + policy_loss

        if self._debug_summaries and alf.summary.should_record_summaries():
            with alf.summary.scope(self._name):
                alf.summary.scalar(
                    "explained_variance_of_value0",
                    tensor_utils.explained_variance(model_output.value[:, 0],
                                                    target.value[:, 0]))
                alf.summary.scalar(
                    "explained_variance_of_value1",
                    tensor_utils.explained_variance(model_output.value[:, 1:],
                                                    target.value[:, 1:]))
                if self._train_reward_function:
                    alf.summary.scalar(
                        "explained_variance_of_reward0",
                        tensor_utils.explained_variance(
                            model_output.reward[:, 0], target.reward[:, 0]))
                    alf.summary.scalar(
                        "explained_variance_of_reward1",
                        tensor_utils.explained_variance(
                            model_output.reward[:, 1:], target.reward[:, 1:]))

                if self._train_game_over_function:

                    def _entropy(events):
                        p = events.to(torch.float32).mean()
                        p = torch.tensor([p, 1 - p])
                        return -(p * (p + 1e-30).log()).sum(), p[0]

                    h0, p0 = _entropy(target.game_over[:, 0])
                    alf.summary.scalar("game_over0", p0)
                    h1, p1 = _entropy(target.game_over[:, 1:])
                    alf.summary.scalar("game_over1", p1)

                    alf.summary.scalar(
                        "explained_entropy_of_game_over0",
                        torch.where(
                            h0 == 0, h0,
                            1. - unscaled_game_over_loss[:, 0].mean() /
                            (h0 + 1e-30)))
                    alf.summary.scalar(
                        "explained_entropy_of_game_over1",
                        torch.where(
                            h1 == 0, h1,
                            1. - unscaled_game_over_loss[:, 0].mean() /
                            (h1 + 1e-30)))
                summary_utils.add_mean_hist_summary("target_value",
                                                    target.value)
                summary_utils.add_mean_hist_summary("value",
                                                    model_output.value)
                summary_utils.add_mean_hist_summary(
                    "td_error", target.value - model_output.value)

        return LossInfo(
            loss=loss,
            extra=dict(
                value=value_loss,
                reward=reward_loss,
                policy=policy_loss,
                game_over=game_over_loss))