def test_checkpointing(self):
        global_mesh = create_global_mesh((4, 2), ('x', 'y'))
        global_input_shape = (8, 2)
        mesh_axes = ['x', 'y']
        num = util.prod(global_input_shape)
        global_input_data1 = np.arange(num).reshape(global_input_shape)

        def cb1(index):
            return global_input_data1[index]

        gsda1 = GlobalDeviceArray.from_callback(global_input_shape,
                                                global_mesh, mesh_axes, cb1)
        ckpt_dir1 = pathlib.Path(self.create_tempdir('first').full_path)

        global_input_data2 = np.arange(num,
                                       num + num).reshape(global_input_shape)

        def cb2(index):
            return global_input_data2[index]

        gsda2 = GlobalDeviceArray.from_callback(global_input_shape,
                                                global_mesh, mesh_axes, cb2)
        ckpt_dir2 = pathlib.Path(self.create_tempdir('second').full_path)

        ckpt_paths = [str(ckpt_dir1), str(ckpt_dir2)]

        tspecs = jax.tree_map(serialization.get_tensorstore_spec, ckpt_paths)

        # Async Serialization below.
        async def run_serializer():
            future_writer = jax.tree_map(serialization.async_serialize,
                                         ckpt_paths, [gsda1, gsda2], tspecs)
            return await asyncio.gather(*future_writer)

        asyncio.run(run_serializer())

        # Async Deserialization below.
        async def run():
            future_gsdas = jax.tree_map(serialization.async_deserialize,
                                        ckpt_paths, [global_mesh, global_mesh],
                                        [mesh_axes, ['x']], tspecs)
            return await asyncio.gather(*future_gsdas)

        m1, m2 = asyncio.run(run())

        self.assertArraysEqual(m1.local_shards[0].data.to_py(),
                               np.array([[0], [2]]))
        self.assertArraysEqual(m1.local_shards[1].data.to_py(),
                               np.array([[1], [3]]))
        self.assertEqual(m1.local_shards[0].data.shape, (2, 1))
        self.assertEqual(m1.dtype, np.int32)

        self.assertArraysEqual(m2.local_shards[0].data.to_py(),
                               np.array([[16, 17], [18, 19]]))
        self.assertArraysEqual(m2.local_shards[1].data.to_py(),
                               np.array([[16, 17], [18, 19]]))
        self.assertEqual(m2.local_shards[0].data.shape, (2, 2))
        self.assertEqual(m2.dtype, np.int32)
Beispiel #2
0
  def test_checkpointing(self):
    global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
    global_input_shape = (8, 2)
    mesh_axes = P('x', 'y')
    num = util.prod(global_input_shape)

    # First GDA
    global_input_data1 = np.arange(num).reshape(global_input_shape)
    def cb1(index):
      return global_input_data1[index]
    gda1 = GlobalDeviceArray.from_callback(global_input_shape, global_mesh,
                                           mesh_axes, cb1)
    ckpt_dir1 = pathlib.Path(self.create_tempdir('first').full_path)

    # Second GDA
    global_input_data2 = np.arange(num, num + num).reshape(global_input_shape)
    def cb2(index):
      return global_input_data2[index]
    gda2 = GlobalDeviceArray.from_callback(global_input_shape, global_mesh,
                                           mesh_axes, cb2)
    ckpt_dir2 = pathlib.Path(self.create_tempdir('second').full_path)

    # Third GDA
    def cb3(index):
      return np.array([])
    global_mesh1d = jtu.create_global_mesh((8,), ('x',))
    gda3 = GlobalDeviceArray.from_callback((0,), global_mesh1d, P(None), cb3)
    ckpt_dir3 = pathlib.Path(self.create_tempdir('third').full_path)

    ckpt_paths = [str(ckpt_dir1), str(ckpt_dir2), str(ckpt_dir3)]
    tspecs = jax.tree_map(serialization.get_tensorstore_spec, ckpt_paths)

    serialization.run_serialization([gda1, gda2, gda3], tspecs)

    m1, m2, m3 = serialization.run_deserialization(
        [global_mesh, global_mesh, global_mesh1d],
        [mesh_axes, P('x'), P(None)],
        tspecs)

    self.assertArraysEqual(m1.local_shards[0].data.to_py(),
                           np.array([[0], [2]]))
    self.assertArraysEqual(m1.local_shards[1].data.to_py(),
                           np.array([[1], [3]]))
    self.assertEqual(m1.local_shards[0].data.shape, (2, 1))
    self.assertEqual(m1.dtype, np.int32)

    self.assertArraysEqual(m2.local_shards[0].data.to_py(),
                           np.array([[16, 17], [18, 19]]))
    self.assertArraysEqual(m2.local_shards[1].data.to_py(),
                           np.array([[16, 17], [18, 19]]))
    self.assertEqual(m2.local_shards[0].data.shape, (2, 2))
    self.assertEqual(m2.dtype, np.int32)

    for i, s in enumerate(m3.local_shards):
      self.assertEqual(s.index, (slice(None),))
      self.assertEqual(s.replica_id, i)
      self.assertArraysEqual(s.data.to_py(), np.array([]))
    self.assertEqual(m3.dtype, np.float32)
