def test_get_intra_splits_invalid_groups(self):
        items_and_groups = []
        # 4 groups with 2, 3, 4, 5 items each.
        groups = ['A'] * 2 + ['B'] * 3 + ['C'] * 4 + ['D'] * 5
        items_and_groups = list(enumerate(groups))

        for split_num in range(3):
            two_split_groups = [(str(i), 1.0 / 2) for i in range(2)]
            res = splitter.get_splits_by_group(items_and_groups,
                                               two_split_groups, split_num,
                                               'intra')
            self.assertEqual({'A', 'B', 'C', 'D'}, res.keys())

            three_split_groups = [(str(i), 1.0 / 3) for i in range(3)]
            res = splitter.get_splits_by_group(items_and_groups,
                                               three_split_groups, split_num,
                                               'intra')
            # Now A shouldn't have a split
            self.assertEqual({'B', 'C', 'D'}, res.keys())

            four_split_groups = [(str(i), 1.0 / 4) for i in range(4)]
            res = splitter.get_splits_by_group(items_and_groups,
                                               four_split_groups, split_num,
                                               'intra')
            # Now A and B shouldn't have a split
            self.assertEqual({'C', 'D'}, res.keys())

            five_split_groups = [(str(i), 1.0 / 5) for i in range(5)]
            res = splitter.get_splits_by_group(items_and_groups,
                                               five_split_groups, split_num,
                                               'intra')
            # Now A, B and C shouldn't have a split
            self.assertEqual({'D'}, res.keys())
    def test_get_intra_splits_by_group_correct_splits(self):
        items_and_groups = []
        # 4 groups with 10, 15, 20, 25 items each.
        groups = ['A'] * 10 + ['B'] * 15 + ['C'] * 20 + ['D'] * 25
        items_and_groups = list(enumerate(groups))

        split_probs = [('train', 0.5), ('dev', 0.25), ('test', 0.25)]

        for split_num in range(3):
            res = splitter.get_splits_by_group(items_and_groups, split_probs,
                                               split_num, 'intra')

            # Check we have the expected 4 groups in the result.
            self.assertEqual({'A', 'B', 'C', 'D'}, res.keys())

            # Check that every group has a train dev and test, and that they have the
            # correct items.
            i2g = dict(items_and_groups)
            for group, l in [('A', 10), ('B', 15), ('C', 20), ('D', 25)]:
                cur_splits = res[group]
                # Check current split has train dev and test sets
                self.assertEqual({'train', 'dev', 'test'}, cur_splits.keys())
                # Check the current split has the correct number of items
                self.assertEqual(
                    l, sum([len(items) for items in cur_splits.values()]))
                for items in cur_splits.values():
                    for item in items:
                        self.assertEqual(group, i2g[item])
    def test_get_inter_splits_by_group_correct_splits(self):
        items_and_groups = [  # 4 groups with 1 or 2 items each.
            (1, 'A'),
            (2, 'A'),
            (3, 'B'),
            (4, 'B'),
            (5, 'C'),
            (6, 'C'),
            (7, 'D'),
        ]
        split_probs = [('train', 0.5), ('dev', 0.25), ('test', 0.25)]

        for split_num in range(3):
            res = splitter.get_splits_by_group(items_and_groups, split_probs,
                                               split_num, 'inter')

            # Check we have the expected 3 splits in the result.
            self.assertEqual({'train', 'dev', 'test'}, res.keys())

            # Check the union of all splits is all items and there are no duplicates.
            self.assertCountEqual([1, 2, 3, 4, 5, 6, 7],
                                  list(res['train']) + list(res['dev']) +
                                  list(res['test']))

            # Check items for each group appear in the same split.
            i2g = dict(items_and_groups)
            split_groups = {  # Get list of unique groups in each split.
                name: list({i2g[x]
                            for x in items})
                for name, items in res.items()
            }
            self.assertCountEqual(['A', 'B', 'C', 'D'], split_groups['train'] +
                                  split_groups['dev'] + split_groups['test'])

            # Check number of groups in train/dev/test are as requested.
            self.assertLen(split_groups['train'], 2)
            self.assertLen(split_groups['dev'], 1)
            self.assertLen(split_groups['test'], 1)