Ejemplo n.º 1
  def testPyTreeArgs(self):
    if jax.device_count() < 2:
      raise SkipTest

    def f(a, b, c):
      a1, a2 = a
      c1, (c2, c3) = c
      return a1 + a2 + b + c1 + c2 + c3

    def _make_arg(*shape):
      return np.arange(prod(shape)).reshape(shape)

    a = (_make_arg(4, 4), 1)
    b = _make_arg(4, 4)
    c = [2, (_make_arg(4, 4), _make_arg(4, 4))]

    in_parts = (None, P(2, 1), [None, P(2, 1)])
    out_parts = P(2, 1)

    result = sharded_jit(f, in_parts, out_parts)(a, b, c)
    expected = f(a, b, c)

    self.assertAllClose(result, expected, check_dtypes=False)
    self.assertIsInstance(result, pxla.ShardedDeviceArray)
    self.assertLen(result.device_buffers, 2)

    in_parts = None
    result = sharded_jit(f, in_parts, out_parts)(a, b, c)
    self.assertAllClose(result, expected, check_dtypes=False)
    self.assertIsInstance(result, pxla.ShardedDeviceArray)
    self.assertLen(result.device_buffers, 2)
Ejemplo n.º 2
  def testPyTreeArgs(self):
    if jax.local_device_count() < 4:
      raise SkipTest("requires 4 devices")

    def f(a, b, c):
      a1, a2 = a
      c1, (c2, c3) = c
      return a1 + a2 + b + c1 + c2 + c3

    def _make_arg(*shape):
      return np.arange(prod(shape)).reshape(shape)

    a = (_make_arg(2, 4, 4), _make_arg(2))
    b = _make_arg(2, 4, 4)
    c = (_make_arg(2), (_make_arg(2, 4, 4), _make_arg(2, 4, 4)))

    in_parts = (None, P(2, 1), (None, P(2, 1)))
    out_parts = P(2, 1)

    result = pmap(sharded_jit(f, in_parts=in_parts, out_parts=out_parts))(a, b, c)
    expected = pmap(f)(a, b, c)

    self.assertAllClose(result, expected, check_dtypes=False)
    self.assertTrue(isinstance(result, pxla.ShardedDeviceArray))
    self.assertEqual(len(result.device_buffers), 4)
Ejemplo n.º 3
    def testCompilationCache(self):
        f = lambda x: x + 1
        sharded_f = sharded_jit(f, in_parts=P(2), out_parts=P(2))
        shape = (2, )
        x = np.arange(prod(shape), dtype=np.float32).reshape(shape)

        with jtu.assert_num_jit_and_pmap_compilations(1):
Ejemplo n.º 4
  def testPyTreeOutputs(self):
    if jax.device_count() < 2:
      raise SkipTest

    def f(x):
      return x + 1, ((x + 2, x + 3), x + 4)

    shape = (4, 4)
    x = np.arange(prod(shape)).reshape(shape)
    in_parts = (P(2, 1),)
    out_parts = (P(2, 1), ((P(1, 2), None), P(2, 1)))

    result = sharded_jit(f, in_parts, out_parts)(x)
    expected = f(x)
    self.assertAllClose(result, expected, check_dtypes=False)

    out_parts = None
    result = sharded_jit(f, in_parts, out_parts)(x)
    self.assertAllClose(result, expected, check_dtypes=False)
Ejemplo n.º 5
    def testShardingConstraint(self):
        if jax.local_device_count() < 2:
            raise SkipTest("requires 2 devices")

        def f(x):
            y = x + 1
            y = with_sharding_constraint(y, P(1, 2))
            return y * 2

        shape = (8, 8)
        x = np.arange(prod(shape)).reshape(shape)
        expected = (x + 1) * 2

        # Matching sharded_jit partitions
        actual = sharded_jit(f, in_parts=P(2, 1), out_parts=P(2, 1))(x)
        self.assertAllClose(actual, expected, check_dtypes=False)
        self.assertLen(actual.device_buffers, 2)
        # TODO(jblespiau): We can simply use buf.xla_shape() when version 0.1.58 is
        # the default.
            getattr(actual.device_buffers[0], "xla_shape",
                    actual.device_buffers[0].shape)().dimensions(), (4, 8))
            getattr(actual.device_buffers[1], "xla_shape",
                    actual.device_buffers[1].shape)().dimensions(), (4, 8))

        # Mismatched sharded_jit partitions
        with self.assertRaisesRegex(
                r"with_sharding_constraint with partitions=PartitionSpec\(1, 2\) "
                r"\(total partitions: 2\) doesn't match expected number of partitions: "
                r"4. If these partitions look right, check outer sharded_jit and/or "
                r"other with_sharding_constraint calls."):
            sharded_jit(f, in_parts=P(2, 2), out_parts=P(2, 2))(x)

        # Replicated sharded_jit
        actual = sharded_jit(f, in_parts=None, out_parts=None)(x)
        self.assertAllClose(actual, expected, check_dtypes=False)
        self.assertLen(actual.device_buffers, 2)
