Example #1
0
    def _calc_critic_loss(self, experience, train_info: SacInfo):
        critic_info = train_info.critic

        critic_losses = []
        for i, l in enumerate(self._critic_losses):
            critic_losses.append(
                l(experience=experience,
                  value=critic_info.critics[:, :, i, ...],
                  target_value=critic_info.target_critic).loss)

        critic_loss = math_ops.add_n(critic_losses)

        if (experience.batch_info != ()
                and experience.batch_info.importance_weights != ()):
            valid_masks = (experience.step_type != StepType.LAST).to(
                torch.float32)
            valid_n = torch.clamp(valid_masks.sum(dim=0), min=1.0)
            priority = ((critic_loss * valid_masks).sum(dim=0) /
                        valid_n).sqrt()
        else:
            priority = ()

        return LossInfo(loss=critic_loss,
                        priority=priority,
                        extra=critic_loss / float(self._num_critic_replicas))
Example #2
0
    def calc_loss(self, experience, train_info: DdpgInfo):

        critic_losses = [None] * self._num_critic_replicas
        for i in range(self._num_critic_replicas):
            critic_losses[i] = self._critic_losses[i](
                experience=experience,
                value=train_info.critic.q_values[:, :, i, ...],
                target_value=train_info.critic.target_q_values).loss

        critic_loss = math_ops.add_n(critic_losses)

        if (experience.batch_info != ()
                and experience.batch_info.importance_weights != ()):
            valid_masks = (experience.step_type != StepType.LAST).to(
                torch.float32)
            valid_n = torch.clamp(valid_masks.sum(dim=0), min=1.0)
            priority = ((critic_loss * valid_masks).sum(dim=0) /
                        valid_n).sqrt()
        else:
            priority = ()

        actor_loss = train_info.actor_loss

        return LossInfo(loss=critic_loss + actor_loss.loss,
                        priority=priority,
                        extra=DdpgLossInfo(critic=critic_loss,
                                           actor=actor_loss.extra))
Example #3
0
        def _benchmark(pnet, name):
            t0 = time.time()
            outputs = []
            for _ in range(1000):
                embedding = input_spec.randn(outer_dims=(batch_size, ))
                output, _ = pnet(embedding)
                outputs.append(output)
            o = math_ops.add_n(outputs).sum()
            logging.info("%s time=%s %s" % (name, time.time() - t0, float(o)))

            self.assertEqual(output.shape, (batch_size, replicas, 1))
            self.assertEqual(pnet.output_spec.shape, (replicas, 1))
Example #4
0
    def _actor_train_step(self, exp: Experience, state, action, critics,
                          log_pi, action_distribution):
        neg_entropy = sum(nest.flatten(log_pi))

        if self._act_type == ActionType.Discrete:
            # Pure discrete case doesn't need to learn an actor network
            return (), LossInfo(extra=SacActorInfo(neg_entropy=neg_entropy))

        if self._act_type == ActionType.Continuous:
            critics, critics_state = self._compute_critics(
                self._critic_networks, exp.observation, action, state)
            if critics.ndim == 3:
                # Multidimensional reward: [B, num_criric_replicas, reward_dim]
                if self._reward_weights is None:
                    critics = critics.sum(dim=2)
                else:
                    critics = torch.tensordot(critics,
                                              self._reward_weights,
                                              dims=1)

            target_q_value = critics.min(dim=1)[0]
            continuous_log_pi = log_pi
            cont_alpha = torch.exp(self._log_alpha).detach()
        else:
            # use the critics computed during action prediction for Mixed type
            critics_state = ()
            discrete_act_dist = action_distribution[0]
            discrete_entropy = discrete_act_dist.entropy()
            # critics is already after min over replicas
            weighted_q_value = torch.sum(discrete_act_dist.probs * critics,
                                         dim=-1)
            discrete_alpha = torch.exp(self._log_alpha[0]).detach()
            target_q_value = weighted_q_value + discrete_alpha * discrete_entropy
            action, continuous_log_pi = action[1], log_pi[1]
            cont_alpha = torch.exp(self._log_alpha[1]).detach()

        dqda = nest_utils.grad(action, target_q_value.sum())

        def actor_loss_fn(dqda, action):
            if self._dqda_clipping:
                dqda = torch.clamp(dqda, -self._dqda_clipping,
                                   self._dqda_clipping)
            loss = 0.5 * losses.element_wise_squared_loss(
                (dqda + action).detach(), action)
            return loss.sum(list(range(1, loss.ndim)))

        actor_loss = nest.map_structure(actor_loss_fn, dqda, action)
        actor_loss = math_ops.add_n(nest.flatten(actor_loss))
        actor_info = LossInfo(loss=actor_loss + cont_alpha * continuous_log_pi,
                              extra=SacActorInfo(actor_loss=actor_loss,
                                                 neg_entropy=neg_entropy))
        return critics_state, actor_info
