Ejemplo n.º 1
0
    def _cumulative_reduction(a,
                              axis: Optional[Union[int, Tuple[int,
                                                              ...]]] = None,
                              dtype=None,
                              out=None):
        _check_arraylike(np_reduction.__name__, a)
        if out is not None:
            raise NotImplementedError(
                f"The 'out' argument to jnp.{np_reduction.__name__} "
                f"is not supported.")
        lax_internal._check_user_dtype_supported(dtype, np_reduction.__name__)

        if axis is None or _isscalar(a):
            a = lax.reshape(a, (np.size(a), ))
            axis = 0

        a_shape = list(np.shape(a))
        num_dims = len(a_shape)
        axis = _canonicalize_axis(axis, num_dims)

        if fill_nan:
            a = _where(lax_internal._isnan(a), _lax_const(a, fill_value), a)

        if not dtype and dtypes.dtype(a) == np.bool_:
            dtype = dtypes.canonicalize_dtype(dtypes.int_)
        if dtype:
            a = lax.convert_element_type(a, dtype)

        return reduction(a, axis)
Ejemplo n.º 2
0
def nanmean(a,
            axis: Optional[Union[int, Tuple[int, ...]]] = None,
            dtype=None,
            out=None,
            keepdims=False,
            where=None):
    _check_arraylike("nanmean", a)
    lax_internal._check_user_dtype_supported(dtype, "nanmean")
    if out is not None:
        raise NotImplementedError(
            "The 'out' argument to jnp.nanmean is not supported.")
    if dtypes.issubdtype(dtypes.dtype(a), np.bool_) or dtypes.issubdtype(
            dtypes.dtype(a), np.integer):
        return mean(a, axis, dtype, out, keepdims, where=where)
    if dtype is None:
        dtype = dtypes.dtype(a)
    nan_mask = lax_internal.bitwise_not(lax_internal._isnan(a))
    normalizer = sum(nan_mask,
                     axis=axis,
                     dtype=np.int32,
                     keepdims=keepdims,
                     where=where)
    normalizer = lax.convert_element_type(normalizer, dtype)
    td = lax.div(nansum(a, axis, dtype=dtype, keepdims=keepdims, where=where),
                 normalizer)
    return td
Ejemplo n.º 3
0
def _mean(a,
          axis: Optional[Union[int, Tuple[int, ...]]] = None,
          dtype=None,
          out=None,
          keepdims=False,
          *,
          where=None):
    _check_arraylike("mean", a)
    lax_internal._check_user_dtype_supported(dtype, "mean")
    if out is not None:
        raise NotImplementedError(
            "The 'out' argument to jnp.mean is not supported.")

    if where is None:
        if axis is None:
            normalizer = core.dimension_as_value(np.size(a))
        else:
            normalizer = core.dimension_as_value(_axis_size(a, axis))
    else:
        normalizer = sum(_broadcast_to(where, np.shape(a)),
                         axis,
                         dtype=dtype,
                         keepdims=keepdims)

    if dtype is None:
        dtype = dtypes._to_inexact_dtype(dtypes.dtype(a))
    dtype = dtypes.canonicalize_dtype(dtype)

    return lax.div(sum(a, axis, dtype=dtype, keepdims=keepdims, where=where),
                   lax.convert_element_type(normalizer, dtype))
Ejemplo n.º 4
0
def nanvar(a,
           axis: Optional[Union[int, Tuple[int, ...]]] = None,
           dtype=None,
           out=None,
           ddof=0,
           keepdims=False,
           where=None):
    _check_arraylike("nanvar", a)
    lax_internal._check_user_dtype_supported(dtype, "nanvar")
    if out is not None:
        raise NotImplementedError(
            "The 'out' argument to jnp.nanvar is not supported.")

    a_dtype, dtype = _var_promote_types(dtypes.dtype(a), dtype)
    a_mean = nanmean(a, axis, dtype=a_dtype, keepdims=True, where=where)

    centered = _where(lax_internal._isnan(a), 0,
                      a - a_mean)  # double-where trick for gradients.
    if dtypes.issubdtype(centered.dtype, np.complexfloating):
        centered = lax.real(lax.mul(centered, lax.conj(centered)))
    else:
        centered = lax.square(centered)

    normalizer = sum(lax_internal.bitwise_not(lax_internal._isnan(a)),
                     axis=axis,
                     keepdims=keepdims,
                     where=where)
    normalizer = normalizer - ddof
    normalizer_mask = lax.le(normalizer, 0)
    result = sum(centered, axis, keepdims=keepdims, where=where)
    result = _where(normalizer_mask, np.nan, result)
    divisor = _where(normalizer_mask, 1, normalizer)
    out = lax.div(result, lax.convert_element_type(divisor, result.dtype))
    return lax.convert_element_type(out, dtype)
