def _moveaxis(a, source: int, destination: int): # simplified version of jnp.moveaxis() for local use. _check_arraylike("moveaxis", a) a = _asarray(a) source = _canonicalize_axis(source, np.ndim(a)) destination = _canonicalize_axis(destination, np.ndim(a)) perm = [i for i in range(np.ndim(a)) if i != source] perm.insert(destination, source) return lax.transpose(a, perm)
def _cumulative_reduction(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None, out=None): _check_arraylike(np_reduction.__name__, a) if out is not None: raise NotImplementedError( f"The 'out' argument to jnp.{np_reduction.__name__} " f"is not supported.") lax_internal._check_user_dtype_supported(dtype, np_reduction.__name__) if axis is None or _isscalar(a): a = lax.reshape(a, (np.size(a), )) axis = 0 a_shape = list(np.shape(a)) num_dims = len(a_shape) axis = _canonicalize_axis(axis, num_dims) if fill_nan: a = _where(lax_internal._isnan(a), _lax_const(a, fill_value), a) if not dtype and dtypes.dtype(a) == np.bool_: dtype = dtypes.canonicalize_dtype(dtypes.int_) if dtype: a = lax.convert_element_type(a, dtype) return reduction(a, axis)
def _average(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, weights=None, returned=False): a = _asarray(a) if weights is None: # Treat all weights as 1 avg = mean(a, axis=axis) if axis is None: weights_sum = lax.full((), core.dimension_as_value(np.size(a)), dtype=avg.dtype) else: weights_sum = lax.full_like(avg, core.dimension_as_value(a.shape[axis]), dtype=avg.dtype) else: weights = _asarray(weights) if dtypes.issubdtype(a.dtype, np.inexact): out_dtype = dtypes.result_type(a.dtype, weights.dtype) else: out_dtype = dtypes.result_type(a.dtype, weights.dtype, dtypes.float_) out_dtype = dtypes.canonicalize_dtype(out_dtype) a_shape = np.shape(a) a_ndim = len(a_shape) weights_shape = np.shape(weights) axis = None if axis is None else _canonicalize_axis(axis, a_ndim) if a_shape != weights_shape: # Make sure the dimensions work out if axis is None: raise ValueError("Axis must be specified when shapes of a and " "weights differ.") if len(weights_shape) != 1: raise ValueError("1D weights expected when shapes of a and " "weights differ.") if not core.symbolic_equal_dim(weights_shape[0], a_shape[axis]): raise ValueError("Length of weights not " "compatible with specified axis.") weights = _broadcast_to(weights, (a_ndim - 1) * (1, ) + weights_shape) weights = _moveaxis(weights, -1, axis) weights_sum = sum(weights, axis=axis, dtype=out_dtype) avg = sum(a * weights, axis=axis, dtype=out_dtype) / weights_sum if returned: if avg.shape != weights_sum.shape: weights_sum = _broadcast_to(weights_sum, avg.shape) return avg, weights_sum return avg
def _canonicalize_axis_allow_named(x, rank): return maybe_named_axis(x, lambda i: _canonicalize_axis(i, rank), lambda name: name)