def test_sampling(self): data = [1, 2, 3, 4, 5] weights = [1, 2, 3, 4, 5] sampler = DistributedWeightedRandomSampler( weights=weights, dataset=data, shuffle=False, generator=torch.Generator().manual_seed(0) ) samples = np.array([data[i] for i in list(sampler)]) if dist.get_rank() == 0: np.testing.assert_allclose(samples, np.array([5, 5, 5])) if dist.get_rank() == 1: np.testing.assert_allclose(samples, np.array([1, 4, 4]))
def test_num_samples(self): data = [1, 2, 3, 4, 5] weights = [1, 2, 3, 4, 5] sampler = DistributedWeightedRandomSampler( weights=weights, num_samples_per_rank=5, replacement=True, dataset=data, shuffle=False, generator=torch.Generator().manual_seed(123), ) samples = np.array([data[i] for i in list(sampler)]) if dist.get_rank() == 0: np.testing.assert_allclose(samples, np.array([3, 1, 5, 1, 5])) if dist.get_rank() == 1: np.testing.assert_allclose(samples, np.array([4, 2, 4, 2, 4]))