Esempio n. 1
0
 def send_right(x, axis_name):
     left_perm = [(i, (i + 1) % device_count)
                  for i in range(device_count)]
     return lax.ppermute(x, perm=left_perm, axis_name=axis_name)
Esempio n. 2
0
 def protate(x, axis_name):
     n = lax.psum(1, axis_name)
     return lax.ppermute(x, axis_name,
                         [(i, (i + 1) % n) for i in range(n)])