Beispiel #3
0
 def test_gda_equality_raises_not_implemented(self):
   global_mesh = jtu.create_global_mesh((1, 2), ('x', 'y'))
   global_input_shape = (8, 2)
   mesh_axes = P(None,)
   global_input_data = np.arange(
       prod(global_input_shape)).reshape(global_input_shape)
   def cb(index):
     return global_input_data[index]
   input_gda = GlobalDeviceArray.from_callback(
       global_input_shape, global_mesh, mesh_axes, cb)
   same_input_gda = GlobalDeviceArray.from_callback(
       global_input_shape, global_mesh, mesh_axes, cb)
   with self.assertRaisesRegex(NotImplementedError,
       'GlobalDeviceArray equality is intentionally unimplemented.'):
     input_gda == same_input_gda
    def test_device_mismatch(self):
        devices = jax.devices()
        if len(devices) < 8:
            raise unittest.SkipTest("Test requires 8 global devices.")
        mesh_devices = np.array([[devices[0], devices[2]],
                                 [devices[3], devices[1]],
                                 [devices[4], devices[6]],
                                 [devices[7], devices[5]]])
        global_mesh = Mesh(mesh_devices, ('x', 'y'))
        global_input_shape = (8, 2)
        mesh_axes = ['x', 'y']
        global_input_data = np.arange(
            prod(global_input_shape)).reshape(global_input_shape)
        indices = get_shard_indices(global_input_shape, global_mesh, mesh_axes)

        dbs = [
            jax.device_put(global_input_data[indices[d]], d)
            for d in jax.local_devices()
        ]

        with self.assertRaisesRegex(
                ValueError,
                'The `global_mesh.local_devices` and `device_buffers` device order'
        ):
            GlobalDeviceArray(global_input_shape, global_mesh, mesh_axes, dbs)
    def test_gda_batched_callback_with_devices(self):
        global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
        global_input_shape = (8, 2)
        mesh_axes = ['x']
        global_input_data = np.arange(
            prod(global_input_shape),
            dtype=np.float32).reshape(global_input_shape)

        def cb(cb_inp):
            self.assertLen(cb_inp, 4)
            dbs = []
            for inp in cb_inp:
                index, devices = inp
                self.assertLen(devices, 2)
                array = global_input_data[index]
                dbs.extend(
                    [jax.device_put(array, device) for device in devices])
            return dbs

        gda = GlobalDeviceArray.from_batched_callback_with_devices(
            global_input_shape, global_mesh, mesh_axes, cb)
        expected_first_shard_value = np.array([[0, 1], [2, 3]],
                                              dtype=np.float32)
        self.assertArraysEqual(
            gda.local_data(0).to_py(), expected_first_shard_value)
        expected_second_shard_value = np.array([[0, 1], [2, 3]],
                                               dtype=np.float32)
        self.assertArraysEqual(
            gda.local_data(1).to_py(), expected_second_shard_value)
    def test_gda_subset_devices(self, mesh_axes, expected_index,
                                expected_shard_shape, expected_replica_ids):
        global_mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
        global_input_shape = (8, 2)
        global_input_data = np.arange(
            prod(global_input_shape)).reshape(global_input_shape)

        def cb(index):
            return global_input_data[index]

        gda = GlobalDeviceArray.from_callback(global_input_shape, global_mesh,
                                              mesh_axes, cb)
        self.assertEqual(gda.local_shards[0].index, expected_index[0])
        self.assertArraysEqual(gda.local_data(0),
                               global_input_data[expected_index[0]])
        self.assertEqual(gda.local_shards[1].index, expected_index[1])
        self.assertArraysEqual(gda.local_data(1),
                               global_input_data[expected_index[1]])
        self.assertEqual(gda.local_data(0).shape, expected_shard_shape)
        replica_ids = [i.replica_id for i in gda.local_shards]
        self.assertListEqual(replica_ids, expected_replica_ids)
        for g, l in safe_zip(gda.global_shards, gda.local_shards):
            self.assertEqual(g.device, l.device)
            self.assertEqual(g.index, l.index)
            self.assertEqual(g.replica_id, l.replica_id)
            self.assertArraysEqual(g.data, l.data)
