def _get_nccl_collective_communicator(self, comm_key, device_list): """Create or retrieve an NCCL communicator from cache. If the communicator is found in cache, return the communicator. If not, a communicator and a stream will be created and put in cache. TODO(Hao): this function is not thread-safe now. Args: comm_key (str): the key to query the communicator cache. device_list (List): a list of GPU devices of the current process that participates into the collective. Returns: communicator: the NCCL communicator corresponded to the devices. """ if not comm_key: raise RuntimeError("Got empty communicator key.") for d in device_list: self._used_gpu_indices.add(d) # TODO(Hao): lock the _dev_comm_map here. if comm_key in self._dev_comm_map: return self._dev_comm_map[comm_key] group_key = self._generate_group_key(comm_key) if self.rank == 0: nccl_uid = self._generate_nccl_uid(group_key) else: rendezvous = Rendezvous(group_key) rendezvous.meet() nccl_uid = rendezvous.get_nccl_id() # Now create the communicators actual_world_size = len(device_list) * self.world_size comms = [None] * len(device_list) streams = [None] * len(device_list) events = [None] * len(device_list) nccl_util.groupStart() for i, device in enumerate(device_list): actual_rank = self.rank * len(device_list) + i with nccl_util.Device(device): comms[i] = nccl_util.create_nccl_communicator( actual_world_size, nccl_uid, actual_rank ) # request a stream from the pool # note the device_idx is absolute index. streams[i] = get_stream_pool(device).get_stream() # TODO(Fu): double check the parameters events[i] = cupy.cuda.Event() nccl_util.groupEnd() # TODO(Fu): lock self._dev_comm_map[comm_key] = comms self._dev_streams_map[comm_key] = streams self._dev_event_map[comm_key] = events return comms
def _collective( self, input_tensors, output_tensors, collective_fn, preprocess_fn=None, postprocess_fn=None, ): """A method to encapsulate all collective calls. Args: input_tensors: the list of the input tensors. output_tensors: the list of the output tensors. collective_fn: the collective function call. preprocess_fn: preprocess procedures before collective calls. postprocess_fn: postprocess procedures after collective calls. Returns: None """ _check_gpu_tensors(input_tensors) _check_gpu_tensors(output_tensors) devices = nccl_util.get_tensor_device_list(input_tensors) key = _get_comm_key_from_devices(devices) comms = self._get_nccl_collective_communicator(key, devices) streams = self._dev_streams_map[key] events = self._dev_event_map[key] # TODO(Hao): sync streams and events self._sync_streams(devices, events, streams) # Make the collective call if preprocess_fn: preprocess_fn(streams) nccl_util.groupStart() # TODO(Fu): how to recordStreams as there are no library functions # We also need to make sure input tensors are not freed before their # usages on ncclStreams finish. This can be achieved by calling # c10::cuda::CUDACachingAllocator::recordStream, which remembers the # usage stream (ncclStream), creates an event on the usage stream # when GC attempts to free the input tensor, and delays GC until that # event is done. for i, tensor in enumerate(input_tensors): collective_fn(tensor, output_tensors[i], comms[i], streams[i]) nccl_util.groupEnd() if postprocess_fn: postprocess_fn(streams)
def _collective(self, input_tensors, output_tensors, collective_fn, preprocess_fn=None, postprocess_fn=None): """A method to encapsulate all collective calls. Args: input_tensors: the list of the input tensors. output_tensors: the list of the output tensors. collective_fn: the collective function call. preprocess_fn: preprocess procedures before collective calls. postprocess_fn: postprocess procedures after collective calls. Returns: None """ _check_gpu_tensors(input_tensors) _check_gpu_tensors(output_tensors) devices = nccl_util.get_tensor_device_list(input_tensors) key = _get_comm_key_from_devices(devices) comms = self._get_nccl_collective_communicator(key, devices) streams = self._dev_streams_map[key] # TODO(Hao): sync streams and events self._sync_streams() # Make the collective call if preprocess_fn: preprocess_fn(streams) nccl_util.groupStart() for i, tensor in enumerate(input_tensors): collective_fn(tensor, output_tensors[i], comms[i], streams[i]) nccl_util.groupEnd() if postprocess_fn: postprocess_fn(streams)