Ejemplo n.º 1
0
    def _calc_returns_and_advantages(self, experience, value):
        returns = value_ops.discounted_return(rewards=experience.reward,
                                              values=value,
                                              step_types=experience.step_type,
                                              discounts=experience.discount *
                                              self._gamma)
        returns = tensor_utils.tensor_extend(returns, value[-1])

        if not self._use_gae:
            advantages = returns - value
        else:
            advantages = value_ops.generalized_advantage_estimation(
                rewards=experience.reward,
                values=value,
                step_types=experience.step_type,
                discounts=experience.discount * self._gamma,
                td_lambda=self._lambda)
            advantages = tensor_utils.tensor_extend_zero(advantages)
            if self._use_td_lambda_return:
                returns = advantages + value

        return returns, advantages
Ejemplo n.º 2
0
    def _backup(self, trees: _MCTSTree, search_paths, path_lengths, values):
        B = trees.B.unsqueeze(0)
        T = search_paths.shape[0]
        depth = torch.arange(T).unsqueeze(-1)

        if trees.reward is not None:
            reward = trees.reward[B, search_paths]
            reward[depth > path_lengths] = 0.
            # [T+1, batch_size]
            reward = tensor_utils.tensor_extend_zero(reward)
            reward[path_lengths, B] = values
            discounts = (self._discount**torch.arange(
                T + 1, dtype=torch.float32)).unsqueeze(-1)

            # discounted_return[t] = discount^t * reward[t]
            discounted_return = reward * discounts
            # discounted_return[t] = \sum_{s=t}^T discount^s * reward[s]
            discounted_return = reward.flip(0).cumsum(dim=0).flip(0)

            # discounted_return[t] = \sum_{s=t}^T discount^(s-t) * reward[s]
            discounted_return = discounted_return / discounts
            discounted_return = discounted_return[1:]
        else:
            # [T, 1]
            steps = torch.arange(1, T + 1, dtype=torch.float32).unsqueeze(-1)
            # [T, B]
            discounts = self._discount**(path_lengths.unsqueeze(0) - steps)
            discounted_return = values.unsqueeze(0) * discounts

        value_sum = trees.value_sum[B, search_paths] + discounted_return

        valid = depth < path_lengths
        nodes = (B.expand(T, -1)[valid], search_paths[valid])
        trees.visit_count[nodes] += 1
        trees.value_sum[nodes] = value_sum[valid]
        trees.update_value_stats((B, search_paths), valid)
        self._update_best_child(trees, nodes)
Ejemplo n.º 3
0
    def forward(self, experience, value, target_value):
        """Cacluate the loss.

        The first dimension of all the tensors is time dimension and the second
        dimesion is the batch dimension.

        Args:
            experience (Experience): experience collected from ``unroll()`` or
                a replay buffer. All tensors are time-major.
            value (torch.Tensor): the time-major tensor for the value at each time
                step. The loss is between this and the calculated return.
            target_value (torch.Tensor): the time-major tensor for the value at
                each time step. This is used to calculate return. ``target_value``
                can be same as ``value``.
        Returns:
            LossInfo: with the ``extra`` field same as ``loss``.
        """
        if self._lambda == 1.0:
            returns = value_ops.discounted_return(
                rewards=experience.reward,
                values=target_value,
                step_types=experience.step_type,
                discounts=experience.discount * self._gamma)
        elif self._lambda == 0.0:
            returns = value_ops.one_step_discounted_return(
                rewards=experience.reward,
                values=target_value,
                step_types=experience.step_type,
                discounts=experience.discount * self._gamma)
        else:
            advantages = value_ops.generalized_advantage_estimation(
                rewards=experience.reward,
                values=target_value,
                step_types=experience.step_type,
                discounts=experience.discount * self._gamma,
                td_lambda=self._lambda)
            returns = advantages + target_value[:-1]

        value = value[:-1]

        if self._debug_summaries and alf.summary.should_record_summaries():
            mask = experience.step_type[:-1] != StepType.LAST
            with alf.summary.scope(self._name):

                def _summarize(v, r, td, suffix):
                    alf.summary.scalar(
                        "explained_variance_of_return_by_value" + suffix,
                        tensor_utils.explained_variance(v, r, mask))
                    safe_mean_hist_summary('values' + suffix, v, mask)
                    safe_mean_hist_summary('returns' + suffix, r, mask)
                    safe_mean_hist_summary("td_error" + suffix, td, mask)

                if value.ndim == 2:
                    _summarize(value, returns, returns - value, '')
                else:
                    td = returns - value
                    for i in range(value.shape[2]):
                        suffix = '/' + str(i)
                        _summarize(value[..., i], returns[..., i], td[..., i],
                                   suffix)

        loss = self._td_error_loss_fn(returns.detach(), value)

        if loss.ndim == 3:
            # Multidimensional reward. Average over the critic loss for all dimensions
            loss = loss.mean(dim=2)

        # The shape of the loss expected by Algorith.update_with_gradient is
        # [T, B], so we need to augment it with additional zeros.
        loss = tensor_utils.tensor_extend_zero(loss)
        return LossInfo(loss=loss, extra=loss)
