예제 #1
0
def threefry_random_bits(key: jnp.ndarray, bit_width, shape):
    """Sample uniform random bits of given width and shape using PRNG key."""
    if not _is_threefry_prng_key(key):
        raise TypeError("_random_bits got invalid prng key.")
    if bit_width not in (8, 16, 32, 64):
        raise TypeError("requires 8-, 16-, 32- or 64-bit field width.")
    shape = core.as_named_shape(shape)
    for name, size in shape.named_items:
        real_size = lax.psum(1, name)
        if real_size != size:
            raise ValueError(
                f"The shape of axis {name} was specified as {size}, "
                f"but it really is {real_size}")
        axis_index = lax.axis_index(name)
        key = threefry_fold_in(key, axis_index)
    size = prod(shape.positional)
    # Compute ceil(bit_width * size / 32) in a way that is friendly to shape
    # polymorphism
    max_count, r = divmod(bit_width * size, 32)
    if r > 0:
        max_count += 1

    if core.is_constant_dim(max_count):
        nblocks, rem = divmod(max_count, jnp.iinfo(np.uint32).max)
    else:
        nblocks, rem = 0, max_count

    if not nblocks:
        bits = threefry_2x32(key, lax.iota(np.uint32, rem))
    else:
        keys = threefry_split(key, nblocks + 1)
        subkeys, last_key = keys[:-1], keys[-1]
        blocks = vmap(threefry_2x32,
                      in_axes=(0, None))(subkeys,
                                         lax.iota(np.uint32,
                                                  jnp.iinfo(np.uint32).max))
        last = threefry_2x32(last_key, lax.iota(np.uint32, rem))
        bits = lax.concatenate([blocks.ravel(), last], 0)

    dtype = UINT_DTYPES[bit_width]
    if bit_width == 64:
        bits = [lax.convert_element_type(x, dtype) for x in jnp.split(bits, 2)]
        bits = lax.shift_left(bits[0], dtype(32)) | bits[1]
    elif bit_width in [8, 16]:
        # this is essentially bits.view(dtype)[:size]
        bits = lax.bitwise_and(
            np.uint32(np.iinfo(dtype).max),
            lax.shift_right_logical(
                lax.broadcast(bits, (1, )),
                lax.mul(
                    np.uint32(bit_width),
                    lax.broadcasted_iota(np.uint32, (32 // bit_width, 1), 0))))
        bits = lax.reshape(bits, (np.uint32(max_count * 32 // bit_width), ),
                           (1, 0))
        bits = lax.convert_element_type(bits, dtype)[:size]
    return lax.reshape(bits, shape)
예제 #2
0
 def value_and_jacfwd_f(*args, **kwargs):
     f = lu.wrap_init(fun, kwargs)
     f_partial, dyn_args = argnums_partial(
         f, argnums, args, require_static_args_hashable=False)
     tree_map(partial(_check_input_dtype_jacfwd, holomorphic), dyn_args)
     pushfwd = partial(_jvp, f_partial, dyn_args)
     y, jac = vmap(pushfwd, out_axes=(None, -1))(_std_basis(dyn_args))
     tree_map(partial(_check_output_dtype_jacfwd, holomorphic), y)
     example_args = dyn_args[0] if isinstance(argnums, int) else dyn_args
     return y, tree_map(partial(_jacfwd_unravel, example_args), y, jac)
예제 #3
0
    def testMultivariateNormalLogpdfBatch(self, ndim, nbatch, dtype):
        # Regression test for #5570
        rng = jtu.rand_default(self.rng())
        x = rng((nbatch, ndim), dtype)
        mean = 5 * rng((nbatch, ndim), dtype)
        factor = rng((nbatch, ndim, 2 * ndim), dtype)
        cov = factor @ factor.transpose(0, 2, 1)

        result1 = lsp_stats.multivariate_normal.logpdf(x, mean, cov)
        result2 = api.vmap(lsp_stats.multivariate_normal.logpdf)(x, mean, cov)
        self.assertArraysEqual(result1, result2)
예제 #4
0
 def _CheckBatching(self, op, bdim_size, bdims, shapes, dtypes, rng,
                    rtol=None, atol=None):
   batched_shapes = map(partial(add_bdim, bdim_size), bdims, shapes)
   args = [rng(shape, dtype) for shape, dtype in zip(batched_shapes, dtypes)]
   args_slice = args_slicer(args, bdims)
   ans = api.vmap(op, bdims)(*args)
   if bdim_size == 0:
     args = [rng(shape, dtype) for shape, dtype in zip(shapes, dtypes)]
     out = op(*args)
     expected = np.zeros((0,) + out.shape, out.dtype)
   else:
     expected = np.stack([op(*args_slice(i)) for i in range(bdim_size)])
   self.assertAllClose(ans, expected, rtol=rtol, atol=atol)
예제 #5
0
파일: linalg.py 프로젝트: nbswords/jax
def _lu_python(x):
  """Default LU decomposition in Python, where no better version exists."""
  m, n = x.shape[-2:]
  batch_dims = x.shape[:-2]
  if len(batch_dims) > 0:
    batch_size = np.prod(batch_dims, dtype=np.int64)
    lu, pivot, perm = api.vmap(_lu_blocked)(lax.reshape(x, (batch_size, m, n)))
    lu = lax.reshape(lu, batch_dims + (m, n))
    pivot = lax.reshape(pivot, batch_dims + (min(m, n),))
    perm = lax.reshape(perm, batch_dims + (m,))
  else:
    lu, pivot, perm = _lu_blocked(x)
  return lu, pivot, perm
예제 #6
0
    def wrapped(*args):
        error_context = ("on vectorized function with excluded={!r} and "
                         "signature={!r}".format(excluded, signature))
        excluded_func, args = _apply_excluded(pyfunc, excluded, args)
        args = tuple(map(jnp.asarray, args))

        if signature is not None:
            input_core_dims, output_core_dims = _parse_gufunc_signature(
                signature)
        else:
            input_core_dims = [()] * len(args)
            output_core_dims = None

        broadcast_shape, dim_sizes = _parse_input_dimensions(
            args, input_core_dims, error_context)

        checked_func = _check_output_dims(excluded_func, dim_sizes,
                                          output_core_dims, error_context)

        # Rather than broadcasting all arguments to full broadcast shapes, prefer
        # expanding dimensions using vmap when possible. By pushing broadcasting
        # into vmap, we can make use of more efficient batching rules for
        # primitives where only some arguments are batched (e.g., for
        # lax_linalg.triangular_solve).

        vec_args = []
        vmap_counts = []

        for arg, core_dims in zip(args, input_core_dims):
            # Explicitly broadcast the dimensions already found on each argument,
            # because these dimensiosns might be of size 1, which vmap doesn't
            # handle.
            # TODO(shoyer): Consider squeezing out size 1 dimensions instead, and
            # doing all vectorization with vmap? This *might* be a little more
            # efficient but would require more careful book-keeping.
            core_shape = tuple(dim_sizes[dim] for dim in core_dims)
            full_shape = broadcast_shape + core_shape
            vec_shape = full_shape[-arg.ndim:] if arg.ndim else ()

            vec_arg = jnp.broadcast_to(arg, vec_shape)
            vec_args.append(vec_arg)

            vmap_count = len(vec_shape) - len(core_shape)
            vmap_counts.append(vmap_count)

        vectorized_func = checked_func
        while any(vmap_counts):
            in_axes = tuple(0 if c > 0 else None for c in vmap_counts)
            vmap_counts = [max(c - 1, 0) for c in vmap_counts]
            vectorized_func = api.vmap(vectorized_func, in_axes)
        return vectorized_func(*vec_args)
예제 #7
0
 def value_and_jacrev_f(*args, **kwargs):
     f = lu.wrap_init(fun, kwargs)
     f_partial, dyn_args = argnums_partial(
         f, argnums, args, require_static_args_hashable=False)
     tree_map(partial(_check_input_dtype_jacrev, holomorphic, allow_int),
              dyn_args)
     if not has_aux:
         y, pullback = _vjp(f_partial, *dyn_args)
     else:
         y, pullback, aux = _vjp(f_partial, *dyn_args, has_aux=True)
     tree_map(partial(_check_output_dtype_jacrev, holomorphic), y)
     jac = vmap(pullback)(_std_basis(y))
     jac = jac[0] if isinstance(argnums, int) else jac
     example_args = dyn_args[0] if isinstance(argnums, int) else dyn_args
     jac_tree = tree_map(partial(_jacrev_unravel, y), example_args, jac)
     if not has_aux:
         return y, tree_transpose(tree_structure(example_args),
                                  tree_structure(y), jac_tree)
     else:
         return (y, aux), tree_transpose(tree_structure(example_args),
                                         tree_structure(y), jac_tree)
     return
예제 #8
0
파일: linalg.py 프로젝트: nbswords/jax
def _solve(a, b):
  _check_solve_shapes(a, b)

  # Broadcast leading dimensions of b to the shape of a, as is required by
  # custom_linear_solve.
  out_shape = tuple(d_a if d_b == 1 else d_b
                    for d_a, d_b in zip(a.shape[:-1] + (1,), b.shape))
  b = jnp.broadcast_to(b, out_shape)

  # With custom_linear_solve, we can reuse the same factorization when
  # computing sensitivities. This is considerably faster.
  lu_, _, permutation = lu(lax.stop_gradient(a))
  custom_solve = partial(
      lax.custom_linear_solve,
      lambda x: _matvec_multiply(a, x),
      solve=lambda _, x: lu_solve(lu_, permutation, x, trans=0),
      transpose_solve=lambda _, x: lu_solve(lu_, permutation, x, trans=1))
  if a.ndim == b.ndim + 1:
    # b.shape == [..., m]
    return custom_solve(b)
  else:
    # b.shape == [..., m, k]
    return api.vmap(custom_solve, b.ndim - 1, max(a.ndim, b.ndim) - 1)(b)
예제 #9
0
def _rbg_fold_in(key: jnp.ndarray, data: int) -> jnp.ndarray:
    return vmap(_threefry_fold_in, (0, None), 0)(key.reshape(2, 2),
                                                 data).reshape(4)
예제 #10
0
def _rbg_split(key: jnp.ndarray, num: int) -> jnp.ndarray:
    return vmap(_threefry_split, (0, None), 1)(key.reshape(2, 2),
                                               num).reshape(num, 4)