def testShardingConstraint(self): def f(x): y = x + 1 y = with_sharding_constraint(y, P(1, 2)) return y * 2 shape = (8, 8) x = np.arange(np.prod(shape)).reshape(shape) expected = (x + 1) * 2 # Matching sharded_jit partitions actual = sharded_jit(f, in_parts=P(2, 1), out_parts=P(2, 1))(x) self.assertAllClose(actual, expected, check_dtypes=False) self.assertLen(actual.device_buffers, 2) self.assertEqual(actual.device_buffers[0].shape().dimensions(), (4, 8)) self.assertEqual(actual.device_buffers[1].shape().dimensions(), (4, 8)) # Mismatched sharded_jit partitions with self.assertRaisesRegex( ValueError, r"with_sharding_constraint with partitions=PartitionSpec\(1, 2\) " r"\(total partitions: 2\) doesn't match expected number of partitions: " r"4. If these partitions look right, check outer sharded_jit and/or " r"other with_sharding_constraint calls."): sharded_jit(f, in_parts=P(2, 2), out_parts=P(2, 2))(x) # Replicated sharded_jit actual = sharded_jit(f, in_parts=None, out_parts=None)(x) self.assertAllClose(actual, expected, check_dtypes=False) self.assertLen(actual.device_buffers, 2) self.assertAllClose(actual.device_buffers[0].to_py(), actual.device_buffers[1].to_py(), check_dtypes=False)
def test_pjit_TwoMeshAxisSharding(self): @functools.partial(pjit.pjit, in_axis_resources=P(("x", "y"), ), out_axis_resources=P(("x", "y"), )) def jax_func(x, y): return x @ y x_shape = (24, 8) y_shape = (8, 2) x = jnp.arange(np.prod(x_shape), dtype=np.float32).reshape(x_shape) y = jnp.arange(np.prod(y_shape), dtype=np.float32).reshape(y_shape) self._check_sharding_annotations( jax_func, [x, y], expected=[ r"f32\[24,8\].*sharding={devices=\[4,1\]0,1,2,3", # x r"f32\[8,2\].*sharding={devices=\[4,1\]0,1,2,3", # y r"f32\[24,2\].*sharding={devices=\[4,1\]0,1,2,3", # output ], expected_opt=[ # TODO: relax ordering r"f32\[2,2\].*sharding={devices=\[4,1\]0,1,2,3", # y r"f32\[6,8\].*sharding={devices=\[4,1\]0,1,2,3", # x # TODO: why we cannot see .*sharding={devices=\[4,1\]0,1,2,3 r"f32\[1,6,2\]", # output ], num_partitions=4)
def test_pjit_basic1D(self): @functools.partial(pjit.pjit, in_axis_resources=(P("x"), P("x")), out_axis_resources=None) def jax_func(x, y): return x + y shape = (8, 10) x = np.arange(np.prod(shape), dtype=np.float32).reshape(shape) hlo = jax.xla_computation(jax_func)(x, x).as_hlo_text() print(f"HLO is {hlo}") print(f"JAXPR is {jax.make_jaxpr(jax_func)(x, x)}") self._check_sharding_annotations( jax_func, [x, x], expected=[ r"f32\[8,10\].*sharding={devices=\[2,1\]", # x and y r"f32\[8,10\].*sharding={replicated", # output ], expected_opt=[ r"f32\[4,10\].*sharding={devices=\[2,1\]", # x and y # TODO: why don't we see "sharding={replicated" r"f32\[8,10\]", # output ], num_partitions=2)
def test_pjit_basic2D(self): @functools.partial(pjit.pjit, in_axis_resources=(P(None, "x", "y"), P("y")), out_axis_resources=P("x")) def jax_func(x, y): return x @ y x_shape = (8, 6, 4) y_shape = (4, 2) x = jnp.arange(np.prod(x_shape), dtype=np.float32).reshape(x_shape) y = jnp.arange(np.prod(y_shape), dtype=np.float32).reshape(y_shape) self._check_sharding_annotations( jax_func, [x, y], expected=[ r"f32\[8,6,4\].*sharding={devices=\[1,2,2\]0,1,2,3", # x r"f32\[4,2\].*sharding={devices=\[2,1,2\]0,2,1,3 last_tile_dim_replicate", # y r"f32\[8,6,2\].*sharding={devices=\[2,1,1,2\]0,1,2,3 last_tile_dim_replicate", # output ], expected_opt=[ # TODO: relax ordering r"f32\[2,2\].*sharding={devices=\[2,1,2\]0,2,1,3 last_tile_dim_replicate", # y r"f32\[8,3,2\].*sharding={devices=\[1,2,2\]0,1,2,3", # x # TODO: why we cannot see sharding={devices=\[2,1,1,2\]0,1,2,3 last_tile_dim_replicate? r"bf16\[4,6,2\]", # output ], num_partitions=4)
def test_sharded_jit_with_sharding_constraint(self): """A sharding constraint in the middle.""" def jax_func(x, y): logits1 = jnp.dot(x, y) return jnp.sin( sharded_jit.with_sharding_constraint(logits1, P(2, 1))) sharded_jax_func = sharded_jit.sharded_jit(jax_func, in_parts=(P(1, 2), P(2, 1)), out_parts=P(1, 2)) xshape = (6, 8) x = np.arange(np.prod(xshape), dtype=np.float32).reshape(xshape) yshape = (8, 10) y = np.arange(np.prod(yshape), dtype=np.float32).reshape(yshape) self._check_sharding_annotations( sharded_jax_func, [x, y], expected=[ r"f32\[6,8\].*sharding={devices=\[1,2\]", r"f32\[8,10\].*sharding={devices=\[2,1\]", r"f32\[6,10\].*sharding={devices=\[2,1\]", r"f32\[6,10\].*sine.*sharding={devices=\[1,2\]" ], expected_opt=[ # TODO(necula): relax ordering r"f32\[4,10\].*sharding={devices=\[2,1\]", r"f32\[6,4\].*sharding={devices=\[1,2\]", ], num_partitions=2)
def test_sharded_jit_in_out(self): """Test input and output sharding annotations.""" sharded_jax_func = sharded_jit.sharded_jit(jnp.dot, in_parts=(P(1, 2), P(2, 1)), out_parts=P(1, 2)) xshape = (3, 8) x = np.arange(np.prod(xshape), dtype=np.float32).reshape(xshape) yshape = (8, 5) y = np.arange(np.prod(yshape), dtype=np.float32).reshape(yshape) self._check_sharding_annotations( sharded_jax_func, [x, y], expected=[ r"f32\[3,8\].*sharding={devices=\[1,2\]", r"f32\[8,5\].*sharding={devices=\[2,1\]", r"f32\[3,5\].*sharding={devices=\[1,2\]" ], expected_opt=[ # TODO(necula): relax ordering r"f32\[4,5\].*sharding={devices=\[2,1\]", r"f32\[3,4\].*sharding={devices=\[1,2\]", r"f32\[3,5\].*fusion", r"f32\[3,5\].*all-reduce", ], num_partitions=2)
def testPyTreeArgs(self): if jax.device_count() < 2: raise SkipTest def f(a, b, c): a1, a2 = a c1, (c2, c3) = c return a1 + a2 + b + c1 + c2 + c3 def _make_arg(*shape): return np.arange(np.prod(shape)).reshape(shape) a = (_make_arg(4, 4), 1) b = _make_arg(4, 4) c = [2, (_make_arg(4, 4), _make_arg(4, 4))] in_parts = (None, P(2, 1), [None, P(2, 1)]) out_parts = P(2, 1) result = sharded_jit(f, in_parts, out_parts)(a, b, c) expected = f(a, b, c) self.assertAllClose(result, expected, check_dtypes=False) self.assertIsInstance(result, pxla.ShardedDeviceArray) self.assertLen(result.device_buffers, 2) in_parts = None result = sharded_jit(f, in_parts, out_parts)(a, b, c) self.assertAllClose(result, expected, check_dtypes=False) self.assertIsInstance(result, pxla.ShardedDeviceArray) self.assertLen(result.device_buffers, 2)
def testPyTreeArgs(self): if jax.local_device_count() < 4: raise SkipTest("requires 4 devices") def f(a, b, c): a1, a2 = a c1, (c2, c3) = c return a1 + a2 + b + c1 + c2 + c3 def _make_arg(*shape): return np.arange(np.prod(shape)).reshape(shape) a = (_make_arg(2, 4, 4), _make_arg(2)) b = _make_arg(2, 4, 4) c = (_make_arg(2), (_make_arg(2, 4, 4), _make_arg(2, 4, 4))) in_parts = (None, P(2, 1), (None, P(2, 1))) out_parts = P(2, 1) result = pmap(sharded_jit(f, in_parts=in_parts, out_parts=out_parts))(a, b, c) expected = pmap(f)(a, b, c) self.assertAllClose(result, expected, check_dtypes=False) self.assertTrue(isinstance(result, pxla.ShardedDeviceArray)) self.assertEqual(len(result.device_buffers), 4)
def testCompilationCache(self): f = lambda x: x + 1 sharded_f = sharded_jit(f, in_parts=P(2), out_parts=P(2)) shape = (2, ) x = np.arange(prod(shape), dtype=np.float32).reshape(shape) with jtu.assert_num_jit_and_pmap_compilations(1): sharded_f(x) sharded_f(x)
def testCompilationCache(self): f = lambda x: x + 1 sharded_f = sharded_jit(f, in_parts=P(2), out_parts=P(2)) shape = (2, ) x = np.arange(prod(shape), dtype=np.float32).reshape(shape) with jtu.count_jit_and_pmap_compiles() as count: sharded_f(x) sharded_f(x) self.assertEqual(count[0], 1)
def testTranslationRule(self): @partial(sharded_jit, in_parts=(P(2, 1), P(2, 1)), out_parts=None) def f(x, y): return x + y # Test that the translation rule runs without error and produces the # OpShardings we expect somewhere. shape = (8, 8) hlo = jax.xla_computation(f)(np.ones(shape), np.ones(shape)) self.assertIn("sharding={devices=[2,1]0,1}", hlo.as_hlo_text()) self.assertIn("sharding={replicated}", hlo.as_hlo_text())
def apply(self, num_embeddings, features, embedding_init=default_embed_init, dtype=jnp.float32, num_partitions=2): """Layer that returns an embedding matrix. Args: num_embeddings: number of embeddings. features: Number of feature dimensions for each embedding. embedding_init: embedding initializer. dtype: dtype to use for activations. num_partitions: number of ways to partition (i.e. how many devices to run across). Returns: An embedding matrix suitable for embedding[inputs]. """ embedding = self.param('embedding', (num_embeddings, features), embedding_init) embedding = jnp.asarray(embedding, dtype) if num_partitions > 1: embedding = with_sharding_constraint(embedding, P(num_partitions, 1)) return embedding
def apply(self, inputs, mlp_dim, dtype=jnp.float32, out_dim=None, dropout_rate=0.1, deterministic=False, kernel_init=nn.initializers.xavier_uniform(), bias_init=nn.initializers.normal(stddev=1e-6), num_partitions=2): """Applies Transformer MlpBlock module.""" actual_out_dim = inputs.shape[-1] if out_dim is None else out_dim inputs_shape = inputs.shape inputs = inputs.reshape((-1, inputs_shape[-1])) x = nn.Dense(inputs, mlp_dim, dtype=dtype, kernel_init=kernel_init, bias_init=bias_init) x = nn.relu(x) if num_partitions > 1: x = with_sharding_constraint(x, P(1, num_partitions)) x = nn.dropout(x, rate=dropout_rate, deterministic=deterministic) output = nn.Dense(x, actual_out_dim, dtype=dtype, kernel_init=kernel_init, bias_init=bias_init) output = nn.dropout(output, rate=dropout_rate, deterministic=deterministic) output = output.reshape(inputs_shape[:-1] + (actual_out_dim, )) return output
def testPyTreeOutputs(self): if jax.local_device_count() < 4: raise SkipTest("requires 4 devices") def f(x): return x + 1, ((x + 2, x + 3), x + 4) shape = (2, 4, 4) x = np.arange(np.prod(shape)).reshape(shape) in_parts = (P(2, 1),) out_parts = (P(2, 1), ((P(1, 2), None), P(2, 1))) result = pmap(sharded_jit(f, in_parts=in_parts, out_parts=out_parts))(x) expected = pmap(f)(x) self.assertAllClose(result, expected, check_dtypes=False)
def testInfeed(self, partition_input): if jax.local_device_count() % 2 != 0: raise SkipTest shape = (jax.local_device_count() * 2, 4) # Run computation across all devices so we know which devices to feed. parts = P(jax.local_device_count(), 1) in_parts = parts if partition_input else None infeed_shapes = (jax.ShapedArray(shape, np.float32), jax.ShapedArray((1, ), np.float32)) infeed_parts = (parts, None) @partial(sharded_jit, in_parts=in_parts, out_parts=None) def f(x): token = lax.create_token(x) (y, z), token = lax.infeed(token, infeed_shapes, partitions=infeed_parts) return x @ y.T + z x = np.arange(prod(shape), dtype=np.float32).reshape(shape) y = x + 1 shard_size = shape[0] // jax.local_device_count() y_shards = [ y[i:i + shard_size] for i in range(0, shape[0], shard_size) ] z = jnp.array([3.], dtype=np.float32) result = f(x) assert len(jax.local_devices()) == len(y_shards) for device, y_shard in zip(jax.local_devices(), y_shards): device.transfer_to_infeed((y_shard, z)) expected = x @ y.T + z self.assertAllClose(result, expected, check_dtypes=False)
def testBasic(self): if jax.device_count() < 2: raise SkipTest @partial(sharded_jit, in_parts=(P(2, 1), P(2, 1)), out_parts=None) def f(x, y): return x + y shape = (8, 8) x = np.arange(np.prod(shape), dtype=np.float32).reshape(shape) actual = f(x, x + 1) expected = x + (x + 1) self.assertAllClose(actual, expected, check_dtypes=False) self.assertIsInstance(actual, pxla.ShardedDeviceArray) self.assertLen(actual.device_buffers, 2) self.assertAllClose(actual.device_buffers[0].to_py(), expected, check_dtypes=False)
def testPyTreeOutputs(self): if jax.device_count() < 2: raise SkipTest def f(x): return x + 1, ((x + 2, x + 3), x + 4) shape = (4, 4) x = np.arange(prod(shape)).reshape(shape) in_parts = (P(2, 1), ) out_parts = (P(2, 1), ((P(1, 2), None), P(2, 1))) result = sharded_jit(f, in_parts, out_parts)(x) expected = f(x) self.assertAllClose(result, expected, check_dtypes=False) out_parts = None result = sharded_jit(f, in_parts, out_parts)(x) self.assertAllClose(result, expected, check_dtypes=False)
def testShardingConstraint(self): if jax.local_device_count() < 2: raise SkipTest("requires 2 devices") def f(x): y = x + 1 y = with_sharding_constraint(y, P(1, 2)) return y * 2 shape = (8, 8) x = np.arange(prod(shape)).reshape(shape) expected = (x + 1) * 2 # Matching sharded_jit partitions actual = sharded_jit(f, in_parts=P(2, 1), out_parts=P(2, 1))(x) self.assertAllClose(actual, expected, check_dtypes=False) self.assertLen(actual.device_buffers, 2) # TODO(jblespiau): We can simply use buf.xla_shape() when version 0.1.58 is # the default. self.assertEqual( getattr(actual.device_buffers[0], "xla_shape", actual.device_buffers[0].shape)().dimensions(), (4, 8)) self.assertEqual( getattr(actual.device_buffers[1], "xla_shape", actual.device_buffers[1].shape)().dimensions(), (4, 8)) # Mismatched sharded_jit partitions with self.assertRaisesRegex( ValueError, r"with_sharding_constraint with partitions=PartitionSpec\(1, 2\) " r"\(total partitions: 2\) doesn't match expected number of partitions: " r"4. If these partitions look right, check outer sharded_jit and/or " r"other with_sharding_constraint calls."): sharded_jit(f, in_parts=P(2, 2), out_parts=P(2, 2))(x) # Replicated sharded_jit actual = sharded_jit(f, in_parts=None, out_parts=None)(x) self.assertAllClose(actual, expected, check_dtypes=False) self.assertLen(actual.device_buffers, 2) self.assertAllClose(actual.device_buffers[0].to_py(), actual.device_buffers[1].to_py(), check_dtypes=False)
def testInAxesNone(self): shape = (4, 4) replicas = 2 in_partitions = (P(2, 1), None, None) out_partitions = P(2, 1) in_axes = (None, None, 0) x = y = np.arange(np.prod(shape), dtype=np.float32).reshape(shape) dummy = np.arange(replicas, dtype=np.float32) + 1 num_shards = replicas * np.prod(in_partitions[0]) if num_shards > jax.local_device_count(): raise SkipTest("requires %d devices" % num_shards) def f(x, y, _): return x @ y result = pmap( sharded_jit(f, in_parts=in_partitions, out_parts=out_partitions), in_axes=in_axes)(x, y, dummy) expected = pmap(f, in_axes=in_axes)(x, y, dummy) self.assertAllClose(result, expected, check_dtypes=True)
def testNotEnoughDevices(self): ndevices = jax.local_device_count() @partial(sharded_jit, in_parts=P(ndevices + 1), out_parts=None) def f(x): return x + x with self.assertRaisesRegex( ValueError, f"sharded_jit computation requires {ndevices + 1} devices, " f"but only {ndevices} devices are available."): f(np.ones(ndevices + 1))
def testGradOfShardingConstraint(self): if jax.local_device_count() < 4: raise SkipTest("requires 4 devices") @partial(sharded_jit, in_parts=P(4,1), out_parts=None) def f(x): y = x + 1 p, vjp_f = vjp(lambda z: jnp.sin(with_sharding_constraint(z, P(2,2))), y) return vjp_f(p) def expected_f(x): y = x + 1 p, vjp_f = vjp(lambda z: jnp.sin(z), y) return vjp_f(p) shape = (4, 4) x = jnp.arange(jnp.prod(shape), dtype=jnp.float32).reshape(shape) actual = f(x) expected = expected_f(x) self.assertAllClose(actual, expected, check_dtypes=False)
def testManyArgs(self): if jax.local_device_count() < 4: raise SkipTest("requires 4 devices") num_args = 200 def f(*args): return jnp.sum(args) shape = (2, 4, 4) args = [np.arange(np.prod(shape)).reshape(shape)] * num_args in_partitions = (P(2, 1),) * num_args out_partitions = None result = pmap(sharded_jit( f, in_parts=in_partitions, out_parts=out_partitions))(*args) expected = pmap(f)(*args) self.assertAllClose(result, expected, check_dtypes=False) self.assertTrue(isinstance(result, pxla.ShardedDeviceArray)) self.assertEqual(len(result.device_buffers), 4)
def _pjit(inp): if isinstance(inp, GlobalDeviceArray): if inp.is_fully_replicated: return inp.local_data(0).to_py() global_mesh = inp._global_mesh in_axis_resources = FROM_GDA else: # DA/SDA/np.array will be sharded based on global_mesh.local_mesh. # Shape of local_mesh will always be (1, local_device_count()) devices = np.array(jax.devices()).reshape(jax.process_count(), jax.local_device_count()) global_mesh = maps.Mesh(devices, ('processes', 'local_devices')) in_axis_resources = P('processes') if inp.ndim == 0 or not titled: inp = np.expand_dims(inp, axis=0) with maps.mesh(global_mesh.devices, global_mesh.axis_names): out = pjit(lambda x: x, in_axis_resources=in_axis_resources, out_axis_resources=None)(inp) return out.local_data(0).to_py()
def test_sharded_jit_replicated(self): """A replicated input and output.""" sharded_jax_func = sharded_jit.sharded_jit( jnp.dot, in_parts=(P(1, 2), None), out_parts=None) xshape = (3, 8) x = np.arange(np.prod(xshape), dtype=np.float32).reshape(xshape) yshape = (8, 5) y = np.arange(np.prod(yshape), dtype=np.float32).reshape(yshape) self._check_sharding_annotations( sharded_jax_func, [x, y], expected=[ r"f32\[3,8\].*sharding={devices=\[1,2\]", r"f32\[8,5\].*sharding={replicated}", r"f32\[3,5\].*sharding={replicated}" ], expected_opt=[ # TODO(necula): relax ordering r"f32\[8,5\].*sharding={replicated}", r"f32\[3,4\].*sharding={devices=\[1,2\]", ], num_partitions=2)
def f(x): y = jnp.dot(x, x) y = with_sharding_constraint(y, P(2, 1)) return y * 2
class PmapOfShardedJitTest(jtu.JaxTestCase): def setUp(self): super(PmapOfShardedJitTest, self).setUp() if jtu.device_under_test() != "tpu": raise SkipTest # TODO(skye): make a similar version for ShardedJitTest and run the same tests def _runTest(self, f, in_partitions, out_partitions, dtype=np.float32): """Compares pmap(sharded_jit(f, ...)) to pmap(f)""" shape = (2, 4, 4) num_shards = shape[0] * np.prod(in_partitions[0]) if num_shards > jax.local_device_count(): raise SkipTest("requires %d devices" % num_shards) x = np.arange(prod(shape)).reshape(shape) y = x + 1 result = pmap( sharded_jit(f, in_parts=in_partitions, out_parts=out_partitions))(x, y) expected = pmap(f)(x, y) self.assertAllClose(result, expected, check_dtypes=False) flat_result = tree_util.tree_flatten(result)[0] for r in flat_result: self.assertTrue(isinstance(r, pxla.ShardedDeviceArray)) self.assertEqual(len(r.device_buffers), num_shards) @parameterized.named_parameters({ "testcase_name": "_in_parts={}_out_parts={}".format(in_partitions, out_partitions).replace(" ", ""), "in_partitions": in_partitions, "out_partitions": out_partitions } for in_partitions in [ (P(2, 1), P(2, 1)), (P(2, 1), P(1, 2)), (P(2, 2), P(2, 2)), (P(4, 1), P(2, 2)), ] for out_partitions in [in_partitions[0], None]) def testBasic(self, in_partitions, out_partitions): def f(x, y): return lax.dot(x, y) self._runTest(f, in_partitions, out_partitions) @parameterized.named_parameters( { "testcase_name": "_in_parts={}_out_parts={}".format(in_partitions, out_partitions).replace( " ", ""), "in_partitions": in_partitions, "out_partitions": out_partitions } for in_partitions in [(P(2, 1), P(2, 1)), (P(2, 1), P(1, 2)), (P(4, 1), P(2, 2))] for out_partitions in [(in_partitions[1], in_partitions[0], None), (None, None, None)]) def testMultipleOutputs(self, in_partitions, out_partitions): def f(x, y): a = lax.dot(x, y) # TODO(skye): use these more interesting outputs once returning constants # works # return a, a + 1, 3 return a, a + x, x + y self._runTest(f, in_partitions, out_partitions) @parameterized.named_parameters( { "testcase_name": "_in_parts={}_out_parts={}".format(in_partitions, out_partitions).replace( " ", ""), "in_partitions": in_partitions, "out_partitions": out_partitions } for in_partitions in [(P(2, 1), P(2, 1)), (P(2, 1), P(1, 2)), (P(4, 1), P(2, 2))] for out_partitions in [in_partitions[0], None]) def testArrayConstants(self, in_partitions, out_partitions): def f(x, y): a = lax.dot(x, y) b = a + jnp.ones(a.shape) c = b + jnp.ones(a.shape[0]) return c self._runTest(f, in_partitions, out_partitions) def testPyTreeArgs(self): if jax.local_device_count() < 4: raise SkipTest("requires 4 devices") def f(a, b, c): a1, a2 = a c1, (c2, c3) = c return a1 + a2 + b + c1 + c2 + c3 def _make_arg(*shape): return np.arange(prod(shape)).reshape(shape) a = (_make_arg(2, 4, 4), _make_arg(2)) b = _make_arg(2, 4, 4) c = (_make_arg(2), (_make_arg(2, 4, 4), _make_arg(2, 4, 4))) in_parts = (None, P(2, 1), (None, P(2, 1))) out_parts = P(2, 1) result = pmap(sharded_jit(f, in_parts=in_parts, out_parts=out_parts))(a, b, c) expected = pmap(f)(a, b, c) self.assertAllClose(result, expected, check_dtypes=False) self.assertTrue(isinstance(result, pxla.ShardedDeviceArray)) self.assertEqual(len(result.device_buffers), 4) def testPyTreeOutputs(self): if jax.local_device_count() < 4: raise SkipTest("requires 4 devices") def f(x): return x + 1, ((x + 2, x + 3), x + 4) shape = (2, 4, 4) x = np.arange(prod(shape)).reshape(shape) in_parts = (P(2, 1), ) out_parts = (P(2, 1), ((P(1, 2), None), P(2, 1))) result = pmap(sharded_jit(f, in_parts=in_parts, out_parts=out_parts))(x) expected = pmap(f)(x) self.assertAllClose(result, expected, check_dtypes=False) def testManyArgs(self): if jax.local_device_count() < 4: raise SkipTest("requires 4 devices") num_args = 200 def f(*args): return jnp.asarray(args).sum() shape = (2, 4, 4) args = [np.arange(prod(shape)).reshape(shape)] * num_args in_partitions = (P(2, 1), ) * num_args out_partitions = None result = pmap( sharded_jit(f, in_parts=in_partitions, out_parts=out_partitions))(*args) expected = pmap(f)(*args) self.assertAllClose(result, expected, check_dtypes=False) self.assertTrue(isinstance(result, pxla.ShardedDeviceArray)) self.assertEqual(len(result.device_buffers), 4) def testShardingConstraint(self): if jax.local_device_count() < 4: raise SkipTest("requires 4 devices") @partial(sharded_jit, in_parts=None, out_parts=None) def f(x): y = jnp.dot(x, x) y = with_sharding_constraint(y, P(2, 1)) return y * 2 def expected_f(x): return jnp.dot(x, x) * 2 shape = (2, 8, 8) x = np.arange(prod(shape)).reshape(shape) result = pmap(f)(x) expected = pmap(expected_f)(x) self.assertAllClose(result, expected, check_dtypes=False) self.assertIsInstance(result, pxla.ShardedDeviceArray) self.assertLen(result.device_buffers, 4) def testInAxesNone(self): shape = (4, 4) replicas = 2 in_partitions = (P(2, 1), None, None) out_partitions = P(2, 1) in_axes = (None, None, 0) x = y = np.arange(prod(shape), dtype=np.float32).reshape(shape) dummy = np.arange(replicas, dtype=np.float32) + 1 num_shards = replicas * np.prod(in_partitions[0]) if num_shards > jax.local_device_count(): raise SkipTest("requires %d devices" % num_shards) def f(x, y, _): return x @ y result = pmap(sharded_jit(f, in_parts=in_partitions, out_parts=out_partitions), in_axes=in_axes)(x, y, dummy) expected = pmap(f, in_axes=in_axes)(x, y, dummy) self.assertAllClose(result, expected, check_dtypes=True)
def f(x): y = x + 1 y = with_sharding_constraint(y, P(2, 1)) return y * 2
def f(x): y = x + 1 p, vjp_f = vjp( lambda z: jnp.sin(with_sharding_constraint(z, P(2, 2))), y) return vjp_f(p)
def f(x): return lax.while_loop( lambda i: i[0, 0] < 10., lambda i: with_sharding_constraint(i + 1., P(2, 1)), x)
def apply(self, inputs_q, inputs_kv, num_heads, dtype=jnp.float32, qkv_features=None, out_features=None, attention_axis=None, causal_mask=False, padding_mask=None, key_padding_mask=None, segmentation=None, key_segmentation=None, cache=None, broadcast_dropout=True, dropout_rng=None, dropout_rate=0., deterministic=False, precision=None, kernel_init=nn.linear.default_kernel_init, bias_init=nn.initializers.zeros, bias=True, num_partitions=2): """Applies multi-head dot product attention on the input data. Projects the inputs into multi-headed query, key, and value vectors, applies dot-product attention and project the results to an output vector. This can be used for encoder-decoder attention by specifying both `inputs_q` and `inputs_kv` orfor self-attention by only specifying `inputs_q` and setting `inputs_kv` to None. Args: inputs_q: input queries of shape `[bs, dim1, dim2, ..., dimN, features]`. inputs_kv: key/values of shape `[bs, dim1, dim2, ..., dimN, features]` or None for self-attention, inn which case key/values will be derived from inputs_q. num_heads: number of attention heads. Features (i.e. inputs_q.shape[-1]) should be divisible by the number of heads. dtype: the dtype of the computation (default: float32) qkv_features: dimension of the key, query, and value. out_features: dimension of the last projection attention_axis: axes over which the attention is applied ( 'None' means attention over all axes, but batch, heads, and features). causal_mask: boolean specifying whether to apply a causal mask on the attention weights. If True, the output at timestep `t` will not depend on inputs at timesteps strictly greater than `t`. padding_mask: boolean specifying query tokens that are pad token. key_padding_mask: boolean specifying key-value tokens that are pad token. segmentation: segment indices for packed inputs_q data. key_segmentation: segment indices for packed inputs_kv data. cache: an instance of `flax.nn.attention.Cache` used for efficient autoregressive decoding. broadcast_dropout: bool: use a broadcasted dropout along batch dims. dropout_rng: JAX PRNGKey: to be used for dropout dropout_rate: dropout rate deterministic: bool, deterministic or not (to apply dropout) precision: numerical precision of the computation see `jax.lax.Precision` for details. kernel_init: initializer for the kernel of the Dense layers. bias_init: initializer for the bias of the Dense layers. bias: bool: whether pointwise QKVO dense transforms use bias. num_partitions: number of ways to partition (i.e. how many devices to run across). Returns: output of shape `[bs, dim1, dim2, ..., dimN, features]`. """ assert causal_mask or not cache, ( 'Caching is only support for causal attention.') if inputs_kv is None: inputs_kv = inputs_q if attention_axis is None: attention_axis = tuple(range(1, inputs_q.ndim - 1)) features = out_features or inputs_q.shape[-1] qkv_features = qkv_features or inputs_q.shape[-1] assert qkv_features % num_heads == 0, ( 'Memory dimension must be divisible by number of heads.') head_dim = qkv_features // num_heads dense = nn.DenseGeneral.partial(axis=-1, features=(num_heads, head_dim), kernel_init=kernel_init, bias_init=bias_init, bias=bias, precision=precision) # project inputs_q to multi-headed q/k/v # dimensions are then [bs, dims..., n_heads, n_features_per_head] query, key, value = (dense(inputs_q, dtype=dtype, name='query'), dense(inputs_kv, dtype=dtype, name='key'), dense(inputs_kv, dtype=dtype, name='value')) if num_partitions > 1: partitions = P(1, 1, num_partitions, 1) query = with_sharding_constraint(query, partitions) key = with_sharding_constraint(key, partitions) value = with_sharding_constraint(value, partitions) if cache: assert isinstance(cache, Cache), 'cache must be an instance of Cache' if self.is_initializing(): cache.store(lambda: (key.ndim, key.shape[-2:])) else: cache_entry = cache.retrieve(None) expected_shape = list(cache_entry.key.shape[:-2]) for attn_dim in attention_axis: expected_shape[attn_dim] = 1 expected_shape = tuple(expected_shape) + inputs_q.shape[-1:] if expected_shape != inputs_q.shape: raise ValueError('Invalid shape provided, ' 'expected shape %s instead got %s.' % (expected_shape, inputs_q.shape)) if not isinstance(cache_entry, _CacheEntry): raise ValueError('Cache is not initialized.') cshape = cache_entry.key.shape i = cache_entry.i one_hot_indices = jax.nn.one_hot(i, cshape[3], dtype=key.dtype).reshape( (1, 1, 1, cshape[3])) key = key.transpose((0, 2, 3, 1)) key = cache_entry.key + key * one_hot_indices value = value.transpose((0, 2, 3, 1)) value = cache_entry.value + value * one_hot_indices one = jnp.array(1, jnp.uint32) cache_entry = cache_entry.replace(i=cache_entry.i + one, key=key, value=value) cache.store(cache_entry) key = key.transpose((0, 3, 1, 2)) value = value.transpose((0, 3, 1, 2)) cshape = (cshape[0], cshape[3], cshape[1], cshape[2]) # TODO(levskaya): verify this is still needed in translation decoding. key_padding_mask = jnp.broadcast_to( (jnp.arange(cshape[1]) < cache_entry.i), cshape[:2]) key_padding_mask = key_padding_mask.astype(jnp.float32)[..., None] # create attention masks mask_components = [] if causal_mask: if cache and not self.is_initializing(): bias_pre_shape = (1, ) * (key.ndim - 1) attn_shape = tuple(np.take(key.shape, attention_axis)) attn_size = np.prod(attn_shape) ii = jnp.arange(attn_size, dtype=jnp.uint32) mask = ii < cache_entry.i mask_components.append( mask.reshape(bias_pre_shape + attn_shape)) else: mask_components.append(_make_causal_mask(key, attention_axis)) if padding_mask is not None: if key_padding_mask is None: key_padding_mask = padding_mask padding_mask = make_padding_mask(padding_mask_query=padding_mask, padding_mask_key=key_padding_mask, query_shape=query.shape, key_shape=key.shape, attention_axis=attention_axis) mask_components.append(padding_mask) if segmentation is not None: if key_segmentation is None: key_segmentation = segmentation segmentation_mask = make_padding_mask( padding_mask_query=segmentation, padding_mask_key=key_segmentation, query_shape=query.shape, key_shape=key.shape, attention_axis=attention_axis, segmentation_mask=True) mask_components.append(segmentation_mask) if mask_components: attention_mask = mask_components[0] for component in mask_components[1:]: attention_mask = jnp.logical_and(attention_mask, component) # attention mask in the form of attention bias attention_bias = lax.select( attention_mask > 0, jnp.full(attention_mask.shape, 0.).astype(dtype), jnp.full(attention_mask.shape, -1e10).astype(dtype)) else: attention_bias = None # apply attention x = dot_product_attention(query, key, value, dtype=dtype, axis=attention_axis, bias=attention_bias, precision=precision, dropout_rng=dropout_rng, dropout_rate=dropout_rate, broadcast_dropout=broadcast_dropout, deterministic=deterministic) # back to the original inputs dimensions out = nn.DenseGeneral(x, features=features, axis=(-2, -1), kernel_init=kernel_init, bias_init=bias_init, bias=bias, dtype=dtype, precision=precision, name='out') if num_partitions > 1: x = with_sharding_constraint(x, None) return out