Ejemplo n.º 4
0
 def calc_loss(self, experience, train_info: td.Distribution):
     dist: td.Distribution = train_info
     log_prob = dist.log_prob(experience.action)
     loss = -log_prob[:-1] * experience.reward[1:]
     loss = tensor_utils.tensor_extend_zero(loss)
     return LossInfo(loss=loss)
Ejemplo n.º 5
0
    def calc_loss(self, experience, info: SarsaInfo):
        loss = info.actor_loss
        if self._log_alpha is not None:
            alpha = self._log_alpha.exp().detach()
            alpha_loss = self._log_alpha * (-info.neg_entropy -
                                            self._target_entropy).detach()
            loss = loss + alpha * info.neg_entropy + alpha_loss
        else:
            alpha_loss = ()

        # For sarsa, info.critics is actually the critics for the previous step.
        # And info.target_critics is the critics for the current step. So we
        # need to rearrange ``experience``` to match the requirement for
        # `OneStepTDLoss`.
        step_type0 = experience.step_type[0]
        step_type0 = torch.where(step_type0 == StepType.LAST,
                                 torch.tensor(StepType.MID), step_type0)
        step_type0 = torch.where(step_type0 == StepType.FIRST,
                                 torch.tensor(StepType.LAST), step_type0)

        reward = experience.reward
        if self._use_entropy_reward:
            reward -= (self._log_alpha.exp() * info.neg_entropy).detach()
        shifted_experience = experience._replace(
            discount=tensor_utils.tensor_prepend_zero(experience.discount),
            reward=tensor_utils.tensor_prepend_zero(reward),
            step_type=tensor_utils.tensor_prepend(experience.step_type,
                                                  step_type0))
        critic_losses = []
        for i in range(self._num_critic_replicas):
            critic = tensor_utils.tensor_extend_zero(info.critics[..., i])
            target_critic = tensor_utils.tensor_prepend_zero(
                info.target_critics)
            loss_info = self._critic_losses[i](shifted_experience, critic,
                                               target_critic)
            critic_losses.append(nest_map(lambda l: l[:-1], loss_info.loss))

        critic_loss = math_ops.add_n(critic_losses)

        not_first_step = (experience.step_type != StepType.FIRST).to(
            torch.float32)
        critic_loss = critic_loss * not_first_step
        if (experience.batch_info != ()
                and experience.batch_info.importance_weights != ()):
            valid_n = torch.clamp(not_first_step.sum(dim=0), min=1.0)
            priority = (critic_loss.sum(dim=0) / valid_n).sqrt()
        else:
            priority = ()

        # put critic_loss to scalar_loss because loss will be masked by
        # ~is_last at train_complete(). The critic_loss here should be
        # masked by ~is_first instead, which is done above
        scalar_loss = critic_loss.mean()

        if self._debug_summaries and alf.summary.should_record_summaries():
            with alf.summary.scope(self._name):
                if self._log_alpha is not None:
                    alf.summary.scalar("alpha", alpha)

        return LossInfo(loss=loss,
                        scalar_loss=scalar_loss,
                        priority=priority,
                        extra=SarsaLossInfo(actor=info.actor_loss,
                                            critic=critic_loss,
                                            alpha=alpha_loss,
                                            neg_entropy=info.neg_entropy))