class SimpleDecoder(Algorithm):
    """A simple decoder with elementwise loss between the target and the predicted value.

    It is used to predict the target value from the given representation. Its
    loss can be used to train the representation.
    """
    def __init__(self,
                 input_tensor_spec,
                 target_field,
                 decoder_net_ctor,
                 loss_ctor=partial(torch.nn.SmoothL1Loss, reduction='none'),
                 loss_weight=1.0,
                 summarize_each_dimension=False,
                 optimizer=None,
                 normalize_target=False,
                 append_target_field_to_name=True,
                 debug_summaries=False,
                 name="SimpleDecoder"):
        """
        Args:
            input_tensor_spec (TensorSpec): describing the input tensor.
            target_field (str): name of the field in the experience to be used
                as the decoding target.
            decoder_net_ctor (Callable): called as ``decoder_net_ctor(input_tensor_spec=input_tensor_spec)``
                to construct an instance of ``Network`` for decoding. The network
                should take the latent representation as input and output the
                predicted value of the target.
            loss_ctor (Callable): loss function with signature
                ``loss(y_pred, y_true)``. Note that it should not reduce to a
                scalar. It should at least keep the batch dimension in the
                returned loss.
            loss_weight (float): weight for the loss.
            optimizer (Optimzer|None): if provided, it will be used to optimize
                the parameter of decoder_net
            normalize_target (bool): whether to normalize target.
                Note that the effect of this is to change the loss. The predicted
                value itself is not normalized.
            append_target_field_to_name (bool): whether append target field to
                the name of the decoder. If True, the actual name used will be
                ``name.target_field``
            debug_summaries (bool): whether to generate debug summaries
            name (str): name of this instance
        """
        if append_target_field_to_name:
            name = name + "." + target_field

        super().__init__(optimizer=optimizer,
                         debug_summaries=debug_summaries,
                         name=name)
        self._decoder_net = decoder_net_ctor(
            input_tensor_spec=input_tensor_spec)
        assert self._decoder_net.state_spec == (
        ), "RNN decoder is not suppported"
        self._summarize_each_dimension = summarize_each_dimension
        self._target_field = target_field
        self._loss = loss_ctor()
        self._loss_weight = loss_weight
        if normalize_target:
            self._target_normalizer = AdaptiveNormalizer(
                self._decoder_net.output_spec,
                auto_update=False,
                name=name + ".target_normalizer")
        else:
            self._target_normalizer = None

    def get_target_fields(self):
        return self._target_field

    def train_step(self, repr, state=()):
        predicted_target = self._decoder_net(repr)[0]
        return AlgStep(output=predicted_target,
                       state=state,
                       info=predicted_target)

    def predict_step(self, repr, state=()):
        predicted_target = self._decoder_net(repr)[0]
        return AlgStep(output=predicted_target,
                       state=state,
                       info=predicted_target)

    def calc_loss(self, target, predicted, mask=None):
        """Calculate the loss between ``target`` and ``predicted``.

        Args:
            target (Tensor): target to be predicted. Its shape is [T, B, ...]
            predicted (Tensor): predicted target. Its shape is [T, B, ...]
            mask (bool Tensor): indicating which target should be predicted.
                Its shape is [T, B].
        Returns:
            LossInfo
        """
        if self._target_normalizer:
            self._target_normalizer.update(target)
            target = self._target_normalizer.normalize(target)
            predicted = self._target_normalizer.normalize(predicted)

        loss = self._loss(predicted, target)
        if self._debug_summaries and alf.summary.should_record_summaries():
            with alf.summary.scope(self._name):

                def _summarize1(pred, tgt, loss, mask, suffix):
                    alf.summary.scalar(
                        "explained_variance" + suffix,
                        tensor_utils.explained_variance(pred, tgt, mask))
                    safe_mean_hist_summary('predict' + suffix, pred, mask)
                    safe_mean_hist_summary('target' + suffix, tgt, mask)
                    safe_mean_summary("loss" + suffix, loss, mask)

                def _summarize(pred, tgt, loss, mask, suffix):
                    _summarize1(pred[0], tgt[0], loss[0], mask[0],
                                suffix + "/current")
                    if pred.shape[0] > 1:
                        _summarize1(pred[1:], tgt[1:], loss[1:], mask[1:],
                                    suffix + "/future")

                if loss.ndim == 2:
                    _summarize(predicted, target, loss, mask, '')
                elif not self._summarize_each_dimension:
                    m = mask
                    if m is not None:
                        m = m.unsqueeze(-1).expand_as(predicted)
                    _summarize(predicted, target, loss, m, '')
                else:
                    for i in range(predicted.shape[2]):
                        suffix = '/' + str(i)
                        _summarize(predicted[..., i], target[..., i],
                                   loss[..., i], mask, suffix)

        if loss.ndim == 3:
            loss = loss.mean(dim=2)

        if mask is not None:
            loss = loss * mask

        return LossInfo(loss=loss * self._loss_weight, extra=loss)
