def threefry_random_bits(key: jnp.ndarray, bit_width, shape): """Sample uniform random bits of given width and shape using PRNG key.""" if not _is_threefry_prng_key(key): raise TypeError("_random_bits got invalid prng key.") if bit_width not in (8, 16, 32, 64): raise TypeError("requires 8-, 16-, 32- or 64-bit field width.") shape = core.as_named_shape(shape) for name, size in shape.named_items: real_size = lax.psum(1, name) if real_size != size: raise ValueError( f"The shape of axis {name} was specified as {size}, " f"but it really is {real_size}") axis_index = lax.axis_index(name) key = threefry_fold_in(key, axis_index) size = prod(shape.positional) # Compute ceil(bit_width * size / 32) in a way that is friendly to shape # polymorphism max_count, r = divmod(bit_width * size, 32) if r > 0: max_count += 1 if core.is_constant_dim(max_count): nblocks, rem = divmod(max_count, jnp.iinfo(np.uint32).max) else: nblocks, rem = 0, max_count if not nblocks: bits = threefry_2x32(key, lax.iota(np.uint32, rem)) else: keys = threefry_split(key, nblocks + 1) subkeys, last_key = keys[:-1], keys[-1] blocks = vmap(threefry_2x32, in_axes=(0, None))(subkeys, lax.iota(np.uint32, jnp.iinfo(np.uint32).max)) last = threefry_2x32(last_key, lax.iota(np.uint32, rem)) bits = lax.concatenate([blocks.ravel(), last], 0) dtype = UINT_DTYPES[bit_width] if bit_width == 64: bits = [lax.convert_element_type(x, dtype) for x in jnp.split(bits, 2)] bits = lax.shift_left(bits[0], dtype(32)) | bits[1] elif bit_width in [8, 16]: # this is essentially bits.view(dtype)[:size] bits = lax.bitwise_and( np.uint32(np.iinfo(dtype).max), lax.shift_right_logical( lax.broadcast(bits, (1, )), lax.mul( np.uint32(bit_width), lax.broadcasted_iota(np.uint32, (32 // bit_width, 1), 0)))) bits = lax.reshape(bits, (np.uint32(max_count * 32 // bit_width), ), (1, 0)) bits = lax.convert_element_type(bits, dtype)[:size] return lax.reshape(bits, shape)
def value_and_jacfwd_f(*args, **kwargs): f = lu.wrap_init(fun, kwargs) f_partial, dyn_args = argnums_partial( f, argnums, args, require_static_args_hashable=False) tree_map(partial(_check_input_dtype_jacfwd, holomorphic), dyn_args) pushfwd = partial(_jvp, f_partial, dyn_args) y, jac = vmap(pushfwd, out_axes=(None, -1))(_std_basis(dyn_args)) tree_map(partial(_check_output_dtype_jacfwd, holomorphic), y) example_args = dyn_args[0] if isinstance(argnums, int) else dyn_args return y, tree_map(partial(_jacfwd_unravel, example_args), y, jac)
def testMultivariateNormalLogpdfBatch(self, ndim, nbatch, dtype): # Regression test for #5570 rng = jtu.rand_default(self.rng()) x = rng((nbatch, ndim), dtype) mean = 5 * rng((nbatch, ndim), dtype) factor = rng((nbatch, ndim, 2 * ndim), dtype) cov = factor @ factor.transpose(0, 2, 1) result1 = lsp_stats.multivariate_normal.logpdf(x, mean, cov) result2 = api.vmap(lsp_stats.multivariate_normal.logpdf)(x, mean, cov) self.assertArraysEqual(result1, result2)
def _CheckBatching(self, op, bdim_size, bdims, shapes, dtypes, rng, rtol=None, atol=None): batched_shapes = map(partial(add_bdim, bdim_size), bdims, shapes) args = [rng(shape, dtype) for shape, dtype in zip(batched_shapes, dtypes)] args_slice = args_slicer(args, bdims) ans = api.vmap(op, bdims)(*args) if bdim_size == 0: args = [rng(shape, dtype) for shape, dtype in zip(shapes, dtypes)] out = op(*args) expected = np.zeros((0,) + out.shape, out.dtype) else: expected = np.stack([op(*args_slice(i)) for i in range(bdim_size)]) self.assertAllClose(ans, expected, rtol=rtol, atol=atol)
def _lu_python(x): """Default LU decomposition in Python, where no better version exists.""" m, n = x.shape[-2:] batch_dims = x.shape[:-2] if len(batch_dims) > 0: batch_size = np.prod(batch_dims, dtype=np.int64) lu, pivot, perm = api.vmap(_lu_blocked)(lax.reshape(x, (batch_size, m, n))) lu = lax.reshape(lu, batch_dims + (m, n)) pivot = lax.reshape(pivot, batch_dims + (min(m, n),)) perm = lax.reshape(perm, batch_dims + (m,)) else: lu, pivot, perm = _lu_blocked(x) return lu, pivot, perm
def wrapped(*args): error_context = ("on vectorized function with excluded={!r} and " "signature={!r}".format(excluded, signature)) excluded_func, args = _apply_excluded(pyfunc, excluded, args) args = tuple(map(jnp.asarray, args)) if signature is not None: input_core_dims, output_core_dims = _parse_gufunc_signature( signature) else: input_core_dims = [()] * len(args) output_core_dims = None broadcast_shape, dim_sizes = _parse_input_dimensions( args, input_core_dims, error_context) checked_func = _check_output_dims(excluded_func, dim_sizes, output_core_dims, error_context) # Rather than broadcasting all arguments to full broadcast shapes, prefer # expanding dimensions using vmap when possible. By pushing broadcasting # into vmap, we can make use of more efficient batching rules for # primitives where only some arguments are batched (e.g., for # lax_linalg.triangular_solve). vec_args = [] vmap_counts = [] for arg, core_dims in zip(args, input_core_dims): # Explicitly broadcast the dimensions already found on each argument, # because these dimensiosns might be of size 1, which vmap doesn't # handle. # TODO(shoyer): Consider squeezing out size 1 dimensions instead, and # doing all vectorization with vmap? This *might* be a little more # efficient but would require more careful book-keeping. core_shape = tuple(dim_sizes[dim] for dim in core_dims) full_shape = broadcast_shape + core_shape vec_shape = full_shape[-arg.ndim:] if arg.ndim else () vec_arg = jnp.broadcast_to(arg, vec_shape) vec_args.append(vec_arg) vmap_count = len(vec_shape) - len(core_shape) vmap_counts.append(vmap_count) vectorized_func = checked_func while any(vmap_counts): in_axes = tuple(0 if c > 0 else None for c in vmap_counts) vmap_counts = [max(c - 1, 0) for c in vmap_counts] vectorized_func = api.vmap(vectorized_func, in_axes) return vectorized_func(*vec_args)
def value_and_jacrev_f(*args, **kwargs): f = lu.wrap_init(fun, kwargs) f_partial, dyn_args = argnums_partial( f, argnums, args, require_static_args_hashable=False) tree_map(partial(_check_input_dtype_jacrev, holomorphic, allow_int), dyn_args) if not has_aux: y, pullback = _vjp(f_partial, *dyn_args) else: y, pullback, aux = _vjp(f_partial, *dyn_args, has_aux=True) tree_map(partial(_check_output_dtype_jacrev, holomorphic), y) jac = vmap(pullback)(_std_basis(y)) jac = jac[0] if isinstance(argnums, int) else jac example_args = dyn_args[0] if isinstance(argnums, int) else dyn_args jac_tree = tree_map(partial(_jacrev_unravel, y), example_args, jac) if not has_aux: return y, tree_transpose(tree_structure(example_args), tree_structure(y), jac_tree) else: return (y, aux), tree_transpose(tree_structure(example_args), tree_structure(y), jac_tree) return
def _solve(a, b): _check_solve_shapes(a, b) # Broadcast leading dimensions of b to the shape of a, as is required by # custom_linear_solve. out_shape = tuple(d_a if d_b == 1 else d_b for d_a, d_b in zip(a.shape[:-1] + (1,), b.shape)) b = jnp.broadcast_to(b, out_shape) # With custom_linear_solve, we can reuse the same factorization when # computing sensitivities. This is considerably faster. lu_, _, permutation = lu(lax.stop_gradient(a)) custom_solve = partial( lax.custom_linear_solve, lambda x: _matvec_multiply(a, x), solve=lambda _, x: lu_solve(lu_, permutation, x, trans=0), transpose_solve=lambda _, x: lu_solve(lu_, permutation, x, trans=1)) if a.ndim == b.ndim + 1: # b.shape == [..., m] return custom_solve(b) else: # b.shape == [..., m, k] return api.vmap(custom_solve, b.ndim - 1, max(a.ndim, b.ndim) - 1)(b)
def _rbg_fold_in(key: jnp.ndarray, data: int) -> jnp.ndarray: return vmap(_threefry_fold_in, (0, None), 0)(key.reshape(2, 2), data).reshape(4)
def _rbg_split(key: jnp.ndarray, num: int) -> jnp.ndarray: return vmap(_threefry_split, (0, None), 1)(key.reshape(2, 2), num).reshape(num, 4)