Ejemplo n.º 6
  def testCompilationCache(self):
    if jax.local_device_count() < 2:
      raise SkipTest("requires 2 devices")
    f = lambda x: x + 1
    sharded_f = sharded_jit(f, in_parts=P(2), out_parts=P(2))
    shape = (2,)
    x = np.arange(prod(shape), dtype=np.float32).reshape(shape)

    with jtu.assert_num_jit_and_pmap_compilations(1):
Ejemplo n.º 7
  def testPyTreeOutputs(self):
    if jax.local_device_count() < 4:
      raise SkipTest("requires 4 devices")

    def f(x):
      return x + 1, ((x + 2, x + 3), x + 4)

    shape = (2, 4, 4)
    x = np.arange(prod(shape)).reshape(shape)
    in_parts = (P(2, 1),)
    out_parts = (P(2, 1), ((P(1, 2), None), P(2, 1)))

    result = pmap(sharded_jit(f, in_parts=in_parts, out_parts=out_parts))(x)
    expected = pmap(f)(x)

    self.assertAllClose(result, expected, check_dtypes=False)
Ejemplo n.º 8
  def testNestedShardingConstraint(self):
    if jax.local_device_count() < 2:
      raise SkipTest("requires 2 devices")

    shape = (8, 8)

    def f(x):
      return lax.while_loop(lambda i: i[0,0] < 10.,
                            lambda i: with_sharding_constraint(i + 1., P(2, 1)),

    x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
    expected = x + 10.
    actual = sharded_jit(f, in_parts=None, out_parts=None)(x)
    self.assertAllClose(actual, expected, check_dtypes=False)
    self.assertLen(actual.device_buffers, 2)
Ejemplo n.º 9
  def _runTest(self, f, in_partitions, out_partitions, dtype=np.float32):
    """Compares pmap(sharded_jit(f, ...)) to pmap(f)"""
    shape = (2, 4, 4)
    num_shards = shape[0] * np.prod(in_partitions[0])
    if num_shards > jax.local_device_count():
      raise SkipTest("requires %d devices" % num_shards)

    x = np.arange(prod(shape)).reshape(shape)
    y = x + 1
    result = pmap(
        sharded_jit(f, in_parts=in_partitions, out_parts=out_partitions))(x, y)
    expected = pmap(f)(x, y)
    self.assertAllClose(result, expected, check_dtypes=False)

    flat_result = tree_util.tree_flatten(result)[0]
    for r in flat_result:
      self.assertTrue(isinstance(r, pxla.ShardedDeviceArray))
      self.assertEqual(len(r.device_buffers), num_shards)
Ejemplo n.º 10
  def testInAxesNone(self):
    shape = (4, 4)
    replicas = 2
    in_partitions = (P(2, 1), None, None)
    out_partitions = P(2, 1)
    in_axes = (None, None, 0)
    x = y = np.arange(prod(shape), dtype=np.float32).reshape(shape)
    dummy = np.arange(replicas, dtype=np.float32) + 1
    num_shards = replicas * np.prod(in_partitions[0])
    if num_shards > jax.local_device_count():
      raise SkipTest("requires %d devices" % num_shards)

    def f(x, y, _):
      return x @ y

    result = pmap(
        sharded_jit(f, in_parts=in_partitions, out_parts=out_partitions),
        in_axes=in_axes)(x, y, dummy)
    expected = pmap(f, in_axes=in_axes)(x, y, dummy)
    self.assertAllClose(result, expected, check_dtypes=True)
Ejemplo n.º 11
  def testManyArgs(self):
    if jax.local_device_count() < 4:
      raise SkipTest("requires 4 devices")

    num_args = 200

    def f(*args):
      return jnp.asarray(args).sum()

    shape = (2, 4, 4)
    args = [np.arange(prod(shape)).reshape(shape)] * num_args
    in_partitions = (P(2, 1),) * num_args
    out_partitions = None
    result = pmap(sharded_jit(
        f, in_parts=in_partitions, out_parts=out_partitions))(*args)
    expected = pmap(f)(*args)

    self.assertAllClose(result, expected, check_dtypes=False)
    self.assertTrue(isinstance(result, pxla.ShardedDeviceArray))
    self.assertEqual(len(result.device_buffers), 4)