def nanvar(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None, out=None, ddof=0, keepdims=False, where=None): _check_arraylike("nanvar", a) lax_internal._check_user_dtype_supported(dtype, "nanvar") if out is not None: raise NotImplementedError( "The 'out' argument to jnp.nanvar is not supported.") a_dtype, dtype = _var_promote_types(dtypes.dtype(a), dtype) a_mean = nanmean(a, axis, dtype=a_dtype, keepdims=True, where=where) centered = _where(lax_internal._isnan(a), 0, a - a_mean) # double-where trick for gradients. if dtypes.issubdtype(centered.dtype, np.complexfloating): centered = lax.real(lax.mul(centered, lax.conj(centered))) else: centered = lax.square(centered) normalizer = sum(lax_internal.bitwise_not(lax_internal._isnan(a)), axis=axis, keepdims=keepdims, where=where) normalizer = normalizer - ddof normalizer_mask = lax.le(normalizer, 0) result = sum(centered, axis, keepdims=keepdims, where=where) result = _where(normalizer_mask, np.nan, result) divisor = _where(normalizer_mask, 1, normalizer) out = lax.div(result, lax.convert_element_type(divisor, result.dtype)) return lax.convert_element_type(out, dtype)
def _cumulative_reduction(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None, out=None): _check_arraylike(np_reduction.__name__, a) if out is not None: raise NotImplementedError( f"The 'out' argument to jnp.{np_reduction.__name__} " f"is not supported.") lax_internal._check_user_dtype_supported(dtype, np_reduction.__name__) if axis is None or _isscalar(a): a = lax.reshape(a, (np.size(a), )) axis = 0 a_shape = list(np.shape(a)) num_dims = len(a_shape) axis = _canonicalize_axis(axis, num_dims) if fill_nan: a = _where(lax_internal._isnan(a), _lax_const(a, fill_value), a) if not dtype and dtypes.dtype(a) == np.bool_: dtype = dtypes.canonicalize_dtype(dtypes.int_) if dtype: a = lax.convert_element_type(a, dtype) return reduction(a, axis)
def nanmean(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None, out=None, keepdims=False, where=None): _check_arraylike("nanmean", a) lax_internal._check_user_dtype_supported(dtype, "nanmean") if out is not None: raise NotImplementedError( "The 'out' argument to jnp.nanmean is not supported.") if dtypes.issubdtype(dtypes.dtype(a), np.bool_) or dtypes.issubdtype( dtypes.dtype(a), np.integer): return mean(a, axis, dtype, out, keepdims, where=where) if dtype is None: dtype = dtypes.dtype(a) nan_mask = lax_internal.bitwise_not(lax_internal._isnan(a)) normalizer = sum(nan_mask, axis=axis, dtype=np.int32, keepdims=keepdims, where=where) normalizer = lax.convert_element_type(normalizer, dtype) td = lax.div(nansum(a, axis, dtype=dtype, keepdims=keepdims, where=where), normalizer) return td
def logaddexp(x1, x2): x1, x2 = _promote_args_inexact("logaddexp", x1, x2) amax = lax.max(x1, x2) if dtypes.issubdtype(x1.dtype, np.floating): delta = lax.sub(x1, x2) return lax.select(lax_internal._isnan(delta), lax.add(x1, x2), # NaNs or infinities of the same sign. lax.add(amax, lax.log1p(lax.exp(lax.neg(lax.abs(delta)))))) else: delta = lax.sub(lax.add(x1, x2), lax.mul(amax, _constant_like(amax, 2))) out = lax.add(amax, lax.log1p(lax.exp(delta))) return lax.complex(lax.real(out), _wrap_between(lax.imag(out), np.pi))
def _nan_reduction(a, name, jnp_reduction, init_val, nan_if_all_nan, axis=None, keepdims=None, **kwargs): _check_arraylike(name, a) if not dtypes.issubdtype(dtypes.dtype(a), np.inexact): return jnp_reduction(a, axis=axis, keepdims=keepdims, **kwargs) out = jnp_reduction(_where(lax_internal._isnan(a), _reduction_init_val(a, init_val), a), axis=axis, keepdims=keepdims, **kwargs) if nan_if_all_nan: return _where( all(lax_internal._isnan(a), axis=axis, keepdims=keepdims), _lax_const(a, np.nan), out) else: return out