Beispiel #7
0
  def test_gda_2d_shard(self, mesh_axes, expected_index, expected_shard_shape,
                        expected_replica_ids, expected_is_fully_replicated):
    global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
    global_input_shape = (8, 2)
    global_input_data = np.arange(
        prod(global_input_shape)).reshape(global_input_shape)
    def cb(index):
      return global_input_data[index]

    gda = GlobalDeviceArray.from_callback(global_input_shape, global_mesh,
                                          mesh_axes, cb)
    self.assertEqual(gda.ndim, 2)
    self.assertEqual(gda.size, 16)
    self.assertEqual(gda.mesh_axes, mesh_axes)
    self.assertEqual(gda.local_shards[0].index, expected_index[0])
    self.assertArraysEqual(gda.local_data(0),
                           global_input_data[expected_index[0]])
    self.assertEqual(gda.local_shards[1].index, expected_index[1])
    self.assertArraysEqual(gda.local_data(1),
                           global_input_data[expected_index[1]])
    self.assertEqual(gda.local_data(0).shape, expected_shard_shape)
    replica_ids = [i.replica_id for i in gda.local_shards]
    self.assertListEqual(replica_ids, expected_replica_ids)
    self.assertListEqual([i.device.id for i in gda.local_shards],
                         [0, 1, 2, 3, 4, 5, 6, 7])
    self.assertEqual(gda.is_fully_replicated, expected_is_fully_replicated)
    for s in gda.local_shards:
      self.assertEqual(s.data.aval,
                       core.ShapedArray(expected_shard_shape, s.data.dtype))
    for g, l in safe_zip(gda.global_shards, gda.local_shards):
      self.assertEqual(g.device, l.device)
      self.assertEqual(g.index, l.index)
      self.assertEqual(g.replica_id, l.replica_id)
      self.assertEqual(g.data.aval, l.data.aval)
      self.assertArraysEqual(g.data, l.data)
Beispiel #8
0
def create_gda(global_shape, global_mesh, mesh_axes, global_data=None):
    if global_data is None:
        global_data = np.arange(prod(global_shape)).reshape(global_shape)

    return GlobalDeviceArray.from_callback(
        global_shape, global_mesh, mesh_axes,
        lambda idx: global_data[idx]), global_data
