Esempio n. 1
0
 def testMedianStoppingOnCompleteOnly(self):
     rule = MedianStoppingRule(grace_period=0, min_samples_required=1)
     t1, t2 = self.basicSetup(rule)
     self.assertEqual(
         rule.on_trial_result(None, t2, result(100, 0)),
         TrialScheduler.CONTINUE)
     rule.on_trial_complete(None, t1, result(10, 1000))
     self.assertEqual(
         rule.on_trial_result(None, t2, result(101, 0)),
         TrialScheduler.STOP)
Esempio n. 2
0
 def testMedianStoppingUsesMedian(self):
     rule = MedianStoppingRule(grace_period=0, min_samples_required=1)
     t1, t2 = self.basicSetup(rule)
     rule.on_trial_complete(None, t1, result(10, 1000))
     rule.on_trial_complete(None, t2, result(10, 1000))
     t3 = Trial("PPO")
     self.assertEqual(
         rule.on_trial_result(None, t3, result(1, 260)),
         TrialScheduler.CONTINUE)
     self.assertEqual(
         rule.on_trial_result(None, t3, result(2, 260)),
         TrialScheduler.STOP)
 def testMedianStoppingMinSamples(self):
     rule = MedianStoppingRule(grace_period=0, min_samples_required=2)
     t1, t2 = self.basicSetup(rule)
     rule.on_trial_complete(None, t1, result(10, 1000))
     t3 = Trial("PPO")
     self.assertEqual(rule.on_trial_result(None, t3, result(3, 10)),
                      TrialScheduler.CONTINUE)
     rule.on_trial_complete(None, t2, result(10, 1000))
     self.assertEqual(rule.on_trial_result(None, t3, result(3, 10)),
                      TrialScheduler.STOP)
 def testMedianStoppingConstantPerf(self):
     rule = MedianStoppingRule(grace_period=0, min_samples_required=1)
     t1, t2 = self.basicSetup(rule)
     rule.on_trial_complete(None, t1, result(10, 1000))
     self.assertEqual(rule.on_trial_result(None, t2, result(5, 450)),
                      TrialScheduler.CONTINUE)
     self.assertEqual(rule.on_trial_result(None, t2, result(6, 0)),
                      TrialScheduler.CONTINUE)
     self.assertEqual(rule.on_trial_result(None, t2, result(10, 450)),
                      TrialScheduler.STOP)
 def testMedianStoppingSoftStop(self):
     rule = MedianStoppingRule(grace_period=0,
                               min_samples_required=1,
                               hard_stop=False)
     t1, t2 = self.basicSetup(rule)
     rule.on_trial_complete(None, t1, result(10, 1000))
     rule.on_trial_complete(None, t2, result(10, 1000))
     t3 = Trial("PPO")
     self.assertEqual(rule.on_trial_result(None, t3, result(1, 260)),
                      TrialScheduler.CONTINUE)
     self.assertEqual(rule.on_trial_result(None, t3, result(2, 260)),
                      TrialScheduler.PAUSE)
def main():
    args = parse_args()
    save_path = args.save_path = os.path.join(args.save_folder, args.arch)
    os.makedirs(save_path)
    #os.makedirs(save_path, exist_ok=True)

    # config
    args.logger_file = os.path.join(save_path, 'log_{}.txt'.format(args.cmd))

    handlers = [
        logging.FileHandler(args.logger_file, mode='w'),
        logging.StreamHandler()
    ]
    logging.basicConfig(level=logging.INFO,
                        datefmt='%m-%d-%y %H:%M',
                        format='%(asctime)s:%(message)s',
                        handlers=handlers)

    if args.cmd == 'train':
        logging.info('start training {}'.format(args.arch))
        run_training(args)
    elif args.cmd == 'test':
        logging.info('start evaluating {} with checkpoints from {}'.format(
            args.arch, args.resume))
        test_model(args)
    elif args.cmd == 'tune':
        import ray
        import ray.tune as tune
        from ray.tune import Experiment
        from ray.tune.median_stopping_rule import MedianStoppingRule

        ray.init()
        sched = MedianStoppingRule(time_attr="timesteps_total",
                                   reward_attr="neg_mean_loss")
        tune.register_trainable(
            "run_training",
            lambda cfg, reporter: run_training(args, cfg, reporter))
        experiment = Experiment(
            "train_rl",
            "run_training",
            trial_resources={"gpu": 1},
            config={"alpha": tune.grid_search([0.1, 0.01, 0.001])})
        tune.run_experiments(experiment, scheduler=sched, verbose=False)
    def testAlternateMetrics(self):
        def result2(t, rew):
            return TrainingResult(training_iteration=t, neg_mean_loss=rew)

        rule = MedianStoppingRule(grace_period=0,
                                  min_samples_required=1,
                                  time_attr='training_iteration',
                                  reward_attr='neg_mean_loss')
        t1 = Trial("PPO")  # mean is 450, max 900, t_max=10
        t2 = Trial("PPO")  # mean is 450, max 450, t_max=5
        for i in range(10):
            self.assertEqual(
                rule.on_trial_result(None, t1, result2(i, i * 100)),
                TrialScheduler.CONTINUE)
        for i in range(5):
            self.assertEqual(rule.on_trial_result(None, t2, result2(i, 450)),
                             TrialScheduler.CONTINUE)
        rule.on_trial_complete(None, t1, result2(10, 1000))
        self.assertEqual(rule.on_trial_result(None, t2, result2(5, 450)),
                         TrialScheduler.CONTINUE)
        self.assertEqual(rule.on_trial_result(None, t2, result2(6, 0)),
                         TrialScheduler.CONTINUE)
Esempio n. 8
0
    def testAlternateMetrics(self):
        def result2(t, rew):
            return TrainingResult(training_iteration=t, neg_mean_loss=rew)

        rule = MedianStoppingRule(
            grace_period=0, min_samples_required=1,
            time_attr='training_iteration', reward_attr='neg_mean_loss')
        t1 = Trial("PPO")  # mean is 450, max 900, t_max=10
        t2 = Trial("PPO")  # mean is 450, max 450, t_max=5
        for i in range(10):
            self.assertEqual(
                rule.on_trial_result(None, t1, result2(i, i * 100)),
                TrialScheduler.CONTINUE)
        for i in range(5):
            self.assertEqual(
                rule.on_trial_result(None, t2, result2(i, 450)),
                TrialScheduler.CONTINUE)
        rule.on_trial_complete(None, t1, result2(10, 1000))
        self.assertEqual(
            rule.on_trial_result(None, t2, result2(5, 450)),
            TrialScheduler.CONTINUE)
        self.assertEqual(
            rule.on_trial_result(None, t2, result2(6, 0)),
            TrialScheduler.CONTINUE)