Esempio n. 1
0
def _bincount(
        arr,
        weights=None,
        minlength=None,
        maxlength=None,  # pylint: disable=unused-argument
        dtype=np.int32,
        name=None):  # pylint: disable=unused-argument
    """Counts number of occurences of each value in `arr`."""
    # TODO(https://github.com/google/jax/issues/5719): Use np.bincount directly?
    if not JAX_MODE:
        return np.bincount(arr, weights,
                           minlength).astype(utils.numpy_dtype(dtype))

    dtype = utils.numpy_dtype(dtype)
    num_buckets = (np.max(arr) + 1) if np.size(arr) else 0
    if minlength is not None and maxlength is not None and minlength == maxlength:
        # In the case where we can use minlength directly, this helps avoids the
        # use of an abstract value, which prevents JAX JIT.
        num_buckets = minlength
    else:
        if minlength is not None:
            num_buckets = np.maximum(num_buckets, minlength)
        if maxlength is not None:
            num_buckets = np.minimum(num_buckets, maxlength)
    one_hots = one_hot(arr, num_buckets)
    # Reduce over every dimension except the last one.
    axes = tuple(range(0, one_hots.ndim - 1))
    if weights is not None:
        return np.sum(one_hots * weights[..., np.newaxis],
                      axis=axes).astype(dtype)
    return np.sum(one_hots, axis=axes).astype(dtype)
Esempio n. 2
0
def _bincount(
        arr,
        weights=None,
        minlength=None,
        maxlength=None,  # pylint: disable=unused-argument
        dtype=np.int32,
        name=None):  # pylint: disable=unused-argument
    """Counts number of occurences of each value in `arr`."""
    if not JAX_MODE:
        return np.bincount(arr, weights,
                           minlength).astype(utils.numpy_dtype(dtype))

    dtype = utils.numpy_dtype(dtype)
    num_buckets = np.max(arr) + 1
    if minlength is not None:
        num_buckets = np.maximum(num_buckets, minlength)
    if maxlength is not None:
        num_buckets = np.minimum(num_buckets, maxlength)
    one_hots = one_hot(arr, num_buckets)
    # Reduce over every dimension except the last one.
    axes = tuple(range(0, one_hots.ndim - 1))
    if weights is not None:
        return np.sum(one_hots * weights[..., np.newaxis],
                      axis=axes).astype(dtype)
    return np.sum(one_hots, axis=axes).astype(dtype)
Esempio n. 3
0
def _sparse_softmax_cross_entropy_with_logits(  # pylint: disable=invalid-name,unused-argument
        labels, logits, name=None):
    """Sparse Softmax cross entropy with logits."""
    labels_shape = labels.shape
    num_classes = logits.shape[-1]
    logits = np.reshape(logits, [-1, num_classes])
    labels = np.reshape(labels, [-1])

    labels = numpy_array.one_hot(labels, num_classes)

    cost = -np.sum(np.where(
        labels == 0, np.zeros_like(labels),
        labels * (logits - reduce_logsumexp(logits, axis=-1, keepdims=True))),
                   axis=-1)
    cost = np.reshape(cost, labels_shape)
    return cost
Esempio n. 4
0
def _sparse_softmax_cross_entropy_with_logits(  # pylint: disable=invalid-name,unused-argument
        _sentinel=None,
        labels=None,
        logits=None,
        name=None):
    """Sparse Softmax cross entropy with logits."""
    if _sentinel is not None:
        raise ValueError('Pass in `label` and `logits` parameters as kwargs')
    labels_shape = labels.shape
    num_classes = logits.shape[-1]
    logits = np.reshape(logits, [-1, num_classes])
    labels = np.reshape(labels, [-1])

    labels = numpy_array.one_hot(labels, num_classes)

    cost = -np.sum(np.where(
        labels == 0, np.zeros_like(labels),
        labels * (logits - reduce_logsumexp(logits, axis=-1, keepdims=True))),
                   axis=-1)
    cost = np.reshape(cost, labels_shape)
    return cost