Esempio n. 1
0
    def test_pjit(self):
        with tempfile.TemporaryDirectory() as tmpdir:
            cc.initialize_cache(tmpdir)

            @partial(pjit,
                     in_axis_resources=(P('x'), P('x')),
                     out_axis_resources=None)
            def f(x, y):
                return x + y

            shape = (8, 8)
            x = np.arange(prod(shape), dtype=np.int64).reshape(shape)
            f(x, x + 1)
            files_in_directory = len(os.listdir(tmpdir))
            self.assertEqual(files_in_directory, 1)
            x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
            f(x, x + 1)
            files_in_directory = len(os.listdir(tmpdir))
            self.assertEqual(files_in_directory, 2)
Esempio n. 2
0
def threefry_random_bits(key: jnp.ndarray, bit_width, shape):
    """Sample uniform random bits of given width and shape using PRNG key."""
    if not _is_threefry_prng_key(key):
        raise TypeError("threefry_random_bits got invalid prng key.")
    if bit_width not in (8, 16, 32, 64):
        raise TypeError("requires 8-, 16-, 32- or 64-bit field width.")
    shape = core.as_named_shape(shape)
    for name, size in shape.named_items:
        real_size = lax.psum(1, name)
        if real_size != size:
            raise ValueError(
                f"The shape of axis {name} was specified as {size}, "
                f"but it really is {real_size}")
        axis_index = lax.axis_index(name)
        key = threefry_fold_in(key, axis_index)
    size = prod(shape.positional)
    # Compute ceil(bit_width * size / 32) in a way that is friendly to shape
    # polymorphism
    max_count, r = divmod(bit_width * size, 32)
    if r > 0:
        max_count += 1

    if core.is_constant_dim(max_count):
        nblocks, rem = divmod(max_count, jnp.iinfo(np.uint32).max)
    else:
        nblocks, rem = 0, max_count

    if not nblocks:
        bits = threefry_2x32(key, lax.iota(np.uint32, rem))
    else:
        keys = threefry_split(key, nblocks + 1)
        subkeys, last_key = keys[:-1], keys[-1]
        blocks = vmap(threefry_2x32,
                      in_axes=(0, None))(subkeys,
                                         lax.iota(np.uint32,
                                                  jnp.iinfo(np.uint32).max))
        last = threefry_2x32(last_key, lax.iota(np.uint32, rem))
        bits = lax.concatenate([blocks.ravel(), last], 0)

    dtype = UINT_DTYPES[bit_width]
    if bit_width == 64:
        bits = [lax.convert_element_type(x, dtype) for x in jnp.split(bits, 2)]
        bits = lax.shift_left(bits[0], dtype(32)) | bits[1]
    elif bit_width in [8, 16]:
        # this is essentially bits.view(dtype)[:size]
        bits = lax.bitwise_and(
            np.uint32(np.iinfo(dtype).max),
            lax.shift_right_logical(
                lax.broadcast(bits, (1, )),
                lax.mul(
                    np.uint32(bit_width),
                    lax.broadcasted_iota(np.uint32, (32 // bit_width, 1), 0))))
        bits = lax.reshape(bits, ((max_count * 32 // bit_width), ), (1, 0))
        bits = lax.convert_element_type(bits, dtype)[:size]
    return lax.reshape(bits, shape)
Esempio n. 3
0
def psum_bind(*args, axis_name, axis_index_groups):
  if all(not isinstance(x, core.Tracer) for x in args):
    if axis_index_groups is not None:
      size = len(axis_index_groups[0])
    elif isinstance(axis_name, (list, tuple)):
      size = prod([core.axis_frame(name).size for name in axis_name])  # type: ignore
    else:
      size = core.axis_frame(axis_name).size  # type: ignore
    return tuple(size * x for x in args)
  return core.Primitive.bind(
      psum_p, *args, axis_name=axis_name, axis_index_groups=axis_index_groups)
Esempio n. 4
0
  def testCompilationCache(self):
    if jax.local_device_count() < 2:
      raise SkipTest("requires 2 devices")
    f = lambda x: x + 1
    sharded_f = sharded_jit(f, in_parts=P(2), out_parts=P(2))
    shape = (2,)
    x = np.arange(prod(shape), dtype=np.float32).reshape(shape)

    with jtu.assert_num_jit_and_pmap_compilations(1):
      sharded_f(x)
      sharded_f(x)
Esempio n. 5
0
def with_mesh(named_shape: MeshSpec) -> Generator[None, None, None]:
    """Test utility for setting up meshes given mesh data from `schedules`."""
    # This is similar to the `with_mesh` function above, but isn't a decorator.
    axis_names, shape = unzip2(named_shape)
    size = prod(shape)
    local_devices = list(jax.local_devices())
    if len(local_devices) < size:
        raise SkipTest(f"Test requires {size} local devices")
    mesh_devices = np.array(local_devices[:size]).reshape(shape)
    with mesh(mesh_devices, axis_names):
        yield
Esempio n. 6
0
def get_shard_shape(global_shape, global_mesh, mesh_axes) -> Shape:
    chunk_size = []
    for mesh_axis, size in zip(mesh_axes, global_shape):
        if not mesh_axis:
            chunk_size.append(size)
        elif isinstance(mesh_axis, tuple):
            m = prod([global_mesh.shape[ma] for ma in mesh_axis])
            chunk_size.append(size // m)
        else:
            chunk_size.append(size // global_mesh.shape[mesh_axis])
    if len(chunk_size) != len(global_shape):
        chunk_size.extend(global_shape[len(chunk_size):])
    return tuple(chunk_size)
Esempio n. 7
0
def _parse_dim(spec):
  if '+' in spec:
    return np.sum(map(_parse_dim, spec.split('+')))
  elif '*' in spec:
    return prod(map(_parse_dim, spec.split('*')))
  elif spec.isdigit() or spec.startswith('-') and spec[1:].isdigit():
    return _parse_lit(spec)
  elif spec[0] in _identifiers:
    return _parse_id(spec)
  elif spec == '_':
    return _monomorphic_dim
  else:
    raise ShapeSyntaxError(spec)
Esempio n. 8
0
  def test_gda_block_until_ready(self):
    global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
    global_input_shape = (8, 2)
    mesh_axes = P(('x', 'y'))
    global_input_data = np.arange(
        prod(global_input_shape)).reshape(global_input_shape)

    def cb(index):
      return global_input_data[index]

    gda = GlobalDeviceArray.from_callback(
        global_input_shape, global_mesh, mesh_axes, cb)

    self.assertTrue(gda.block_until_ready() is gda)
Esempio n. 9
0
def gda_construction_callback(mesh_axes, state):
  # Keep the mesh containing 8 local devices as using >8 local devices is
  # unrealistic. Since `from_callback` measures `device_put` time as well, it
  # dominates when local devices are for example 2048 (local devices will never
  # be 2048).
  global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
  global_input_shape = (2048, 2048)
  global_input_data = np.arange(
      prod(global_input_shape)).reshape(global_input_shape)
  def cb(index):
    return global_input_data[index]

  while state:
    gda.GlobalDeviceArray.from_callback(
        global_input_shape, global_mesh, mesh_axes, cb)
Esempio n. 10
0
 def test_gda_equality_raises_not_implemented(self):
   global_mesh = jtu.create_global_mesh((1, 2), ('x', 'y'))
   global_input_shape = (8, 2)
   mesh_axes = P(None,)
   global_input_data = np.arange(
       prod(global_input_shape)).reshape(global_input_shape)
   def cb(index):
     return global_input_data[index]
   input_gda = GlobalDeviceArray.from_callback(
       global_input_shape, global_mesh, mesh_axes, cb)
   same_input_gda = GlobalDeviceArray.from_callback(
       global_input_shape, global_mesh, mesh_axes, cb)
   with self.assertRaisesRegex(NotImplementedError,
       'GlobalDeviceArray equality is intentionally unimplemented.'):
     input_gda == same_input_gda
Esempio n. 11
0
async def async_deserialize(mesh,
                            mesh_axes,
                            tensorstore_spec,
                            global_shape=None,
                            dtype=None):
    t = ts.open(ts.Spec(tensorstore_spec), open=True,
                context=TS_CONTEXT).result()
    shape = t.shape if global_shape is None else global_shape
    requires_padding = prod(shape) > prod(t.shape)

    if requires_padding:
        new_shard_shape = gda.get_shard_shape(tuple(shape), mesh, mesh_axes)

    async def cb(index):
        if requires_padding:
            # This is needed because the shape the array was saved with is smaller
            # than the requested shape of the array in which it will be reloaded. So
            # the extra values will be filled with 0s.
            out = np.zeros(new_shard_shape, dtype=t.dtype.numpy_dtype)
            requested_domain = ts.IndexTransform(
                input_shape=shape)[index].domain
            restricted_domain = t.domain.intersect(requested_domain)
            await ts.array(out)[ts.d[:].translate_to[requested_domain.origin]
                                ][restricted_domain].write(t[restricted_domain]
                                                           )
        else:
            out = await t[index].read()

        if dtype is not None:
            # Cast while reloading on process to avoid 2 copies on device if the
            # casting is done on device.
            return out.astype(dtype)
        return out

    return await create_async_gda_from_callback(tuple(shape), mesh, mesh_axes,
                                                cb)
    def test_async_checkpointing(self):
        global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
        global_input_shape = (8, 2)
        mesh_axes = P('x', 'y')
        num = util.prod(global_input_shape)

        # First GDA
        global_input_data1 = np.arange(num).reshape(global_input_shape)

        def cb1(index):
            return global_input_data1[index]

        gda1 = GlobalDeviceArray.from_callback(global_input_shape, global_mesh,
                                               mesh_axes, cb1)
        temp_ckpt_dir1 = pathlib.Path(
            self.create_tempdir('temp_first').full_path)
        ckpt_dir1 = str(temp_ckpt_dir1).replace('temp_first', 'first')

        s_tspecs = jax.tree_map(serialization.get_tensorstore_spec,
                                [str(temp_ckpt_dir1)])

        manager = serialization.GlobalAsyncCheckpointManager()
        manager.serialize([gda1],
                          s_tspecs,
                          temp_checkpoint_dir=temp_ckpt_dir1,
                          final_checkpoint_dir=ckpt_dir1)
        manager.wait_until_finished()

        d_tspecs = jax.tree_map(serialization.get_tensorstore_spec,
                                [str(ckpt_dir1)])
        m1, = manager.deserialize([global_mesh], [mesh_axes], d_tspecs)
        self.assertArraysEqual(m1.local_shards[0].data.to_py(),
                               np.array([[0], [2]]))
        self.assertArraysEqual(m1.local_shards[1].data.to_py(),
                               np.array([[1], [3]]))
        self.assertEqual(m1.local_shards[0].data.shape, (2, 1))
        self.assertEqual(m1.dtype, np.int32)

        # Will throw `file already exists` error when `tf.io.gfile.rename`.
        # `wait_until_finished` will raise the error.
        with self.assertRaises(Exception):
            ckpt_dir1 = pathlib.Path(self.create_tempdir('first').full_path)
            manager1 = serialization.GlobalAsyncCheckpointManager()
            manager1.serialize([gda1],
                               s_tspecs,
                               temp_checkpoint_dir=temp_ckpt_dir1,
                               final_checkpoint_dir=ckpt_dir1)
            manager1.wait_until_finished()
Esempio n. 13
0
 def test_gda_str_repr(self):
   global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
   global_input_shape = (8, 2)
   mesh_axes = P(('x', 'y'))
   global_input_data = np.arange(
       prod(global_input_shape)).reshape(global_input_shape)
   def cb(index):
     return global_input_data[index]
   gda = GlobalDeviceArray.from_callback(
       global_input_shape, global_mesh, mesh_axes, cb)
   self.assertEqual(str(gda),
                    'GlobalDeviceArray(shape=(8, 2), dtype=int32)')
   self.assertEqual(
       repr(gda), ('GlobalDeviceArray(shape=(8, 2), dtype=int32, '
                   "global_mesh_shape={'x': 4, 'y': 2}, "
                   "mesh_axes=PartitionSpec(('x', 'y'),))"))
Esempio n. 14
0
  def testPyTreeOutputs(self):
    if jax.local_device_count() < 4:
      raise SkipTest("requires 4 devices")

    def f(x):
      return x + 1, ((x + 2, x + 3), x + 4)

    shape = (2, 4, 4)
    x = np.arange(prod(shape)).reshape(shape)
    in_parts = (P(2, 1),)
    out_parts = (P(2, 1), ((P(1, 2), None), P(2, 1)))

    result = pmap(sharded_jit(f, in_parts=in_parts, out_parts=out_parts))(x)
    expected = pmap(f)(x)

    self.assertAllClose(result, expected, check_dtypes=False)
Esempio n. 15
0
  def testBasic1D(self):
    @partial(pjit,
             in_axis_resources=(P('x'), P('x')),
             out_axis_resources=None)
    def f(x, y):
      return x + y

    shape = (8, 8)
    x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
    actual = f(x, x + 1)
    expected = x + (x + 1)
    self.assertAllClose(actual, expected, check_dtypes=False)
    self.assertIsInstance(actual, pxla.ShardedDeviceArray)
    self.assertLen(actual.device_buffers, 2)
    self.assertAllClose(actual.device_buffers[0].to_py(), expected,
                        check_dtypes=False)
Esempio n. 16
0
def gda_construction_raw(mesh_shape, mesh_axes, state):
    # `device_put` time is not measured in this benchmark. All the devices here
    # are local.
    global_mesh = jtu.create_global_mesh(mesh_shape, ("x", "y"))
    global_input_shape = (2048, 2048)
    global_input_data = np.arange(
        prod(global_input_shape)).reshape(global_input_shape)
    global_indices = gda.get_shard_indices(global_input_shape, global_mesh,
                                           mesh_axes)
    dbs = [
        jax.device_put(global_input_data[global_indices[device]], device)
        for device in global_mesh.local_devices
    ]

    while state:
        gda.GlobalDeviceArray(global_input_shape, global_mesh, mesh_axes, dbs)
Esempio n. 17
0
  def testNestedShardingConstraint(self):
    if jax.local_device_count() < 2:
      raise SkipTest("requires 2 devices")

    shape = (8, 8)

    @jit
    def f(x):
      return lax.while_loop(lambda i: i[0,0] < 10.,
                            lambda i: with_sharding_constraint(i + 1., P(2, 1)),
                            x)

    x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
    expected = x + 10.
    actual = sharded_jit(f, in_parts=None, out_parts=None)(x)
    self.assertAllClose(actual, expected, check_dtypes=False)
    self.assertLen(actual.device_buffers, 2)
Esempio n. 18
0
  def testBasic(self):
    if jax.device_count() < 2:
      raise SkipTest

    @partial(sharded_jit, in_parts=(P(2, 1), P(2, 1)), out_parts=None)
    def f(x, y):
      return x + y

    shape = (8, 8)
    x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
    actual = f(x, x + 1)
    expected = x + (x + 1)
    self.assertAllClose(actual, expected, check_dtypes=False)
    self.assertIsInstance(actual, pxla.ShardedDeviceArray)
    self.assertLen(actual.device_buffers, 2)
    self.assertAllClose(actual.device_buffers[0].to_py(), expected,
                        check_dtypes=False)
Esempio n. 19
0
 def init(key, shape, dtype=dtype):
     if len(shape) < 2:
         raise ValueError(
             "orthogonal initializer requires at least a 2D shape")
     n_rows, n_cols = prod(shape) // shape[column_axis], shape[column_axis]
     matrix_shape = (n_cols, n_rows) if n_rows < n_cols else (n_rows,
                                                              n_cols)
     A = random.normal(key, matrix_shape, dtype)
     Q, R = jnp.linalg.qr(A)
     diag_sign = lax.broadcast_to_rank(jnp.sign(jnp.diag(R)), rank=Q.ndim)
     Q *= diag_sign  # needed for a uniform distribution
     if n_rows < n_cols: Q = Q.T
     Q = jnp.reshape(
         Q,
         tuple(np.delete(shape, column_axis)) + (shape[column_axis], ))
     Q = jnp.moveaxis(Q, -1, column_axis)
     return scale * Q
Esempio n. 20
0
  def testDeviceBufferAval(self):

    @partial(pjit, in_axis_resources=None, out_axis_resources=P('x'))
    def f(x):
      return x

    shape = (2, 2)
    x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
    actual = f(x)
    expected = x
    self.assertAllClose(actual, expected, check_dtypes=False)
    self.assertIsInstance(actual, pxla.ShardedDeviceArray)
    self.assertLen(actual.device_buffers, 1)
    self.assertAllClose(
        actual.device_buffers[0].to_py(), expected, check_dtypes=False)
    # Repro for a bug on device_buffer aval
    _ = repr(actual.device_buffers)
Esempio n. 21
0
  def _runTest(self, f, in_partitions, out_partitions, dtype=np.float32):
    """Compares pmap(sharded_jit(f, ...)) to pmap(f)"""
    shape = (2, 4, 4)
    num_shards = shape[0] * np.prod(in_partitions[0])
    if num_shards > jax.local_device_count():
      raise SkipTest("requires %d devices" % num_shards)

    x = np.arange(prod(shape)).reshape(shape)
    y = x + 1
    result = pmap(
        sharded_jit(f, in_parts=in_partitions, out_parts=out_partitions))(x, y)
    expected = pmap(f)(x, y)
    self.assertAllClose(result, expected, check_dtypes=False)

    flat_result = tree_util.tree_flatten(result)[0]
    for r in flat_result:
      self.assertTrue(isinstance(r, pxla.ShardedDeviceArray))
      self.assertEqual(len(r.device_buffers), num_shards)
Esempio n. 22
0
 def test_gda_1d_shard(self, mesh_axes, expected_index, expected_shard_shape,
                        expected_replica_ids):
   global_mesh = create_global_mesh((8,), ('x'))
   global_input_shape = (16,)
   global_input_data = np.arange(prod(global_input_shape)).reshape(-1)
   def cb(index):
     return global_input_data[index]
   gda = GlobalDeviceArray.from_callback(global_input_shape,
                                                 global_mesh,
                                                 mesh_axes, cb)
   self.assertEqual(gda.local_shards[0].index, expected_index[0])
   self.assertArraysEqual(gda.local_data(0),
                          global_input_data[expected_index[0]])
   self.assertEqual(gda.local_shards[1].index, expected_index[1])
   self.assertArraysEqual(gda.local_data(1),
                          global_input_data[expected_index[1]])
   self.assertEqual(gda.local_data(0).shape, expected_shard_shape)
   replica_ids = [i.replica_id for i in gda.local_shards]
   self.assertListEqual(replica_ids, expected_replica_ids)
Esempio n. 23
0
def _irfft_transpose(t, fft_lengths):
    # The transpose of IRFFT is the RFFT of the cotangent times a scaling
    # factor and a mask. The mask scales the cotangent for the Hermitian
    # symmetric components of the RFFT by a factor of two, since these components
    # are de-duplicated in the RFFT.
    x = fft(t, xla_client.FftType.RFFT, fft_lengths)
    n = x.shape[-1]
    is_odd = fft_lengths[-1] % 2
    full = partial(lax.full_like, t, dtype=t.dtype)
    mask = lax.concatenate([
        full(1.0, shape=(1, )),
        full(2.0, shape=(n - 2 + is_odd, )),
        full(1.0, shape=(1 - is_odd, ))
    ],
                           dimension=0)
    scale = 1 / prod(fft_lengths)
    out = scale * mask * x
    assert out.dtype == _complex_dtype(t.dtype), (out.dtype, t.dtype)
    return out
Esempio n. 24
0
  def test_pjit_gsda_mesh_mismatch(self):
    global_mesh = create_global_mesh((4, 2), ('x', 'y'))
    global_input_shape = (8, 2)
    mesh_axes = ['x', 'y']
    global_input_data = np.arange(
        prod(global_input_shape), dtype=np.float32).reshape(global_input_shape)
    def cb(index):
      return global_input_data[index]

    gda_obj = global_device_array.GlobalDeviceArray.from_callback(
        global_input_shape, global_mesh, mesh_axes, cb)

    with self.assertRaisesRegex(ValueError,
                                "Pjit's mesh and GDA's mesh should be equal."):
      @partial(pjit, in_axis_resources=FROM_GDA, out_axis_resources=P('x', 'y'))
      def f(x):
        return x

      f(gda_obj)
Esempio n. 25
0
  def test_gda_batched_callback(self):
    global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
    global_input_shape = (8, 2)
    mesh_axes = P(('x', 'y'))
    global_input_data = np.arange(
        prod(global_input_shape)).reshape(global_input_shape)

    def cb(indices):
      self.assertEqual(len(indices), len(global_mesh.local_devices))
      return [global_input_data[index] for index in indices]

    gda = GlobalDeviceArray.from_batched_callback(
        global_input_shape, global_mesh, mesh_axes, cb)
    expected_first_shard_value = np.array([[0, 1]])
    self.assertArraysEqual(gda.local_data(0).to_py(),
                           expected_first_shard_value)
    expected_second_shard_value = np.array([[2, 3]])
    self.assertArraysEqual(gda.local_data(1).to_py(),
                           expected_second_shard_value)
Esempio n. 26
0
  def testPyTreeOutputs(self):
    if jax.device_count() < 2:
      raise SkipTest

    def f(x):
      return x + 1, ((x + 2, x + 3), x + 4)

    shape = (4, 4)
    x = np.arange(prod(shape)).reshape(shape)
    in_parts = (P(2, 1),)
    out_parts = (P(2, 1), ((P(1, 2), None), P(2, 1)))

    result = sharded_jit(f, in_parts, out_parts)(x)
    expected = f(x)
    self.assertAllClose(result, expected, check_dtypes=False)

    out_parts = None
    result = sharded_jit(f, in_parts, out_parts)(x)
    self.assertAllClose(result, expected, check_dtypes=False)
Esempio n. 27
0
  def testShardingConstraint(self):
    if jax.local_device_count() < 2:
      raise SkipTest("requires 2 devices")

    def f(x):
      y = x + 1
      y = with_sharding_constraint(y, P(1,2))
      return y * 2

    shape = (8, 8)
    x = np.arange(prod(shape)).reshape(shape)
    expected = (x + 1) * 2

    # Matching sharded_jit partitions
    actual = sharded_jit(f, in_parts=P(2,1), out_parts=P(2,1))(x)
    self.assertAllClose(actual, expected, check_dtypes=False)
    self.assertLen(actual.device_buffers, 2)
    # TODO(jblespiau): We can simply use buf.xla_shape() when version 0.1.58 is
    # the default.
    self.assertEqual(
        getattr(actual.device_buffers[0], "xla_shape",
                actual.device_buffers[0].shape)().dimensions(), (4, 8))
    self.assertEqual(
        getattr(actual.device_buffers[1], "xla_shape",
                actual.device_buffers[1].shape)().dimensions(), (4, 8))

    # Mismatched sharded_jit partitions
    with self.assertRaisesRegex(
        ValueError,
        r"with_sharding_constraint with partitions=PartitionSpec\(1, 2\) "
        r"\(total partitions: 2\) doesn't match expected number of partitions: "
        r"4. If these partitions look right, check outer sharded_jit and/or "
        r"other with_sharding_constraint calls."):
      sharded_jit(f, in_parts=P(2,2), out_parts=P(2,2))(x)

    # Replicated sharded_jit
    actual = sharded_jit(f, in_parts=None, out_parts=None)(x)
    self.assertAllClose(actual, expected, check_dtypes=False)
    self.assertLen(actual.device_buffers, 2)
    self.assertAllClose(actual.device_buffers[0].to_py(),
                        actual.device_buffers[1].to_py(),
                        check_dtypes=False)
Esempio n. 28
0
  def testGradOfShardingConstraint(self):
    if jax.local_device_count() < 4:
      raise SkipTest("requires 4 devices")

    @partial(sharded_jit, in_parts=P(4,1), out_parts=None)
    def f(x):
      y = x + 1
      p, vjp_f = vjp(lambda z: jnp.sin(with_sharding_constraint(z, P(2,2))), y)
      return vjp_f(p)

    def expected_f(x):
      y = x + 1
      p, vjp_f = vjp(lambda z: jnp.sin(z), y)
      return vjp_f(p)

    shape = (4, 4)
    x = jnp.arange(prod(shape), dtype=jnp.float32).reshape(shape)
    actual = f(x)
    expected = expected_f(x)
    self.assertAllClose(actual, expected, check_dtypes=False)
Esempio n. 29
0
  def testManyArgs(self):
    if jax.local_device_count() < 4:
      raise SkipTest("requires 4 devices")

    num_args = 200

    def f(*args):
      return jnp.asarray(args).sum()

    shape = (2, 4, 4)
    args = [np.arange(prod(shape)).reshape(shape)] * num_args
    in_partitions = (P(2, 1),) * num_args
    out_partitions = None
    result = pmap(sharded_jit(
        f, in_parts=in_partitions, out_parts=out_partitions))(*args)
    expected = pmap(f)(*args)

    self.assertAllClose(result, expected, check_dtypes=False)
    self.assertTrue(isinstance(result, pxla.ShardedDeviceArray))
    self.assertEqual(len(result.device_buffers), 4)
Esempio n. 30
0
  def testInAxesNone(self):
    shape = (4, 4)
    replicas = 2
    in_partitions = (P(2, 1), None, None)
    out_partitions = P(2, 1)
    in_axes = (None, None, 0)
    x = y = np.arange(prod(shape), dtype=np.float32).reshape(shape)
    dummy = np.arange(replicas, dtype=np.float32) + 1
    num_shards = replicas * np.prod(in_partitions[0])
    if num_shards > jax.local_device_count():
      raise SkipTest("requires %d devices" % num_shards)

    def f(x, y, _):
      return x @ y

    result = pmap(
        sharded_jit(f, in_parts=in_partitions, out_parts=out_partitions),
        in_axes=in_axes)(x, y, dummy)
    expected = pmap(f, in_axes=in_axes)(x, y, dummy)
    self.assertAllClose(result, expected, check_dtypes=True)