def allreduce(self, tensors, allreduce_options=AllReduceOptions()): """AllReduce a list of tensors following options. Args: tensor: the tensor to be reduced, each tensor locates on CPU allreduce_options: Returns: None """ def collective_fn(input_tensor, output_tensor, context): pygloo.allreduce( context, gloo_util.get_tensor_ptr(input_tensor), gloo_util.get_tensor_ptr(output_tensor), gloo_util.get_tensor_n_elements(input_tensor), gloo_util.get_gloo_tensor_dtype(input_tensor), gloo_util.get_gloo_reduce_op(allreduce_options.reduceOp)) self._collective(tensors, tensors, collective_fn)
def allreduce(self, tensors, allreduce_options=AllReduceOptions()): """AllReduce tensors across the collective group following options. Args: tensors (List): the list of tensors to be reduced. Each tensor must reside on one GPU of the current process. allreduce_options: allreduce options. Returns: None """ def collective_fn(input_tensor, output_tensor, comm, stream): comm.allReduce( nccl_util.get_tensor_ptr(input_tensor), nccl_util.get_tensor_ptr(output_tensor), nccl_util.get_tensor_n_elements(input_tensor), nccl_util.get_nccl_tensor_dtype(input_tensor), nccl_util.get_nccl_reduce_op(allreduce_options.reduceOp), stream.ptr) self._collective(tensors, tensors, collective_fn)
def allreduce(self, tensor, allreduce_options=AllReduceOptions()): """AllReduce the tensor across the collective group following options. Args: tensor: the tensor to be reduced, each tensor locates on a GPU allreduce_options: Returns: """ # obtain the communicator comm = self._get_nccl_communicator() # obtain the stream: using default stream by now # TODO(Hao): implement a simple stream manager here stream = self._get_cuda_stream() dtype = nccl_util.get_nccl_tensor_dtype(tensor) ptr = nccl_util.get_tensor_ptr(tensor) n_elems = nccl_util.get_tensor_n_elements(tensor) reduce_op = nccl_util.get_nccl_reduce_op(allreduce_options.reduceOp) # in-place allreduce comm.allReduce(ptr, ptr, n_elems, dtype, reduce_op, stream.ptr)
def allreduce(self, tensor, allreduce_options=AllReduceOptions()): raise NotImplementedError()