def entropy_with_logits(logits, rank=None, average_across_batch=True, average_across_remaining=False, sum_over_batch=False, sum_over_remaining=True): """Shannon entropy given logits. Args: logits: Unscaled log probabilities of shape `[batch_size, d_2, ..., d_{rank-1}, distribution_dim]` and of dtype `float32` or `float64`. The rank of the tensor is optionally specified by the argument :attr:`rank`. The tensor is considered as having `[batch_size, .., d_{rank-1}]` elements, each of which has a distribution of length `d_rank` (i.e., `distribution_dim`). So the last dimension is always summed out to compute the entropy. rank (int, optional): The rank of :attr:`logits`. If `None` (default), :attr:`rank` is inferred automatically from :attr:`logits`. If the inferred rank is `None`, :attr:`rank` is set to 2, i.e., assuming :attr:`logits` is of shape `[batch_size, distribution_dim]` average_across_batch (bool): If set, average the entropy across the batch dimension. Must not set :attr:`average_across_batch`' and :attr:`sum_over_batch` at the same time. average_across_remaining (bool): If set, average the entropy across the remaining dimensions. Must not set :attr:`average_across_remaining`' and :attr:`sum_over_remaining` at the same time. Used only when :attr:`logits` has rank >= 3. sum_over_batch (bool): If set, sum the entropy across the batch dimension. Must not set :attr:`average_across_batch` and :attr:`sum_over_batch` at the same time. sum_over_remaining (bool): If set, sum the entropy across the remaining dimension. Must not set :attr:`average_across_remaining` and :attr:`sum_over_remaining` at the same time. Used only when :attr:`logits` has rank >= 3. """ entropy = _get_entropy(logits) if rank is None: rank = get_rank(logits) if rank is None: rank = 2 rank -= 1 # reduced last dimension # Reduces if average_across_batch and sum_over_batch: raise ValueError("Only one of `average_across_batch` and " "`sum_over_batch` can be set.") if average_across_remaining and sum_over_remaining: raise ValueError("Only one of `average_across_remaining` and " "`sum_over_remaining` can be set.") sum_axes, average_axes = [], [] if sum_over_batch: sum_axes.append(0) if average_across_batch: average_axes.append(0) if sum_over_remaining and rank >= 2: sum_axes += list(range(1, rank)) if average_across_remaining and rank >= 2: average_axes += list(range(1, rank)) entropy = reduce_dimensions( entropy, average_axes=average_axes, sum_axes=sum_axes) return entropy
def entropy_with_logits(logits: torch.Tensor, rank: Optional[int] = None, average_across_batch: bool = True, average_across_remaining: bool = False, sum_over_batch: bool = False, sum_over_remaining: bool = True) -> torch.Tensor: r"""Shannon entropy given logits. Args: logits: Unscaled log probabilities of shape `[batch_size, d_2, ..., d_{rank-1}, distribution_dim]` and of dtype `float32` or `float64`. The rank of the tensor is optionally specified by the argument :attr:`rank`. The tensor is considered as having `[batch_size, .., d_{rank-1}]` elements, each of which has a distribution of length `d_rank` (i.e., `distribution_dim`). So the last dimension is always summed out to compute the entropy. rank (int, optional): The rank of :attr:`logits`. If `None` (default), `rank` is inferred automatically from `logits`. If the inference fails, `rank` is set to 2, i.e., assuming :attr:`logits` is of shape `[batch_size, distribution_dim]` average_across_batch (bool): If set, average the entropy across the batch dimension. Must not set `average_across_batch`' and `sum_over_batch` at the same time. average_across_remaining (bool): If set, average the entropy across the remaining dimensions. Must not set `average_across_remaining`' and `sum_over_remaining` at the same time. Used only when :attr:`logits` has rank >= 3. sum_over_batch (bool): If set, sum the entropy across the batch dimension. Must not set `average_across_batch` and `sum_over_batch` at the same time. sum_over_remaining (bool): If set, sum the entropy across the remaining dimension. Must not set `average_across_remaining` and `sum_over_remaining` at the same time. Used only when :attr:`logits` has rank >= 3. Returns: A Tensor containing the Shannon entropy. The dimensionality of the Tensor depends on the configuration of reduction arguments. For example, if both batch and remaining dimensions are reduced (by either sum or average), the returned Tensor is a scalar Tensor. """ entropy = _get_entropy(logits) if rank is None: rank = get_rank(logits) if rank is None: rank = 2 rank -= 1 if average_across_batch and sum_over_batch: raise ValueError("Only one of `average_across_batch` and " "`sum_over_batch` can be set.") if average_across_remaining and sum_over_remaining: raise ValueError("Only one of `average_across_remaining` and " "`sum_over_remaining` can be set.") sum_axes, average_axes = [], [] if sum_over_batch: sum_axes.append(0) if average_across_batch: average_axes.append(0) if sum_over_remaining and rank >= 2: sum_axes += list(range(1, rank)) if average_across_remaining and rank >= 2: average_axes += list(range(1, rank)) entropy = reduce_dimensions(entropy, average_axes=average_axes, sum_axes=sum_axes) return entropy
def binary_sigmoid_cross_entropy(pos_logits=None, neg_logits=None, average_across_batch=True, average_across_classes=True, sum_over_batch=False, sum_over_classes=False, return_pos_neg_losses=False, name=None): """Computes sigmoid cross entropy of binary predictions. Args: pos_logits: The logits of predicting positive on positive data. A tensor of shape `[batch_size(, num_classes)]`. neg_logits: The logits of predicting positive on negative data. A tensor of shape `[batch_size(, num_classes)]`. average_across_batch (bool): If set, average the loss across the batch dimension. Must not set `average_across_batch`' and `sum_over_batch` at the same time. average_across_classes (bool): If set, average the loss across the class dimension (if exists). Must not set `average_across_classes`' and `sum_over_classes` at the same time. Ignored if :attr:`logits` is a 1D Tensor. sum_over_batch (bool): If set, sum the loss across the batch dimension. Must not set `average_across_batch` and `sum_over_batch` at the same time. sum_over_classes (bool): If set, sum the loss across the class dimension. Must not set `average_across_classes` and `sum_over_classes` at the same time. Ignored if :attr:`logits` is a 2D Tensor. return_pos_neg_losses (bool): If set, additionally returns the losses on :attr:`pos_logits` and :attr:`neg_logits`, respectively. name (str, optional): A name for the operation. Returns: By default, a Tensor containing the loss, of rank 0, 1, or 2 depending on the arguments :attr:`{average_across}/{sum_over}_{batch}/{classes}`. For example: - If :attr:`sum_over_batch` and :attr:`average_across_classes` \ are `True` (default), the return Tensor is of rank 0. - If arguments are `False`, the return Tensor is of shape \ `[batch_size(, num_classes)]`. If :attr:`return_pos_neg_losses` is `True`, returns a tuple `(loss, pos_loss, neg_loss)`, where `loss` is the loss above; `pos_loss` is the loss on `pos_logits` only; and `neg_loss` is the loss on `neg_logits` only. They have `loss = pos_loss + neg_loss`. """ with tf.name_scope(name, "binary_sigmoid_cross_entropy"): average_axes, sum_axes = [], [] average_axes += [0] if average_across_batch else [] average_axes += [1] if average_across_classes else [] sum_axes += [0] if sum_over_batch else [] sum_axes += [1] if sum_over_classes else [] pos_loss = 0 if pos_logits is not None: pos_loss = tf.nn.sigmoid_cross_entropy_with_logits( logits=pos_logits, labels=tf.ones_like(pos_logits)) pos_loss = reduce_dimensions(pos_loss, average_axes, sum_axes) neg_loss = 0 if neg_logits is not None: neg_loss = tf.nn.sigmoid_cross_entropy_with_logits( logits=neg_logits, labels=tf.zeros_like(neg_logits)) neg_loss = reduce_dimensions(neg_loss, average_axes, sum_axes) loss = pos_loss + neg_loss if return_pos_neg_losses: return loss, pos_loss, neg_loss else: return loss
def binary_sigmoid_cross_entropy( pos_logits: Optional[torch.Tensor] = None, neg_logits: Optional[torch.Tensor] = None, average_across_batch: bool = True, average_across_classes: bool = True, sum_over_batch: bool = False, sum_over_classes: bool = False, return_pos_neg_losses: bool = False) \ -> Union[Tuple[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]: r"""Computes sigmoid cross entropy of binary predictions. Args: pos_logits: The logits of predicting positive on positive data. A tensor of shape `[batch_size(, num_classes)]`. neg_logits: The logits of predicting positive on negative data. A tensor of shape `[batch_size(, num_classes)]`. average_across_batch (bool): If set, average the loss across the batch dimension. Must not set `average_across_batch`' and `sum_over_batch` at the same time. average_across_classes (bool): If set, average the loss across the class dimension (if exists). Must not set `average_across_classes`' and `sum_over_classes` at the same time. Ignored if :attr:`logits` is a 1D Tensor. sum_over_batch (bool): If set, sum the loss across the batch dimension. Must not set `average_across_batch` and `sum_over_batch` at the same time. sum_over_classes (bool): If set, sum the loss across the class dimension. Must not set `average_across_classes` and `sum_over_classes` at the same time. Ignored if :attr:`logits` is a 2D Tensor. return_pos_neg_losses (bool): If set, additionally returns the losses on :attr:`pos_logits` and :attr:`neg_logits`, respectively. Returns: By default, a Tensor containing the loss, of rank 0, 1, or 2 depending on the arguments :attr:`{average_across}/{sum_over}_{batch}/{classes}`. For example: - If :attr:`sum_over_batch` and :attr:`average_across_classes` are `True` (default), the return Tensor is of rank 0. - If arguments are `False`, the return Tensor is of shape `[batch_size(, num_classes)]`. If :attr:`return_pos_neg_losses` is `True`, returns a tuple `(loss, pos_loss, neg_loss)`, where `loss` is the loss above; `pos_loss` is the loss on `pos_logits` only; and `neg_loss` is the loss on `neg_logits` only. They have `loss = pos_loss + neg_loss`. """ average_axes = [0] if average_across_batch else [] average_axes += [1] if average_across_classes else [] sum_axes = [0] if sum_over_batch else [] sum_axes += [1] if sum_over_classes else [] if pos_logits is not None: pos_loss = F.binary_cross_entropy_with_logits( pos_logits, torch.ones_like(pos_logits), reduction='none') pos_loss = reduce_dimensions(pos_loss, average_axes, sum_axes) else: pos_loss = 0 # type: ignore if neg_logits is not None: neg_loss = F.binary_cross_entropy_with_logits( neg_logits, torch.zeros_like(neg_logits), reduction='none') neg_loss = reduce_dimensions(neg_loss, average_axes, sum_axes) else: neg_loss = 0 # type: ignore loss = pos_loss + neg_loss if return_pos_neg_losses: return loss, pos_loss, neg_loss else: return loss