Пример #1
0
    def test_kl_loss(self):
        input_logits = torch.rand(self._batch_size, self._d,
                                  self._distribution_dim)
        target_logits = torch.rand(self._batch_size, self._d,
                                  self._distribution_dim)

        kld = info_loss.kl_divg_loss_with_logits(
            target_logits,
            input_logits,
            softmax_temperature=0.4,
            confidence_threshold=0.3,
            reduction="mean"
        )

        rank = get_rank(kld)
        self.assertEqual(rank, 0)

        kld = info_loss.kl_divg_loss_with_logits(
            target_logits,
            input_logits,
            softmax_temperature=0.4,
            confidence_threshold=0.3,
            reduction="sum"
        )

        rank = get_rank(kld)
        self.assertEqual(rank, 0)

        kld = info_loss.kl_divg_loss_with_logits(
            target_logits,
            input_logits,
            softmax_temperature=0.4,
            confidence_threshold=0.3,
            reduction="none"
        )

        rank = get_rank(kld)
        self.assertEqual(rank, get_rank(input_logits))

        kld = info_loss.kl_divg_loss_with_logits(
            target_logits,
            input_logits,
            softmax_temperature=0.4,
            confidence_threshold=0.3,
            reduction="batchmean"
        )

        rank = get_rank(kld)
        self.assertEqual(rank, 0)

        kld = info_loss.kl_divg_loss_with_logits(
            target_logits,
            target_logits,
            softmax_temperature=1.0,
            confidence_threshold=-1,
            reduction="mean"
        )

        self.assertLess(kld, 1e-5)
Пример #2
0
    def _test_sequence_loss(self, loss_fn, actions, logits, advantages, batched,
                            sequence_length):
        loss = loss_fn(actions, logits, advantages, batched=batched,
                       sequence_length=sequence_length)
        rank = get_rank(loss)
        self.assertEqual(rank, 0)

        loss = loss_fn(actions, logits, advantages, batched=batched,
                       sequence_length=sequence_length,
                       sum_over_timesteps=False)
        rank = get_rank(loss)
        self.assertEqual(rank, 1)
        self.assertEqual(loss.shape, torch.Size([self._max_time]))

        loss = loss_fn(actions, logits, advantages, batched=batched,
                       sequence_length=sequence_length,
                       sum_over_timesteps=False,
                       average_across_timesteps=True,
                       average_across_batch=False)
        rank = get_rank(loss)
        if batched:
            self.assertEqual(rank, 1)
            self.assertEqual(loss.shape, torch.Size([self._batch_size]))
        else:
            self.assertEqual(rank, 0)

        loss = loss_fn(actions, logits, advantages, batched=batched,
                       sequence_length=sequence_length,
                       sum_over_timesteps=False,
                       average_across_batch=False)
        rank = get_rank(loss)
        if batched:
            self.assertEqual(rank, 2)
            self.assertEqual(loss.shape,
                             torch.Size([self._batch_size, self._max_time]))
        else:
            self.assertEqual(rank, 1)
            self.assertEqual(loss.shape,
                             torch.Size([self._max_time]))

        sequence_length_time = torch.randint(
            high=self._batch_size, size=(self._max_time,))
        loss = loss_fn(actions, logits, advantages, batched=batched,
                       sequence_length=sequence_length_time,
                       sum_over_timesteps=False,
                       average_across_batch=False,
                       time_major=True)
        if batched:
            self.assertEqual(loss.shape, torch.Size([self._batch_size,
                                                     self._max_time]))
        else:
            self.assertEqual(loss.shape, torch.Size([self._max_time]))
Пример #3
0
    def _test_sequence_loss(self, loss_fn, labels, logits, sequence_length):
        loss = loss_fn(labels, logits, sequence_length)
        rank = get_rank(loss)
        self.assertEqual(rank, 0)

        loss = loss_fn(labels,
                       logits,
                       sequence_length,
                       sum_over_timesteps=False)
        rank = get_rank(loss)
        self.assertEqual(rank, 1)
        self.assertEqual(loss.shape, torch.Size([self._max_time]))

        loss = loss_fn(labels,
                       logits,
                       sequence_length,
                       sum_over_timesteps=False,
                       average_across_timesteps=True,
                       average_across_batch=False)
        rank = get_rank(loss)
        self.assertEqual(rank, 1)
        self.assertEqual(loss.shape, torch.Size([self._batch_size]))

        loss = loss_fn(labels,
                       logits,
                       sequence_length,
                       sum_over_timesteps=False,
                       average_across_batch=False)
        rank = get_rank(loss)
        self.assertEqual(rank, 2)
        self.assertEqual(loss.shape,
                         torch.Size([self._batch_size, self._max_time]))

        sequence_length_time = torch.randint(size=(self._max_time, ),
                                             high=self._batch_size)
        loss = loss_fn(labels,
                       logits,
                       sequence_length_time,
                       sum_over_timesteps=False,
                       average_across_batch=False,
                       time_major=True)
        self.assertEqual(loss.shape,
                         torch.Size([self._batch_size, self._max_time]))
