def testAlternateMetrics(self): """Checking that alternate metrics will pass.""" def result2(t, rew): return TrainingResult(time_total_s=t, neg_mean_loss=rew) sched = HyperBandScheduler(time_attr='time_total_s', reward_attr='neg_mean_loss') stats = self.default_statistics() for i in range(stats["max_trials"]): t = Trial("__fake") sched.on_trial_add(None, t) runner = _MockTrialRunner() big_bracket = sched._hyperbands[0][-1] for trl in big_bracket.current_trials(): runner._launch_trial(trl) current_length = len(big_bracket.current_trials()) # Provides results from 0 to 8 in order, keeping the last one running for i, trl in enumerate(big_bracket.current_trials()): status = sched.on_trial_result(runner, trl, result2(1, i)) if status == TrialScheduler.CONTINUE: continue elif status == TrialScheduler.PAUSE: runner._pause_trial(trl) elif status == TrialScheduler.STOP: self.assertNotEqual(trl.status, Trial.TERMINATED) self.stopTrial(trl, runner) new_length = len(big_bracket.current_trials()) self.assertEqual(status, TrialScheduler.CONTINUE) self.assertEqual(new_length, self.downscale(current_length, sched))
def testAlternateMetrics(self): """Checking that alternate metrics will pass.""" def result2(t, rew): return TrainingResult(time_total_s=t, neg_mean_loss=rew) sched = HyperBandScheduler( time_attr='time_total_s', reward_attr='neg_mean_loss') stats = self.default_statistics() for i in range(stats["max_trials"]): t = Trial("__fake") sched.on_trial_add(None, t) runner = _MockTrialRunner(sched) big_bracket = sched._hyperbands[0][-1] for trl in big_bracket.current_trials(): runner._launch_trial(trl) current_length = len(big_bracket.current_trials()) # Provides results from 0 to 8 in order, keeping the last one running for i, trl in enumerate(big_bracket.current_trials()): action = sched.on_trial_result(runner, trl, result2(1, i)) runner.process_action(trl, action) new_length = len(big_bracket.current_trials()) self.assertEqual(action, TrialScheduler.CONTINUE) self.assertEqual(new_length, self.downscale(current_length, sched))
def testAlternateMetrics(self): """Checking that alternate metrics will pass.""" def result2(t, rew): return dict(time_total_s=t, neg_mean_loss=rew) sched = HyperBandScheduler(time_attr='time_total_s', reward_attr='neg_mean_loss') stats = self.default_statistics() for i in range(stats["max_trials"]): t = Trial("__fake") sched.on_trial_add(None, t) runner = _MockTrialRunner(sched) big_bracket = sched._hyperbands[0][-1] for trl in big_bracket.current_trials(): runner._launch_trial(trl) current_length = len(big_bracket.current_trials()) # Provides results from 0 to 8 in order, keeping the last one running for i, trl in enumerate(big_bracket.current_trials()): action = sched.on_trial_result(runner, trl, result2(1, i)) runner.process_action(trl, action) new_length = len(big_bracket.current_trials()) self.assertEqual(action, TrialScheduler.CONTINUE) self.assertEqual(new_length, self.downscale(current_length, sched))
def testConfigSameEtaSmall(self): sched = HyperBandScheduler(max_t=1) i = 0 while len(sched._hyperbands) < 2: t = Trial("__fake") sched.on_trial_add(None, t) i += 1 self.assertEqual(len(sched._hyperbands[0]), 5) self.assertTrue(all(v is None for v in sched._hyperbands[0][1:]))
def schedulerSetup(self, num_trials): """Setup a scheduler and Runner with max Iter = 9 Bracketing is placed as follows: (3, 9); (5, 3) -> (2, 9); (9, 1) -> (3, 3) -> (1, 9); """ sched = HyperBandScheduler(9, eta=3) for i in range(num_trials): t = Trial("t%d" % i, "__fake") sched.on_trial_add(None, t) runner = _MockTrialRunner() return sched, runner
def schedulerSetup(self, num_trials): """Setup a scheduler and Runner with max Iter = 9 Bracketing is placed as follows: (5, 81); (8, 27) -> (3, 81); (15, 9) -> (5, 27) -> (2, 81); (34, 3) -> (12, 9) -> (4, 27) -> (2, 81); (81, 1) -> (27, 3) -> (9, 9) -> (3, 27) -> (1, 81);""" sched = HyperBandScheduler() for i in range(num_trials): t = Trial("__fake") sched.on_trial_add(None, t) runner = _MockTrialRunner() return sched, runner
def schedulerSetup(self, num_trials): """Setup a scheduler and Runner with max Iter = 9 Bracketing is placed as follows: (5, 81); (8, 27) -> (3, 54); (15, 9) -> (5, 27) -> (2, 45); (34, 3) -> (12, 9) -> (4, 27) -> (2, 42); (81, 1) -> (27, 3) -> (9, 9) -> (3, 27) -> (1, 41);""" sched = HyperBandScheduler() for i in range(num_trials): t = Trial("__fake") sched.on_trial_add(None, t) runner = _MockTrialRunner(sched) return sched, runner
def testConfigSameEta(self): sched = HyperBandScheduler() i = 0 while not sched._cur_band_filled(): t = Trial("__fake") sched.on_trial_add(None, t) i += 1 self.assertEqual(len(sched._hyperbands[0]), 5) self.assertEqual(sched._hyperbands[0][0]._n, 5) self.assertEqual(sched._hyperbands[0][0]._r, 81) self.assertEqual(sched._hyperbands[0][-1]._n, 81) self.assertEqual(sched._hyperbands[0][-1]._r, 1) sched = HyperBandScheduler(max_t=810) i = 0 while not sched._cur_band_filled(): t = Trial("__fake") sched.on_trial_add(None, t) i += 1 self.assertEqual(len(sched._hyperbands[0]), 5) self.assertEqual(sched._hyperbands[0][0]._n, 5) self.assertEqual(sched._hyperbands[0][0]._r, 810) self.assertEqual(sched._hyperbands[0][-1]._n, 81) self.assertEqual(sched._hyperbands[0][-1]._r, 10)