def testNestedXMapDifferentResources(self): @partial(xmap, in_axes=A({'a': 0}), out_axes=A({'a': 0}), schedule=[('a', 'x')]) def f(x): with mesh(np.empty((), dtype=np.object), ()): @partial(xmap, in_axes=A({'b': 0}), out_axes=A({'b': 0}), schedule=[('b', 'vectorize')]) def h(x): return x return h(x) xshape = (2, 5, 6) x = jnp.arange(np.prod(xshape)).reshape(xshape) with self.assertRaisesRegex(RuntimeError, "Changing the resource environment.*"): f(x)
def testPdotBatching(self): def f(x, y): return lax.pdot(x, y, 'i') rng = np.random.RandomState(0) x = rng.randn(2, 3, 8) y = rng.randn(2, 8, 5) f_mapped = xmap(f, in_axes=[A({ 'i': 2, 'j': 0 }), A({ 'i': 1, 'j': 0 })], out_axes=A({'j': 0}), schedule=[('j', 'vectorize'), ('i', 'r1'), ('i', 'vectorize')]) z = f_mapped(x, y) self.assertAllClose(z, jnp.einsum('nij,njk->nik', x, y))
def testXMapCollectives(self): def f(a, b): return lax.psum(a + 2, 'x'), b * 4 fm = xmap(f, in_axes=[A({ 'x': 0, 'z': 1 }), A({'y': 1})], out_axes=[A({'z': 0}), A({'y': 0})], schedule=[ ('x', 'r1'), ('x', 'r2'), ('y', 'r1'), ('z', 'r3'), ('x', 'vectorize'), ('y', 'vectorize'), ]) a = jnp.arange(16 * 5 * 2).reshape((16, 5, 2)) b = jnp.arange(6 * 16).reshape((6, 16)) c, d = fm(a, b) self.assertAllClose(c, (a + 2).sum(0)) self.assertAllClose(d, (b * 4).T)