Ejemplo n.º 1
0
  def test_checkpointing(self):
    global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
    global_input_shape = (8, 2)
    mesh_axes = P('x', 'y')
    num = util.prod(global_input_shape)

    # First GDA
    global_input_data1 = np.arange(num).reshape(global_input_shape)
    def cb1(index):
      return global_input_data1[index]
    gda1 = GlobalDeviceArray.from_callback(global_input_shape, global_mesh,
                                           mesh_axes, cb1)
    ckpt_dir1 = pathlib.Path(self.create_tempdir('first').full_path)

    # Second GDA
    global_input_data2 = np.arange(num, num + num).reshape(global_input_shape)
    def cb2(index):
      return global_input_data2[index]
    gda2 = GlobalDeviceArray.from_callback(global_input_shape, global_mesh,
                                           mesh_axes, cb2)
    ckpt_dir2 = pathlib.Path(self.create_tempdir('second').full_path)

    # Third GDA
    def cb3(index):
      return np.array([])
    global_mesh1d = jtu.create_global_mesh((8,), ('x',))
    gda3 = GlobalDeviceArray.from_callback((0,), global_mesh1d, P(None), cb3)
    ckpt_dir3 = pathlib.Path(self.create_tempdir('third').full_path)

    ckpt_paths = [str(ckpt_dir1), str(ckpt_dir2), str(ckpt_dir3)]
    tspecs = jax.tree_map(serialization.get_tensorstore_spec, ckpt_paths)

    serialization.run_serialization([gda1, gda2, gda3], tspecs)

    m1, m2, m3 = serialization.run_deserialization(
        [global_mesh, global_mesh, global_mesh1d],
        [mesh_axes, P('x'), P(None)],
        tspecs)

    self.assertArraysEqual(m1.local_shards[0].data.to_py(),
                           np.array([[0], [2]]))
    self.assertArraysEqual(m1.local_shards[1].data.to_py(),
                           np.array([[1], [3]]))
    self.assertEqual(m1.local_shards[0].data.shape, (2, 1))
    self.assertEqual(m1.dtype, np.int32)

    self.assertArraysEqual(m2.local_shards[0].data.to_py(),
                           np.array([[16, 17], [18, 19]]))
    self.assertArraysEqual(m2.local_shards[1].data.to_py(),
                           np.array([[16, 17], [18, 19]]))
    self.assertEqual(m2.local_shards[0].data.shape, (2, 2))
    self.assertEqual(m2.dtype, np.int32)

    for i, s in enumerate(m3.local_shards):
      self.assertEqual(s.index, (slice(None),))
      self.assertEqual(s.replica_id, i)
      self.assertArraysEqual(s.data.to_py(), np.array([]))
    self.assertEqual(m3.dtype, np.float32)
Ejemplo n.º 2
0
    def test_mesh_hash(self):
        global_mesh1 = jtu.create_global_mesh((4, 2), ('x', 'y'))

        global_mesh2 = jtu.create_global_mesh((2, 4), ('x', 'y'))

        global_mesh3 = jtu.create_global_mesh((4, 2), ('x', 'y'))

        self.assertNotEqual(hash(global_mesh1), hash(global_mesh2))
        self.assertEqual(hash(global_mesh1), hash(global_mesh3))
Ejemplo n.º 3
0
  def test_gda_2d_shard(self, mesh_axes, expected_index, expected_shard_shape,
                        expected_replica_ids, expected_is_fully_replicated):
    global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
    global_input_shape = (8, 2)
    global_input_data = np.arange(
        prod(global_input_shape)).reshape(global_input_shape)
    def cb(index):
      return global_input_data[index]

    gda = GlobalDeviceArray.from_callback(global_input_shape, global_mesh,
                                          mesh_axes, cb)
    self.assertEqual(gda.ndim, 2)
    self.assertEqual(gda.size, 16)
    self.assertEqual(gda.mesh_axes, mesh_axes)
    self.assertEqual(gda.local_shards[0].index, expected_index[0])
    self.assertArraysEqual(gda.local_data(0),
                           global_input_data[expected_index[0]])
    self.assertEqual(gda.local_shards[1].index, expected_index[1])
    self.assertArraysEqual(gda.local_data(1),
                           global_input_data[expected_index[1]])
    self.assertEqual(gda.local_data(0).shape, expected_shard_shape)
    replica_ids = [i.replica_id for i in gda.local_shards]
    self.assertListEqual(replica_ids, expected_replica_ids)
    self.assertListEqual([i.device.id for i in gda.local_shards],
                         [0, 1, 2, 3, 4, 5, 6, 7])
    self.assertEqual(gda.is_fully_replicated, expected_is_fully_replicated)
    for s in gda.local_shards:
      self.assertEqual(s.data.aval,
                       core.ShapedArray(expected_shard_shape, s.data.dtype))
    for g, l in safe_zip(gda.global_shards, gda.local_shards):
      self.assertEqual(g.device, l.device)
      self.assertEqual(g.index, l.index)
      self.assertEqual(g.replica_id, l.replica_id)
      self.assertEqual(g.data.aval, l.data.aval)
      self.assertArraysEqual(g.data, l.data)
