Пример #1
0
def main(argv):
    args = parser.parse_args(argv)
    runner = TrialRunner(MedianStoppingRule())

    if args.config_file:
        with open(args.config_file) as f:
            config = yaml.load(f)
        for trial in parse_to_trials(config):
            runner.add_trial(trial)
    else:
        runner.add_trial(
            Trial(args.env, args.alg, args.config, args.local_dir, None,
                  args.resources, args.stop, args.checkpoint_freq,
                  args.restore, args.upload_dir))
    ray.init(redis_address=args.redis_address,
             num_cpus=args.num_cpus,
             num_gpus=args.num_gpus)

    while not runner.is_finished():
        runner.step()
        print(runner.debug_string())

    for trial in runner.get_trials():
        if trial.status != Trial.TERMINATED:
            print("Exit 1")
            sys.exit(1)

    print("Exit 0")
Пример #2
0
 def testEval(self):
     trials = parse_to_trials({
         "tune-pong": {
             "env": "Pong-v0",
             "config": {
                 "foo": {
                     "eval": "2 + 2"
                 },
             },
         },
     })
     self.assertEqual(len(trials), 1)
     self.assertEqual(trials[0].config, {"foo": 4})
     self.assertEqual(trials[0].agent_id, "0_foo=4")
Пример #3
0
 def testParseToTrials(self):
     trials = parse_to_trials({
         "tune-pong": {
             "env": "Pong-v0",
             "alg": "PPO",
             "num_trials": 2,
             "config": {
                 "foo": "bar"
             },
         },
     })
     self.assertEqual(len(trials), 2)
     self.assertEqual(trials[0].env_name, "Pong-v0")
     self.assertEqual(trials[0].config, {"foo": "bar"})
     self.assertEqual(trials[0].alg, "PPO")
     self.assertEqual(trials[0].agent_id, "0")
     self.assertEqual(trials[0].local_dir, "/tmp/ray/tune-pong")
     self.assertEqual(trials[1].agent_id, "1")
Пример #4
0
 def testGridSearchAndEval(self):
     trials = parse_to_trials({
         "tune-pong": {
             "env": "Pong-v0",
             "num_trials": 1,
             "config": {
                 "qux": {
                     "eval": "2 + 2"
                 },
                 "bar": {
                     "grid_search": [True, False]
                 },
                 "foo": {
                     "grid_search": [1, 2, 3]
                 },
             },
         },
     })
     self.assertEqual(len(trials), 1)
     self.assertEqual(trials[0].config, {"bar": True, "foo": 1, "qux": 4})
     self.assertEqual(trials[0].agent_id, "0_bar=True_foo=1_qux=4")
Пример #5
0
 def testGridSearch(self):
     trials = parse_to_trials({
         "tune-pong": {
             "env": "Pong-v0",
             "num_trials": 6,
             "config": {
                 "bar": {
                     "grid_search": [True, False]
                 },
                 "foo": {
                     "grid_search": [1, 2, 3]
                 },
             },
         },
     })
     self.assertEqual(len(trials), 6)
     self.assertEqual(trials[0].config, {"bar": True, "foo": 1})
     self.assertEqual(trials[0].agent_id, "0_bar=True_foo=1")
     self.assertEqual(trials[1].config, {"bar": False, "foo": 1})
     self.assertEqual(trials[1].agent_id, "1_bar=False_foo=1")
     self.assertEqual(trials[2].config, {"bar": True, "foo": 2})
     self.assertEqual(trials[3].config, {"bar": False, "foo": 2})
     self.assertEqual(trials[4].config, {"bar": True, "foo": 3})
     self.assertEqual(trials[5].config, {"bar": False, "foo": 3})
Пример #6
0
parser.add_argument("--redis-address", default=None, type=str,
                    help="The Redis address of the cluster.")
parser.add_argument("--restore", default=None, type=str,
                    help="If specified, restore from this checkpoint.")
parser.add_argument("-f", "--config-file", default=None, type=str,
                    help="If specified, use config options from this file.")


if __name__ == "__main__":
    args = parser.parse_args()
    runner = TrialRunner()

    if args.config_file:
        with open(args.config_file) as f:
            config = yaml.load(f)
        for trial in parse_to_trials(config):
            runner.add_trial(trial)
    else:
        runner.add_trial(
            Trial(
                args.env, args.alg, args.config, args.local_dir, None,
                args.resources, args.stop, args.checkpoint_freq,
                args.restore, args.upload_dir))

    ray.init(redis_address=args.redis_address)

    while not runner.is_finished():
        runner.step()
        print(runner.debug_string())

    for trial in runner.get_trials():