Example #5
0
 def decode_step(self, latent_vector, observations):
     """Calculate decoding loss."""
     decoders = flatten(self._decoders)
     observations = flatten(observations)
     decoder_losses = [
         decoder.train_step((latent_vector, obs)).info
         for decoder, obs in zip(decoders, observations)
     ]
     loss = math_ops.add_n(
         [decoder_loss.loss for decoder_loss in decoder_losses])
     decoder_losses = alf.nest.pack_sequence_as(self._decoders,
                                                decoder_losses)
     return LossInfo(loss=loss, extra=decoder_losses)
Example #6
0
    def _train_step(self, time_step: TimeStep, state: SarsaState):
        not_first_step = time_step.step_type != StepType.FIRST
        prev_critics, critic_states = self._critic_networks(
            (state.prev_observation, time_step.prev_action), state.critics)

        critic_states = common.reset_state_if_necessary(
            state.critics, critic_states, not_first_step)

        action_distribution, action, actor_state, noise_state = self._get_action(
            self._actor_network, time_step, state)

        critics, _ = self._critic_networks((time_step.observation, action),
                                           critic_states)
        critic = critics.min(dim=1)[0]
        dqda = nest_utils.grad(action, critic.sum())

        def actor_loss_fn(dqda, action):
            if self._dqda_clipping:
                dqda = dqda.clamp(-self._dqda_clipping, self._dqda_clipping)
            loss = 0.5 * losses.element_wise_squared_loss(
                (dqda + action).detach(), action)
            loss = loss.sum(list(range(1, loss.ndim)))
            return loss

        actor_loss = nest_map(actor_loss_fn, dqda, action)
        actor_loss = math_ops.add_n(alf.nest.flatten(actor_loss))

        neg_entropy = ()
        if self._log_alpha is not None:
            neg_entropy = dist_utils.compute_log_probability(
                action_distribution, action)

        target_critics, target_critic_states = self._target_critic_networks(
            (time_step.observation, action), state.target_critics)

        info = SarsaInfo(action_distribution=action_distribution,
                         actor_loss=actor_loss,
                         critics=prev_critics,
                         neg_entropy=neg_entropy,
                         target_critics=target_critics.min(dim=1)[0])

        rl_state = SarsaState(noise=noise_state,
                              prev_observation=time_step.observation,
                              prev_step_type=time_step.step_type,
                              actor=actor_state,
                              critics=critic_states,
                              target_critics=target_critic_states)

        return AlgStep(action, rl_state, info)
