def nanvar(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None, out=None, ddof=0, keepdims=False, where=None): _check_arraylike("nanvar", a) lax_internal._check_user_dtype_supported(dtype, "nanvar") if out is not None: raise NotImplementedError( "The 'out' argument to jnp.nanvar is not supported.") a_dtype, dtype = _var_promote_types(dtypes.dtype(a), dtype) a_mean = nanmean(a, axis, dtype=a_dtype, keepdims=True, where=where) centered = _where(lax_internal._isnan(a), 0, a - a_mean) # double-where trick for gradients. if dtypes.issubdtype(centered.dtype, np.complexfloating): centered = lax.real(lax.mul(centered, lax.conj(centered))) else: centered = lax.square(centered) normalizer = sum(lax_internal.bitwise_not(lax_internal._isnan(a)), axis=axis, keepdims=keepdims, where=where) normalizer = normalizer - ddof normalizer_mask = lax.le(normalizer, 0) result = sum(centered, axis, keepdims=keepdims, where=where) result = _where(normalizer_mask, np.nan, result) divisor = _where(normalizer_mask, 1, normalizer) out = lax.div(result, lax.convert_element_type(divisor, result.dtype)) return lax.convert_element_type(out, dtype)
def _irfft_transpose(t, fft_lengths): # The transpose of IRFFT is the RFFT of the cotangent times a scaling # factor and a mask. The mask scales the cotangent for the Hermitian # symmetric components of the RFFT by a factor of two, since these components # are de-duplicated in the RFFT. x = fft(t, xla_client.FftType.RFFT, fft_lengths) n = x.shape[-1] is_odd = fft_lengths[-1] % 2 full = partial(lax.full_like, t, dtype=t.dtype) mask = lax.concatenate( [full(1.0, shape=(1,)), full(2.0, shape=(n - 2 + is_odd,)), full(1.0, shape=(1 - is_odd,))], dimension=0) scale = 1 / prod(fft_lengths) out = scale * mask * x assert out.dtype == _complex_dtype(t.dtype), (out.dtype, t.dtype) # Use JAX's convention for complex gradients # https://github.com/google/jax/issues/6223#issuecomment-807740707 return lax.conj(out)
def _var(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None, out=None, ddof=0, keepdims=False, *, where=None): _check_arraylike("var", a) lax_internal._check_user_dtype_supported(dtype, "var") if out is not None: raise NotImplementedError( "The 'out' argument to jnp.var is not supported.") computation_dtype, dtype = _var_promote_types(dtypes.dtype(a), dtype) a = a.astype(computation_dtype) a_mean = mean(a, axis, dtype=computation_dtype, keepdims=True, where=where) centered = lax.sub(a, a_mean) if dtypes.issubdtype(centered.dtype, np.complexfloating): centered = lax.real(lax.mul(centered, lax.conj(centered))) else: centered = lax.square(centered) if where is None: if axis is None: normalizer = core.dimension_as_value(np.size(a)) else: normalizer = core.dimension_as_value(_axis_size(a, axis)) else: normalizer = sum(_broadcast_to(where, np.shape(a)), axis, dtype=dtype, keepdims=keepdims) normalizer = normalizer - ddof result = sum(centered, axis, keepdims=keepdims, where=where) out = lax.div(result, lax.convert_element_type(normalizer, result.dtype)) return lax.convert_element_type(out, dtype)
def conjugate(x): _check_arraylike("conjugate", x) return lax.conj(x) if np.iscomplexobj(x) else x
def conjugate(x): return lax.conj(x) if iscomplexobj(x) else x