예제 #1
0
 def testUndefinedResourcesConstraint(self, mesh, resources):
     x = jnp.ones((2, 2))
     spec = P(resources, )
     with self.assertRaisesRegex(
             ValueError, r"One of with_sharding_constraint arguments"
             r".*" + spec_regex(spec) + r", but resource axis "
             r"x is undefined."):
         pjit(lambda x: with_sharding_constraint(x, spec),
              in_axis_resources=None,
              out_axis_resources=None)(x)
예제 #2
0
 def testRankTooLowOuts(self):
     x = jnp.arange(2)
     spec = P('x', 'y')
     error = (r"One of pjit outputs.*" + spec_regex(spec) +
              r", which implies "
              r"that it has a rank of at least 2, but it is 0")
     with self.assertRaisesRegex(ValueError, error):
         pjit(lambda x: x.sum(),
              in_axis_resources=None,
              out_axis_resources=spec)(x)
예제 #3
0
파일: pjit_test.py 프로젝트: rsepassi/jax
  def testCaching(self):
    def f(x):
      assert should_be_tracing
      return jnp.sin(x) * 2

    x = np.arange(16).reshape(4, 4)
    devices = np.array(list(jax.local_devices())[:4])
    if devices.size < 4:
      raise unittest.SkipTest("Test requires 4 devices")
    devices = devices.reshape((2, 2))
    with mesh(devices, ('x', 'y')):
      should_be_tracing = True
      pjit(f, in_axis_resources=P(('x', 'y')), out_axis_resources=None)(x)
      should_be_tracing = False
      pjit(f, in_axis_resources=P(('x', 'y')), out_axis_resources=None)(x)
    # Re-create the mesh to make sure that has no influence on caching
    with mesh(devices, ('x', 'y')):
      should_be_tracing = False
      pjit(f, in_axis_resources=P(('x', 'y')), out_axis_resources=None)(x)
예제 #4
0
    def test_pjit(self):
        with tempfile.TemporaryDirectory() as tmpdir:
            cc.initialize_cache(tmpdir)

            @partial(pjit,
                     in_axis_resources=(P('x'), P('x')),
                     out_axis_resources=None)
            def f(x, y):
                return x + y

            shape = (8, 8)
            x = np.arange(prod(shape), dtype=np.int64).reshape(shape)
            f(x, x + 1)
            files_in_directory = len(os.listdir(tmpdir))
            self.assertEqual(files_in_directory, 1)
            x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
            f(x, x + 1)
            files_in_directory = len(os.listdir(tmpdir))
            self.assertEqual(files_in_directory, 2)
예제 #5
0
 def testRankTooLowConstraint(self):
     x = jnp.arange(2)
     spec = P('x', 'y')
     error = (r"One of with_sharding_constraint arguments " +
              r"was given.*" + spec_regex(spec) + r", which implies "
              r"that it has a rank of at least 2, but it is 1")
     with self.assertRaisesRegex(ValueError, error):
         pjit(lambda x: with_sharding_constraint(x, spec),
              in_axis_resources=None,
              out_axis_resources=None)(x)
예제 #6
0
        def f(x):
            with mesh(np.array([jax.local_devices()[0]]), ('x')):

                @partial(pjit,
                         in_axis_resources=P('x'),
                         out_axis_resources=None)
                def h(x):
                    return x

                return h(x)
예제 #7
0
파일: array_test.py 프로젝트: romanngg/jax
 def test_array_sharded_astype(self):
     with jax._src.config.jax_array(True):
         global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
         input_shape = (8, 2)
         arr, input_data = create_array(
             input_shape,
             sharding.MeshPspecSharding(global_mesh, P('x', 'y')))
         arr_float32 = arr.astype(jnp.float32)
         self.assertEqual(arr_float32.dtype, np.float32)
         self.assertArraysEqual(arr_float32, input_data.astype(np.float32))
예제 #8
0
  def testPyTreeOutputs(self):
    if jax.device_count() < 2:
      raise SkipTest

    def f(x):
      return x + 1, ((x + 2, x + 3), x + 4)

    shape = (4, 4)
    x = np.arange(prod(shape)).reshape(shape)
    in_parts = (P(2, 1),)
    out_parts = (P(2, 1), ((P(1, 2), None), P(2, 1)))

    result = sharded_jit(f, in_parts, out_parts)(x)
    expected = f(x)
    self.assertAllClose(result, expected, check_dtypes=False)

    out_parts = None
    result = sharded_jit(f, in_parts, out_parts)(x)
    self.assertAllClose(result, expected, check_dtypes=False)