Ejemplo n.º 4
0
    def test_gda_subset_devices(self, mesh_axes, expected_index,
                                expected_shard_shape, expected_replica_ids):
        global_mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
        global_input_shape = (8, 2)
        global_input_data = np.arange(
            prod(global_input_shape)).reshape(global_input_shape)

        def cb(index):
            return global_input_data[index]

        gda = GlobalDeviceArray.from_callback(global_input_shape, global_mesh,
                                              mesh_axes, cb)
        self.assertEqual(gda.local_shards[0].index, expected_index[0])
        self.assertArraysEqual(gda.local_data(0),
                               global_input_data[expected_index[0]])
        self.assertEqual(gda.local_shards[1].index, expected_index[1])
        self.assertArraysEqual(gda.local_data(1),
                               global_input_data[expected_index[1]])
        self.assertEqual(gda.local_data(0).shape, expected_shard_shape)
        replica_ids = [i.replica_id for i in gda.local_shards]
        self.assertListEqual(replica_ids, expected_replica_ids)
        for g, l in safe_zip(gda.global_shards, gda.local_shards):
            self.assertEqual(g.device, l.device)
            self.assertEqual(g.index, l.index)
            self.assertEqual(g.replica_id, l.replica_id)
            self.assertArraysEqual(g.data, l.data)
Ejemplo n.º 5
0
def indices_replica_id_calc_cached(mesh_shape, mesh_axes, state):
    global_input_shape = (2048, 2048)
    global_mesh = jtu.create_global_mesh(mesh_shape, ("x", "y"))

    while state:
        gda.get_shard_indices_replica_ids(global_input_shape, global_mesh,
                                          mesh_axes)
Ejemplo n.º 6
0
    def test_gda_batched_callback_with_devices(self):
        global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
        global_input_shape = (8, 2)
        mesh_axes = ['x']
        global_input_data = np.arange(
            prod(global_input_shape),
            dtype=np.float32).reshape(global_input_shape)

        def cb(cb_inp):
            self.assertLen(cb_inp, 4)
            dbs = []
            for inp in cb_inp:
                index, devices = inp
                self.assertLen(devices, 2)
                array = global_input_data[index]
                dbs.extend(
                    [jax.device_put(array, device) for device in devices])
            return dbs

        gda = GlobalDeviceArray.from_batched_callback_with_devices(
            global_input_shape, global_mesh, mesh_axes, cb)
        expected_first_shard_value = np.array([[0, 1], [2, 3]],
                                              dtype=np.float32)
        self.assertArraysEqual(
            gda.local_data(0).to_py(), expected_first_shard_value)
        expected_second_shard_value = np.array([[0, 1], [2, 3]],
                                               dtype=np.float32)
        self.assertArraysEqual(
            gda.local_data(1).to_py(), expected_second_shard_value)
Ejemplo n.º 7
0
 def test_repr(self):
     with jax._src.config.jax_array(True):
         global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
         input_shape = (8, 2)
         arr, _ = create_array(
             input_shape,
             sharding.MeshPspecSharding(global_mesh, P('x', 'y')))
         repr(arr)  # doesn't crash
Ejemplo n.º 8
0
 def test_array_device_get(self):
     with jax._src.config.jax_array(True):
         global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
         input_shape = (8, 2)
         arr, input_data = create_array(
             input_shape,
             sharding.MeshPspecSharding(global_mesh, P('x', 'y')))
         self.assertArraysEqual(jax.device_get(arr), input_data)
