コード例 #1
0
    def test_disable_shuffle(self):
        sampler = BucketBatchSampler(batch_size=2,
                                     sorting_keys=["text"],
                                     shuffle=False)

        grouped_instances = []
        for indices in sampler.get_batch_indices(self.instances):
            grouped_instances.append([self.instances[idx] for idx in indices])
        expected_groups = [
            [self.instances[4], self.instances[2]],
            [self.instances[0], self.instances[1]],
            [self.instances[3]],
        ]
        for idx, group in enumerate(grouped_instances):
            assert group == expected_groups[idx]
コード例 #2
0
    def test_create_batches_groups_correctly(self):
        sampler = BucketBatchSampler(batch_size=2,
                                     padding_noise=0,
                                     sorting_keys=["text"])

        grouped_instances = []
        for indices in sampler.get_batch_indices(self.instances):
            grouped_instances.append([self.instances[idx] for idx in indices])
        expected_groups = [
            [self.instances[4], self.instances[2]],
            [self.instances[0], self.instances[1]],
            [self.instances[3]],
        ]
        for group in grouped_instances:
            assert group in expected_groups
            expected_groups.remove(group)
        assert expected_groups == []