def test_multiple_cursors(self): # pylint: disable=protected-access lazy_instances1 = _LazyInstances(lambda: (i for i in self.instances)) lazy_instances2 = _LazyInstances(lambda: (i for i in self.instances)) eager_instances1 = self.instances[:] eager_instances2 = self.instances[:] for instances1, instances2 in [(eager_instances1, eager_instances2), (lazy_instances1, lazy_instances2)]: iterator = BasicIterator(batch_size=1, instances_per_epoch=2) iterator.index_with(self.vocab) # First epoch through dataset1 batches = list(iterator._create_batches(instances1, shuffle=False)) grouped_instances = [batch.instances for batch in batches] assert grouped_instances == [[self.instances[0]], [self.instances[1]]] # First epoch through dataset2 batches = list(iterator._create_batches(instances2, shuffle=False)) grouped_instances = [batch.instances for batch in batches] assert grouped_instances == [[self.instances[0]], [self.instances[1]]] # Second epoch through dataset1 batches = list(iterator._create_batches(instances1, shuffle=False)) grouped_instances = [batch.instances for batch in batches] assert grouped_instances == [[self.instances[2]], [self.instances[3]]] # Second epoch through dataset2 batches = list(iterator._create_batches(instances2, shuffle=False)) grouped_instances = [batch.instances for batch in batches] assert grouped_instances == [[self.instances[2]], [self.instances[3]]]
def test_few_instances_per_epoch(self): # pylint: disable=protected-access for test_instances in (self.instances, self.lazy_instances): iterator = BasicIterator(batch_size=2, instances_per_epoch=3) # First epoch: 3 instances -> [2, 1] batches = list( iterator._create_batches(test_instances, shuffle=False)) grouped_instances = [batch.instances for batch in batches] assert grouped_instances == [[ self.instances[0], self.instances[1] ], [self.instances[2]]] # Second epoch: 3 instances -> [2, 1] batches = list( iterator._create_batches(test_instances, shuffle=False)) grouped_instances = [batch.instances for batch in batches] assert grouped_instances == [[ self.instances[3], self.instances[4] ], [self.instances[0]]] # Third epoch: 3 instances -> [2, 1] batches = list( iterator._create_batches(test_instances, shuffle=False)) grouped_instances = [batch.instances for batch in batches] assert grouped_instances == [[ self.instances[1], self.instances[2] ], [self.instances[3]]]
def test_multiple_cursors(self): lazy_instances1 = _LazyInstances(lambda: (i for i in self.instances)) lazy_instances2 = _LazyInstances(lambda: (i for i in self.instances)) eager_instances1 = self.instances[:] eager_instances2 = self.instances[:] for instances1, instances2 in [ (eager_instances1, eager_instances2), (lazy_instances1, lazy_instances2), ]: iterator = BasicIterator(batch_size=1, instances_per_epoch=2) iterator.index_with(self.vocab) # First epoch through dataset1 batches = list(iterator._create_batches(instances1, shuffle=False)) grouped_instances = [batch.instances for batch in batches] assert grouped_instances == [[self.instances[0]], [self.instances[1]]] # First epoch through dataset2 batches = list(iterator._create_batches(instances2, shuffle=False)) grouped_instances = [batch.instances for batch in batches] assert grouped_instances == [[self.instances[0]], [self.instances[1]]] # Second epoch through dataset1 batches = list(iterator._create_batches(instances1, shuffle=False)) grouped_instances = [batch.instances for batch in batches] assert grouped_instances == [[self.instances[2]], [self.instances[3]]] # Second epoch through dataset2 batches = list(iterator._create_batches(instances2, shuffle=False)) grouped_instances = [batch.instances for batch in batches] assert grouped_instances == [[self.instances[2]], [self.instances[3]]]
def test_create_batches_groups_correctly(self): # pylint: disable=protected-access iterator = BasicIterator(batch_size=2) grouped_instances = iterator._create_batches(self.dataset, shuffle=False) assert grouped_instances == [[self.instances[0], self.instances[1]], [self.instances[2], self.instances[3]], [self.instances[4]]]
def test_create_batches_groups_correctly(self): # pylint: disable=protected-access iterator = BasicIterator(batch_size=2) batches = list(iterator._create_batches(self.instances, shuffle=False)) grouped_instances = [batch.instances for batch in batches] assert grouped_instances == [[self.instances[0], self.instances[1]], [self.instances[2], self.instances[3]], [self.instances[4]]]
def test_shuffle(self): # pylint: disable=protected-access for test_instances in (self.instances, self.lazy_instances): iterator = BasicIterator(batch_size=2, instances_per_epoch=100) in_order_batches = list(iterator._create_batches(test_instances, shuffle=False)) shuffled_batches = list(iterator._create_batches(test_instances, shuffle=True)) assert len(in_order_batches) == len(shuffled_batches) # With 100 instances, shuffling better change the order. assert in_order_batches != shuffled_batches # But not the counts of the instances. in_order_counts = Counter(instance for batch in in_order_batches for instance in batch) shuffled_counts = Counter(instance for batch in shuffled_batches for instance in batch) assert in_order_counts == shuffled_counts
def test_create_batches_groups_correctly(self): # pylint: disable=protected-access for test_instances in (self.instances, self.lazy_instances): iterator = BasicIterator(batch_size=2) batches = list(iterator._create_batches(test_instances, shuffle=False)) grouped_instances = [batch.instances for batch in batches] assert grouped_instances == [[self.instances[0], self.instances[1]], [self.instances[2], self.instances[3]], [self.instances[4]]]
def test_shuffle(self): # pylint: disable=protected-access for test_instances in (self.instances, self.lazy_instances): iterator = BasicIterator(batch_size=2, instances_per_epoch=100) in_order_batches = list(iterator._create_batches(test_instances, shuffle=False)) shuffled_batches = list(iterator._create_batches(test_instances, shuffle=True)) assert len(in_order_batches) == len(shuffled_batches) # With 100 instances, shuffling better change the order. assert in_order_batches != shuffled_batches # But not the counts of the instances. in_order_counts = Counter(id(instance) for batch in in_order_batches for instance in batch) shuffled_counts = Counter(id(instance) for batch in shuffled_batches for instance in batch) assert in_order_counts == shuffled_counts
def test_max_instances_in_memory(self): # pylint: disable=protected-access for test_instances in (self.instances, self.lazy_instances): iterator = BasicIterator(batch_size=2, max_instances_in_memory=3) # One epoch: 5 instances -> [2, 1, 2] batches = list(iterator._create_batches(test_instances, shuffle=False)) grouped_instances = [batch.instances for batch in batches] assert grouped_instances == [[self.instances[0], self.instances[1]], [self.instances[2]], [self.instances[3], self.instances[4]]]
def test_max_instances_in_memory(self): # pylint: disable=protected-access for test_instances in (self.instances, self.lazy_instances): iterator = BasicIterator(batch_size=2, max_instances_in_memory=3) # One epoch: 5 instances -> [2, 1, 2] batches = list(iterator._create_batches(test_instances, shuffle=False)) grouped_instances = [batch.instances for batch in batches] assert grouped_instances == [[self.instances[0], self.instances[1]], [self.instances[2]], [self.instances[3], self.instances[4]]]
def test_many_instances_per_epoch(self): # pylint: disable=protected-access for test_instances in (self.instances, self.lazy_instances): iterator = BasicIterator(batch_size=2, instances_per_epoch=7) # First epoch: 7 instances -> [2, 2, 2, 1] batches = list(iterator._create_batches(test_instances, shuffle=False)) grouped_instances = [batch.instances for batch in batches] assert grouped_instances == [[self.instances[0], self.instances[1]], [self.instances[2], self.instances[3]], [self.instances[4], self.instances[0]], [self.instances[1]]] # Second epoch: 7 instances -> [2, 2, 2, 1] batches = list(iterator._create_batches(test_instances, shuffle=False)) grouped_instances = [batch.instances for batch in batches] assert grouped_instances == [[self.instances[2], self.instances[3]], [self.instances[4], self.instances[0]], [self.instances[1], self.instances[2]], [self.instances[3]]]
def test_create_batches_groups_correctly(self): for test_instances in (self.instances, self.lazy_instances): iterator = BasicIterator(batch_size=2) batches = list(iterator._create_batches(test_instances, shuffle=False)) grouped_instances = [batch.instances for batch in batches] assert grouped_instances == [ [self.instances[0], self.instances[1]], [self.instances[2], self.instances[3]], [self.instances[4]], ]
def test_many_instances_per_epoch(self): for test_instances in (self.instances, self.lazy_instances): iterator = BasicIterator(batch_size=2, instances_per_epoch=7) # First epoch: 7 instances -> [2, 2, 2, 1] batches = list(iterator._create_batches(test_instances, shuffle=False)) grouped_instances = [batch.instances for batch in batches] assert grouped_instances == [ [self.instances[0], self.instances[1]], [self.instances[2], self.instances[3]], [self.instances[4], self.instances[0]], [self.instances[1]], ] # Second epoch: 7 instances -> [2, 2, 2, 1] batches = list(iterator._create_batches(test_instances, shuffle=False)) grouped_instances = [batch.instances for batch in batches] assert grouped_instances == [ [self.instances[2], self.instances[3]], [self.instances[4], self.instances[0]], [self.instances[1], self.instances[2]], [self.instances[3]], ]
def test_maximum_samples_per_batch(self): for test_instances in (self.instances, self.lazy_instances): iterator = BasicIterator(batch_size=3, maximum_samples_per_batch=["tokens___tokens", 9]) iterator.index_with(self.vocab) batches = list(iterator._create_batches(test_instances, shuffle=False)) stats = self.get_batches_stats(batches) # ensure all instances are in a batch assert stats["total_instances"] == len(self.instances) # ensure correct batch sizes assert stats["batch_lengths"] == [2, 1, 1, 1] # ensure correct sample sizes (<= 9) assert stats["sample_sizes"] == [8, 3, 9, 1]
def test_maximum_samples_per_batch(self): for test_instances in (self.instances, self.lazy_instances): # pylint: disable=protected-access iterator = BasicIterator( batch_size=3, maximum_samples_per_batch=['num_tokens', 9]) iterator.index_with(self.vocab) batches = list( iterator._create_batches(test_instances, shuffle=False)) stats = self.get_batches_stats(batches) # ensure all instances are in a batch assert stats['total_instances'] == len(self.instances) # ensure correct batch sizes assert stats['batch_lengths'] == [2, 1, 1, 1] # ensure correct sample sizes (<= 9) assert stats['sample_sizes'] == [8, 3, 9, 1]
def test_maximum_samples_per_batch(self): for test_instances in (self.instances, self.lazy_instances): # pylint: disable=protected-access iterator = BasicIterator( batch_size=3, maximum_samples_per_batch=['num_tokens', 9] ) iterator.index_with(self.vocab) batches = list(iterator._create_batches(test_instances, shuffle=False)) stats = self.get_batches_stats(batches) # ensure all instances are in a batch assert stats['total_instances'] == len(self.instances) # ensure correct batch sizes assert stats['batch_lengths'] == [2, 1, 1, 1] # ensure correct sample sizes (<= 9) assert stats['sample_sizes'] == [8, 3, 9, 1]
def test_maximum_samples_per_batch_packs_tightly(self): token_counts = [10, 4, 3] test_instances = self.create_instances_from_token_counts(token_counts) iterator = BasicIterator(batch_size=3, maximum_samples_per_batch=["tokens___tokens", 11]) iterator.index_with(self.vocab) batches = list(iterator._create_batches(test_instances, shuffle=False)) stats = self.get_batches_stats(batches) # ensure all instances are in a batch assert stats["total_instances"] == len(token_counts) # ensure correct batch sizes assert stats["batch_lengths"] == [1, 2] # ensure correct sample sizes (<= 11) assert stats["sample_sizes"] == [10, 8]
def test_maximum_samples_per_batch(self): for test_instances in (self.instances, self.lazy_instances): # pylint: disable=protected-access iterator = BasicIterator( batch_size=3, maximum_samples_per_batch=['num_tokens', 9]) batches = list( iterator._create_batches(test_instances, shuffle=False)) # ensure all instances are in a batch grouped_instances = [batch.instances for batch in batches] num_instances = sum(len(group) for group in grouped_instances) assert num_instances == len(self.instances) # ensure all batches are sufficiently small for batch in batches: batch_sequence_length = max([ instance.get_padding_lengths()['text']['num_tokens'] for instance in batch.instances ]) assert batch_sequence_length * len(batch.instances) <= 9
def test_maximum_samples_per_batch_packs_tightly(self): # pylint: disable=protected-access token_counts = [10, 4, 3] test_instances = self.create_instances_from_token_counts(token_counts) iterator = BasicIterator( batch_size=3, maximum_samples_per_batch=['num_tokens', 11] ) iterator.index_with(self.vocab) batches = list(iterator._create_batches(test_instances, shuffle=False)) stats = self.get_batches_stats(batches) # ensure all instances are in a batch assert stats['total_instances'] == len(token_counts) # ensure correct batch sizes assert stats['batch_lengths'] == [1, 2] # ensure correct sample sizes (<= 11) assert stats['sample_sizes'] == [10, 8]
def test_maximum_samples_per_batch(self): for test_instances in (self.instances, self.lazy_instances): # pylint: disable=protected-access iterator = BasicIterator( batch_size=3, maximum_samples_per_batch=['num_tokens', 9] ) batches = list(iterator._create_batches(test_instances, shuffle=False)) # ensure all instances are in a batch grouped_instances = [batch.instances for batch in batches] num_instances = sum(len(group) for group in grouped_instances) assert num_instances == len(self.instances) # ensure all batches are sufficiently small for batch in batches: batch_sequence_length = max( [instance.get_padding_lengths()['text']['num_tokens'] for instance in batch.instances] ) assert batch_sequence_length * len(batch.instances) <= 9