Example #7
0
    def _calc_critic_loss(self, experience, train_info: MdqInfo):
        critic_info = train_info.critic

        # [t, B, n]
        critic_free_form = critic_info.critic_free_form
        # [t, B, n, action_dim]
        critic_adv_form = critic_info.critic_adv_form
        target_critic_free_form = critic_info.target_critic_free_form
        distill_target = critic_info.distill_target

        num_critic_replicas = critic_free_form.shape[2]

        alpha = torch.exp(self._log_alpha).detach()
        kl_wrt_prior = critic_info.kl_wrt_prior

        # [t, B, n, action_dim] -> [t, B]
        # note that currently the kl_wrt_prior is independent of ensembles,
        # we therefore slice over ensemble by taking the first element;
        # for the aciton dimension, the first element is the full KL
        kl_wrt_prior = kl_wrt_prior[..., 0, 0]

        # [t, B, n] -> [t, B]
        target_critic, min_target_ind = torch.min(
            target_critic_free_form, dim=2)

        # [t, B, n] -> [t, B]
        distill_target, _ = torch.min(distill_target, dim=2)

        target_critic_corrected = target_critic - alpha * kl_wrt_prior

        critic_losses = []
        for j in range(num_critic_replicas):
            critic_losses.append(self._critic_losses[j](
                experience=experience,
                value=critic_free_form[:, :, j],
                target_value=target_critic_corrected).loss)

        critic_loss = math_ops.add_n(critic_losses)

        distill_loss = (
            critic_adv_form[..., -1] - distill_target.unsqueeze(2).detach())**2
        # mean over replica
        distill_loss = distill_loss.mean(dim=2)

        return LossInfo(
            loss=critic_loss,
            extra=critic_loss / len(critic_losses)), distill_loss
Example #8
0
    def calc_loss(self, experience, info: SarsaInfo):
        loss = info.actor_loss
        if self._log_alpha is not None:
            alpha = self._log_alpha.exp().detach()
            alpha_loss = self._log_alpha * (-info.neg_entropy -
                                            self._target_entropy).detach()
            loss = loss + alpha * info.neg_entropy + alpha_loss
        else:
            alpha_loss = ()

        # For sarsa, info.critics is actually the critics for the previous step.
        # And info.target_critics is the critics for the current step. So we
        # need to rearrange ``experience``` to match the requirement for
        # `OneStepTDLoss`.
        step_type0 = experience.step_type[0]
        step_type0 = torch.where(step_type0 == StepType.LAST,
                                 torch.tensor(StepType.MID), step_type0)
        step_type0 = torch.where(step_type0 == StepType.FIRST,
                                 torch.tensor(StepType.LAST), step_type0)

        reward = experience.reward
        if self._use_entropy_reward:
            reward -= (self._log_alpha.exp() * info.neg_entropy).detach()
        shifted_experience = experience._replace(
            discount=tensor_utils.tensor_prepend_zero(experience.discount),
            reward=tensor_utils.tensor_prepend_zero(reward),
            step_type=tensor_utils.tensor_prepend(experience.step_type,
                                                  step_type0))
        critic_losses = []
        for i in range(self._num_critic_replicas):
            critic = tensor_utils.tensor_extend_zero(info.critics[..., i])
            target_critic = tensor_utils.tensor_prepend_zero(
                info.target_critics)
            loss_info = self._critic_losses[i](shifted_experience, critic,
                                               target_critic)
            critic_losses.append(nest_map(lambda l: l[:-1], loss_info.loss))

        critic_loss = math_ops.add_n(critic_losses)

        not_first_step = (experience.step_type != StepType.FIRST).to(
            torch.float32)
        critic_loss = critic_loss * not_first_step
        if (experience.batch_info != ()
                and experience.batch_info.importance_weights != ()):
            valid_n = torch.clamp(not_first_step.sum(dim=0), min=1.0)
            priority = (critic_loss.sum(dim=0) / valid_n).sqrt()
        else:
            priority = ()

        # put critic_loss to scalar_loss because loss will be masked by
        # ~is_last at train_complete(). The critic_loss here should be
        # masked by ~is_first instead, which is done above
        scalar_loss = critic_loss.mean()

        if self._debug_summaries and alf.summary.should_record_summaries():
            with alf.summary.scope(self._name):
                if self._log_alpha is not None:
                    alf.summary.scalar("alpha", alpha)

        return LossInfo(loss=loss,
                        scalar_loss=scalar_loss,
                        priority=priority,
                        extra=SarsaLossInfo(actor=info.actor_loss,
                                            critic=critic_loss,
                                            alpha=alpha_loss,
                                            neg_entropy=info.neg_entropy))