Example #1
0
    def test_multiple_cursors(self):
        # pylint: disable=protected-access
        lazy_instances1 = _LazyInstances(lambda: (i for i in self.instances))
        lazy_instances2 = _LazyInstances(lambda: (i for i in self.instances))

        eager_instances1 = self.instances[:]
        eager_instances2 = self.instances[:]

        for instances1, instances2 in [(eager_instances1, eager_instances2),
                                       (lazy_instances1, lazy_instances2)]:
            iterator = BasicIterator(batch_size=1, instances_per_epoch=2)
            iterator.index_with(self.vocab)

            # First epoch through dataset1
            batches = list(iterator._create_batches(instances1, shuffle=False))
            grouped_instances = [batch.instances for batch in batches]
            assert grouped_instances == [[self.instances[0]], [self.instances[1]]]

            # First epoch through dataset2
            batches = list(iterator._create_batches(instances2, shuffle=False))
            grouped_instances = [batch.instances for batch in batches]
            assert grouped_instances == [[self.instances[0]], [self.instances[1]]]

            # Second epoch through dataset1
            batches = list(iterator._create_batches(instances1, shuffle=False))
            grouped_instances = [batch.instances for batch in batches]
            assert grouped_instances == [[self.instances[2]], [self.instances[3]]]

            # Second epoch through dataset2
            batches = list(iterator._create_batches(instances2, shuffle=False))
            grouped_instances = [batch.instances for batch in batches]
            assert grouped_instances == [[self.instances[2]], [self.instances[3]]]
Example #2
0
 def test_few_instances_per_epoch(self):
     # pylint: disable=protected-access
     for test_instances in (self.instances, self.lazy_instances):
         iterator = BasicIterator(batch_size=2, instances_per_epoch=3)
         # First epoch: 3 instances -> [2, 1]
         batches = list(
             iterator._create_batches(test_instances, shuffle=False))
         grouped_instances = [batch.instances for batch in batches]
         assert grouped_instances == [[
             self.instances[0], self.instances[1]
         ], [self.instances[2]]]
         # Second epoch: 3 instances -> [2, 1]
         batches = list(
             iterator._create_batches(test_instances, shuffle=False))
         grouped_instances = [batch.instances for batch in batches]
         assert grouped_instances == [[
             self.instances[3], self.instances[4]
         ], [self.instances[0]]]
         # Third epoch: 3 instances -> [2, 1]
         batches = list(
             iterator._create_batches(test_instances, shuffle=False))
         grouped_instances = [batch.instances for batch in batches]
         assert grouped_instances == [[
             self.instances[1], self.instances[2]
         ], [self.instances[3]]]
Example #3
0
    def test_multiple_cursors(self):

        lazy_instances1 = _LazyInstances(lambda: (i for i in self.instances))
        lazy_instances2 = _LazyInstances(lambda: (i for i in self.instances))

        eager_instances1 = self.instances[:]
        eager_instances2 = self.instances[:]

        for instances1, instances2 in [
            (eager_instances1, eager_instances2),
            (lazy_instances1, lazy_instances2),
        ]:
            iterator = BasicIterator(batch_size=1, instances_per_epoch=2)
            iterator.index_with(self.vocab)

            # First epoch through dataset1
            batches = list(iterator._create_batches(instances1, shuffle=False))
            grouped_instances = [batch.instances for batch in batches]
            assert grouped_instances == [[self.instances[0]], [self.instances[1]]]

            # First epoch through dataset2
            batches = list(iterator._create_batches(instances2, shuffle=False))
            grouped_instances = [batch.instances for batch in batches]
            assert grouped_instances == [[self.instances[0]], [self.instances[1]]]

            # Second epoch through dataset1
            batches = list(iterator._create_batches(instances1, shuffle=False))
            grouped_instances = [batch.instances for batch in batches]
            assert grouped_instances == [[self.instances[2]], [self.instances[3]]]

            # Second epoch through dataset2
            batches = list(iterator._create_batches(instances2, shuffle=False))
            grouped_instances = [batch.instances for batch in batches]
            assert grouped_instances == [[self.instances[2]], [self.instances[3]]]
