Beispiel #1
0
def sequence_entropy_with_logits(logits,
                                 rank=None,
                                 sequence_length=None,
                                 average_across_batch=True,
                                 average_across_timesteps=False,
                                 average_across_remaining=False,
                                 sum_over_batch=False,
                                 sum_over_timesteps=True,
                                 sum_over_remaining=True,
                                 time_major=False):
    """Shannon entropy given logits.

    Args:
        logits: Unscaled log probabilities of shape
            `[batch_size, max_time, d_3, ..., d_{rank-1}, distribution_dim]`
            and of dtype `float32` or `float64`.

            The rank of the tensor is optionally specified by the argument
            :attr:`rank`.

            The tensor is considered as having `[batch_size, .., d_{rank-1}]`
            elements, each of which has a distribution of length `d_rank`
            (i.e., `distribution_dim`). So the last dimension is always
            summed out to compute the entropy.

            The batch and time dimensions are exchanged if :attr:`time_major`
            is `True`.
        rank (int, optional): The rank of :attr:`logits`.
            If `None` (default), :attr:`rank` is inferred automatically from
            :attr:`logits`. If the inferred rank is `None`, :attr:`rank` is
            set to 3, i.e., assuming :attr:`logits` is of shape
            `[batch_size, max_time, distribution_dim]`
        sequence_length (optional): A Tensor of shape `[batch_size]`.
            Time steps beyond the respective sequence lengths are
            counted into the entropy.
        average_across_timesteps (bool): If set, average the entropy across
            the time dimension. Must not set :attr:`average_across_timesteps`
            and :attr:`sum_over_timesteps` at the same time.
        average_across_batch (bool): If set, average the entropy across the
            batch dimension. Must not set :attr:`average_across_batch`'
            and :attr:`sum_over_batch` at the same time.
        average_across_remaining (bool): If set, average the entropy across the
            remaining dimensions. Must not set :attr:`average_across_remaining`'
            and :attr:`sum_over_remaining` at the same time.
            Used only when :attr:`logits` has rank >= 4.
        sum_over_timesteps (bool): If set, sum the entropy across the
            time dimension. Must not set :attr:`average_across_timesteps`
            and :attr:`sum_over_timesteps` at the same time.
        sum_over_batch (bool): If set, sum the entropy across the
            batch dimension. Must not set :attr:`average_across_batch`
            and :attr:`sum_over_batch` at the same time.
        sum_over_remaining (bool): If set, sum the entropy across the
            remaining dimension. Must not set :attr:`average_across_remaining`
            and :attr:`sum_over_remaining` at the same time.
            Used only when :attr:`logits` has rank >= 4.
        time_major (bool): The shape format of the inputs. If `True`,
            :attr:`logits` must have shape `[max_time, batch_size, ...]`.
            If `False` (default), it must have shape
            `[batch_size, max_time, ...]`.
    """
    entropy = _get_entropy(logits)

    if rank is None:
        rank = get_rank(logits)
    if rank is None:
        rank = 3
    rank -= 1 # reduced last dimension

    entropy = mask_and_reduce(
        entropy,
        sequence_length,
        rank=rank,
        average_across_batch=average_across_batch,
        average_across_timesteps=average_across_timesteps,
        average_across_remaining=average_across_remaining,
        sum_over_batch=sum_over_batch,
        sum_over_timesteps=sum_over_timesteps,
        sum_over_remaining=sum_over_remaining,
        time_major=time_major)

    return entropy
