def test_checkpointing(self): global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) global_input_shape = (8, 2) mesh_axes = P('x', 'y') num = util.prod(global_input_shape) # First GDA global_input_data1 = np.arange(num).reshape(global_input_shape) def cb1(index): return global_input_data1[index] gda1 = GlobalDeviceArray.from_callback(global_input_shape, global_mesh, mesh_axes, cb1) ckpt_dir1 = pathlib.Path(self.create_tempdir('first').full_path) # Second GDA global_input_data2 = np.arange(num, num + num).reshape(global_input_shape) def cb2(index): return global_input_data2[index] gda2 = GlobalDeviceArray.from_callback(global_input_shape, global_mesh, mesh_axes, cb2) ckpt_dir2 = pathlib.Path(self.create_tempdir('second').full_path) # Third GDA def cb3(index): return np.array([]) global_mesh1d = jtu.create_global_mesh((8,), ('x',)) gda3 = GlobalDeviceArray.from_callback((0,), global_mesh1d, P(None), cb3) ckpt_dir3 = pathlib.Path(self.create_tempdir('third').full_path) ckpt_paths = [str(ckpt_dir1), str(ckpt_dir2), str(ckpt_dir3)] tspecs = jax.tree_map(serialization.get_tensorstore_spec, ckpt_paths) serialization.run_serialization([gda1, gda2, gda3], tspecs) m1, m2, m3 = serialization.run_deserialization( [global_mesh, global_mesh, global_mesh1d], [mesh_axes, P('x'), P(None)], tspecs) self.assertArraysEqual(m1.local_shards[0].data.to_py(), np.array([[0], [2]])) self.assertArraysEqual(m1.local_shards[1].data.to_py(), np.array([[1], [3]])) self.assertEqual(m1.local_shards[0].data.shape, (2, 1)) self.assertEqual(m1.dtype, np.int32) self.assertArraysEqual(m2.local_shards[0].data.to_py(), np.array([[16, 17], [18, 19]])) self.assertArraysEqual(m2.local_shards[1].data.to_py(), np.array([[16, 17], [18, 19]])) self.assertEqual(m2.local_shards[0].data.shape, (2, 2)) self.assertEqual(m2.dtype, np.int32) for i, s in enumerate(m3.local_shards): self.assertEqual(s.index, (slice(None),)) self.assertEqual(s.replica_id, i) self.assertArraysEqual(s.data.to_py(), np.array([])) self.assertEqual(m3.dtype, np.float32)
def test_mesh_hash(self): global_mesh1 = jtu.create_global_mesh((4, 2), ('x', 'y')) global_mesh2 = jtu.create_global_mesh((2, 4), ('x', 'y')) global_mesh3 = jtu.create_global_mesh((4, 2), ('x', 'y')) self.assertNotEqual(hash(global_mesh1), hash(global_mesh2)) self.assertEqual(hash(global_mesh1), hash(global_mesh3))
def test_gda_2d_shard(self, mesh_axes, expected_index, expected_shard_shape, expected_replica_ids, expected_is_fully_replicated): global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) global_input_shape = (8, 2) global_input_data = np.arange( prod(global_input_shape)).reshape(global_input_shape) def cb(index): return global_input_data[index] gda = GlobalDeviceArray.from_callback(global_input_shape, global_mesh, mesh_axes, cb) self.assertEqual(gda.ndim, 2) self.assertEqual(gda.size, 16) self.assertEqual(gda.mesh_axes, mesh_axes) self.assertEqual(gda.local_shards[0].index, expected_index[0]) self.assertArraysEqual(gda.local_data(0), global_input_data[expected_index[0]]) self.assertEqual(gda.local_shards[1].index, expected_index[1]) self.assertArraysEqual(gda.local_data(1), global_input_data[expected_index[1]]) self.assertEqual(gda.local_data(0).shape, expected_shard_shape) replica_ids = [i.replica_id for i in gda.local_shards] self.assertListEqual(replica_ids, expected_replica_ids) self.assertListEqual([i.device.id for i in gda.local_shards], [0, 1, 2, 3, 4, 5, 6, 7]) self.assertEqual(gda.is_fully_replicated, expected_is_fully_replicated) for s in gda.local_shards: self.assertEqual(s.data.aval, core.ShapedArray(expected_shard_shape, s.data.dtype)) for g, l in safe_zip(gda.global_shards, gda.local_shards): self.assertEqual(g.device, l.device) self.assertEqual(g.index, l.index) self.assertEqual(g.replica_id, l.replica_id) self.assertEqual(g.data.aval, l.data.aval) self.assertArraysEqual(g.data, l.data)
def test_gda_subset_devices(self, mesh_axes, expected_index, expected_shard_shape, expected_replica_ids): global_mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) global_input_shape = (8, 2) global_input_data = np.arange( prod(global_input_shape)).reshape(global_input_shape) def cb(index): return global_input_data[index] gda = GlobalDeviceArray.from_callback(global_input_shape, global_mesh, mesh_axes, cb) self.assertEqual(gda.local_shards[0].index, expected_index[0]) self.assertArraysEqual(gda.local_data(0), global_input_data[expected_index[0]]) self.assertEqual(gda.local_shards[1].index, expected_index[1]) self.assertArraysEqual(gda.local_data(1), global_input_data[expected_index[1]]) self.assertEqual(gda.local_data(0).shape, expected_shard_shape) replica_ids = [i.replica_id for i in gda.local_shards] self.assertListEqual(replica_ids, expected_replica_ids) for g, l in safe_zip(gda.global_shards, gda.local_shards): self.assertEqual(g.device, l.device) self.assertEqual(g.index, l.index) self.assertEqual(g.replica_id, l.replica_id) self.assertArraysEqual(g.data, l.data)
def indices_replica_id_calc_cached(mesh_shape, mesh_axes, state): global_input_shape = (2048, 2048) global_mesh = jtu.create_global_mesh(mesh_shape, ("x", "y")) while state: gda.get_shard_indices_replica_ids(global_input_shape, global_mesh, mesh_axes)
def test_gda_batched_callback_with_devices(self): global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) global_input_shape = (8, 2) mesh_axes = ['x'] global_input_data = np.arange( prod(global_input_shape), dtype=np.float32).reshape(global_input_shape) def cb(cb_inp): self.assertLen(cb_inp, 4) dbs = [] for inp in cb_inp: index, devices = inp self.assertLen(devices, 2) array = global_input_data[index] dbs.extend( [jax.device_put(array, device) for device in devices]) return dbs gda = GlobalDeviceArray.from_batched_callback_with_devices( global_input_shape, global_mesh, mesh_axes, cb) expected_first_shard_value = np.array([[0, 1], [2, 3]], dtype=np.float32) self.assertArraysEqual( gda.local_data(0).to_py(), expected_first_shard_value) expected_second_shard_value = np.array([[0, 1], [2, 3]], dtype=np.float32) self.assertArraysEqual( gda.local_data(1).to_py(), expected_second_shard_value)
def test_repr(self): with jax._src.config.jax_array(True): global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) input_shape = (8, 2) arr, _ = create_array( input_shape, sharding.MeshPspecSharding(global_mesh, P('x', 'y'))) repr(arr) # doesn't crash
def test_array_device_get(self): with jax._src.config.jax_array(True): global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) input_shape = (8, 2) arr, input_data = create_array( input_shape, sharding.MeshPspecSharding(global_mesh, P('x', 'y'))) self.assertArraysEqual(jax.device_get(arr), input_data)
def test_array_sharded_astype(self): with jax._src.config.jax_array(True): global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) input_shape = (8, 2) arr, input_data = create_array( input_shape, sharding.MeshPspecSharding(global_mesh, P('x', 'y'))) arr_float32 = arr.astype(jnp.float32) self.assertEqual(arr_float32.dtype, np.float32) self.assertArraysEqual(arr_float32, input_data.astype(np.float32))
def test_array_delete(self): with jax._src.config.jax_array(True): global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) input_shape = (8, 2) arr, _ = create_array( input_shape, sharding.MeshPspecSharding(global_mesh, P('x', 'y'))) arr.delete() with self.assertRaisesRegex(ValueError, 'Array has been deleted.'): arr._check_if_deleted() self.assertIsNone(arr._npy_value) self.assertIsNone(arr._arrays)
def test_jax_array_value(self, mesh_axes): with jax._src.config.jax_array(True): global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) input_shape = (8, 2) arr, global_data = create_array( input_shape, sharding.MeshPspecSharding(global_mesh, mesh_axes)) for s in arr.addressable_shards: self.assertLen(s.data._arrays, 1) self.assertArraysEqual(s.data._arrays[0], global_data[s.index]) self.assertArraysEqual(arr._value, global_data) self.assertArraysEqual(arr._npy_value, global_data)
def test_checkpointing_with_bigger_shape(self): global_mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) global_input_shape = (8, 2) num = util.prod(global_input_shape) # First GDA global_input_data1 = np.arange(num, dtype=np.int32).reshape(global_input_shape) def cb1(index): return global_input_data1[index] gda1 = GlobalDeviceArray.from_callback(global_input_shape, global_mesh, P('x', 'y'), cb1) ckpt_dir1 = pathlib.Path(self.create_tempdir('first').full_path) ckpt_paths = [str(ckpt_dir1)] tspecs = jax.tree_map(serialization.get_tensorstore_spec, ckpt_paths) serialization.run_serialization([gda1], tspecs) m1, = serialization.run_deserialization( [jtu.create_global_mesh((4, 2), ('x', 'y'))], [P('x', 'y')], tspecs, [(12, 2)], [np.float32] ) expected_data = { 0: np.array([[0], [2], [4]], dtype=np.float32), 1: np.array([[1], [3], [5]], dtype=np.float32), 2: np.array([[6], [8], [10]], dtype=np.float32), 3: np.array([[7], [9], [11]], dtype=np.float32), 4: np.array([[12], [14], [0]], dtype=np.float32), 5: np.array([[13], [15], [0]], dtype=np.float32), 6: np.array([[0], [0], [0]], dtype=np.float32), 7: np.array([[0], [0], [0]], dtype=np.float32), } for l in m1.local_shards: self.assertArraysEqual(l.data.to_py(), expected_data[l.device.id])
def test_gda_block_until_ready(self): global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) global_input_shape = (8, 2) mesh_axes = P(('x', 'y')) global_input_data = np.arange( prod(global_input_shape)).reshape(global_input_shape) def cb(index): return global_input_data[index] gda = GlobalDeviceArray.from_callback( global_input_shape, global_mesh, mesh_axes, cb) self.assertTrue(gda.block_until_ready() is gda)
def gda_construction_callback(mesh_axes, state): # Keep the mesh containing 8 local devices as using >8 local devices is # unrealistic. Since `from_callback` measures `device_put` time as well, it # dominates when local devices are for example 2048 (local devices will never # be 2048). global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) global_input_shape = (2048, 2048) global_input_data = np.arange( prod(global_input_shape)).reshape(global_input_shape) def cb(index): return global_input_data[index] while state: gda.GlobalDeviceArray.from_callback( global_input_shape, global_mesh, mesh_axes, cb)
def test_gda_equality_raises_not_implemented(self): global_mesh = jtu.create_global_mesh((1, 2), ('x', 'y')) global_input_shape = (8, 2) mesh_axes = P(None,) global_input_data = np.arange( prod(global_input_shape)).reshape(global_input_shape) def cb(index): return global_input_data[index] input_gda = GlobalDeviceArray.from_callback( global_input_shape, global_mesh, mesh_axes, cb) same_input_gda = GlobalDeviceArray.from_callback( global_input_shape, global_mesh, mesh_axes, cb) with self.assertRaisesRegex(NotImplementedError, 'GlobalDeviceArray equality is intentionally unimplemented.'): input_gda == same_input_gda
def test_async_checkpointing(self): global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) global_input_shape = (8, 2) mesh_axes = P('x', 'y') num = util.prod(global_input_shape) # First GDA global_input_data1 = np.arange(num).reshape(global_input_shape) def cb1(index): return global_input_data1[index] gda1 = GlobalDeviceArray.from_callback(global_input_shape, global_mesh, mesh_axes, cb1) temp_ckpt_dir1 = pathlib.Path( self.create_tempdir('temp_first').full_path) ckpt_dir1 = str(temp_ckpt_dir1).replace('temp_first', 'first') s_tspecs = jax.tree_map(serialization.get_tensorstore_spec, [str(temp_ckpt_dir1)]) manager = serialization.GlobalAsyncCheckpointManager() manager.serialize([gda1], s_tspecs, temp_checkpoint_dir=temp_ckpt_dir1, final_checkpoint_dir=ckpt_dir1) manager.wait_until_finished() d_tspecs = jax.tree_map(serialization.get_tensorstore_spec, [str(ckpt_dir1)]) m1, = manager.deserialize([global_mesh], [mesh_axes], d_tspecs) self.assertArraysEqual(m1.local_shards[0].data.to_py(), np.array([[0], [2]])) self.assertArraysEqual(m1.local_shards[1].data.to_py(), np.array([[1], [3]])) self.assertEqual(m1.local_shards[0].data.shape, (2, 1)) self.assertEqual(m1.dtype, np.int32) # Will throw `file already exists` error when `tf.io.gfile.rename`. # `wait_until_finished` will raise the error. with self.assertRaises(Exception): ckpt_dir1 = pathlib.Path(self.create_tempdir('first').full_path) manager1 = serialization.GlobalAsyncCheckpointManager() manager1.serialize([gda1], s_tspecs, temp_checkpoint_dir=temp_ckpt_dir1, final_checkpoint_dir=ckpt_dir1) manager1.wait_until_finished()
def gda_construction_raw(mesh_shape, mesh_axes, state): # `device_put` time is not measured in this benchmark. All the devices here # are local. global_mesh = jtu.create_global_mesh(mesh_shape, ("x", "y")) global_input_shape = (2048, 2048) global_input_data = np.arange( prod(global_input_shape)).reshape(global_input_shape) global_indices = gda.get_shard_indices(global_input_shape, global_mesh, mesh_axes) dbs = [ jax.device_put(global_input_data[global_indices[device]], device) for device in global_mesh.local_devices ] while state: gda.GlobalDeviceArray(global_input_shape, global_mesh, mesh_axes, dbs)
def test_mesh_pspec_sharding_interface(self): mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) pspec = P('y', 'x') global_shape = (8, 4) mp_sharding = sharding.MeshPspecSharding(mesh, pspec) di_map = mp_sharding.devices_indices_map(global_shape) op_sharding = mp_sharding._to_xla_op_sharding(len(global_shape)) device_assignment = mp_sharding._device_assignment() self.assertEqual(di_map[mesh.devices.flat[0]], (slice(0, 4), slice(0, 1))) self.assertArraysEqual(device_assignment, list(mesh.devices.flat)) self.assertEqual(op_sharding.type, xc.OpSharding.Type.OTHER) self.assertListEqual(op_sharding.tile_assignment_dimensions, [2, 4]) self.assertListEqual(op_sharding.tile_assignment_devices, [0, 2, 4, 6, 1, 3, 5, 7])
def test_gda_str_repr(self): global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) global_input_shape = (8, 2) mesh_axes = P(('x', 'y')) global_input_data = np.arange( prod(global_input_shape)).reshape(global_input_shape) def cb(index): return global_input_data[index] gda = GlobalDeviceArray.from_callback( global_input_shape, global_mesh, mesh_axes, cb) self.assertEqual(str(gda), 'GlobalDeviceArray(shape=(8, 2), dtype=int32)') self.assertEqual( repr(gda), ('GlobalDeviceArray(shape=(8, 2), dtype=int32, ' "global_mesh_shape={'x': 4, 'y': 2}, " "mesh_axes=PartitionSpec(('x', 'y'),))"))
def test_gda_shape_0_1d_mesh(self): global_mesh = jtu.create_global_mesh((8,), ('x')) global_input_shape = (0,) mesh_axes = P(None) def cb(index): return np.array([]) gda = GlobalDeviceArray.from_callback(global_input_shape, global_mesh, mesh_axes, cb) self.assertEqual(gda.ndim, 1) self.assertEqual(gda.size, 0) for i, s in enumerate(gda.local_shards): self.assertEqual(s.index, (slice(None),)) self.assertEqual(s.replica_id, i) self.assertArraysEqual(s.data.to_py(), np.array([])) self.assertEqual(gda.dtype, np.float32) self.assertEqual( gda_lib.get_shard_shape(global_input_shape, global_mesh, mesh_axes), (0,))
def test_gda_1d_shard(self, mesh_axes, expected_index, expected_shard_shape, expected_replica_ids): global_mesh = jtu.create_global_mesh((8,), ('x')) global_input_shape = (16,) global_input_data = np.arange(prod(global_input_shape)).reshape(-1) def cb(index): return global_input_data[index] gda = GlobalDeviceArray.from_callback(global_input_shape, global_mesh, mesh_axes, cb) self.assertEqual(gda.local_shards[0].index, expected_index[0]) self.assertArraysEqual(gda.local_data(0), global_input_data[expected_index[0]]) self.assertEqual(gda.local_shards[1].index, expected_index[1]) self.assertArraysEqual(gda.local_data(1), global_input_data[expected_index[1]]) self.assertEqual(gda.local_data(0).shape, expected_shard_shape) replica_ids = [i.replica_id for i in gda.local_shards] self.assertListEqual(replica_ids, expected_replica_ids)
def test_gda_batched_callback(self): global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) global_input_shape = (8, 2) mesh_axes = P(('x', 'y')) global_input_data = np.arange( prod(global_input_shape)).reshape(global_input_shape) def cb(indices): self.assertEqual(len(indices), len(global_mesh.local_devices)) return [global_input_data[index] for index in indices] gda = GlobalDeviceArray.from_batched_callback( global_input_shape, global_mesh, mesh_axes, cb) expected_first_shard_value = np.array([[0, 1]]) self.assertArraysEqual(gda.local_data(0).to_py(), expected_first_shard_value) expected_second_shard_value = np.array([[2, 3]]) self.assertArraysEqual(gda.local_data(1).to_py(), expected_second_shard_value)