def _dct_ortho_norm(out, axis): factor = lax.concatenate([ lax.full((1, ), 4, out.dtype), lax.full((out.shape[axis] - 1, ), 2, out.dtype) ], 0) factor = lax.expand_dims(factor, [a for a in range(out.ndim) if a != axis]) return out / lax.sqrt(factor * out.shape[axis])
def gen_values_outside_bounds(constraint, size, key=random.PRNGKey(11)): if isinstance(constraint, constraints._Boolean): return random.bernoulli(key, shape=size) - 2 elif isinstance(constraint, constraints._GreaterThan): return constraint.lower_bound - np.exp(random.normal(key, size)) elif isinstance(constraint, constraints._IntegerInterval): lower_bound = np.broadcast_to(constraint.lower_bound, size) return random.randint(key, size, lower_bound - 1, lower_bound) elif isinstance(constraint, constraints._IntegerGreaterThan): return constraint.lower_bound - poisson(key, 5, shape=size) elif isinstance(constraint, constraints._Interval): upper_bound = np.broadcast_to(constraint.upper_bound, size) return random.uniform(key, size, minval=upper_bound, maxval=upper_bound + 1.) elif isinstance(constraint, constraints._Real): return lax.full(size, np.nan) elif isinstance(constraint, constraints._Simplex): return osp.dirichlet.rvs(alpha=np.ones( (size[-1], )), size=size[:-1]) + 1e-2 elif isinstance(constraint, constraints._Multinomial): n = size[-1] return multinomial(key, p=np.ones((n, )) / n, n=constraint.upper_bound, shape=size[:-1]) + 1 elif isinstance(constraint, constraints._CorrCholesky): return signed_stick_breaking_tril( random.uniform(key, size[:-2] + (size[-1] * (size[-1] - 1) // 2, ), minval=-1, maxval=1)) + 1e-2 else: raise NotImplementedError('{} not implemented.'.format(constraint))
def _tscan_impl(a, bs, fields, consts, aval_out, jaxpr): length = tuple(bs)[0].shape[0] state = [ lax.full((length, ) + a[i].shape, 0, lax._dtype(a[i])) for i in fields ] def body_fun(i, vals): a, state = vals # select i-th element from each b b = [lax.dynamic_index_in_dim(b, i, keepdims=False) for b in bs] a_out = core.eval_jaxpr(jaxpr, consts, (), a, core.pack(b)) # select fields from a_out and update state state_out = [ lax.dynamic_update_index_in_dim(s, a[None, ...], i, axis=0) for a, s in zip([tuple(a_out)[j] for j in fields], state) ] return a_out, state_out _, state = lax.fori_loop(0, length, body_fun, (a, state)) # set None for non-selected fields out = [None] * len(a) for field, i in zip(fields, range(len(fields))): out[field] = state[i] return core.pack(out)
def _average(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, weights=None, returned=False): a = _asarray(a) if weights is None: # Treat all weights as 1 avg = mean(a, axis=axis) if axis is None: weights_sum = lax.full((), core.dimension_as_value(np.size(a)), dtype=avg.dtype) else: weights_sum = lax.full_like(avg, core.dimension_as_value(a.shape[axis]), dtype=avg.dtype) else: weights = _asarray(weights) if dtypes.issubdtype(a.dtype, np.inexact): out_dtype = dtypes.result_type(a.dtype, weights.dtype) else: out_dtype = dtypes.result_type(a.dtype, weights.dtype, dtypes.float_) out_dtype = dtypes.canonicalize_dtype(out_dtype) a_shape = np.shape(a) a_ndim = len(a_shape) weights_shape = np.shape(weights) axis = None if axis is None else _canonicalize_axis(axis, a_ndim) if a_shape != weights_shape: # Make sure the dimensions work out if axis is None: raise ValueError("Axis must be specified when shapes of a and " "weights differ.") if len(weights_shape) != 1: raise ValueError("1D weights expected when shapes of a and " "weights differ.") if not core.symbolic_equal_dim(weights_shape[0], a_shape[axis]): raise ValueError("Length of weights not " "compatible with specified axis.") weights = _broadcast_to(weights, (a_ndim - 1) * (1, ) + weights_shape) weights = _moveaxis(weights, -1, axis) weights_sum = sum(weights, axis=axis, dtype=out_dtype) avg = sum(a * weights, axis=axis, dtype=out_dtype) / weights_sum if returned: if avg.shape != weights_sum.shape: weights_sum = _broadcast_to(weights_sum, avg.shape) return avg, weights_sum return avg