Пример #4
0
    def test_sequence_sigmoid_cross_entropy(self):
        """Tests `texar.torch.losses.sequence_sigmoid_cross_entropy`.
        """
        self._test_sequence_loss(mle_losses.sequence_sigmoid_cross_entropy,
                                 self._one_hot_labels, self._logits,
                                 self._sequence_length)

        self._test_sequence_loss(mle_losses.sequence_sigmoid_cross_entropy,
                                 self._one_hot_labels[:, :, 0],
                                 self._logits[:, :, 0], self._sequence_length)

        loss = mle_losses.sequence_sigmoid_cross_entropy(
            logits=self._logits[:, :, 0],
            labels=torch.ones([self._batch_size, self._max_time]),
            sequence_length=self._sequence_length)
        rank = get_rank(loss)
        self.assertEqual(rank, 0)
Пример #5
0
    def _test_entropy(self, entropy_fn, logits, sequence_length=None):
        if sequence_length is None:
            entropy = entropy_fn(logits)
            rank = get_rank(entropy)
            self.assertEqual(rank, 0)

            entropy = entropy_fn(logits, average_across_batch=False)
            rank = get_rank(entropy)
            self.assertEqual(rank, 1)
            self.assertEqual(entropy.shape, torch.Size([self._batch_size]))
        else:
            entropy = entropy_fn(logits, sequence_length=sequence_length)
            rank = get_rank(entropy)
            self.assertEqual(rank, 0)

            entropy = entropy_fn(logits,
                                 sequence_length=sequence_length,
                                 sum_over_timesteps=False)
            rank = get_rank(entropy)
            self.assertEqual(rank, 1)
            self.assertEqual(entropy.shape, torch.Size([self._max_time]))

            entropy = entropy_fn(logits,
                                 sequence_length=sequence_length,
                                 sum_over_timesteps=False,
                                 average_across_timesteps=True,
                                 average_across_batch=False)
            rank = get_rank(entropy)
            self.assertEqual(rank, 1)
            self.assertEqual(entropy.shape, torch.Size([self._batch_size]))

            entropy = entropy_fn(logits,
                                 sequence_length=sequence_length,
                                 sum_over_timesteps=False,
                                 average_across_batch=False)
            rank = get_rank(entropy)
            self.assertEqual(rank, 2)
            self.assertEqual(entropy.shape,
                             torch.Size([self._batch_size, self._max_time]))

            sequence_length_time = torch.randint(size=(self._max_time, ),
                                                 high=self._batch_size)
            entropy = entropy_fn(logits,
                                 sequence_length=sequence_length_time,
                                 sum_over_timesteps=False,
                                 average_across_batch=False,
                                 time_major=True)
            self.assertEqual(entropy.shape,
                             torch.Size([self._batch_size, self._max_time]))
