Example #1
0
 def _test_metrics(self, result_func, metric, mode):
     scheduler = AsyncHyperBandScheduler(grace_period=1,
                                         time_attr="training_iteration",
                                         metric=metric,
                                         mode=mode,
                                         brackets=1)
     t1 = Trial("PPO")  # mean is 450, max 900, t_max=10
     t2 = Trial("PPO")  # mean is 450, max 450, t_max=5
     scheduler.on_trial_add(None, t1)
     scheduler.on_trial_add(None, t2)
     for i in range(10):
         self.assertEqual(
             scheduler.on_trial_result(None, t1, result_func(i, i * 100)),
             TrialScheduler.CONTINUE)
     for i in range(5):
         self.assertEqual(
             scheduler.on_trial_result(None, t2, result_func(i, 450)),
             TrialScheduler.CONTINUE)
     scheduler.on_trial_complete(None, t1, result_func(10, 1000))
     self.assertEqual(
         scheduler.on_trial_result(None, t2, result_func(5, 450)),
         TrialScheduler.CONTINUE)
     self.assertEqual(
         scheduler.on_trial_result(None, t2, result_func(6, 0)),
         TrialScheduler.CONTINUE)
Example #2
0
    def testAlternateMetrics(self):
        def result2(t, rew):
            return dict(training_iteration=t, neg_mean_loss=rew)

        scheduler = AsyncHyperBandScheduler(
            grace_period=1,
            time_attr='training_iteration',
            reward_attr='neg_mean_loss',
            brackets=1)
        t1 = Trial("PPO")  # mean is 450, max 900, t_max=10
        t2 = Trial("PPO")  # mean is 450, max 450, t_max=5
        scheduler.on_trial_add(None, t1)
        scheduler.on_trial_add(None, t2)
        for i in range(10):
            self.assertEqual(
                scheduler.on_trial_result(None, t1, result2(i, i * 100)),
                TrialScheduler.CONTINUE)
        for i in range(5):
            self.assertEqual(
                scheduler.on_trial_result(None, t2, result2(i, 450)),
                TrialScheduler.CONTINUE)
        scheduler.on_trial_complete(None, t1, result2(10, 1000))
        self.assertEqual(
            scheduler.on_trial_result(None, t2, result2(5, 450)),
            TrialScheduler.CONTINUE)
        self.assertEqual(
            scheduler.on_trial_result(None, t2, result2(6, 0)),
            TrialScheduler.CONTINUE)
Example #3
0
    def testAlternateMetrics(self):
        def result2(t, rew):
            return dict(training_iteration=t, neg_mean_loss=rew)

        scheduler = AsyncHyperBandScheduler(
            grace_period=1,
            time_attr='training_iteration',
            reward_attr='neg_mean_loss',
            brackets=1)
        t1 = Trial("PPO")  # mean is 450, max 900, t_max=10
        t2 = Trial("PPO")  # mean is 450, max 450, t_max=5
        scheduler.on_trial_add(None, t1)
        scheduler.on_trial_add(None, t2)
        for i in range(10):
            self.assertEqual(
                scheduler.on_trial_result(None, t1, result2(i, i * 100)),
                TrialScheduler.CONTINUE)
        for i in range(5):
            self.assertEqual(
                scheduler.on_trial_result(None, t2, result2(i, 450)),
                TrialScheduler.CONTINUE)
        scheduler.on_trial_complete(None, t1, result2(10, 1000))
        self.assertEqual(
            scheduler.on_trial_result(None, t2, result2(5, 450)),
            TrialScheduler.CONTINUE)
        self.assertEqual(
            scheduler.on_trial_result(None, t2, result2(6, 0)),
            TrialScheduler.CONTINUE)
Example #4
0
 def testAsyncHBOnComplete(self):
     scheduler = AsyncHyperBandScheduler(max_t=10, brackets=1)
     t1, t2 = self.basicSetup(scheduler)
     t3 = Trial("PPO")
     scheduler.on_trial_add(None, t3)
     scheduler.on_trial_complete(None, t3, result(10, 1000))
     self.assertEqual(scheduler.on_trial_result(None, t2, result(101, 0)),
                      TrialScheduler.STOP)