Beispiel #2
0
def pg_loss_with_log_probs(log_probs,
                           advantages,
                           rank=None,
                           batched=False,
                           sequence_length=None,
                           average_across_batch=True,
                           average_across_timesteps=False,
                           average_across_remaining=False,
                           sum_over_batch=False,
                           sum_over_timesteps=True,
                           sum_over_remaining=True,
                           time_major=False):
    """Policy gradient loss with log probs of actions.

    `pg_loss = reduce( advantages * -log_probs )`,
    where `advantages` does not back-propagate gradients.

    All arguments except :attr:`log_probs` are the same as
    :func:`pg_loss_with_logits`.

    Args:
        log_probs: Log probabilities of shape
            `[(batch_size,) max_time, ..., d_rank]` and dtype `float32`
            or `float64`. The rank of the Tensor is specified
            with :attr:`rank`.

            The batch dimension exists only if :attr:`batched` is `True`.

            The batch and time dimensions are exchanged, i.e.,
            `[max_time, batch_size, ...]` if :attr:`time_major` is `True`.
        advantages: Tensor of shape
            `[(batch_size,) max_time, d_3, ..., d_rank]` and
            dtype `float32` or `float64`.
            The batch dimension exists only if `batched` is `True`.
            The batch and time dimensions
            are exchanged if `time_major` is `True`.
        rank (int, optional): The rank of :attr:`log_probs`.
            If `None` (default), rank is automatically inferred from
            `log_probs` or `advantages`. If the inference fails,
            `rank` is set to 1 if `batched``==False`,
            and set to 2 if `batched``==True`.
        batched (bool): `True` if the inputs are batched.
        sequence_length (optional): A Tensor of shape `[batch_size]`.
            Time steps beyond the respective sequence lengths will have zero
            losses. Used if :attr:`batched` is `True`.
        average_across_timesteps (bool): If set, average the loss across
            the time dimension. Must not set `average_across_timesteps`
            and `sum_over_timesteps` at the same time.
        average_across_batch (bool): If set, average the loss across the
            batch dimension. Must not set `average_across_batch`'
            and `sum_over_batch` at the same time.
            Ignored if `batched` is `False`.
        average_across_remaining (bool): If set, average the sequence across the
            remaining dimensions. Must not set `average_across_remaining`'
            and `sum_over_remaining` at the same time. Ignored if
            no more dimensions other than the batch and time dimensions.
        sum_over_timesteps (bool): If set, sum the loss across the
            time dimension. Must not set `average_across_timesteps`
            and `sum_over_timesteps` at the same time.
        sum_over_batch (bool): If set, sum the loss across the
            batch dimension. Must not set `average_across_batch`
            and `sum_over_batch` at the same time.
            Ignored if `batched` is `False`.
        sum_over_remaining (bool): If set, sum the loss across the
            remaining dimension. Must not set `average_across_remaining`
            and `sum_over_remaining` at the same time. Ignored if
            no more dimensions other than the batch and time dimensions.
        time_major (bool): The shape format of the inputs. If `True`,
            :attr:`log_probs` and :attr:`advantages` must have shape
            `[max_time, batch_size, ...]`. If `False` (default),
            they must have shape `[batch_size, max_time, ...]`.
            Ignored if :attr:`batched` is `False`.

    Returns:
        A Tensor containing the loss to minimize, whose rank depends on the
        reduce arguments. For example, the batch dimension is reduced if
        either :attr:`average_across_batch` or :attr:`sum_over_batch` is
        `True`, which decreases the rank of output tensor by 1.
    """
    advantages = tf.stop_gradient(advantages)

    losses = -log_probs * advantages

    if rank is None:
        rank = get_rank(log_probs) or get_rank(advantages)
    if rank is None:
        rank = 2 if batched else 1

    if batched:
        losses = mask_and_reduce(
            losses,
            sequence_length,
            rank=rank,
            average_across_batch=average_across_batch,
            average_across_timesteps=average_across_timesteps,
            average_across_remaining=average_across_remaining,
            sum_over_batch=sum_over_batch,
            sum_over_timesteps=sum_over_timesteps,
            sum_over_remaining=sum_over_remaining,
            time_major=time_major)
    elif rank > 1:
        if average_across_remaining and sum_over_remaining:
            raise ValueError("Only one of `average_across_remaining` and "
                             "`sum_over_remaining` can be set.")
        if average_across_remaining:
            losses = tf.reduce_mean(losses, axis=list(range(1, rank)))
        elif sum_over_remaining:
            losses = tf.reduce_sum(losses, axis=list(range(1, rank)))

    if not batched:
        if average_across_timesteps and sum_over_timesteps:
            raise ValueError("Only one of `average_across_timesteps` and "
                             "`sum_over_timesteps` can be set.")
        if average_across_timesteps:
            losses = tf.reduce_mean(losses, axis=0)
        elif sum_over_timesteps:
            losses = tf.reduce_sum(losses, axis=0)

    return losses