Example #4
0
 def test_create_batches_groups_correctly(self):
     # pylint: disable=protected-access
     iterator = BasicIterator(batch_size=2)
     grouped_instances = iterator._create_batches(self.dataset, shuffle=False)
     assert grouped_instances == [[self.instances[0], self.instances[1]],
                                  [self.instances[2], self.instances[3]],
                                  [self.instances[4]]]
Example #5
0
 def test_create_batches_groups_correctly(self):
     # pylint: disable=protected-access
     iterator = BasicIterator(batch_size=2)
     batches = list(iterator._create_batches(self.instances, shuffle=False))
     grouped_instances = [batch.instances for batch in batches]
     assert grouped_instances == [[self.instances[0], self.instances[1]],
                                  [self.instances[2], self.instances[3]],
                                  [self.instances[4]]]
Example #6
0
    def test_shuffle(self):
        # pylint: disable=protected-access
        for test_instances in (self.instances, self.lazy_instances):

            iterator = BasicIterator(batch_size=2, instances_per_epoch=100)

            in_order_batches = list(iterator._create_batches(test_instances, shuffle=False))
            shuffled_batches = list(iterator._create_batches(test_instances, shuffle=True))

            assert len(in_order_batches) == len(shuffled_batches)

            # With 100 instances, shuffling better change the order.
            assert in_order_batches != shuffled_batches

            # But not the counts of the instances.
            in_order_counts = Counter(instance for batch in in_order_batches for instance in batch)
            shuffled_counts = Counter(instance for batch in shuffled_batches for instance in batch)
            assert in_order_counts == shuffled_counts
Example #7
0
 def test_create_batches_groups_correctly(self):
     # pylint: disable=protected-access
     for test_instances in (self.instances, self.lazy_instances):
         iterator = BasicIterator(batch_size=2)
         batches = list(iterator._create_batches(test_instances, shuffle=False))
         grouped_instances = [batch.instances for batch in batches]
         assert grouped_instances == [[self.instances[0], self.instances[1]],
                                      [self.instances[2], self.instances[3]],
                                      [self.instances[4]]]
Example #8
0
    def test_shuffle(self):
        # pylint: disable=protected-access
        for test_instances in (self.instances, self.lazy_instances):

            iterator = BasicIterator(batch_size=2, instances_per_epoch=100)

            in_order_batches = list(iterator._create_batches(test_instances, shuffle=False))
            shuffled_batches = list(iterator._create_batches(test_instances, shuffle=True))

            assert len(in_order_batches) == len(shuffled_batches)

            # With 100 instances, shuffling better change the order.
            assert in_order_batches != shuffled_batches

            # But not the counts of the instances.
            in_order_counts = Counter(id(instance) for batch in in_order_batches for instance in batch)
            shuffled_counts = Counter(id(instance) for batch in shuffled_batches for instance in batch)
            assert in_order_counts == shuffled_counts
Example #9
0
 def test_max_instances_in_memory(self):
     # pylint: disable=protected-access
     for test_instances in (self.instances, self.lazy_instances):
         iterator = BasicIterator(batch_size=2, max_instances_in_memory=3)
         # One epoch: 5 instances -> [2, 1, 2]
         batches = list(iterator._create_batches(test_instances, shuffle=False))
         grouped_instances = [batch.instances for batch in batches]
         assert grouped_instances == [[self.instances[0], self.instances[1]],
                                      [self.instances[2]],
                                      [self.instances[3], self.instances[4]]]
Example #10
0
 def test_max_instances_in_memory(self):
     # pylint: disable=protected-access
     for test_instances in (self.instances, self.lazy_instances):
         iterator = BasicIterator(batch_size=2, max_instances_in_memory=3)
         # One epoch: 5 instances -> [2, 1, 2]
         batches = list(iterator._create_batches(test_instances, shuffle=False))
         grouped_instances = [batch.instances for batch in batches]
         assert grouped_instances == [[self.instances[0], self.instances[1]],
                                      [self.instances[2]],
                                      [self.instances[3], self.instances[4]]]
