Example #1
0
    def test_len(self):
        batch_size = 3
        drop_uneven = True
        dataset = [i for i in range(10)]
        group_ids = [random.randint(0, 1) for _ in dataset]
        sampler = RandomSampler(dataset)

        batch_sampler = GroupedBatchSampler(sampler, group_ids, batch_size,
                                            drop_uneven)
        result = list(batch_sampler)
        self.assertEqual(len(result), len(batch_sampler))
        self.assertEqual(len(result), len(batch_sampler))

        batch_sampler = GroupedBatchSampler(sampler, group_ids, batch_size,
                                            drop_uneven)
        batch_sampler_len = len(batch_sampler)
        result = list(batch_sampler)
        self.assertEqual(len(result), batch_sampler_len)
        self.assertEqual(len(result), len(batch_sampler))
Example #2
0
 def test_respect_order_simple(self):
     drop_uneven = False
     dataset = [i for i in range(40)]
     group_ids = [i // 10 for i in dataset]
     sampler = SequentialSampler(dataset)
     for batch_size in [1, 3, 5, 6]:
         batch_sampler = GroupedBatchSampler(sampler, group_ids, batch_size,
                                             drop_uneven)
         result = list(batch_sampler)
         merged_result = list(itertools.chain.from_iterable(result))
         self.assertEqual(merged_result, dataset)
Example #3
0
    def test_permute_subset_sampler_drop_uneven(self):
        batch_size = 3
        drop_uneven = True
        dataset = [i for i in range(10)]
        group_ids = [0, 0, 1, 0, 1, 1, 0, 1, 1, 0]
        sampler = SubsetSampler([5, 0, 6, 1, 3, 8])

        batch_sampler = GroupedBatchSampler(sampler, group_ids, batch_size,
                                            drop_uneven)
        result = list(batch_sampler)

        expected = [[0, 6, 1]]
        self.assertEqual(result, expected)
Example #4
0
    def test_subset_sampler(self):
        batch_size = 3
        drop_uneven = False
        dataset = [i for i in range(10)]
        group_ids = [0, 0, 1, 0, 1, 1, 0, 1, 1, 0]
        sampler = SubsetSampler([0, 3, 5, 6, 7, 8])

        batch_sampler = GroupedBatchSampler(sampler, group_ids, batch_size,
                                            drop_uneven)
        result = list(batch_sampler)

        expected = [[0, 3, 6], [5, 7, 8]]
        self.assertEqual(result, expected)
Example #5
0
    def test_respect_order_drop_uneven(self):
        batch_size = 3
        drop_uneven = True
        dataset = [i for i in range(10)]
        group_ids = [0, 0, 1, 0, 1, 1, 0, 1, 1, 0]
        sampler = SequentialSampler(dataset)
        batch_sampler = GroupedBatchSampler(sampler, group_ids, batch_size,
                                            drop_uneven)

        result = list(batch_sampler)

        expected = [[0, 1, 3], [2, 4, 5]]
        self.assertEqual(result, expected)
Example #6
0
    def test_respect_order(self):
        drop_uneven = False
        dataset = [i for i in range(10)]
        group_ids = [0, 0, 1, 0, 1, 1, 0, 1, 1, 0]
        sampler = SequentialSampler(dataset)

        expected = [
            [[0], [1], [2], [3], [4], [5], [6], [7], [8], [9]],
            [[0, 1, 3], [2, 4, 5], [6, 9], [7, 8]],
            [[0, 1, 3, 6], [2, 4, 5, 7], [8], [9]],
        ]

        for idx, batch_size in enumerate([1, 3, 4]):
            batch_sampler = GroupedBatchSampler(sampler, group_ids, batch_size,
                                                drop_uneven)
            result = list(batch_sampler)
            self.assertEqual(result, expected[idx])