def sequence_softmax_cross_entropy(labels,
                                   logits,
                                   sequence_length,
                                   average_across_batch=True,
                                   average_across_timesteps=False,
                                   sum_over_batch=False,
                                   sum_over_timesteps=True,
                                   time_major=False,
                                   stop_gradient_to_label=False,
                                   name=None):
    """Computes softmax cross entropy for each time step of sequence
    predictions.

    Args:
        labels: Target class distributions.

            - If :attr:`time_major` is `False` (default), this must be a\
            Tensor of shape `[batch_size, max_time, num_classes]`.

            - If `time_major` is `True`, this must be a Tensor of shape\
            `[max_time, batch_size, num_classes]`.

            Each row of `labels` should be a valid probability
            distribution, otherwise, the computation of the gradient will be
            incorrect.
        logits: Unscaled log probabilities. This must have the shape of
            `[max_time, batch_size, num_classes]` or
            `[batch_size, max_time, num_classes]` according to
            the value of `time_major`.
        sequence_length: A Tensor of shape `[batch_size]`. Time steps beyond
            the respective sequence lengths will have zero losses.
        average_across_timesteps (bool): If set, average the loss across
            the time dimension. Must not set `average_across_timesteps`
            and `sum_over_timesteps` at the same time.
        average_across_batch (bool): If set, average the loss across the
            batch dimension. Must not set `average_across_batch`'
            and `sum_over_batch` at the same time.
        sum_over_timesteps (bool): If set, sum the loss across the
            time dimension. Must not set `average_across_timesteps`
            and `sum_over_timesteps` at the same time.
        sum_over_batch (bool): If set, sum the loss across the
            batch dimension. Must not set `average_across_batch`
            and `sum_over_batch` at the same time.
        time_major (bool): The shape format of the inputs. If `True`,
            :attr:`labels` and :attr:`logits` must have shape
            `[max_time, batch_size, ...]`. If `False`
            (default), they must have shape `[batch_size, max_time, ...]`.
        stop_gradient_to_label (bool): If set, gradient propagation to
            :attr:`labels` will be disabled.
        name (str, optional): A name for the operation.

    Returns:
        A Tensor containing the loss, of rank 0, 1, or 2 depending on the
        arguments :attr:`{average_across}/{sum_over}_{timesteps}/{batch}`.
        For example:

        - If :attr:`sum_over_timesteps` and :attr:`average_across_batch`  \
        are `True` (default), the return Tensor is of rank 0.

        - If :attr:`average_across_batch` is `True` and other arguments are \
        `False`, the return Tensor is of shape `[max_time]`.
    """
    with tf.name_scope(name, "sequence_softmax_cross_entropy"):
        if stop_gradient_to_label:
            labels = tf.stop_gradient(labels)

        losses = tf.nn.softmax_cross_entropy_with_logits_v2(labels=labels,
                                                            logits=logits)

        losses = mask_and_reduce(
            losses,
            sequence_length,
            rank=2,
            average_across_batch=average_across_batch,
            average_across_timesteps=average_across_timesteps,
            sum_over_batch=sum_over_batch,
            sum_over_timesteps=sum_over_timesteps,
            time_major=time_major)

        return losses
