def _create_tensor_dicts_from_qiterable(qiterable: QIterable, output_queue: Queue, iterator: DataIterator, shuffle: bool, index: int) -> None: """ Pulls instances from ``qiterable.output_queue``, converts them into ``TensorDict``s using ``iterator``, and puts them on the ``output_queue``. """ logger.info(f"Iterator worker: {index} PID: {os.getpid()}") def instances() -> Iterator[Instance]: while qiterable.num_active_workers.value > 0 or qiterable.num_inflight_items.value > 0: while True: try: yield qiterable.output_queue.get(block=False, timeout=1.0) with qiterable.num_inflight_items.get_lock(): qiterable.num_inflight_items.value -= 1 except Empty: break for tensor_dict in iterator(instances(), num_epochs=1, shuffle=shuffle): output_queue.put(tensor_dict) output_queue.put(index) # See the note above in _create_tensor_dicts_from_queue. output_queue.join()
def _create_tensor_dicts_from_queue(input_queue: Queue, output_queue: Queue, iterator: DataIterator, shuffle: bool, index: int) -> None: """ Pulls instances from ``input_queue``, converts them into ``TensorDict``s using ``iterator``, and puts them on the ``output_queue``. """ logger.info(f"Iterator worker: {index} PID: {os.getpid()}") def instances() -> Iterator[Instance]: instance = input_queue.get() while instance is not None: yield instance instance = input_queue.get() for tensor_dict in iterator(instances(), num_epochs=1, shuffle=shuffle): output_queue.put(tensor_dict) output_queue.put(index) # We need to ensure we've gotten all the tensors out of this queue before # this process ends. Otherwise we'll crash. See # https://github.com/pytorch/pytorch/issues/7181. This appears to be an # issue specifically with tensors, perhaps due to the refcounting involved # in managing them in shared memory. If you're working on this code, be # aware that I've only been able to reproduce this issue on Linux. Testing # on a Mac alone is not sufficient. output_queue.join()