コード例 #1
0
    def alltoall(self, x, mesh_axis, split_axis, concat_axis):
        """Grouped alltoall (like MPI alltoall with splitting and concatenation).

    Args:
      x: a LaidOutTensor
      mesh_axis: an integer the mesh axis along which to group
      split_axis: an integer (the Tensor axis along which to split)
      concat_axis: an integer (the Tensor axis along which to concatenate)
    Returns:
      a LaidOutTensor
    """
        x = x.to_laid_out_tensor()
        t = x.one_slice
        group_assignment = self._create_group_assignment([mesh_axis])
        dtype = t.dtype
        if dtype == tf.float32:
            # There seems to be a bug with float32 alltoall.
            # Do it in bfloat16 until the bug is fixed.
            # TODO(noam): file a bug
            t = tf.to_bfloat16(t)
        t = tpu_ops.all_to_all(t,
                               concat_dimension=concat_axis,
                               split_dimension=split_axis,
                               split_count=len(group_assignment[0]),
                               group_assignment=group_assignment)
        t = tf.cast(t, dtype)
        x = self.LaidOutTensor([t])
        return x
コード例 #2
0
  def alltoall(self, x, mesh_axis, split_axis, concat_axis):
    """Grouped alltoall (like MPI alltoall with splitting and concatenation).

    Args:
      x: a LaidOutTensor
      mesh_axis: an integer the mesh axis along which to group
      split_axis: an integer (the Tensor axis along which to split)
      concat_axis: an integer (the Tensor axis along which to concatenate)
    Returns:
      a LaidOutTensor
    """
    x = x.to_laid_out_tensor()
    t = x.one_slice
    group_assignment = self._create_group_assignment([mesh_axis])
    t = tpu_ops.all_to_all(
        t,
        concat_dimension=concat_axis,
        split_dimension=split_axis,
        split_count=len(group_assignment[0]),
        group_assignment=group_assignment)
    x = self.LaidOutTensor([t])
    return x