def testUnmaterializedAxis(self): shape = (4, 8) spec = pxla.ShardingSpec(shards_per_axis=(4, 1), is_axis_materialized=(False, True), replication_factor=1) self.assertEqual(pxla.spec_to_indices(shape, spec), (0, 1, 2, 3)) shape = (2, 2) spec = pxla.ShardingSpec(shards_per_axis=(1, 2), is_axis_materialized=(True, False), replication_factor=1) self.assertEqual(pxla.spec_to_indices(shape, spec), ((slice(None), 0), (slice(None), 1)))
def testReplication(self): shape = (2, 8) spec = pxla.ShardingSpec(shards_per_axis=(2, 1), is_axis_materialized=(False, True), replication_factor=3) self.assertEqual(pxla.spec_to_indices(shape, spec), (0, 0, 0, 1, 1, 1))
def testNoSharding(self): shape = (4,8) spec = pxla.ShardingSpec(shards_per_axis=(1, 1), is_axis_materialized=(True, True), replication_factor=1) self.assertEqual(pxla.spec_to_indices(shape, spec), (slice(None),))
def testUnshardedAxis(self): shape = (4, 8) spec = pxla.ShardingSpec(shards_per_axis=(2, 1), is_axis_materialized=(True, True), replication_factor=1) self.assertEqual(pxla.spec_to_indices(shape, spec), (slice(0,2), (slice(2,4))))
def testOneLogicalTwoMeshAxesSharding(self): def f(v): return v * 4 fxy = xmap(f, in_axes=['a', ...], out_axes={1: 'a'}, axis_resources={'a': ('x', 'y')}) fyx = xmap(f, in_axes=['a', ...], out_axes={1: 'a'}, axis_resources={'a': ('y', 'x')}) vshape = (4, 5) v = jnp.arange(np.prod(vshape)).reshape(vshape) zxy = fxy(v) self.assertEqual( zxy.sharding_spec, pxla.ShardingSpec((pxla.NoSharding(), pxla.Chunked((2, 2))), (pxla.ShardedAxis(0), pxla.ShardedAxis(1)))) zyx = fyx(v) self.assertEqual( zyx.sharding_spec, pxla.ShardingSpec((pxla.NoSharding(), pxla.Chunked((2, 2))), (pxla.ShardedAxis(1), pxla.ShardedAxis(0))))
def testReshardInput(self): if xla_bridge.device_count() < 6: raise SkipTest("testReshardInput requires 6 devices") # Manually construct a ShardedDeviceArray with the wrong sharding for the # subsequent pmap shard_shape = (3, 2) shard = np.arange(np.prod(shard_shape)).reshape(shard_shape) bufs = [xla.device_put(shard, d) for d in xla_bridge.devices()[:4]] aval = ShapedArray((6, 4), shard.dtype) sharding_spec = pxla.ShardingSpec(shards_per_axis=(2, 2), is_axis_materialized=(True, True), replication_factor=2) arr = pxla.ShardedDeviceArray(aval, sharding_spec, bufs) r = pmap(lambda x: x + 1)(arr) self.assertAllClose(r, arr + 1, check_dtypes=True) self.assertEqual(len(r.device_buffers), 6)