def testPyTreeArgs(self): if jax.device_count() < 2: raise SkipTest def f(a, b, c): a1, a2 = a c1, (c2, c3) = c return a1 + a2 + b + c1 + c2 + c3 def _make_arg(*shape): return np.arange(prod(shape)).reshape(shape) a = (_make_arg(4, 4), 1) b = _make_arg(4, 4) c = [2, (_make_arg(4, 4), _make_arg(4, 4))] in_parts = (None, P(2, 1), [None, P(2, 1)]) out_parts = P(2, 1) result = sharded_jit(f, in_parts, out_parts)(a, b, c) expected = f(a, b, c) self.assertAllClose(result, expected, check_dtypes=False) self.assertIsInstance(result, pxla.ShardedDeviceArray) self.assertLen(result.device_buffers, 2) in_parts = None result = sharded_jit(f, in_parts, out_parts)(a, b, c) self.assertAllClose(result, expected, check_dtypes=False) self.assertIsInstance(result, pxla.ShardedDeviceArray) self.assertLen(result.device_buffers, 2)
def testPyTreeArgs(self): if jax.local_device_count() < 4: raise SkipTest("requires 4 devices") def f(a, b, c): a1, a2 = a c1, (c2, c3) = c return a1 + a2 + b + c1 + c2 + c3 def _make_arg(*shape): return np.arange(prod(shape)).reshape(shape) a = (_make_arg(2, 4, 4), _make_arg(2)) b = _make_arg(2, 4, 4) c = (_make_arg(2), (_make_arg(2, 4, 4), _make_arg(2, 4, 4))) in_parts = (None, P(2, 1), (None, P(2, 1))) out_parts = P(2, 1) result = pmap(sharded_jit(f, in_parts=in_parts, out_parts=out_parts))(a, b, c) expected = pmap(f)(a, b, c) self.assertAllClose(result, expected, check_dtypes=False) self.assertTrue(isinstance(result, pxla.ShardedDeviceArray)) self.assertEqual(len(result.device_buffers), 4)
def testCompilationCache(self): f = lambda x: x + 1 sharded_f = sharded_jit(f, in_parts=P(2), out_parts=P(2)) shape = (2, ) x = np.arange(prod(shape), dtype=np.float32).reshape(shape) with jtu.assert_num_jit_and_pmap_compilations(1): sharded_f(x) sharded_f(x)
def testPyTreeOutputs(self): if jax.device_count() < 2: raise SkipTest def f(x): return x + 1, ((x + 2, x + 3), x + 4) shape = (4, 4) x = np.arange(prod(shape)).reshape(shape) in_parts = (P(2, 1),) out_parts = (P(2, 1), ((P(1, 2), None), P(2, 1))) result = sharded_jit(f, in_parts, out_parts)(x) expected = f(x) self.assertAllClose(result, expected, check_dtypes=False) out_parts = None result = sharded_jit(f, in_parts, out_parts)(x) self.assertAllClose(result, expected, check_dtypes=False)
def testShardingConstraint(self): if jax.local_device_count() < 2: raise SkipTest("requires 2 devices") def f(x): y = x + 1 y = with_sharding_constraint(y, P(1, 2)) return y * 2 shape = (8, 8) x = np.arange(prod(shape)).reshape(shape) expected = (x + 1) * 2 # Matching sharded_jit partitions actual = sharded_jit(f, in_parts=P(2, 1), out_parts=P(2, 1))(x) self.assertAllClose(actual, expected, check_dtypes=False) self.assertLen(actual.device_buffers, 2) # TODO(jblespiau): We can simply use buf.xla_shape() when version 0.1.58 is # the default. self.assertEqual( getattr(actual.device_buffers[0], "xla_shape", actual.device_buffers[0].shape)().dimensions(), (4, 8)) self.assertEqual( getattr(actual.device_buffers[1], "xla_shape", actual.device_buffers[1].shape)().dimensions(), (4, 8)) # Mismatched sharded_jit partitions with self.assertRaisesRegex( ValueError, r"with_sharding_constraint with partitions=PartitionSpec\(1, 2\) " r"\(total partitions: 2\) doesn't match expected number of partitions: " r"4. If these partitions look right, check outer sharded_jit and/or " r"other with_sharding_constraint calls."): sharded_jit(f, in_parts=P(2, 2), out_parts=P(2, 2))(x) # Replicated sharded_jit actual = sharded_jit(f, in_parts=None, out_parts=None)(x) self.assertAllClose(actual, expected, check_dtypes=False) self.assertLen(actual.device_buffers, 2) self.assertAllClose(actual.device_buffers[0].to_py(), actual.device_buffers[1].to_py(), check_dtypes=False)
def testCompilationCache(self): if jax.local_device_count() < 2: raise SkipTest("requires 2 devices") f = lambda x: x + 1 sharded_f = sharded_jit(f, in_parts=P(2), out_parts=P(2)) shape = (2,) x = np.arange(prod(shape), dtype=np.float32).reshape(shape) with jtu.assert_num_jit_and_pmap_compilations(1): sharded_f(x) sharded_f(x)
def testPyTreeOutputs(self): if jax.local_device_count() < 4: raise SkipTest("requires 4 devices") def f(x): return x + 1, ((x + 2, x + 3), x + 4) shape = (2, 4, 4) x = np.arange(prod(shape)).reshape(shape) in_parts = (P(2, 1),) out_parts = (P(2, 1), ((P(1, 2), None), P(2, 1))) result = pmap(sharded_jit(f, in_parts=in_parts, out_parts=out_parts))(x) expected = pmap(f)(x) self.assertAllClose(result, expected, check_dtypes=False)
def testNestedShardingConstraint(self): if jax.local_device_count() < 2: raise SkipTest("requires 2 devices") shape = (8, 8) @jit def f(x): return lax.while_loop(lambda i: i[0,0] < 10., lambda i: with_sharding_constraint(i + 1., P(2, 1)), x) x = np.arange(prod(shape), dtype=np.float32).reshape(shape) expected = x + 10. actual = sharded_jit(f, in_parts=None, out_parts=None)(x) self.assertAllClose(actual, expected, check_dtypes=False) self.assertLen(actual.device_buffers, 2)
def _runTest(self, f, in_partitions, out_partitions, dtype=np.float32): """Compares pmap(sharded_jit(f, ...)) to pmap(f)""" shape = (2, 4, 4) num_shards = shape[0] * np.prod(in_partitions[0]) if num_shards > jax.local_device_count(): raise SkipTest("requires %d devices" % num_shards) x = np.arange(prod(shape)).reshape(shape) y = x + 1 result = pmap( sharded_jit(f, in_parts=in_partitions, out_parts=out_partitions))(x, y) expected = pmap(f)(x, y) self.assertAllClose(result, expected, check_dtypes=False) flat_result = tree_util.tree_flatten(result)[0] for r in flat_result: self.assertTrue(isinstance(r, pxla.ShardedDeviceArray)) self.assertEqual(len(r.device_buffers), num_shards)
def testInAxesNone(self): shape = (4, 4) replicas = 2 in_partitions = (P(2, 1), None, None) out_partitions = P(2, 1) in_axes = (None, None, 0) x = y = np.arange(prod(shape), dtype=np.float32).reshape(shape) dummy = np.arange(replicas, dtype=np.float32) + 1 num_shards = replicas * np.prod(in_partitions[0]) if num_shards > jax.local_device_count(): raise SkipTest("requires %d devices" % num_shards) def f(x, y, _): return x @ y result = pmap( sharded_jit(f, in_parts=in_partitions, out_parts=out_partitions), in_axes=in_axes)(x, y, dummy) expected = pmap(f, in_axes=in_axes)(x, y, dummy) self.assertAllClose(result, expected, check_dtypes=True)
def testManyArgs(self): if jax.local_device_count() < 4: raise SkipTest("requires 4 devices") num_args = 200 def f(*args): return jnp.asarray(args).sum() shape = (2, 4, 4) args = [np.arange(prod(shape)).reshape(shape)] * num_args in_partitions = (P(2, 1),) * num_args out_partitions = None result = pmap(sharded_jit( f, in_parts=in_partitions, out_parts=out_partitions))(*args) expected = pmap(f)(*args) self.assertAllClose(result, expected, check_dtypes=False) self.assertTrue(isinstance(result, pxla.ShardedDeviceArray)) self.assertEqual(len(result.device_buffers), 4)