Ejemplo n.º 9
0
 def test_array_sharded_astype(self):
     with jax._src.config.jax_array(True):
         global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
         input_shape = (8, 2)
         arr, input_data = create_array(
             input_shape,
             sharding.MeshPspecSharding(global_mesh, P('x', 'y')))
         arr_float32 = arr.astype(jnp.float32)
         self.assertEqual(arr_float32.dtype, np.float32)
         self.assertArraysEqual(arr_float32, input_data.astype(np.float32))
Ejemplo n.º 10
0
 def test_array_delete(self):
     with jax._src.config.jax_array(True):
         global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
         input_shape = (8, 2)
         arr, _ = create_array(
             input_shape,
             sharding.MeshPspecSharding(global_mesh, P('x', 'y')))
         arr.delete()
         with self.assertRaisesRegex(ValueError, 'Array has been deleted.'):
             arr._check_if_deleted()
         self.assertIsNone(arr._npy_value)
         self.assertIsNone(arr._arrays)
Ejemplo n.º 11
0
 def test_jax_array_value(self, mesh_axes):
     with jax._src.config.jax_array(True):
         global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
         input_shape = (8, 2)
         arr, global_data = create_array(
             input_shape,
             sharding.MeshPspecSharding(global_mesh, mesh_axes))
         for s in arr.addressable_shards:
             self.assertLen(s.data._arrays, 1)
             self.assertArraysEqual(s.data._arrays[0], global_data[s.index])
         self.assertArraysEqual(arr._value, global_data)
         self.assertArraysEqual(arr._npy_value, global_data)
Ejemplo n.º 12
0
  def test_checkpointing_with_bigger_shape(self):
    global_mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
    global_input_shape = (8, 2)
    num = util.prod(global_input_shape)

    # First GDA
    global_input_data1 = np.arange(num, dtype=np.int32).reshape(global_input_shape)
    def cb1(index):
      return global_input_data1[index]
    gda1 = GlobalDeviceArray.from_callback(global_input_shape, global_mesh,
                                           P('x', 'y'), cb1)
    ckpt_dir1 = pathlib.Path(self.create_tempdir('first').full_path)

    ckpt_paths = [str(ckpt_dir1)]
    tspecs = jax.tree_map(serialization.get_tensorstore_spec, ckpt_paths)

    serialization.run_serialization([gda1], tspecs)

    m1, = serialization.run_deserialization(
        [jtu.create_global_mesh((4, 2), ('x', 'y'))],
        [P('x', 'y')],
        tspecs,
        [(12, 2)],
        [np.float32]
    )

    expected_data = {
        0: np.array([[0], [2], [4]], dtype=np.float32),
        1: np.array([[1], [3], [5]], dtype=np.float32),
        2: np.array([[6], [8], [10]], dtype=np.float32),
        3: np.array([[7], [9], [11]], dtype=np.float32),
        4: np.array([[12], [14], [0]], dtype=np.float32),
        5: np.array([[13], [15], [0]], dtype=np.float32),
        6: np.array([[0], [0], [0]], dtype=np.float32),
        7: np.array([[0], [0], [0]], dtype=np.float32),
    }

    for l in m1.local_shards:
      self.assertArraysEqual(l.data.to_py(), expected_data[l.device.id])
Ejemplo n.º 13
0
  def test_gda_block_until_ready(self):
    global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
    global_input_shape = (8, 2)
    mesh_axes = P(('x', 'y'))
    global_input_data = np.arange(
        prod(global_input_shape)).reshape(global_input_shape)

    def cb(index):
      return global_input_data[index]

    gda = GlobalDeviceArray.from_callback(
        global_input_shape, global_mesh, mesh_axes, cb)

    self.assertTrue(gda.block_until_ready() is gda)
Ejemplo n.º 14
0
def gda_construction_callback(mesh_axes, state):
  # Keep the mesh containing 8 local devices as using >8 local devices is
  # unrealistic. Since `from_callback` measures `device_put` time as well, it
  # dominates when local devices are for example 2048 (local devices will never
  # be 2048).
  global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
  global_input_shape = (2048, 2048)
  global_input_data = np.arange(
      prod(global_input_shape)).reshape(global_input_shape)
  def cb(index):
    return global_input_data[index]

  while state:
    gda.GlobalDeviceArray.from_callback(
        global_input_shape, global_mesh, mesh_axes, cb)
