seed=i, ) return lambda: env envs = utils.env.vectorize_env( [env_fn(i) for i in range(num_processes)], state_normalize=True, device=device, train=True, ) obs_space, action_space = envs.observation_space, envs.action_space init_obs = envs.reset() torch.manual_seed(0) agent = learn.PPO( obs_space, action_space, init_obs, hidden=256, num_mini_batch=32, clip_param=0.3, num_steps=num_steps, num_processes=num_processes, gamma=gamma, entropy_coef=0.01, linear_schedule=False, ) if args.train:
def train_ppo(env_class, steps, track_eps=25, log_interval=1, solved_at=90.0, continual_solved_at=90.0, care_about=None, num_processes=8, gamma=0.99, MaxT=400, num_steps=128, clip_param=0.3, linear_schedule=True, policy=None, ob_rms=None, eval_envs=None, eval_eps=-1, hidden=-1, entropy_coef=0, linear_schedule_mode=0, lr=3e-4, training_seed=0, verbosity=1, training_method=learn.PPO, log_extras={}, policy_class=learn.PolicyPPO, discrete=False): assert (verbosity in [1, 2]) is_continual = training_method.__name__ in ["PPO_EWC", "PPO_DM"] if is_continual: assert (care_about != None) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") num_env_steps = int(steps) if eval_envs != None: assert (eval_eps > 0) def env_fn(i): env = env_class(discrete=discrete) # env.debug['show_reasons'] = True env = utils.env.wrap_env( env, action_normalize=not discrete, time_limit=MaxT, deterministic=True, seed=i, ) return lambda: env envs = utils.env.vectorize_env( [env_fn(i) for i in range(num_processes)], state_normalize=True, device=device, train=True, ) if ob_rms != None: envs.ob_rms = ob_rms obs_space, action_space = envs.observation_space, envs.action_space init_obs = envs.reset() torch.manual_seed(training_seed) print("training_method = %s" % training_method.__name__) agent = training_method(obs_space, action_space, init_obs, clip_param=clip_param, num_steps=num_steps, lr=lr, num_processes=num_processes, gamma=gamma, policy=policy, hidden=hidden, linear_schedule=linear_schedule, entropy_coef=entropy_coef, linear_schedule_mode=linear_schedule_mode, policy_class=policy_class) num_updates = agent.compute_updates_needed(num_env_steps, num_processes) episode_rewards = collections.deque(maxlen=track_eps) s = collections.deque(maxlen=track_eps) log_dict = { 'r': episode_rewards, 'eps_done': 0, 'satisfactions': s, **log_extras } start = utils.timer() ret_steps = -1 for j in range(num_updates): agent.pre_step(j, num_updates) agent.step(envs, log=log_dict) vloss, piloss, ent = agent.train() if (j + 1) % log_interval == 0 and len(log_dict['r']) > 1: total_num_steps = (j + 1) * num_processes * num_steps elapsed = "Elapsed %s" % utils.timer_done(start) MeanR = np.mean(log_dict['r']) MedR = np.median(log_dict['r']) MinR = np.min(log_dict['r']) MaxR = np.max(log_dict['r']) if verbosity == 1: reward_stats = "MeanR:%.2f" % (MeanR) extra_stats = [reward_stats] elif verbosity == 2: reward_stats1 = "MeanR,MedR:%.2f,%.2f" % (MeanR, MedR) reward_stats2 = "MinR,MaxR:%.2f,%.2f" % (MinR, MaxR) reg_loss = None if type(ent) == list: ent, reg_loss = ent loss_stats = "Ent:%f, VLoss:%f, PiLoss:%f" % (ent, vloss, piloss) if reg_loss is not None: loss_stats += ", Reg:%f" % (reg_loss) extra_stats = [ reward_stats1, reward_stats2, loss_stats, ] reasons = "Reasons: %s" % (set(list(s))) stats = [ "Steps:%g" % total_num_steps, "Eps:%d" % log_dict['eps_done'], elapsed, *extra_stats, ] print(" ".join(stats)) print(reasons) if eval_envs != None: eval_rews = [] for eval_env in eval_envs: eval_rews += [ utils.env.evaluate_ppo(agent.actor_critic, None, eval_env, device, num_episodes=eval_eps, wrap=False, silent=True) ] eval_rews[-1] = round(eval_rews[-1], 2) if is_continual: eval_MeanR = np.mean( np.clip(eval_rews[:care_about], -100., 100.)) if not is_continual and care_about != None: eval_relevant_R = np.clip(eval_rews[care_about - 1], -100., 100.) print(eval_rews) # print("") sys.stdout.flush() if MeanR >= solved_at: if eval_envs != None: if is_continual: if eval_MeanR < continual_solved_at: continue if not is_continual and care_about != None: if eval_relevant_R < solved_at: continue print("Model solved! Continue") ret_steps = total_num_steps break if ret_steps == -1: print("Not solved.") ob_rms = utils.env.get_ob_rms(envs) assert (ob_rms != None) envs.close() return agent.actor_critic, ob_rms, ret_steps