コード例 #1
0
ファイル: trial_scheduler_test.py プロジェクト: adgirish/ray
    def testAlternateMetrics(self):
        """Checking that alternate metrics will pass."""

        def result2(t, rew):
            return TrainingResult(time_total_s=t, neg_mean_loss=rew)

        sched = HyperBandScheduler(
            time_attr='time_total_s', reward_attr='neg_mean_loss')
        stats = self.default_statistics()

        for i in range(stats["max_trials"]):
            t = Trial("__fake")
            sched.on_trial_add(None, t)
        runner = _MockTrialRunner(sched)

        big_bracket = sched._hyperbands[0][-1]

        for trl in big_bracket.current_trials():
            runner._launch_trial(trl)
        current_length = len(big_bracket.current_trials())

        # Provides results from 0 to 8 in order, keeping the last one running
        for i, trl in enumerate(big_bracket.current_trials()):
            action = sched.on_trial_result(runner, trl, result2(1, i))
            runner.process_action(trl, action)

        new_length = len(big_bracket.current_trials())
        self.assertEqual(action, TrialScheduler.CONTINUE)
        self.assertEqual(new_length, self.downscale(current_length, sched))
    def testAlternateMetrics(self):
        """Checking that alternate metrics will pass."""
        def result2(t, rew):
            return TrainingResult(time_total_s=t, neg_mean_loss=rew)

        sched = HyperBandScheduler(time_attr='time_total_s',
                                   reward_attr='neg_mean_loss')
        stats = self.default_statistics()

        for i in range(stats["max_trials"]):
            t = Trial("__fake")
            sched.on_trial_add(None, t)
        runner = _MockTrialRunner()

        big_bracket = sched._hyperbands[0][-1]

        for trl in big_bracket.current_trials():
            runner._launch_trial(trl)
        current_length = len(big_bracket.current_trials())

        # Provides results from 0 to 8 in order, keeping the last one running
        for i, trl in enumerate(big_bracket.current_trials()):
            status = sched.on_trial_result(runner, trl, result2(1, i))
            if status == TrialScheduler.CONTINUE:
                continue
            elif status == TrialScheduler.PAUSE:
                runner._pause_trial(trl)
            elif status == TrialScheduler.STOP:
                self.assertNotEqual(trl.status, Trial.TERMINATED)
                self.stopTrial(trl, runner)

        new_length = len(big_bracket.current_trials())
        self.assertEqual(status, TrialScheduler.CONTINUE)
        self.assertEqual(new_length, self.downscale(current_length, sched))
コード例 #3
0
    def testAlternateMetrics(self):
        """Checking that alternate metrics will pass."""
        def result2(t, rew):
            return dict(time_total_s=t, neg_mean_loss=rew)

        sched = HyperBandScheduler(time_attr='time_total_s',
                                   reward_attr='neg_mean_loss')
        stats = self.default_statistics()

        for i in range(stats["max_trials"]):
            t = Trial("__fake")
            sched.on_trial_add(None, t)
        runner = _MockTrialRunner(sched)

        big_bracket = sched._hyperbands[0][-1]

        for trl in big_bracket.current_trials():
            runner._launch_trial(trl)
        current_length = len(big_bracket.current_trials())

        # Provides results from 0 to 8 in order, keeping the last one running
        for i, trl in enumerate(big_bracket.current_trials()):
            action = sched.on_trial_result(runner, trl, result2(1, i))
            runner.process_action(trl, action)

        new_length = len(big_bracket.current_trials())
        self.assertEqual(action, TrialScheduler.CONTINUE)
        self.assertEqual(new_length, self.downscale(current_length, sched))