예제 #9
0
    def testShardingConstraint(self):
        if jax.local_device_count() < 2:
            raise SkipTest("requires 2 devices")

        def f(x):
            y = x + 1
            y = with_sharding_constraint(y, P(1, 2))
            return y * 2

        shape = (8, 8)
        x = np.arange(prod(shape)).reshape(shape)
        expected = (x + 1) * 2

        # Matching sharded_jit partitions
        actual = sharded_jit(f, in_parts=P(2, 1), out_parts=P(2, 1))(x)
        self.assertAllClose(actual, expected, check_dtypes=False)
        self.assertLen(actual.device_buffers, 2)
        # TODO(jblespiau): We can simply use buf.xla_shape() when version 0.1.58 is
        # the default.
        self.assertEqual(
            getattr(actual.device_buffers[0], "xla_shape",
                    actual.device_buffers[0].shape)().dimensions(), (4, 8))
        self.assertEqual(
            getattr(actual.device_buffers[1], "xla_shape",
                    actual.device_buffers[1].shape)().dimensions(), (4, 8))

        # Mismatched sharded_jit partitions
        with self.assertRaisesRegex(
                ValueError,
                r"with_sharding_constraint with partitions=PartitionSpec\(1, 2\) "
                r"\(total partitions: 2\) doesn't match expected number of partitions: "
                r"4. If these partitions look right, check outer sharded_jit and/or "
                r"other with_sharding_constraint calls."):
            sharded_jit(f, in_parts=P(2, 2), out_parts=P(2, 2))(x)

        # Replicated sharded_jit
        actual = sharded_jit(f, in_parts=None, out_parts=None)(x)
        self.assertAllClose(actual, expected, check_dtypes=False)
        self.assertLen(actual.device_buffers, 2)
        self.assertAllClose(actual.device_buffers[0].to_py(),
                            actual.device_buffers[1].to_py(),
                            check_dtypes=False)
예제 #10
0
  def testInAxesNone(self):
    shape = (4, 4)
    replicas = 2
    in_partitions = (P(2, 1), None, None)
    out_partitions = P(2, 1)
    in_axes = (None, None, 0)
    x = y = np.arange(prod(shape), dtype=np.float32).reshape(shape)
    dummy = np.arange(replicas, dtype=np.float32) + 1
    num_shards = replicas * np.prod(in_partitions[0])
    if num_shards > jax.local_device_count():
      raise SkipTest("requires %d devices" % num_shards)

    def f(x, y, _):
      return x @ y

    result = pmap(
        sharded_jit(f, in_parts=in_partitions, out_parts=out_partitions),
        in_axes=in_axes)(x, y, dummy)
    expected = pmap(f, in_axes=in_axes)(x, y, dummy)
    self.assertAllClose(result, expected, check_dtypes=True)
예제 #11
0
파일: pjit_test.py 프로젝트: rsepassi/jax
 def testNonDivisibleConstraint(self, mesh, resources):
   x = jnp.ones((3, 2))
   spec = P(resources,)
   mesh_size = str(np.prod([dim[1] for dim in mesh], dtype=np.int64))
   with self.assertRaisesRegex(ValueError,
                               r"One of with_sharding_constraint arguments"
                               r".*" + spec_regex(spec) + r".*implies that the size of "
                               r"its dimension 0 should be divisible by " + mesh_size +
                               r", but it is equal to 3"):
     pjit(lambda x: with_sharding_constraint(x, spec),
          in_axis_resources=None, out_axis_resources=None)(x)
예제 #12
0
 def testNonDivisibleArgs(self, mesh, resources):
     x = jnp.ones((3, 2))
     spec = P(resources, None)
     mesh_size = str(np.prod([dim[1] for dim in mesh], dtype=np.int64))
     with self.assertRaisesRegex(
             ValueError,
             r"One of pjit arguments.*" + spec_regex(spec) + r".*"
             r"implies that the size of its dimension 0 should be "
             r"divisible by " + mesh_size + r", but it is equal to 3"):
         pjit(lambda x: x, in_axis_resources=spec,
              out_axis_resources=None)(x)
