Exemple #1
0
def entropy_with_logits(logits,
                        rank=None,
                        average_across_batch=True,
                        average_across_remaining=False,
                        sum_over_batch=False,
                        sum_over_remaining=True):
    """Shannon entropy given logits.

    Args:
        logits: Unscaled log probabilities of shape
            `[batch_size, d_2, ..., d_{rank-1}, distribution_dim]`
            and of dtype `float32` or `float64`.

            The rank of the tensor is optionally specified by the argument
            :attr:`rank`.

            The tensor is considered as having `[batch_size, .., d_{rank-1}]`
            elements, each of which has a distribution of length `d_rank`
            (i.e., `distribution_dim`). So the last dimension is always
            summed out to compute the entropy.
        rank (int, optional): The rank of :attr:`logits`.
            If `None` (default), :attr:`rank` is inferred automatically from
            :attr:`logits`. If the inferred rank is `None`, :attr:`rank` is
            set to 2, i.e., assuming :attr:`logits` is of shape
            `[batch_size, distribution_dim]`
        average_across_batch (bool): If set, average the entropy across the
            batch dimension. Must not set :attr:`average_across_batch`'
            and :attr:`sum_over_batch` at the same time.
        average_across_remaining (bool): If set, average the entropy across the
            remaining dimensions. Must not set :attr:`average_across_remaining`'
            and :attr:`sum_over_remaining` at the same time.
            Used only when :attr:`logits` has rank >= 3.
        sum_over_batch (bool): If set, sum the entropy across the
            batch dimension. Must not set :attr:`average_across_batch`
            and :attr:`sum_over_batch` at the same time.
        sum_over_remaining (bool): If set, sum the entropy across the
            remaining dimension. Must not set :attr:`average_across_remaining`
            and :attr:`sum_over_remaining` at the same time.
            Used only when :attr:`logits` has rank >= 3.
    """
    entropy = _get_entropy(logits)

    if rank is None:
        rank = get_rank(logits)
    if rank is None:
        rank = 2
    rank -= 1 # reduced last dimension

    # Reduces
    if average_across_batch and sum_over_batch:
        raise ValueError("Only one of `average_across_batch` and "
                         "`sum_over_batch` can be set.")
    if average_across_remaining and sum_over_remaining:
        raise ValueError("Only one of `average_across_remaining` and "
                         "`sum_over_remaining` can be set.")
    sum_axes, average_axes = [], []
    if sum_over_batch:
        sum_axes.append(0)
    if average_across_batch:
        average_axes.append(0)
    if sum_over_remaining and rank >= 2:
        sum_axes += list(range(1, rank))
    if average_across_remaining and rank >= 2:
        average_axes += list(range(1, rank))

    entropy = reduce_dimensions(
        entropy, average_axes=average_axes, sum_axes=sum_axes)

    return entropy
