Exemple #1
0
 def testAxisSizes(self, mesh, axis_resources):
     result = xmap(lambda: lax.axis_index('i'),
                   in_axes=(),
                   out_axes=['i', ...],
                   axis_sizes={'i': 6},
                   axis_resources=dict(axis_resources))()
     self.assertAllClose(result, jnp.arange(6, dtype=result.dtype))
Exemple #2
0
def threefry_random_bits(key: jnp.ndarray, bit_width, shape):
    """Sample uniform random bits of given width and shape using PRNG key."""
    if not _is_threefry_prng_key(key):
        raise TypeError("_random_bits got invalid prng key.")
    if bit_width not in (8, 16, 32, 64):
        raise TypeError("requires 8-, 16-, 32- or 64-bit field width.")
    shape = core.as_named_shape(shape)
    for name, size in shape.named_items:
        real_size = lax.psum(1, name)
        if real_size != size:
            raise ValueError(
                f"The shape of axis {name} was specified as {size}, "
                f"but it really is {real_size}")
        axis_index = lax.axis_index(name)
        key = threefry_fold_in(key, axis_index)
    size = prod(shape.positional)
    # Compute ceil(bit_width * size / 32) in a way that is friendly to shape
    # polymorphism
    max_count, r = divmod(bit_width * size, 32)
    if r > 0:
        max_count += 1

    if core.is_constant_dim(max_count):
        nblocks, rem = divmod(max_count, jnp.iinfo(np.uint32).max)
    else:
        nblocks, rem = 0, max_count

    if not nblocks:
        bits = threefry_2x32(key, lax.iota(np.uint32, rem))
    else:
        keys = threefry_split(key, nblocks + 1)
        subkeys, last_key = keys[:-1], keys[-1]
        blocks = vmap(threefry_2x32,
                      in_axes=(0, None))(subkeys,
                                         lax.iota(np.uint32,
                                                  jnp.iinfo(np.uint32).max))
        last = threefry_2x32(last_key, lax.iota(np.uint32, rem))
        bits = lax.concatenate([blocks.ravel(), last], 0)

    dtype = UINT_DTYPES[bit_width]
    if bit_width == 64:
        bits = [lax.convert_element_type(x, dtype) for x in jnp.split(bits, 2)]
        bits = lax.shift_left(bits[0], dtype(32)) | bits[1]
    elif bit_width in [8, 16]:
        # this is essentially bits.view(dtype)[:size]
        bits = lax.bitwise_and(
            np.uint32(np.iinfo(dtype).max),
            lax.shift_right_logical(
                lax.broadcast(bits, (1, )),
                lax.mul(
                    np.uint32(bit_width),
                    lax.broadcasted_iota(np.uint32, (32 // bit_width, 1), 0))))
        bits = lax.reshape(bits, (np.uint32(max_count * 32 // bit_width), ),
                           (1, 0))
        bits = lax.convert_element_type(bits, dtype)[:size]
    return lax.reshape(bits, shape)
 def _internal_inverse_pth_root_all():
   preconditioners = jnp.array(all_statistics)
   current_replica = lax.axis_index(batch_axis_name)
   preconditioners, errors = _matrix_inverse_pth_root_vmap(
       all_statistics[current_replica], all_exponents[current_replica])
   preconditioners = jax.lax.all_gather(preconditioners, batch_axis_name)
   errors = jax.lax.all_gather(errors, batch_axis_name)
   preconditioners_flat = _unbatch(preconditioners)
   errors_flat = _unbatch(errors)
   return preconditioners_flat, errors_flat
Exemple #4
0
 def testNamedShape(self, mesh, axis_resources):
   x = np.arange(4,)
   y = 2
   f = xmap(lambda x, y: (x + y, y * lax.axis_index('i')),
            in_axes=(['i', ...], {}),
            out_axes=(['i', ...], ['i', ...]),
            axis_resources=dict(axis_resources))
   z, w = f(x, y)
   self.assertEqual(z.aval.named_shape, {})
   self.assertEqual(w.aval.named_shape, {})
Exemple #5
0
 def testResourceConflictNestInner(self):
   f = xmap(lambda x: lax.axis_index('i') + x,
            in_axes=[], out_axes=['i'], axis_sizes={'i': 4}, axis_resources={'i': 'x'})
   h = xmap(f, in_axes=['j', ...], out_axes=['j', ...], axis_resources={'j': 'x'})
   x = np.arange(4)
   error = (r"Axes `i` and `j` are both mapped to the resource `x`, but they "
            r"coincide in the named_shape of a value returned from a primitive "
            r"add created at .*")
   with self.assertRaisesRegex(JAXTypeError, error):
     h(x)
Exemple #6
0
def one_hot(x: Array,
            num_classes: int,
            *,
            dtype: Any = jnp.float64,
            axis: Union[int, AxisName] = -1) -> Array:
    """One-hot encodes the given indicies.

  Each index in the input ``x`` is encoded as a vector of zeros of length
  ``num_classes`` with the element at ``index`` set to one::

    >>> jax.nn.one_hot(jnp.array([0, 1, 2]), 3)
    DeviceArray([[1., 0., 0.],
                  [0., 1., 0.],
                  [0., 0., 1.]], dtype=float32)

  Indicies outside the range [0, num_classes) will be encoded as zeros::

    >>> jax.nn.one_hot(jnp.array([-1, 3]), 3)
    DeviceArray([[0., 0., 0.],
                 [0., 0., 0.]], dtype=float32)

  Args:
    x: A tensor of indices.
    num_classes: Number of classes in the one-hot dimension.
    dtype: optional, a float dtype for the returned values (default float64 if
      jax_enable_x64 is true, otherwise float32).
    axis: the axis or axes along which the function should be
      computed.
  """
    num_classes = core.concrete_or_error(
        int, num_classes,
        "The error arose in jax.nn.one_hot argument `num_classes`.")
    dtype = dtypes.canonicalize_dtype(dtype)
    x = jnp.asarray(x)
    try:
        output_pos_axis = util.canonicalize_axis(axis, x.ndim + 1)
    except TypeError:
        axis_size = lax.psum(1, axis)
        if num_classes != axis_size:
            raise ValueError(
                f"Expected num_classes to match the size of axis {axis}, "
                f"but {num_classes} != {axis_size}") from None
        axis_idx = lax.axis_index(axis)
        return jnp.asarray(x == axis_idx, dtype=dtype)
    axis = operator.index(axis)
    lhs = lax.expand_dims(x, (axis, ))
    rhs_shape = [1] * x.ndim
    rhs_shape.insert(output_pos_axis, num_classes)
    rhs = lax.broadcast_in_dim(jnp.arange(num_classes, dtype=x.dtype),
                               rhs_shape, (output_pos_axis, ))
    return jnp.asarray(lhs == rhs, dtype=dtype)
Exemple #7
0
def _one_hot(x: Array, num_classes: int, *, dtype: Any,
             axis: Union[int, AxisName]) -> Array:
    num_classes = core.concrete_or_error(
        int, num_classes,
        "The error arose in jax.nn.one_hot argument `num_classes`.")
    dtype = dtypes.canonicalize_dtype(dtype)
    x = jnp.asarray(x)
    try:
        output_pos_axis = util.canonicalize_axis(axis, x.ndim + 1)
    except TypeError:
        axis_size = lax.psum(1, axis)
        if num_classes != axis_size:
            raise ValueError(
                f"Expected num_classes to match the size of axis {axis}, "
                f"but {num_classes} != {axis_size}") from None
        axis_idx = lax.axis_index(axis)
        return jnp.asarray(x == axis_idx, dtype=dtype)
    axis = operator.index(axis)  # type: ignore[arg-type]
    lhs = lax.expand_dims(x, (axis, ))
    rhs_shape = [1] * x.ndim
    rhs_shape.insert(output_pos_axis, num_classes)
    rhs = lax.broadcasted_iota(x.dtype, rhs_shape, output_pos_axis)
    return jnp.asarray(lhs == rhs, dtype=dtype)
def get_axis_index(axis_name=None):
  if JAX_MODE:
    return lax.axis_index(axis_name)
  ctx = tf.distribute.get_replica_context()
  return ctx.replica_id_in_sync_group
Exemple #9
0
 def testAxisIndex(self):
   x = np.arange(10)
   self.assertAllClose(
     vmap(lambda x: x - lax.axis_index('i'), axis_name='i')(x),
     x - np.arange(x.shape[0]))
Exemple #10
0
 def f(x):
   return x * lax.axis_index('i')
Exemple #11
0
 def f(x):
   i_size = lax.psum(1, 'i')
   return x + lax.axis_index('i') + i_size * lax.axis_index('j')
Exemple #12
0
 def testCollectiveAllGather(self):
   x = jnp.arange(4)
   result = xmap(lambda x: lax.all_gather(x, 'i') + lax.axis_index('i'),
                 in_axes=['i', ...], out_axes=['i', ...])(x)
   self.assertAllClose(result, x + x[jnp.newaxis].T)
Exemple #13
0
 def plate_fun(key, *args, **kwargs):
     key = random.fold_in(key, lax.axis_index(name))
     return f(key, *args, **kwargs)