def testXMapMeshCollectives(self): local_devices = list(jax.local_devices()) if len(local_devices) < 4: raise SkipTest("Test requires at least 4 local devices") def f(a, b): return lax.psum(a * 2, 'a'), b * 4 devices = np.array(local_devices[:4]).reshape((2, 2)) with mesh(devices, ('x', 'y')): fm = xmap(f, in_axes=[A({ 'a': 0, 'b': 1 }), A({'c': 0})], out_axes=[A({'b': 0}), A({'c': 0})], schedule=[ ('a', 'x'), ('b', 'y'), ('c', 'x'), ('a', 'vectorize'), ('b', 'vectorize'), ]) ashape = (16, 8, 5) a = jnp.arange(np.prod(ashape)).reshape(ashape) bshape = (2, 7) b = jnp.arange(np.prod(bshape)).reshape(bshape) c, d = fm(a, b) self.assertAllClose(c, (a * 2).sum(0)) self.assertAllClose(d, b * 4)
def f(x): with mesh(np.empty((), dtype=np.object), ()): @partial(xmap, in_axes=A({'b': 0}), out_axes=A({'b': 0}), schedule=[('b', 'vectorize')]) def h(x): return x return h(x)
def new_f(*args, **kwargs): axis_names, shape = unzip2(named_shape) size = np.prod(shape) local_devices = list(jax.local_devices()) if len(local_devices) < size: raise SkipTest(f"Test requires {size} local devices") mesh_devices = np.array(local_devices[:size]).reshape(shape) with mesh(mesh_devices, axis_names): return f(*args, **kwargs)