def test_iterable_dataset_shard_with_length(self): sampler_shards = [ IterableDatasetShard(list(range(100)), batch_size=4, drop_last=True, num_processes=2, process_index=i) for i in range(2) ] # Build expected shards: each process will have batches of size 4 until there is not enough elements to # form two full batches (so we stop at 96 = (100 // (4 * 2)) * 4) expected_shards = [[], []] current_shard = 0 for i in range(0, 96, 4): expected_shards[current_shard].extend(list(range(i, i + 4))) current_shard = 1 - current_shard self.assertListEqual([list(shard) for shard in sampler_shards], expected_shards) self.assertListEqual([len(shard) for shard in sampler_shards], [len(shard) for shard in expected_shards]) sampler_shards = [ IterableDatasetShard(list(range(100)), batch_size=4, drop_last=False, num_processes=2, process_index=i) for i in range(2) ] # When drop_last=False, we get two last full batches by looping back to the beginning. expected_shards[0].extend(list(range(96, 100))) expected_shards[1].extend(list(range(0, 4))) self.assertListEqual([list(shard) for shard in sampler_shards], expected_shards) self.assertListEqual([len(shard) for shard in sampler_shards], [len(shard) for shard in expected_shards])
def check_iterable_dataset_shard(self, dataset, batch_size, drop_last, num_processes=2, epoch=0): # Set the seed for the base dataset to get the proper reference. dataset.generator.manual_seed(epoch) reference = list(dataset) shards = [ IterableDatasetShard(dataset, batch_size=batch_size, drop_last=drop_last, num_processes=num_processes, process_index=i) for i in range(num_processes) ] for shard in shards: shard.set_epoch(epoch) shard_lists = [list(shard) for shard in shards] for shard in shard_lists: # All shards have a number of samples that is a round multiple of batch size self.assertTrue(len(shard) % batch_size == 0) # All shards have the same number of samples self.assertEqual(len(shard), len(shard_lists[0])) for shard in shards: # All shards know the total number of samples self.assertEqual(shard.num_examples, len(reference)) observed = [] for idx in range(0, len(shard_lists[0]), batch_size): for shard in shard_lists: observed += shard[idx:idx + batch_size] # If drop_last is False we loop through samples at the beginning to have a size that is a round multiple of # batch_size if not drop_last: while len(reference) < len(observed): reference += reference self.assertListEqual(observed, reference[:len(observed)]) # Check equivalence between IterableDataset and ShardSampler dataset.generator.manual_seed(epoch) reference = list(dataset) sampler_shards = [ ShardSampler(reference, batch_size=batch_size, drop_last=drop_last, num_processes=num_processes, process_index=i) for i in range(num_processes) ] for shard, sampler_shard in zip(shard_lists, sampler_shards): self.assertListEqual(shard, list(sampler_shard))