Пример #6
0
    def UDA_pipeline(self, train_num, test_num, unsup_num):
        self.output_sample_features_to_file(
            self.sample_feature,
            self.feature_types,
            self.train_path,
            dup_num=train_num,
        )

        self.output_sample_features_to_file(
            self.sample_feature,
            self.feature_types,
            self.test_path,
            dup_num=test_num,
        )

        self.output_sample_features_to_file(
            self.unsup_sample_feature,
            self.unsup_feature_types,
            self.unsup_path,
            dup_num=unsup_num,
        )

        train_dataset = tx.data.RecordData(hparams=self.train_hparam,
                                           device=torch.device("cpu"))
        test_dataset = tx.data.RecordData(hparams=self.test_hparam,
                                          device=torch.device("cpu"))
        unsup_dataset = tx.data.RecordData(hparams=self.unsup_hparam,
                                           device=torch.device("cpu"))
        sup_iterator = tx.data.DataIterator({
            "train": train_dataset,
            "test": test_dataset,
        })
        unsup_iterator = tx.data.DataIterator({"unsup": unsup_dataset})

        def unsup_forward_fn(batch):
            orig_input = batch["input_ids"]
            aug_input = batch["aug_input_ids"]

            orig_batch_size = orig_input.size(0)
            aug_batch_size = aug_input.size(0)
            num_category = 2
            orig_logits = torch.ones(orig_batch_size, num_category)
            aug_logits = torch.ones(aug_batch_size, num_category)
            return orig_logits, aug_logits

        iterator = UDAIterator(
            sup_iterator,
            unsup_iterator,
            softmax_temperature=1.0,
            confidence_threshold=-1,
            reduction="mean",
        )

        num_epoch = 10
        iterator.switch_to_dataset_unsup("unsup")

        for epoch in range(num_epoch):
            iterator.switch_to_dataset("train", use_unsup=True)

            for batch, unsup_batch in iterator:
                orig_loss, aug_loss = unsup_forward_fn(unsup_batch)
                unsup_loss = iterator.calculate_uda_loss(orig_loss, aug_loss)
                self.assertLess(unsup_loss, 1e-5)

                sup_rank = get_rank(batch["input_ids"])
                self.assertEqual(sup_rank, 2)

                unsup_orig_rank = get_rank(unsup_batch["input_ids"])
                self.assertEqual(unsup_orig_rank, 2)

                unsup_aug_rank = get_rank(unsup_batch["aug_input_ids"])
                self.assertEqual(unsup_aug_rank, 2)

            iterator.switch_to_dataset("test", use_unsup=False)
            for batch, _ in iterator:
                sup_rank = get_rank(batch["input_ids"])
                self.assertEqual(sup_rank, 2)
Пример #7
0
def sequence_sigmoid_cross_entropy(
        labels: torch.Tensor,
        logits: torch.Tensor,
        sequence_length: Optional[torch.LongTensor],
        average_across_batch: bool = True,
        average_across_timesteps: bool = False,
        average_across_classes: bool = True,
        sum_over_batch: bool = False,
        sum_over_timesteps: bool = True,
        sum_over_classes: bool = False,
        time_major: bool = False,
        stop_gradient_to_label: bool = False) -> torch.Tensor:
    r"""Computes sigmoid cross entropy for each time step of sequence
    predictions.

    Args:
        labels: Target class distributions.

            - If :attr:`time_major` is `False` (default), this must be a
              Tensor of shape `[batch_size, max_time(, num_classes)]`.

            - If `time_major` is `True`, this must be a Tensor of shape
              `[max_time, batch_size(, num_classes)]`.

            Each row of `labels` should be a valid probability
            distribution, otherwise, the computation of the gradient will be
            incorrect.
        logits: Unscaled log probabilities having the same shape as with
            :attr:`labels`.
        sequence_length: A Tensor of shape `[batch_size]`. Time steps beyond
            the respective sequence lengths will have zero losses.
        average_across_timesteps (bool): If set, average the loss across
            the time dimension. Must not set `average_across_timesteps`
            and `sum_over_timesteps` at the same time.
        average_across_batch (bool): If set, average the loss across the
            batch dimension. Must not set `average_across_batch`'
            and `sum_over_batch` at the same time.
        average_across_classes (bool): If set, average the loss across the
            class dimension (if exists). Must not set
            `average_across_classes`' and `sum_over_classes` at
            the same time. Ignored if :attr:`logits` is a 2D Tensor.
        sum_over_timesteps (bool): If set, sum the loss across the
            time dimension. Must not set `average_across_timesteps`
            and `sum_over_timesteps` at the same time.
        sum_over_batch (bool): If set, sum the loss across the
            batch dimension. Must not set `average_across_batch`
            and `sum_over_batch` at the same time.
        sum_over_classes (bool): If set, sum the loss across the
            class dimension. Must not set `average_across_classes`
            and `sum_over_classes` at the same time. Ignored if
            :attr:`logits` is a 2D Tensor.
        time_major (bool): The shape format of the inputs. If `True`,
            :attr:`labels` and :attr:`logits` must have shape
            `[max_time, batch_size, ...]`. If `False`
            (default), they must have shape `[batch_size, max_time, ...]`.
        stop_gradient_to_label (bool): If set, gradient propagation to
            :attr:`labels` will be disabled.

    Returns:
        A Tensor containing the loss, of rank 0, 1, or 2 depending on the
        arguments
        :attr:`{average_across}/{sum_over}_{timesteps}/{batch}/{classes}`.
        For example, if the class dimension does not exist, and

        - If :attr:`sum_over_timesteps` and :attr:`average_across_batch`
          are `True` (default), the return Tensor is of rank 0.

        - If :attr:`average_across_batch` is `True` and other arguments are
          `False`, the return Tensor is of shape `[max_time]`.
    """
    if stop_gradient_to_label:
        labels = labels.detach()
    losses = F.binary_cross_entropy_with_logits(logits,
                                                labels.type(logits.dtype),
                                                reduction='none')

    rank = shapes.get_rank(logits) or shapes.get_rank(labels)

    losses = mask_and_reduce(losses,
                             sequence_length,
                             rank=rank,
                             average_across_batch=average_across_batch,
                             average_across_timesteps=average_across_timesteps,
                             average_across_remaining=average_across_classes,
                             sum_over_batch=sum_over_batch,
                             sum_over_timesteps=sum_over_timesteps,
                             sum_over_remaining=sum_over_classes,
                             time_major=time_major)

    return losses
