def testVMap(self): f = pjit(lambda x, y: (x + y, x), in_axis_resources=P('x'), out_axis_resources=P('x')) x = jnp.arange(4) y = jnp.arange(5*4).reshape((5, 4)) z, w = jax.vmap(f, in_axes=(None, 0), out_axes=(0, None))(x, y) self.assertAllClose(z, x + y) self.assertAllClose(w, x) self.assertEqual(z.sharding_spec.sharding, (pxla.NoSharding(), pxla.Chunked([2]))) self.assertEqual(w.sharding_spec.sharding, (pxla.Chunked([2]),))
def testNestedMesh(self, mesh, axis_resources): @partial(xmap, in_axes={1: 'a'}, out_axes=({ 0: 'a' }, {}), axis_resources=dict([axis_resources[0]])) def f(x): y = x * 2 @partial(xmap, in_axes={0: 'b'}, out_axes=({ 1: 'b' }, {}), axis_resources=dict([axis_resources[1]])) def h(y): return jnp.sin(y), lax.psum(y, ('a', 'b')) return h(y) xshape = (4, 2, 5) x = jnp.arange(np.prod(xshape)).reshape(xshape) y = f(x) self.assertAllClose(y, (jnp.sin(x * 2).transpose( (1, 2, 0)), (x * 2).sum((0, 1)))) self.assertEqual( y[0].sharding_spec.sharding, (pxla.Chunked(2), pxla.NoSharding(), pxla.NoSharding())) self.assertEqual(y[0].sharding_spec.mesh_mapping, (pxla.Replicated(2), pxla.ShardedAxis(0)) + (pxla.Replicated(2), ) * (len(mesh) - 2))
def testNestedXMapMesh(self): @partial(xmap, in_axes=A({'a': 1}), out_axes=A({'a': 0}), schedule=[('a', 'y')]) def f(x): y = x * 2 @partial(xmap, in_axes=A({'b': 0}), out_axes=A({'b': 1}), schedule=[('b', 'x')]) def h(y): return jnp.sin(y) return h(y) xshape = (2, 3, 5) x = jnp.arange(np.prod(xshape)).reshape(xshape) y = f(x) self.assertAllClose(y, jnp.sin(x * 2).transpose((1, 2, 0))) # Make sure the op really ran accros a 2D mesh. self.assertEqual(y.sharding_spec.sharding, (pxla.Chunked(3), None, None)) self.assertEqual(y.sharding_spec.mesh_mapping, (pxla.Replicated(2), pxla.ShardedAxis(0)))
def testNestedMesh(self): @partial(xmap, in_axes={1: 'a'}, out_axes={0: 'a'}, axis_resources={'a': 'y'}) def f(x): y = x * 2 @partial(xmap, in_axes={0: 'b'}, out_axes={1: 'b'}, axis_resources={'b': 'x'}) def h(y): return jnp.sin(y) return h(y) xshape = (2, 3, 5) x = jnp.arange(np.prod(xshape)).reshape(xshape) y = f(x) self.assertAllClose(y, jnp.sin(x * 2).transpose((1, 2, 0))) # Make sure the op really ran accros a 2D mesh. self.assertEqual(y.sharding_spec.sharding, (pxla.Chunked(3), None, None)) self.assertEqual(y.sharding_spec.mesh_mapping, (pxla.Replicated(2), pxla.ShardedAxis(0)))
def testOneLogicalTwoMeshAxesSharding(self): def f(v): return v * 4 fxy = xmap(f, in_axes=['a', ...], out_axes={1: 'a'}, axis_resources={'a': ('x', 'y')}) fyx = xmap(f, in_axes=['a', ...], out_axes={1: 'a'}, axis_resources={'a': ('y', 'x')}) vshape = (4, 5) v = jnp.arange(np.prod(vshape)).reshape(vshape) zxy = fxy(v) self.assertEqual( zxy.sharding_spec, pxla.ShardingSpec((pxla.NoSharding(), pxla.Chunked((2, 2))), (pxla.ShardedAxis(0), pxla.ShardedAxis(1)))) zyx = fyx(v) self.assertEqual( zyx.sharding_spec, pxla.ShardingSpec((pxla.NoSharding(), pxla.Chunked((2, 2))), (pxla.ShardedAxis(1), pxla.ShardedAxis(0))))
def testNestedMesh(self, mesh, axis_resources): @partial(xmap, in_axes={1: 'a'}, out_axes=({0: 'a'}, {}), axis_resources=dict([axis_resources[0]])) def f(x): y = x * 2 @partial(xmap, in_axes={0: 'b'}, out_axes=({1: 'b'}, {}), axis_resources=dict([axis_resources[1]])) def h(y): # Multiply by a constant array to better exercise the partial_eval rule return jnp.sin(y) * np.arange(y.size), lax.psum(y, ('a', 'b')) return h(y) xshape = (4, 2, 5) x = jnp.arange(np.prod(xshape)).reshape(xshape) y = f(x) self.assertAllClose(y, ((jnp.sin(x * 2) * np.arange(xshape[-1])).transpose((1, 2, 0)), (x * 2).sum((0, 1)))) self.assertEqual(y[0].sharding_spec.sharding, (pxla.Chunked([2]), pxla.NoSharding(), pxla.NoSharding())) self.assertEqual(y[0].sharding_spec.mesh_mapping, (pxla.Replicated(2), pxla.ShardedAxis(0)) + (pxla.Replicated(2),) * (len(mesh) - 2)) if maps.EXPERIMENTAL_SPMD_LOWERING: hlo = jax.xla_computation(f)(x).as_hlo_text() # Make sure that there are non-partial sharding specs in the HLO self.assertRegex(hlo, r"sharding={devices=\[[0-9,]+\][0-9,]+}")