def test_few_instances_per_epoch(self):
     # pylint: disable=protected-access
     for test_instances in (self.instances, self.lazy_instances):
         iterator = BasicIterator(batch_size=2, instances_per_epoch=3)
         # First epoch: 3 instances -> [2, 1]
         batches = list(
             iterator._create_batches(test_instances, shuffle=False))
         grouped_instances = [batch.instances for batch in batches]
         assert grouped_instances == [[
             self.instances[0], self.instances[1]
         ], [self.instances[2]]]
         # Second epoch: 3 instances -> [2, 1]
         batches = list(
             iterator._create_batches(test_instances, shuffle=False))
         grouped_instances = [batch.instances for batch in batches]
         assert grouped_instances == [[
             self.instances[3], self.instances[4]
         ], [self.instances[0]]]
         # Third epoch: 3 instances -> [2, 1]
         batches = list(
             iterator._create_batches(test_instances, shuffle=False))
         grouped_instances = [batch.instances for batch in batches]
         assert grouped_instances == [[
             self.instances[1], self.instances[2]
         ], [self.instances[3]]]
 def test_create_batches_groups_correctly(self):
     # pylint: disable=protected-access
     for test_instances in (self.instances, self.lazy_instances):
         iterator = BasicIterator(batch_size=2)
         batches = list(
             iterator._create_batches(test_instances, shuffle=False))
         grouped_instances = [batch.instances for batch in batches]
         assert grouped_instances == [[
             self.instances[0], self.instances[1]
         ], [self.instances[2], self.instances[3]], [self.instances[4]]]
 def test_max_instances_in_memory(self):
     # pylint: disable=protected-access
     for test_instances in (self.instances, self.lazy_instances):
         iterator = BasicIterator(batch_size=2, max_instances_in_memory=3)
         # One epoch: 5 instances -> [2, 1, 2]
         batches = list(
             iterator._create_batches(test_instances, shuffle=False))
         grouped_instances = [batch.instances for batch in batches]
         assert grouped_instances == [[
             self.instances[0], self.instances[1]
         ], [self.instances[2]], [self.instances[3], self.instances[4]]]
    def test_shuffle(self):
        # pylint: disable=protected-access
        for test_instances in (self.instances, self.lazy_instances):

            iterator = BasicIterator(batch_size=2, instances_per_epoch=100)

            in_order_batches = list(
                iterator._create_batches(test_instances, shuffle=False))
            shuffled_batches = list(
                iterator._create_batches(test_instances, shuffle=True))

            assert len(in_order_batches) == len(shuffled_batches)

            # With 100 instances, shuffling better change the order.
            assert in_order_batches != shuffled_batches

            # But not the counts of the instances.
            in_order_counts = Counter(instance for batch in in_order_batches
                                      for instance in batch)
            shuffled_counts = Counter(instance for batch in shuffled_batches
                                      for instance in batch)
            assert in_order_counts == shuffled_counts
    def test_multiple_cursors(self):
        # pylint: disable=protected-access
        lazy_instances1 = _LazyInstances(lambda: (i for i in self.instances))
        lazy_instances2 = _LazyInstances(lambda: (i for i in self.instances))

        eager_instances1 = self.instances[:]
        eager_instances2 = self.instances[:]

        for instances1, instances2 in [(eager_instances1, eager_instances2),
                                       (lazy_instances1, lazy_instances2)]:
            iterator = BasicIterator(batch_size=1, instances_per_epoch=2)
            iterator.index_with(self.vocab)

            # First epoch through dataset1
            batches = list(iterator._create_batches(instances1, shuffle=False))
            grouped_instances = [batch.instances for batch in batches]
            assert grouped_instances == [[self.instances[0]],
                                         [self.instances[1]]]

            # First epoch through dataset2
            batches = list(iterator._create_batches(instances2, shuffle=False))
            grouped_instances = [batch.instances for batch in batches]
            assert grouped_instances == [[self.instances[0]],
                                         [self.instances[1]]]

            # Second epoch through dataset1
            batches = list(iterator._create_batches(instances1, shuffle=False))
            grouped_instances = [batch.instances for batch in batches]
            assert grouped_instances == [[self.instances[2]],
                                         [self.instances[3]]]

            # Second epoch through dataset2
            batches = list(iterator._create_batches(instances2, shuffle=False))
            grouped_instances = [batch.instances for batch in batches]
            assert grouped_instances == [[self.instances[2]],
                                         [self.instances[3]]]
    def test_maximum_samples_per_batch(self):
        for test_instances in (self.instances, self.lazy_instances):
            # pylint: disable=protected-access
            iterator = BasicIterator(
                batch_size=3, maximum_samples_per_batch=['num_tokens', 9])
            batches = list(
                iterator._create_batches(test_instances, shuffle=False))

            # ensure all instances are in a batch
            grouped_instances = [batch.instances for batch in batches]
            num_instances = sum(len(group) for group in grouped_instances)
            assert num_instances == len(self.instances)

            # ensure all batches are sufficiently small
            for batch in batches:
                batch_sequence_length = max([
                    instance.get_padding_lengths()['text']['num_tokens']
                    for instance in batch.instances
                ])
                assert batch_sequence_length * len(batch.instances) <= 9