Beispiel #1
0
    def testShardingConstraint(self):
        def f(x):
            y = x + 1
            y = with_sharding_constraint(y, P(1, 2))
            return y * 2

        shape = (8, 8)
        x = np.arange(np.prod(shape)).reshape(shape)
        expected = (x + 1) * 2

        # Matching sharded_jit partitions
        actual = sharded_jit(f, in_parts=P(2, 1), out_parts=P(2, 1))(x)
        self.assertAllClose(actual, expected, check_dtypes=False)
        self.assertLen(actual.device_buffers, 2)
        self.assertEqual(actual.device_buffers[0].shape().dimensions(), (4, 8))
        self.assertEqual(actual.device_buffers[1].shape().dimensions(), (4, 8))

        # Mismatched sharded_jit partitions
        with self.assertRaisesRegex(
                ValueError,
                r"with_sharding_constraint with partitions=PartitionSpec\(1, 2\) "
                r"\(total partitions: 2\) doesn't match expected number of partitions: "
                r"4. If these partitions look right, check outer sharded_jit and/or "
                r"other with_sharding_constraint calls."):
            sharded_jit(f, in_parts=P(2, 2), out_parts=P(2, 2))(x)

        # Replicated sharded_jit
        actual = sharded_jit(f, in_parts=None, out_parts=None)(x)
        self.assertAllClose(actual, expected, check_dtypes=False)
        self.assertLen(actual.device_buffers, 2)
        self.assertAllClose(actual.device_buffers[0].to_py(),
                            actual.device_buffers[1].to_py(),
                            check_dtypes=False)
Beispiel #2
0
    def test_pjit_TwoMeshAxisSharding(self):
        @functools.partial(pjit.pjit,
                           in_axis_resources=P(("x", "y"), ),
                           out_axis_resources=P(("x", "y"), ))
        def jax_func(x, y):
            return x @ y

        x_shape = (24, 8)
        y_shape = (8, 2)
        x = jnp.arange(np.prod(x_shape), dtype=np.float32).reshape(x_shape)
        y = jnp.arange(np.prod(y_shape), dtype=np.float32).reshape(y_shape)
        self._check_sharding_annotations(
            jax_func,
            [x, y],
            expected=[
                r"f32\[24,8\].*sharding={devices=\[4,1\]0,1,2,3",  # x
                r"f32\[8,2\].*sharding={devices=\[4,1\]0,1,2,3",  # y
                r"f32\[24,2\].*sharding={devices=\[4,1\]0,1,2,3",  # output
            ],
            expected_opt=[
                # TODO: relax ordering
                r"f32\[2,2\].*sharding={devices=\[4,1\]0,1,2,3",  # y
                r"f32\[6,8\].*sharding={devices=\[4,1\]0,1,2,3",  # x
                # TODO: why we cannot see .*sharding={devices=\[4,1\]0,1,2,3
                r"f32\[1,6,2\]",  # output
            ],
            num_partitions=4)
Beispiel #3
0
    def test_pjit_basic1D(self):
        @functools.partial(pjit.pjit,
                           in_axis_resources=(P("x"), P("x")),
                           out_axis_resources=None)
        def jax_func(x, y):
            return x + y

        shape = (8, 10)
        x = np.arange(np.prod(shape), dtype=np.float32).reshape(shape)
        hlo = jax.xla_computation(jax_func)(x, x).as_hlo_text()
        print(f"HLO is {hlo}")
        print(f"JAXPR is {jax.make_jaxpr(jax_func)(x, x)}")
        self._check_sharding_annotations(
            jax_func,
            [x, x],
            expected=[
                r"f32\[8,10\].*sharding={devices=\[2,1\]",  # x and y
                r"f32\[8,10\].*sharding={replicated",  # output
            ],
            expected_opt=[
                r"f32\[4,10\].*sharding={devices=\[2,1\]",  # x and y
                # TODO: why don't we see "sharding={replicated"
                r"f32\[8,10\]",  # output
            ],
            num_partitions=2)
Beispiel #4
0
    def test_pjit_basic2D(self):
        @functools.partial(pjit.pjit,
                           in_axis_resources=(P(None, "x", "y"), P("y")),
                           out_axis_resources=P("x"))
        def jax_func(x, y):
            return x @ y

        x_shape = (8, 6, 4)
        y_shape = (4, 2)
        x = jnp.arange(np.prod(x_shape), dtype=np.float32).reshape(x_shape)
        y = jnp.arange(np.prod(y_shape), dtype=np.float32).reshape(y_shape)
        self._check_sharding_annotations(
            jax_func,
            [x, y],
            expected=[
                r"f32\[8,6,4\].*sharding={devices=\[1,2,2\]0,1,2,3",  # x
                r"f32\[4,2\].*sharding={devices=\[2,1,2\]0,2,1,3 last_tile_dim_replicate",  # y
                r"f32\[8,6,2\].*sharding={devices=\[2,1,1,2\]0,1,2,3 last_tile_dim_replicate",  # output
            ],
            expected_opt=[
                # TODO: relax ordering
                r"f32\[2,2\].*sharding={devices=\[2,1,2\]0,2,1,3 last_tile_dim_replicate",  # y
                r"f32\[8,3,2\].*sharding={devices=\[1,2,2\]0,1,2,3",  # x
                # TODO: why we cannot see sharding={devices=\[2,1,1,2\]0,1,2,3 last_tile_dim_replicate?
                r"bf16\[4,6,2\]",  # output
            ],
            num_partitions=4)