예제 #13
0
파일: pjit_test.py 프로젝트: rsepassi/jax
 def testConstraintShardsXMapAxis(self):
   spec = P('x')
   f = xmap(lambda x: with_sharding_constraint(x, axis_resources=spec),
            in_axes=['i', ...], out_axes=['i', ...], axis_resources={'i': 'x'})
   x = jnp.arange(4).reshape((2, 2))
   error = (r"with_sharding_constraint input has an axis resources specification of " +
            spec_regex(spec) + r" that uses one or more mesh axes already used by "
            r"xmap to partition a named axis appearing in its named_shape \(both "
            r"use mesh axes `x`\)")
   with self.assertRaisesRegex(JAXTypeError, error):
     f(x)
예제 #14
0
    def input(self, x):
        # [batch, seq, dim]
        projected = self.input_proj(x)

        # [batch, seq, mp, dim//mp]
        projected = maybe_shard(projected, P("dp", None, "mp"))
        mp_split = jnp.reshape(projected,
                               projected.shape[:-1] + (self.mp_num, -1))
        mp_split = maybe_shard(mp_split, P("dp", None, "mp", None))

        local_dim = self.d_head * self.n_head // self.mp_num

        q, v, k, ff = jnp.split(mp_split,
                                [local_dim, local_dim * 2, local_dim * 3],
                                axis=-1)

        q = self.head_split(q)
        v = self.head_split(v)
        k = self.head_split(k)

        return q, v, k, ff
예제 #15
0
 def test_array_delete(self):
     with jax._src.config.jax_array(True):
         global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
         input_shape = (8, 2)
         arr, _ = create_array(
             input_shape,
             sharding.MeshPspecSharding(global_mesh, P('x', 'y')))
         arr.delete()
         with self.assertRaisesRegex(ValueError, 'Array has been deleted.'):
             arr._check_if_deleted()
         self.assertIsNone(arr._npy_value)
         self.assertIsNone(arr._arrays)
예제 #16
0
  def testNotEnoughDevices(self):
    ndevices = jax.local_device_count()

    @partial(sharded_jit, in_parts=P(ndevices + 1), out_parts=None)
    def f(x):
      return x + x

    with self.assertRaisesRegex(
        ValueError,
        f"sharded_jit computation requires {ndevices + 1} devices, "
        f"but only {ndevices} devices are available."):
      f(np.ones(ndevices + 1))
예제 #17
0
    def test_checkpointing_with_bigger_shape(self):
        global_mesh = create_global_mesh((2, 2), ('x', 'y'))
        global_input_shape = (8, 2)
        num = util.prod(global_input_shape)

        # First GDA
        global_input_data1 = np.arange(num).reshape(global_input_shape)

        def cb1(index):
            return global_input_data1[index]

        gda1 = GlobalDeviceArray.from_callback(global_input_shape, global_mesh,
                                               P('x', 'y'), cb1)
        ckpt_dir1 = pathlib.Path(self.create_tempdir('first').full_path)

        ckpt_paths = [str(ckpt_dir1)]
        tspecs = jax.tree_map(serialization.get_tensorstore_spec, ckpt_paths)

        serialization.run_serialization([gda1], tspecs)

        m1, = serialization.run_deserialization(
            [create_global_mesh((4, 2), ('x', 'y'))],
            [P('x', 'y')],
            tspecs,
            [(12, 2)],
        )

        expected_data = {
            0: np.array([[0], [2], [4]]),
            1: np.array([[1], [3], [5]]),
            2: np.array([[6], [8], [10]]),
            3: np.array([[7], [9], [11]]),
            4: np.array([[12], [14], [0]]),
            5: np.array([[13], [15], [0]]),
            6: np.array([[0], [0], [0]]),
            7: np.array([[0], [0], [0]]),
        }

        for l in m1.local_shards:
            self.assertArraysEqual(l.data.to_py(), expected_data[l.device.id])
        def shard_strategy(shape_dtype, parallel):
            if shape_dtype.ndim <= 1:
                return P()
            # embedding/projection layers
            elif shape_dtype.shape == (config["n_vocab"], config["d_model"]):
                return P(parallel, None)
            elif shape_dtype.shape == (config["d_model"], config["n_vocab"]):
                return P(None, parallel)

            # a transformer layer
            elif shape_dtype.shape[0] == config["layers"]:
                if shape_dtype.ndim == 2:
                    # a channel wise variable (e.g. layernorm parameters)
                    # replicate it for speed
                    return P(None)
                elif shape_dtype.ndim == 3:
                    # a weight matrix
                    matrix_size = shape_dtype.shape[1:]

                    assert matrix_size[0] != matrix_size[
                        1]  # this case is ambiguous

                    if matrix_size[0] == config["d_model"]:
                        # shard along the axis which is _not_ the model dimension
                        return P(None, None, parallel)
                    elif matrix_size[1] == config["d_model"]:
                        return P(None, parallel, None)
                else:
                    raise NotImplementedError("borked")

            else:
                raise NotImplementedError("borked")
