예제 #1
0
    def get_batch(
        self,
        input_pack: PackType,
        context_type: Type[Annotation],
        requests: DataRequest,
    ):
        r"""Returns an iterator of A tuple contains datapack,
        instance and batch data. In the basic ProcessingBatcher,
        to be compatible with existing implementation,
        instance is not needed, thus using None."""

        # cache the new pack and generate batches
        for (data_batch,
             instance_num) in self._get_data_batch(input_pack, context_type,
                                                   requests):
            self.current_batch = merge_batches(
                [self.current_batch, data_batch])
            self.current_batch_sources.append(instance_num)
            self.data_pack_pool.extend([input_pack] * instance_num)

            # Yield a batch on two conditions.
            # 1. If we do not want to have batches from different pack, we
            # should yield since this pack is exhausted.
            # 2. We should also yield when the batcher condition is met:
            # i.e. ``_should_yield()`` is True.
            if not self.cross_pack or self._should_yield():
                yield self.data_pack_pool, None, self.current_batch
                self.current_batch = {}
                self.current_batch_sources = []
                self.data_pack_pool = []
예제 #2
0
    def get_batch(
            self, input_pack: PackType, context_type: Type[Annotation],
            requests: DataRequest) -> Iterator[Dict]:
        r"""Returns an iterator of data batches."""
        # cache the new pack and generate batches
        self.data_pack_pool.append(input_pack)

        for (data_batch, instance_num) in self._get_data_batch(
                input_pack, context_type, requests):
            self.current_batch = merge_batches(
                [self.current_batch, data_batch])
            self.current_batch_sources.append(instance_num)

            # Yield a batch on two conditions.
            # 1. If we do not want to have batches from different pack, we
            # should yield since this pack is exhausted.
            # 2. We should also yield when the batcher condition is met:
            # i.e. ``_should_yield()`` is True.
            if not self.cross_pack or self._should_yield():
                yield self.current_batch
                self.current_batch = {}
                self.current_batch_sources = []
예제 #3
0
    def get_batch(
        self, input_pack: PackType
    ) -> Iterator[Tuple[List[PackType], List[Optional[Annotation]], Dict]]:
        r"""By feeding data pack to this function, formatted features will
        be yielded based on the batching logic. Each element in the iterator is
        a triplet of datapack, context instance and batched data.

        Args:
            input_pack: The input data pack to get features from.

        Returns:
             An iterator of A tuple contains datapack, context instance and
             batch data.
             Note: For backward compatibility issues, this function
             return a list of `None` as contexts.
        """
        batch_count = 0

        # cache the new pack and generate batches
        for (data_batch, instance_num) in self._get_data_batch(input_pack):
            self.current_batch = merge_batches([self.current_batch, data_batch])
            self.current_batch_sources.append(instance_num)
            self.data_pack_pool.extend([input_pack] * instance_num)

            # Yield a batch on two conditions.
            # 1. If we do not want to have batches from different pack, we
            # should yield since this pack is exhausted.
            # 2. We should also yield when the batcher condition is met:
            # i.e. ``_should_yield()`` is True.
            if not self._cross_pack or self._should_yield():
                batch_count += 1
                yield (
                    self.data_pack_pool,
                    [None] * len(self.data_pack_pool),
                    self.current_batch,
                )
                self.current_batch = {}
                self.current_batch_sources = []
                self.data_pack_pool = []