def cosine_distance( predictions: chex.Array, targets: chex.Array, epsilon: float = 0., ) -> chex.Array: r"""Computes the cosine distance between targets and predictions. The cosine **distance**, implemented here, measures the **dissimilarity** of two vectors as the opposite of cosine **similarity**: `1 - cos(\theta)`. References: [Wikipedia, 2021](https://en.wikipedia.org/wiki/Cosine_similarity) Args: predictions: The predicted vector. targets: Ground truth target vector. epsilon: minimum norm for terms in the denominator of the cosine similarity. Returns: cosine similarity values. """ chex.assert_is_broadcastable(targets.shape, predictions.shape) chex.assert_type([predictions, targets], float) # cosine distance = 1 - cosine similarity. return 1. - cosine_similarity(predictions, targets, epsilon)
def cosine_similarity( predictions: chex.Array, targets: chex.Array, epsilon: float = 0., ) -> chex.Array: r"""Computes the cosine similarity between targets and predictions. The cosine **similarity** is a measure of similarity between vectors defined as the cosine of the angle between them, which is also the inner product of those vectors normalized to have unit norm. References: [Wikipedia, 2021](https://en.wikipedia.org/wiki/Cosine_similarity) Args: predictions: The predicted vector. targets: Ground truth target vector. epsilon: minimum norm for terms in the denominator of the cosine similarity. Returns: cosine similarity values. """ chex.assert_is_broadcastable(targets.shape, predictions.shape) chex.assert_type([predictions, targets], float) # vectorize norm fn, to treat all dimensions except the last as batch dims. batched_norm_fn = jnp.vectorize(utils.safe_norm, signature='(k)->()', excluded={1}) # normalise the last dimension of targets and predictions. unit_targets = targets / jnp.expand_dims(batched_norm_fn(targets, epsilon), axis=-1) unit_predictions = predictions / jnp.expand_dims( batched_norm_fn(predictions, epsilon), axis=-1) # return cosine similarity. return jnp.sum(unit_targets * unit_predictions, axis=-1)
def softmax_cross_entropy( logits: chex.Array, labels: chex.Array, ) -> chex.Array: """Computes the softmax cross entropy between sets of logits and labels. Measures the probability error in discrete classification tasks in which the classes are mutually exclusive (each entry is in exactly one class). For example, each CIFAR-10 image is labeled with one and only one label: an image can be a dog or a truck, but not both. References: [Goodfellow et al, 2016](http://www.deeplearningbook.org/contents/prob.html) Args: logits: unnormalized log probabilities. labels: a valid probability distribution (non-negative, sum to 1), e.g a one hot encoding of which class is the correct one for each input. Returns: the cross entropy loss. """ chex.assert_is_broadcastable(labels.shape, logits.shape) chex.assert_type([logits], float) return -jnp.sum(labels * jax.nn.log_softmax(logits, axis=-1), axis=-1)
def unitwise_norm(x: chex.Array) -> chex.Array: """Computes norms of each output unit separately.""" if jnp.squeeze(x).ndim <= 1: # Scalars and vectors squared_norm = jnp.sum(numerics.abs_sq(x), keepdims=True) # Note that this assumes parameters with a shape of length 3 are multihead # linear parameters--if you wish to apply AGC to 1D convs, you may need # to modify this line. elif x.ndim in (2, 3): # Linear layers of shape IO or multihead linear squared_norm = jnp.sum(numerics.abs_sq(x), axis=0, keepdims=True) elif x.ndim == 4: # Conv kernels of shape HWIO squared_norm = jnp.sum(numerics.abs_sq(x), axis=(0, 1, 2), keepdims=True) else: raise ValueError( f'Expected parameter with shape in {1, 2, 3, 4}, got {x.shape}.') chex.assert_is_broadcastable(squared_norm.shape, x.shape) return jnp.broadcast_to(jnp.sqrt(squared_norm), x.shape)
def sigmoid_binary_cross_entropy(logits, labels): """Computes sigmoid cross entropy given logits and multiple class labels. Measures the probability error in discrete classification tasks in which each class is an independent binary prediction and different classes are not mutually exclusive. This may be used for multilabel image classification for instance a model may predict that an image contains both a cat and a dog. References: [Goodfellow et al, 2016](http://www.deeplearningbook.org/contents/prob.html) Args: logits: unnormalized log probabilities. labels: the probability for that class. Returns: a sigmoid cross entropy loss. """ chex.assert_is_broadcastable(labels.shape, logits.shape) chex.assert_type([logits], float) log_p = jax.nn.log_sigmoid(logits) # log(1 - sigmoid(x)) = log_sigmoid(-x), the latter more numerically stable log_not_p = jax.nn.log_sigmoid(-logits) return -labels * log_p - (1. - labels) * log_not_p