Example #11
0
    def test_many_instances_per_epoch(self):
        # pylint: disable=protected-access
        for test_instances in (self.instances, self.lazy_instances):
            iterator = BasicIterator(batch_size=2, instances_per_epoch=7)
            # First epoch: 7 instances -> [2, 2, 2, 1]
            batches = list(iterator._create_batches(test_instances, shuffle=False))
            grouped_instances = [batch.instances for batch in batches]
            assert grouped_instances == [[self.instances[0], self.instances[1]],
                                         [self.instances[2], self.instances[3]],
                                         [self.instances[4], self.instances[0]],
                                         [self.instances[1]]]

            # Second epoch: 7 instances -> [2, 2, 2, 1]
            batches = list(iterator._create_batches(test_instances, shuffle=False))
            grouped_instances = [batch.instances for batch in batches]
            assert grouped_instances == [[self.instances[2], self.instances[3]],
                                         [self.instances[4], self.instances[0]],
                                         [self.instances[1], self.instances[2]],
                                         [self.instances[3]]]
Example #12
0
    def test_create_batches_groups_correctly(self):

        for test_instances in (self.instances, self.lazy_instances):
            iterator = BasicIterator(batch_size=2)
            batches = list(iterator._create_batches(test_instances, shuffle=False))
            grouped_instances = [batch.instances for batch in batches]
            assert grouped_instances == [
                [self.instances[0], self.instances[1]],
                [self.instances[2], self.instances[3]],
                [self.instances[4]],
            ]
Example #13
0
    def test_many_instances_per_epoch(self):

        for test_instances in (self.instances, self.lazy_instances):
            iterator = BasicIterator(batch_size=2, instances_per_epoch=7)
            # First epoch: 7 instances -> [2, 2, 2, 1]
            batches = list(iterator._create_batches(test_instances, shuffle=False))
            grouped_instances = [batch.instances for batch in batches]
            assert grouped_instances == [
                [self.instances[0], self.instances[1]],
                [self.instances[2], self.instances[3]],
                [self.instances[4], self.instances[0]],
                [self.instances[1]],
            ]

            # Second epoch: 7 instances -> [2, 2, 2, 1]
            batches = list(iterator._create_batches(test_instances, shuffle=False))
            grouped_instances = [batch.instances for batch in batches]
            assert grouped_instances == [
                [self.instances[2], self.instances[3]],
                [self.instances[4], self.instances[0]],
                [self.instances[1], self.instances[2]],
                [self.instances[3]],
            ]
Example #14
0
    def test_maximum_samples_per_batch(self):
        for test_instances in (self.instances, self.lazy_instances):

            iterator = BasicIterator(batch_size=3, maximum_samples_per_batch=["tokens___tokens", 9])
            iterator.index_with(self.vocab)
            batches = list(iterator._create_batches(test_instances, shuffle=False))
            stats = self.get_batches_stats(batches)

            # ensure all instances are in a batch
            assert stats["total_instances"] == len(self.instances)

            # ensure correct batch sizes
            assert stats["batch_lengths"] == [2, 1, 1, 1]

            # ensure correct sample sizes (<= 9)
            assert stats["sample_sizes"] == [8, 3, 9, 1]
Example #15
0
    def test_maximum_samples_per_batch(self):
        for test_instances in (self.instances, self.lazy_instances):
            # pylint: disable=protected-access
            iterator = BasicIterator(
                batch_size=3, maximum_samples_per_batch=['num_tokens', 9])
            iterator.index_with(self.vocab)
            batches = list(
                iterator._create_batches(test_instances, shuffle=False))
            stats = self.get_batches_stats(batches)

            # ensure all instances are in a batch
            assert stats['total_instances'] == len(self.instances)

            # ensure correct batch sizes
            assert stats['batch_lengths'] == [2, 1, 1, 1]

            # ensure correct sample sizes (<= 9)
            assert stats['sample_sizes'] == [8, 3, 9, 1]