Пример #8
0
def entropy_with_logits(logits: torch.Tensor,
                        rank: Optional[int] = None,
                        average_across_batch: bool = True,
                        average_across_remaining: bool = False,
                        sum_over_batch: bool = False,
                        sum_over_remaining: bool = True) -> torch.Tensor:
    r"""Shannon entropy given logits.

    Args:
        logits: Unscaled log probabilities of shape
            `[batch_size, d_2, ..., d_{rank-1}, distribution_dim]`
            and of dtype `float32` or `float64`.

            The rank of the tensor is optionally specified by the argument
            :attr:`rank`.

            The tensor is considered as having `[batch_size, .., d_{rank-1}]`
            elements, each of which has a distribution of length `d_rank`
            (i.e., `distribution_dim`). So the last dimension is always
            summed out to compute the entropy.
        rank (int, optional): The rank of :attr:`logits`.
            If `None` (default), `rank` is inferred automatically from
            `logits`. If the inference fails, `rank` is
            set to 2, i.e., assuming :attr:`logits` is of shape
            `[batch_size, distribution_dim]`
        average_across_batch (bool): If set, average the entropy across the
            batch dimension. Must not set `average_across_batch`'
            and `sum_over_batch` at the same time.
        average_across_remaining (bool): If set, average the entropy across the
            remaining dimensions. Must not set `average_across_remaining`'
            and `sum_over_remaining` at the same time.
            Used only when :attr:`logits` has rank >= 3.
        sum_over_batch (bool): If set, sum the entropy across the
            batch dimension. Must not set `average_across_batch`
            and `sum_over_batch` at the same time.
        sum_over_remaining (bool): If set, sum the entropy across the
            remaining dimension. Must not set `average_across_remaining`
            and `sum_over_remaining` at the same time.
            Used only when :attr:`logits` has rank >= 3.

    Returns:
        A Tensor containing the Shannon entropy. The dimensionality of the
        Tensor depends on the configuration of reduction arguments. For
        example, if both batch and remaining dimensions are reduced (by
        either sum or average), the returned Tensor is a scalar Tensor.
    """
    entropy = _get_entropy(logits)

    if rank is None:
        rank = get_rank(logits)
    if rank is None:
        rank = 2
    rank -= 1

    if average_across_batch and sum_over_batch:
        raise ValueError("Only one of `average_across_batch` and "
                         "`sum_over_batch` can be set.")
    if average_across_remaining and sum_over_remaining:
        raise ValueError("Only one of `average_across_remaining` and "
                         "`sum_over_remaining` can be set.")
    sum_axes, average_axes = [], []
    if sum_over_batch:
        sum_axes.append(0)
    if average_across_batch:
        average_axes.append(0)
    if sum_over_remaining and rank >= 2:
        sum_axes += list(range(1, rank))
    if average_across_remaining and rank >= 2:
        average_axes += list(range(1, rank))

    entropy = reduce_dimensions(entropy,
                                average_axes=average_axes,
                                sum_axes=sum_axes)

    return entropy