def sequence_sparse_softmax_cross_entropy(labels,
                                          logits,
                                          sequence_length,
                                          average_across_batch=True,
                                          average_across_timesteps=False,
                                          sum_over_batch=False,
                                          sum_over_timesteps=True,
                                          time_major=False,
                                          name=None):
    """Computes sparse softmax cross entropy for each time step of sequence
    predictions.

    Args:
        labels: Target class indexes. I.e., classes are mutually exclusive
            (each entry is in exactly one class).

            - If :attr:`time_major` is `False` (default), this must be\
            a Tensor of shape `[batch_size, max_time]`.

            - If `time_major` is `True`, this must be a Tensor of shape\
            `[max_time, batch_size].`
        logits: Unscaled log probabilities. This must have the shape of
            `[max_time, batch_size, num_classes]` or
            `[batch_size, max_time, num_classes]` according to
            the value of `time_major`.
        sequence_length: A Tensor of shape `[batch_size]`. Time steps beyond
            the respective sequence lengths will have zero losses.
        average_across_timesteps (bool): If set, average the loss across
            the time dimension. Must not set `average_across_timesteps`
            and `sum_over_timesteps` at the same time.
        average_across_batch (bool): If set, average the loss across the
            batch dimension. Must not set `average_across_batch`'
            and `sum_over_batch` at the same time.
        sum_over_timesteps (bool): If set, sum the loss across the
            time dimension. Must not set `average_across_timesteps`
            and `sum_over_timesteps` at the same time.
        sum_over_batch (bool): If set, sum the loss across the
            batch dimension. Must not set `average_across_batch`
            and `sum_over_batch` at the same time.
        time_major (bool): The shape format of the inputs. If `True`,
            :attr:`labels` and :attr:`logits` must have shape
            `[max_time, batch_size, ...]`. If `False`
            (default), they must have shape `[batch_size, max_time, ...]`.
        name (str, optional): A name for the operation.

    Returns:
        A Tensor containing the loss, of rank 0, 1, or 2 depending on the
        arguments :attr:`{average_across}/{sum_over}_{timesteps}/{batch}`.
        For example:

        - If :attr:`sum_over_timesteps` and :attr:`average_across_batch`  \
        are `True` (default), the return Tensor is of rank 0.

        - If :attr:`average_across_batch` is `True` and other arguments are \
        `False`, the return Tensor is of shape `[max_time]`.

    Example:

        .. code-block:: python

            embedder = WordEmbedder(vocab_size=data.vocab.size)
            decoder = BasicRNNDecoder(vocab_size=data.vocab.size)
            outputs, _, _ = decoder(
                decoding_strategy='train_greedy',
                inputs=embedder(data_batch['text_ids']),
                sequence_length=data_batch['length']-1)

            loss = sequence_sparse_softmax_cross_entropy(
                labels=data_batch['text_ids'][:, 1:],
                logits=outputs.logits,
                sequence_length=data_batch['length']-1)

    """
    with tf.name_scope(name, "sequence_sparse_softmax_cross_entropy"):
        losses = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=labels,
                                                                logits=logits)

        losses = mask_and_reduce(
            losses,
            sequence_length,
            rank=2,
            average_across_batch=average_across_batch,
            average_across_timesteps=average_across_timesteps,
            sum_over_batch=sum_over_batch,
            sum_over_timesteps=sum_over_timesteps,
            time_major=time_major)

        return losses
Beispiel #5
0
def sequence_sigmoid_cross_entropy(labels,
                                   logits,
                                   sequence_length,
                                   average_across_batch=True,
                                   average_across_timesteps=False,
                                   average_across_classes=True,
                                   sum_over_batch=False,
                                   sum_over_timesteps=True,
                                   sum_over_classes=False,
                                   time_major=False,
                                   stop_gradient_to_label=False,
                                   name=None):
    """Computes sigmoid cross entropy for each time step of sequence
    predictions.

    Args:
        labels: Target class distributions.

            If :attr:`time_major` is `False` (default), this must be a
                Tensor of shape `[batch_size, max_time(, num_classes)]`.

            If :attr:`time_major` is `True`, this must be a Tensor of shape
                `[max_time, batch_size(, num_classes)]`.

            Each row of :attr:`labels` should be a valid probability
            distribution, otherwise, the computation of the gradient will be
            incorrect.
        logits: Unscaled log probabilities having the same shape as with
            :attr:`labels`.
        sequence_length: A Tensor of shape `[batch_size]`. Time steps beyond
            the respective sequence lengths will have zero losses.
        average_across_timesteps (bool): If set, average the loss across
            the time dimension. Must not set :attr:`average_across_timesteps`
            and :attr:`sum_over_timesteps` at the same time.
        average_across_batch (bool): If set, average the loss across the
            batch dimension. Must not set :attr:`average_across_batch`'
            and :attr:`sum_over_batch` at the same time.
        average_across_classes (bool): If set, average the loss across the
            class dimension (if exists). Must not set
            :attr:`average_across_classes`' and :attr:`sum_over_classes` at
            the same time. Ignored if :attr:`logits` is a 2D Tensor.
        sum_over_timesteps (bool): If set, sum the loss across the
            time dimension. Must not set :attr:`average_across_timesteps`
            and :attr:`sum_over_timesteps` at the same time.
        sum_over_batch (bool): If set, sum the loss across the
            batch dimension. Must not set :attr:`average_across_batch`
            and :attr:`sum_over_batch` at the same time.
        sum_over_classes (bool): If set, sum the loss across the
            class dimension. Must not set :attr:`average_across_classes`
            and :attr:`sum_over_classes` at the same time. Ignored if
            :attr:`logits` is a 2D Tensor.
        time_major (bool): The shape format of the inputs. If `True`,
            :attr:`labels` and :attr:`logits` must have shape
            `[max_time, batch_size, ...]`. If `False`
            (default), they must have shape `[batch_size, max_time, ...]`.
        stop_gradient_to_label (bool): If set, gradient propagation to
            :attr:`labels` will be disabled.
        name (str, optional): A name for the operation.

    Returns:
        A Tensor containing the loss, of rank 0, 1, or 2 depending on the
        arguments
        :attr:`{average_across}/{sum_over}_{timesteps}/{batch}/{classes}`.
        For example, if the class dimension does not exist, and

        - If :attr:`sum_over_timesteps` and :attr:`average_across_batch`  \
        are `True` (default), the return Tensor is of rank 0.

        - If :attr:`average_across_batch` is `True` and other arguments are \
        `False`, the return Tensor is of shape `[max_time]`.
    """

    with tf.name_scope(name, "sequence_sigmoid_cross_entropy"):
        if stop_gradient_to_label:
            labels = tf.stop_gradient(labels)

        losses = tf.nn.sigmoid_cross_entropy_with_logits(labels=labels,
                                                         logits=logits)

        rank = shapes.get_rank(logits) or shapes.get_rank(labels)
        if rank is None:
            raise ValueError(
                'Cannot determine the rank of `logits` or `labels`.')

        losses = mask_and_reduce(
            losses,
            sequence_length,
            rank=rank,
            average_across_batch=average_across_batch,
            average_across_timesteps=average_across_timesteps,
            average_across_remaining=average_across_classes,
            sum_over_batch=sum_over_batch,
            sum_over_timesteps=sum_over_timesteps,
            sum_over_remaining=sum_over_classes,
            time_major=time_major)

        return losses