예제 #19
0
파일: pjit_test.py 프로젝트: rsepassi/jax
  def test_pjit_gsda_wrong_resource_for_gsda_input(self):
    global_mesh = create_global_mesh((4, 2), ('x', 'y'))
    global_input_shape = (8, 2)
    mesh_axes = ['x']
    global_input_data = np.arange(
        prod(global_input_shape), dtype=np.float32).reshape(global_input_shape)
    def cb(index):
      return global_input_data[index]

    gda_obj = global_device_array.GlobalDeviceArray.from_callback(
        global_input_shape, global_mesh, mesh_axes, cb)
    with self.assertRaisesWithLiteralMatch(ValueError, (
        "Got an input GDA to pjit with different partitioning than specified "
        "in the in_axis_resources argument to pjit. The paritioning must "
        'match, or use `jax.experimental.pjit.FROM_GDA` in `in_axis_resources`. '
        "Got GDA spec: <partitions=(('x',),) sync=2>, "
        "pjit spec: <partitions=(('x',), ('y',)) sync=2>")):
      @partial(pjit, in_axis_resources=P('x', 'y'), out_axis_resources=P('x', 'y'))
      def f(x):
        return x

      f(gda_obj)
예제 #20
0
파일: pjit_test.py 프로젝트: rsepassi/jax
 def testNestedDifferentResources(self):
   @partial(pjit, in_axis_resources=P('x'), out_axis_resources=None)
   def f(x):
     with mesh(np.array([jax.local_devices()[0]]), ('x')):
       @partial(pjit, in_axis_resources=P('x'), out_axis_resources=None)
       def h(x):
         return x
       return h(x)
   xshape = (2, 5, 6)
   x = jnp.arange(np.prod(xshape)).reshape(xshape)
   with self.assertRaisesRegex(RuntimeError,
                               "Changing the physical mesh is not allowed.*"):
     f(x)
예제 #21
0
  def test_gda_block_until_ready(self):
    global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
    global_input_shape = (8, 2)
    mesh_axes = P(('x', 'y'))
    global_input_data = np.arange(
        prod(global_input_shape)).reshape(global_input_shape)

    def cb(index):
      return global_input_data[index]

    gda = GlobalDeviceArray.from_callback(
        global_input_shape, global_mesh, mesh_axes, cb)

    self.assertTrue(gda.block_until_ready() is gda)
예제 #22
0
  def test_partition_spec_mismatch_semantically_equivalent(self):
    global_mesh = create_global_mesh((4, 2), ('x', 'y'))
    global_input_shape = (8, 2)
    mesh_axes = [None]
    global_input_data = np.arange(
        prod(global_input_shape), dtype=np.float32).reshape(global_input_shape)

    def cb(index):
      return global_input_data[index]

    with jax._src.config.parallel_functions_output_gda(True):
      gda_obj = global_device_array.GlobalDeviceArray.from_callback(
          global_input_shape, global_mesh, mesh_axes, cb)

      @partial(pjit, in_axis_resources=P(None), out_axis_resources=P(None))
      def f(x):
        return x

      output_gda = f(gda_obj)
      # Ensure output_gda._mesh_axes = P() is matched with P(None).
      self.assertEqual(output_gda._mesh_axes, ())
      # P(None) is in_axis_resources.
      f(output_gda)