Exemple #2
0
def entropy_with_logits(logits: torch.Tensor,
                        rank: Optional[int] = None,
                        average_across_batch: bool = True,
                        average_across_remaining: bool = False,
                        sum_over_batch: bool = False,
                        sum_over_remaining: bool = True) -> torch.Tensor:
    r"""Shannon entropy given logits.

    Args:
        logits: Unscaled log probabilities of shape
            `[batch_size, d_2, ..., d_{rank-1}, distribution_dim]`
            and of dtype `float32` or `float64`.

            The rank of the tensor is optionally specified by the argument
            :attr:`rank`.

            The tensor is considered as having `[batch_size, .., d_{rank-1}]`
            elements, each of which has a distribution of length `d_rank`
            (i.e., `distribution_dim`). So the last dimension is always
            summed out to compute the entropy.
        rank (int, optional): The rank of :attr:`logits`.
            If `None` (default), `rank` is inferred automatically from
            `logits`. If the inference fails, `rank` is
            set to 2, i.e., assuming :attr:`logits` is of shape
            `[batch_size, distribution_dim]`
        average_across_batch (bool): If set, average the entropy across the
            batch dimension. Must not set `average_across_batch`'
            and `sum_over_batch` at the same time.
        average_across_remaining (bool): If set, average the entropy across the
            remaining dimensions. Must not set `average_across_remaining`'
            and `sum_over_remaining` at the same time.
            Used only when :attr:`logits` has rank >= 3.
        sum_over_batch (bool): If set, sum the entropy across the
            batch dimension. Must not set `average_across_batch`
            and `sum_over_batch` at the same time.
        sum_over_remaining (bool): If set, sum the entropy across the
            remaining dimension. Must not set `average_across_remaining`
            and `sum_over_remaining` at the same time.
            Used only when :attr:`logits` has rank >= 3.

    Returns:
        A Tensor containing the Shannon entropy. The dimensionality of the
        Tensor depends on the configuration of reduction arguments. For
        example, if both batch and remaining dimensions are reduced (by
        either sum or average), the returned Tensor is a scalar Tensor.
    """
    entropy = _get_entropy(logits)

    if rank is None:
        rank = get_rank(logits)
    if rank is None:
        rank = 2
    rank -= 1

    if average_across_batch and sum_over_batch:
        raise ValueError("Only one of `average_across_batch` and "
                         "`sum_over_batch` can be set.")
    if average_across_remaining and sum_over_remaining:
        raise ValueError("Only one of `average_across_remaining` and "
                         "`sum_over_remaining` can be set.")
    sum_axes, average_axes = [], []
    if sum_over_batch:
        sum_axes.append(0)
    if average_across_batch:
        average_axes.append(0)
    if sum_over_remaining and rank >= 2:
        sum_axes += list(range(1, rank))
    if average_across_remaining and rank >= 2:
        average_axes += list(range(1, rank))

    entropy = reduce_dimensions(entropy,
                                average_axes=average_axes,
                                sum_axes=sum_axes)

    return entropy
def binary_sigmoid_cross_entropy(pos_logits=None,
                                 neg_logits=None,
                                 average_across_batch=True,
                                 average_across_classes=True,
                                 sum_over_batch=False,
                                 sum_over_classes=False,
                                 return_pos_neg_losses=False,
                                 name=None):
    """Computes sigmoid cross entropy of binary predictions.

    Args:
        pos_logits: The logits of predicting positive on positive data. A
            tensor of shape `[batch_size(, num_classes)]`.
        neg_logits: The logits of predicting positive on negative data. A
            tensor of shape `[batch_size(, num_classes)]`.
        average_across_batch (bool): If set, average the loss across the
            batch dimension. Must not set `average_across_batch`'
            and `sum_over_batch` at the same time.
        average_across_classes (bool): If set, average the loss across the
            class dimension (if exists). Must not set
            `average_across_classes`' and `sum_over_classes` at
            the same time. Ignored if :attr:`logits` is a 1D Tensor.
        sum_over_batch (bool): If set, sum the loss across the
            batch dimension. Must not set `average_across_batch`
            and `sum_over_batch` at the same time.
        sum_over_classes (bool): If set, sum the loss across the
            class dimension. Must not set `average_across_classes`
            and `sum_over_classes` at the same time. Ignored if
            :attr:`logits` is a 2D Tensor.
        return_pos_neg_losses (bool): If set, additionally returns the losses
            on :attr:`pos_logits` and :attr:`neg_logits`, respectively.
        name (str, optional): A name for the operation.

    Returns:
        By default, a Tensor containing the loss, of rank 0, 1, or 2 depending
        on the arguments :attr:`{average_across}/{sum_over}_{batch}/{classes}`.
        For example:

            - If :attr:`sum_over_batch` and :attr:`average_across_classes`  \
            are `True` (default), the return Tensor is of rank 0.

            - If  arguments are `False`, the return Tensor is of shape \
            `[batch_size(, num_classes)]`.

        If :attr:`return_pos_neg_losses` is `True`, returns a tuple
        `(loss, pos_loss, neg_loss)`, where `loss` is the loss above;
        `pos_loss` is the loss on `pos_logits` only; and
        `neg_loss` is the loss on `neg_logits` only. They have
        `loss = pos_loss + neg_loss`.
    """
    with tf.name_scope(name, "binary_sigmoid_cross_entropy"):
        average_axes, sum_axes = [], []
        average_axes += [0] if average_across_batch else []
        average_axes += [1] if average_across_classes else []
        sum_axes += [0] if sum_over_batch else []
        sum_axes += [1] if sum_over_classes else []

        pos_loss = 0
        if pos_logits is not None:
            pos_loss = tf.nn.sigmoid_cross_entropy_with_logits(
                logits=pos_logits, labels=tf.ones_like(pos_logits))

            pos_loss = reduce_dimensions(pos_loss, average_axes, sum_axes)

        neg_loss = 0
        if neg_logits is not None:
            neg_loss = tf.nn.sigmoid_cross_entropy_with_logits(
                logits=neg_logits, labels=tf.zeros_like(neg_logits))

            neg_loss = reduce_dimensions(neg_loss, average_axes, sum_axes)

    loss = pos_loss + neg_loss

    if return_pos_neg_losses:
        return loss, pos_loss, neg_loss
    else:
        return loss
