Exemple #1
0
    def _get_nccl_collective_communicator(self, comm_key, device_list):
        """Create or retrieve an NCCL communicator from cache.

        If the communicator is found in cache, return the communicator. If not,
        a communicator and a stream will be created and put in cache.
        TODO(Hao): this function is not thread-safe now.

        Args:
            comm_key (str): the key to query the communicator cache.
            device_list (List): a list of GPU devices of the current process
                                that participates into the collective.

        Returns:
            communicator: the NCCL communicator corresponded to the devices.
        """
        if not comm_key:
            raise RuntimeError("Got empty communicator key.")
        for d in device_list:
            self._used_gpu_indices.add(d)

        # TODO(Hao): lock the _dev_comm_map here.
        if comm_key in self._dev_comm_map:
            return self._dev_comm_map[comm_key]

        group_key = self._generate_group_key(comm_key)
        if self.rank == 0:
            nccl_uid = self._generate_nccl_uid(group_key)
        else:
            rendezvous = Rendezvous(group_key)
            rendezvous.meet()
            nccl_uid = rendezvous.get_nccl_id()

        # Now create the communicators
        actual_world_size = len(device_list) * self.world_size
        comms = [None] * len(device_list)
        streams = [None] * len(device_list)
        events = [None] * len(device_list)
        nccl_util.groupStart()
        for i, device in enumerate(device_list):
            actual_rank = self.rank * len(device_list) + i
            with nccl_util.Device(device):
                comms[i] = nccl_util.create_nccl_communicator(
                    actual_world_size, nccl_uid, actual_rank
                )
                # request a stream from the pool
                # note the device_idx is absolute index.
                streams[i] = get_stream_pool(device).get_stream()
                # TODO(Fu): double check the parameters
                events[i] = cupy.cuda.Event()
        nccl_util.groupEnd()
        # TODO(Fu): lock
        self._dev_comm_map[comm_key] = comms
        self._dev_streams_map[comm_key] = streams
        self._dev_event_map[comm_key] = events
        return comms
Exemple #2
0
    def _collective(
        self,
        input_tensors,
        output_tensors,
        collective_fn,
        preprocess_fn=None,
        postprocess_fn=None,
    ):
        """A method to encapsulate all collective calls.

        Args:
            input_tensors: the list of the input tensors.
            output_tensors: the list of the output tensors.
            collective_fn: the collective function call.
            preprocess_fn: preprocess procedures before collective calls.
            postprocess_fn: postprocess procedures after collective calls.

        Returns:
            None
        """
        _check_gpu_tensors(input_tensors)
        _check_gpu_tensors(output_tensors)

        devices = nccl_util.get_tensor_device_list(input_tensors)
        key = _get_comm_key_from_devices(devices)
        comms = self._get_nccl_collective_communicator(key, devices)
        streams = self._dev_streams_map[key]
        events = self._dev_event_map[key]

        # TODO(Hao): sync streams and events
        self._sync_streams(devices, events, streams)

        # Make the collective call
        if preprocess_fn:
            preprocess_fn(streams)

        nccl_util.groupStart()
        # TODO(Fu): how to recordStreams as there are no library functions
        # We also need to make sure input tensors are not freed before their
        # usages on ncclStreams finish. This can be achieved by calling
        # c10::cuda::CUDACachingAllocator::recordStream, which remembers the
        # usage stream (ncclStream), creates an event on the usage stream
        # when GC attempts to free the input tensor, and delays GC until that
        # event is done.
        for i, tensor in enumerate(input_tensors):
            collective_fn(tensor, output_tensors[i], comms[i], streams[i])
        nccl_util.groupEnd()
        if postprocess_fn:
            postprocess_fn(streams)
    def _collective(self,
                    input_tensors,
                    output_tensors,
                    collective_fn,
                    preprocess_fn=None,
                    postprocess_fn=None):
        """A method to encapsulate all collective calls.

        Args:
            input_tensors: the list of the input tensors.
            output_tensors: the list of the output tensors.
            collective_fn: the collective function call.
            preprocess_fn: preprocess procedures before collective calls.
            postprocess_fn: postprocess procedures after collective calls.

        Returns:
            None
        """
        _check_gpu_tensors(input_tensors)
        _check_gpu_tensors(output_tensors)

        devices = nccl_util.get_tensor_device_list(input_tensors)
        key = _get_comm_key_from_devices(devices)
        comms = self._get_nccl_collective_communicator(key, devices)
        streams = self._dev_streams_map[key]

        # TODO(Hao): sync streams and events
        self._sync_streams()

        # Make the collective call
        if preprocess_fn:
            preprocess_fn(streams)
        nccl_util.groupStart()
        for i, tensor in enumerate(input_tensors):
            collective_fn(tensor, output_tensors[i], comms[i], streams[i])
        nccl_util.groupEnd()
        if postprocess_fn:
            postprocess_fn(streams)