Ejemplo n.º 1
0
    def _batch_worker(self, worker_id: int, queue: mp.JoinableQueue, lock,
                      rx: Connection) -> None:
        Tqdm.set_lock(lock)
        try:
            self.reader._set_worker_info(
                WorkerInfo(self.num_workers, worker_id))
            instances = self.reader.read(self.data_path)
            for batch in self._instances_to_batches(
                    instances, move_to_device=self._worker_cuda_safe):
                if self._safe_queue_put(worker_id, (batch, None), queue, rx):
                    continue
                else:
                    # Couldn't put item on queue because parent process has exited.
                    return
        except Exception as e:
            if not self._safe_queue_put(
                    worker_id,
                (None, (repr(e), traceback.format_exc())), queue, rx):
                return

        # Indicate to the consumer (main thread) that this worker is finished.
        queue.put((None, None))

        # Wait until this process can safely exit.
        queue.join()
    def _instance_worker(self, worker_id: int, queue: mp.JoinableQueue, lock) -> None:
        Tqdm.set_lock(lock)
        try:
            self.reader._set_worker_info(WorkerInfo(self.num_workers, worker_id))
            instances = self.reader.read(self.data_path)
            checked_for_token_indexers: bool = False
            for instance in instances:
                # Check the first instance to make sure it doesn't contain any TextFields with
                # token_indexers because we don't want to be duplicating those by sending
                # them across processes.
                if not checked_for_token_indexers:
                    for field_name, field in instance.fields.items():
                        if isinstance(field, TextField) and field._token_indexers is not None:
                            raise ValueError(
                                f"Found a TextField ({field_name}) with token_indexers already "
                                "applied, but you're using num_workers > 0 in your data loader. "
                                "Make sure your dataset reader's text_to_instance() method doesn't "
                                "add any token_indexers to the TextFields it creates. Instead, the token_indexers "
                                "should be added to the instances in the apply_token_indexers() method of your "
                                "dataset reader (which you'll have to implement if you haven't done "
                                "so already)."
                            )
                    checked_for_token_indexers = True
                queue.put((instance, None))
        except Exception as e:
            queue.put((None, (repr(e), traceback.format_exc())))

        # Indicate to the consumer that this worker is finished.
        queue.put((None, None))

        # Wait until this process can safely exit.
        queue.join()
Ejemplo n.º 3
0
    def _batch_worker(self, worker_id: int, queue: mp.JoinableQueue) -> None:
        try:
            self.reader._set_worker_info(
                WorkerInfo(self.num_workers, worker_id))
            instances = self.reader.read(self.data_path)
            for batch in self._instances_to_batches(
                    instances, move_to_device=self._worker_cuda_safe):
                queue.put((batch, None))
        except Exception as e:
            queue.put((None, (repr(e), traceback.format_exc())))

        # Indicate to the consumer (main thread) that this worker is finished.
        queue.put((None, None))

        # Wait until this process can safely exit.
        queue.join()
Ejemplo n.º 4
0
    def _batch_worker(self, instance_queue: mp.JoinableQueue,
                      batch_queue: mp.JoinableQueue) -> None:
        try:
            for batch_chunk in lazy_groups_of(
                    self._instances_to_batches(
                        self._gather_instances(instance_queue)),
                    self._batch_chunk_size,
            ):
                batch_queue.put((batch_chunk, None))
        except Exception as e:
            batch_queue.put((None, (e, traceback.format_exc())))

        # Indicate to the consumer (main thread) that this worker is finished.
        batch_queue.put((None, None))

        # Wait for the consumer (in the main process) to finish receiving all batch groups
        # to avoid prematurely closing the queue.
        batch_queue.join()