Пример #1
0
    def broadcast(self, tensors, broadcast_options=BroadcastOptions()):
        """Broadcast tensors to all other gpus following options.

        Args:
            tensors (List): tensors to be broadcast or received.
            broadcast_options: broadcast options.

        Returns:
            None
        """
        root_rank = (
            len(tensors) * broadcast_options.root_rank + broadcast_options.root_tensor
        )

        def collective_fn(input_tensor, output_tensor, comm, stream):
            comm.broadcast(
                nccl_util.get_tensor_ptr(input_tensor),
                nccl_util.get_tensor_ptr(output_tensor),
                nccl_util.get_tensor_n_elements(input_tensor),
                nccl_util.get_nccl_tensor_dtype(input_tensor),
                root_rank,
                stream.ptr,
            )

        self._collective(tensors, tensors, collective_fn)
Пример #2
0
    def broadcast(self, tensor, broadcast_options=BroadcastOptions()):
        """Broadcast tensor to all other processes following options.

        Args:
            tensor: the tensor to be broadcasted.
            broadcast_options: broadcast options.

        Returns:
            None
        """
        comm = self._get_nccl_communicator()
        stream = self._get_cuda_stream()

        dtype = nccl_util.get_nccl_tensor_dtype(tensor)
        ptr = nccl_util.get_tensor_ptr(tensor)
        n_elems = nccl_util.get_tensor_n_elements(tensor)
        # in-place broadcast
        comm.broadcast(ptr, ptr, n_elems, dtype, broadcast_options.root_rank,
                       stream.ptr)
Пример #3
0
    def broadcast(self, tensors, broadcast_options=BroadcastOptions()):
        """Broadcast tensors to all other processes following options.

        Args:
            tensors (List): tensors to be broadcast or received.
            broadcast_options: broadcast options.

        Returns:
            None
        """
        root_rank = broadcast_options.root_rank

        def collective_fn(input_tensor, output_tensor, context):
            pygloo.broadcast(context, gloo_util.get_tensor_ptr(input_tensor),
                             gloo_util.get_tensor_ptr(output_tensor),
                             gloo_util.get_tensor_n_elements(input_tensor),
                             gloo_util.get_gloo_tensor_dtype(input_tensor),
                             root_rank)

        self._collective(tensors, tensors, collective_fn)
Пример #4
0
 def broadcast(self, tensor, broadcast_options=BroadcastOptions()):
     raise NotImplementedError()