Esempio n. 1
0
def _average(a,
             axis: Optional[Union[int, Tuple[int, ...]]] = None,
             weights=None,
             returned=False):
    a = _asarray(a)

    if weights is None:  # Treat all weights as 1
        avg = mean(a, axis=axis)
        if axis is None:
            weights_sum = lax.full((),
                                   core.dimension_as_value(np.size(a)),
                                   dtype=avg.dtype)
        else:
            weights_sum = lax.full_like(avg,
                                        core.dimension_as_value(a.shape[axis]),
                                        dtype=avg.dtype)
    else:
        weights = _asarray(weights)

        if dtypes.issubdtype(a.dtype, np.inexact):
            out_dtype = dtypes.result_type(a.dtype, weights.dtype)
        else:
            out_dtype = dtypes.result_type(a.dtype, weights.dtype,
                                           dtypes.float_)
        out_dtype = dtypes.canonicalize_dtype(out_dtype)

        a_shape = np.shape(a)
        a_ndim = len(a_shape)
        weights_shape = np.shape(weights)
        axis = None if axis is None else _canonicalize_axis(axis, a_ndim)

        if a_shape != weights_shape:
            # Make sure the dimensions work out
            if axis is None:
                raise ValueError("Axis must be specified when shapes of a and "
                                 "weights differ.")
            if len(weights_shape) != 1:
                raise ValueError("1D weights expected when shapes of a and "
                                 "weights differ.")
            if not core.symbolic_equal_dim(weights_shape[0], a_shape[axis]):
                raise ValueError("Length of weights not "
                                 "compatible with specified axis.")

            weights = _broadcast_to(weights,
                                    (a_ndim - 1) * (1, ) + weights_shape)
            weights = _moveaxis(weights, -1, axis)

        weights_sum = sum(weights, axis=axis, dtype=out_dtype)
        avg = sum(a * weights, axis=axis, dtype=out_dtype) / weights_sum

    if returned:
        if avg.shape != weights_sum.shape:
            weights_sum = _broadcast_to(weights_sum, avg.shape)
        return avg, weights_sum
    return avg
Esempio n. 2
0
def _mean(a,
          axis: Optional[Union[int, Tuple[int, ...]]] = None,
          dtype=None,
          out=None,
          keepdims=False,
          *,
          where=None):
    _check_arraylike("mean", a)
    lax_internal._check_user_dtype_supported(dtype, "mean")
    if out is not None:
        raise NotImplementedError(
            "The 'out' argument to jnp.mean is not supported.")

    if where is None:
        if axis is None:
            normalizer = core.dimension_as_value(np.size(a))
        else:
            normalizer = core.dimension_as_value(_axis_size(a, axis))
    else:
        normalizer = sum(_broadcast_to(where, np.shape(a)),
                         axis,
                         dtype=dtype,
                         keepdims=keepdims)

    if dtype is None:
        dtype = dtypes._to_inexact_dtype(dtypes.dtype(a))
    dtype = dtypes.canonicalize_dtype(dtype)

    return lax.div(sum(a, axis, dtype=dtype, keepdims=keepdims, where=where),
                   lax.convert_element_type(normalizer, dtype))
Esempio n. 3
0
def _var(a,
         axis: Optional[Union[int, Tuple[int, ...]]] = None,
         dtype=None,
         out=None,
         ddof=0,
         keepdims=False,
         *,
         where=None):
    _check_arraylike("var", a)
    lax_internal._check_user_dtype_supported(dtype, "var")
    if out is not None:
        raise NotImplementedError(
            "The 'out' argument to jnp.var is not supported.")

    computation_dtype, dtype = _var_promote_types(dtypes.dtype(a), dtype)
    a = a.astype(computation_dtype)
    a_mean = mean(a, axis, dtype=computation_dtype, keepdims=True, where=where)
    centered = lax.sub(a, a_mean)
    if dtypes.issubdtype(centered.dtype, np.complexfloating):
        centered = lax.real(lax.mul(centered, lax.conj(centered)))
    else:
        centered = lax.square(centered)

    if where is None:
        if axis is None:
            normalizer = core.dimension_as_value(np.size(a))
        else:
            normalizer = core.dimension_as_value(_axis_size(a, axis))
    else:
        normalizer = sum(_broadcast_to(where, np.shape(a)),
                         axis,
                         dtype=dtype,
                         keepdims=keepdims)
    normalizer = normalizer - ddof

    result = sum(centered, axis, keepdims=keepdims, where=where)
    out = lax.div(result, lax.convert_element_type(normalizer, result.dtype))
    return lax.convert_element_type(out, dtype)