示例#1
0
    def allgather(self,
                  tensor_list,
                  tensor,
                  allgather_options=AllGatherOptions()):
        """Allgather tensors across the group into a list of  tensors.

        Args:
            tensor_list: the tensor list to store the results.
            tensor: the tensor to be allgather-ed across the group.
            allgather_options: allgather options.

        Returns:
            None
        """

        _check_inputs_compatibility_for_scatter_gather(tensor, tensor_list)
        comm = self._get_nccl_communicator()
        stream = self._get_cuda_stream()

        dtype = nccl_util.get_nccl_tensor_dtype(tensor)
        send_ptr = nccl_util.get_tensor_ptr(tensor)
        n_elems = nccl_util.get_tensor_n_elements(tensor)
        flattened = _flatten_for_scatter_gather(tensor_list, copy=False)
        recv_ptr = nccl_util.get_tensor_ptr(flattened)
        comm.allGather(send_ptr, recv_ptr, n_elems, dtype, stream.ptr)
        for i, t in enumerate(tensor_list):
            nccl_util.copy_tensor(t, flattened[i])
示例#2
0
def _flatten_for_scatter_gather(tensor_list, copy=False):
    """Flatten the tensor for gather/scatter operations.

    Args:
        tensor_list: the list of tensors to be scattered/gathered.
        copy: whether the copy the tensors in tensor_list into the buffer.

    Returns:
        The flattened tensor buffer.
    """
    if not tensor_list:
        raise RuntimeError("Received an empty list.")
    t = tensor_list[0]
    # note we need a cupy dtype here.
    dtype = nccl_util.get_cupy_tensor_dtype(t)
    buffer_shape = [len(tensor_list)] + nccl_util.get_tensor_shape(t)
    buffer = cupy.empty(buffer_shape, dtype=dtype)
    if copy:
        for i, tensor in enumerate(tensor_list):
            nccl_util.copy_tensor(buffer[i], tensor)
    return buffer
示例#3
0
 def preprocess_fn(stream):
     for i, tensor_list in enumerate(tensor_lists):
         for j, tensor in enumerate(tensor_list):
             nccl_util.copy_tensor(input_flattened[i][j], tensor)
示例#4
0
 def postprocess_fn(stream):
     # TODO(Hao): designate a copy stream.
     for i, tensor_list in enumerate(tensor_lists):
         for j, tensor in enumerate(tensor_list):
             nccl_util.copy_tensor(tensor, output_flattened[i][j])