Beispiel #5
0
    def test_sharded_jit_with_sharding_constraint(self):
        """A sharding constraint in the middle."""
        def jax_func(x, y):
            logits1 = jnp.dot(x, y)
            return jnp.sin(
                sharded_jit.with_sharding_constraint(logits1, P(2, 1)))

        sharded_jax_func = sharded_jit.sharded_jit(jax_func,
                                                   in_parts=(P(1, 2), P(2, 1)),
                                                   out_parts=P(1, 2))
        xshape = (6, 8)
        x = np.arange(np.prod(xshape), dtype=np.float32).reshape(xshape)
        yshape = (8, 10)
        y = np.arange(np.prod(yshape), dtype=np.float32).reshape(yshape)
        self._check_sharding_annotations(
            sharded_jax_func,
            [x, y],
            expected=[
                r"f32\[6,8\].*sharding={devices=\[1,2\]",
                r"f32\[8,10\].*sharding={devices=\[2,1\]",
                r"f32\[6,10\].*sharding={devices=\[2,1\]",
                r"f32\[6,10\].*sine.*sharding={devices=\[1,2\]"
            ],
            expected_opt=[
                # TODO(necula): relax ordering
                r"f32\[4,10\].*sharding={devices=\[2,1\]",
                r"f32\[6,4\].*sharding={devices=\[1,2\]",
            ],
            num_partitions=2)
Beispiel #6
0
 def test_sharded_jit_in_out(self):
     """Test input and output sharding annotations."""
     sharded_jax_func = sharded_jit.sharded_jit(jnp.dot,
                                                in_parts=(P(1, 2), P(2, 1)),
                                                out_parts=P(1, 2))
     xshape = (3, 8)
     x = np.arange(np.prod(xshape), dtype=np.float32).reshape(xshape)
     yshape = (8, 5)
     y = np.arange(np.prod(yshape), dtype=np.float32).reshape(yshape)
     self._check_sharding_annotations(
         sharded_jax_func,
         [x, y],
         expected=[
             r"f32\[3,8\].*sharding={devices=\[1,2\]",
             r"f32\[8,5\].*sharding={devices=\[2,1\]",
             r"f32\[3,5\].*sharding={devices=\[1,2\]"
         ],
         expected_opt=[
             # TODO(necula): relax ordering
             r"f32\[4,5\].*sharding={devices=\[2,1\]",
             r"f32\[3,4\].*sharding={devices=\[1,2\]",
             r"f32\[3,5\].*fusion",
             r"f32\[3,5\].*all-reduce",
         ],
         num_partitions=2)
Beispiel #7
0
    def testPyTreeArgs(self):
        if jax.device_count() < 2:
            raise SkipTest

        def f(a, b, c):
            a1, a2 = a
            c1, (c2, c3) = c
            return a1 + a2 + b + c1 + c2 + c3

        def _make_arg(*shape):
            return np.arange(np.prod(shape)).reshape(shape)

        a = (_make_arg(4, 4), 1)
        b = _make_arg(4, 4)
        c = [2, (_make_arg(4, 4), _make_arg(4, 4))]

        in_parts = (None, P(2, 1), [None, P(2, 1)])
        out_parts = P(2, 1)

        result = sharded_jit(f, in_parts, out_parts)(a, b, c)
        expected = f(a, b, c)

        self.assertAllClose(result, expected, check_dtypes=False)
        self.assertIsInstance(result, pxla.ShardedDeviceArray)
        self.assertLen(result.device_buffers, 2)

        in_parts = None
        result = sharded_jit(f, in_parts, out_parts)(a, b, c)
        self.assertAllClose(result, expected, check_dtypes=False)
        self.assertIsInstance(result, pxla.ShardedDeviceArray)
        self.assertLen(result.device_buffers, 2)
Beispiel #8
0
    def testPyTreeArgs(self):
        if jax.local_device_count() < 4:
            raise SkipTest("requires 4 devices")

        def f(a, b, c):
            a1, a2 = a
            c1, (c2, c3) = c
            return a1 + a2 + b + c1 + c2 + c3

        def _make_arg(*shape):
            return np.arange(np.prod(shape)).reshape(shape)

        a = (_make_arg(2, 4, 4), _make_arg(2))
        b = _make_arg(2, 4, 4)
        c = (_make_arg(2), (_make_arg(2, 4, 4), _make_arg(2, 4, 4)))

        in_parts = (None, P(2, 1), (None, P(2, 1)))
        out_parts = P(2, 1)

        result = pmap(sharded_jit(f, in_parts=in_parts,
                                  out_parts=out_parts))(a, b, c)
        expected = pmap(f)(a, b, c)

        self.assertAllClose(result, expected, check_dtypes=False)
        self.assertTrue(isinstance(result, pxla.ShardedDeviceArray))
        self.assertEqual(len(result.device_buffers), 4)
