def testXMapMeshCollectives(self): local_devices = list(jax.local_devices()) if len(local_devices) < 4: raise SkipTest("Test requires at least 4 local devices") def f(a, b): return lax.psum(a * 2, 'a'), b * 4 devices = np.array(local_devices[:4]).reshape((2, 2)) with mesh(devices, ('x', 'y')): fm = xmap(f, in_axes=[A({ 'a': 0, 'b': 1 }), A({'c': 0})], out_axes=[A({'b': 0}), A({'c': 0})], schedule=[ ('a', 'x'), ('b', 'y'), ('c', 'x'), ('a', 'vectorize'), ('b', 'vectorize'), ]) ashape = (16, 8, 5) a = jnp.arange(np.prod(ashape)).reshape(ashape) bshape = (2, 7) b = jnp.arange(np.prod(bshape)).reshape(bshape) c, d = fm(a, b) self.assertAllClose(c, (a * 2).sum(0)) self.assertAllClose(d, b * 4)
def testXMap(self): def f(a, b): return a + 2, b * 4 fm = xmap(f, in_axes=[A({ 'x': 0, 'z': 1 }), A({'y': 1})], out_axes=[A({ 'x': 1, 'z': 0 }), A({'y': 0})], schedule=[ ('x', 'r1'), ('x', 'r2'), ('y', 'r1'), ('z', 'r3'), ('x', 'vectorize'), ('y', 'vectorize'), ]) a = jnp.arange(16 * 5 * 2).reshape((16, 5, 2)) b = jnp.arange(6 * 16).reshape((6, 16)) c, d = fm(a, b) self.assertAllClose(c, (a + 2).transpose((1, 0, 2))) self.assertAllClose(d, (b * 4).T)
def testXMapCompilationCache(self): def f(x): assert python_should_be_executing return x * 2 fm = xmap(f, in_axes=A({'a': 0}), out_axes=A({'a': 0}), schedule=[('a', 'x'), ('a', 'vectorize')]) x = np.arange(8).reshape((2, 2, 2)) python_should_be_executing = True fm(x) python_should_be_executing = False fm(x)
def testPdotBasic(self): def f(x, y): return lax.pdot(x, y, 'i') f_mapped = xmap(f, in_axes=[A({'i': 1}), A({'i': 0})], out_axes=A(), schedule=[('i', 'r1'), ('i', 'vectorize')]) rng = np.random.RandomState(0) x = rng.randn(3, 8) y = rng.randn(8, 5) z = f_mapped(x, y) self.assertAllClose(z, jnp.dot(x, y))
def testPdotBatching(self): def f(x, y): return lax.pdot(x, y, 'i') rng = np.random.RandomState(0) x = rng.randn(2, 3, 8) y = rng.randn(2, 8, 5) f_mapped = xmap(f, in_axes=[A({'i': 2, 'j': 0}), A({'i': 1, 'j': 0})], out_axes=A({'j': 0}), schedule=[('j', 'vectorize'), ('i', 'r1'), ('i', 'vectorize')]) z = f_mapped(x, y) self.assertAllClose(z, jnp.einsum('nij,njk->nik', x, y))
def testXMapCollectives(self): def f(a, b): return lax.psum(a + 2, 'x'), b * 4 with fake_resources(r1=4, r2=2, r3=5): fm = xmap(f, in_axes=[A({ 'x': 0, 'z': 1 }), A({'y': 1})], out_axes=[A({'z': 0}), A({'y': 0})], schedule=[ ('x', 'r1'), ('x', 'r2'), ('y', 'r1'), ('z', 'r3'), ('x', 'vectorize'), ('y', 'vectorize'), ]) a = jnp.arange(16 * 5 * 2).reshape((16, 5, 2)) b = jnp.arange(6 * 16).reshape((6, 16)) c, d = fm(a, b) self.assertAllClose(c, (a + 2).sum(0)) self.assertAllClose(d, (b * 4).T)