def testCollectivePermute1D(self): perm = np.array([3, 1, 2, 0]) x = jnp.arange(4) result = xmap(lambda x: lax.pshuffle(x, 'i', perm), in_axes=['i', ...], out_axes=['i', ...])(x) self.assertAllClose(result, perm)
def testCollectivePermuteCyclicWithPShuffle(self): device_count = xla_bridge.device_count() values = onp.arange(device_count) shift_right = [(i - 1) % device_count for i in range(device_count)] f = lambda x: lax.pshuffle(x, perm=shift_right, axis_name='i') expected = onp.roll(values, -1) ans = onp.asarray(pmap(f, "i")(values)) self.assertAllClose(ans, expected, check_dtypes=False)
def testCollectivePermute2D(self): perm = np.array([3, 1, 2, 0]) x = jnp.arange(4).reshape((2, 2)) result = xmap(lambda x: lax.pshuffle(x, ('i', 'j'), perm), in_axes=['i', 'j', ...], out_axes=['i', 'j', ...], axis_resources={'i': 'x', 'j': 'y'})(x).reshape((-1,)) self.assertAllClose(result, perm)
def testPShuffleWithBadPerm(self): device_count = xla_bridge.device_count() bad_perm = list(range(device_count)) bad_perm[0] = 1 f = lambda x: lax.pshuffle(x, perm=bad_perm, axis_name='i') g = lambda: pmap(f, "i")(onp.arange(device_count)) self.assertRaisesRegex( AssertionError, "Given `perm` does not represent a real permutation: \\[1.*\\]", g)