def test_preprocessed_splitter():
    class DemeanPreproc():
        """Just for tests :)"""
        def apply(self, dataset, can_fit=False):
            topo_view = dataset.get_topological_view()
            if can_fit:
                self.mean = np.mean(topo_view)
            dataset.set_topological_view(topo_view - self.mean)

    data = np.arange(10)
    dataset = DenseDesignMatrixWrapper(topo_view=to_4d_array(data),
                                       y=np.zeros(10))
    splitter = SingleFoldSplitter(n_folds=10, i_test_fold=9)
    preproc_splitter = PreprocessedSplitter(dataset_splitter=splitter,
                                            preprocessor=DemeanPreproc())

    first_round_sets = preproc_splitter.get_train_valid_test(dataset)

    train_topo = first_round_sets['train'].get_topological_view()
    valid_topo = first_round_sets['valid'].get_topological_view()
    test_topo = first_round_sets['test'].get_topological_view()
    assert np.array_equal(
        train_topo, to_4d_array([-3.5, -2.5, -1.5, -0.5, 0.5, 1.5, 2.5, 3.5]))
    assert np.array_equal(valid_topo, to_4d_array([4.5]))
    assert np.array_equal(test_topo, to_4d_array([5.5]))

    second_round_set = preproc_splitter.get_train_merged_valid_test(dataset)

    train_topo = second_round_set['train'].get_topological_view()
    valid_topo = second_round_set['valid'].get_topological_view()
    test_topo = second_round_set['test'].get_topological_view()
    assert np.array_equal(train_topo,
                          to_4d_array([-4, -3, -2, -1, 0, 1, 2, 3, 4]))
    assert np.array_equal(valid_topo, to_4d_array([4]))
    assert np.array_equal(test_topo, to_4d_array([5]))
示例#2
0
 def __init__(self,
              final_layer,
              dataset,
              splitter,
              preprocessor,
              iterator,
              loss_expression,
              updates_expression,
              updates_modifier,
              monitors,
              stop_criterion,
              remember_best_chan,
              run_after_early_stop,
              batch_modifier=None):
     self.final_layer = final_layer
     self.dataset = dataset
     self.dataset_provider = PreprocessedSplitter(splitter, preprocessor)
     self.preprocessor = preprocessor
     self.iterator = iterator
     self.loss_expression = loss_expression
     self.updates_expression = updates_expression
     self.updates_modifier = updates_modifier
     self.monitors = monitors
     self.stop_criterion = stop_criterion
     self.monitor_manager = MonitorManager(monitors)
     self.remember_extension = RememberBest(remember_best_chan)
     self.run_after_early_stop = run_after_early_stop
     self.batch_modifier = batch_modifier