Exemple #4
0
def binary_sigmoid_cross_entropy(
        pos_logits: Optional[torch.Tensor] = None,
        neg_logits: Optional[torch.Tensor] = None,
        average_across_batch: bool = True,
        average_across_classes: bool = True,
        sum_over_batch: bool = False,
        sum_over_classes: bool = False,
        return_pos_neg_losses: bool = False) \
        -> Union[Tuple[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]:
    r"""Computes sigmoid cross entropy of binary predictions.

    Args:
        pos_logits: The logits of predicting positive on positive data. A
            tensor of shape `[batch_size(, num_classes)]`.
        neg_logits: The logits of predicting positive on negative data. A
            tensor of shape `[batch_size(, num_classes)]`.
        average_across_batch (bool): If set, average the loss across the
            batch dimension. Must not set `average_across_batch`'
            and `sum_over_batch` at the same time.
        average_across_classes (bool): If set, average the loss across the
            class dimension (if exists). Must not set
            `average_across_classes`' and `sum_over_classes` at
            the same time. Ignored if :attr:`logits` is a 1D Tensor.
        sum_over_batch (bool): If set, sum the loss across the
            batch dimension. Must not set `average_across_batch`
            and `sum_over_batch` at the same time.
        sum_over_classes (bool): If set, sum the loss across the
            class dimension. Must not set `average_across_classes`
            and `sum_over_classes` at the same time. Ignored if
            :attr:`logits` is a 2D Tensor.
        return_pos_neg_losses (bool): If set, additionally returns the losses
            on :attr:`pos_logits` and :attr:`neg_logits`, respectively.

    Returns:
        By default, a Tensor containing the loss, of rank 0, 1, or 2 depending
        on the arguments :attr:`{average_across}/{sum_over}_{batch}/{classes}`.
        For example:

            - If :attr:`sum_over_batch` and :attr:`average_across_classes`
              are `True` (default), the return Tensor is of rank 0.

            - If  arguments are `False`, the return Tensor is of shape
              `[batch_size(, num_classes)]`.

        If :attr:`return_pos_neg_losses` is `True`, returns a tuple
        `(loss, pos_loss, neg_loss)`, where `loss` is the loss above;
        `pos_loss` is the loss on `pos_logits` only; and
        `neg_loss` is the loss on `neg_logits` only. They have
        `loss = pos_loss + neg_loss`.
    """
    average_axes = [0] if average_across_batch else []
    average_axes += [1] if average_across_classes else []
    sum_axes = [0] if sum_over_batch else []
    sum_axes += [1] if sum_over_classes else []

    if pos_logits is not None:
        pos_loss = F.binary_cross_entropy_with_logits(
            pos_logits, torch.ones_like(pos_logits), reduction='none')

        pos_loss = reduce_dimensions(pos_loss, average_axes, sum_axes)
    else:
        pos_loss = 0  # type: ignore

    if neg_logits is not None:
        neg_loss = F.binary_cross_entropy_with_logits(
            neg_logits, torch.zeros_like(neg_logits), reduction='none')

        neg_loss = reduce_dimensions(neg_loss, average_axes, sum_axes)
    else:
        neg_loss = 0  # type: ignore

    loss = pos_loss + neg_loss

    if return_pos_neg_losses:
        return loss, pos_loss, neg_loss
    else:
        return loss