def testAlternateMetrics(self): def result2(t, rew): return TrainingResult(training_iteration=t, neg_mean_loss=rew) scheduler = AsyncHyperBandScheduler( grace_period=1, time_attr='training_iteration', reward_attr='neg_mean_loss', brackets=1) t1 = Trial("PPO") # mean is 450, max 900, t_max=10 t2 = Trial("PPO") # mean is 450, max 450, t_max=5 scheduler.on_trial_add(None, t1) scheduler.on_trial_add(None, t2) for i in range(10): self.assertEqual( scheduler.on_trial_result(None, t1, result2(i, i * 100)), TrialScheduler.CONTINUE) for i in range(5): self.assertEqual( scheduler.on_trial_result(None, t2, result2(i, 450)), TrialScheduler.CONTINUE) scheduler.on_trial_complete(None, t1, result2(10, 1000)) self.assertEqual( scheduler.on_trial_result(None, t2, result2(5, 450)), TrialScheduler.CONTINUE) self.assertEqual( scheduler.on_trial_result(None, t2, result2(6, 0)), TrialScheduler.CONTINUE)
def testAsyncHBOnComplete(self): scheduler = AsyncHyperBandScheduler(max_t=10, brackets=1) t1, t2 = self.basicSetup(scheduler) t3 = Trial("PPO") scheduler.on_trial_add(None, t3) scheduler.on_trial_complete(None, t3, result(10, 1000)) self.assertEqual(scheduler.on_trial_result(None, t2, result(101, 0)), TrialScheduler.STOP)
def testAsyncHBOnComplete(self): scheduler = AsyncHyperBandScheduler( max_t=10, brackets=1) t1, t2 = self.basicSetup(scheduler) t3 = Trial("PPO") scheduler.on_trial_add(None, t3) scheduler.on_trial_complete(None, t3, result(10, 1000)) self.assertEqual( scheduler.on_trial_result(None, t2, result(101, 0)), TrialScheduler.STOP)
def testAsyncHBUsesPercentile(self): scheduler = AsyncHyperBandScheduler( grace_period=1, max_t=10, reduction_factor=2, brackets=1) t1, t2 = self.basicSetup(scheduler) scheduler.on_trial_complete(None, t1, result(10, 1000)) scheduler.on_trial_complete(None, t2, result(10, 1000)) t3 = Trial("PPO") scheduler.on_trial_add(None, t3) self.assertEqual( scheduler.on_trial_result(None, t3, result(1, 260)), TrialScheduler.STOP) self.assertEqual( scheduler.on_trial_result(None, t3, result(2, 260)), TrialScheduler.STOP)
def testAsyncHBGracePeriod(self): scheduler = AsyncHyperBandScheduler(grace_period=2.5, reduction_factor=3, brackets=1) t1, t2 = self.basicSetup(scheduler) scheduler.on_trial_complete(None, t1, result(10, 1000)) scheduler.on_trial_complete(None, t2, result(10, 1000)) t3 = Trial("PPO") scheduler.on_trial_add(None, t3) self.assertEqual(scheduler.on_trial_result(None, t3, result(1, 10)), TrialScheduler.CONTINUE) self.assertEqual(scheduler.on_trial_result(None, t3, result(2, 10)), TrialScheduler.CONTINUE) self.assertEqual(scheduler.on_trial_result(None, t3, result(3, 10)), TrialScheduler.STOP)
def testAsyncHBGracePeriod(self): scheduler = AsyncHyperBandScheduler( grace_period=2.5, reduction_factor=3, brackets=1) t1, t2 = self.basicSetup(scheduler) scheduler.on_trial_complete(None, t1, result(10, 1000)) scheduler.on_trial_complete(None, t2, result(10, 1000)) t3 = Trial("PPO") scheduler.on_trial_add(None, t3) self.assertEqual( scheduler.on_trial_result(None, t3, result(1, 10)), TrialScheduler.CONTINUE) self.assertEqual( scheduler.on_trial_result(None, t3, result(2, 10)), TrialScheduler.CONTINUE) self.assertEqual( scheduler.on_trial_result(None, t3, result(3, 10)), TrialScheduler.STOP)