def testTrees(self): ptranspose = lambda x, axis_name: lax.all_to_all(x, axis_name, 0, 0) def protate(x, axis_name): n = lax.psum(1, axis_name) return lax.ppermute(x, axis_name, [(i, (i + 1) % n) for i in range(n)]) tree_f = lambda f: partial(tree_util.tree_map, f) jax_f = lambda p: pmap(lambda x: p(x, 'i'), 'i') onp_f = lambda p: tree_f(lambda x: onp.broadcast_to(p(x, 0), x.shape)) onp_transpose = tree_f(onp.transpose) onp_rotate = tree_f(lambda x: onp.concatenate([x[-1:], x[:-1]])) n = xla_bridge.device_count() x = {'a': onp.arange(1 * n * n, 2 * n * n).reshape([n, n]), 'b': onp.arange(2 * n * n, 3 * n * n).reshape([n, n]), 'c': onp.arange(4 * n * n, 5 * n * n).reshape([n, n])} assert_allclose = partial(tree_util.tree_multimap, partial(self.assertAllClose, check_dtypes=False)) assert_allclose(jax_f(lax.pmax)(x), onp_f(onp.max)(x)) assert_allclose(jax_f(lax.pmin)(x), onp_f(onp.min)(x)) assert_allclose(jax_f(lax.psum)(x), onp_f(onp.sum)(x)) assert_allclose(jax_f(lax.pmean)(x), onp_f(onp.mean)(x)) if jtu.device_under_test() not in ("cpu", "gpu"): # NOTE: all-to-all and ppermute only supported on TPU. assert_allclose(jax_f(ptranspose)(x), onp_transpose(x)) assert_allclose(jax_f(protate)(x), onp_rotate(x))
def testAllToAll(self, vmap_axis, split_axis, concat_axis): shape = (4, 4, 4, 4) x = np.arange(np.prod(shape)).reshape(shape) f = vmap(lambda x: lax.all_to_all(x, 'i', split_axis, concat_axis), in_axes=vmap_axis, axis_name='i') y = f(x) ref = jnp.moveaxis(x, (vmap_axis, split_axis + (vmap_axis <= split_axis)), (concat_axis + 1, 0)) self.assertAllClose(y, ref)
def f(x): return lax.all_to_all(x, ('i', 'j'), split_axis, concat_axis)
def f(x): return lax.all_to_all(x, 'i', 0, 0)