def _instance_worker(self, worker_id: int, queue: mp.JoinableQueue, lock) -> None: Tqdm.set_lock(lock) try: self.reader._set_worker_info(WorkerInfo(self.num_workers, worker_id)) instances = self.reader.read(self.data_path) checked_for_token_indexers: bool = False for instance in instances: # Check the first instance to make sure it doesn't contain any TextFields with # token_indexers because we don't want to be duplicating those by sending # them across processes. if not checked_for_token_indexers: for field_name, field in instance.fields.items(): if isinstance(field, TextField) and field._token_indexers is not None: raise ValueError( f"Found a TextField ({field_name}) with token_indexers already " "applied, but you're using num_workers > 0 in your data loader. " "Make sure your dataset reader's text_to_instance() method doesn't " "add any token_indexers to the TextFields it creates. Instead, the token_indexers " "should be added to the instances in the apply_token_indexers() method of your " "dataset reader (which you'll have to implement if you haven't done " "so already)." ) checked_for_token_indexers = True queue.put((instance, None)) except Exception as e: queue.put((None, (repr(e), traceback.format_exc()))) # Indicate to the consumer that this worker is finished. queue.put((None, None)) # Wait until this process can safely exit. queue.join()
def _batch_worker(self, worker_id: int, queue: mp.JoinableQueue, lock, rx: Connection) -> None: Tqdm.set_lock(lock) try: self.reader._set_worker_info( WorkerInfo(self.num_workers, worker_id)) instances = self.reader.read(self.data_path) for batch in self._instances_to_batches( instances, move_to_device=self._worker_cuda_safe): if self._safe_queue_put(worker_id, (batch, None), queue, rx): continue else: # Couldn't put item on queue because parent process has exited. return except Exception as e: if not self._safe_queue_put( worker_id, (None, (repr(e), traceback.format_exc())), queue, rx): return # Indicate to the consumer (main thread) that this worker is finished. queue.put((None, None)) # Wait until this process can safely exit. queue.join()
def _batch_worker(self, worker_id: int, queue: mp.JoinableQueue) -> None: try: self.reader._set_worker_info( WorkerInfo(self.num_workers, worker_id)) instances = self.reader.read(self.data_path) for batch in self._instances_to_batches( instances, move_to_device=self._worker_cuda_safe): queue.put((batch, None)) except Exception as e: queue.put((None, (repr(e), traceback.format_exc()))) # Indicate to the consumer (main thread) that this worker is finished. queue.put((None, None)) # Wait until this process can safely exit. queue.join()
def test_instance_slicing( monkeypatch, reader_class, world_size: Optional[int], num_workers: Optional[int], max_instances: Optional[int], ): """ Ensure that the intances read by each worker are always unique and the total adds up to `max_instances`. """ results: List[Set[int]] = [] minimum_expected_result_size = max_instances or TOTAL_INSTANCES maximum_expected_result_size = max_instances or TOTAL_INSTANCES if world_size is not None and num_workers is not None: minimum_expected_result_size //= world_size minimum_expected_result_size //= num_workers maximum_expected_result_size = minimum_expected_result_size + 1 for global_rank in range(world_size): monkeypatch.setattr(common_util, "is_distributed", lambda: True) monkeypatch.setattr(dist, "get_rank", lambda: global_rank) monkeypatch.setattr(dist, "get_world_size", lambda: world_size) for worker_id in range(num_workers): reader = reader_class(max_instances=max_instances) reader._set_worker_info(WorkerInfo(num_workers, worker_id)) result = set( x["index"].label for x in reader.read("the-path-doesnt-matter") # type: ignore ) results.append(result) elif world_size is not None: minimum_expected_result_size //= world_size maximum_expected_result_size = minimum_expected_result_size + 1 for global_rank in range(world_size): monkeypatch.setattr(common_util, "is_distributed", lambda: True) monkeypatch.setattr(dist, "get_rank", lambda: global_rank) monkeypatch.setattr(dist, "get_world_size", lambda: world_size) reader = reader_class(max_instances=max_instances) result = set( x["index"].label for x in reader.read("the-path-doesnt-matter") # type: ignore ) results.append(result) elif num_workers is not None: minimum_expected_result_size //= num_workers maximum_expected_result_size = minimum_expected_result_size + 1 for worker_id in range(num_workers): reader = reader_class(max_instances=max_instances) reader._set_worker_info(WorkerInfo(num_workers, worker_id)) result = set( x["index"].label for x in reader.read("the-path-doesnt-matter") # type: ignore ) results.append(result) else: reader = reader_class(max_instances=max_instances) result = set( x["index"].label for x in reader.read("the-path-doesnt-matter") # type: ignore ) results.append(result) # We need to check that all of the result sets are mutually exclusive and that they're # union has size `max_instances`. # Checking that they're mutually exclusive is equivalent to checking that the sum # of the size of each set is equal to the size of the union. union: Set[int] = set() total: int = 0 for result in results: union |= result total += len(result) # Also make sure the size of the set is within the expected bounds. assert minimum_expected_result_size <= len(result) assert len(result) <= maximum_expected_result_size assert len(union) == total == (max_instances or TOTAL_INSTANCES)