Beispiel #9
0
    def testCompilationCache(self):
        f = lambda x: x + 1
        sharded_f = sharded_jit(f, in_parts=P(2), out_parts=P(2))
        shape = (2, )
        x = np.arange(prod(shape), dtype=np.float32).reshape(shape)

        with jtu.assert_num_jit_and_pmap_compilations(1):
            sharded_f(x)
            sharded_f(x)
Beispiel #10
0
    def testCompilationCache(self):
        f = lambda x: x + 1
        sharded_f = sharded_jit(f, in_parts=P(2), out_parts=P(2))
        shape = (2, )
        x = np.arange(prod(shape), dtype=np.float32).reshape(shape)

        with jtu.count_jit_and_pmap_compiles() as count:
            sharded_f(x)
            sharded_f(x)
        self.assertEqual(count[0], 1)
Beispiel #11
0
    def testTranslationRule(self):
        @partial(sharded_jit, in_parts=(P(2, 1), P(2, 1)), out_parts=None)
        def f(x, y):
            return x + y

        # Test that the translation rule runs without error and produces the
        # OpShardings we expect somewhere.
        shape = (8, 8)
        hlo = jax.xla_computation(f)(np.ones(shape), np.ones(shape))
        self.assertIn("sharding={devices=[2,1]0,1}", hlo.as_hlo_text())
        self.assertIn("sharding={replicated}", hlo.as_hlo_text())
Beispiel #12
0
    def apply(self,
              num_embeddings,
              features,
              embedding_init=default_embed_init,
              dtype=jnp.float32,
              num_partitions=2):
        """Layer that returns an embedding matrix.

    Args:
      num_embeddings: number of embeddings.
      features: Number of feature dimensions for each embedding.
      embedding_init: embedding initializer.
      dtype: dtype to use for activations.
      num_partitions: number of ways to partition (i.e. how many devices to run
        across).

    Returns:
      An embedding matrix suitable for embedding[inputs].
    """
        embedding = self.param('embedding', (num_embeddings, features),
                               embedding_init)
        embedding = jnp.asarray(embedding, dtype)
        if num_partitions > 1:
            embedding = with_sharding_constraint(embedding,
                                                 P(num_partitions, 1))
        return embedding
Beispiel #13
0
 def apply(self,
           inputs,
           mlp_dim,
           dtype=jnp.float32,
           out_dim=None,
           dropout_rate=0.1,
           deterministic=False,
           kernel_init=nn.initializers.xavier_uniform(),
           bias_init=nn.initializers.normal(stddev=1e-6),
           num_partitions=2):
     """Applies Transformer MlpBlock module."""
     actual_out_dim = inputs.shape[-1] if out_dim is None else out_dim
     inputs_shape = inputs.shape
     inputs = inputs.reshape((-1, inputs_shape[-1]))
     x = nn.Dense(inputs,
                  mlp_dim,
                  dtype=dtype,
                  kernel_init=kernel_init,
                  bias_init=bias_init)
     x = nn.relu(x)
     if num_partitions > 1:
         x = with_sharding_constraint(x, P(1, num_partitions))
     x = nn.dropout(x, rate=dropout_rate, deterministic=deterministic)
     output = nn.Dense(x,
                       actual_out_dim,
                       dtype=dtype,
                       kernel_init=kernel_init,
                       bias_init=bias_init)
     output = nn.dropout(output,
                         rate=dropout_rate,
                         deterministic=deterministic)
     output = output.reshape(inputs_shape[:-1] + (actual_out_dim, ))
     return output
Beispiel #14
0
  def testPyTreeOutputs(self):
    if jax.local_device_count() < 4:
      raise SkipTest("requires 4 devices")

    def f(x):
      return x + 1, ((x + 2, x + 3), x + 4)

    shape = (2, 4, 4)
    x = np.arange(np.prod(shape)).reshape(shape)
    in_parts = (P(2, 1),)
    out_parts = (P(2, 1), ((P(1, 2), None), P(2, 1)))

    result = pmap(sharded_jit(f, in_parts=in_parts, out_parts=out_parts))(x)
    expected = pmap(f)(x)

    self.assertAllClose(result, expected, check_dtypes=False)
Beispiel #15
0
    def testInfeed(self, partition_input):
        if jax.local_device_count() % 2 != 0:
            raise SkipTest

        shape = (jax.local_device_count() * 2, 4)
        # Run computation across all devices so we know which devices to feed.
        parts = P(jax.local_device_count(), 1)
        in_parts = parts if partition_input else None
        infeed_shapes = (jax.ShapedArray(shape, np.float32),
                         jax.ShapedArray((1, ), np.float32))
        infeed_parts = (parts, None)

        @partial(sharded_jit, in_parts=in_parts, out_parts=None)
        def f(x):
            token = lax.create_token(x)
            (y, z), token = lax.infeed(token,
                                       infeed_shapes,
                                       partitions=infeed_parts)
            return x @ y.T + z

        x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
        y = x + 1
        shard_size = shape[0] // jax.local_device_count()
        y_shards = [
            y[i:i + shard_size] for i in range(0, shape[0], shard_size)
        ]
        z = jnp.array([3.], dtype=np.float32)

        result = f(x)
        assert len(jax.local_devices()) == len(y_shards)
        for device, y_shard in zip(jax.local_devices(), y_shards):
            device.transfer_to_infeed((y_shard, z))

        expected = x @ y.T + z
        self.assertAllClose(result, expected, check_dtypes=False)
