def test_batch_sampler(self):
        input_data = torch.randn((103, 1, 4, 10))

        buck_sampler_no_drop_last = BucketingSampler(input_data,
                                                     batch_size=20,
                                                     drop_last=False,
                                                     num_bins=3,
                                                     replacement=False)
        buck_batch_sampler_no_drop_last = BatchSampler(
            buck_sampler_no_drop_last, batch_size=20, drop_last=False)

        sim_sampler_no_drop_last = SimilarSizeSampler(input_data,
                                                      batch_size=20,
                                                      drop_last=False,
                                                      replacement=False)
        sim_batch_sampler_no_drop_last = BatchSampler(sim_sampler_no_drop_last,
                                                      batch_size=20,
                                                      drop_last=False)

        assert (sim_batch_sampler_no_drop_last.__len__() == 6)
        assert (buck_batch_sampler_no_drop_last.__len__() == 6)

        counter = 0
        for _ in buck_batch_sampler_no_drop_last.__iter__():
            counter = counter + 1
        assert (counter == 6)

        counter = 0
        for _ in sim_batch_sampler_no_drop_last.__iter__():
            counter = counter + 1
        assert (counter == 6)

        buck_sampler_drop_last = BucketingSampler(input_data,
                                                  batch_size=20,
                                                  drop_last=True,
                                                  num_bins=3,
                                                  replacement=False)
        buck_batch_sampler_drop_last = BatchSampler(buck_sampler_drop_last,
                                                    batch_size=20,
                                                    drop_last=True)

        sim_sampler_drop_last = SimilarSizeSampler(input_data,
                                                   batch_size=20,
                                                   drop_last=True,
                                                   replacement=False)
        sim_batch_sampler_drop_last = BatchSampler(sim_sampler_drop_last,
                                                   batch_size=20,
                                                   drop_last=True)

        assert (sim_batch_sampler_drop_last.__len__() == 5)
        assert (buck_batch_sampler_drop_last.__len__() == 5)

        counter = 0
        for _ in buck_batch_sampler_drop_last.__iter__():
            counter = counter + 1
        assert (counter == 5)

        counter = 0
        for _ in sim_batch_sampler_drop_last.__iter__():
            counter = counter + 1
        assert (counter == 5)