Exemple #1
0
 def testCollectivePermute1D(self):
   perm = np.array([3, 1, 2, 0])
   x = jnp.arange(4)
   result = xmap(lambda x: lax.pshuffle(x, 'i', perm),
                 in_axes=['i', ...],
                 out_axes=['i', ...])(x)
   self.assertAllClose(result, perm)
Exemple #2
0
 def testCollectivePermuteCyclicWithPShuffle(self):
   device_count = xla_bridge.device_count()
   values = onp.arange(device_count)
   shift_right = [(i - 1) % device_count for i in range(device_count)]
   f = lambda x: lax.pshuffle(x, perm=shift_right, axis_name='i')
   expected = onp.roll(values, -1)
   ans = onp.asarray(pmap(f, "i")(values))
   self.assertAllClose(ans, expected, check_dtypes=False)
Exemple #3
0
 def testCollectivePermute2D(self):
   perm = np.array([3, 1, 2, 0])
   x = jnp.arange(4).reshape((2, 2))
   result = xmap(lambda x: lax.pshuffle(x, ('i', 'j'), perm),
                 in_axes=['i', 'j', ...],
                 out_axes=['i', 'j', ...],
                 axis_resources={'i': 'x', 'j': 'y'})(x).reshape((-1,))
   self.assertAllClose(result, perm)
Exemple #4
0
 def testPShuffleWithBadPerm(self):
   device_count = xla_bridge.device_count()
   bad_perm = list(range(device_count))
   bad_perm[0] = 1
   f = lambda x: lax.pshuffle(x, perm=bad_perm, axis_name='i')
   g = lambda: pmap(f, "i")(onp.arange(device_count))
   self.assertRaisesRegex(
     AssertionError,
     "Given `perm` does not represent a real permutation: \\[1.*\\]", g)