Esempio n. 1
0
    def reducescatter(self,
                      tensor,
                      tensor_list,
                      reducescatter_options=ReduceScatterOptions()):
        """Reducescatter a list of tensors across the group.

        Args:
            tensor: the output after reducescatter (could be unspecified).
            tensor_list: the list of tensor to be reduce and scattered.
            reducescatter_options: reducescatter options.

        Returns:
            None
        """
        _check_inputs_compatibility_for_scatter_gather(tensor, tensor_list)

        comm = self._get_nccl_communicator()
        stream = self._get_cuda_stream()
        dtype = nccl_util.get_nccl_tensor_dtype(tensor_list[0])
        n_elems = nccl_util.get_tensor_n_elements(tensor_list[0])
        reduce_op = nccl_util.get_nccl_reduce_op(
            reducescatter_options.reduceOp)

        # get the send_ptr
        flattened = _flatten_for_scatter_gather(tensor_list, copy=True)
        send_ptr = nccl_util.get_tensor_ptr(flattened)
        recv_ptr = nccl_util.get_tensor_ptr(tensor)
        comm.reduceScatter(send_ptr, recv_ptr, n_elems, dtype, reduce_op,
                           stream.ptr)
Esempio n. 2
0
    def reducescatter(self,
                      tensors,
                      tensor_lists,
                      reducescatter_options=ReduceScatterOptions()):
        """Reduce the scatter a list of tensors across the group.

        Args:
            tensors (List): the output tensors (could be unspecified), each
                            located on CPU.
            tensor_lists (List[List]): the list of tensors to be reduced then
                                       scattered.
            reducescatter_options: reduce-scatter options.

        Returns:
            None
        """
        def collective_fn(input_tensor, output_tensor, context):
            size = gloo_util.get_tensor_n_elements(input_tensor)
            world_size = self._gloo_context.size
            pygloo.reduce_scatter(
                context,
                gloo_util.get_tensor_ptr(input_tensor),
                gloo_util.get_tensor_ptr(output_tensor),
                size,
                [size // world_size for _ in range(world_size)],
                gloo_util.get_gloo_tensor_dtype(output_tensor),
                gloo_util.get_gloo_reduce_op(reducescatter_options.reduceOp),
            )

        _check_inputs_compatibility_for_scatter_gather(tensors, tensor_lists)
        input_flattened = [
            _flatten_for_scatter_gather(tensor_list, copy=False)
            for tensor_list in tensor_lists
        ]

        def preprocess_fn():
            for i, tensor_list in enumerate(tensor_lists):
                for j, tensor in enumerate(tensor_list):
                    gloo_util.copy_tensor(input_flattened[i][j], tensor)

        self._collective(input_flattened,
                         tensors,
                         collective_fn,
                         preprocess_fn=preprocess_fn)
Esempio n. 3
0
    def reducescatter(self,
                      tensors,
                      tensor_lists,
                      reducescatter_options=ReduceScatterOptions()):
        """Reduce then scatter a list of tensors across the group.

        Args:
            tensors (List): the output tensors (could be unspecified), each
                            located on a GPU of the current process.
            tensor_lists (List[List]): the list of tensors to be reduced then
                                       scattered.
            reducescatter_options: reduce-scatter options.

        Returns:
            None
        """

        def collective_fn(input_tensor, output_tensor, comm, stream):
            comm.reduceScatter(
                nccl_util.get_tensor_ptr(input_tensor),
                nccl_util.get_tensor_ptr(output_tensor),
                nccl_util.get_tensor_n_elements(output_tensor),
                nccl_util.get_nccl_tensor_dtype(output_tensor),
                nccl_util.get_nccl_reduce_op(reducescatter_options.reduceOp),
                stream.ptr)

        _check_inputs_compatibility_for_scatter_gather(tensors, tensor_lists)
        input_flattened = [
            _flatten_for_scatter_gather(tensor_list, copy=False)
            for tensor_list in tensor_lists
        ]

        def preprocess_fn(stream):
            for i, tensor_list in enumerate(tensor_lists):
                for j, tensor in enumerate(tensor_list):
                    nccl_util.copy_tensor(input_flattened[i][j], tensor)

        self._collective(
            input_flattened,
            tensors,
            collective_fn,
            preprocess_fn=preprocess_fn)
Esempio n. 4
0
 def reducescatter(self,
                   tensor,
                   tensor_list,
                   reducescatter_options=ReduceScatterOptions()):
     raise NotImplementedError()