def alltoall(self, x, mesh_axis, split_axis, concat_axis): """Grouped alltoall (like MPI alltoall with splitting and concatenation). Args: x: a LaidOutTensor mesh_axis: an integer the mesh axis along which to group split_axis: an integer (the Tensor axis along which to split) concat_axis: an integer (the Tensor axis along which to concatenate) Returns: a LaidOutTensor """ x = x.to_laid_out_tensor() t = x.one_slice group_assignment = self._create_group_assignment([mesh_axis]) dtype = t.dtype if dtype == tf.float32: # There seems to be a bug with float32 alltoall. # Do it in bfloat16 until the bug is fixed. # TODO(noam): file a bug t = tf.to_bfloat16(t) t = tpu_ops.all_to_all(t, concat_dimension=concat_axis, split_dimension=split_axis, split_count=len(group_assignment[0]), group_assignment=group_assignment) t = tf.cast(t, dtype) x = self.LaidOutTensor([t]) return x
def alltoall(self, x, mesh_axis, split_axis, concat_axis): """Grouped alltoall (like MPI alltoall with splitting and concatenation). Args: x: a LaidOutTensor mesh_axis: an integer the mesh axis along which to group split_axis: an integer (the Tensor axis along which to split) concat_axis: an integer (the Tensor axis along which to concatenate) Returns: a LaidOutTensor """ x = x.to_laid_out_tensor() t = x.one_slice group_assignment = self._create_group_assignment([mesh_axis]) t = tpu_ops.all_to_all( t, concat_dimension=concat_axis, split_dimension=split_axis, split_count=len(group_assignment[0]), group_assignment=group_assignment) x = self.LaidOutTensor([t]) return x