Ejemplo n.º 15
0
 def test_gda_equality_raises_not_implemented(self):
   global_mesh = jtu.create_global_mesh((1, 2), ('x', 'y'))
   global_input_shape = (8, 2)
   mesh_axes = P(None,)
   global_input_data = np.arange(
       prod(global_input_shape)).reshape(global_input_shape)
   def cb(index):
     return global_input_data[index]
   input_gda = GlobalDeviceArray.from_callback(
       global_input_shape, global_mesh, mesh_axes, cb)
   same_input_gda = GlobalDeviceArray.from_callback(
       global_input_shape, global_mesh, mesh_axes, cb)
   with self.assertRaisesRegex(NotImplementedError,
       'GlobalDeviceArray equality is intentionally unimplemented.'):
     input_gda == same_input_gda
Ejemplo n.º 16
0
    def test_async_checkpointing(self):
        global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
        global_input_shape = (8, 2)
        mesh_axes = P('x', 'y')
        num = util.prod(global_input_shape)

        # First GDA
        global_input_data1 = np.arange(num).reshape(global_input_shape)

        def cb1(index):
            return global_input_data1[index]

        gda1 = GlobalDeviceArray.from_callback(global_input_shape, global_mesh,
                                               mesh_axes, cb1)
        temp_ckpt_dir1 = pathlib.Path(
            self.create_tempdir('temp_first').full_path)
        ckpt_dir1 = str(temp_ckpt_dir1).replace('temp_first', 'first')

        s_tspecs = jax.tree_map(serialization.get_tensorstore_spec,
                                [str(temp_ckpt_dir1)])

        manager = serialization.GlobalAsyncCheckpointManager()
        manager.serialize([gda1],
                          s_tspecs,
                          temp_checkpoint_dir=temp_ckpt_dir1,
                          final_checkpoint_dir=ckpt_dir1)
        manager.wait_until_finished()

        d_tspecs = jax.tree_map(serialization.get_tensorstore_spec,
                                [str(ckpt_dir1)])
        m1, = manager.deserialize([global_mesh], [mesh_axes], d_tspecs)
        self.assertArraysEqual(m1.local_shards[0].data.to_py(),
                               np.array([[0], [2]]))
        self.assertArraysEqual(m1.local_shards[1].data.to_py(),
                               np.array([[1], [3]]))
        self.assertEqual(m1.local_shards[0].data.shape, (2, 1))
        self.assertEqual(m1.dtype, np.int32)

        # Will throw `file already exists` error when `tf.io.gfile.rename`.
        # `wait_until_finished` will raise the error.
        with self.assertRaises(Exception):
            ckpt_dir1 = pathlib.Path(self.create_tempdir('first').full_path)
            manager1 = serialization.GlobalAsyncCheckpointManager()
            manager1.serialize([gda1],
                               s_tspecs,
                               temp_checkpoint_dir=temp_ckpt_dir1,
                               final_checkpoint_dir=ckpt_dir1)
            manager1.wait_until_finished()
Ejemplo n.º 17
0
def gda_construction_raw(mesh_shape, mesh_axes, state):
    # `device_put` time is not measured in this benchmark. All the devices here
    # are local.
    global_mesh = jtu.create_global_mesh(mesh_shape, ("x", "y"))
    global_input_shape = (2048, 2048)
    global_input_data = np.arange(
        prod(global_input_shape)).reshape(global_input_shape)
    global_indices = gda.get_shard_indices(global_input_shape, global_mesh,
                                           mesh_axes)
    dbs = [
        jax.device_put(global_input_data[global_indices[device]], device)
        for device in global_mesh.local_devices
    ]

    while state:
        gda.GlobalDeviceArray(global_input_shape, global_mesh, mesh_axes, dbs)
