def testUndefinedResourcesConstraint(self, mesh, resources): x = jnp.ones((2, 2)) spec = P(resources, ) with self.assertRaisesRegex( ValueError, r"One of with_sharding_constraint arguments" r".*" + spec_regex(spec) + r", but resource axis " r"x is undefined."): pjit(lambda x: with_sharding_constraint(x, spec), in_axis_resources=None, out_axis_resources=None)(x)
def testRankTooLowOuts(self): x = jnp.arange(2) spec = P('x', 'y') error = (r"One of pjit outputs.*" + spec_regex(spec) + r", which implies " r"that it has a rank of at least 2, but it is 0") with self.assertRaisesRegex(ValueError, error): pjit(lambda x: x.sum(), in_axis_resources=None, out_axis_resources=spec)(x)
def testCaching(self): def f(x): assert should_be_tracing return jnp.sin(x) * 2 x = np.arange(16).reshape(4, 4) devices = np.array(list(jax.local_devices())[:4]) if devices.size < 4: raise unittest.SkipTest("Test requires 4 devices") devices = devices.reshape((2, 2)) with mesh(devices, ('x', 'y')): should_be_tracing = True pjit(f, in_axis_resources=P(('x', 'y')), out_axis_resources=None)(x) should_be_tracing = False pjit(f, in_axis_resources=P(('x', 'y')), out_axis_resources=None)(x) # Re-create the mesh to make sure that has no influence on caching with mesh(devices, ('x', 'y')): should_be_tracing = False pjit(f, in_axis_resources=P(('x', 'y')), out_axis_resources=None)(x)
def test_pjit(self): with tempfile.TemporaryDirectory() as tmpdir: cc.initialize_cache(tmpdir) @partial(pjit, in_axis_resources=(P('x'), P('x')), out_axis_resources=None) def f(x, y): return x + y shape = (8, 8) x = np.arange(prod(shape), dtype=np.int64).reshape(shape) f(x, x + 1) files_in_directory = len(os.listdir(tmpdir)) self.assertEqual(files_in_directory, 1) x = np.arange(prod(shape), dtype=np.float32).reshape(shape) f(x, x + 1) files_in_directory = len(os.listdir(tmpdir)) self.assertEqual(files_in_directory, 2)
def testRankTooLowConstraint(self): x = jnp.arange(2) spec = P('x', 'y') error = (r"One of with_sharding_constraint arguments " + r"was given.*" + spec_regex(spec) + r", which implies " r"that it has a rank of at least 2, but it is 1") with self.assertRaisesRegex(ValueError, error): pjit(lambda x: with_sharding_constraint(x, spec), in_axis_resources=None, out_axis_resources=None)(x)
def f(x): with mesh(np.array([jax.local_devices()[0]]), ('x')): @partial(pjit, in_axis_resources=P('x'), out_axis_resources=None) def h(x): return x return h(x)
def test_array_sharded_astype(self): with jax._src.config.jax_array(True): global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) input_shape = (8, 2) arr, input_data = create_array( input_shape, sharding.MeshPspecSharding(global_mesh, P('x', 'y'))) arr_float32 = arr.astype(jnp.float32) self.assertEqual(arr_float32.dtype, np.float32) self.assertArraysEqual(arr_float32, input_data.astype(np.float32))
def testPyTreeOutputs(self): if jax.device_count() < 2: raise SkipTest def f(x): return x + 1, ((x + 2, x + 3), x + 4) shape = (4, 4) x = np.arange(prod(shape)).reshape(shape) in_parts = (P(2, 1),) out_parts = (P(2, 1), ((P(1, 2), None), P(2, 1))) result = sharded_jit(f, in_parts, out_parts)(x) expected = f(x) self.assertAllClose(result, expected, check_dtypes=False) out_parts = None result = sharded_jit(f, in_parts, out_parts)(x) self.assertAllClose(result, expected, check_dtypes=False)
def testShardingConstraint(self): if jax.local_device_count() < 2: raise SkipTest("requires 2 devices") def f(x): y = x + 1 y = with_sharding_constraint(y, P(1, 2)) return y * 2 shape = (8, 8) x = np.arange(prod(shape)).reshape(shape) expected = (x + 1) * 2 # Matching sharded_jit partitions actual = sharded_jit(f, in_parts=P(2, 1), out_parts=P(2, 1))(x) self.assertAllClose(actual, expected, check_dtypes=False) self.assertLen(actual.device_buffers, 2) # TODO(jblespiau): We can simply use buf.xla_shape() when version 0.1.58 is # the default. self.assertEqual( getattr(actual.device_buffers[0], "xla_shape", actual.device_buffers[0].shape)().dimensions(), (4, 8)) self.assertEqual( getattr(actual.device_buffers[1], "xla_shape", actual.device_buffers[1].shape)().dimensions(), (4, 8)) # Mismatched sharded_jit partitions with self.assertRaisesRegex( ValueError, r"with_sharding_constraint with partitions=PartitionSpec\(1, 2\) " r"\(total partitions: 2\) doesn't match expected number of partitions: " r"4. If these partitions look right, check outer sharded_jit and/or " r"other with_sharding_constraint calls."): sharded_jit(f, in_parts=P(2, 2), out_parts=P(2, 2))(x) # Replicated sharded_jit actual = sharded_jit(f, in_parts=None, out_parts=None)(x) self.assertAllClose(actual, expected, check_dtypes=False) self.assertLen(actual.device_buffers, 2) self.assertAllClose(actual.device_buffers[0].to_py(), actual.device_buffers[1].to_py(), check_dtypes=False)
def testInAxesNone(self): shape = (4, 4) replicas = 2 in_partitions = (P(2, 1), None, None) out_partitions = P(2, 1) in_axes = (None, None, 0) x = y = np.arange(prod(shape), dtype=np.float32).reshape(shape) dummy = np.arange(replicas, dtype=np.float32) + 1 num_shards = replicas * np.prod(in_partitions[0]) if num_shards > jax.local_device_count(): raise SkipTest("requires %d devices" % num_shards) def f(x, y, _): return x @ y result = pmap( sharded_jit(f, in_parts=in_partitions, out_parts=out_partitions), in_axes=in_axes)(x, y, dummy) expected = pmap(f, in_axes=in_axes)(x, y, dummy) self.assertAllClose(result, expected, check_dtypes=True)
def testNonDivisibleConstraint(self, mesh, resources): x = jnp.ones((3, 2)) spec = P(resources,) mesh_size = str(np.prod([dim[1] for dim in mesh], dtype=np.int64)) with self.assertRaisesRegex(ValueError, r"One of with_sharding_constraint arguments" r".*" + spec_regex(spec) + r".*implies that the size of " r"its dimension 0 should be divisible by " + mesh_size + r", but it is equal to 3"): pjit(lambda x: with_sharding_constraint(x, spec), in_axis_resources=None, out_axis_resources=None)(x)
def testNonDivisibleArgs(self, mesh, resources): x = jnp.ones((3, 2)) spec = P(resources, None) mesh_size = str(np.prod([dim[1] for dim in mesh], dtype=np.int64)) with self.assertRaisesRegex( ValueError, r"One of pjit arguments.*" + spec_regex(spec) + r".*" r"implies that the size of its dimension 0 should be " r"divisible by " + mesh_size + r", but it is equal to 3"): pjit(lambda x: x, in_axis_resources=spec, out_axis_resources=None)(x)
def testConstraintShardsXMapAxis(self): spec = P('x') f = xmap(lambda x: with_sharding_constraint(x, axis_resources=spec), in_axes=['i', ...], out_axes=['i', ...], axis_resources={'i': 'x'}) x = jnp.arange(4).reshape((2, 2)) error = (r"with_sharding_constraint input has an axis resources specification of " + spec_regex(spec) + r" that uses one or more mesh axes already used by " r"xmap to partition a named axis appearing in its named_shape \(both " r"use mesh axes `x`\)") with self.assertRaisesRegex(JAXTypeError, error): f(x)
def input(self, x): # [batch, seq, dim] projected = self.input_proj(x) # [batch, seq, mp, dim//mp] projected = maybe_shard(projected, P("dp", None, "mp")) mp_split = jnp.reshape(projected, projected.shape[:-1] + (self.mp_num, -1)) mp_split = maybe_shard(mp_split, P("dp", None, "mp", None)) local_dim = self.d_head * self.n_head // self.mp_num q, v, k, ff = jnp.split(mp_split, [local_dim, local_dim * 2, local_dim * 3], axis=-1) q = self.head_split(q) v = self.head_split(v) k = self.head_split(k) return q, v, k, ff
def test_array_delete(self): with jax._src.config.jax_array(True): global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) input_shape = (8, 2) arr, _ = create_array( input_shape, sharding.MeshPspecSharding(global_mesh, P('x', 'y'))) arr.delete() with self.assertRaisesRegex(ValueError, 'Array has been deleted.'): arr._check_if_deleted() self.assertIsNone(arr._npy_value) self.assertIsNone(arr._arrays)
def testNotEnoughDevices(self): ndevices = jax.local_device_count() @partial(sharded_jit, in_parts=P(ndevices + 1), out_parts=None) def f(x): return x + x with self.assertRaisesRegex( ValueError, f"sharded_jit computation requires {ndevices + 1} devices, " f"but only {ndevices} devices are available."): f(np.ones(ndevices + 1))
def test_checkpointing_with_bigger_shape(self): global_mesh = create_global_mesh((2, 2), ('x', 'y')) global_input_shape = (8, 2) num = util.prod(global_input_shape) # First GDA global_input_data1 = np.arange(num).reshape(global_input_shape) def cb1(index): return global_input_data1[index] gda1 = GlobalDeviceArray.from_callback(global_input_shape, global_mesh, P('x', 'y'), cb1) ckpt_dir1 = pathlib.Path(self.create_tempdir('first').full_path) ckpt_paths = [str(ckpt_dir1)] tspecs = jax.tree_map(serialization.get_tensorstore_spec, ckpt_paths) serialization.run_serialization([gda1], tspecs) m1, = serialization.run_deserialization( [create_global_mesh((4, 2), ('x', 'y'))], [P('x', 'y')], tspecs, [(12, 2)], ) expected_data = { 0: np.array([[0], [2], [4]]), 1: np.array([[1], [3], [5]]), 2: np.array([[6], [8], [10]]), 3: np.array([[7], [9], [11]]), 4: np.array([[12], [14], [0]]), 5: np.array([[13], [15], [0]]), 6: np.array([[0], [0], [0]]), 7: np.array([[0], [0], [0]]), } for l in m1.local_shards: self.assertArraysEqual(l.data.to_py(), expected_data[l.device.id])
def shard_strategy(shape_dtype, parallel): if shape_dtype.ndim <= 1: return P() # embedding/projection layers elif shape_dtype.shape == (config["n_vocab"], config["d_model"]): return P(parallel, None) elif shape_dtype.shape == (config["d_model"], config["n_vocab"]): return P(None, parallel) # a transformer layer elif shape_dtype.shape[0] == config["layers"]: if shape_dtype.ndim == 2: # a channel wise variable (e.g. layernorm parameters) # replicate it for speed return P(None) elif shape_dtype.ndim == 3: # a weight matrix matrix_size = shape_dtype.shape[1:] assert matrix_size[0] != matrix_size[ 1] # this case is ambiguous if matrix_size[0] == config["d_model"]: # shard along the axis which is _not_ the model dimension return P(None, None, parallel) elif matrix_size[1] == config["d_model"]: return P(None, parallel, None) else: raise NotImplementedError("borked") else: raise NotImplementedError("borked")
def test_pjit_gsda_wrong_resource_for_gsda_input(self): global_mesh = create_global_mesh((4, 2), ('x', 'y')) global_input_shape = (8, 2) mesh_axes = ['x'] global_input_data = np.arange( prod(global_input_shape), dtype=np.float32).reshape(global_input_shape) def cb(index): return global_input_data[index] gda_obj = global_device_array.GlobalDeviceArray.from_callback( global_input_shape, global_mesh, mesh_axes, cb) with self.assertRaisesWithLiteralMatch(ValueError, ( "Got an input GDA to pjit with different partitioning than specified " "in the in_axis_resources argument to pjit. The paritioning must " 'match, or use `jax.experimental.pjit.FROM_GDA` in `in_axis_resources`. ' "Got GDA spec: <partitions=(('x',),) sync=2>, " "pjit spec: <partitions=(('x',), ('y',)) sync=2>")): @partial(pjit, in_axis_resources=P('x', 'y'), out_axis_resources=P('x', 'y')) def f(x): return x f(gda_obj)
def testNestedDifferentResources(self): @partial(pjit, in_axis_resources=P('x'), out_axis_resources=None) def f(x): with mesh(np.array([jax.local_devices()[0]]), ('x')): @partial(pjit, in_axis_resources=P('x'), out_axis_resources=None) def h(x): return x return h(x) xshape = (2, 5, 6) x = jnp.arange(np.prod(xshape)).reshape(xshape) with self.assertRaisesRegex(RuntimeError, "Changing the physical mesh is not allowed.*"): f(x)
def test_gda_block_until_ready(self): global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) global_input_shape = (8, 2) mesh_axes = P(('x', 'y')) global_input_data = np.arange( prod(global_input_shape)).reshape(global_input_shape) def cb(index): return global_input_data[index] gda = GlobalDeviceArray.from_callback( global_input_shape, global_mesh, mesh_axes, cb) self.assertTrue(gda.block_until_ready() is gda)
def test_partition_spec_mismatch_semantically_equivalent(self): global_mesh = create_global_mesh((4, 2), ('x', 'y')) global_input_shape = (8, 2) mesh_axes = [None] global_input_data = np.arange( prod(global_input_shape), dtype=np.float32).reshape(global_input_shape) def cb(index): return global_input_data[index] with jax._src.config.parallel_functions_output_gda(True): gda_obj = global_device_array.GlobalDeviceArray.from_callback( global_input_shape, global_mesh, mesh_axes, cb) @partial(pjit, in_axis_resources=P(None), out_axis_resources=P(None)) def f(x): return x output_gda = f(gda_obj) # Ensure output_gda._mesh_axes = P() is matched with P(None). self.assertEqual(output_gda._mesh_axes, ()) # P(None) is in_axis_resources. f(output_gda)
def testLowerCompile(self): @partial(pjit, in_axis_resources=P(('x', 'y'),), out_axis_resources=P(('x', 'y'),)) def f(x, y): return x @ y shape = (8, 8) x = jnp.arange(np.prod(shape)).reshape(shape) expected = x @ (x + 1) exe = f.lower(x, x + 1).compile() actual = exe(x, x + 1) splits = np.split(expected, 4) self.assertAllClose(actual.device_buffers[0].to_py(), splits[0], check_dtypes=False) self.assertAllClose(actual.device_buffers[1].to_py(), splits[1], check_dtypes=False) self.assertAllClose(actual.device_buffers[2].to_py(), splits[2], check_dtypes=False) self.assertAllClose(actual.device_buffers[3].to_py(), splits[3], check_dtypes=False)
def testVmapModifiesAxisResources(self): h = pjit(lambda x, y: (x + y, x, y), in_axis_resources=P('x'), out_axis_resources=None) x = jnp.arange(4) y = jnp.arange(5*4).reshape((5, 4)) jaxpr = jax.make_jaxpr(jax.vmap(h, in_axes=(None, 0)))(x, y).jaxpr eqn = jaxpr.eqns[0] self.assertIs(eqn.primitive, pjit_p) x_sync, y_sync = (spec.sync for spec in eqn.params['in_axis_resources']) self.assertEqual(x_sync, SpecSync.IN_SYNC) self.assertEqual(y_sync, SpecSync.DIM_PERMUTE) x_sync, y_sync, z_sync = (spec.sync for spec in eqn.params['out_axis_resources']) self.assertEqual(x_sync, SpecSync.DIM_PERMUTE) self.assertEqual(y_sync, SpecSync.IN_SYNC) self.assertEqual(z_sync, SpecSync.DIM_PERMUTE)
def test_gda_equality_raises_not_implemented(self): global_mesh = jtu.create_global_mesh((1, 2), ('x', 'y')) global_input_shape = (8, 2) mesh_axes = P(None,) global_input_data = np.arange( prod(global_input_shape)).reshape(global_input_shape) def cb(index): return global_input_data[index] input_gda = GlobalDeviceArray.from_callback( global_input_shape, global_mesh, mesh_axes, cb) same_input_gda = GlobalDeviceArray.from_callback( global_input_shape, global_mesh, mesh_axes, cb) with self.assertRaisesRegex(NotImplementedError, 'GlobalDeviceArray equality is intentionally unimplemented.'): input_gda == same_input_gda
def testTwoMeshAxisSharding(self): @partial(pjit, in_axis_resources=P(('x', 'y'),), out_axis_resources=P(('x', 'y'),)) def f(x, y): return x @ y shape = (8, 8) x = jnp.arange(np.prod(shape)).reshape(shape) actual = f(x, x + 1) expected = x @ (x + 1) self.assertAllClose(actual, expected, check_dtypes=False) self.assertIsInstance(actual, pxla.ShardedDeviceArray) self.assertLen(actual.device_buffers, 4) splits = np.split(expected, 4) self.assertAllClose(actual.device_buffers[0].to_py(), splits[0], check_dtypes=False) self.assertAllClose(actual.device_buffers[1].to_py(), splits[1], check_dtypes=False) self.assertAllClose(actual.device_buffers[2].to_py(), splits[2], check_dtypes=False) self.assertAllClose(actual.device_buffers[3].to_py(), splits[3], check_dtypes=False)
def test_async_checkpointing(self): global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) global_input_shape = (8, 2) mesh_axes = P('x', 'y') num = util.prod(global_input_shape) # First GDA global_input_data1 = np.arange(num).reshape(global_input_shape) def cb1(index): return global_input_data1[index] gda1 = GlobalDeviceArray.from_callback(global_input_shape, global_mesh, mesh_axes, cb1) temp_ckpt_dir1 = pathlib.Path( self.create_tempdir('temp_first').full_path) ckpt_dir1 = str(temp_ckpt_dir1).replace('temp_first', 'first') s_tspecs = jax.tree_map(serialization.get_tensorstore_spec, [str(temp_ckpt_dir1)]) manager = serialization.GlobalAsyncCheckpointManager() manager.serialize([gda1], s_tspecs, temp_checkpoint_dir=temp_ckpt_dir1, final_checkpoint_dir=ckpt_dir1) manager.wait_until_finished() d_tspecs = jax.tree_map(serialization.get_tensorstore_spec, [str(ckpt_dir1)]) m1, = manager.deserialize([global_mesh], [mesh_axes], d_tspecs) self.assertArraysEqual(m1.local_shards[0].data.to_py(), np.array([[0], [2]])) self.assertArraysEqual(m1.local_shards[1].data.to_py(), np.array([[1], [3]])) self.assertEqual(m1.local_shards[0].data.shape, (2, 1)) self.assertEqual(m1.dtype, np.int32) # Will throw `file already exists` error when `tf.io.gfile.rename`. # `wait_until_finished` will raise the error. with self.assertRaises(Exception): ckpt_dir1 = pathlib.Path(self.create_tempdir('first').full_path) manager1 = serialization.GlobalAsyncCheckpointManager() manager1.serialize([gda1], s_tspecs, temp_checkpoint_dir=temp_ckpt_dir1, final_checkpoint_dir=ckpt_dir1) manager1.wait_until_finished()
def test_mesh_pspec_sharding_interface(self): mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) pspec = P('y', 'x') global_shape = (8, 4) mp_sharding = sharding.MeshPspecSharding(mesh, pspec) di_map = mp_sharding.devices_indices_map(global_shape) op_sharding = mp_sharding._to_xla_op_sharding(len(global_shape)) device_assignment = mp_sharding._device_assignment() self.assertEqual(di_map[mesh.devices.flat[0]], (slice(0, 4), slice(0, 1))) self.assertArraysEqual(device_assignment, list(mesh.devices.flat)) self.assertEqual(op_sharding.type, xc.OpSharding.Type.OTHER) self.assertListEqual(op_sharding.tile_assignment_dimensions, [2, 4]) self.assertListEqual(op_sharding.tile_assignment_devices, [0, 2, 4, 6, 1, 3, 5, 7])
def test_gda_str_repr(self): global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) global_input_shape = (8, 2) mesh_axes = P(('x', 'y')) global_input_data = np.arange( prod(global_input_shape)).reshape(global_input_shape) def cb(index): return global_input_data[index] gda = GlobalDeviceArray.from_callback( global_input_shape, global_mesh, mesh_axes, cb) self.assertEqual(str(gda), 'GlobalDeviceArray(shape=(8, 2), dtype=int32)') self.assertEqual( repr(gda), ('GlobalDeviceArray(shape=(8, 2), dtype=int32, ' "global_mesh_shape={'x': 4, 'y': 2}, " "mesh_axes=PartitionSpec(('x', 'y'),))"))
def testDeviceBufferAval(self): @partial(pjit, in_axis_resources=None, out_axis_resources=P('x')) def f(x): return x shape = (2, 2) x = np.arange(prod(shape), dtype=np.float32).reshape(shape) actual = f(x) expected = x self.assertAllClose(actual, expected, check_dtypes=False) self.assertIsInstance(actual, pxla.ShardedDeviceArray) self.assertLen(actual.device_buffers, 1) self.assertAllClose( actual.device_buffers[0].to_py(), expected, check_dtypes=False) # Repro for a bug on device_buffer aval _ = repr(actual.device_buffers)