示例#1
0
文件: loss.py 项目: rwightman/objax
def cross_entropy_logits(logits: JaxArray, labels: JaxArray) -> JaxArray:
    """Computes the softmax cross-entropy loss on n-dimensional data.

    Args:
        logits: (batch, ..., #class) tensor of logits.
        labels: (batch, ..., #class) tensor of label probabilities (e.g. labels.sum(axis=-1) must be 1)

    Returns:
        (batch, ...) tensor of the cross-entropies for each entry.
    """
    return logsumexp(logits, axis=-1) - (logits * labels).sum(-1)
示例#2
0
文件: loss.py 项目: rwightman/objax
def cross_entropy_logits_sparse(logits: JaxArray,
                                labels: Union[JaxArray, int]) -> JaxArray:
    """Computes the softmax cross-entropy loss.

    Args:
        logits: (batch, ..., #class) tensor of logits.
        labels: (batch, ...) integer tensor of label indexes in {0, ...,#nclass-1} or just a single integer.

    Returns:
        (batch,) tensor of the cross-entropies for each entry.
    """
    return logsumexp(logits, axis=1) - logits[jn.arange(logits.shape[0]),
                                              labels]
示例#3
0
文件: loss.py 项目: google/objax
def cross_entropy_logits(logits: JaxArray, labels: JaxArray) -> JaxArray:
    """Computes the softmax cross-entropy loss on n-dimensional data.

    Args:
        logits: (batch, ..., #class) tensor of logits.
        labels: (batch, ..., #class) tensor of label probabilities (e.g. labels.sum(axis=-1) must be 1)

    Returns:
        (batch, ...) tensor of the cross-entropies for each entry.
    """

    FUNC_NAME = 'cross_entropy_logits'
    if logits.shape != labels.shape:
        warnings.warn(' {} {} : arg1 {} and arg2 {}'.format(
            WARN_SHAPE_MISMATCH, FUNC_NAME, logits.shape, labels.shape))
    return logsumexp(logits, axis=-1) - (logits * labels).sum(-1)
示例#4
0
def cross_entropy_logits_sparse(logits: JaxArray,
                                labels: Union[JaxArray, int]) -> JaxArray:
    """Computes the softmax cross-entropy loss.

    Args:
        logits: (batch, ..., #class) tensor of logits.
        labels: (batch, ...) integer tensor of label indexes in {0, ...,#nclass-1} or just a single integer.

    Returns:
        (batch, ...) tensor of the cross-entropies for each entry.
    """
    if isinstance(labels, int):
        labeled_logits = logits[..., labels]
    else:
        labeled_logits = jn.take_along_axis(logits, labels[..., None],
                                            -1).squeeze(-1)

    return logsumexp(logits, axis=-1) - labeled_logits