Esempio n. 1
0
 def test_new(self,
              number,
              candidates,
              estimator_spec,
              best_candidate_index,
              is_over_fn,
              subnetwork_reports_fn=None,
              step=0):
     if subnetwork_reports_fn is None:
         subnetwork_reports = {}
     else:
         subnetwork_reports = subnetwork_reports_fn()
     with self.test_session():
         iteration = _Iteration(number=number,
                                candidates=candidates,
                                estimator_spec=estimator_spec,
                                best_candidate_index=best_candidate_index,
                                summaries=[],
                                is_over_fn=is_over_fn,
                                subnetwork_reports=subnetwork_reports,
                                step=step)
         self.assertEqual(iteration.number, number)
         self.assertEqual(iteration.candidates, candidates)
         self.assertEqual(iteration.estimator_spec, estimator_spec)
         self.assertEqual(iteration.best_candidate_index,
                          best_candidate_index)
         self.assertEqual(iteration.is_over_fn(), is_over_fn())
         self.assertEqual(iteration.subnetwork_reports, subnetwork_reports)
         self.assertEqual(iteration.step, step)
Esempio n. 2
0
 def test_new(self,
              number,
              candidates,
              estimator_spec,
              best_candidate_index,
              subnetwork_reports_fn=None):
     if subnetwork_reports_fn is None:
         subnetwork_reports = {}
     else:
         subnetwork_reports = subnetwork_reports_fn()
     iteration = _Iteration(number=number,
                            candidates=candidates,
                            subnetwork_specs=None,
                            estimator_spec=estimator_spec,
                            best_candidate_index=best_candidate_index,
                            summaries=[],
                            subnetwork_reports=subnetwork_reports,
                            train_manager=_TrainManager(
                                [], [],
                                self.test_subdirectory,
                                is_chief=True))
     self.assertEqual(iteration.number, number)
     self.assertEqual(iteration.candidates, candidates)
     self.assertEqual(iteration.estimator_spec, estimator_spec)
     self.assertEqual(iteration.best_candidate_index, best_candidate_index)
     self.assertEqual(iteration.subnetwork_reports, subnetwork_reports)
Esempio n. 3
0
 def test_new_errors(self,
                     number=0,
                     candidates=lambda: [_dummy_candidate()],
                     estimator_spec=tu.dummy_estimator_spec(),
                     best_candidate_index=0,
                     subnetwork_reports=lambda: []):
     with self.assertRaises(ValueError):
         _Iteration(number=number,
                    candidates=candidates(),
                    subnetwork_specs=None,
                    estimator_spec=estimator_spec,
                    best_candidate_index=best_candidate_index,
                    summaries=[],
                    subnetwork_reports=subnetwork_reports(),
                    train_manager=_TrainManager([], [],
                                                self.test_subdirectory,
                                                is_chief=True))
Esempio n. 4
0
 def test_new_errors(self,
                     number=0,
                     candidates=lambda: [_dummy_candidate()],
                     estimator_spec=tu.dummy_estimator_spec(),
                     best_candidate_index=0,
                     is_over_fn=lambda: True,
                     subnetwork_reports=lambda: []):
     with self.test_session():
         with self.assertRaises(ValueError):
             _Iteration(number=number,
                        candidates=candidates(),
                        subnetwork_specs=None,
                        estimator_spec=estimator_spec,
                        best_candidate_index=best_candidate_index,
                        summaries=[],
                        is_over_fn=is_over_fn,
                        subnetwork_reports=subnetwork_reports())