def _worker_loop(input_queue, output_queue, done_event, init_fn, worker_id): try: # Initialize C side signal handlers for SIGBUS and SIGSEGV. Python signal # module's handlers are executed after Python returns from C low-level # handlers, likely when the same fatal signal had already happened # again. # https://docs.python.org/3/library/signal.html#execution-of-python-signal-handlers signal_handling._set_worker_signal_handlers() torch.set_num_threads(1) watchdog = ManagerWatchdog() global _worker_info _worker_info = WorkerInfo(id=worker_id) init_exception = None try: if init_fn is not None: init_fn(worker_id) except Exception: init_exception = ExceptionWrapper(where=f'in process {worker_id}') if watchdog.is_alive(): output_queue.put((None, init_exception)) while watchdog.is_alive(): try: func = input_queue.get(timeout=MP_STATUS_CHECK_INTERVAL) except queue.Empty: continue if func is None: # Received the final signal break elif done_event.is_set(): # `done_event` is set. But I haven't received the final signal # (None) yet. I will keep continuing until get it, and skip the # processing steps. continue try: result = func() except Exception: # It is important that we don't store exc_info in a variable. # `ExceptionWrapper` does the correct thing. # See NOTE [ Python Traceback Reference Cycle Problem ] result = ExceptionWrapper(where=f'in process {worker_id}') output_queue.put(result) del result, func # save memory except KeyboardInterrupt: # Main process will raise KeyboardInterrupt anyways. pass output_queue.cancel_join_thread() output_queue.close()
def _worker_loop(index_queue, data_queue, done_event, seed, init_fn, worker_id, cnt): # See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on the # logic of this function. try: global _use_shared_memory _use_shared_memory = True # Intialize C side signal handlers for SIGBUS and SIGSEGV. Python signal # module's handlers are executed after Python returns from C low-level # handlers, likely when the same fatal signal happened again already. # https://docs.python.org/3/library/signal.html Sec. 18.8.1.1 signal_handling._set_worker_signal_handlers() torch.set_num_threads(1) random.seed(seed) torch.manual_seed(seed) data_queue.cancel_join_thread() if init_fn is not None: init_fn(worker_id) watchdog = ManagerWatchdog() while watchdog.is_alive(): try: r = index_queue.get(timeout=MP_STATUS_CHECK_INTERVAL) except queue.Empty: continue if r is None: # Received the final signal assert done_event.is_set() return elif done_event.is_set(): # Done event is set. But I haven't received the final signal # (None) yet. I will keep continuing until get it, and skip the # processing steps. continue idx, batch_indices = r try: samples = cnt.increment(batch_indices) # if cnt.val.value % interval == 0: # print('change') # dataset.transform(np.random.choice(transform_fns)) # samples = collate_fn([dataset[i] for i in batch_indices]) # print(cnt.val.value) except Exception: # It is important that we don't store exc_info in a variable, # see NOTE [ Python Traceback Reference Cycle Problem ] data_queue.put((idx, ExceptionWrapper(sys.exc_info()))) else: data_queue.put((idx, samples)) del samples except KeyboardInterrupt: # Main process will raise KeyboardInterrupt anyways. pass
def _ms_loop(dataset, index_queue, data_queue, done_event, collate_fn, scale, seed, init_fn, worker_id): try: collate._use_shared_memory = True signal_handling._set_worker_signal_handlers() torch.set_num_threads(1) random.seed(seed) torch.manual_seed(seed) data_queue.cancel_join_thread() if init_fn is not None: init_fn(worker_id) watchdog = ManagerWatchdog() while watchdog.is_alive(): try: r = index_queue.get(timeout=MP_STATUS_CHECK_INTERVAL) except queue.Empty: continue if r is None: assert done_event.is_set() return elif done_event.is_set(): continue idx, batch_indices = r try: idx_scale = 0 if len(scale) > 1 and dataset.train: idx_scale = random.randrange(0, len(scale)) dataset.set_scale(idx_scale) samples = collate_fn([dataset[i] for i in batch_indices]) ##make samples.append(idx_scale) except Exception: data_queue.put((idx, ExceptionWrapper(sys.exc_info()))) else: data_queue.put((idx, samples)) del samples except KeyboardInterrupt: pass
def _worker_loop( data_reader, batch_queue, data_queue, global_done_event, worker_done_event, seed, init_fn, worker_id, ): # Intialize C side signal handlers for SIGBUS and SIGSEGV. Python signal # module's handlers are executed after Python returns from C low-level # handlers, likely when the same fatal signal happened again already. # https://docs.python.org/3/library/signal.html Sec. 18.8.1.1 _set_worker_signal_handlers() torch.set_num_threads(1) random.seed(seed) # TODO: numpy doesn't take seed bigger than INT32 # np.random.seed(seed) torch.manual_seed(seed) # Do not wait for putting thread to join when this worker exits. Otherwise, # this worker may always be waiting to put and doesn't check batch_queue # and global_done_event for termination signal. data_queue.cancel_join_thread() if init_fn is not None: init_fn(worker_id) watchdog = ManagerWatchdog() shard = data_reader.get_shard(worker_id) shard_itr = iter(shard) shard_done = False while True: if shard_done: # Wait until the main thread acknowledge the WorkerDone message or # it signals shutdown. if (not watchdog.is_alive() or global_done_event.is_set() or worker_done_event.wait(0.1)): break continue try: idx = batch_queue.get(timeout=MANAGER_STATUS_CHECK_INTERVAL) except queue.Empty: if watchdog.is_alive() and not global_done_event.is_set(): continue else: break # use global_done_event so that we can get faster exiting signal even if there # are still batches in batch_queue if idx is None or global_done_event.is_set(): break try: samples = next(shard_itr) except StopIteration: # Signal to the main thread that this worker has run out of data. # The worker cannot exit immediately because the queue might not be # flushed immediately. data_queue.put((idx, WorkerDone(worker_id))) shard_done = True except Exception: data_queue.put((idx, ExceptionWrapper(sys.exc_info()))) else: data_queue.put((idx, samples)) del samples
def _custom_worker_loop(dataset_kind, dataset, index_queue, data_queue, done_event, auto_collation, collate_fn, drop_last, seed, init_fn, worker_id, num_workers): # See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on the # logic of this function. try: # Intialize C side signal handlers for SIGBUS and SIGSEGV. Python signal # module's handlers are executed after Python returns from C low-level # handlers, likely when the same fatal signal had already happened # again. # https://docs.python.org/3/library/signal.html#execution-of-python-signal-handlers signal_handling._set_worker_signal_handlers() torch.set_num_threads(1) random.seed(seed) torch.manual_seed(seed) global _worker_info _worker_info = WorkerInfo(id=worker_id, num_workers=num_workers, seed=seed, dataset=dataset) init_exception = None try: if init_fn is not None: init_fn(worker_id) fetcher = _CustomDatasetKind.create_fetcher( dataset_kind, dataset, auto_collation, collate_fn, drop_last) except Exception: init_exception = ExceptionWrapper( where="in DataLoader worker process {}".format(worker_id)) # When using Iterable mode, some worker can exit earlier than others due # to the IterableDataset behaving differently for different workers. # When such things happen, an `_IterableDatasetStopIteration` object is # sent over to the main process with the ID of this worker, so that the # main process won't send more tasks to this worker, and will send # `None` to this worker to properly exit it. # # Note that we cannot set `done_event` from a worker as it is shared # among all processes. Instead, we set the `iteration_end` flag to # signify that the iterator is exhausted. When either `done_event` or # `iteration_end` is set, we skip all processing step and just wait for # `None`. iteration_end = False watchdog = ManagerWatchdog() while watchdog.is_alive(): try: r = index_queue.get(timeout=MP_STATUS_CHECK_INTERVAL) except queue.Empty: continue if r is None: # Received the final signal assert done_event.is_set() or iteration_end break elif done_event.is_set() or iteration_end: # `done_event` is set. But I haven't received the final signal # (None) yet. I will keep continuing until get it, and skip the # processing steps. continue idx, index = r if init_exception is not None: data = init_exception init_exception = None else: try: data = fetcher.fetch(index) except Exception as e: if isinstance(e, StopIteration ) and dataset_kind == _DatasetKind.Iterable: data = _IterableDatasetStopIteration(worker_id) # Set `iteration_end` # (1) to save future `next(...)` calls, and # (2) to avoid sending multiple `_IterableDatasetStopIteration`s. iteration_end = True else: # It is important that we don't store exc_info in a variable. # `ExceptionWrapper` does the correct thing. # See NOTE [ Python Traceback Reference Cycle Problem ] data = ExceptionWrapper( where="in DataLoader worker process {}".format( worker_id)) data_queue.put((idx, data)) del data, idx, index, r # save memory except KeyboardInterrupt: # Main process will raise KeyboardInterrupt anyways. pass if done_event.is_set(): data_queue.cancel_join_thread() data_queue.close()
def _worker_loop( data_reader, batch_queue, data_queue, global_done_event, worker_done_event, seed, init_fn, worker_id, ): # Intialize C side signal handlers for SIGBUS and SIGSEGV. Python signal # module's handlers are executed after Python returns from C low-level # handlers, likely when the same fatal signal happened again already. # https://docs.python.org/3/library/signal.html Sec. 18.8.1.1 _set_worker_signal_handlers() torch.set_num_threads(1) random.seed(seed) # TODO: numpy doesn't take seed bigger than INT32 # np.random.seed(seed) torch.manual_seed(seed) # Do not wait for putting thread to join when this worker exits. Otherwise, # this worker may always be waiting to put and doesn't check batch_queue # and global_done_event for termination signal. data_queue.cancel_join_thread() if init_fn is not None: init_fn(worker_id) watchdog = ManagerWatchdog() shard = data_reader.get_shard(worker_id) shard_itr = iter(shard) shard_done = False while True: if shard_done: # Wait until the main thread acknowledge the WorkerDone message or # it signals shutdown. if ( not watchdog.is_alive() or global_done_event.is_set() or worker_done_event.wait(0.1) ): break continue try: idx = batch_queue.get(timeout=MANAGER_STATUS_CHECK_INTERVAL) except queue.Empty: if watchdog.is_alive() and not global_done_event.is_set(): continue else: break # use global_done_event so that we can get faster exiting signal even if there # are still batches in batch_queue if idx is None or global_done_event.is_set(): break try: samples = next(shard_itr) except StopIteration: # Signal to the main thread that this worker has run out of data. # The worker cannot exit immediately because the queue might not be # flushed immediately. data_queue.put((idx, WorkerDone(worker_id))) shard_done = True except Exception: data_queue.put((idx, ExceptionWrapper(sys.exc_info()))) else: data_queue.put((idx, samples)) del samples