예제 #1
0
def _moveaxis(a, source: int, destination: int):
    # simplified version of jnp.moveaxis() for local use.
    _check_arraylike("moveaxis", a)
    a = _asarray(a)
    source = _canonicalize_axis(source, np.ndim(a))
    destination = _canonicalize_axis(destination, np.ndim(a))
    perm = [i for i in range(np.ndim(a)) if i != source]
    perm.insert(destination, source)
    return lax.transpose(a, perm)
예제 #2
0
    def _cumulative_reduction(a,
                              axis: Optional[Union[int, Tuple[int,
                                                              ...]]] = None,
                              dtype=None,
                              out=None):
        _check_arraylike(np_reduction.__name__, a)
        if out is not None:
            raise NotImplementedError(
                f"The 'out' argument to jnp.{np_reduction.__name__} "
                f"is not supported.")
        lax_internal._check_user_dtype_supported(dtype, np_reduction.__name__)

        if axis is None or _isscalar(a):
            a = lax.reshape(a, (np.size(a), ))
            axis = 0

        a_shape = list(np.shape(a))
        num_dims = len(a_shape)
        axis = _canonicalize_axis(axis, num_dims)

        if fill_nan:
            a = _where(lax_internal._isnan(a), _lax_const(a, fill_value), a)

        if not dtype and dtypes.dtype(a) == np.bool_:
            dtype = dtypes.canonicalize_dtype(dtypes.int_)
        if dtype:
            a = lax.convert_element_type(a, dtype)

        return reduction(a, axis)
예제 #3
0
파일: reductions.py 프로젝트: cloudhan/jax
def _average(a,
             axis: Optional[Union[int, Tuple[int, ...]]] = None,
             weights=None,
             returned=False):
    a = _asarray(a)

    if weights is None:  # Treat all weights as 1
        avg = mean(a, axis=axis)
        if axis is None:
            weights_sum = lax.full((),
                                   core.dimension_as_value(np.size(a)),
                                   dtype=avg.dtype)
        else:
            weights_sum = lax.full_like(avg,
                                        core.dimension_as_value(a.shape[axis]),
                                        dtype=avg.dtype)
    else:
        weights = _asarray(weights)

        if dtypes.issubdtype(a.dtype, np.inexact):
            out_dtype = dtypes.result_type(a.dtype, weights.dtype)
        else:
            out_dtype = dtypes.result_type(a.dtype, weights.dtype,
                                           dtypes.float_)
        out_dtype = dtypes.canonicalize_dtype(out_dtype)

        a_shape = np.shape(a)
        a_ndim = len(a_shape)
        weights_shape = np.shape(weights)
        axis = None if axis is None else _canonicalize_axis(axis, a_ndim)

        if a_shape != weights_shape:
            # Make sure the dimensions work out
            if axis is None:
                raise ValueError("Axis must be specified when shapes of a and "
                                 "weights differ.")
            if len(weights_shape) != 1:
                raise ValueError("1D weights expected when shapes of a and "
                                 "weights differ.")
            if not core.symbolic_equal_dim(weights_shape[0], a_shape[axis]):
                raise ValueError("Length of weights not "
                                 "compatible with specified axis.")

            weights = _broadcast_to(weights,
                                    (a_ndim - 1) * (1, ) + weights_shape)
            weights = _moveaxis(weights, -1, axis)

        weights_sum = sum(weights, axis=axis, dtype=out_dtype)
        avg = sum(a * weights, axis=axis, dtype=out_dtype) / weights_sum

    if returned:
        if avg.shape != weights_sum.shape:
            weights_sum = _broadcast_to(weights_sum, avg.shape)
        return avg, weights_sum
    return avg
예제 #4
0
def _canonicalize_axis_allow_named(x, rank):
    return maybe_named_axis(x, lambda i: _canonicalize_axis(i, rank),
                            lambda name: name)