def _instance_worker(self, worker_id: int, queue: mp.JoinableQueue, lock) -> None:
        Tqdm.set_lock(lock)
        try:
            self.reader._set_worker_info(WorkerInfo(self.num_workers, worker_id))
            instances = self.reader.read(self.data_path)
            checked_for_token_indexers: bool = False
            for instance in instances:
                # Check the first instance to make sure it doesn't contain any TextFields with
                # token_indexers because we don't want to be duplicating those by sending
                # them across processes.
                if not checked_for_token_indexers:
                    for field_name, field in instance.fields.items():
                        if isinstance(field, TextField) and field._token_indexers is not None:
                            raise ValueError(
                                f"Found a TextField ({field_name}) with token_indexers already "
                                "applied, but you're using num_workers > 0 in your data loader. "
                                "Make sure your dataset reader's text_to_instance() method doesn't "
                                "add any token_indexers to the TextFields it creates. Instead, the token_indexers "
                                "should be added to the instances in the apply_token_indexers() method of your "
                                "dataset reader (which you'll have to implement if you haven't done "
                                "so already)."
                            )
                    checked_for_token_indexers = True
                queue.put((instance, None))
        except Exception as e:
            queue.put((None, (repr(e), traceback.format_exc())))

        # Indicate to the consumer that this worker is finished.
        queue.put((None, None))

        # Wait until this process can safely exit.
        queue.join()
    def _batch_worker(self, worker_id: int, queue: mp.JoinableQueue, lock,
                      rx: Connection) -> None:
        Tqdm.set_lock(lock)
        try:
            self.reader._set_worker_info(
                WorkerInfo(self.num_workers, worker_id))
            instances = self.reader.read(self.data_path)
            for batch in self._instances_to_batches(
                    instances, move_to_device=self._worker_cuda_safe):
                if self._safe_queue_put(worker_id, (batch, None), queue, rx):
                    continue
                else:
                    # Couldn't put item on queue because parent process has exited.
                    return
        except Exception as e:
            if not self._safe_queue_put(
                    worker_id,
                (None, (repr(e), traceback.format_exc())), queue, rx):
                return

        # Indicate to the consumer (main thread) that this worker is finished.
        queue.put((None, None))

        # Wait until this process can safely exit.
        queue.join()
Example #3
0
    def _batch_worker(self, worker_id: int, queue: mp.JoinableQueue) -> None:
        try:
            self.reader._set_worker_info(
                WorkerInfo(self.num_workers, worker_id))
            instances = self.reader.read(self.data_path)
            for batch in self._instances_to_batches(
                    instances, move_to_device=self._worker_cuda_safe):
                queue.put((batch, None))
        except Exception as e:
            queue.put((None, (repr(e), traceback.format_exc())))

        # Indicate to the consumer (main thread) that this worker is finished.
        queue.put((None, None))

        # Wait until this process can safely exit.
        queue.join()
def test_instance_slicing(
    monkeypatch,
    reader_class,
    world_size: Optional[int],
    num_workers: Optional[int],
    max_instances: Optional[int],
):
    """
    Ensure that the intances read by each worker are always unique and the total
    adds up to `max_instances`.
    """
    results: List[Set[int]] = []

    minimum_expected_result_size = max_instances or TOTAL_INSTANCES
    maximum_expected_result_size = max_instances or TOTAL_INSTANCES

    if world_size is not None and num_workers is not None:
        minimum_expected_result_size //= world_size
        minimum_expected_result_size //= num_workers
        maximum_expected_result_size = minimum_expected_result_size + 1
        for global_rank in range(world_size):
            monkeypatch.setattr(common_util, "is_distributed", lambda: True)
            monkeypatch.setattr(dist, "get_rank", lambda: global_rank)
            monkeypatch.setattr(dist, "get_world_size", lambda: world_size)
            for worker_id in range(num_workers):
                reader = reader_class(max_instances=max_instances)
                reader._set_worker_info(WorkerInfo(num_workers, worker_id))
                result = set(
                    x["index"].label for x in reader.read("the-path-doesnt-matter")  # type: ignore
                )
                results.append(result)
    elif world_size is not None:
        minimum_expected_result_size //= world_size
        maximum_expected_result_size = minimum_expected_result_size + 1
        for global_rank in range(world_size):
            monkeypatch.setattr(common_util, "is_distributed", lambda: True)
            monkeypatch.setattr(dist, "get_rank", lambda: global_rank)
            monkeypatch.setattr(dist, "get_world_size", lambda: world_size)
            reader = reader_class(max_instances=max_instances)
            result = set(
                x["index"].label for x in reader.read("the-path-doesnt-matter")  # type: ignore
            )
            results.append(result)
    elif num_workers is not None:
        minimum_expected_result_size //= num_workers
        maximum_expected_result_size = minimum_expected_result_size + 1
        for worker_id in range(num_workers):
            reader = reader_class(max_instances=max_instances)
            reader._set_worker_info(WorkerInfo(num_workers, worker_id))
            result = set(
                x["index"].label for x in reader.read("the-path-doesnt-matter")  # type: ignore
            )
            results.append(result)
    else:
        reader = reader_class(max_instances=max_instances)
        result = set(
            x["index"].label for x in reader.read("the-path-doesnt-matter")  # type: ignore
        )
        results.append(result)

    # We need to check that all of the result sets are mutually exclusive and that they're
    # union has size `max_instances`.
    # Checking that they're mutually exclusive is equivalent to checking that the sum
    # of the size of each set is equal to the size of the union.

    union: Set[int] = set()
    total: int = 0
    for result in results:
        union |= result
        total += len(result)
        # Also make sure the size of the set is within the expected bounds.
        assert minimum_expected_result_size <= len(result)
        assert len(result) <= maximum_expected_result_size

    assert len(union) == total == (max_instances or TOTAL_INSTANCES)