Beispiel #6
0
def sequence_softmax_cross_entropy(
        labels: torch.Tensor,
        logits: torch.Tensor,
        sequence_length: Optional[torch.LongTensor],
        average_across_batch: bool = True,
        average_across_timesteps: bool = False,
        sum_over_batch: bool = False,
        sum_over_timesteps: bool = True,
        time_major: bool = False,
        stop_gradient_to_label: bool = False) -> torch.Tensor:
    r"""Computes softmax cross entropy for each time step of sequence
    predictions.

    Args:
        labels: Target class distributions.

            - If :attr:`time_major` is `False` (default), this must be a
              Tensor of shape `[batch_size, max_time, num_classes]`.

            - If `time_major` is `True`, this must be a Tensor of shape
              `[max_time, batch_size, num_classes]`.

            Each row of `labels` should be a valid probability
            distribution, otherwise, the computation of the gradient will be
            incorrect.
        logits: Unscaled log probabilities. This must have the shape of
            `[max_time, batch_size, num_classes]` or
            `[batch_size, max_time, num_classes]` according to
            the value of `time_major`.
        sequence_length: A Tensor of shape `[batch_size]`. Time steps beyond
            the respective sequence lengths will have zero losses.
        average_across_timesteps (bool): If set, average the loss across
            the time dimension. Must not set `average_across_timesteps`
            and `sum_over_timesteps` at the same time.
        average_across_batch (bool): If set, average the loss across the
            batch dimension. Must not set `average_across_batch`'
            and `sum_over_batch` at the same time.
        sum_over_timesteps (bool): If set, sum the loss across the
            time dimension. Must not set `average_across_timesteps`
            and `sum_over_timesteps` at the same time.
        sum_over_batch (bool): If set, sum the loss across the
            batch dimension. Must not set `average_across_batch`
            and `sum_over_batch` at the same time.
        time_major (bool): The shape format of the inputs. If `True`,
            :attr:`labels` and :attr:`logits` must have shape
            `[max_time, batch_size, ...]`. If `False`
            (default), they must have shape `[batch_size, max_time, ...]`.
        stop_gradient_to_label (bool): If set, gradient propagation to
            :attr:`labels` will be disabled.

    Returns:
        A Tensor containing the loss, of rank 0, 1, or 2 depending on the
        arguments :attr:`{average_across}/{sum_over}_{timesteps}/{batch}`.
        For example:

        - If :attr:`sum_over_timesteps` and :attr:`average_across_batch`
          are `True` (default), the return Tensor is of rank 0.

        - If :attr:`average_across_batch` is `True` and other arguments are
          `False`, the return Tensor is of shape `[max_time]`.
    """
    if stop_gradient_to_label:
        labels = labels.detach()

    losses = (-labels.type(logits.dtype) *
              F.log_softmax(logits, -1)).sum(dim=-1)

    losses = mask_and_reduce(losses,
                             sequence_length,
                             rank=2,
                             average_across_batch=average_across_batch,
                             average_across_timesteps=average_across_timesteps,
                             sum_over_batch=sum_over_batch,
                             sum_over_timesteps=sum_over_timesteps,
                             time_major=time_major)
    return losses