Ejemplo n.º 18
0
    def test_mesh_pspec_sharding_interface(self):
        mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
        pspec = P('y', 'x')
        global_shape = (8, 4)
        mp_sharding = sharding.MeshPspecSharding(mesh, pspec)
        di_map = mp_sharding.devices_indices_map(global_shape)
        op_sharding = mp_sharding._to_xla_op_sharding(len(global_shape))
        device_assignment = mp_sharding._device_assignment()

        self.assertEqual(di_map[mesh.devices.flat[0]],
                         (slice(0, 4), slice(0, 1)))
        self.assertArraysEqual(device_assignment, list(mesh.devices.flat))
        self.assertEqual(op_sharding.type, xc.OpSharding.Type.OTHER)
        self.assertListEqual(op_sharding.tile_assignment_dimensions, [2, 4])
        self.assertListEqual(op_sharding.tile_assignment_devices,
                             [0, 2, 4, 6, 1, 3, 5, 7])
Ejemplo n.º 19
0
 def test_gda_str_repr(self):
   global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
   global_input_shape = (8, 2)
   mesh_axes = P(('x', 'y'))
   global_input_data = np.arange(
       prod(global_input_shape)).reshape(global_input_shape)
   def cb(index):
     return global_input_data[index]
   gda = GlobalDeviceArray.from_callback(
       global_input_shape, global_mesh, mesh_axes, cb)
   self.assertEqual(str(gda),
                    'GlobalDeviceArray(shape=(8, 2), dtype=int32)')
   self.assertEqual(
       repr(gda), ('GlobalDeviceArray(shape=(8, 2), dtype=int32, '
                   "global_mesh_shape={'x': 4, 'y': 2}, "
                   "mesh_axes=PartitionSpec(('x', 'y'),))"))
Ejemplo n.º 20
0
 def test_gda_shape_0_1d_mesh(self):
   global_mesh = jtu.create_global_mesh((8,), ('x'))
   global_input_shape = (0,)
   mesh_axes = P(None)
   def cb(index):
     return np.array([])
   gda = GlobalDeviceArray.from_callback(global_input_shape, global_mesh,
                                         mesh_axes, cb)
   self.assertEqual(gda.ndim, 1)
   self.assertEqual(gda.size, 0)
   for i, s in enumerate(gda.local_shards):
     self.assertEqual(s.index, (slice(None),))
     self.assertEqual(s.replica_id, i)
     self.assertArraysEqual(s.data.to_py(), np.array([]))
   self.assertEqual(gda.dtype, np.float32)
   self.assertEqual(
       gda_lib.get_shard_shape(global_input_shape, global_mesh, mesh_axes),
       (0,))
Ejemplo n.º 21
0
  def test_gda_1d_shard(self, mesh_axes, expected_index, expected_shard_shape,
                         expected_replica_ids):
    global_mesh = jtu.create_global_mesh((8,), ('x'))
    global_input_shape = (16,)
    global_input_data = np.arange(prod(global_input_shape)).reshape(-1)
    def cb(index):
      return global_input_data[index]

    gda = GlobalDeviceArray.from_callback(global_input_shape, global_mesh,
                                          mesh_axes, cb)
    self.assertEqual(gda.local_shards[0].index, expected_index[0])
    self.assertArraysEqual(gda.local_data(0),
                           global_input_data[expected_index[0]])
    self.assertEqual(gda.local_shards[1].index, expected_index[1])
    self.assertArraysEqual(gda.local_data(1),
                           global_input_data[expected_index[1]])
    self.assertEqual(gda.local_data(0).shape, expected_shard_shape)
    replica_ids = [i.replica_id for i in gda.local_shards]
    self.assertListEqual(replica_ids, expected_replica_ids)
Ejemplo n.º 22
0
  def test_gda_batched_callback(self):
    global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
    global_input_shape = (8, 2)
    mesh_axes = P(('x', 'y'))
    global_input_data = np.arange(
        prod(global_input_shape)).reshape(global_input_shape)

    def cb(indices):
      self.assertEqual(len(indices), len(global_mesh.local_devices))
      return [global_input_data[index] for index in indices]

    gda = GlobalDeviceArray.from_batched_callback(
        global_input_shape, global_mesh, mesh_axes, cb)
    expected_first_shard_value = np.array([[0, 1]])
    self.assertArraysEqual(gda.local_data(0).to_py(),
                           expected_first_shard_value)
    expected_second_shard_value = np.array([[2, 3]])
    self.assertArraysEqual(gda.local_data(1).to_py(),
                           expected_second_shard_value)