コード例 #1
0
from allennlp.common.util import lazy_groups_of
from allennlp.data.instance import Instance
from allennlp.data.iterators.data_iterator import DataIterator
from allennlp.data.dataset import Batch

logger = logging.getLogger(__name__)  # pylint: disable=invalid-name


class BasicIterator(DataIterator):
    u"""
    A very basic iterator that takes a dataset, possibly shuffles it, and creates fixed sized batches.

    It takes the same parameters as :class:`allennlp.data.iterators.DataIterator`
    """
    def _create_batches(self, instances, shuffle):
        # First break the dataset into memory-sized lists:
        for instance_list in self._memory_sized_lists(instances):
            if shuffle:
                random.shuffle(instance_list)
            iterator = iter(instance_list)
            # Then break each memory-sized list into batches.
            for batch_instances in lazy_groups_of(iterator, self._batch_size):
                for possibly_smaller_batches in self._ensure_batch_is_sufficiently_small(
                        batch_instances):
                    batch = Batch(possibly_smaller_batches)
                    yield batch


BasicIterator = DataIterator.register(u"basic")(BasicIterator)
コード例 #2
0
            instance_list = sort_by_padding(instance_list,
                                            self._sorting_keys,
                                            self.vocab,
                                            self._padding_noise)

            batches = []
            for batch_instances in lazy_groups_of(iter(instance_list), self._batch_size):
                for possibly_smaller_batches in self._ensure_batch_is_sufficiently_small(batch_instances):
                    batches.append(Batch(possibly_smaller_batches))

            move_to_front = self._biggest_batch_first and len(batches) > 1
            if move_to_front:
                # We'll actually pop the last _two_ batches, because the last one might not be full.
                last_batch = batches.pop()
                penultimate_batch = batches.pop()
            if shuffle:
                random.shuffle(batches)
            else:
                logger.warning(u"shuffle parameter is set to False,"
                               u" while bucket iterators by definition change the order of your data.")
            if move_to_front:
                batches.insert(0, penultimate_batch)
                batches.insert(0, last_batch)

            _i = batches
            while True:
                yield _i.next()

BucketIterator = DataIterator.register(u"bucket")(BucketIterator)
コード例 #3
0
    Parameters
    ----------
    See :class:`BucketIterator`.
    """
    def __init__(self,
                 sorting_keys,
                 padding_noise=0.1,
                 biggest_batch_first=False,
                 batch_size=32,
                 instances_per_epoch=None,
                 max_instances_in_memory=None,
                 cache_instances=False):
        super(EpochTrackingBucketIterator,
              self).__init__(sorting_keys=sorting_keys,
                             padding_noise=padding_noise,
                             biggest_batch_first=biggest_batch_first,
                             batch_size=batch_size,
                             instances_per_epoch=instances_per_epoch,
                             max_instances_in_memory=max_instances_in_memory,
                             track_epoch=True,
                             cache_instances=cache_instances)
        warnings.warn(
            u"EpochTrackingBucketIterator is deprecated, "
            u"please just use BucketIterator with track_epoch=True",
            DeprecationWarning)


EpochTrackingBucketIterator = DataIterator.register(u"epoch_tracking_bucket")(
    EpochTrackingBucketIterator)