Beispiel #9
0
  def test_gda_block_until_ready(self):
    global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
    global_input_shape = (8, 2)
    mesh_axes = P(('x', 'y'))
    global_input_data = np.arange(
        prod(global_input_shape)).reshape(global_input_shape)

    def cb(index):
      return global_input_data[index]

    gda = GlobalDeviceArray.from_callback(
        global_input_shape, global_mesh, mesh_axes, cb)

    self.assertTrue(gda.block_until_ready() is gda)
    def test_async_checkpointing(self):
        global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
        global_input_shape = (8, 2)
        mesh_axes = P('x', 'y')
        num = util.prod(global_input_shape)

        # First GDA
        global_input_data1 = np.arange(num).reshape(global_input_shape)

        def cb1(index):
            return global_input_data1[index]

        gda1 = GlobalDeviceArray.from_callback(global_input_shape, global_mesh,
                                               mesh_axes, cb1)
        temp_ckpt_dir1 = pathlib.Path(
            self.create_tempdir('temp_first').full_path)
        ckpt_dir1 = str(temp_ckpt_dir1).replace('temp_first', 'first')

        s_tspecs = jax.tree_map(serialization.get_tensorstore_spec,
                                [str(temp_ckpt_dir1)])

        manager = serialization.GlobalAsyncCheckpointManager()
        manager.serialize([gda1],
                          s_tspecs,
                          temp_checkpoint_dir=temp_ckpt_dir1,
                          final_checkpoint_dir=ckpt_dir1)
        manager.wait_until_finished()

        d_tspecs = jax.tree_map(serialization.get_tensorstore_spec,
                                [str(ckpt_dir1)])
        m1, = manager.deserialize([global_mesh], [mesh_axes], d_tspecs)
        self.assertArraysEqual(m1.local_shards[0].data.to_py(),
                               np.array([[0], [2]]))
        self.assertArraysEqual(m1.local_shards[1].data.to_py(),
                               np.array([[1], [3]]))
        self.assertEqual(m1.local_shards[0].data.shape, (2, 1))
        self.assertEqual(m1.dtype, np.int32)

        # Will throw `file already exists` error when `tf.io.gfile.rename`.
        # `wait_until_finished` will raise the error.
        with self.assertRaises(Exception):
            ckpt_dir1 = pathlib.Path(self.create_tempdir('first').full_path)
            manager1 = serialization.GlobalAsyncCheckpointManager()
            manager1.serialize([gda1],
                               s_tspecs,
                               temp_checkpoint_dir=temp_ckpt_dir1,
                               final_checkpoint_dir=ckpt_dir1)
            manager1.wait_until_finished()
Beispiel #11
0
 def test_gda_str_repr(self):
   global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
   global_input_shape = (8, 2)
   mesh_axes = P(('x', 'y'))
   global_input_data = np.arange(
       prod(global_input_shape)).reshape(global_input_shape)
   def cb(index):
     return global_input_data[index]
   gda = GlobalDeviceArray.from_callback(
       global_input_shape, global_mesh, mesh_axes, cb)
   self.assertEqual(str(gda),
                    'GlobalDeviceArray(shape=(8, 2), dtype=int32)')
   self.assertEqual(
       repr(gda), ('GlobalDeviceArray(shape=(8, 2), dtype=int32, '
                   "global_mesh_shape={'x': 4, 'y': 2}, "
                   "mesh_axes=PartitionSpec(('x', 'y'),))"))
