def cross_entropy_logits(logits: JaxArray, labels: JaxArray) -> JaxArray: """Computes the softmax cross-entropy loss on n-dimensional data. Args: logits: (batch, ..., #class) tensor of logits. labels: (batch, ..., #class) tensor of label probabilities (e.g. labels.sum(axis=-1) must be 1) Returns: (batch, ...) tensor of the cross-entropies for each entry. """ return logsumexp(logits, axis=-1) - (logits * labels).sum(-1)
def cross_entropy_logits_sparse(logits: JaxArray, labels: Union[JaxArray, int]) -> JaxArray: """Computes the softmax cross-entropy loss. Args: logits: (batch, ..., #class) tensor of logits. labels: (batch, ...) integer tensor of label indexes in {0, ...,#nclass-1} or just a single integer. Returns: (batch,) tensor of the cross-entropies for each entry. """ return logsumexp(logits, axis=1) - logits[jn.arange(logits.shape[0]), labels]
def cross_entropy_logits(logits: JaxArray, labels: JaxArray) -> JaxArray: """Computes the softmax cross-entropy loss on n-dimensional data. Args: logits: (batch, ..., #class) tensor of logits. labels: (batch, ..., #class) tensor of label probabilities (e.g. labels.sum(axis=-1) must be 1) Returns: (batch, ...) tensor of the cross-entropies for each entry. """ FUNC_NAME = 'cross_entropy_logits' if logits.shape != labels.shape: warnings.warn(' {} {} : arg1 {} and arg2 {}'.format( WARN_SHAPE_MISMATCH, FUNC_NAME, logits.shape, labels.shape)) return logsumexp(logits, axis=-1) - (logits * labels).sum(-1)
def cross_entropy_logits_sparse(logits: JaxArray, labels: Union[JaxArray, int]) -> JaxArray: """Computes the softmax cross-entropy loss. Args: logits: (batch, ..., #class) tensor of logits. labels: (batch, ...) integer tensor of label indexes in {0, ...,#nclass-1} or just a single integer. Returns: (batch, ...) tensor of the cross-entropies for each entry. """ if isinstance(labels, int): labeled_logits = logits[..., labels] else: labeled_logits = jn.take_along_axis(logits, labels[..., None], -1).squeeze(-1) return logsumexp(logits, axis=-1) - labeled_logits