Beispiel #16
0
  def testBasic(self):
    if jax.device_count() < 2:
      raise SkipTest

    @partial(sharded_jit, in_parts=(P(2, 1), P(2, 1)), out_parts=None)
    def f(x, y):
      return x + y

    shape = (8, 8)
    x = np.arange(np.prod(shape), dtype=np.float32).reshape(shape)
    actual = f(x, x + 1)
    expected = x + (x + 1)
    self.assertAllClose(actual, expected, check_dtypes=False)
    self.assertIsInstance(actual, pxla.ShardedDeviceArray)
    self.assertLen(actual.device_buffers, 2)
    self.assertAllClose(actual.device_buffers[0].to_py(), expected,
                        check_dtypes=False)
Beispiel #17
0
    def testPyTreeOutputs(self):
        if jax.device_count() < 2:
            raise SkipTest

        def f(x):
            return x + 1, ((x + 2, x + 3), x + 4)

        shape = (4, 4)
        x = np.arange(prod(shape)).reshape(shape)
        in_parts = (P(2, 1), )
        out_parts = (P(2, 1), ((P(1, 2), None), P(2, 1)))

        result = sharded_jit(f, in_parts, out_parts)(x)
        expected = f(x)
        self.assertAllClose(result, expected, check_dtypes=False)

        out_parts = None
        result = sharded_jit(f, in_parts, out_parts)(x)
        self.assertAllClose(result, expected, check_dtypes=False)
Beispiel #18
0
    def testShardingConstraint(self):
        if jax.local_device_count() < 2:
            raise SkipTest("requires 2 devices")

        def f(x):
            y = x + 1
            y = with_sharding_constraint(y, P(1, 2))
            return y * 2

        shape = (8, 8)
        x = np.arange(prod(shape)).reshape(shape)
        expected = (x + 1) * 2

        # Matching sharded_jit partitions
        actual = sharded_jit(f, in_parts=P(2, 1), out_parts=P(2, 1))(x)
        self.assertAllClose(actual, expected, check_dtypes=False)
        self.assertLen(actual.device_buffers, 2)
        # TODO(jblespiau): We can simply use buf.xla_shape() when version 0.1.58 is
        # the default.
        self.assertEqual(
            getattr(actual.device_buffers[0], "xla_shape",
                    actual.device_buffers[0].shape)().dimensions(), (4, 8))
        self.assertEqual(
            getattr(actual.device_buffers[1], "xla_shape",
                    actual.device_buffers[1].shape)().dimensions(), (4, 8))

        # Mismatched sharded_jit partitions
        with self.assertRaisesRegex(
                ValueError,
                r"with_sharding_constraint with partitions=PartitionSpec\(1, 2\) "
                r"\(total partitions: 2\) doesn't match expected number of partitions: "
                r"4. If these partitions look right, check outer sharded_jit and/or "
                r"other with_sharding_constraint calls."):
            sharded_jit(f, in_parts=P(2, 2), out_parts=P(2, 2))(x)

        # Replicated sharded_jit
        actual = sharded_jit(f, in_parts=None, out_parts=None)(x)
        self.assertAllClose(actual, expected, check_dtypes=False)
        self.assertLen(actual.device_buffers, 2)
        self.assertAllClose(actual.device_buffers[0].to_py(),
                            actual.device_buffers[1].to_py(),
                            check_dtypes=False)
Beispiel #19
0
  def testInAxesNone(self):
    shape = (4, 4)
    replicas = 2
    in_partitions = (P(2, 1), None, None)
    out_partitions = P(2, 1)
    in_axes = (None, None, 0)
    x = y = np.arange(np.prod(shape), dtype=np.float32).reshape(shape)
    dummy = np.arange(replicas, dtype=np.float32) + 1
    num_shards = replicas * np.prod(in_partitions[0])
    if num_shards > jax.local_device_count():
      raise SkipTest("requires %d devices" % num_shards)

    def f(x, y, _):
      return x @ y

    result = pmap(
        sharded_jit(f, in_parts=in_partitions, out_parts=out_partitions),
        in_axes=in_axes)(x, y, dummy)
    expected = pmap(f, in_axes=in_axes)(x, y, dummy)
    self.assertAllClose(result, expected, check_dtypes=True)
Beispiel #20
0
    def testNotEnoughDevices(self):
        ndevices = jax.local_device_count()

        @partial(sharded_jit, in_parts=P(ndevices + 1), out_parts=None)
        def f(x):
            return x + x

        with self.assertRaisesRegex(
                ValueError,
                f"sharded_jit computation requires {ndevices + 1} devices, "
                f"but only {ndevices} devices are available."):
            f(np.ones(ndevices + 1))
