コード例 #1
0
    def test_random_splitter(self, ratios: List[int],
                             dataset_iterator: DatasetIteratorIF):
        splitter_impl = RandomSplitterImpl(ratios=ratios)
        splitter = Splitter(splitter_impl)
        iterator_splits = splitter.split(dataset_iterator)

        assert sorted([i for split in iterator_splits
                       for i in split]) == sorted(dataset_iterator)
コード例 #2
0
ファイル: test_splitter.py プロジェクト: jeffmaxey/datastack
    def test_stratified_splitter_no_seed(
            self, ratios: List[float],
            dataset_iterator_stratifiable: DatasetIteratorIF):
        splitter_impl = StratifiedSplitterImpl(ratios=ratios)
        splitter = Splitter(splitter_impl)
        iterator_splits = splitter.split(dataset_iterator_stratifiable)

        assert sorted([i for split in iterator_splits for i in split
                       ]) == sorted(dataset_iterator_stratifiable)
コード例 #3
0
ファイル: test_splitter.py プロジェクト: jeffmaxey/datastack
    def test_stratification_seed(
            self, ratios: List[float],
            dataset_iterator_stratifiable: DatasetIteratorIF):
        splitter_impl = StratifiedSplitterImpl(ratios=ratios, seed=100)
        splitter = Splitter(splitter_impl)
        iterator_splits = splitter.split(dataset_iterator_stratifiable)

        # target distribution should be equal among all splits
        assert (sum([sample[1] for sample in iterator_splits[0]]) == 15)
        assert (sum([sample[1] for sample in iterator_splits[1]]) == 15)
        assert (sum([sample[1] for sample in iterator_splits[2]]) == 10)
        assert (sum([sample[1] for sample in iterator_splits[3]]) == 5)
        assert (sum([sample[1] for sample in iterator_splits[4]]) == 5)
コード例 #4
0
ファイル: test_splitter.py プロジェクト: jeffmaxey/datastack
    def test_stratified_splitter_seed_reproducable(
            self, ratios: List[float],
            dataset_iterator_stratifiable: DatasetIteratorIF):
        splitter_impl = StratifiedSplitterImpl(ratios=ratios, seed=100)
        splitter = Splitter(splitter_impl)
        splits_old = splitter.split(dataset_iterator_stratifiable)

        splitter_impl = StratifiedSplitterImpl(ratios=ratios, seed=100)
        splitter = Splitter(splitter_impl)
        splits_new = splitter.split(dataset_iterator_stratifiable)

        for split_old, split_new in zip(splits_old, splits_new):
            assert ([i for i in split_old] == [j for j in split_new])
コード例 #5
0
ファイル: test_splitter.py プロジェクト: jeffmaxey/datastack
    def test_nested_cv_splitter(self, num_outer_loop_folds: int,
                                num_inner_loop_folds: int,
                                inner_stratification: bool,
                                outer_stratification: bool, shuffle: bool,
                                big_dataset_iterator: DatasetIteratorIF):
        splitter_impl = NestedCVSplitterImpl(
            num_outer_loop_folds=num_outer_loop_folds,
            num_inner_loop_folds=num_inner_loop_folds,
            inner_stratification=inner_stratification,
            outer_stratification=outer_stratification,
            shuffle=shuffle)
        splitter = Splitter(splitter_impl)
        outer_folds, inner_folds = splitter.split(big_dataset_iterator)
        # make sure that outer folds have no intersection
        for i in range(len(outer_folds)):
            for j in range(len(outer_folds)):
                if i != j:
                    # makes sure there is no intersection
                    assert len(
                        set(outer_folds[i].indices).intersection(
                            set(outer_folds[j].indices))) == 0
        # make sure that inner folds have no intersection
        for i in range(len(inner_folds)):
            for j in range(len(inner_folds[i])):
                for k in range(len(inner_folds[i])):
                    if j != k:
                        # makes sure there is no intersection
                        assert len(
                            set(inner_folds[i][j].indices).intersection(
                                set(inner_folds[i][k].indices))) == 0
        # test stratification
        if outer_stratification:
            class_counts = dict(
                collections.Counter([t for _, t in big_dataset_iterator]))
            class_counts_per_fold = {
                target_class: int(count / num_outer_loop_folds)
                for target_class, count in class_counts.items()
            }
            for fold in outer_folds:
                fold_class_counts = dict(
                    collections.Counter([t for _, t in fold]))
                for key in list(class_counts_per_fold.keys()) + list(
                        fold_class_counts.keys()):
                    assert class_counts_per_fold[key] == fold_class_counts[key]

        if inner_stratification:
            for i in range(len(inner_folds)):
                class_counts = dict(
                    collections.Counter([t for _, t in outer_folds[i]]))
                class_counts_per_fold = {
                    target_class: int(count * (num_outer_loop_folds - 1) /
                                      num_inner_loop_folds)
                    for target_class, count in class_counts.items()
                }
                for fold in inner_folds[i]:
                    fold_class_counts = dict(
                        collections.Counter([t for _, t in fold]))
                    for key in list(class_counts_per_fold.keys()) + list(
                            fold_class_counts.keys()):
                        assert class_counts_per_fold[key] == fold_class_counts[
                            key]