Beispiel #7
0
def sequence_sigmoid_cross_entropy(
        labels: torch.Tensor,
        logits: torch.Tensor,
        sequence_length: Optional[torch.LongTensor],
        average_across_batch: bool = True,
        average_across_timesteps: bool = False,
        average_across_classes: bool = True,
        sum_over_batch: bool = False,
        sum_over_timesteps: bool = True,
        sum_over_classes: bool = False,
        time_major: bool = False,
        stop_gradient_to_label: bool = False) -> torch.Tensor:
    r"""Computes sigmoid cross entropy for each time step of sequence
    predictions.

    Args:
        labels: Target class distributions.

            - If :attr:`time_major` is `False` (default), this must be a
              Tensor of shape `[batch_size, max_time(, num_classes)]`.

            - If `time_major` is `True`, this must be a Tensor of shape
              `[max_time, batch_size(, num_classes)]`.

            Each row of `labels` should be a valid probability
            distribution, otherwise, the computation of the gradient will be
            incorrect.
        logits: Unscaled log probabilities having the same shape as with
            :attr:`labels`.
        sequence_length: A Tensor of shape `[batch_size]`. Time steps beyond
            the respective sequence lengths will have zero losses.
        average_across_timesteps (bool): If set, average the loss across
            the time dimension. Must not set `average_across_timesteps`
            and `sum_over_timesteps` at the same time.
        average_across_batch (bool): If set, average the loss across the
            batch dimension. Must not set `average_across_batch`'
            and `sum_over_batch` at the same time.
        average_across_classes (bool): If set, average the loss across the
            class dimension (if exists). Must not set
            `average_across_classes`' and `sum_over_classes` at
            the same time. Ignored if :attr:`logits` is a 2D Tensor.
        sum_over_timesteps (bool): If set, sum the loss across the
            time dimension. Must not set `average_across_timesteps`
            and `sum_over_timesteps` at the same time.
        sum_over_batch (bool): If set, sum the loss across the
            batch dimension. Must not set `average_across_batch`
            and `sum_over_batch` at the same time.
        sum_over_classes (bool): If set, sum the loss across the
            class dimension. Must not set `average_across_classes`
            and `sum_over_classes` at the same time. Ignored if
            :attr:`logits` is a 2D Tensor.
        time_major (bool): The shape format of the inputs. If `True`,
            :attr:`labels` and :attr:`logits` must have shape
            `[max_time, batch_size, ...]`. If `False`
            (default), they must have shape `[batch_size, max_time, ...]`.
        stop_gradient_to_label (bool): If set, gradient propagation to
            :attr:`labels` will be disabled.

    Returns:
        A Tensor containing the loss, of rank 0, 1, or 2 depending on the
        arguments
        :attr:`{average_across}/{sum_over}_{timesteps}/{batch}/{classes}`.
        For example, if the class dimension does not exist, and

        - If :attr:`sum_over_timesteps` and :attr:`average_across_batch`
          are `True` (default), the return Tensor is of rank 0.

        - If :attr:`average_across_batch` is `True` and other arguments are
          `False`, the return Tensor is of shape `[max_time]`.
    """
    if stop_gradient_to_label:
        labels = labels.detach()
    losses = F.binary_cross_entropy_with_logits(logits,
                                                labels.type(logits.dtype),
                                                reduction='none')

    rank = shapes.get_rank(logits) or shapes.get_rank(labels)

    losses = mask_and_reduce(losses,
                             sequence_length,
                             rank=rank,
                             average_across_batch=average_across_batch,
                             average_across_timesteps=average_across_timesteps,
                             average_across_remaining=average_across_classes,
                             sum_over_batch=sum_over_batch,
                             sum_over_timesteps=sum_over_timesteps,
                             sum_over_remaining=sum_over_classes,
                             time_major=time_major)

    return losses