Beispiel #21
0
  def testGradOfShardingConstraint(self):
    if jax.local_device_count() < 4:
      raise SkipTest("requires 4 devices")

    @partial(sharded_jit, in_parts=P(4,1), out_parts=None)
    def f(x):
      y = x + 1
      p, vjp_f = vjp(lambda z: jnp.sin(with_sharding_constraint(z, P(2,2))), y)
      return vjp_f(p)

    def expected_f(x):
      y = x + 1
      p, vjp_f = vjp(lambda z: jnp.sin(z), y)
      return vjp_f(p)

    shape = (4, 4)
    x = jnp.arange(jnp.prod(shape), dtype=jnp.float32).reshape(shape)
    actual = f(x)
    expected = expected_f(x)
    self.assertAllClose(actual, expected, check_dtypes=False)
Beispiel #22
0
  def testManyArgs(self):
    if jax.local_device_count() < 4:
      raise SkipTest("requires 4 devices")

    num_args = 200

    def f(*args):
      return jnp.sum(args)

    shape = (2, 4, 4)
    args = [np.arange(np.prod(shape)).reshape(shape)] * num_args
    in_partitions = (P(2, 1),) * num_args
    out_partitions = None
    result = pmap(sharded_jit(
        f, in_parts=in_partitions, out_parts=out_partitions))(*args)
    expected = pmap(f)(*args)

    self.assertAllClose(result, expected, check_dtypes=False)
    self.assertTrue(isinstance(result, pxla.ShardedDeviceArray))
    self.assertEqual(len(result.device_buffers), 4)
Beispiel #23
0
    def _pjit(inp):
        if isinstance(inp, GlobalDeviceArray):
            if inp.is_fully_replicated:
                return inp.local_data(0).to_py()
            global_mesh = inp._global_mesh
            in_axis_resources = FROM_GDA
        else:
            # DA/SDA/np.array will be sharded based on global_mesh.local_mesh.
            # Shape of local_mesh will always be (1, local_device_count())
            devices = np.array(jax.devices()).reshape(jax.process_count(),
                                                      jax.local_device_count())
            global_mesh = maps.Mesh(devices, ('processes', 'local_devices'))
            in_axis_resources = P('processes')
            if inp.ndim == 0 or not titled:
                inp = np.expand_dims(inp, axis=0)

        with maps.mesh(global_mesh.devices, global_mesh.axis_names):
            out = pjit(lambda x: x,
                       in_axis_resources=in_axis_resources,
                       out_axis_resources=None)(inp)
        return out.local_data(0).to_py()
Beispiel #24
0
  def test_sharded_jit_replicated(self):
    """A replicated input and output."""

    sharded_jax_func = sharded_jit.sharded_jit(
        jnp.dot, in_parts=(P(1, 2), None), out_parts=None)
    xshape = (3, 8)
    x = np.arange(np.prod(xshape), dtype=np.float32).reshape(xshape)
    yshape = (8, 5)
    y = np.arange(np.prod(yshape), dtype=np.float32).reshape(yshape)
    self._check_sharding_annotations(
        sharded_jax_func, [x, y],
        expected=[
            r"f32\[3,8\].*sharding={devices=\[1,2\]",
            r"f32\[8,5\].*sharding={replicated}",
            r"f32\[3,5\].*sharding={replicated}"
        ],
        expected_opt=[
            # TODO(necula): relax ordering
            r"f32\[8,5\].*sharding={replicated}",
            r"f32\[3,4\].*sharding={devices=\[1,2\]",
        ],
        num_partitions=2)
Beispiel #25
0
 def f(x):
     y = jnp.dot(x, x)
     y = with_sharding_constraint(y, P(2, 1))
     return y * 2