Example #16
0
    def test_maximum_samples_per_batch(self):
        for test_instances in (self.instances, self.lazy_instances):
            # pylint: disable=protected-access
            iterator = BasicIterator(
                    batch_size=3, maximum_samples_per_batch=['num_tokens', 9]
            )
            iterator.index_with(self.vocab)
            batches = list(iterator._create_batches(test_instances, shuffle=False))
            stats = self.get_batches_stats(batches)

            # ensure all instances are in a batch
            assert stats['total_instances'] == len(self.instances)

            # ensure correct batch sizes
            assert stats['batch_lengths'] == [2, 1, 1, 1]

            # ensure correct sample sizes (<= 9)
            assert stats['sample_sizes'] == [8, 3, 9, 1]
Example #17
0
    def test_maximum_samples_per_batch_packs_tightly(self):

        token_counts = [10, 4, 3]
        test_instances = self.create_instances_from_token_counts(token_counts)

        iterator = BasicIterator(batch_size=3, maximum_samples_per_batch=["tokens___tokens", 11])
        iterator.index_with(self.vocab)
        batches = list(iterator._create_batches(test_instances, shuffle=False))
        stats = self.get_batches_stats(batches)

        # ensure all instances are in a batch
        assert stats["total_instances"] == len(token_counts)

        # ensure correct batch sizes
        assert stats["batch_lengths"] == [1, 2]

        # ensure correct sample sizes (<= 11)
        assert stats["sample_sizes"] == [10, 8]
Example #18
0
    def test_maximum_samples_per_batch(self):
        for test_instances in (self.instances, self.lazy_instances):
            # pylint: disable=protected-access
            iterator = BasicIterator(
                batch_size=3, maximum_samples_per_batch=['num_tokens', 9])
            batches = list(
                iterator._create_batches(test_instances, shuffle=False))

            # ensure all instances are in a batch
            grouped_instances = [batch.instances for batch in batches]
            num_instances = sum(len(group) for group in grouped_instances)
            assert num_instances == len(self.instances)

            # ensure all batches are sufficiently small
            for batch in batches:
                batch_sequence_length = max([
                    instance.get_padding_lengths()['text']['num_tokens']
                    for instance in batch.instances
                ])
                assert batch_sequence_length * len(batch.instances) <= 9
Example #19
0
    def test_maximum_samples_per_batch_packs_tightly(self):
        # pylint: disable=protected-access
        token_counts = [10, 4, 3]
        test_instances = self.create_instances_from_token_counts(token_counts)

        iterator = BasicIterator(
                batch_size=3, maximum_samples_per_batch=['num_tokens', 11]
        )
        iterator.index_with(self.vocab)
        batches = list(iterator._create_batches(test_instances, shuffle=False))
        stats = self.get_batches_stats(batches)

        # ensure all instances are in a batch
        assert stats['total_instances'] == len(token_counts)

        # ensure correct batch sizes
        assert stats['batch_lengths'] == [1, 2]

        # ensure correct sample sizes (<= 11)
        assert stats['sample_sizes'] == [10, 8]
Example #20
0
    def test_maximum_samples_per_batch(self):
        for test_instances in (self.instances, self.lazy_instances):
            # pylint: disable=protected-access
            iterator = BasicIterator(
                    batch_size=3, maximum_samples_per_batch=['num_tokens', 9]
            )
            batches = list(iterator._create_batches(test_instances, shuffle=False))

            # ensure all instances are in a batch
            grouped_instances = [batch.instances for batch in batches]
            num_instances = sum(len(group) for group in grouped_instances)
            assert num_instances == len(self.instances)

            # ensure all batches are sufficiently small
            for batch in batches:
                batch_sequence_length = max(
                        [instance.get_padding_lengths()['text']['num_tokens']
                         for instance in batch.instances]
                )
                assert batch_sequence_length * len(batch.instances) <= 9