def test_meta_train_splitting_v2_multi_gpu_recurrent() -> None: """ Runs meta training with meta splitting networks v2 and compares reward curves against saved baseline for a MetaWorld ML10 benchmark, running multiple processes and on the GPU with a recurrent policy. """ # Load default training config. with open(META_SPLITTING_V2_CONFIG_PATH, "r") as config_file: config = json.load(config_file) # Modify default training config. config["meta_train_config"]["num_updates"] //= MP_FACTOR config["meta_train_config"]["num_processes"] *= MP_FACTOR config["meta_test_config"]["num_updates"] //= MP_FACTOR config["meta_test_config"]["num_processes"] *= MP_FACTOR config["cuda"] = True config["meta_train_config"]["architecture_config"]["recurrent"] = True config["meta_train_config"]["architecture_config"][ "recurrent_hidden_size"] = 64 config["meta_train_config"][ "baseline_metrics_filename"] = "meta_splitting_v2_multi_gpu_recurrent" # Run training. meta_train(config)
def test_meta_train_splitting_v2() -> None: """ Runs meta training with meta splitting networks v2 and compares reward curves against saved baseline for a MetaWorld ML10 benchmark. """ # Load default training config. with open(META_SPLITTING_V2_CONFIG_PATH, "r") as config_file: config = json.load(config_file) # Modify default training config. config["meta_train_config"][ "baseline_metrics_filename"] = "meta_splitting_v2" # Run training. meta_train(config)
def test_meta_train_splitting_v1_recurrent() -> None: """ Runs meta training with meta splitting networks v1 and compares reward curves against saved baseline for a MetaWorld ML10 benchmark, with a recurrent policy. """ # Load default training config. with open(META_SPLITTING_V1_CONFIG_PATH, "r") as config_file: config = json.load(config_file) # Modify default training config. config["meta_train_config"]["architecture_config"]["recurrent"] = True config["meta_train_config"]["architecture_config"][ "recurrent_hidden_size"] = 64 config["meta_train_config"][ "baseline_metrics_filename"] = "meta_splitting_v1_recurrent" # Run training. meta_train(config)
def test_meta_train_splitting_v1_multi() -> None: """ Runs meta training with meta splitting networks v1 and compares reward curves against saved baseline for a MetaWorld ML10 benchmark, running multiple processes. """ # Load default training config. with open(META_SPLITTING_V1_CONFIG_PATH, "r") as config_file: config = json.load(config_file) # Modify default training config. config["meta_train_config"]["num_updates"] //= MP_FACTOR config["meta_train_config"]["num_processes"] *= MP_FACTOR config["meta_test_config"]["num_updates"] //= MP_FACTOR config["meta_test_config"]["num_processes"] *= MP_FACTOR config["meta_train_config"][ "baseline_metrics_filename"] = "meta_splitting_v1_multi" # Run training. meta_train(config)
if __name__ == "__main__": # Parse config filename from command line arguments. parser = argparse.ArgumentParser() parser.add_argument( "command", type=str, help="Command to run. Either 'train' or 'tune'.", ) parser.add_argument( "config_filename", type=str, help="Name of config file to load from.", ) args = parser.parse_args() # Load config file. with open(args.config_filename, "r") as config_file: config = json.load(config_file) # Run specified command. if args.command == "train": train(config) elif args.command == "tune": tune(config) elif args.command == "meta_train": meta_train(config) else: raise ValueError("Unsupported command: '%s'" % args.command)