def test_new(self, number, candidates, estimator_spec, best_candidate_index, subnetwork_reports_fn=None): if subnetwork_reports_fn is None: subnetwork_reports = {} else: subnetwork_reports = subnetwork_reports_fn() iteration = _Iteration(number=number, candidates=candidates, subnetwork_specs=None, estimator_spec=estimator_spec, best_candidate_index=best_candidate_index, summaries=[], subnetwork_reports=subnetwork_reports, train_manager=_TrainManager( [], [], self.test_subdirectory, is_chief=True)) self.assertEqual(iteration.number, number) self.assertEqual(iteration.candidates, candidates) self.assertEqual(iteration.estimator_spec, estimator_spec) self.assertEqual(iteration.best_candidate_index, best_candidate_index) self.assertEqual(iteration.subnetwork_reports, subnetwork_reports)
def test_new_errors(self, number=0, candidates=lambda: [_dummy_candidate()], estimator_spec=tu.dummy_estimator_spec(), best_candidate_index=0, subnetwork_reports=lambda: []): with self.assertRaises(ValueError): _Iteration(number=number, candidates=candidates(), subnetwork_specs=None, estimator_spec=estimator_spec, best_candidate_index=best_candidate_index, summaries=[], subnetwork_reports=subnetwork_reports(), train_manager=_TrainManager([], [], self.test_subdirectory, is_chief=True))