def allgather(self, tensor_list, tensor, allgather_options=AllGatherOptions()): """Allgather tensors across the group into a list of tensors. Args: tensor_list: the tensor list to store the results. tensor: the tensor to be allgather-ed across the group. allgather_options: allgather options. Returns: None """ _check_inputs_compatibility_for_scatter_gather(tensor, tensor_list) comm = self._get_nccl_communicator() stream = self._get_cuda_stream() dtype = nccl_util.get_nccl_tensor_dtype(tensor) send_ptr = nccl_util.get_tensor_ptr(tensor) n_elems = nccl_util.get_tensor_n_elements(tensor) flattened = _flatten_for_scatter_gather(tensor_list, copy=False) recv_ptr = nccl_util.get_tensor_ptr(flattened) comm.allGather(send_ptr, recv_ptr, n_elems, dtype, stream.ptr) for i, t in enumerate(tensor_list): nccl_util.copy_tensor(t, flattened[i])
def _flatten_for_scatter_gather(tensor_list, copy=False): """Flatten the tensor for gather/scatter operations. Args: tensor_list: the list of tensors to be scattered/gathered. copy: whether the copy the tensors in tensor_list into the buffer. Returns: The flattened tensor buffer. """ if not tensor_list: raise RuntimeError("Received an empty list.") t = tensor_list[0] # note we need a cupy dtype here. dtype = nccl_util.get_cupy_tensor_dtype(t) buffer_shape = [len(tensor_list)] + nccl_util.get_tensor_shape(t) buffer = cupy.empty(buffer_shape, dtype=dtype) if copy: for i, tensor in enumerate(tensor_list): nccl_util.copy_tensor(buffer[i], tensor) return buffer
def preprocess_fn(stream): for i, tensor_list in enumerate(tensor_lists): for j, tensor in enumerate(tensor_list): nccl_util.copy_tensor(input_flattened[i][j], tensor)
def postprocess_fn(stream): # TODO(Hao): designate a copy stream. for i, tensor_list in enumerate(tensor_lists): for j, tensor in enumerate(tensor_list): nccl_util.copy_tensor(tensor, output_flattened[i][j])