def reduce(self, tensors, reduce_options=ReduceOptions()): """Reduce tensors following options. Args: tensors (List): the list of tensors to be reduced, this list only have one tensor. reduce_options: reduce options. Returns: None """ root_rank = reduce_options.root_rank def collective_fn(input_tensor, output_tensor, context): pygloo.reduce( context, gloo_util.get_tensor_ptr(input_tensor), gloo_util.get_tensor_ptr(output_tensor), gloo_util.get_tensor_n_elements(input_tensor), gloo_util.get_gloo_tensor_dtype(input_tensor), gloo_util.get_gloo_reduce_op(reduce_options.reduceOp), root_rank, ) self._collective(tensors, tensors, collective_fn)
def reduce(self, tensors, reduce_options=ReduceOptions()): """Reduce tensors to a destination gpu following options. Args: tensors (List): the list of tensors to be reduced, each tensor must reside on one gpu of the current process. reduce_options: reduce options. Returns: None """ root_rank = len(tensors) * reduce_options.root_rank + reduce_options.root_tensor def collective_fn(input_tensor, output_tensor, comm, stream): comm.reduce( nccl_util.get_tensor_ptr(input_tensor), nccl_util.get_tensor_ptr(output_tensor), nccl_util.get_tensor_n_elements(input_tensor), nccl_util.get_nccl_tensor_dtype(input_tensor), nccl_util.get_nccl_reduce_op(reduce_options.reduceOp), root_rank, stream.ptr, ) self._collective(tensors, tensors, collective_fn)
def reduce(self, tensor, reduce_options=ReduceOptions()): """Reduce tensor to a destination process following options. Args: tensor: the tensor to be reduced. reduce_options: reduce options Returns: None """ comm = self._get_nccl_communicator() stream = self._get_cuda_stream() dtype = nccl_util.get_nccl_tensor_dtype(tensor) ptr = nccl_util.get_tensor_ptr(tensor) n_elems = nccl_util.get_tensor_n_elements(tensor) reduce_op = nccl_util.get_nccl_reduce_op(reduce_options.reduceOp) # in-place reduce comm.reduce(ptr, ptr, n_elems, dtype, reduce_op, reduce_options.root_rank, stream.ptr)
def reduce(self, tensor, reduce_options=ReduceOptions()): raise NotImplementedError()