def _ms_loop(dataset, index_queue, data_queue, collate_fn, scale, seed, init_fn, worker_id): global _use_shared_memory _use_shared_memory = True _set_worker_signal_handlers() torch.set_num_threads(1) torch.manual_seed(seed) while True: r = index_queue.get() if r is None: break idx, batch_indices = r try: idx_scale = 0 if len(scale) > 1 and dataset.train: idx_scale = random.randrange(0, len(scale)) dataset.set_scale(idx_scale) samples = collate_fn([dataset[i] for i in batch_indices]) samples.append(idx_scale) except Exception: data_queue.put((idx, _utils.ExceptionWrapper(sys.exc_info()))) else: data_queue.put((idx, samples))
def _ms_loop(dataset, index_queue, data_queue, done_event, collate_fn, scale, seed, init_fn, worker_id): try: global _use_shared_memory _use_shared_memory = True _set_worker_signal_handlers() torch.set_num_threads(1) random.seed(seed) torch.manual_seed(seed) data_queue.cancel_join_thread() if init_fn is not None: init_fn(worker_id) # watchdog = ManagerWatchdog() watchdog = _utils.worker.ManagerWatchdog() while watchdog.is_alive(): # try: # r = index_queue.get(timeout=MP_STATUS_CHECK_INTERVAL) try: r = index_queue.get(timeout=_utils.MP_STATUS_CHECK_INTERVAL) except queue.Empty: continue if r is None: assert done_event.is_set() return elif done_event.is_set(): continue idx, batch_indices = r try: idx_scale = 0 if len(scale) > 1 and dataset.train: idx_scale = random.randrange(0, len(scale)) dataset.set_scale(idx_scale) samples = collate_fn([dataset[i] for i in batch_indices]) samples.append(idx_scale) # except Exception: # data_queue.put((idx, ExceptionWrapper(sys.exc_info()))) except Exception: data_queue.put((idx, _utils.ExceptionWrapper(sys.exc_info()))) else: data_queue.put((idx, samples)) except KeyboardInterrupt: pass