def _pin_memory_loop(in_queue, out_queue, done_event, pin_memory_param, device_id): """ This is copied from dataloader. It uses a different `pin_memory()`. It'd probably be best to merge. """ if pin_memory_param: torch.cuda.set_device(device_id) while True: try: r = in_queue.get() except Exception: if done_event.is_set(): return raise if r is None or done_event.is_set(): break if isinstance(r[1], ExceptionWrapper): out_queue.put(r) continue idx, batch = r try: if pin_memory_param: batch = pin_memory(batch) except Exception: out_queue.put((idx, ExceptionWrapper(sys.exc_info()))) else: out_queue.put((idx, batch))
def _ms_loop(dataset, index_queue, data_queue, collate_fn, scale, seed, init_fn, worker_id): global _use_shared_memory _use_shared_memory = True _set_worker_signal_handlers() torch.set_num_threads(1) torch.manual_seed(seed) while True: r = index_queue.get() if r is None: break idx, batch_indices = r try: idx_scale = 0 if len(scale) > 1 and dataset.train: idx_scale = random.randrange(0, len(scale)) dataset.set_scale(idx_scale) samples = collate_fn([dataset[i] for i in batch_indices]) samples.append(idx_scale) except Exception: data_queue.put((idx, ExceptionWrapper(sys.exc_info()))) else: data_queue.put((idx, samples))
def _worker_loop(index_queue, data_queue, done_event, seed, init_fn, worker_id, cnt): # See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on the # logic of this function. try: global _use_shared_memory _use_shared_memory = True # Intialize C side signal handlers for SIGBUS and SIGSEGV. Python signal # module's handlers are executed after Python returns from C low-level # handlers, likely when the same fatal signal happened again already. # https://docs.python.org/3/library/signal.html Sec. 18.8.1.1 signal_handling._set_worker_signal_handlers() torch.set_num_threads(1) random.seed(seed) torch.manual_seed(seed) data_queue.cancel_join_thread() if init_fn is not None: init_fn(worker_id) watchdog = ManagerWatchdog() while watchdog.is_alive(): try: r = index_queue.get(timeout=MP_STATUS_CHECK_INTERVAL) except queue.Empty: continue if r is None: # Received the final signal assert done_event.is_set() return elif done_event.is_set(): # Done event is set. But I haven't received the final signal # (None) yet. I will keep continuing until get it, and skip the # processing steps. continue idx, batch_indices = r try: samples = cnt.increment(batch_indices) # if cnt.val.value % interval == 0: # print('change') # dataset.transform(np.random.choice(transform_fns)) # samples = collate_fn([dataset[i] for i in batch_indices]) # print(cnt.val.value) except Exception: # It is important that we don't store exc_info in a variable, # see NOTE [ Python Traceback Reference Cycle Problem ] data_queue.put((idx, ExceptionWrapper(sys.exc_info()))) else: data_queue.put((idx, samples)) del samples except KeyboardInterrupt: # Main process will raise KeyboardInterrupt anyways. pass
def run(self): # Disable polluting stderr with errors that are supposed to happen. sys.stderr = open(os.devnull, "w") try: super(ErrorTrackingProcess, self).run() self._cconn.send(None) except Exception: self._cconn.send(ExceptionWrapper(sys.exc_info())) raise
def run(self): if HAS_FAULTHANDLER: faulthandler.enable() if not IS_WINDOWS: # windows does not have faulthandler.register faulthandler.register(signal.SIGUSR1, chain=True) if self.disable_stderr: # Disable polluting stderr with errors that are supposed to happen. sys.stderr = open(os.devnull, "w") try: super(ErrorTrackingProcess, self).run() self._cconn.send(None) except Exception: self._cconn.send(ExceptionWrapper(sys.exc_info())) raise
def _ms_loop(dataset, index_queue, data_queue, done_event, collate_fn, scale, seed, init_fn, worker_id): try: collate._use_shared_memory = True signal_handling._set_worker_signal_handlers() torch.set_num_threads(1) random.seed(seed) torch.manual_seed(seed) data_queue.cancel_join_thread() if init_fn is not None: init_fn(worker_id) watchdog = ManagerWatchdog() while watchdog.is_alive(): try: r = index_queue.get(timeout=MP_STATUS_CHECK_INTERVAL) except queue.Empty: continue if r is None: assert done_event.is_set() return elif done_event.is_set(): continue idx, batch_indices = r try: idx_scale = 0 if len(scale) > 1 and dataset.train: idx_scale = random.randrange(0, len(scale)) dataset.set_scale(idx_scale) samples = collate_fn([dataset[i] for i in batch_indices]) ##make samples.append(idx_scale) except Exception: data_queue.put((idx, ExceptionWrapper(sys.exc_info()))) else: data_queue.put((idx, samples)) del samples except KeyboardInterrupt: pass
def _worker_loop( data_reader, batch_queue, data_queue, global_done_event, worker_done_event, seed, init_fn, worker_id, ): # Intialize C side signal handlers for SIGBUS and SIGSEGV. Python signal # module's handlers are executed after Python returns from C low-level # handlers, likely when the same fatal signal happened again already. # https://docs.python.org/3/library/signal.html Sec. 18.8.1.1 _set_worker_signal_handlers() torch.set_num_threads(1) random.seed(seed) # TODO: numpy doesn't take seed bigger than INT32 # np.random.seed(seed) torch.manual_seed(seed) # Do not wait for putting thread to join when this worker exits. Otherwise, # this worker may always be waiting to put and doesn't check batch_queue # and global_done_event for termination signal. data_queue.cancel_join_thread() if init_fn is not None: init_fn(worker_id) watchdog = ManagerWatchdog() shard = data_reader.get_shard(worker_id) shard_itr = iter(shard) shard_done = False while True: if shard_done: # Wait until the main thread acknowledge the WorkerDone message or # it signals shutdown. if (not watchdog.is_alive() or global_done_event.is_set() or worker_done_event.wait(0.1)): break continue try: idx = batch_queue.get(timeout=MANAGER_STATUS_CHECK_INTERVAL) except queue.Empty: if watchdog.is_alive() and not global_done_event.is_set(): continue else: break # use global_done_event so that we can get faster exiting signal even if there # are still batches in batch_queue if idx is None or global_done_event.is_set(): break try: samples = next(shard_itr) except StopIteration: # Signal to the main thread that this worker has run out of data. # The worker cannot exit immediately because the queue might not be # flushed immediately. data_queue.put((idx, WorkerDone(worker_id))) shard_done = True except Exception: data_queue.put((idx, ExceptionWrapper(sys.exc_info()))) else: data_queue.put((idx, samples)) del samples