def send_right(x, axis_name): left_perm = [(i, (i + 1) % device_count) for i in range(device_count)] return lax.ppermute(x, perm=left_perm, axis_name=axis_name)
def protate(x, axis_name): n = lax.psum(1, axis_name) return lax.ppermute(x, axis_name, [(i, (i + 1) % n) for i in range(n)])