예제 #1
0
    def _memory_sized_lists(
            self, instances: Iterable[Instance]) -> Iterable[List[Instance]]:
        """
        Breaks the dataset into "memory-sized" lists of instances,
        which it yields up one at a time until it gets through a full epoch.

        For example, if the dataset is already an in-memory list, and each epoch
        represents one pass through the dataset, it just yields back the dataset.
        Whereas if the dataset is lazily read from disk and we've specified to
        load 1000 instances at a time, then it yields lists of 1000 instances each.
        """
        lazy = is_lazy(instances)

        # Get an iterator over the next epoch worth of instances.
        iterator = self._take_instances(instances, self._instances_per_epoch)

        # We have four different cases to deal with:

        # With lazy instances and no guidance about how many to load into memory,
        # we just load ``batch_size`` instances at a time:
        if lazy and self._max_instances_in_memory is None:
            yield from lazy_groups_of(iterator, self._batch_size)
        # If we specified max instances in memory, lazy or not, we just
        # load ``max_instances_in_memory`` instances at a time:
        elif self._max_instances_in_memory is not None:
            yield from lazy_groups_of(iterator, self._max_instances_in_memory)
        # If we have non-lazy instances, and we want all instances each epoch,
        # then we just yield back the list of instances:
        elif self._instances_per_epoch is None:
            yield ensure_list(instances)
        # In the final case we have non-lazy instances, we want a specific number
        # of instances each epoch, and we didn't specify how to many instances to load
        # into memory. So we convert the whole iterator to a list:
        else:
            yield list(iterator)
예제 #2
0
    def __init__(self, instances: Iterable[Instance]) -> None:
        """
        A Batch just takes an iterable of instances in its constructor and hangs onto them
        in a list.
        """
        super().__init__()

        self.instances: List[Instance] = ensure_list(instances)
        self._check_types()
예제 #3
0
 def get_num_batches(self, instances: Iterable[Instance]) -> int:
     """
     Returns the number of batches that ``dataset`` will be split into; if you want to track
     progress through the batch with the generator produced by ``__call__``, this could be
     useful.
     """
     if is_lazy(instances) and self._instances_per_epoch is None:
         # Unable to compute num batches, so just return 1.
         return 1
     elif self._instances_per_epoch is not None:
         return math.ceil(self._instances_per_epoch / self._batch_size)
     else:
         # Not lazy, so can compute the list length.
         return math.ceil(len(ensure_list(instances)) / self._batch_size)