Пример #1
0
    def _run_outer_loop(
        self,
        input_data: InputDataset,
        outer_split: Split,
        data_splitter: DataSplitter,
    ) -> OuterLoopResults:

        feature_elimination_results = {}
        feature_set = list(range(input_data.n_features))

        while len(feature_set) >= self._minimum_features:
            inner_results = []

            for inner_split in data_splitter.iter_inner_splits(outer_split):
                inner_loop_data = data_splitter.split_data(
                    input_data, inner_split, feature_set
                )

                feature_evaluation_results = self._feature_evaluator.evaluate_features(
                    inner_loop_data, feature_set
                )

                inner_results.append(feature_evaluation_results)

            feature_elimination_results[tuple(feature_set)] = inner_results
            feature_set = self._remove_features(feature_set, inner_results)

        outer_loop_results = self._create_outer_loop_results(
            feature_elimination_results, input_data, outer_split, data_splitter
        )

        return outer_loop_results
Пример #2
0
def test_split_separation(dataset):
    ds = DataSplitter(n_outer=5, n_inner=4, random_state=0, input_data=dataset)
    # all indices should appear once if joining outer_test, inner_test and inner_train
    for outer_split in ds.iter_outer_splits():
        for inner_split in ds.iter_inner_splits(outer_split):
            all_indeces = (list(outer_split.test_indices) +
                           list(inner_split.test_indices) +
                           list(inner_split.train_indices))
            assert len(all_indeces) == 12
            assert sorted(all_indeces) == list(range(12))
            out_train = list(inner_split.test_indices) + list(
                inner_split.train_indices)
            assert sorted(out_train) == sorted(outer_split.train_indices)
Пример #3
0
def test_make_splits_grouped(grouped_dataset):
    ds = DataSplitter(n_outer=5,
                      n_inner=4,
                      random_state=0,
                      input_data=grouped_dataset)
    groups = grouped_dataset.groups
    assert ds
    # check there is no intersection among the groups
    for outer_split in ds.iter_outer_splits():
        train_idx = outer_split.train_indices
        test_idx = outer_split.test_indices
        for inner_split in ds.iter_inner_splits(outer_split):
            inner_train = inner_split.train_indices
            valid_idx = inner_split.test_indices
            assert not set(groups[inner_train]).intersection(groups[valid_idx])
            assert not set(groups[test_idx]).intersection(groups[valid_idx])
            assert not set(groups[inner_train]).intersection(groups[test_idx])
Пример #4
0
def test_iter_inner_splits(dataset):
    ds = DataSplitter(n_outer=5, n_inner=4, random_state=0, input_data=dataset)
    for outer_split in ds.iter_outer_splits():
        for inner_split in ds.iter_inner_splits(outer_split):
            assert inner_split
            assert inner_split == ds._splits[(outer_split.id, inner_split.id)]