Ejemplo n.º 5
0
def nanstd(a,
           axis: Optional[Union[int, Tuple[int, ...]]] = None,
           dtype=None,
           out=None,
           ddof=0,
           keepdims=False,
           where=None):
    _check_arraylike("nanstd", a)
    lax_internal._check_user_dtype_supported(dtype, "nanstd")
    if out is not None:
        raise NotImplementedError(
            "The 'out' argument to jnp.nanstd is not supported.")
    return lax.sqrt(
        nanvar(a,
               axis=axis,
               dtype=dtype,
               ddof=ddof,
               keepdims=keepdims,
               where=where))
Ejemplo n.º 6
0
def nanprod(a,
            axis: Optional[Union[int, Tuple[int, ...]]] = None,
            dtype=None,
            out=None,
            keepdims=None,
            initial=None,
            where=None):
    lax_internal._check_user_dtype_supported(dtype, "nanprod")
    return _nan_reduction(a,
                          'nanprod',
                          prod,
                          1,
                          nan_if_all_nan=False,
                          axis=axis,
                          dtype=dtype,
                          out=out,
                          keepdims=keepdims,
                          initial=initial,
                          where=where)
Ejemplo n.º 7
0
def _var(a,
         axis: Optional[Union[int, Tuple[int, ...]]] = None,
         dtype=None,
         out=None,
         ddof=0,
         keepdims=False,
         *,
         where=None):
    _check_arraylike("var", a)
    lax_internal._check_user_dtype_supported(dtype, "var")
    if out is not None:
        raise NotImplementedError(
            "The 'out' argument to jnp.var is not supported.")

    computation_dtype, dtype = _var_promote_types(dtypes.dtype(a), dtype)
    a = a.astype(computation_dtype)
    a_mean = mean(a, axis, dtype=computation_dtype, keepdims=True, where=where)
    centered = lax.sub(a, a_mean)
    if dtypes.issubdtype(centered.dtype, np.complexfloating):
        centered = lax.real(lax.mul(centered, lax.conj(centered)))
    else:
        centered = lax.square(centered)

    if where is None:
        if axis is None:
            normalizer = core.dimension_as_value(np.size(a))
        else:
            normalizer = core.dimension_as_value(_axis_size(a, axis))
    else:
        normalizer = sum(_broadcast_to(where, np.shape(a)),
                         axis,
                         dtype=dtype,
                         keepdims=keepdims)
    normalizer = normalizer - ddof

    result = sum(centered, axis, keepdims=keepdims, where=where)
    out = lax.div(result, lax.convert_element_type(normalizer, result.dtype))
    return lax.convert_element_type(out, dtype)
Ejemplo n.º 8
0
def _reduction(a,
               name,
               np_fun,
               op,
               init_val,
               has_identity=True,
               preproc=None,
               bool_op=None,
               upcast_f16_for_computation=False,
               axis=None,
               dtype=None,
               out=None,
               keepdims=False,
               initial=None,
               where_=None,
               parallel_reduce=None):
    bool_op = bool_op or op
    # Note: we must accept out=None as an argument, because numpy reductions delegate to
    # object methods. For example `np.sum(x)` will call `x.sum()` if the `sum()` method
    # exists, passing along all its arguments.
    if out is not None:
        raise NotImplementedError(
            f"The 'out' argument to jnp.{name} is not supported.")
    _check_arraylike(name, a)
    lax_internal._check_user_dtype_supported(dtype, name)
    axis = core.concrete_or_error(None, axis,
                                  f"axis argument to jnp.{name}().")

    if initial is None and not has_identity and where_ is not None:
        raise ValueError(
            f"reduction operation {name} does not have an identity, so to use a "
            f"where mask one has to specify 'initial'")

    a = a if isinstance(a, ndarray) else _asarray(a)
    a = preproc(a) if preproc else a
    pos_dims, dims = _reduction_dims(a, axis)

    if initial is None and not has_identity:
        shape = np.shape(a)
        if not _all(core.greater_equal_dim(shape[d], 1) for d in pos_dims):
            raise ValueError(
                f"zero-size array to reduction operation {name} which has no identity"
            )

    result_dtype = dtypes.canonicalize_dtype(
        dtype or dtypes.dtype(np_fun(np.ones((), dtype=dtypes.dtype(a)))))
    if upcast_f16_for_computation and dtypes.issubdtype(
            result_dtype, np.inexact):
        computation_dtype = _upcast_f16(result_dtype)
    else:
        computation_dtype = result_dtype
    a = lax.convert_element_type(a, computation_dtype)
    op = op if computation_dtype != np.bool_ else bool_op
    # NB: in XLA, init_val must be an identity for the op, so the user-specified
    # initial value must be applied afterward.
    init_val = _reduction_init_val(a, init_val)
    if where_ is not None:
        a = _where(where_, a, init_val)
    if pos_dims is not dims:
        if parallel_reduce is None:
            raise NotImplementedError(
                f"Named reductions not implemented for jnp.{name}()")
        result = parallel_reduce(a, dims)
    else:
        result = lax.reduce(a, init_val, op, dims)
    if initial is not None:
        result = op(lax.convert_element_type(initial, a.dtype), result)
    if keepdims:
        result = lax.expand_dims(result, pos_dims)
    return lax.convert_element_type(result, dtype or result_dtype)