예제 #23
0
파일: pjit_test.py 프로젝트: rsepassi/jax
  def testLowerCompile(self):
    @partial(pjit,
             in_axis_resources=P(('x', 'y'),),
             out_axis_resources=P(('x', 'y'),))
    def f(x, y):
      return x @ y

    shape = (8, 8)
    x = jnp.arange(np.prod(shape)).reshape(shape)
    expected = x @ (x + 1)

    exe = f.lower(x, x + 1).compile()
    actual = exe(x, x + 1)

    splits = np.split(expected, 4)
    self.assertAllClose(actual.device_buffers[0].to_py(), splits[0],
                        check_dtypes=False)
    self.assertAllClose(actual.device_buffers[1].to_py(), splits[1],
                        check_dtypes=False)
    self.assertAllClose(actual.device_buffers[2].to_py(), splits[2],
                        check_dtypes=False)
    self.assertAllClose(actual.device_buffers[3].to_py(), splits[3],
                        check_dtypes=False)
예제 #24
0
파일: pjit_test.py 프로젝트: rsepassi/jax
 def testVmapModifiesAxisResources(self):
   h = pjit(lambda x, y: (x + y, x, y), in_axis_resources=P('x'), out_axis_resources=None)
   x = jnp.arange(4)
   y = jnp.arange(5*4).reshape((5, 4))
   jaxpr = jax.make_jaxpr(jax.vmap(h, in_axes=(None, 0)))(x, y).jaxpr
   eqn = jaxpr.eqns[0]
   self.assertIs(eqn.primitive, pjit_p)
   x_sync, y_sync = (spec.sync for spec in eqn.params['in_axis_resources'])
   self.assertEqual(x_sync, SpecSync.IN_SYNC)
   self.assertEqual(y_sync, SpecSync.DIM_PERMUTE)
   x_sync, y_sync, z_sync = (spec.sync for spec in eqn.params['out_axis_resources'])
   self.assertEqual(x_sync, SpecSync.DIM_PERMUTE)
   self.assertEqual(y_sync, SpecSync.IN_SYNC)
   self.assertEqual(z_sync, SpecSync.DIM_PERMUTE)
예제 #25
0
 def test_gda_equality_raises_not_implemented(self):
   global_mesh = jtu.create_global_mesh((1, 2), ('x', 'y'))
   global_input_shape = (8, 2)
   mesh_axes = P(None,)
   global_input_data = np.arange(
       prod(global_input_shape)).reshape(global_input_shape)
   def cb(index):
     return global_input_data[index]
   input_gda = GlobalDeviceArray.from_callback(
       global_input_shape, global_mesh, mesh_axes, cb)
   same_input_gda = GlobalDeviceArray.from_callback(
       global_input_shape, global_mesh, mesh_axes, cb)
   with self.assertRaisesRegex(NotImplementedError,
       'GlobalDeviceArray equality is intentionally unimplemented.'):
     input_gda == same_input_gda
예제 #26
0
파일: pjit_test.py 프로젝트: rsepassi/jax
  def testTwoMeshAxisSharding(self):
    @partial(pjit,
             in_axis_resources=P(('x', 'y'),),
             out_axis_resources=P(('x', 'y'),))
    def f(x, y):
      return x @ y

    shape = (8, 8)
    x = jnp.arange(np.prod(shape)).reshape(shape)
    actual = f(x, x + 1)
    expected = x @ (x + 1)
    self.assertAllClose(actual, expected, check_dtypes=False)
    self.assertIsInstance(actual, pxla.ShardedDeviceArray)
    self.assertLen(actual.device_buffers, 4)

    splits = np.split(expected, 4)
    self.assertAllClose(actual.device_buffers[0].to_py(), splits[0],
                        check_dtypes=False)
    self.assertAllClose(actual.device_buffers[1].to_py(), splits[1],
                        check_dtypes=False)
    self.assertAllClose(actual.device_buffers[2].to_py(), splits[2],
                        check_dtypes=False)
    self.assertAllClose(actual.device_buffers[3].to_py(), splits[3],
                        check_dtypes=False)
