Exemplo n.º 1
0
def cross_entropy_logits(logits: JaxArray, labels: JaxArray) -> JaxArray:
    """Computes the softmax cross-entropy loss on n-dimensional data.

    Args:
        logits: (batch, ..., #class) tensor of logits.
        labels: (batch, ..., #class) tensor of label probabilities (e.g. labels.sum(axis=-1) must be 1)

    Returns:
        (batch, ...) tensor of the cross-entropies for each entry.
    """
    return logsumexp(logits, axis=-1) - (logits * labels).sum(-1)
Exemplo n.º 2
0
def cross_entropy_logits_sparse(logits: JaxArray,
                                labels: Union[JaxArray, int]) -> JaxArray:
    """Computes the softmax cross-entropy loss.

    Args:
        logits: (batch, ..., #class) tensor of logits.
        labels: (batch, ...) integer tensor of label indexes in {0, ...,#nclass-1} or just a single integer.

    Returns:
        (batch,) tensor of the cross-entropies for each entry.
    """
    return logsumexp(logits, axis=1) - logits[jn.arange(logits.shape[0]),
                                              labels]
Exemplo n.º 3
0
Arquivo: loss.py Projeto: google/objax
def cross_entropy_logits(logits: JaxArray, labels: JaxArray) -> JaxArray:
    """Computes the softmax cross-entropy loss on n-dimensional data.

    Args:
        logits: (batch, ..., #class) tensor of logits.
        labels: (batch, ..., #class) tensor of label probabilities (e.g. labels.sum(axis=-1) must be 1)

    Returns:
        (batch, ...) tensor of the cross-entropies for each entry.
    """

    FUNC_NAME = 'cross_entropy_logits'
    if logits.shape != labels.shape:
        warnings.warn(' {} {} : arg1 {} and arg2 {}'.format(
            WARN_SHAPE_MISMATCH, FUNC_NAME, logits.shape, labels.shape))
    return logsumexp(logits, axis=-1) - (logits * labels).sum(-1)
Exemplo n.º 4
0
def cross_entropy_logits_sparse(logits: JaxArray,
                                labels: Union[JaxArray, int]) -> JaxArray:
    """Computes the softmax cross-entropy loss.

    Args:
        logits: (batch, ..., #class) tensor of logits.
        labels: (batch, ...) integer tensor of label indexes in {0, ...,#nclass-1} or just a single integer.

    Returns:
        (batch, ...) tensor of the cross-entropies for each entry.
    """
    if isinstance(labels, int):
        labeled_logits = logits[..., labels]
    else:
        labeled_logits = jn.take_along_axis(logits, labels[..., None],
                                            -1).squeeze(-1)

    return logsumexp(logits, axis=-1) - labeled_logits