Beispiel #26
0
class PmapOfShardedJitTest(jtu.JaxTestCase):
    def setUp(self):
        super(PmapOfShardedJitTest, self).setUp()
        if jtu.device_under_test() != "tpu":
            raise SkipTest

    # TODO(skye): make a similar version for ShardedJitTest and run the same tests
    def _runTest(self, f, in_partitions, out_partitions, dtype=np.float32):
        """Compares pmap(sharded_jit(f, ...)) to pmap(f)"""
        shape = (2, 4, 4)
        num_shards = shape[0] * np.prod(in_partitions[0])
        if num_shards > jax.local_device_count():
            raise SkipTest("requires %d devices" % num_shards)

        x = np.arange(prod(shape)).reshape(shape)
        y = x + 1
        result = pmap(
            sharded_jit(f, in_parts=in_partitions,
                        out_parts=out_partitions))(x, y)
        expected = pmap(f)(x, y)
        self.assertAllClose(result, expected, check_dtypes=False)

        flat_result = tree_util.tree_flatten(result)[0]
        for r in flat_result:
            self.assertTrue(isinstance(r, pxla.ShardedDeviceArray))
            self.assertEqual(len(r.device_buffers), num_shards)

    @parameterized.named_parameters({
        "testcase_name":
        "_in_parts={}_out_parts={}".format(in_partitions,
                                           out_partitions).replace(" ", ""),
        "in_partitions":
        in_partitions,
        "out_partitions":
        out_partitions
    } for in_partitions in [
        (P(2, 1), P(2, 1)),
        (P(2, 1), P(1, 2)),
        (P(2, 2), P(2, 2)),
        (P(4, 1), P(2, 2)),
    ] for out_partitions in [in_partitions[0], None])
    def testBasic(self, in_partitions, out_partitions):
        def f(x, y):
            return lax.dot(x, y)

        self._runTest(f, in_partitions, out_partitions)

    @parameterized.named_parameters(
        {
            "testcase_name":
            "_in_parts={}_out_parts={}".format(in_partitions,
                                               out_partitions).replace(
                                                   " ", ""),
            "in_partitions":
            in_partitions,
            "out_partitions":
            out_partitions
        } for in_partitions in [(P(2, 1),
                                 P(2, 1)), (P(2, 1),
                                            P(1, 2)), (P(4, 1), P(2, 2))]
        for out_partitions in [(in_partitions[1], in_partitions[0],
                                None), (None, None, None)])
    def testMultipleOutputs(self, in_partitions, out_partitions):
        def f(x, y):
            a = lax.dot(x, y)
            # TODO(skye): use these more interesting outputs once returning constants
            # works
            # return a, a + 1, 3
            return a, a + x, x + y

        self._runTest(f, in_partitions, out_partitions)

    @parameterized.named_parameters(
        {
            "testcase_name":
            "_in_parts={}_out_parts={}".format(in_partitions,
                                               out_partitions).replace(
                                                   " ", ""),
            "in_partitions":
            in_partitions,
            "out_partitions":
            out_partitions
        } for in_partitions in [(P(2, 1),
                                 P(2, 1)), (P(2, 1),
                                            P(1, 2)), (P(4, 1), P(2, 2))]
        for out_partitions in [in_partitions[0], None])
    def testArrayConstants(self, in_partitions, out_partitions):
        def f(x, y):
            a = lax.dot(x, y)
            b = a + jnp.ones(a.shape)
            c = b + jnp.ones(a.shape[0])
            return c

        self._runTest(f, in_partitions, out_partitions)

    def testPyTreeArgs(self):
        if jax.local_device_count() < 4:
            raise SkipTest("requires 4 devices")

        def f(a, b, c):
            a1, a2 = a
            c1, (c2, c3) = c
            return a1 + a2 + b + c1 + c2 + c3

        def _make_arg(*shape):
            return np.arange(prod(shape)).reshape(shape)

        a = (_make_arg(2, 4, 4), _make_arg(2))
        b = _make_arg(2, 4, 4)
        c = (_make_arg(2), (_make_arg(2, 4, 4), _make_arg(2, 4, 4)))

        in_parts = (None, P(2, 1), (None, P(2, 1)))
        out_parts = P(2, 1)

        result = pmap(sharded_jit(f, in_parts=in_parts,
                                  out_parts=out_parts))(a, b, c)
        expected = pmap(f)(a, b, c)

        self.assertAllClose(result, expected, check_dtypes=False)
        self.assertTrue(isinstance(result, pxla.ShardedDeviceArray))
        self.assertEqual(len(result.device_buffers), 4)

    def testPyTreeOutputs(self):
        if jax.local_device_count() < 4:
            raise SkipTest("requires 4 devices")

        def f(x):
            return x + 1, ((x + 2, x + 3), x + 4)

        shape = (2, 4, 4)
        x = np.arange(prod(shape)).reshape(shape)
        in_parts = (P(2, 1), )
        out_parts = (P(2, 1), ((P(1, 2), None), P(2, 1)))

        result = pmap(sharded_jit(f, in_parts=in_parts,
                                  out_parts=out_parts))(x)
        expected = pmap(f)(x)

        self.assertAllClose(result, expected, check_dtypes=False)

    def testManyArgs(self):
        if jax.local_device_count() < 4:
            raise SkipTest("requires 4 devices")

        num_args = 200

        def f(*args):
            return jnp.asarray(args).sum()

        shape = (2, 4, 4)
        args = [np.arange(prod(shape)).reshape(shape)] * num_args
        in_partitions = (P(2, 1), ) * num_args
        out_partitions = None
        result = pmap(
            sharded_jit(f, in_parts=in_partitions,
                        out_parts=out_partitions))(*args)
        expected = pmap(f)(*args)

        self.assertAllClose(result, expected, check_dtypes=False)
        self.assertTrue(isinstance(result, pxla.ShardedDeviceArray))
        self.assertEqual(len(result.device_buffers), 4)

    def testShardingConstraint(self):
        if jax.local_device_count() < 4:
            raise SkipTest("requires 4 devices")

        @partial(sharded_jit, in_parts=None, out_parts=None)
        def f(x):
            y = jnp.dot(x, x)
            y = with_sharding_constraint(y, P(2, 1))
            return y * 2

        def expected_f(x):
            return jnp.dot(x, x) * 2

        shape = (2, 8, 8)
        x = np.arange(prod(shape)).reshape(shape)
        result = pmap(f)(x)
        expected = pmap(expected_f)(x)

        self.assertAllClose(result, expected, check_dtypes=False)
        self.assertIsInstance(result, pxla.ShardedDeviceArray)
        self.assertLen(result.device_buffers, 4)

    def testInAxesNone(self):
        shape = (4, 4)
        replicas = 2
        in_partitions = (P(2, 1), None, None)
        out_partitions = P(2, 1)
        in_axes = (None, None, 0)
        x = y = np.arange(prod(shape), dtype=np.float32).reshape(shape)
        dummy = np.arange(replicas, dtype=np.float32) + 1
        num_shards = replicas * np.prod(in_partitions[0])
        if num_shards > jax.local_device_count():
            raise SkipTest("requires %d devices" % num_shards)

        def f(x, y, _):
            return x @ y

        result = pmap(sharded_jit(f,
                                  in_parts=in_partitions,
                                  out_parts=out_partitions),
                      in_axes=in_axes)(x, y, dummy)
        expected = pmap(f, in_axes=in_axes)(x, y, dummy)
        self.assertAllClose(result, expected, check_dtypes=True)
