コード例 #1
0
ファイル: remat_impl.py プロジェクト: xueeinstein/jax
def _dummy_remat_result(aval: core.AbstractValue):
    """A result that will be discarded"""
    if aval is core.abstract_token:
        return lax.create_token()
    else:
        return lax.broadcast(np.array(0, dtype=aval.dtype),
                             aval.shape)  # type: ignore
コード例 #2
0
ファイル: linalg.py プロジェクト: ahoenselaar/jax
def inv(a):
    if jnp.ndim(a) < 2 or a.shape[-1] != a.shape[-2]:
        raise ValueError(
            f"Argument to inv must have shape [..., n, n], got {a.shape}.")
    return solve(
        a, lax.broadcast(jnp.eye(a.shape[-1], dtype=lax.dtype(a)),
                         a.shape[:-2]))
コード例 #3
0
def threefry_random_bits(key: jnp.ndarray, bit_width, shape):
    """Sample uniform random bits of given width and shape using PRNG key."""
    if not _is_threefry_prng_key(key):
        raise TypeError("_random_bits got invalid prng key.")
    if bit_width not in (8, 16, 32, 64):
        raise TypeError("requires 8-, 16-, 32- or 64-bit field width.")
    shape = core.as_named_shape(shape)
    for name, size in shape.named_items:
        real_size = lax.psum(1, name)
        if real_size != size:
            raise ValueError(
                f"The shape of axis {name} was specified as {size}, "
                f"but it really is {real_size}")
        axis_index = lax.axis_index(name)
        key = threefry_fold_in(key, axis_index)
    size = prod(shape.positional)
    # Compute ceil(bit_width * size / 32) in a way that is friendly to shape
    # polymorphism
    max_count, r = divmod(bit_width * size, 32)
    if r > 0:
        max_count += 1

    if core.is_constant_dim(max_count):
        nblocks, rem = divmod(max_count, jnp.iinfo(np.uint32).max)
    else:
        nblocks, rem = 0, max_count

    if not nblocks:
        bits = threefry_2x32(key, lax.iota(np.uint32, rem))
    else:
        keys = threefry_split(key, nblocks + 1)
        subkeys, last_key = keys[:-1], keys[-1]
        blocks = vmap(threefry_2x32,
                      in_axes=(0, None))(subkeys,
                                         lax.iota(np.uint32,
                                                  jnp.iinfo(np.uint32).max))
        last = threefry_2x32(last_key, lax.iota(np.uint32, rem))
        bits = lax.concatenate([blocks.ravel(), last], 0)

    dtype = UINT_DTYPES[bit_width]
    if bit_width == 64:
        bits = [lax.convert_element_type(x, dtype) for x in jnp.split(bits, 2)]
        bits = lax.shift_left(bits[0], dtype(32)) | bits[1]
    elif bit_width in [8, 16]:
        # this is essentially bits.view(dtype)[:size]
        bits = lax.bitwise_and(
            np.uint32(np.iinfo(dtype).max),
            lax.shift_right_logical(
                lax.broadcast(bits, (1, )),
                lax.mul(
                    np.uint32(bit_width),
                    lax.broadcasted_iota(np.uint32, (32 // bit_width, 1), 0))))
        bits = lax.reshape(bits, (np.uint32(max_count * 32 // bit_width), ),
                           (1, 0))
        bits = lax.convert_element_type(bits, dtype)[:size]
    return lax.reshape(bits, shape)
コード例 #4
0
ファイル: attention.py プロジェクト: backpropper/flax
 def tri(n, m, k=0):
   # Tie in the key to avoid the mask becoming a constant.
   # This way XLA can construct the mask during computation and fuse it
   # with the attention ops.
   x = jnp.arange(n, dtype=jnp.int32)
   y = jnp.arange(m, dtype=jnp.int32)
   mask = lax.ge(
       (lax.broadcast_in_dim(x, shape=(n, m), broadcast_dimensions=(0,))) + k,
       lax.broadcast(y, [n]))
   return mask
コード例 #5
0
def broadcast_to(arr, shape):
  """Like Numpy's broadcast_to but doesn't necessarily return views."""
  arr = arr if isinstance(arr, ndarray) or isscalar(arr) else array(arr)
  if _shape(arr) != shape:
    # TODO(mattjj): revise this to call lax.broadcast_in_dim rather than
    # lax.broadcast and lax.transpose
    _broadcast_shapes(shape, _shape(arr))  # error checking
    nlead = len(shape) - len(_shape(arr))
    diff, = onp.where(onp.not_equal(shape[nlead:], _shape(arr)))

    new_dims = tuple(range(nlead)) + tuple(nlead + diff)
    kept_dims = tuple(onp.delete(onp.arange(len(shape)), new_dims))
    perm = onp.argsort(new_dims + kept_dims)

    broadcast_dims = onp.take(shape, new_dims)
    squeezed_array = squeeze(arr, diff)
    return lax.transpose(lax.broadcast(squeezed_array, broadcast_dims), perm)
  else:
    return arr
コード例 #6
0
ファイル: lax_vmap_test.py プロジェクト: x1489/jax
 def testBroadcast(self, shape, dtype, broadcast_sizes, bdims):
     rng = jtu.rand_default(self.rng())
     op = lambda x: lax.broadcast(x, broadcast_sizes)
     self._CheckBatching(op, 5, bdims, (shape, ), (dtype, ), rng)
コード例 #7
0
ファイル: api_test.py プロジェクト: yyht/jax
 def f(a, b, c):
     a = lax.broadcast(a, (2, ))
     return lax.select(a, b, c)
コード例 #8
0
 def testBroadcastGrad(self, shape, dtype, broadcast_sizes, rng_factory):
   rng = rng_factory(self.rng())
   args = (rng(shape, dtype),)
   broadcast = lambda x: lax.broadcast(x, broadcast_sizes)
   check_grads(broadcast, args, 2, ["fwd", "rev"], eps=1.)
コード例 #9
0
ファイル: lax_autodiff_test.py プロジェクト: zhaowilliam/jax
 def testBroadcastGrad(self, shape, dtype, broadcast_sizes):
   rng = jtu.rand_default(self.rng())
   args = (rng(shape, dtype),)
   broadcast = lambda x: lax.broadcast(x, broadcast_sizes)
   check_grads(broadcast, args, 2, ["fwd", "rev"], eps=1.)
コード例 #10
0
 def testBroadcast(self, shape, dtype, broadcast_sizes, bdims, rng_factory):
     rng = rng_factory(self.rng())
     op = lambda x: lax.broadcast(x, broadcast_sizes)
     self._CheckBatching(op, 5, bdims, (shape, ), (dtype, ), rng)
コード例 #11
0
def full(shape, fill_value, dtype=None):
    if dtype:
        fill_value = lax.convert_element_type(fill_value, dtype)
    return lax.broadcast(fill_value, tuple(shape))