def test_sampling(self):
        data = [1, 2, 3, 4, 5]
        weights = [1, 2, 3, 4, 5]
        sampler = DistributedWeightedRandomSampler(
            weights=weights, dataset=data, shuffle=False, generator=torch.Generator().manual_seed(0)
        )
        samples = np.array([data[i] for i in list(sampler)])

        if dist.get_rank() == 0:
            np.testing.assert_allclose(samples, np.array([5, 5, 5]))

        if dist.get_rank() == 1:
            np.testing.assert_allclose(samples, np.array([1, 4, 4]))
    def test_num_samples(self):
        data = [1, 2, 3, 4, 5]
        weights = [1, 2, 3, 4, 5]
        sampler = DistributedWeightedRandomSampler(
            weights=weights,
            num_samples_per_rank=5,
            replacement=True,
            dataset=data,
            shuffle=False,
            generator=torch.Generator().manual_seed(123),
        )
        samples = np.array([data[i] for i in list(sampler)])

        if dist.get_rank() == 0:
            np.testing.assert_allclose(samples, np.array([3, 1, 5, 1, 5]))

        if dist.get_rank() == 1:
            np.testing.assert_allclose(samples, np.array([4, 2, 4, 2, 4]))