def testInfeed(self): devices = np.array(jax.local_devices()) nr_devices = len(devices) shape = (nr_devices * 3, nr_devices * 5) def f_for_jit(x): token = lax.create_token(x) (y, ), token = lax.infeed(token, shape=(jax.ShapedArray( x.shape, np.float32), )) (z, ), token = lax.infeed(token, shape=(jax.ShapedArray( x.shape, np.float32), )) (w, ), token = lax.infeed(token, shape=(jax.ShapedArray( x.shape, np.float32), )) return x + y + z + w x = np.arange(np.prod(shape), dtype=np.float32).reshape(shape) y = x * 2. z = x * 3. w = x * 4. # Transfer data to infeed before executing the function. For GPUs, the # execution of the compiled function is blocking, so transferring data # to infeed before executing ensures that the execution does not deadlock # waiting for the infeed data. logging.info('Transfering to infeed for the jit call') d = devices[0] d.transfer_to_infeed((y, )) d.transfer_to_infeed((z, )) d.transfer_to_infeed((w, )) # JIT logging.info('Making jit call') res0 = jax.jit(f_for_jit)(x) self.assertAllClose(res0, x + y + z + w, check_dtypes=True) # PJIT def f_for_pjit(x): token = lax.create_token(x) # A replicated infeed (y, ), token = lax.infeed(token, shape=(jax.ShapedArray( x.shape, np.float32), ), partitions=(None, )) # An infeed sharded on first axis (z, ), token = lax.infeed(token, shape=(jax.ShapedArray( x.shape, np.float32), ), partitions=(P(nr_devices, 1), )) # An infeed sharded on second axis (w, ), token = lax.infeed(token, shape=(jax.ShapedArray( x.shape, np.float32), ), partitions=(P(1, nr_devices), )) return x + y + z + w logging.info('Transfering to infeed for the pjit call') for didx, d in enumerate(devices): # Transfer the whole array to all devices for replicated. d.transfer_to_infeed((y, )) # For sharded infeed, transfer only the needed slices to each device. d.transfer_to_infeed((z[3 * didx:3 * didx + 3, :])) d.transfer_to_infeed((w[:, 5 * didx:5 * didx + 5], )) with mesh(devices, ['d']): logging.info('Making pjit call') res = pjit(f_for_pjit, in_axis_resources=(P('d'), ), out_axis_resources=P('d'))(x) self.assertAllClose(res0, res, check_dtypes=True)
class PmapOfShardedJitTest(jtu.JaxTestCase): def setUp(self): super().setUp() if jtu.device_under_test() == "gpu": os.environ["NCCL_LAUNCH_MODE"] = "PARALLEL" # TODO(skye): make a similar version for ShardedJitTest and run the same tests def _runTest(self, f, in_partitions, out_partitions, dtype=np.float32): """Compares pmap(sharded_jit(f, ...)) to pmap(f)""" shape = (2, 4, 4) num_shards = shape[0] * np.prod(in_partitions[0]) if num_shards > jax.local_device_count(): raise SkipTest("requires %d devices" % num_shards) x = np.arange(prod(shape)).reshape(shape) y = x + 1 result = pmap( sharded_jit(f, in_parts=in_partitions, out_parts=out_partitions))(x, y) expected = pmap(f)(x, y) self.assertAllClose(result, expected, check_dtypes=False) flat_result = tree_util.tree_flatten(result)[0] for r in flat_result: self.assertTrue(isinstance(r, pxla.ShardedDeviceArray)) self.assertEqual(len(r.device_buffers), num_shards) @parameterized.named_parameters({ "testcase_name": "_in_parts={}_out_parts={}".format(in_partitions, out_partitions).replace(" ", ""), "in_partitions": in_partitions, "out_partitions": out_partitions } for in_partitions in [ (P(2, 1), P(2, 1)), (P(2, 1), P(1, 2)), (P(2, 2), P(2, 2)), (P(4, 1), P(2, 2)), ] for out_partitions in [in_partitions[0], None]) def testBasic(self, in_partitions, out_partitions): def f(x, y): return lax.dot(x, y) self._runTest(f, in_partitions, out_partitions) @parameterized.named_parameters({ "testcase_name": "_in_parts={}_out_parts={}".format(in_partitions, out_partitions).replace(" ", ""), "in_partitions": in_partitions, "out_partitions": out_partitions } for in_partitions in [ (P(2, 1), P(2, 1)), (P(2, 1), P(1, 2)), (P(4, 1), P(2, 2)) ] for out_partitions in [(in_partitions[1], in_partitions[0], None), (None, None, None)]) def testMultipleOutputs(self, in_partitions, out_partitions): def f(x, y): a = lax.dot(x, y) # TODO(skye): use these more interesting outputs once returning constants # works # return a, a + 1, 3 return a, a + x, x + y self._runTest(f, in_partitions, out_partitions) @parameterized.named_parameters({ "testcase_name": "_in_parts={}_out_parts={}".format(in_partitions, out_partitions).replace(" ", ""), "in_partitions": in_partitions, "out_partitions": out_partitions } for in_partitions in [ (P(2, 1), P(2, 1)), (P(2, 1), P(1, 2)), (P(4, 1), P(2, 2)) ] for out_partitions in [in_partitions[0], None]) def testArrayConstants(self, in_partitions, out_partitions): def f(x, y): a = lax.dot(x, y) b = a + jnp.ones(a.shape) c = b + jnp.ones(a.shape[0])[jnp.newaxis] return c self._runTest(f, in_partitions, out_partitions) def testPyTreeArgs(self): if jax.local_device_count() < 4: raise SkipTest("requires 4 devices") def f(a, b, c): a1, a2 = a c1, (c2, c3) = c return a1 + a2 + b + c1 + c2 + c3 def _make_arg(*shape): return np.arange(prod(shape)).reshape(shape) a = (_make_arg(2, 4, 4), _make_arg(2)) b = _make_arg(2, 4, 4) c = (_make_arg(2), (_make_arg(2, 4, 4), _make_arg(2, 4, 4))) in_parts = (None, P(2, 1), (None, P(2, 1))) out_parts = P(2, 1) result = pmap(sharded_jit(f, in_parts=in_parts, out_parts=out_parts))(a, b, c) expected = pmap(f)(a, b, c) self.assertAllClose(result, expected, check_dtypes=False) self.assertTrue(isinstance(result, pxla.ShardedDeviceArray)) self.assertEqual(len(result.device_buffers), 4) def testPyTreeOutputs(self): if jax.local_device_count() < 4: raise SkipTest("requires 4 devices") def f(x): return x + 1, ((x + 2, x + 3), x + 4) shape = (2, 4, 4) x = np.arange(prod(shape)).reshape(shape) in_parts = (P(2, 1),) out_parts = (P(2, 1), ((P(1, 2), None), P(2, 1))) result = pmap(sharded_jit(f, in_parts=in_parts, out_parts=out_parts))(x) expected = pmap(f)(x) self.assertAllClose(result, expected, check_dtypes=False) def testManyArgs(self): if jax.local_device_count() < 4: raise SkipTest("requires 4 devices") num_args = 200 def f(*args): return jnp.asarray(args).sum() shape = (2, 4, 4) args = [np.arange(prod(shape)).reshape(shape)] * num_args in_partitions = (P(2, 1),) * num_args out_partitions = None result = pmap(sharded_jit( f, in_parts=in_partitions, out_parts=out_partitions))(*args) expected = pmap(f)(*args) self.assertAllClose(result, expected, check_dtypes=False) self.assertTrue(isinstance(result, pxla.ShardedDeviceArray)) self.assertEqual(len(result.device_buffers), 4) def testShardingConstraint(self): if jax.local_device_count() < 4: raise SkipTest("requires 4 devices") @partial(sharded_jit, in_parts=None, out_parts=None) def f(x): y = jnp.dot(x, x) y = with_sharding_constraint(y, P(2,1)) return y * 2 def expected_f(x): return jnp.dot(x, x) * 2 shape = (2, 8, 8) x = np.arange(prod(shape)).reshape(shape) result = pmap(f)(x) expected = pmap(expected_f)(x) self.assertAllClose(result, expected, check_dtypes=False) self.assertIsInstance(result, pxla.ShardedDeviceArray) self.assertLen(result.device_buffers, 4) def testInAxesNone(self): shape = (4, 4) replicas = 2 in_partitions = (P(2, 1), None, None) out_partitions = P(2, 1) in_axes = (None, None, 0) x = y = np.arange(prod(shape), dtype=np.float32).reshape(shape) dummy = np.arange(replicas, dtype=np.float32) + 1 num_shards = replicas * np.prod(in_partitions[0]) if num_shards > jax.local_device_count(): raise SkipTest("requires %d devices" % num_shards) def f(x, y, _): return x @ y result = pmap( sharded_jit(f, in_parts=in_partitions, out_parts=out_partitions), in_axes=in_axes)(x, y, dummy) expected = pmap(f, in_axes=in_axes)(x, y, dummy) self.assertAllClose(result, expected, check_dtypes=True)
def f(x): y = jnp.dot(x, x) y = with_sharding_constraint(y, P(2,1)) return y * 2
def embedding(x): x = maybe_shard(x, P("dp", None)) return EmbeddingShardV2(config)(x)
def f(x): y = x + 1 y = with_sharding_constraint(y, P(2,1)) return y * 2
class GDATest(jtu.JaxTestCase): @parameterized.named_parameters( ("mesh_x_y", P("x", "y"), # There are more slices but for convienient purposes, checking for only # 2. The indices + shard_shape + replica_id should be unique enough. ((slice(0, 2), slice(0, 1)), (slice(0, 2), slice(1, 2))), (2, 1), [0, 0, 0, 0, 0, 0, 0, 0], False), ("mesh_x", P("x"), ((slice(0, 2), slice(None)), (slice(0, 2), slice(None))), (2, 2), [0, 1, 0, 1, 0, 1, 0, 1], False), ("mesh_y", P("y"), ((slice(0, 4), slice(None)), (slice(4, 8), slice(None))), (4, 2), [0, 0, 1, 1, 2, 2, 3, 3], False), ("mesh_none_y", P(None, "y"), ((slice(None), slice(0, 1)), (slice(None), slice(1, 2))), (8, 1), [0, 0, 1, 1, 2, 2, 3, 3], False), ("mesh_xy", P(("x", "y")), ((slice(0, 1), slice(None)), (slice(1, 2), slice(None))), (1, 2), [0, 0, 0, 0, 0, 0, 0, 0], False), ("mesh_fully_replicated", P(), ((slice(None), slice(None)), (slice(None), slice(None))), (8, 2), [0, 1, 2, 3, 4, 5, 6, 7], True), ) def test_gda_2d_shard(self, mesh_axes, expected_index, expected_shard_shape, expected_replica_ids, expected_is_fully_replicated): global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) global_input_shape = (8, 2) global_input_data = np.arange( prod(global_input_shape)).reshape(global_input_shape) def cb(index): return global_input_data[index] gda = GlobalDeviceArray.from_callback(global_input_shape, global_mesh, mesh_axes, cb) self.assertEqual(gda.ndim, 2) self.assertEqual(gda.size, 16) self.assertEqual(gda.mesh_axes, mesh_axes) self.assertEqual(gda.local_shards[0].index, expected_index[0]) self.assertArraysEqual(gda.local_data(0), global_input_data[expected_index[0]]) self.assertEqual(gda.local_shards[1].index, expected_index[1]) self.assertArraysEqual(gda.local_data(1), global_input_data[expected_index[1]]) self.assertEqual(gda.local_data(0).shape, expected_shard_shape) replica_ids = [i.replica_id for i in gda.local_shards] self.assertListEqual(replica_ids, expected_replica_ids) self.assertListEqual([i.device.id for i in gda.local_shards], [0, 1, 2, 3, 4, 5, 6, 7]) self.assertEqual(gda.is_fully_replicated, expected_is_fully_replicated) for s in gda.local_shards: self.assertEqual(s.data.aval, core.ShapedArray(expected_shard_shape, s.data.dtype)) for g, l in safe_zip(gda.global_shards, gda.local_shards): self.assertEqual(g.device, l.device) self.assertEqual(g.index, l.index) self.assertEqual(g.replica_id, l.replica_id) self.assertEqual(g.data.aval, l.data.aval) self.assertArraysEqual(g.data, l.data) @parameterized.named_parameters( ("mesh_x_y_z", P("x", "y", "z"), # There are more slices but for convienient purposes, checking for only # 2. The indices + shard_shape + replica_id should be unique enough. ((slice(0, 4), slice(0, 2), slice(0, 1)), (slice(0, 4), slice(0, 2), slice(1, 2))), (4, 2, 1), [0, 0, 0, 0, 0, 0, 0, 0]), ("mesh_xy_z", P(("x", "y"), "z"), ((slice(0, 2), slice(0, 2), slice(None)), (slice(0, 2), slice(2, 4), slice(None))), (2, 2, 2), [0, 0, 0, 0, 0, 0, 0, 0]), ("mesh_z", P("z"), ((slice(0, 4), slice(None), slice(None)), (slice(4, 8), slice(None), slice(None))), (4, 4, 2), [0, 0, 1, 1, 2, 2, 3, 3]), ) def test_gda_3d_shard(self, mesh_axes, expected_index, expected_shard_shape, expected_replica_ids): global_mesh = jtu.create_global_mesh((2, 2, 2), ('x', 'y', 'z')) global_input_shape = (8, 4, 2) global_input_data = np.arange( prod(global_input_shape)).reshape(global_input_shape) def cb(index): return global_input_data[index] gda = GlobalDeviceArray.from_callback(global_input_shape, global_mesh, mesh_axes, cb) self.assertEqual(gda.ndim, 3) self.assertEqual(gda.size, 64) self.assertEqual(gda.local_shards[0].index, expected_index[0]) self.assertArraysEqual(gda.local_data(0), global_input_data[expected_index[0]]) self.assertEqual(gda.local_shards[1].index, expected_index[1]) self.assertArraysEqual(gda.local_data(1), global_input_data[expected_index[1]]) self.assertEqual(gda.local_data(0).shape, expected_shard_shape) replica_ids = [i.replica_id for i in gda.local_shards] self.assertListEqual(replica_ids, expected_replica_ids) @parameterized.named_parameters( ("mesh_x", P("x"), # There are more slices but for convienient purposes, checking for only # 2. The indices + shard_shape + replica_id should be unique enough. ((slice(0, 2),), (slice(2, 4),)), (2,), [0, 0, 0, 0, 0, 0, 0, 0]), ("mesh_none", P(), ((slice(None),), (slice(None),)), (16,), [0, 1, 2, 3, 4, 5, 6, 7]), ) def test_gda_1d_shard(self, mesh_axes, expected_index, expected_shard_shape, expected_replica_ids): global_mesh = jtu.create_global_mesh((8,), ('x')) global_input_shape = (16,) global_input_data = np.arange(prod(global_input_shape)).reshape(-1) def cb(index): return global_input_data[index] gda = GlobalDeviceArray.from_callback(global_input_shape, global_mesh, mesh_axes, cb) self.assertEqual(gda.ndim, 1) self.assertEqual(gda.size, 16) self.assertEqual(gda.local_shards[0].index, expected_index[0]) self.assertArraysEqual(gda.local_data(0), global_input_data[expected_index[0]]) self.assertEqual(gda.local_shards[1].index, expected_index[1]) self.assertArraysEqual(gda.local_data(1), global_input_data[expected_index[1]]) self.assertEqual(gda.local_data(0).shape, expected_shard_shape) replica_ids = [i.replica_id for i in gda.local_shards] self.assertListEqual(replica_ids, expected_replica_ids) def test_gda_shape_0_1d_mesh(self): global_mesh = jtu.create_global_mesh((8,), ('x')) global_input_shape = (0,) mesh_axes = P(None) def cb(index): return np.array([]) gda = GlobalDeviceArray.from_callback(global_input_shape, global_mesh, mesh_axes, cb) self.assertEqual(gda.ndim, 1) self.assertEqual(gda.size, 0) for i, s in enumerate(gda.local_shards): self.assertEqual(s.index, (slice(None),)) self.assertEqual(s.replica_id, i) self.assertArraysEqual(s.data.to_py(), np.array([])) self.assertEqual(gda.dtype, np.float32) self.assertEqual( gda_lib.get_shard_shape(global_input_shape, global_mesh, mesh_axes), (0,)) @parameterized.named_parameters( ("mesh_x_y", P("x", "y"), # There are more slices but for convienient purposes, checking for only # 2. The indices + shard_shape + replica_id should be unique enough. ((slice(0, 4), slice(0, 1)), (slice(0, 4), slice(1, 2))), (4, 1), [0, 0, 0, 0]), ) def test_gda_subset_devices(self, mesh_axes, expected_index, expected_shard_shape, expected_replica_ids): global_mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) global_input_shape = (8, 2) global_input_data = np.arange( prod(global_input_shape)).reshape(global_input_shape) def cb(index): return global_input_data[index] gda = GlobalDeviceArray.from_callback(global_input_shape, global_mesh, mesh_axes, cb) self.assertEqual(gda.local_shards[0].index, expected_index[0]) self.assertArraysEqual(gda.local_data(0), global_input_data[expected_index[0]]) self.assertEqual(gda.local_shards[1].index, expected_index[1]) self.assertArraysEqual(gda.local_data(1), global_input_data[expected_index[1]]) self.assertEqual(gda.local_data(0).shape, expected_shard_shape) replica_ids = [i.replica_id for i in gda.local_shards] self.assertListEqual(replica_ids, expected_replica_ids) for g, l in safe_zip(gda.global_shards, gda.local_shards): self.assertEqual(g.device, l.device) self.assertEqual(g.index, l.index) self.assertEqual(g.replica_id, l.replica_id) self.assertArraysEqual(g.data, l.data) def test_gda_batched_callback(self): global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) global_input_shape = (8, 2) mesh_axes = P(('x', 'y')) global_input_data = np.arange( prod(global_input_shape)).reshape(global_input_shape) def cb(indices): self.assertEqual(len(indices), len(global_mesh.local_devices)) return [global_input_data[index] for index in indices] gda = GlobalDeviceArray.from_batched_callback( global_input_shape, global_mesh, mesh_axes, cb) expected_first_shard_value = np.array([[0, 1]]) self.assertArraysEqual(gda.local_data(0).to_py(), expected_first_shard_value) expected_second_shard_value = np.array([[2, 3]]) self.assertArraysEqual(gda.local_data(1).to_py(), expected_second_shard_value) def test_gda_batched_callback_with_devices(self): global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) global_input_shape = (8, 2) mesh_axes = P('x') global_input_data = np.arange( prod(global_input_shape), dtype=np.float32).reshape(global_input_shape) def cb(cb_inp): self.assertLen(cb_inp, 4) dbs = [] for inp in cb_inp: index, devices = inp self.assertLen(devices, 2) array = global_input_data[index] dbs.extend([jax.device_put(array, device) for device in devices]) return dbs gda = GlobalDeviceArray.from_batched_callback_with_devices( global_input_shape, global_mesh, mesh_axes, cb) expected_first_shard_value = np.array([[0, 1], [2, 3]], dtype=np.float32) self.assertArraysEqual(gda.local_data(0).to_py(), expected_first_shard_value) expected_second_shard_value = np.array([[0, 1], [2, 3]], dtype=np.float32) self.assertArraysEqual(gda.local_data(1).to_py(), expected_second_shard_value) def test_gda_str_repr(self): global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) global_input_shape = (8, 2) mesh_axes = P(('x', 'y')) global_input_data = np.arange( prod(global_input_shape)).reshape(global_input_shape) def cb(index): return global_input_data[index] gda = GlobalDeviceArray.from_callback( global_input_shape, global_mesh, mesh_axes, cb) self.assertEqual(str(gda), 'GlobalDeviceArray(shape=(8, 2), dtype=int32)') self.assertEqual( repr(gda), ('GlobalDeviceArray(shape=(8, 2), dtype=int32, ' "global_mesh_shape={'x': 4, 'y': 2}, " "mesh_axes=PartitionSpec(('x', 'y'),))")) def test_gda_equality_raises_not_implemented(self): global_mesh = jtu.create_global_mesh((1, 2), ('x', 'y')) global_input_shape = (8, 2) mesh_axes = P(None,) global_input_data = np.arange( prod(global_input_shape)).reshape(global_input_shape) def cb(index): return global_input_data[index] input_gda = GlobalDeviceArray.from_callback( global_input_shape, global_mesh, mesh_axes, cb) same_input_gda = GlobalDeviceArray.from_callback( global_input_shape, global_mesh, mesh_axes, cb) with self.assertRaisesRegex(NotImplementedError, 'GlobalDeviceArray equality is intentionally unimplemented.'): input_gda == same_input_gda def test_mesh_hash(self): global_mesh1 = jtu.create_global_mesh((4, 2), ('x', 'y')) global_mesh2 = jtu.create_global_mesh((2, 4), ('x', 'y')) global_mesh3 = jtu.create_global_mesh((4, 2), ('x', 'y')) self.assertNotEqual(hash(global_mesh1), hash(global_mesh2)) self.assertEqual(hash(global_mesh1), hash(global_mesh3)) def test_device_mismatch(self): devices = jax.devices() if len(devices) < 8: raise unittest.SkipTest("Test requires 8 global devices.") mesh_devices = np.array([[devices[0], devices[2]], [devices[3], devices[1]], [devices[4], devices[6]], [devices[7], devices[5]]]) global_mesh = Mesh(mesh_devices, ('x', 'y')) global_input_shape = (8, 2) mesh_axes = P('x', 'y') global_input_data = np.arange( prod(global_input_shape)).reshape(global_input_shape) indices = get_shard_indices(global_input_shape, global_mesh, mesh_axes) dbs = [ jax.device_put(global_input_data[indices[d]], d) for d in jax.local_devices() ] with self.assertRaisesRegex( ValueError, 'The `global_mesh.local_devices` and `device_buffers` device order'): GlobalDeviceArray(global_input_shape, global_mesh, mesh_axes, dbs) def test_gda_block_until_ready(self): global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) global_input_shape = (8, 2) mesh_axes = P(('x', 'y')) global_input_data = np.arange( prod(global_input_shape)).reshape(global_input_shape) def cb(index): return global_input_data[index] gda = GlobalDeviceArray.from_callback( global_input_shape, global_mesh, mesh_axes, cb) self.assertTrue(gda.block_until_ready() is gda)
def f(x): x = with_sharding_constraint(x, [P('x', 'y'), P('y', 'x')]) x = x.copy() x[0]["a"] *= 2 return x
def dispatch(): with mesh(devices, ['d']): logging.info('Making pjit call') pjit(f, in_axis_resources=(P('d'),), out_axis_resources=P('d'))(x)
def test_pjit_gda_multi_input_multi_output(self): global_mesh = create_global_mesh((4, 2), ('x', 'y')) global_input_shape = (8, 2) input_data = np.arange( prod(global_input_shape)).reshape(global_input_shape) def cb(index): return input_data[index] mesh_axes1 = P('x', 'y') gda1 = global_device_array.GlobalDeviceArray.from_callback( global_input_shape, global_mesh, mesh_axes1, cb) mesh_axes2 = P('x') gda2 = global_device_array.GlobalDeviceArray.from_callback( global_input_shape, global_mesh, mesh_axes2, cb) mesh_axes3 = P(('x', 'y')) gda3 = global_device_array.GlobalDeviceArray.from_callback( global_input_shape, global_mesh, mesh_axes3, cb) mesh_axes4 = P(None) gda4 = global_device_array.GlobalDeviceArray.from_callback( global_input_shape, global_mesh, mesh_axes4, cb) with jax._src.config.parallel_functions_output_gda(True): @partial( pjit, # `FROM_GDA` will be replicated for all the inputs. in_axis_resources=FROM_GDA, out_axis_resources=(mesh_axes1, mesh_axes4, mesh_axes2, mesh_axes3)) def f(x, y, z, a): return x @ x.T, y, z, a out1, out2, out3, out4 = f(gda1, gda2, gda3, gda4) self.assertIsInstance(out1, global_device_array.GlobalDeviceArray) self.assertEqual(out1.shape, (8, 8)) self.assertEqual(out1.local_shards[0].data.shape, (2, 4)) self.assertEqual(out1.local_shards[0].index, (slice(0, 2), slice(0, 4))) self.assertEqual(out1.local_shards[1].index, (slice(0, 2), slice(4, 8))) self.assertListEqual([s.replica_id for s in out1.local_shards], [0, 0, 0, 0, 0, 0, 0, 0]) expected_matrix_mul = input_data @ input_data.T for s in out1.local_shards: self.assertArraysEqual(s.data, expected_matrix_mul[s.index]) self.assertIsInstance(out2, global_device_array.GlobalDeviceArray) self.assertEqual(out2.shape, (8, 2)) self.assertEqual(out2.local_shards[0].data.shape, (8, 2)) self.assertEqual(out2.local_shards[0].index, (slice(None), slice(None))) self.assertEqual(out2.local_shards[1].index, (slice(None), slice(None))) self.assertListEqual([s.replica_id for s in out2.local_shards], [0, 1, 2, 3, 4, 5, 6, 7]) for s in out2.local_shards: self.assertArraysEqual(s.data, input_data) self.assertIsInstance(out3, global_device_array.GlobalDeviceArray) self.assertEqual(out3.shape, (8, 2)) self.assertEqual(out3.local_shards[0].data.shape, (2, 2)) self.assertEqual(out3.local_shards[0].index, (slice(0, 2), slice(None))) self.assertEqual(out3.local_shards[1].index, (slice(0, 2), slice(None))) self.assertListEqual([s.replica_id for s in out3.local_shards], [0, 1, 0, 1, 0, 1, 0, 1]) for s in out3.local_shards: self.assertArraysEqual(s.data, input_data[s.index]) self.assertIsInstance(out4, global_device_array.GlobalDeviceArray) self.assertEqual(out4.shape, (8, 2)) self.assertEqual(out4.local_shards[0].data.shape, (1, 2)) self.assertEqual(out4.local_shards[0].index, (slice(0, 1), slice(None))) self.assertEqual(out4.local_shards[1].index, (slice(1, 2), slice(None))) self.assertListEqual([s.replica_id for s in out4.local_shards], [0, 0, 0, 0, 0, 0, 0, 0]) for s in out4.local_shards: self.assertArraysEqual(s.data, input_data[s.index])
def testLowerWithDuckTyping(self): x = jax.ShapeDtypeStruct((2, 2), jnp.float32) # Make sure this doesn't crash pjit(lambda x: x + 4, in_axis_resources=P('x'), out_axis_resources=P('x')).lower(x)
def f(x): token = lax.create_token(x) token = lax.outfeed(token, x, partitions=(None,)) token = lax.outfeed(token, x, partitions=(P(nr_devices, 1),)) token = lax.outfeed(token, x, partitions=(P(1, nr_devices),)) return x
def f(x): with mesh(np.array([jax.local_devices()[0]]), ('x')): @partial(pjit, in_axis_resources=P('x'), out_axis_resources=None) def h(x): return x return h(x)
class JaxArrayTest(jtu.JaxTestCase): @parameterized.named_parameters( ("mesh_x_y", P("x", "y")), ("mesh_x", P("x")), ("mesh_y", P("y")), ("mesh_none_y", P(None, "y")), ("mesh_xy", P(("x", "y"))), ("mesh_fully_replicated", P()), ) def test_jax_array_value(self, mesh_axes): with jax._src.config.jax_array(True): global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) input_shape = (8, 2) arr, global_data = create_array( input_shape, sharding.MeshPspecSharding(global_mesh, mesh_axes)) for s in arr.addressable_shards: self.assertLen(s.data._arrays, 1) self.assertArraysEqual(s.data._arrays[0], global_data[s.index]) self.assertArraysEqual(arr._value, global_data) self.assertArraysEqual(arr._npy_value, global_data) def test_array_delete(self): with jax._src.config.jax_array(True): global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) input_shape = (8, 2) arr, _ = create_array( input_shape, sharding.MeshPspecSharding(global_mesh, P('x', 'y'))) arr.delete() with self.assertRaisesRegex(ValueError, 'Array has been deleted.'): arr._check_if_deleted() self.assertIsNone(arr._npy_value) self.assertIsNone(arr._arrays) def test_device_put(self): with jax._src.config.jax_array(True): numpy_array = np.array([1, 2, 3]) arr = jax.device_put(numpy_array, jax.devices()[0]) self.assertIsInstance(arr.sharding, sharding.SingleDeviceSharding) self.assertArraysEqual(arr, numpy_array) self.assertEqual(arr._committed, True) for i in arr.addressable_shards: self.assertArraysEqual(i.data, numpy_array) self.assertEqual(i.device, jax.devices()[0]) self.assertEqual(i.index, (slice(None), )) def test_device_put_array_delete(self): with jax._src.config.jax_array(True): arr = jax.device_put(np.array([1, 2, 3]), jax.devices()[0]) arr.delete() with self.assertRaisesRegex(ValueError, 'Array has been deleted.'): arr._check_if_deleted() self.assertIsNone(arr._npy_value) self.assertIsNone(arr._arrays) def test_array_device_get(self): with jax._src.config.jax_array(True): global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) input_shape = (8, 2) arr, input_data = create_array( input_shape, sharding.MeshPspecSharding(global_mesh, P('x', 'y'))) self.assertArraysEqual(jax.device_get(arr), input_data) def test_repr(self): with jax._src.config.jax_array(True): global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) input_shape = (8, 2) arr, _ = create_array( input_shape, sharding.MeshPspecSharding(global_mesh, P('x', 'y'))) repr(arr) # doesn't crash def test_jnp_array(self): with jax._src.config.jax_array(True): arr = jnp.array([1, 2, 3]) self.assertIsInstance(arr, array.Array) self.assertIsInstance(arr.sharding, sharding.SingleDeviceSharding) self.assertEqual(arr._committed, False) def test_jnp_array_jit_add(self): with jax._src.config.jax_array(True): a = jnp.array([1, 2, 3]) b = jnp.array([4, 5, 6]) arr = jax.jit(lambda x, y: x + y)(a, b) self.assertIsInstance(arr, array.Array) self.assertArraysEqual(arr, np.array([5, 7, 9])) self.assertIsInstance(arr.sharding, sharding.SingleDeviceSharding) def test_jnp_array_jnp_add(self): with jax._src.config.jax_array(True): arr = jnp.add(jnp.array([1, 2, 3]), jnp.array([4, 5, 6])) self.assertIsInstance(arr, array.Array) self.assertArraysEqual(arr, np.array([5, 7, 9])) self.assertIsInstance(arr.sharding, sharding.SingleDeviceSharding) def test_jnp_array_normal_add(self): with jax._src.config.jax_array(True): a = jnp.array([1, 2, 3]) b = jnp.array([4, 5, 6]) arr = a + b self.assertIsInstance(arr, array.Array) self.assertArraysEqual(arr, np.array([5, 7, 9])) self.assertIsInstance(arr.sharding, sharding.SingleDeviceSharding) def test_array_sharded_astype(self): with jax._src.config.jax_array(True): global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) input_shape = (8, 2) arr, input_data = create_array( input_shape, sharding.MeshPspecSharding(global_mesh, P('x', 'y'))) arr_float32 = arr.astype(jnp.float32) self.assertEqual(arr_float32.dtype, np.float32) self.assertArraysEqual(arr_float32, input_data.astype(np.float32)) def test_jnp_array_astype(self): with jax._src.config.jax_array(True): arr = jnp.array([1, 2, 3]) arr_float32 = arr.astype(jnp.float32) self.assertEqual(arr_float32.dtype, np.float32) self.assertArraysEqual(arr_float32, arr.astype(np.float32))
def residual(x, mask): out = x + TransformerLayerShardV2( config, init_scale=2. / config["layers"])(x, mask) return maybe_shard(out, P("dp", None, "mp"))
class GDATest(jtu.JaxTestCase): @parameterized.named_parameters( ("mesh_x_y", ["x", "y"], # There are more slices but for convienient purposes, checking for only # 2. The indices + shard_shape + replica_id should be unique enough. ((slice(0, 2), slice(0, 1)), (slice(0, 2), slice(1, 2))), (2, 1), [0, 0, 0, 0, 0, 0, 0, 0], False), ("mesh_x_y_pspec", P("x", "y"), ((slice(0, 2), slice(0, 1)), (slice(0, 2), slice(1, 2))), (2, 1), [0, 0, 0, 0, 0, 0, 0, 0], False), ("mesh_x", ["x"], ((slice(0, 2), slice(None)), (slice(0, 2), slice(None))), (2, 2), [0, 1, 0, 1, 0, 1, 0, 1], False), ("mesh_y", ["y"], ((slice(0, 4), slice(None)), (slice(4, 8), slice(None))), (4, 2), [0, 0, 1, 1, 2, 2, 3, 3], False), ("mesh_none_y", [None, "y"], ((slice(None), slice(0, 1)), (slice(None), slice(1, 2))), (8, 1), [0, 0, 1, 1, 2, 2, 3, 3], False), ("mesh_xy", [("x", "y")], ((slice(0, 1), slice(None)), (slice(1, 2), slice(None))), (1, 2), [0, 0, 0, 0, 0, 0, 0, 0], False), ("mesh_fully_replicated", [], ((slice(None), slice(None)), (slice(None), slice(None))), (8, 2), [0, 1, 2, 3, 4, 5, 6, 7], True), ) def test_gda_2d_shard(self, mesh_axes, expected_index, expected_shard_shape, expected_replica_ids, expected_is_fully_replicated): global_mesh = create_global_mesh((4, 2), ('x', 'y')) global_input_shape = (8, 2) global_input_data = np.arange( prod(global_input_shape)).reshape(global_input_shape) def cb(index): return global_input_data[index] gda = GlobalDeviceArray.from_callback(global_input_shape, global_mesh, mesh_axes, cb) self.assertEqual(gda.local_shards[0].index, expected_index[0]) self.assertArraysEqual(gda.local_data(0), global_input_data[expected_index[0]]) self.assertEqual(gda.local_shards[1].index, expected_index[1]) self.assertArraysEqual(gda.local_data(1), global_input_data[expected_index[1]]) self.assertEqual(gda.local_data(0).shape, expected_shard_shape) replica_ids = [i.replica_id for i in gda.local_shards] self.assertListEqual(replica_ids, expected_replica_ids) self.assertListEqual([i.device.id for i in gda.local_shards], [0, 1, 2, 3, 4, 5, 6, 7]) self.assertEqual(gda.is_fully_replicated, expected_is_fully_replicated) for s in gda.local_shards: self.assertEqual(s.data.aval, core.ShapedArray(expected_shard_shape, s.data.dtype)) for g, l in safe_zip(gda.global_shards, gda.local_shards): self.assertEqual(g.device, l.device) self.assertEqual(g.index, l.index) self.assertEqual(g.replica_id, l.replica_id) self.assertEqual(g.data.aval, l.data.aval) self.assertArraysEqual(g.data, l.data) @parameterized.named_parameters( ("mesh_x_y_z", ["x", "y", "z"], # There are more slices but for convienient purposes, checking for only # 2. The indices + shard_shape + replica_id should be unique enough. ((slice(0, 4), slice(0, 2), slice(0, 1)), (slice(0, 4), slice(0, 2), slice(1, 2))), (4, 2, 1), [0, 0, 0, 0, 0, 0, 0, 0]), ("mesh_xy_z", [("x", "y"), "z"], ((slice(0, 2), slice(0, 2), slice(None)), (slice(0, 2), slice(2, 4), slice(None))), (2, 2, 2), [0, 0, 0, 0, 0, 0, 0, 0]), ("mesh_z", ["z"], ((slice(0, 4), slice(None), slice(None)), (slice(4, 8), slice(None), slice(None))), (4, 4, 2), [0, 0, 1, 1, 2, 2, 3, 3]), ) def test_gda_3d_shard(self, mesh_axes, expected_index, expected_shard_shape, expected_replica_ids): global_mesh = create_global_mesh((2, 2, 2), ('x', 'y', 'z')) global_input_shape = (8, 4, 2) global_input_data = np.arange( prod(global_input_shape)).reshape(global_input_shape) def cb(index): return global_input_data[index] gda = GlobalDeviceArray.from_callback(global_input_shape, global_mesh, mesh_axes, cb) self.assertEqual(gda.local_shards[0].index, expected_index[0]) self.assertArraysEqual(gda.local_data(0), global_input_data[expected_index[0]]) self.assertEqual(gda.local_shards[1].index, expected_index[1]) self.assertArraysEqual(gda.local_data(1), global_input_data[expected_index[1]]) self.assertEqual(gda.local_data(0).shape, expected_shard_shape) replica_ids = [i.replica_id for i in gda.local_shards] self.assertListEqual(replica_ids, expected_replica_ids) @parameterized.named_parameters( ("mesh_x", ["x"], # There are more slices but for convienient purposes, checking for only # 2. The indices + shard_shape + replica_id should be unique enough. ((slice(0, 2),), (slice(2, 4),)), (2,), [0, 0, 0, 0, 0, 0, 0, 0]), ("mesh_none", [], ((slice(None),), (slice(None),)), (16,), [0, 1, 2, 3, 4, 5, 6, 7]), ) def test_gda_1d_shard(self, mesh_axes, expected_index, expected_shard_shape, expected_replica_ids): global_mesh = create_global_mesh((8,), ('x')) global_input_shape = (16,) global_input_data = np.arange(prod(global_input_shape)).reshape(-1) def cb(index): return global_input_data[index] gda = GlobalDeviceArray.from_callback(global_input_shape, global_mesh, mesh_axes, cb) self.assertEqual(gda.local_shards[0].index, expected_index[0]) self.assertArraysEqual(gda.local_data(0), global_input_data[expected_index[0]]) self.assertEqual(gda.local_shards[1].index, expected_index[1]) self.assertArraysEqual(gda.local_data(1), global_input_data[expected_index[1]]) self.assertEqual(gda.local_data(0).shape, expected_shard_shape) replica_ids = [i.replica_id for i in gda.local_shards] self.assertListEqual(replica_ids, expected_replica_ids) @parameterized.named_parameters( ("mesh_x_y", ["x", "y"], # There are more slices but for convienient purposes, checking for only # 2. The indices + shard_shape + replica_id should be unique enough. ((slice(0, 4), slice(0, 1)), (slice(0, 4), slice(1, 2))), (4, 1), [0, 0, 0, 0]), ) def test_gda_subset_devices(self, mesh_axes, expected_index, expected_shard_shape, expected_replica_ids): global_mesh = create_global_mesh((2, 2), ('x', 'y')) global_input_shape = (8, 2) global_input_data = np.arange( prod(global_input_shape)).reshape(global_input_shape) def cb(index): return global_input_data[index] gda = GlobalDeviceArray.from_callback(global_input_shape, global_mesh, mesh_axes, cb) self.assertEqual(gda.local_shards[0].index, expected_index[0]) self.assertArraysEqual(gda.local_data(0), global_input_data[expected_index[0]]) self.assertEqual(gda.local_shards[1].index, expected_index[1]) self.assertArraysEqual(gda.local_data(1), global_input_data[expected_index[1]]) self.assertEqual(gda.local_data(0).shape, expected_shard_shape) replica_ids = [i.replica_id for i in gda.local_shards] self.assertListEqual(replica_ids, expected_replica_ids) for g, l in safe_zip(gda.global_shards, gda.local_shards): self.assertEqual(g.device, l.device) self.assertEqual(g.index, l.index) self.assertEqual(g.replica_id, l.replica_id) self.assertArraysEqual(g.data, l.data) def test_gda_batched_callback(self): global_mesh = create_global_mesh((4, 2), ('x', 'y')) global_input_shape = (8, 2) mesh_axes = [('x', 'y')] global_input_data = np.arange( prod(global_input_shape)).reshape(global_input_shape) def cb(indices): self.assertEqual(len(indices), len(global_mesh.local_devices)) return [global_input_data[index] for index in indices] gda = GlobalDeviceArray.from_batched_callback( global_input_shape, global_mesh, mesh_axes, cb) expected_first_shard_value = np.array([[0, 1]]) self.assertArraysEqual(gda.local_data(0).to_py(), expected_first_shard_value) expected_second_shard_value = np.array([[2, 3]]) self.assertArraysEqual(gda.local_data(1).to_py(), expected_second_shard_value) def test_gda_batched_callback_with_devices(self): global_mesh = create_global_mesh((4, 2), ('x', 'y')) global_input_shape = (8, 2) mesh_axes = ['x'] global_input_data = np.arange( prod(global_input_shape), dtype=np.float32).reshape(global_input_shape) def cb(cb_inp): self.assertLen(cb_inp, 4) dbs = [] for inp in cb_inp: index, devices = inp self.assertLen(devices, 2) array = global_input_data[index] dbs.extend([jax.device_put(array, device) for device in devices]) return dbs gda = GlobalDeviceArray.from_batched_callback_with_devices( global_input_shape, global_mesh, mesh_axes, cb) expected_first_shard_value = np.array([[0, 1], [2, 3]], dtype=np.float32) self.assertArraysEqual(gda.local_data(0).to_py(), expected_first_shard_value) expected_second_shard_value = np.array([[0, 1], [2, 3]], dtype=np.float32) self.assertArraysEqual(gda.local_data(1).to_py(), expected_second_shard_value) def test_gda_str_repr(self): global_mesh = create_global_mesh((4, 2), ('x', 'y')) global_input_shape = (8, 2) mesh_axes = [('x', 'y')] global_input_data = np.arange( prod(global_input_shape)).reshape(global_input_shape) def cb(index): return global_input_data[index] gda = GlobalDeviceArray.from_callback( global_input_shape, global_mesh, mesh_axes, cb) self.assertEqual(str(gda), 'GlobalDeviceArray(shape=(8, 2), dtype=int32)') self.assertEqual( repr(gda), ("GlobalDeviceArray(shape=(8, 2), dtype=int32, " "global_mesh_shape={'x': 4, 'y': 2}, mesh_axes=[('x', 'y')])"))
def testNoopPartitionSpecs(self): noops = [P(), P(None), P(()), P((), None), P(None, None, ())] x = jnp.arange(8).reshape((2, 2, 2)) for spec in noops: y = pjit(lambda x: x * 2, in_axis_resources=spec, out_axis_resources=spec)(x) self.assertAllClose(y, x * 2)
def test_checkpointing(self): global_mesh = create_global_mesh((4, 2), ('x', 'y')) global_input_shape = (8, 2) mesh_axes = P('x', 'y') num = util.prod(global_input_shape) # First GDA global_input_data1 = np.arange(num).reshape(global_input_shape) def cb1(index): return global_input_data1[index] gda1 = GlobalDeviceArray.from_callback(global_input_shape, global_mesh, mesh_axes, cb1) ckpt_dir1 = pathlib.Path(self.create_tempdir('first').full_path) # Second GDA global_input_data2 = np.arange(num, num + num).reshape(global_input_shape) def cb2(index): return global_input_data2[index] gda2 = GlobalDeviceArray.from_callback(global_input_shape, global_mesh, mesh_axes, cb2) ckpt_dir2 = pathlib.Path(self.create_tempdir('second').full_path) # Third GDA def cb3(index): return np.array([]) global_mesh1d = create_global_mesh((8, ), ('x', )) gda3 = GlobalDeviceArray.from_callback((0, ), global_mesh1d, P(None), cb3) ckpt_dir3 = pathlib.Path(self.create_tempdir('third').full_path) ckpt_paths = [str(ckpt_dir1), str(ckpt_dir2), str(ckpt_dir3)] tspecs = jax.tree_map(serialization.get_tensorstore_spec, ckpt_paths) serialization.run_serialization([gda1, gda2, gda3], tspecs) m1, m2, m3 = serialization.run_deserialization( [global_mesh, global_mesh, global_mesh1d], [mesh_axes, P('x'), P(None)], tspecs) self.assertArraysEqual(m1.local_shards[0].data.to_py(), np.array([[0], [2]])) self.assertArraysEqual(m1.local_shards[1].data.to_py(), np.array([[1], [3]])) self.assertEqual(m1.local_shards[0].data.shape, (2, 1)) self.assertEqual(m1.dtype, np.int32) self.assertArraysEqual(m2.local_shards[0].data.to_py(), np.array([[16, 17], [18, 19]])) self.assertArraysEqual(m2.local_shards[1].data.to_py(), np.array([[16, 17], [18, 19]])) self.assertEqual(m2.local_shards[0].data.shape, (2, 2)) self.assertEqual(m2.dtype, np.int32) for i, s in enumerate(m3.local_shards): self.assertEqual(s.index, (slice(None), )) self.assertEqual(s.replica_id, i) self.assertArraysEqual(s.data.to_py(), np.array([])) self.assertEqual(m3.dtype, np.float32)
def f(x): return lax.while_loop(lambda i: i[0,0] < 10., lambda i: with_sharding_constraint(i + 1., P(2, 1)), x)
def f(x): y = x + 1 y = with_sharding_constraint(y, P('x', 'y')) return y * 2
def f(x): y = x + 1 p, vjp_f = vjp(lambda z: jnp.sin(with_sharding_constraint(z, P(2,2))), y) return vjp_f(p)
def testNonHashableAxisResources(self): x = jnp.arange(4) y = pjit(lambda x: {'b': x['a'] + 2}, in_axis_resources=({'a': P('x')},), out_axis_resources={'b': P('x')})({'a': x}) self.assertAllClose(y, {'b': x + 2})
def __init__(self, config): self.config = config optimizer = config["optimizer"] bf16_optimizer = config.get("bf16_optimizer", False) early_cast = config.get("early_cast", False) early_collect = config.get("early_collect", True) def embedding(x): x = maybe_shard(x, P("dp", None)) return EmbeddingShardV2(config)(x) def residual(x, mask): out = x + TransformerLayerShardV2( config, init_scale=2. / config["layers"])(x, mask) return maybe_shard(out, P("dp", None, "mp")) def transformer(x, mask): return hk.remat(residual)(x, mask) def projection(x): return Projection(config)(x) def init_fns(): embed_init_fn = hk.transform( hk.experimental.optimize_rng_use(embedding)).init transformer_init_fn = hk.transform( hk.experimental.optimize_rng_use(transformer)).init projection_init_fn = hk.transform( hk.experimental.optimize_rng_use(projection)).init return embed_init_fn, transformer_init_fn, projection_init_fn def shard_strategy(shape_dtype, parallel): if shape_dtype.ndim <= 1: return P() # embedding/projection layers elif shape_dtype.shape == (config["n_vocab"], config["d_model"]): return P(parallel, None) elif shape_dtype.shape == (config["d_model"], config["n_vocab"]): return P(None, parallel) # a transformer layer elif shape_dtype.shape[0] == config["layers"]: if shape_dtype.ndim == 2: # a channel wise variable (e.g. layernorm parameters) # replicate it for speed return P(None) elif shape_dtype.ndim == 3: # a weight matrix matrix_size = shape_dtype.shape[1:] assert matrix_size[0] != matrix_size[ 1] # this case is ambiguous if matrix_size[0] == config["d_model"]: # shard along the axis which is _not_ the model dimension return P(None, None, parallel) elif matrix_size[1] == config["d_model"]: return P(None, parallel, None) else: raise NotImplementedError("borked") else: raise NotImplementedError("borked") def init(key, x): embed_init_fn, transformer_init_fn, projection_init_fn = init_fns() def init_scan_fn(key, x): new_key, key = jax.random.split(key) return new_key, transformer_init_fn(key, x, 0) e_key, t_key, p_key = jax.random.split(key, 3) input_shape = (config["layers"], ) + x.shape + ( config["d_model"], ) params = { "embed": embed_init_fn(e_key, x), "transformer": jax.lax.scan(init_scan_fn, t_key, xs=jax.random.uniform(t_key, input_shape, dtype=jnp.float32))[1], "proj": projection_init_fn( p_key, jax.random.uniform(t_key, input_shape[1:], dtype=jnp.float32)), } return { "params": (to_bf16 if early_cast else to_f32)(params), "step": np.array(0), "opt_state": optimizer.init((to_bf16 if bf16_optimizer else to_f32)(params)) } assert thread_resources.env.shape['mp'] == config["cores_per_replica"] dp = thread_resources.env.shape['dp'] mp = thread_resources.env.shape['mp'] key = hk.PRNGSequence(42) x = jax.random.uniform(next(key), (mp * dp, 16), minval=0, maxval=1).astype(jnp.uint32) # batch, seq head_print("starting shape evaluation") param_shapes = jax.eval_shape(init, jax.random.PRNGKey(42), x) state_shard = { "step": P(), # zero level 1: shard optimizer states over both MP and DP "opt_state": jax.tree_map(partial(shard_strategy, parallel=["mp", "dp"]), param_shapes["opt_state"]), # fp32 params are also sharded (so this is like a weird mix between zero-1 and zero-3...) "params": jax.tree_map(partial(shard_strategy, parallel=["mp", "dp"]), param_shapes["params"]), } head_print("sharding strategy:") jax.tree_multimap(head_print, state_shard, param_shapes) self.init_pjit = pjit(init, in_axis_resources=(None, P("dp")), out_axis_resources=state_shard) def apply_fns(): embed_apply_fn = hk.without_apply_rng( hk.transform(embedding)).apply transformer_apply_fn = hk.without_apply_rng( hk.transform(transformer)).apply return embed_apply_fn, transformer_apply_fn def train_apply_fn(params, x, y): embed_apply_fn, transformer_apply_fn = apply_fns() def train_loss(x, y): loss, _ = Projection(config).loss(x, y, z_loss=1.0) return loss.mean(), loss[:, -1].mean() projection_apply_fn = hk.without_apply_rng( hk.transform(train_loss)).apply x = embed_apply_fn(params["embed"], x) x = to_bf16(x) def apply_scan_fn(x, layer_state): return to_bf16(transformer_apply_fn(layer_state, x, 0)), None x = jax.lax.scan(apply_scan_fn, x, xs=params["transformer"])[0] return projection_apply_fn(params["proj"], x, y) mp_shard_strategy = jax.tree_map( partial(shard_strategy, parallel=["mp"]), param_shapes["params"]) def train(state, ctx, tgt): if early_collect: bf16_params = maybe_shard(to_bf16(state["params"]), mp_shard_strategy) else: bf16_params = to_bf16(state["params"]) def microbatch(old_grad, batch): ctx, tgt = batch val_grad_fn = jax.value_and_grad(train_apply_fn, has_aux=True, allow_int=True) (loss, last_loss), grad = val_grad_fn(bf16_params, ctx, tgt) new_grad = jax.tree_multimap(lambda a, b: a + b, old_grad, grad) return new_grad, (loss, last_loss) if ctx.shape[0] == 1: val_grad_fn = jax.value_and_grad(train_apply_fn, has_aux=True, allow_int=True) (loss, last_loss), grad = val_grad_fn(bf16_params, ctx[0], tgt[0]) else: grad, (loss, last_loss) = jax.lax.scan( microbatch, jax.tree_map( lambda x: jnp.zeros_like(x).astype(jnp.bfloat16), bf16_params), (ctx, tgt)) updates, new_opt_state = optimizer.update(grad, state["opt_state"], state["params"]) return to_f32(loss), to_f32(last_loss), { "params": optax.apply_updates(state["params"], to_f32(updates)), "step": state["step"] + 1, "opt_state": new_opt_state, } self.train_pjit = pjit(train, in_axis_resources=(state_shard, P(None, "dp"), P(None, "dp")), out_axis_resources=(None, None, state_shard), donate_argnums=(0, )) def eval_apply_fn(params, x, y, mask): embed_apply_fn, transformer_apply_fn = apply_fns() if early_collect: bf16_params = maybe_shard(to_bf16(params), mp_shard_strategy) else: bf16_params = to_bf16(params) def eval_loss(x, y): loss, correct = Projection(config).loss(x, y) return { "loss": loss.mean(axis=-1), "last_loss": loss[:, -1], "all_loss": loss, "correct": correct } projection_apply_fn = hk.without_apply_rng( hk.transform(eval_loss)).apply x = embed_apply_fn(bf16_params["embed"], x) def apply_scan_fn(layer_in, layer_state): x, mask = layer_in return (to_bf16(transformer_apply_fn(layer_state, x, mask)), mask), None x = jax.lax.scan(apply_scan_fn, (to_bf16(x), mask), xs=bf16_params["transformer"])[0][0] return projection_apply_fn(bf16_params["proj"], x, y) def eval(params, ctx, tgt, ctx_length): mask = (jnp.arange(0, ctx.shape[1])[None, :] > ctx_length[:, None]) * -1e10 # head_print("mask.shape", mask.shape) # head_print("ctx.shape", ctx.shape) # head_print("ctx_length.shape", ctx_length.shape) return eval_apply_fn(params, ctx, tgt, mask[:, None, None, :]) self.eval_pjit = pjit( eval, in_axis_resources=(mp_shard_strategy if early_collect else state_shard["params"], P("dp"), P("dp"), P("dp")), out_axis_resources=P("dp")) self.move_weights_pjit = pjit( lambda x: to_bf16(x), in_axis_resources=(state_shard["params"], ), out_axis_resources=mp_shard_strategy if early_collect else state_shard["params"]) seq = config["seq"] vocab = config["n_vocab"] example_shape = ( max(dp // jax.host_count(), 1), seq, ) x = jax.random.uniform(next(key), example_shape, minval=0, maxval=vocab).astype(jnp.uint32) # batch, len head_print("in shape", x.shape) head_print("dp", dp) head_print("mp", mp) self.state = self.init_pjit(next(key), x) self.state_shard = state_shard self.eval_weights = None param_count = hk.data_structures.tree_size(self.state['params']) head_print(f"Total parameters: {param_count * dp}")