def test_preprocessed_splitter(): class DemeanPreproc(): """Just for tests :)""" def apply(self, dataset, can_fit=False): topo_view = dataset.get_topological_view() if can_fit: self.mean = np.mean(topo_view) dataset.set_topological_view(topo_view - self.mean) data = np.arange(10) dataset = DenseDesignMatrixWrapper(topo_view=to_4d_array(data), y=np.zeros(10)) splitter = SingleFoldSplitter(n_folds=10, i_test_fold=9) preproc_splitter = PreprocessedSplitter(dataset_splitter=splitter, preprocessor=DemeanPreproc()) first_round_sets = preproc_splitter.get_train_valid_test(dataset) train_topo = first_round_sets['train'].get_topological_view() valid_topo = first_round_sets['valid'].get_topological_view() test_topo = first_round_sets['test'].get_topological_view() assert np.array_equal( train_topo, to_4d_array([-3.5, -2.5, -1.5, -0.5, 0.5, 1.5, 2.5, 3.5])) assert np.array_equal(valid_topo, to_4d_array([4.5])) assert np.array_equal(test_topo, to_4d_array([5.5])) second_round_set = preproc_splitter.get_train_merged_valid_test(dataset) train_topo = second_round_set['train'].get_topological_view() valid_topo = second_round_set['valid'].get_topological_view() test_topo = second_round_set['test'].get_topological_view() assert np.array_equal(train_topo, to_4d_array([-4, -3, -2, -1, 0, 1, 2, 3, 4])) assert np.array_equal(valid_topo, to_4d_array([4])) assert np.array_equal(test_topo, to_4d_array([5]))
def __init__(self, final_layer, dataset, splitter, preprocessor, iterator, loss_expression, updates_expression, updates_modifier, monitors, stop_criterion, remember_best_chan, run_after_early_stop, batch_modifier=None): self.final_layer = final_layer self.dataset = dataset self.dataset_provider = PreprocessedSplitter(splitter, preprocessor) self.preprocessor = preprocessor self.iterator = iterator self.loss_expression = loss_expression self.updates_expression = updates_expression self.updates_modifier = updates_modifier self.monitors = monitors self.stop_criterion = stop_criterion self.monitor_manager = MonitorManager(monitors) self.remember_extension = RememberBest(remember_best_chan) self.run_after_early_stop = run_after_early_stop self.batch_modifier = batch_modifier