Beispiel #27
0
 def f(x):
     y = x + 1
     y = with_sharding_constraint(y, P(2, 1))
     return y * 2
Beispiel #28
0
 def f(x):
     y = x + 1
     p, vjp_f = vjp(
         lambda z: jnp.sin(with_sharding_constraint(z, P(2, 2))), y)
     return vjp_f(p)
Beispiel #29
0
 def f(x):
     return lax.while_loop(
         lambda i: i[0, 0] < 10.,
         lambda i: with_sharding_constraint(i + 1., P(2, 1)), x)
Beispiel #30
0
    def apply(self,
              inputs_q,
              inputs_kv,
              num_heads,
              dtype=jnp.float32,
              qkv_features=None,
              out_features=None,
              attention_axis=None,
              causal_mask=False,
              padding_mask=None,
              key_padding_mask=None,
              segmentation=None,
              key_segmentation=None,
              cache=None,
              broadcast_dropout=True,
              dropout_rng=None,
              dropout_rate=0.,
              deterministic=False,
              precision=None,
              kernel_init=nn.linear.default_kernel_init,
              bias_init=nn.initializers.zeros,
              bias=True,
              num_partitions=2):
        """Applies multi-head dot product attention on the input data.

    Projects the inputs into multi-headed query, key, and value vectors,
    applies dot-product attention and project the results to an output vector.

    This can be used for encoder-decoder attention by specifying both `inputs_q`
    and `inputs_kv` orfor self-attention by only specifying `inputs_q` and
    setting `inputs_kv` to None.

    Args:
      inputs_q: input queries of shape `[bs, dim1, dim2, ..., dimN, features]`.
      inputs_kv: key/values of shape `[bs, dim1, dim2, ..., dimN, features]`
        or None for self-attention, inn which case key/values will be derived
        from inputs_q.
      num_heads: number of attention heads. Features (i.e. inputs_q.shape[-1])
        should be divisible by the number of heads.
      dtype: the dtype of the computation (default: float32)
      qkv_features: dimension of the key, query, and value.
      out_features: dimension of the last projection
      attention_axis: axes over which the attention is applied ( 'None' means
        attention over all axes, but batch, heads, and features).
      causal_mask: boolean specifying whether to apply a causal mask on the
        attention weights. If True, the output at timestep `t` will not depend
        on inputs at timesteps strictly greater than `t`.
      padding_mask: boolean specifying query tokens that are pad token.
      key_padding_mask: boolean specifying key-value tokens that are pad token.
      segmentation: segment indices for packed inputs_q data.
      key_segmentation: segment indices for packed inputs_kv data.
      cache: an instance of `flax.nn.attention.Cache` used for efficient
        autoregressive decoding.
      broadcast_dropout: bool: use a broadcasted dropout along batch dims.
      dropout_rng: JAX PRNGKey: to be used for dropout
      dropout_rate: dropout rate
      deterministic: bool, deterministic or not (to apply dropout)
      precision: numerical precision of the computation see `jax.lax.Precision`
        for details.
      kernel_init: initializer for the kernel of the Dense layers.
      bias_init: initializer for the bias of the Dense layers.
      bias: bool: whether pointwise QKVO dense transforms use bias.
      num_partitions: number of ways to partition (i.e. how many devices to run
        across).

    Returns:
      output of shape `[bs, dim1, dim2, ..., dimN, features]`.
    """

        assert causal_mask or not cache, (
            'Caching is only support for causal attention.')

        if inputs_kv is None:
            inputs_kv = inputs_q

        if attention_axis is None:
            attention_axis = tuple(range(1, inputs_q.ndim - 1))

        features = out_features or inputs_q.shape[-1]
        qkv_features = qkv_features or inputs_q.shape[-1]

        assert qkv_features % num_heads == 0, (
            'Memory dimension must be divisible by number of heads.')
        head_dim = qkv_features // num_heads

        dense = nn.DenseGeneral.partial(axis=-1,
                                        features=(num_heads, head_dim),
                                        kernel_init=kernel_init,
                                        bias_init=bias_init,
                                        bias=bias,
                                        precision=precision)
        # project inputs_q to multi-headed q/k/v
        # dimensions are then [bs, dims..., n_heads, n_features_per_head]
        query, key, value = (dense(inputs_q, dtype=dtype, name='query'),
                             dense(inputs_kv, dtype=dtype, name='key'),
                             dense(inputs_kv, dtype=dtype, name='value'))
        if num_partitions > 1:
            partitions = P(1, 1, num_partitions, 1)
            query = with_sharding_constraint(query, partitions)
            key = with_sharding_constraint(key, partitions)
            value = with_sharding_constraint(value, partitions)

        if cache:
            assert isinstance(cache,
                              Cache), 'cache must be an instance of Cache'
            if self.is_initializing():
                cache.store(lambda: (key.ndim, key.shape[-2:]))
            else:
                cache_entry = cache.retrieve(None)
                expected_shape = list(cache_entry.key.shape[:-2])
                for attn_dim in attention_axis:
                    expected_shape[attn_dim] = 1
                expected_shape = tuple(expected_shape) + inputs_q.shape[-1:]
                if expected_shape != inputs_q.shape:
                    raise ValueError('Invalid shape provided, '
                                     'expected shape %s instead got %s.' %
                                     (expected_shape, inputs_q.shape))

                if not isinstance(cache_entry, _CacheEntry):
                    raise ValueError('Cache is not initialized.')

                cshape = cache_entry.key.shape
                i = cache_entry.i
                one_hot_indices = jax.nn.one_hot(i, cshape[3],
                                                 dtype=key.dtype).reshape(
                                                     (1, 1, 1, cshape[3]))
                key = key.transpose((0, 2, 3, 1))
                key = cache_entry.key + key * one_hot_indices
                value = value.transpose((0, 2, 3, 1))
                value = cache_entry.value + value * one_hot_indices

                one = jnp.array(1, jnp.uint32)
                cache_entry = cache_entry.replace(i=cache_entry.i + one,
                                                  key=key,
                                                  value=value)
                cache.store(cache_entry)

                key = key.transpose((0, 3, 1, 2))
                value = value.transpose((0, 3, 1, 2))
                cshape = (cshape[0], cshape[3], cshape[1], cshape[2])

                # TODO(levskaya): verify this is still needed in translation decoding.
                key_padding_mask = jnp.broadcast_to(
                    (jnp.arange(cshape[1]) < cache_entry.i), cshape[:2])
                key_padding_mask = key_padding_mask.astype(jnp.float32)[...,
                                                                        None]

        # create attention masks
        mask_components = []

        if causal_mask:
            if cache and not self.is_initializing():
                bias_pre_shape = (1, ) * (key.ndim - 1)
                attn_shape = tuple(np.take(key.shape, attention_axis))
                attn_size = np.prod(attn_shape)
                ii = jnp.arange(attn_size, dtype=jnp.uint32)
                mask = ii < cache_entry.i
                mask_components.append(
                    mask.reshape(bias_pre_shape + attn_shape))
            else:
                mask_components.append(_make_causal_mask(key, attention_axis))

        if padding_mask is not None:
            if key_padding_mask is None:
                key_padding_mask = padding_mask
            padding_mask = make_padding_mask(padding_mask_query=padding_mask,
                                             padding_mask_key=key_padding_mask,
                                             query_shape=query.shape,
                                             key_shape=key.shape,
                                             attention_axis=attention_axis)
            mask_components.append(padding_mask)

        if segmentation is not None:
            if key_segmentation is None:
                key_segmentation = segmentation
            segmentation_mask = make_padding_mask(
                padding_mask_query=segmentation,
                padding_mask_key=key_segmentation,
                query_shape=query.shape,
                key_shape=key.shape,
                attention_axis=attention_axis,
                segmentation_mask=True)
            mask_components.append(segmentation_mask)

        if mask_components:
            attention_mask = mask_components[0]
            for component in mask_components[1:]:
                attention_mask = jnp.logical_and(attention_mask, component)

            # attention mask in the form of attention bias
            attention_bias = lax.select(
                attention_mask > 0,
                jnp.full(attention_mask.shape, 0.).astype(dtype),
                jnp.full(attention_mask.shape, -1e10).astype(dtype))
        else:
            attention_bias = None

        # apply attention
        x = dot_product_attention(query,
                                  key,
                                  value,
                                  dtype=dtype,
                                  axis=attention_axis,
                                  bias=attention_bias,
                                  precision=precision,
                                  dropout_rng=dropout_rng,
                                  dropout_rate=dropout_rate,
                                  broadcast_dropout=broadcast_dropout,
                                  deterministic=deterministic)

        # back to the original inputs dimensions
        out = nn.DenseGeneral(x,
                              features=features,
                              axis=(-2, -1),
                              kernel_init=kernel_init,
                              bias_init=bias_init,
                              bias=bias,
                              dtype=dtype,
                              precision=precision,
                              name='out')
        if num_partitions > 1:
            x = with_sharding_constraint(x, None)

        return out