def test_shard(self): train_batch_sampler = SamplerHelper(self.train_ds) shard_sampler1 = train_batch_sampler.shard(2, 0) shard_sampler2 = train_batch_sampler.shard(2, 1) for i, sample in enumerate(shard_sampler1): self.check_output_equal(i * 2, sample) for i, sample in enumerate(shard_sampler2): self.check_output_equal(i * 2 + 1, sample)
def test_shard_default(self): train_batch_sampler = SamplerHelper(self.train_ds) shard_sampler1 = train_batch_sampler.shard() for i, sample in enumerate(shard_sampler1): self.check_output_equal(i, sample)