Exemplo n.º 1
0
    def _get_data_batch(
        self,
        data_pack: DataPack,
        context_type: Type[Annotation],
        requests: Optional[Dict[Type[Entry], Union[Dict, List]]] = None,
        offset: int = 0,
    ) -> Iterable[Tuple[Dict, int]]:
        r"""Try to get batches from a dataset  with ``batch_size``, but will
        yield an incomplete batch if the data_pack is exhausted.

        Returns:
            An iterator of tuples ``(batch, cnt)``, ``batch`` is a dict
            containing the required annotations and context, and ``cnt`` is
            the number of instances in the batch.
        """
        instances: List[Dict] = []
        current_size = sum(self.current_batch_sources)

        for data in data_pack.get_data(context_type, requests, offset):
            instances.append(data)
            if len(instances) == self.batch_size - current_size:
                batch = batch_instances(instances)
                self.batch_is_full = True
                yield (batch, len(instances))
                instances = []
                self.batch_is_full = False

        # Flush the remaining data.
        if len(instances) > 0:
            batch = batch_instances(instances)
            yield (batch, len(instances))
Exemplo n.º 2
0
    def _get_data_batch(
        self,
        multi_pack: MultiPack,
        context_type: Type[Annotation],
        requests: Optional[Dict[Type[Entry], Union[Dict, List]]] = None,
        offset: int = 0,
    ) -> Iterable[Tuple[Dict, int]]:
        r"""Try to get batches of size ``batch_size``. If the tail instances
        cannot make up a full batch, will generate a small batch with the tail
        instances.

        Returns:
            An iterator of tuples ``(batch, cnt)``, ``batch`` is a dict
            containing the required annotations and context, and ``cnt`` is
            the number of instances in the batch.
        """
        input_pack = multi_pack.get_pack(self.input_pack_name)

        instances: List[Dict] = []
        current_size = sum(self.current_batch_sources)
        for data in input_pack.get_data(context_type, requests, offset):
            instances.append(data)
            if len(instances) == self.batch_size - current_size:
                batch = batch_instances(instances)
                self.batch_is_full = True
                yield (batch, len(instances))
                instances = []
                self.batch_is_full = False

        if len(instances):
            batch = batch_instances(instances)
            yield (batch, len(instances))
Exemplo n.º 3
0
    def _get_data_batch(
        self,
        data_pack: DataPack,
    ) -> Iterable[Tuple[Dict, int]]:
        r"""Get batches from a dataset  with ``batch_size``, It will yield data
        in the format of a tuple that contains the actual data points and the
        number of data points.

        The data points are generated by querying the data pack using the
        `context_type` and `requests` configuration via calling the
        :meth:`~forte.data.DataPack.get_data` method. Here, Each data point is
        in the same format returned by the `get_data` method, and the meaning
        of `context_type` and `requests` are exactly the same as the `get_data`
        method.

        Args:
            data_pack: The data pack to retrieve data from.

        Returns:
            An iterator of tuples ``(batch, cnt)``, ``batch`` is a dict
            containing the required entries and context, and ``cnt`` is
            the number of instances in the batch.
        """
        instances: List[Dict] = []
        current_size = sum(self.current_batch_sources)

        for data in self._get_instance(data_pack):
            instances.append(data)
            if len(instances) == self.configs.batch_size - current_size:
                batch = batch_instances(instances)
                self.batch_is_full = True
                yield batch, len(instances)
                instances = []
                self.batch_is_full = False

        # Flush the remaining data.
        if len(instances) > 0:
            batch = batch_instances(instances)
            yield batch, len(instances)