def _bincount( arr, weights=None, minlength=None, maxlength=None, # pylint: disable=unused-argument dtype=np.int32, name=None): # pylint: disable=unused-argument """Counts number of occurences of each value in `arr`.""" # TODO(https://github.com/google/jax/issues/5719): Use np.bincount directly? if not JAX_MODE: return np.bincount(arr, weights, minlength).astype(utils.numpy_dtype(dtype)) dtype = utils.numpy_dtype(dtype) num_buckets = (np.max(arr) + 1) if np.size(arr) else 0 if minlength is not None and maxlength is not None and minlength == maxlength: # In the case where we can use minlength directly, this helps avoids the # use of an abstract value, which prevents JAX JIT. num_buckets = minlength else: if minlength is not None: num_buckets = np.maximum(num_buckets, minlength) if maxlength is not None: num_buckets = np.minimum(num_buckets, maxlength) one_hots = one_hot(arr, num_buckets) # Reduce over every dimension except the last one. axes = tuple(range(0, one_hots.ndim - 1)) if weights is not None: return np.sum(one_hots * weights[..., np.newaxis], axis=axes).astype(dtype) return np.sum(one_hots, axis=axes).astype(dtype)
def _bincount( arr, weights=None, minlength=None, maxlength=None, # pylint: disable=unused-argument dtype=np.int32, name=None): # pylint: disable=unused-argument """Counts number of occurences of each value in `arr`.""" if not JAX_MODE: return np.bincount(arr, weights, minlength).astype(utils.numpy_dtype(dtype)) dtype = utils.numpy_dtype(dtype) num_buckets = np.max(arr) + 1 if minlength is not None: num_buckets = np.maximum(num_buckets, minlength) if maxlength is not None: num_buckets = np.minimum(num_buckets, maxlength) one_hots = one_hot(arr, num_buckets) # Reduce over every dimension except the last one. axes = tuple(range(0, one_hots.ndim - 1)) if weights is not None: return np.sum(one_hots * weights[..., np.newaxis], axis=axes).astype(dtype) return np.sum(one_hots, axis=axes).astype(dtype)
def _sparse_softmax_cross_entropy_with_logits( # pylint: disable=invalid-name,unused-argument labels, logits, name=None): """Sparse Softmax cross entropy with logits.""" labels_shape = labels.shape num_classes = logits.shape[-1] logits = np.reshape(logits, [-1, num_classes]) labels = np.reshape(labels, [-1]) labels = numpy_array.one_hot(labels, num_classes) cost = -np.sum(np.where( labels == 0, np.zeros_like(labels), labels * (logits - reduce_logsumexp(logits, axis=-1, keepdims=True))), axis=-1) cost = np.reshape(cost, labels_shape) return cost
def _sparse_softmax_cross_entropy_with_logits( # pylint: disable=invalid-name,unused-argument _sentinel=None, labels=None, logits=None, name=None): """Sparse Softmax cross entropy with logits.""" if _sentinel is not None: raise ValueError('Pass in `label` and `logits` parameters as kwargs') labels_shape = labels.shape num_classes = logits.shape[-1] logits = np.reshape(logits, [-1, num_classes]) labels = np.reshape(labels, [-1]) labels = numpy_array.one_hot(labels, num_classes) cost = -np.sum(np.where( labels == 0, np.zeros_like(labels), labels * (logits - reduce_logsumexp(logits, axis=-1, keepdims=True))), axis=-1) cost = np.reshape(cost, labels_shape) return cost