Пример #9
0
def sequence_entropy_with_logits(logits: torch.Tensor,
                                 rank: Optional[int] = None,
                                 sequence_length: Optional[
                                     torch.LongTensor] = None,
                                 average_across_batch: bool = True,
                                 average_across_timesteps: bool = False,
                                 average_across_remaining: bool = False,
                                 sum_over_batch: bool = False,
                                 sum_over_timesteps: bool = True,
                                 sum_over_remaining: bool = True,
                                 time_major: bool = False) -> torch.Tensor:
    r"""Shannon entropy given logits.

    Args:
        logits: Unscaled log probabilities of shape
            `[batch_size, max_time, d_3, ..., d_{rank-1}, distribution_dim]`
            and of dtype `float32` or `float64`.

            The rank of the tensor is optionally specified by the argument
            :attr:`rank`.

            The tensor is considered as having `[batch_size, .., d_{rank-1}]`
            elements, each of which has a distribution of length `d_rank`
            (i.e., `distribution_dim`). So the last dimension is always
            summed out to compute the entropy.

            The batch and time dimensions are exchanged if :attr:`time_major`
            is `True`.
        rank (int, optional): The rank of :attr:`logits`.
            If `None` (default), `rank` is inferred automatically from
            `logits`. If the inference fails, `rank` is
            set to 3, i.e., assuming `logits` is of shape
            `[batch_size, max_time, distribution_dim]`
        sequence_length (optional): A Tensor of shape `[batch_size]`.
            Time steps beyond the respective sequence lengths are
            counted into the entropy.
        average_across_timesteps (bool): If set, average the entropy across
            the time dimension. Must not set `average_across_timesteps`
            and `sum_over_timesteps` at the same time.
        average_across_batch (bool): If set, average the entropy across the
            batch dimension. Must not set `average_across_batch`'
            and `sum_over_batch` at the same time.
        average_across_remaining (bool): If set, average the entropy across the
            remaining dimensions. Must not set `average_across_remaining`'
            and `sum_over_remaining` at the same time.
            Used only when :attr:`logits` has rank >= 4.
        sum_over_timesteps (bool): If set, sum the entropy across the
            time dimension. Must not set `average_across_timesteps`
            and `sum_over_timesteps` at the same time.
        sum_over_batch (bool): If set, sum the entropy across the
            batch dimension. Must not set `average_across_batch`
            and `sum_over_batch` at the same time.
        sum_over_remaining (bool): If set, sum the entropy across the
            remaining dimension. Must not set `average_across_remaining`
            and `sum_over_remaining` at the same time.
            Used only when :attr:`logits` has rank >= 4.
        time_major (bool): The shape format of the inputs. If `True`,
            :attr:`logits` must have shape `[max_time, batch_size, ...]`.
            If `False` (default), it must have shape
            `[batch_size, max_time, ...]`.

    Returns:
        A Tensor containing the Shannon entropy. The dimensionality of the
        Tensor depends on the configuration of reduction arguments. For
        example, if batch, time, and remaining dimensions are all reduced (by
        either sum or average), the returned Tensor is a scalar Tensor.
    """
    entropy = _get_entropy(logits)

    if rank is None:
        rank = get_rank(logits)
    if rank is None:
        rank = 3
    rank -= 1

    entropy = mask_and_reduce(
        entropy,
        sequence_length,
        rank=rank,
        average_across_batch=average_across_batch,
        average_across_timesteps=average_across_timesteps,
        average_across_remaining=average_across_remaining,
        sum_over_batch=sum_over_batch,
        sum_over_timesteps=sum_over_timesteps,
        sum_over_remaining=sum_over_remaining,
        time_major=time_major)

    return entropy
