def test_few_instances_per_epoch(self): # pylint: disable=protected-access for test_instances in (self.instances, self.lazy_instances): iterator = BasicIterator(batch_size=2, instances_per_epoch=3) # First epoch: 3 instances -> [2, 1] batches = list( iterator._create_batches(test_instances, shuffle=False)) grouped_instances = [batch.instances for batch in batches] assert grouped_instances == [[ self.instances[0], self.instances[1] ], [self.instances[2]]] # Second epoch: 3 instances -> [2, 1] batches = list( iterator._create_batches(test_instances, shuffle=False)) grouped_instances = [batch.instances for batch in batches] assert grouped_instances == [[ self.instances[3], self.instances[4] ], [self.instances[0]]] # Third epoch: 3 instances -> [2, 1] batches = list( iterator._create_batches(test_instances, shuffle=False)) grouped_instances = [batch.instances for batch in batches] assert grouped_instances == [[ self.instances[1], self.instances[2] ], [self.instances[3]]]
def test_create_batches_groups_correctly(self): # pylint: disable=protected-access for test_instances in (self.instances, self.lazy_instances): iterator = BasicIterator(batch_size=2) batches = list( iterator._create_batches(test_instances, shuffle=False)) grouped_instances = [batch.instances for batch in batches] assert grouped_instances == [[ self.instances[0], self.instances[1] ], [self.instances[2], self.instances[3]], [self.instances[4]]]
def test_max_instances_in_memory(self): # pylint: disable=protected-access for test_instances in (self.instances, self.lazy_instances): iterator = BasicIterator(batch_size=2, max_instances_in_memory=3) # One epoch: 5 instances -> [2, 1, 2] batches = list( iterator._create_batches(test_instances, shuffle=False)) grouped_instances = [batch.instances for batch in batches] assert grouped_instances == [[ self.instances[0], self.instances[1] ], [self.instances[2]], [self.instances[3], self.instances[4]]]
def test_shuffle(self): # pylint: disable=protected-access for test_instances in (self.instances, self.lazy_instances): iterator = BasicIterator(batch_size=2, instances_per_epoch=100) in_order_batches = list( iterator._create_batches(test_instances, shuffle=False)) shuffled_batches = list( iterator._create_batches(test_instances, shuffle=True)) assert len(in_order_batches) == len(shuffled_batches) # With 100 instances, shuffling better change the order. assert in_order_batches != shuffled_batches # But not the counts of the instances. in_order_counts = Counter(instance for batch in in_order_batches for instance in batch) shuffled_counts = Counter(instance for batch in shuffled_batches for instance in batch) assert in_order_counts == shuffled_counts
def test_multiple_cursors(self): # pylint: disable=protected-access lazy_instances1 = _LazyInstances(lambda: (i for i in self.instances)) lazy_instances2 = _LazyInstances(lambda: (i for i in self.instances)) eager_instances1 = self.instances[:] eager_instances2 = self.instances[:] for instances1, instances2 in [(eager_instances1, eager_instances2), (lazy_instances1, lazy_instances2)]: iterator = BasicIterator(batch_size=1, instances_per_epoch=2) iterator.index_with(self.vocab) # First epoch through dataset1 batches = list(iterator._create_batches(instances1, shuffle=False)) grouped_instances = [batch.instances for batch in batches] assert grouped_instances == [[self.instances[0]], [self.instances[1]]] # First epoch through dataset2 batches = list(iterator._create_batches(instances2, shuffle=False)) grouped_instances = [batch.instances for batch in batches] assert grouped_instances == [[self.instances[0]], [self.instances[1]]] # Second epoch through dataset1 batches = list(iterator._create_batches(instances1, shuffle=False)) grouped_instances = [batch.instances for batch in batches] assert grouped_instances == [[self.instances[2]], [self.instances[3]]] # Second epoch through dataset2 batches = list(iterator._create_batches(instances2, shuffle=False)) grouped_instances = [batch.instances for batch in batches] assert grouped_instances == [[self.instances[2]], [self.instances[3]]]
def test_maximum_samples_per_batch(self): for test_instances in (self.instances, self.lazy_instances): # pylint: disable=protected-access iterator = BasicIterator( batch_size=3, maximum_samples_per_batch=['num_tokens', 9]) batches = list( iterator._create_batches(test_instances, shuffle=False)) # ensure all instances are in a batch grouped_instances = [batch.instances for batch in batches] num_instances = sum(len(group) for group in grouped_instances) assert num_instances == len(self.instances) # ensure all batches are sufficiently small for batch in batches: batch_sequence_length = max([ instance.get_padding_lengths()['text']['num_tokens'] for instance in batch.instances ]) assert batch_sequence_length * len(batch.instances) <= 9