Beispiel #12
0
 def test_gda_shape_0_1d_mesh(self):
   global_mesh = jtu.create_global_mesh((8,), ('x'))
   global_input_shape = (0,)
   mesh_axes = P(None)
   def cb(index):
     return np.array([])
   gda = GlobalDeviceArray.from_callback(global_input_shape, global_mesh,
                                         mesh_axes, cb)
   self.assertEqual(gda.ndim, 1)
   self.assertEqual(gda.size, 0)
   for i, s in enumerate(gda.local_shards):
     self.assertEqual(s.index, (slice(None),))
     self.assertEqual(s.replica_id, i)
     self.assertArraysEqual(s.data.to_py(), np.array([]))
   self.assertEqual(gda.dtype, np.float32)
   self.assertEqual(
       gda_lib.get_shard_shape(global_input_shape, global_mesh, mesh_axes),
       (0,))
Beispiel #13
0
  def test_gda_batched_callback(self):
    global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
    global_input_shape = (8, 2)
    mesh_axes = P(('x', 'y'))
    global_input_data = np.arange(
        prod(global_input_shape)).reshape(global_input_shape)

    def cb(indices):
      self.assertEqual(len(indices), len(global_mesh.local_devices))
      return [global_input_data[index] for index in indices]

    gda = GlobalDeviceArray.from_batched_callback(
        global_input_shape, global_mesh, mesh_axes, cb)
    expected_first_shard_value = np.array([[0, 1]])
    self.assertArraysEqual(gda.local_data(0).to_py(),
                           expected_first_shard_value)
    expected_second_shard_value = np.array([[2, 3]])
    self.assertArraysEqual(gda.local_data(1).to_py(),
                           expected_second_shard_value)
Beispiel #14
0
 def test_gda_1d_shard(self, mesh_axes, expected_index, expected_shard_shape,
                        expected_replica_ids):
   global_mesh = create_global_mesh((8,), ('x'))
   global_input_shape = (16,)
   global_input_data = np.arange(prod(global_input_shape)).reshape(-1)
   def cb(index):
     return global_input_data[index]
   gda = GlobalDeviceArray.from_callback(global_input_shape,
                                                 global_mesh,
                                                 mesh_axes, cb)
   self.assertEqual(gda.local_shards[0].index, expected_index[0])
   self.assertArraysEqual(gda.local_data(0),
                          global_input_data[expected_index[0]])
   self.assertEqual(gda.local_shards[1].index, expected_index[1])
   self.assertArraysEqual(gda.local_data(1),
                          global_input_data[expected_index[1]])
   self.assertEqual(gda.local_data(0).shape, expected_shard_shape)
   replica_ids = [i.replica_id for i in gda.local_shards]
   self.assertListEqual(replica_ids, expected_replica_ids)
Beispiel #15
0
    def test_checkpointing_with_bigger_shape(self):
        global_mesh = create_global_mesh((2, 2), ('x', 'y'))
        global_input_shape = (8, 2)
        num = util.prod(global_input_shape)

        # First GDA
        global_input_data1 = np.arange(num).reshape(global_input_shape)

        def cb1(index):
            return global_input_data1[index]

        gda1 = GlobalDeviceArray.from_callback(global_input_shape, global_mesh,
                                               P('x', 'y'), cb1)
        ckpt_dir1 = pathlib.Path(self.create_tempdir('first').full_path)

        ckpt_paths = [str(ckpt_dir1)]
        tspecs = jax.tree_map(serialization.get_tensorstore_spec, ckpt_paths)

        serialization.run_serialization([gda1], tspecs)

        m1, = serialization.run_deserialization(
            [create_global_mesh((4, 2), ('x', 'y'))],
            [P('x', 'y')],
            tspecs,
            [(12, 2)],
        )

        expected_data = {
            0: np.array([[0], [2], [4]]),
            1: np.array([[1], [3], [5]]),
            2: np.array([[6], [8], [10]]),
            3: np.array([[7], [9], [11]]),
            4: np.array([[12], [14], [0]]),
            5: np.array([[13], [15], [0]]),
            6: np.array([[0], [0], [0]]),
            7: np.array([[0], [0], [0]]),
        }

        for l in m1.local_shards:
            self.assertArraysEqual(l.data.to_py(), expected_data[l.device.id])