def main(argv): args = parser.parse_args(argv) runner = TrialRunner(MedianStoppingRule()) if args.config_file: with open(args.config_file) as f: config = yaml.load(f) for trial in parse_to_trials(config): runner.add_trial(trial) else: runner.add_trial( Trial(args.env, args.alg, args.config, args.local_dir, None, args.resources, args.stop, args.checkpoint_freq, args.restore, args.upload_dir)) ray.init(redis_address=args.redis_address, num_cpus=args.num_cpus, num_gpus=args.num_gpus) while not runner.is_finished(): runner.step() print(runner.debug_string()) for trial in runner.get_trials(): if trial.status != Trial.TERMINATED: print("Exit 1") sys.exit(1) print("Exit 0")
def testEval(self): trials = parse_to_trials({ "tune-pong": { "env": "Pong-v0", "config": { "foo": { "eval": "2 + 2" }, }, }, }) self.assertEqual(len(trials), 1) self.assertEqual(trials[0].config, {"foo": 4}) self.assertEqual(trials[0].agent_id, "0_foo=4")
def testParseToTrials(self): trials = parse_to_trials({ "tune-pong": { "env": "Pong-v0", "alg": "PPO", "num_trials": 2, "config": { "foo": "bar" }, }, }) self.assertEqual(len(trials), 2) self.assertEqual(trials[0].env_name, "Pong-v0") self.assertEqual(trials[0].config, {"foo": "bar"}) self.assertEqual(trials[0].alg, "PPO") self.assertEqual(trials[0].agent_id, "0") self.assertEqual(trials[0].local_dir, "/tmp/ray/tune-pong") self.assertEqual(trials[1].agent_id, "1")
def testGridSearchAndEval(self): trials = parse_to_trials({ "tune-pong": { "env": "Pong-v0", "num_trials": 1, "config": { "qux": { "eval": "2 + 2" }, "bar": { "grid_search": [True, False] }, "foo": { "grid_search": [1, 2, 3] }, }, }, }) self.assertEqual(len(trials), 1) self.assertEqual(trials[0].config, {"bar": True, "foo": 1, "qux": 4}) self.assertEqual(trials[0].agent_id, "0_bar=True_foo=1_qux=4")
def testGridSearch(self): trials = parse_to_trials({ "tune-pong": { "env": "Pong-v0", "num_trials": 6, "config": { "bar": { "grid_search": [True, False] }, "foo": { "grid_search": [1, 2, 3] }, }, }, }) self.assertEqual(len(trials), 6) self.assertEqual(trials[0].config, {"bar": True, "foo": 1}) self.assertEqual(trials[0].agent_id, "0_bar=True_foo=1") self.assertEqual(trials[1].config, {"bar": False, "foo": 1}) self.assertEqual(trials[1].agent_id, "1_bar=False_foo=1") self.assertEqual(trials[2].config, {"bar": True, "foo": 2}) self.assertEqual(trials[3].config, {"bar": False, "foo": 2}) self.assertEqual(trials[4].config, {"bar": True, "foo": 3}) self.assertEqual(trials[5].config, {"bar": False, "foo": 3})
parser.add_argument("--redis-address", default=None, type=str, help="The Redis address of the cluster.") parser.add_argument("--restore", default=None, type=str, help="If specified, restore from this checkpoint.") parser.add_argument("-f", "--config-file", default=None, type=str, help="If specified, use config options from this file.") if __name__ == "__main__": args = parser.parse_args() runner = TrialRunner() if args.config_file: with open(args.config_file) as f: config = yaml.load(f) for trial in parse_to_trials(config): runner.add_trial(trial) else: runner.add_trial( Trial( args.env, args.alg, args.config, args.local_dir, None, args.resources, args.stop, args.checkpoint_freq, args.restore, args.upload_dir)) ray.init(redis_address=args.redis_address) while not runner.is_finished(): runner.step() print(runner.debug_string()) for trial in runner.get_trials():