Example #2
0
class TDLoss(nn.Module):
    def __init__(self,
                 gamma=0.99,
                 td_error_loss_fn=element_wise_squared_loss,
                 td_lambda=0.95,
                 normalize_target=False,
                 debug_summaries=False,
                 name="TDLoss"):
        r"""Create a TDLoss object.

        Let :math:`G_{t:T}` be the bootstaped return from t to T:
            :math:`G_{t:T} = \sum_{i=t+1}^T \gamma^{t-i-1}R_i + \gamma^{T-t} V(s_T)`
        If ``td_lambda`` = 1, the target for step t is :math:`G_{t:T}`.
        If ``td_lambda`` = 0, the target for step t is :math:`G_{t:t+1}`
        If 0 < ``td_lambda`` < 1, the target for step t is the :math:`\lambda`-return:
            :math:`G_t^\lambda = (1 - \lambda) \sum_{i=t+1}^{T-1} \lambda^{i-t}G_{t:i} + \lambda^{T-t-1} G_{t:T}`
        There is a simple relationship between :math:`\lambda`-return and
        the generalized advantage estimation :math:`\hat{A}^{GAE}_t`:
            :math:`G_t^\lambda = \hat{A}^{GAE}_t + V(s_t)`
        where the generalized advantage estimation is defined as:
            :math:`\hat{A}^{GAE}_t = \sum_{i=t}^{T-1}(\gamma\lambda)^{i-t}(R_{i+1} + \gamma V(s_{i+1}) - V(s_i))`

        References:

        Schulman et al. `High-Dimensional Continuous Control Using Generalized Advantage Estimation
        <https://arxiv.org/abs/1506.02438>`_

        Sutton et al. `Reinforcement Learning: An Introduction
        <http://incompleteideas.net/book/the-book.html>`_, Chapter 12, 2018

        Args:
            gamma (float): A discount factor for future rewards.
            td_errors_loss_fn (Callable): A function for computing the TD errors
                loss. This function takes as input the target and the estimated
                Q values and returns the loss for each element of the batch.
            td_lambda (float): Lambda parameter for TD-lambda computation.
            normalize_target (bool): whether to normalize target.
                Note that the effect of this is to change the loss. The critic
                value itself is not normalized.
            debug_summaries (bool): True if debug summaries should be created.
            name (str): The name of this loss.
        """
        super().__init__()

        self._name = name
        self._gamma = gamma
        self._td_error_loss_fn = td_error_loss_fn
        self._lambda = td_lambda
        self._debug_summaries = debug_summaries
        self._normalize_target = normalize_target
        self._target_normalizer = None

    def forward(self, experience, value, target_value):
        """Cacluate the loss.

        The first dimension of all the tensors is time dimension and the second
        dimesion is the batch dimension.

        Args:
            experience (Experience): experience collected from ``unroll()`` or
                a replay buffer. All tensors are time-major.
            value (torch.Tensor): the time-major tensor for the value at each time
                step. The loss is between this and the calculated return.
            target_value (torch.Tensor): the time-major tensor for the value at
                each time step. This is used to calculate return. ``target_value``
                can be same as ``value``.
        Returns:
            LossInfo: with the ``extra`` field same as ``loss``.
        """
        if self._lambda == 1.0:
            returns = value_ops.discounted_return(
                rewards=experience.reward,
                values=target_value,
                step_types=experience.step_type,
                discounts=experience.discount * self._gamma)
        elif self._lambda == 0.0:
            returns = value_ops.one_step_discounted_return(
                rewards=experience.reward,
                values=target_value,
                step_types=experience.step_type,
                discounts=experience.discount * self._gamma)
        else:
            advantages = value_ops.generalized_advantage_estimation(
                rewards=experience.reward,
                values=target_value,
                step_types=experience.step_type,
                discounts=experience.discount * self._gamma,
                td_lambda=self._lambda)
            returns = advantages + target_value[:-1]

        value = value[:-1]
        if self._normalize_target:
            if self._target_normalizer is None:
                self._target_normalizer = AdaptiveNormalizer(
                    alf.TensorSpec(value.shape[2:]),
                    auto_update=False,
                    debug_summaries=self._debug_summaries,
                    name=self._name + ".target_normalizer")

            self._target_normalizer.update(returns)
            returns = self._target_normalizer.normalize(returns)
            value = self._target_normalizer.normalize(value)

        if self._debug_summaries and alf.summary.should_record_summaries():
            mask = experience.step_type[:-1] != StepType.LAST
            with alf.summary.scope(self._name):

                def _summarize(v, r, td, suffix):
                    alf.summary.scalar(
                        "explained_variance_of_return_by_value" + suffix,
                        tensor_utils.explained_variance(v, r, mask))
                    safe_mean_hist_summary('values' + suffix, v, mask)
                    safe_mean_hist_summary('returns' + suffix, r, mask)
                    safe_mean_hist_summary("td_error" + suffix, td, mask)

                if value.ndim == 2:
                    _summarize(value, returns, returns - value, '')
                else:
                    td = returns - value
                    for i in range(value.shape[2]):
                        suffix = '/' + str(i)
                        _summarize(value[..., i], returns[..., i], td[..., i],
                                   suffix)

        loss = self._td_error_loss_fn(returns.detach(), value)

        if loss.ndim == 3:
            # Multidimensional reward. Average over the critic loss for all dimensions
            loss = loss.mean(dim=2)

        # The shape of the loss expected by Algorith.update_with_gradient is
        # [T, B], so we need to augment it with additional zeros.
        loss = tensor_utils.tensor_extend_zero(loss)
        return LossInfo(loss=loss, extra=loss)