예제 #27
0
    def test_async_checkpointing(self):
        global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
        global_input_shape = (8, 2)
        mesh_axes = P('x', 'y')
        num = util.prod(global_input_shape)

        # First GDA
        global_input_data1 = np.arange(num).reshape(global_input_shape)

        def cb1(index):
            return global_input_data1[index]

        gda1 = GlobalDeviceArray.from_callback(global_input_shape, global_mesh,
                                               mesh_axes, cb1)
        temp_ckpt_dir1 = pathlib.Path(
            self.create_tempdir('temp_first').full_path)
        ckpt_dir1 = str(temp_ckpt_dir1).replace('temp_first', 'first')

        s_tspecs = jax.tree_map(serialization.get_tensorstore_spec,
                                [str(temp_ckpt_dir1)])

        manager = serialization.GlobalAsyncCheckpointManager()
        manager.serialize([gda1],
                          s_tspecs,
                          temp_checkpoint_dir=temp_ckpt_dir1,
                          final_checkpoint_dir=ckpt_dir1)
        manager.wait_until_finished()

        d_tspecs = jax.tree_map(serialization.get_tensorstore_spec,
                                [str(ckpt_dir1)])
        m1, = manager.deserialize([global_mesh], [mesh_axes], d_tspecs)
        self.assertArraysEqual(m1.local_shards[0].data.to_py(),
                               np.array([[0], [2]]))
        self.assertArraysEqual(m1.local_shards[1].data.to_py(),
                               np.array([[1], [3]]))
        self.assertEqual(m1.local_shards[0].data.shape, (2, 1))
        self.assertEqual(m1.dtype, np.int32)

        # Will throw `file already exists` error when `tf.io.gfile.rename`.
        # `wait_until_finished` will raise the error.
        with self.assertRaises(Exception):
            ckpt_dir1 = pathlib.Path(self.create_tempdir('first').full_path)
            manager1 = serialization.GlobalAsyncCheckpointManager()
            manager1.serialize([gda1],
                               s_tspecs,
                               temp_checkpoint_dir=temp_ckpt_dir1,
                               final_checkpoint_dir=ckpt_dir1)
            manager1.wait_until_finished()
예제 #28
0
    def test_mesh_pspec_sharding_interface(self):
        mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
        pspec = P('y', 'x')
        global_shape = (8, 4)
        mp_sharding = sharding.MeshPspecSharding(mesh, pspec)
        di_map = mp_sharding.devices_indices_map(global_shape)
        op_sharding = mp_sharding._to_xla_op_sharding(len(global_shape))
        device_assignment = mp_sharding._device_assignment()

        self.assertEqual(di_map[mesh.devices.flat[0]],
                         (slice(0, 4), slice(0, 1)))
        self.assertArraysEqual(device_assignment, list(mesh.devices.flat))
        self.assertEqual(op_sharding.type, xc.OpSharding.Type.OTHER)
        self.assertListEqual(op_sharding.tile_assignment_dimensions, [2, 4])
        self.assertListEqual(op_sharding.tile_assignment_devices,
                             [0, 2, 4, 6, 1, 3, 5, 7])
예제 #29
0
 def test_gda_str_repr(self):
   global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
   global_input_shape = (8, 2)
   mesh_axes = P(('x', 'y'))
   global_input_data = np.arange(
       prod(global_input_shape)).reshape(global_input_shape)
   def cb(index):
     return global_input_data[index]
   gda = GlobalDeviceArray.from_callback(
       global_input_shape, global_mesh, mesh_axes, cb)
   self.assertEqual(str(gda),
                    'GlobalDeviceArray(shape=(8, 2), dtype=int32)')
   self.assertEqual(
       repr(gda), ('GlobalDeviceArray(shape=(8, 2), dtype=int32, '
                   "global_mesh_shape={'x': 4, 'y': 2}, "
                   "mesh_axes=PartitionSpec(('x', 'y'),))"))
예제 #30
0
파일: pjit_test.py 프로젝트: rsepassi/jax
  def testDeviceBufferAval(self):

    @partial(pjit, in_axis_resources=None, out_axis_resources=P('x'))
    def f(x):
      return x

    shape = (2, 2)
    x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
    actual = f(x)
    expected = x
    self.assertAllClose(actual, expected, check_dtypes=False)
    self.assertIsInstance(actual, pxla.ShardedDeviceArray)
    self.assertLen(actual.device_buffers, 1)
    self.assertAllClose(
        actual.device_buffers[0].to_py(), expected, check_dtypes=False)
    # Repro for a bug on device_buffer aval
    _ = repr(actual.device_buffers)