Пример #10
0
def pg_loss_with_log_probs(log_probs: torch.Tensor,
                           advantages: torch.Tensor,
                           rank: Optional[int] = None,
                           batched: bool = False,
                           sequence_length: Optional[torch.LongTensor] = None,
                           average_across_batch: bool = True,
                           average_across_timesteps: bool = False,
                           average_across_remaining: bool = False,
                           sum_over_batch: bool = False,
                           sum_over_timesteps: bool = True,
                           sum_over_remaining: bool = True,
                           time_major: bool = False) -> torch.Tensor:
    r"""Policy gradient loss with log probabilities of actions.

    `pg_loss = reduce(advantages * -log_probs)`,
    where `advantages` does not back-propagate gradients.

    All arguments except :attr:`log_probs` are the same as
    :func:`pg_loss_with_logits`.

    Args:
        log_probs: Log probabilities of shape
            `[(batch_size,) max_time, ..., d_rank]` and dtype `float32`
            or `float64`. The rank of the Tensor is specified
            with :attr:`rank`.

            The batch dimension exists only if :attr:`batched` is `True`.

            The batch and time dimensions are exchanged, i.e.,
            `[max_time, batch_size, ...]` if :attr:`time_major` is `True`.
        advantages: Tensor of shape
            `[(batch_size,) max_time, d_3, ..., d_rank]` and
            dtype `float32` or `float64`.
            The batch dimension exists only if `batched` is `True`.
            The batch and time dimensions
            are exchanged if `time_major` is `True`.
        rank (int, optional): The rank of :attr:`log_probs`.
            If `None` (default), rank is automatically inferred from
            `log_probs` or `advantages`. If the inference fails,
            `rank` is set to 1 if `batched``==False`,
            and set to 2 if `batched``==True`.
        batched (bool): `True` if the inputs are batched.
        sequence_length (optional): A Tensor of shape `[batch_size]`.
            Time steps beyond the respective sequence lengths will have zero
            losses. Used if :attr:`batched` is `True`.
        average_across_timesteps (bool): If set, average the loss across
            the time dimension. Must not set `average_across_timesteps`
            and `sum_over_timesteps` at the same time.
        average_across_batch (bool): If set, average the loss across the
            batch dimension. Must not set `average_across_batch`'
            and `sum_over_batch` at the same time.
            Ignored if `batched` is `False`.
        average_across_remaining (bool): If set, average the sequence across the
            remaining dimensions. Must not set `average_across_remaining`'
            and `sum_over_remaining` at the same time. Ignored if
            no more dimensions other than the batch and time dimensions.
        sum_over_timesteps (bool): If set, sum the loss across the
            time dimension. Must not set `average_across_timesteps`
            and `sum_over_timesteps` at the same time.
        sum_over_batch (bool): If set, sum the loss across the
            batch dimension. Must not set `average_across_batch`
            and `sum_over_batch` at the same time.
            Ignored if `batched` is `False`.
        sum_over_remaining (bool): If set, sum the loss across the
            remaining dimension. Must not set `average_across_remaining`
            and `sum_over_remaining` at the same time. Ignored if
            no more dimensions other than the batch and time dimensions.
        time_major (bool): The shape format of the inputs. If `True`,
            :attr:`log_probs` and :attr:`advantages` must have shape
            `[max_time, batch_size, ...]`. If `False` (default),
            they must have shape `[batch_size, max_time, ...]`.
            Ignored if :attr:`batched` is `False`.

    Returns:
        A Tensor containing the loss to minimize, whose rank depends on the
        reduce arguments. For example, the batch dimension is reduced if
        either :attr:`average_across_batch` or :attr:`sum_over_batch` is
        `True`, which decreases the rank of output tensor by 1.
    """
    advantages = advantages.detach()

    losses = -log_probs * advantages

    if rank is None:
        rank = get_rank(log_probs) or get_rank(advantages)
    if rank is None:
        rank = 2 if batched else 1

    if batched:
        losses = mask_and_reduce(
            losses,
            sequence_length,
            rank=rank,
            average_across_batch=average_across_batch,
            average_across_timesteps=average_across_timesteps,
            average_across_remaining=average_across_remaining,
            sum_over_batch=sum_over_batch,
            sum_over_timesteps=sum_over_timesteps,
            sum_over_remaining=sum_over_remaining,
            time_major=time_major)
    elif rank > 1:
        if average_across_remaining and sum_over_remaining:
            raise ValueError("Only one of `average_across_remaining` and "
                             "`sum_over_remaining` can be set.")
        if average_across_remaining:
            for average_axis in sorted(list(range(1, rank)), reverse=True):
                losses = torch.mean(losses, dim=average_axis)
        elif sum_over_remaining:
            for sum_axis in sorted(list(range(1, rank)), reverse=True):
                losses = torch.sum(losses, dim=sum_axis)

    if not batched:
        if average_across_timesteps and sum_over_timesteps:
            raise ValueError("Only one of `average_across_timesteps` and "
                             "`sum_over_timesteps` can be set.")
        if average_across_timesteps:
            losses = torch.mean(losses, dim=0)
        elif sum_over_timesteps:
            losses = torch.sum(losses, dim=0)

    return losses