def testAlternateMetrics(self): def result2(t, rew): return dict(training_iteration=t, neg_mean_loss=rew) rule = MedianStoppingRule( grace_period=0, min_samples_required=1, time_attr='training_iteration', reward_attr='neg_mean_loss') t1 = Trial("PPO") # mean is 450, max 900, t_max=10 t2 = Trial("PPO") # mean is 450, max 450, t_max=5 for i in range(10): self.assertEqual( rule.on_trial_result(None, t1, result2(i, i * 100)), TrialScheduler.CONTINUE) for i in range(5): self.assertEqual( rule.on_trial_result(None, t2, result2(i, 450)), TrialScheduler.CONTINUE) rule.on_trial_complete(None, t1, result2(10, 1000)) self.assertEqual( rule.on_trial_result(None, t2, result2(5, 450)), TrialScheduler.CONTINUE) self.assertEqual( rule.on_trial_result(None, t2, result2(6, 0)), TrialScheduler.CONTINUE)
def _test_metrics(self, result_func, metric, mode): rule = MedianStoppingRule( grace_period=0, min_samples_required=1, time_attr="training_iteration", metric=metric, mode=mode) t1 = Trial("PPO") # mean is 450, max 900, t_max=10 t2 = Trial("PPO") # mean is 450, max 450, t_max=5 runner = mock_trial_runner() for i in range(10): self.assertEqual( rule.on_trial_result(runner, t1, result_func(i, i * 100)), TrialScheduler.CONTINUE) for i in range(5): self.assertEqual( rule.on_trial_result(runner, t2, result_func(i, 450)), TrialScheduler.CONTINUE) rule.on_trial_complete(runner, t1, result_func(10, 1000)) self.assertEqual( rule.on_trial_result(runner, t2, result_func(5, 450)), TrialScheduler.CONTINUE) self.assertEqual( rule.on_trial_result(runner, t2, result_func(6, 0)), TrialScheduler.CONTINUE)
def testMedianStoppingOnCompleteOnly(self): rule = MedianStoppingRule(grace_period=0, min_samples_required=1) t1, t2 = self.basicSetup(rule) self.assertEqual(rule.on_trial_result(None, t2, result(100, 0)), TrialScheduler.CONTINUE) rule.on_trial_complete(None, t1, result(10, 1000)) self.assertEqual(rule.on_trial_result(None, t2, result(101, 0)), TrialScheduler.STOP)
def testMedianStoppingOnCompleteOnly(self): rule = MedianStoppingRule(grace_period=0, min_samples_required=1) t1, t2 = self.basicSetup(rule) self.assertEqual( rule.on_trial_result(None, t2, result(100, 0)), TrialScheduler.CONTINUE) rule.on_trial_complete(None, t1, result(10, 1000)) self.assertEqual( rule.on_trial_result(None, t2, result(101, 0)), TrialScheduler.STOP)
def testMedianStoppingMinSamples(self): rule = MedianStoppingRule(grace_period=0, min_samples_required=2) t1, t2 = self.basicSetup(rule) rule.on_trial_complete(None, t1, result(10, 1000)) t3 = Trial("PPO") self.assertEqual(rule.on_trial_result(None, t3, result(3, 10)), TrialScheduler.CONTINUE) rule.on_trial_complete(None, t2, result(10, 1000)) self.assertEqual(rule.on_trial_result(None, t3, result(3, 10)), TrialScheduler.STOP)
def testMedianStoppingMinSamples(self): rule = MedianStoppingRule(grace_period=0, min_samples_required=2) t1, t2 = self.basicSetup(rule) rule.on_trial_complete(None, t1, result(10, 1000)) t3 = Trial("PPO") self.assertEqual( rule.on_trial_result(None, t3, result(3, 10)), TrialScheduler.CONTINUE) rule.on_trial_complete(None, t2, result(10, 1000)) self.assertEqual( rule.on_trial_result(None, t3, result(3, 10)), TrialScheduler.STOP)
def testMedianStoppingConstantPerf(self): rule = MedianStoppingRule(grace_period=0, min_samples_required=1) t1, t2 = self.basicSetup(rule) runner = mock_trial_runner() rule.on_trial_complete(runner, t1, result(10, 1000)) self.assertEqual(rule.on_trial_result(runner, t2, result(5, 450)), TrialScheduler.CONTINUE) self.assertEqual(rule.on_trial_result(runner, t2, result(6, 0)), TrialScheduler.CONTINUE) self.assertEqual(rule.on_trial_result(runner, t2, result(10, 450)), TrialScheduler.STOP)
def testMedianStoppingUsesMedian(self): rule = MedianStoppingRule(grace_period=0, min_samples_required=1) t1, t2 = self.basicSetup(rule) runner = mock_trial_runner() rule.on_trial_complete(runner, t1, result(10, 1000)) rule.on_trial_complete(runner, t2, result(10, 1000)) t3 = Trial("PPO") self.assertEqual(rule.on_trial_result(runner, t3, result(1, 260)), TrialScheduler.CONTINUE) self.assertEqual(rule.on_trial_result(runner, t3, result(2, 260)), TrialScheduler.STOP)
def testMedianStoppingSoftStop(self): rule = MedianStoppingRule(grace_period=0, min_samples_required=1, hard_stop=False) t1, t2 = self.basicSetup(rule) rule.on_trial_complete(None, t1, result(10, 1000)) rule.on_trial_complete(None, t2, result(10, 1000)) t3 = Trial("PPO") self.assertEqual(rule.on_trial_result(None, t3, result(1, 260)), TrialScheduler.CONTINUE) self.assertEqual(rule.on_trial_result(None, t3, result(2, 260)), TrialScheduler.PAUSE)
def testMedianStoppingSoftStop(self): rule = MedianStoppingRule( grace_period=0, min_samples_required=1, hard_stop=False) t1, t2 = self.basicSetup(rule) rule.on_trial_complete(None, t1, result(10, 1000)) rule.on_trial_complete(None, t2, result(10, 1000)) t3 = Trial("PPO") self.assertEqual( rule.on_trial_result(None, t3, result(1, 260)), TrialScheduler.CONTINUE) self.assertEqual( rule.on_trial_result(None, t3, result(2, 260)), TrialScheduler.PAUSE)
def testMedianStoppingMinSamples(self): rule = MedianStoppingRule(grace_period=0, min_samples_required=2) t1, t2 = self.basicSetup(rule) runner = mock_trial_runner() rule.on_trial_complete(runner, t1, result(10, 1000)) t3 = Trial("PPO") # Insufficient samples to evaluate t3 self.assertEqual(rule.on_trial_result(runner, t3, result(5, 10)), TrialScheduler.CONTINUE) rule.on_trial_complete(runner, t2, result(5, 1000)) # Sufficient samples to evaluate t3 self.assertEqual(rule.on_trial_result(runner, t3, result(5, 10)), TrialScheduler.STOP)