def test_random_splitter(self, ratios: List[int], dataset_iterator: DatasetIteratorIF): splitter_impl = RandomSplitterImpl(ratios=ratios) splitter = Splitter(splitter_impl) iterator_splits = splitter.split(dataset_iterator) assert sorted([i for split in iterator_splits for i in split]) == sorted(dataset_iterator)
def test_stratified_splitter_no_seed( self, ratios: List[float], dataset_iterator_stratifiable: DatasetIteratorIF): splitter_impl = StratifiedSplitterImpl(ratios=ratios) splitter = Splitter(splitter_impl) iterator_splits = splitter.split(dataset_iterator_stratifiable) assert sorted([i for split in iterator_splits for i in split ]) == sorted(dataset_iterator_stratifiable)
def test_stratification_seed( self, ratios: List[float], dataset_iterator_stratifiable: DatasetIteratorIF): splitter_impl = StratifiedSplitterImpl(ratios=ratios, seed=100) splitter = Splitter(splitter_impl) iterator_splits = splitter.split(dataset_iterator_stratifiable) # target distribution should be equal among all splits assert (sum([sample[1] for sample in iterator_splits[0]]) == 15) assert (sum([sample[1] for sample in iterator_splits[1]]) == 15) assert (sum([sample[1] for sample in iterator_splits[2]]) == 10) assert (sum([sample[1] for sample in iterator_splits[3]]) == 5) assert (sum([sample[1] for sample in iterator_splits[4]]) == 5)
def test_stratified_splitter_seed_reproducable( self, ratios: List[float], dataset_iterator_stratifiable: DatasetIteratorIF): splitter_impl = StratifiedSplitterImpl(ratios=ratios, seed=100) splitter = Splitter(splitter_impl) splits_old = splitter.split(dataset_iterator_stratifiable) splitter_impl = StratifiedSplitterImpl(ratios=ratios, seed=100) splitter = Splitter(splitter_impl) splits_new = splitter.split(dataset_iterator_stratifiable) for split_old, split_new in zip(splits_old, splits_new): assert ([i for i in split_old] == [j for j in split_new])
def test_nested_cv_splitter(self, num_outer_loop_folds: int, num_inner_loop_folds: int, inner_stratification: bool, outer_stratification: bool, shuffle: bool, big_dataset_iterator: DatasetIteratorIF): splitter_impl = NestedCVSplitterImpl( num_outer_loop_folds=num_outer_loop_folds, num_inner_loop_folds=num_inner_loop_folds, inner_stratification=inner_stratification, outer_stratification=outer_stratification, shuffle=shuffle) splitter = Splitter(splitter_impl) outer_folds, inner_folds = splitter.split(big_dataset_iterator) # make sure that outer folds have no intersection for i in range(len(outer_folds)): for j in range(len(outer_folds)): if i != j: # makes sure there is no intersection assert len( set(outer_folds[i].indices).intersection( set(outer_folds[j].indices))) == 0 # make sure that inner folds have no intersection for i in range(len(inner_folds)): for j in range(len(inner_folds[i])): for k in range(len(inner_folds[i])): if j != k: # makes sure there is no intersection assert len( set(inner_folds[i][j].indices).intersection( set(inner_folds[i][k].indices))) == 0 # test stratification if outer_stratification: class_counts = dict( collections.Counter([t for _, t in big_dataset_iterator])) class_counts_per_fold = { target_class: int(count / num_outer_loop_folds) for target_class, count in class_counts.items() } for fold in outer_folds: fold_class_counts = dict( collections.Counter([t for _, t in fold])) for key in list(class_counts_per_fold.keys()) + list( fold_class_counts.keys()): assert class_counts_per_fold[key] == fold_class_counts[key] if inner_stratification: for i in range(len(inner_folds)): class_counts = dict( collections.Counter([t for _, t in outer_folds[i]])) class_counts_per_fold = { target_class: int(count * (num_outer_loop_folds - 1) / num_inner_loop_folds) for target_class, count in class_counts.items() } for fold in inner_folds[i]: fold_class_counts = dict( collections.Counter([t for _, t in fold])) for key in list(class_counts_per_fold.keys()) + list( fold_class_counts.keys()): assert class_counts_per_fold[key] == fold_class_counts[ key]