コード例 #1
0
ファイル: reductions.py プロジェクト: cloudhan/jax
def nanvar(a,
           axis: Optional[Union[int, Tuple[int, ...]]] = None,
           dtype=None,
           out=None,
           ddof=0,
           keepdims=False,
           where=None):
    _check_arraylike("nanvar", a)
    lax_internal._check_user_dtype_supported(dtype, "nanvar")
    if out is not None:
        raise NotImplementedError(
            "The 'out' argument to jnp.nanvar is not supported.")

    a_dtype, dtype = _var_promote_types(dtypes.dtype(a), dtype)
    a_mean = nanmean(a, axis, dtype=a_dtype, keepdims=True, where=where)

    centered = _where(lax_internal._isnan(a), 0,
                      a - a_mean)  # double-where trick for gradients.
    if dtypes.issubdtype(centered.dtype, np.complexfloating):
        centered = lax.real(lax.mul(centered, lax.conj(centered)))
    else:
        centered = lax.square(centered)

    normalizer = sum(lax_internal.bitwise_not(lax_internal._isnan(a)),
                     axis=axis,
                     keepdims=keepdims,
                     where=where)
    normalizer = normalizer - ddof
    normalizer_mask = lax.le(normalizer, 0)
    result = sum(centered, axis, keepdims=keepdims, where=where)
    result = _where(normalizer_mask, np.nan, result)
    divisor = _where(normalizer_mask, 1, normalizer)
    out = lax.div(result, lax.convert_element_type(divisor, result.dtype))
    return lax.convert_element_type(out, dtype)
コード例 #2
0
    def _cumulative_reduction(a,
                              axis: Optional[Union[int, Tuple[int,
                                                              ...]]] = None,
                              dtype=None,
                              out=None):
        _check_arraylike(np_reduction.__name__, a)
        if out is not None:
            raise NotImplementedError(
                f"The 'out' argument to jnp.{np_reduction.__name__} "
                f"is not supported.")
        lax_internal._check_user_dtype_supported(dtype, np_reduction.__name__)

        if axis is None or _isscalar(a):
            a = lax.reshape(a, (np.size(a), ))
            axis = 0

        a_shape = list(np.shape(a))
        num_dims = len(a_shape)
        axis = _canonicalize_axis(axis, num_dims)

        if fill_nan:
            a = _where(lax_internal._isnan(a), _lax_const(a, fill_value), a)

        if not dtype and dtypes.dtype(a) == np.bool_:
            dtype = dtypes.canonicalize_dtype(dtypes.int_)
        if dtype:
            a = lax.convert_element_type(a, dtype)

        return reduction(a, axis)
コード例 #3
0
def nanmean(a,
            axis: Optional[Union[int, Tuple[int, ...]]] = None,
            dtype=None,
            out=None,
            keepdims=False,
            where=None):
    _check_arraylike("nanmean", a)
    lax_internal._check_user_dtype_supported(dtype, "nanmean")
    if out is not None:
        raise NotImplementedError(
            "The 'out' argument to jnp.nanmean is not supported.")
    if dtypes.issubdtype(dtypes.dtype(a), np.bool_) or dtypes.issubdtype(
            dtypes.dtype(a), np.integer):
        return mean(a, axis, dtype, out, keepdims, where=where)
    if dtype is None:
        dtype = dtypes.dtype(a)
    nan_mask = lax_internal.bitwise_not(lax_internal._isnan(a))
    normalizer = sum(nan_mask,
                     axis=axis,
                     dtype=np.int32,
                     keepdims=keepdims,
                     where=where)
    normalizer = lax.convert_element_type(normalizer, dtype)
    td = lax.div(nansum(a, axis, dtype=dtype, keepdims=keepdims, where=where),
                 normalizer)
    return td
コード例 #4
0
def logaddexp(x1, x2):
  x1, x2 = _promote_args_inexact("logaddexp", x1, x2)
  amax = lax.max(x1, x2)
  if dtypes.issubdtype(x1.dtype, np.floating):
    delta = lax.sub(x1, x2)
    return lax.select(lax_internal._isnan(delta),
                      lax.add(x1, x2),  # NaNs or infinities of the same sign.
                      lax.add(amax, lax.log1p(lax.exp(lax.neg(lax.abs(delta))))))
  else:
    delta = lax.sub(lax.add(x1, x2), lax.mul(amax, _constant_like(amax, 2)))
    out = lax.add(amax, lax.log1p(lax.exp(delta)))
    return lax.complex(lax.real(out), _wrap_between(lax.imag(out), np.pi))
コード例 #5
0
def _nan_reduction(a,
                   name,
                   jnp_reduction,
                   init_val,
                   nan_if_all_nan,
                   axis=None,
                   keepdims=None,
                   **kwargs):
    _check_arraylike(name, a)
    if not dtypes.issubdtype(dtypes.dtype(a), np.inexact):
        return jnp_reduction(a, axis=axis, keepdims=keepdims, **kwargs)

    out = jnp_reduction(_where(lax_internal._isnan(a),
                               _reduction_init_val(a, init_val), a),
                        axis=axis,
                        keepdims=keepdims,
                        **kwargs)
    if nan_if_all_nan:
        return _where(
            all(lax_internal._isnan(a), axis=axis, keepdims=keepdims),
            _lax_const(a, np.nan), out)
    else:
        return out