Пример #1
0
def _dct_ortho_norm(out, axis):
    factor = lax.concatenate([
        lax.full((1, ), 4, out.dtype),
        lax.full((out.shape[axis] - 1, ), 2, out.dtype)
    ], 0)
    factor = lax.expand_dims(factor, [a for a in range(out.ndim) if a != axis])
    return out / lax.sqrt(factor * out.shape[axis])
Пример #2
0
def gen_values_outside_bounds(constraint, size, key=random.PRNGKey(11)):
    if isinstance(constraint, constraints._Boolean):
        return random.bernoulli(key, shape=size) - 2
    elif isinstance(constraint, constraints._GreaterThan):
        return constraint.lower_bound - np.exp(random.normal(key, size))
    elif isinstance(constraint, constraints._IntegerInterval):
        lower_bound = np.broadcast_to(constraint.lower_bound, size)
        return random.randint(key, size, lower_bound - 1, lower_bound)
    elif isinstance(constraint, constraints._IntegerGreaterThan):
        return constraint.lower_bound - poisson(key, 5, shape=size)
    elif isinstance(constraint, constraints._Interval):
        upper_bound = np.broadcast_to(constraint.upper_bound, size)
        return random.uniform(key,
                              size,
                              minval=upper_bound,
                              maxval=upper_bound + 1.)
    elif isinstance(constraint, constraints._Real):
        return lax.full(size, np.nan)
    elif isinstance(constraint, constraints._Simplex):
        return osp.dirichlet.rvs(alpha=np.ones(
            (size[-1], )), size=size[:-1]) + 1e-2
    elif isinstance(constraint, constraints._Multinomial):
        n = size[-1]
        return multinomial(key,
                           p=np.ones((n, )) / n,
                           n=constraint.upper_bound,
                           shape=size[:-1]) + 1
    elif isinstance(constraint, constraints._CorrCholesky):
        return signed_stick_breaking_tril(
            random.uniform(key,
                           size[:-2] + (size[-1] * (size[-1] - 1) // 2, ),
                           minval=-1,
                           maxval=1)) + 1e-2
    else:
        raise NotImplementedError('{} not implemented.'.format(constraint))
Пример #3
0
def _tscan_impl(a, bs, fields, consts, aval_out, jaxpr):
    length = tuple(bs)[0].shape[0]
    state = [
        lax.full((length, ) + a[i].shape, 0, lax._dtype(a[i])) for i in fields
    ]

    def body_fun(i, vals):
        a, state = vals
        # select i-th element from each b
        b = [lax.dynamic_index_in_dim(b, i, keepdims=False) for b in bs]
        a_out = core.eval_jaxpr(jaxpr, consts, (), a, core.pack(b))
        # select fields from a_out and update state
        state_out = [
            lax.dynamic_update_index_in_dim(s, a[None, ...], i, axis=0)
            for a, s in zip([tuple(a_out)[j] for j in fields], state)
        ]
        return a_out, state_out

    _, state = lax.fori_loop(0, length, body_fun, (a, state))

    # set None for non-selected fields
    out = [None] * len(a)
    for field, i in zip(fields, range(len(fields))):
        out[field] = state[i]
    return core.pack(out)
Пример #4
0
def _average(a,
             axis: Optional[Union[int, Tuple[int, ...]]] = None,
             weights=None,
             returned=False):
    a = _asarray(a)

    if weights is None:  # Treat all weights as 1
        avg = mean(a, axis=axis)
        if axis is None:
            weights_sum = lax.full((),
                                   core.dimension_as_value(np.size(a)),
                                   dtype=avg.dtype)
        else:
            weights_sum = lax.full_like(avg,
                                        core.dimension_as_value(a.shape[axis]),
                                        dtype=avg.dtype)
    else:
        weights = _asarray(weights)

        if dtypes.issubdtype(a.dtype, np.inexact):
            out_dtype = dtypes.result_type(a.dtype, weights.dtype)
        else:
            out_dtype = dtypes.result_type(a.dtype, weights.dtype,
                                           dtypes.float_)
        out_dtype = dtypes.canonicalize_dtype(out_dtype)

        a_shape = np.shape(a)
        a_ndim = len(a_shape)
        weights_shape = np.shape(weights)
        axis = None if axis is None else _canonicalize_axis(axis, a_ndim)

        if a_shape != weights_shape:
            # Make sure the dimensions work out
            if axis is None:
                raise ValueError("Axis must be specified when shapes of a and "
                                 "weights differ.")
            if len(weights_shape) != 1:
                raise ValueError("1D weights expected when shapes of a and "
                                 "weights differ.")
            if not core.symbolic_equal_dim(weights_shape[0], a_shape[axis]):
                raise ValueError("Length of weights not "
                                 "compatible with specified axis.")

            weights = _broadcast_to(weights,
                                    (a_ndim - 1) * (1, ) + weights_shape)
            weights = _moveaxis(weights, -1, axis)

        weights_sum = sum(weights, axis=axis, dtype=out_dtype)
        avg = sum(a * weights, axis=axis, dtype=out_dtype) / weights_sum

    if returned:
        if avg.shape != weights_sum.shape:
            weights_sum = _broadcast_to(weights_sum, avg.shape)
        return avg, weights_sum
    return avg