Example #5
0
 def testAsyncHBOnComplete(self):
     scheduler = AsyncHyperBandScheduler(max_t=10, brackets=1)
     t1, t2 = self.basicSetup(scheduler)
     t3 = Trial("PPO")
     scheduler.on_trial_add(None, t3)
     scheduler.on_trial_complete(None, t3, result(10, 1000))
     self.assertEqual(
         scheduler.on_trial_result(None, t2, result(101, 0)),
         TrialScheduler.STOP)
Example #6
0
 def testAsyncHBUsesPercentile(self):
     scheduler = AsyncHyperBandScheduler(
         grace_period=1, max_t=10, reduction_factor=2, brackets=1)
     t1, t2 = self.basicSetup(scheduler)
     scheduler.on_trial_complete(None, t1, result(10, 1000))
     scheduler.on_trial_complete(None, t2, result(10, 1000))
     t3 = Trial("PPO")
     scheduler.on_trial_add(None, t3)
     self.assertEqual(
         scheduler.on_trial_result(None, t3, result(1, 260)),
         TrialScheduler.STOP)
     self.assertEqual(
         scheduler.on_trial_result(None, t3, result(2, 260)),
         TrialScheduler.STOP)
Example #7
0
 def testAsyncHBUsesPercentile(self):
     scheduler = AsyncHyperBandScheduler(grace_period=1,
                                         max_t=10,
                                         reduction_factor=2,
                                         brackets=1)
     t1, t2 = self.basicSetup(scheduler)
     scheduler.on_trial_complete(None, t1, result(10, 1000))
     scheduler.on_trial_complete(None, t2, result(10, 1000))
     t3 = Trial("PPO")
     scheduler.on_trial_add(None, t3)
     self.assertEqual(scheduler.on_trial_result(None, t3, result(1, 260)),
                      TrialScheduler.STOP)
     self.assertEqual(scheduler.on_trial_result(None, t3, result(2, 260)),
                      TrialScheduler.STOP)
Example #8
0
 def testAsyncHBGracePeriod(self):
     scheduler = AsyncHyperBandScheduler(grace_period=2.5,
                                         reduction_factor=3,
                                         brackets=1)
     t1, t2 = self.basicSetup(scheduler)
     scheduler.on_trial_complete(None, t1, result(10, 1000))
     scheduler.on_trial_complete(None, t2, result(10, 1000))
     t3 = Trial("PPO")
     scheduler.on_trial_add(None, t3)
     self.assertEqual(scheduler.on_trial_result(None, t3, result(1, 10)),
                      TrialScheduler.CONTINUE)
     self.assertEqual(scheduler.on_trial_result(None, t3, result(2, 10)),
                      TrialScheduler.CONTINUE)
     self.assertEqual(scheduler.on_trial_result(None, t3, result(3, 10)),
                      TrialScheduler.STOP)
Example #9
0
 def testAsyncHBGracePeriod(self):
     scheduler = AsyncHyperBandScheduler(
         grace_period=2.5, reduction_factor=3, brackets=1)
     t1, t2 = self.basicSetup(scheduler)
     scheduler.on_trial_complete(None, t1, result(10, 1000))
     scheduler.on_trial_complete(None, t2, result(10, 1000))
     t3 = Trial("PPO")
     scheduler.on_trial_add(None, t3)
     self.assertEqual(
         scheduler.on_trial_result(None, t3, result(1, 10)),
         TrialScheduler.CONTINUE)
     self.assertEqual(
         scheduler.on_trial_result(None, t3, result(2, 10)),
         TrialScheduler.CONTINUE)
     self.assertEqual(
         scheduler.on_trial_result(None, t3, result(3, 10)),
         TrialScheduler.STOP)