Exemple #1
0
 def test_create_batches_groups_correctly(self):
     # pylint: disable=protected-access
     iterator = LazyBasicIterator(batch_size=2)
     batches = list(iterator._create_batches(self.dataset, shuffle=False))
     grouped_instances = [batch.instances for batch in batches]
     assert grouped_instances == [[self.instances[0], self.instances[1]],
                                  [self.instances[2], self.instances[3]],
                                  [self.instances[4]]]
Exemple #2
0
    def test_from_params(self):
        # pylint: disable=protected-access
        params = Params({})
        iterator = LazyBasicIterator.from_params(params)
        assert iterator._batch_size == 32  # default value

        params = Params({"batch_size": 10})
        iterator = LazyBasicIterator.from_params(params)
        assert iterator._batch_size == 10
Exemple #3
0
    def test_small_epochs(self):
        # pylint: disable=protected-access
        iterator = LazyBasicIterator(batch_size=2, instances_per_epoch=2)

        # We should loop around when we get to the end
        batches = list(
            iterator._create_batches(self.instance_iterable, shuffle=False))
        grouped_instances = [batch.instances for batch in batches]
        assert grouped_instances == [[self.instances[0], self.instances[1]]]

        batches = list(
            iterator._create_batches(self.instance_iterable, shuffle=False))
        grouped_instances = [batch.instances for batch in batches]
        assert grouped_instances == [[self.instances[2], self.instances[3]]]

        batches = list(
            iterator._create_batches(self.instance_iterable, shuffle=False))
        grouped_instances = [batch.instances for batch in batches]
        assert grouped_instances == [[self.instances[4], self.instances[0]]]

        batches = list(
            iterator._create_batches(self.instance_iterable, shuffle=False))
        grouped_instances = [batch.instances for batch in batches]
        assert grouped_instances == [[self.instances[1], self.instances[2]]]

        batches = list(
            iterator._create_batches(self.instance_iterable, shuffle=False))
        grouped_instances = [batch.instances for batch in batches]
        assert grouped_instances == [[self.instances[3], self.instances[4]]]

        batches = list(
            iterator._create_batches(self.instance_iterable, shuffle=False))
        grouped_instances = [batch.instances for batch in batches]
        assert grouped_instances == [[self.instances[0], self.instances[1]]]
Exemple #4
0
    def test_multiple_cursors(self):
        # pylint: disable=protected-access
        instances1 = _LazyInstances(lambda:
                                    (i for i in self.instance_iterable))
        instances2 = _LazyInstances(lambda:
                                    (i for i in self.instance_iterable))

        iterator = LazyBasicIterator(batch_size=1, instances_per_epoch=2)
        iterator.index_with(self.vocab)

        # First epoch through dataset1
        batches = list(iterator._create_batches(instances1, shuffle=False))
        grouped_instances = [batch.instances for batch in batches]
        assert grouped_instances == [[self.instances[0]], [self.instances[1]]]

        # First epoch through dataset2
        batches = list(iterator._create_batches(instances2, shuffle=False))
        grouped_instances = [batch.instances for batch in batches]
        assert grouped_instances == [[self.instances[0]], [self.instances[1]]]

        # Second epoch through dataset1
        batches = list(iterator._create_batches(instances1, shuffle=False))
        grouped_instances = [batch.instances for batch in batches]
        assert grouped_instances == [[self.instances[2]], [self.instances[3]]]

        # Second epoch through dataset2
        batches = list(iterator._create_batches(instances2, shuffle=False))
        grouped_instances = [batch.instances for batch in batches]
        assert grouped_instances == [[self.instances[2]], [self.instances[3]]]
Exemple #5
0
 def test_yield_one_epoch_iterates_over_the_data_once(self):
     iterator = LazyBasicIterator(batch_size=2)
     batches = list(iterator(self.instance_iterable, num_epochs=1))
     # We just want to get the single-token array for the text field in the instance.
     instances = [
         tuple(instance.data.cpu().numpy()) for batch in batches
         for instance in batch['text']["tokens"]
     ]
     assert len(instances) == 5
     self.assert_instances_are_correct(instances)
Exemple #6
0
 def test_call_iterates_over_data_forever(self):
     generator = LazyBasicIterator(batch_size=2)(self.instance_iterable)
     batches = [next(generator)
                for _ in range(18)]  # going over the data 6 times
     # We just want to get the single-token array for the text field in the instance.
     instances = [
         tuple(instance.data.cpu().numpy()) for batch in batches
         for instance in batch['text']["tokens"]
     ]
     assert len(instances) == 5 * 6
     self.assert_instances_are_correct(instances)
Exemple #7
0
    def test_multiple_cursors(self):
        # pylint: disable=protected-access
        dataset1 = LazyDataset(lambda: iter(self.instances))
        dataset1.index_instances(self.vocab)

        dataset2 = LazyDataset(lambda: iter(self.instances))
        dataset2.index_instances(self.vocab)

        iterator = LazyBasicIterator(batch_size=1, instances_per_epoch=2)

        # First epoch through dataset1
        batches = list(iterator._create_batches(dataset1, shuffle=False))
        grouped_instances = [batch.instances for batch in batches]
        assert grouped_instances == [[self.instances[0]], [self.instances[1]]]

        # First epoch through dataset2
        batches = list(iterator._create_batches(dataset2, shuffle=False))
        grouped_instances = [batch.instances for batch in batches]
        assert grouped_instances == [[self.instances[0]], [self.instances[1]]]

        # Second epoch through dataset1
        batches = list(iterator._create_batches(dataset1, shuffle=False))
        grouped_instances = [batch.instances for batch in batches]
        assert grouped_instances == [[self.instances[2]], [self.instances[3]]]

        # Second epoch through dataset2
        batches = list(iterator._create_batches(dataset2, shuffle=False))
        grouped_instances = [batch.instances for batch in batches]
        assert grouped_instances == [[self.instances[2]], [self.instances[3]]]