def my_policy_1155152886(num_envs=1): """We will use this function to load your agent and then testing. Make sure this function can run bug-free, when the working directory is "ierg5350-assignment/assignment5/" You can rewrite this function completely if you have custom agents, but you need to make sure the codes is bug-free and add necessary description on the notebook. Please rename this function!!! We will use program to automatically detect your agent, so a wrong function name will fail the evaluation. Run this file directly to make sure everything is fine. """ # [TODO] rewrite this function # [TODO] CAUTION! PLEASE CHANGE THE NAME OF THIS FUNCTION!!! Otherwise our program can't find your agent! my_agent_log_dir = 'data/cCarRacing-v0_PPO_12-01_21-39' my_agent_suffix = 'final' # checkpoint_path = osp.join(my_agent_log_dir, "checkpoint-{}.pkl".format(my_agent_suffix)) # if not osp.exists(checkpoint_path): # raise ValueError("Can't find anything at {}!".format(checkpoint_path)) # else: # print("Found your checkpoint at {}!".format(checkpoint_path)) return PolicyAPI("cCarRacing-v0", num_envs=num_envs, log_dir=my_agent_log_dir, suffix=my_agent_suffix)
def student_compute_action_function(num_envs=1): """We will use this function to load your agent and then testing. Make sure this function can run bug-free, when the working directory is "ierg6130-assignment/assignment4/" You can rewrite this function completely if you have custom agents, but you need to make sure the codes is bug-free and add necessary description on report_SID.md Run this file directly to make sure everything is fine. """ # [TODO] rewrite this function my_agent_log_dir = "" my_agent_suffix = "" checkpoint_path = osp.join(my_agent_log_dir, "checkpoint-{}.pkl".format(my_agent_suffix)) if not osp.exists(checkpoint_path): print("Can't find anything at {}!".format(checkpoint_path)) else: print("Found your checkpoint at {}!".format(checkpoint_path)) return PolicyAPI(num_envs=num_envs, log_dir=my_agent_log_dir, suffix=my_agent_suffix)
def my_policy_zhenghao(num_envs=1): return PolicyAPI("cCarRacing-v0", num_envs=num_envs, log_dir="data/alphacar", suffix="zhenghao")
def my_policy_1155156694(num_envs=1): return PolicyAPI("cCarRacing-v0", num_envs=num_envs, log_dir="data/alphacar", suffix="alphacar")
def train(args): # Verify algorithm and config algo = args.algo if algo == "PPO": config = ppo_config else: raise ValueError("args.algo must in [PPO]") config.num_envs = args.num_envs config.lr = args.lr config.entropy_loss_weight = args.entropy assert args.env_id in ["cPong-v0", "cCarRacing-v0"], args.env_id # Seed the environments and setup torch seed = args.seed torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed(seed) torch.set_num_threads(1) # Create vectorized environments num_envs = args.num_envs env_id = args.env_id if not args.opponent else "cCarRacingDouble-v0" # Clean log directory log_dir = verify_log_dir( args.log_dir, "{}_{}_{}".format(env_id, algo, datetime.datetime.now().strftime("%m-%d_%H-%M"))) if args.opponent: assert args.num_eval_envs == 0 from competitive_rl.car_racing import make_competitive_car_racing from load_agents import PolicyAPI restore_log_dir = os.path.dirname(args.restore) restore_suffix = os.path.basename( args.restore).split("checkpoint-")[1].split(".pkl")[0] opponent_policy = PolicyAPI("cCarRacing-v0", num_envs=1, log_dir=restore_log_dir, suffix=restore_suffix) envs = make_competitive_car_racing(opponent_policy=opponent_policy, num_envs=num_envs, asynchronous=not args.test) else: envs = make_envs(env_id=env_id, seed=seed, log_dir=log_dir, num_envs=num_envs, asynchronous=not args.test, resized_dim=config.resized_dim, action_repeat=args.action_repeat) if args.num_eval_envs > 0: eval_envs = make_envs(env_id=env_id, seed=seed, log_dir=log_dir, num_envs=args.num_eval_envs, asynchronous=not args.test, resized_dim=config.resized_dim, action_repeat=args.action_repeat) else: eval_envs = None # Setup trainer if algo == "PPO": trainer = PPOTrainer(envs, config) else: raise ValueError("Unknown algorithm {}".format(algo)) if args.restore: restore_log_dir = os.path.dirname(args.restore) restore_suffix = os.path.basename( args.restore).split("checkpoint-")[1].split(".pkl")[0] success = trainer.load_w(restore_log_dir, restore_suffix) if not success: raise ValueError( "We can't restore your agent. The log_dir is {} and the suffix is {}" .format(restore_log_dir, restore_suffix)) # Start training print("Start training!") obs = envs.reset() # frame_stack_tensor.update(obs) raw_obs = trainer.process_obs(obs) processed_obs = trainer.model.world_model(raw_obs) trainer.rollouts.before_update(obs, processed_obs) try: _train(trainer, envs, eval_envs, config, num_envs, algo, log_dir, False, False) except KeyboardInterrupt: print( "The training is stopped by user. The log directory is {}. Now we finish the training." .format(log_dir)) trainer.save_w(log_dir, "final") envs.close()