def _pin_memory_loop(in_queue, out_queue, device_id, done_event): # This setting is thread local, and prevents the copy in pin_memory from # consuming all CPU cores. torch.set_num_threads(1) torch.cuda.set_device(device_id) # See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on the # logic of this function. while not done_event.is_set(): try: r = in_queue.get(timeout=MP_STATUS_CHECK_INTERVAL) except queue.Empty: continue idx, data = r if not done_event.is_set() and not isinstance(data, ExceptionWrapper): try: data = pin_memory(data) except Exception: data = ExceptionWrapper( where="in pin memory thread for device {}".format( device_id)) r = (idx, data) while not done_event.is_set(): try: out_queue.put(r, timeout=MP_STATUS_CHECK_INTERVAL) break except queue.Full: continue del r # save memory
def _work_loss(GPU, inputs, targets, batch, total_batch): try: cosine = self.LM(inputs, targets=targets, batch=batch, total_batch=total_batch) with lock: Cosine_MutiGPU[GPU] = cosine except: ExceptionWrapper(where="LOSSES{GPU} is error".format(GPU))
def wx_norm(sub_weights, temp_x, GPU): try: cosine = F.linear(F.normalize(temp_x), F.normalize(sub_weights)) # cosine = F.linear(F.normalize(temp_x), sub_weights) cosine = torch.chunk(cosine, len(self.GPUS)) with lock: results[GPU] = cosine except Exception: with lock: results[GPU] = ExceptionWrapper( where="weight{GPU} is error".format(GPU))
def _worker(i, module, input, kwargs, device=None): torch.set_grad_enabled(grad_enabled) if device is None: device = get_a_var(input).get_device() try: with torch.cuda.device(device): # this also avoids accidental slicing of `input` if it is a Tensor if not isinstance(input, (list, tuple)): input = (input,) output = module.predict(*input, **kwargs) with lock: results[i] = output except Exception: with lock: results[i] = ExceptionWrapper(where="in replica {} on device {}".format(i, device))
def _worker(i, module, input, target, kwargs, device=None): torch.set_grad_enabled(grad_enabled) if device is None: device = get_a_var(input).get_device() try: with torch.cuda.device(device): if not isinstance(input, (list, tuple)): input = (input, ) if not isinstance(target, (list, tuple)): target = (target, ) output = module(*input, *target, **kwargs) with lock: results[i] = output except Exception: with lock: results[i] = ExceptionWrapper( where="in replica {} on device {}".format(i, device))
def do_one_step(): try: r = in_queue.get(timeout=MP_STATUS_CHECK_INTERVAL) except queue.Empty: return idx, data = r if not done_event.is_set() and not isinstance(data, ExceptionWrapper): try: data = pin_memory(data, device) except Exception: data = ExceptionWrapper( where="in pin memory thread for device {}".format( device_id)) r = (idx, data) while not done_event.is_set(): try: out_queue.put(r, timeout=MP_STATUS_CHECK_INTERVAL) break except queue.Full: continue
async def _single_task(in_queue, out_queue, device_id, done_event, async_sleep): while not done_event.is_set(): try: task = in_queue.get_nowait() except queue.Empty: await asyncio.sleep(async_sleep) continue if not done_event.is_set() and not isinstance(task, ExceptionWrapper): try: task = pin_memory(task) except Exception: task = ExceptionWrapper( where="in pin memory thread for device {}".format( device_id)) while not done_event.is_set(): try: out_queue.put(task, timeout=MP_STATUS_CHECK_INTERVAL) break except queue.Full: continue del task
def _worker_loop(dataset_kind, dataset, index_queue, data_queue, done_event, auto_collation, collate_fn, drop_last, seed, init_fn, worker_id, num_workers): # See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on the # logic of this function. try: # Initialize C side signal handlers for SIGBUS and SIGSEGV. Python signal # module's handlers are executed after Python returns from C low-level # handlers, likely when the same fatal signal had already happened # again. # https://docs.python.org/3/library/signal.html#execution-of-python-signal-handlers signal_handling._set_worker_signal_handlers() torch.set_num_threads(1) random.seed(seed) torch.manual_seed(seed) global _worker_info _worker_info = WorkerInfo(id=worker_id, num_workers=num_workers, seed=seed, dataset=dataset) from torch.utils.data import _DatasetKind init_exception = None try: if init_fn is not None: init_fn(worker_id) fetcher = _DatasetKind.create_fetcher(dataset_kind, dataset, auto_collation, collate_fn, drop_last) except Exception: init_exception = ExceptionWrapper( where="in DataLoader worker process {}".format(worker_id)) # When using Iterable mode, some worker can exit earlier than others due # to the IterableDataset behaving differently for different workers. # When such things happen, an `_IterableDatasetStopIteration` object is # sent over to the main process with the ID of this worker, so that the # main process won't send more tasks to this worker, and will send # `None` to this worker to properly exit it. # # Note that we cannot set `done_event` from a worker as it is shared # among all processes. Instead, we set the `iteration_end` flag to # signify that the iterator is exhausted. When either `done_event` or # `iteration_end` is set, we skip all processing step and just wait for # `None`. iteration_end = False watchdog = ManagerWatchdog() while watchdog.is_alive(): try: r = index_queue.get(timeout=MP_STATUS_CHECK_INTERVAL) except queue.Empty: continue if r is None: # Received the final signal assert done_event.is_set() or iteration_end break elif done_event.is_set() or iteration_end: # `done_event` is set. But I haven't received the final signal # (None) yet. I will keep continuing until get it, and skip the # processing steps. continue idx, index = r if init_exception is not None: data = init_exception init_exception = None else: try: data = fetcher.fetch(index) except Exception as e: if isinstance(e, StopIteration ) and dataset_kind == _DatasetKind.Iterable: data = _IterableDatasetStopIteration(worker_id) # Set `iteration_end` # (1) to save future `next(...)` calls, and # (2) to avoid sending multiple `_IterableDatasetStopIteration`s. iteration_end = True else: # It is important that we don't store exc_info in a variable. # `ExceptionWrapper` does the correct thing. # See NOTE [ Python Traceback Reference Cycle Problem ] data = ExceptionWrapper( where="in DataLoader worker process {}".format( worker_id)) data_queue.put((idx, data)) del data, idx, index, r # save memory except KeyboardInterrupt: # Main process will raise KeyboardInterrupt anyways. pass if done_event.is_set(): data_queue.cancel_join_thread() data_queue.close()
def _worker_loop(dataset, index_queue, data_queue, done_event, collate_fn, seed, init_fn, worker_id): try: # Initialize C side signal handlers for SIGBUS and SIGSEGV. Python signal # module's handlers are executed after Python returns from C low-level # handlers, likely when the same fatal signal had already happened # again. # https://docs.python.org/3/library/signal.html#execution-of-python-signal-handlers # signal_handling._set_worker_signal_handlers() torch.set_num_threads(1) random.seed(seed) torch.manual_seed(seed) # global _worker_info # _worker_info = WorkerInfo(id=worker_id, num_workers=num_workers, # seed=seed, dataset=dataset) init_exception = None try: if init_fn is not None: init_fn(worker_id) fetcher = _MapDatasetFetcher(dataset, collate_fn) except Exception: init_exception = ExceptionWrapper( where="in DataLoader worker process {}".format(worker_id)) iteration_end = False watchdog = ManagerWatchdog() while watchdog.is_alive(): try: # 尝试获取 batch index r = index_queue.get(timeout=5.0) except queue.Empty: # 队列空,则继续循环,直到有数据 continue # 如果外面插入了 _ResumeIteration,表示该epoch worker运行完成,也不要退出,复用 if isinstance(r, _ResumeIteration): # Acknowledge the main process data_queue.put(r) iteration_end = False # Recreate the fetcher for worker-reuse policy fetcher = _MapDatasetFetcher(dataset, collate_fn) continue # 接收到 None,表示 该进程出现了某种异常,所以该进程就要销毁 elif r is None: # 接收到完成训练,可以终止该 worker 进程 assert done_event.is_set() or iteration_end break elif done_event.is_set() or iteration_end: # 跳过当前 # `done_event` is set. But I haven't received the final signal # (None) yet. I will keep continuing until get it, and skip the # processing steps. continue idx, index = r # 如果出现异常,那么当前数据直接返回 if init_exception is not None: data = init_exception init_exception = None else: try: # dtaset+collate_fn 进行组成 Batch data data = fetcher.fetch(index) except Exception as e: data = ExceptionWrapper( where="in DataLoader worker process {}".format(worker_id)) data_queue.put((idx, data)) # 插入共享队列 del data, idx, index, r # save memory except KeyboardInterrupt: # Main process will raise KeyboardInterrupt anyways. pass # 主进程发送了退出命令,取消 if done_event.is_set(): data_queue.cancel_join_thread() data_queue.close()
def _worker_loop( dataset_kind, dataset, index_queue, data_queue, done_event, auto_collation, collate_fn, drop_last, seed, init_fn, worker_id, num_workers, persistent_workers, ): try: # Initialize C side signal handlers for SIGBUS and SIGSEGV. Python signal # NOQA # module's handlers are executed after Python returns from C low-level # NOQA # handlers, likely when the same fatal signal had already happened # NOQA # again. # https://docs.python.org/3/library/signal.html#execution-of-python-signal-handlers # NOQA signal_handling._set_worker_signal_handlers() torch.set_num_threads(1) random.seed(seed) torch.manual_seed(seed) global _worker_info _worker_info = WorkerInfo( id=worker_id, num_workers=num_workers, seed=seed, dataset=dataset ) from torch.utils.data import _DatasetKind init_exception = None try: if init_fn is not None: init_fn(worker_id) fetcher = _DatasetKind.create_fetcher( dataset_kind, dataset, auto_collation, collate_fn, drop_last ) except Exception: init_exception = ExceptionWrapper( where="in DataLoader worker process {}".format(worker_id) ) iteration_end = False watchdog = ManagerWatchdog() while watchdog.is_alive(): try: r = index_queue.get(timeout=MP_STATUS_CHECK_INTERVAL) except queue.Empty: continue if isinstance(r, _ResumeIteration): iteration_end = False # Recreate the fetcher for worker-reuse policy fetcher = _DatasetKind.create_fetcher( dataset_kind, dataset, auto_collation, collate_fn, drop_last, ) continue elif r is None: # Received the final signal assert done_event.is_set() or iteration_end if done_event.is_set() or ( iteration_end and not persistent_workers ): break continue elif done_event.is_set() or iteration_end: # `done_event` is set. But I haven't received the final signal # (None) yet. I will keep continuing until get it, and skip the # processing steps. continue idx, index = r if init_exception is not None: data = init_exception init_exception = None else: try: data = fetcher.fetch(index) except Exception as e: if ( isinstance(e, StopIteration) and dataset_kind == _DatasetKind.Iterable ): data = _IterableDatasetStopIteration(worker_id) iteration_end = True else: data = ExceptionWrapper( where="in DataLoader worker process {}".format( worker_id ) ) data_queue.put((idx, data)) del data, idx, index, r # save memory except KeyboardInterrupt: # Main process will raise KeyboardInterrupt anyways. pass if done_event.is_set(): data_queue.cancel_join_thread() data_queue.close()