def main(): # setup parameters args = SimpleNamespace( env_module="environments", env_name="TargetEnv-v0", device="cuda:0" if torch.cuda.is_available() else "cpu", num_parallel=100, vae_path="models/", frame_skip=1, seed=16, load_saved_model=False, ) args.num_parallel *= args.frame_skip env = make_gym_environment(args) # env parameters args.action_size = env.action_space.shape[0] args.observation_size = env.observation_space.shape[0] # other configs args.save_path = os.path.join(current_dir, "con_" + args.env_name + ".pt") # sampling parameters args.num_frames = 10e7 args.num_steps_per_rollout = env.unwrapped.max_timestep args.num_updates = int(args.num_frames / args.num_parallel / args.num_steps_per_rollout) # learning parameters args.lr = 3e-5 args.final_lr = 1e-5 args.eps = 1e-5 args.lr_decay_type = "exponential" args.mini_batch_size = 1000 args.num_mini_batch = (args.num_parallel * args.num_steps_per_rollout // args.mini_batch_size) # ppo parameters use_gae = True entropy_coef = 0.0 value_loss_coef = 1.0 ppo_epoch = 10 gamma = 0.99 gae_lambda = 0.95 clip_param = 0.2 max_grad_norm = 1.0 torch.manual_seed(args.seed) torch.cuda.manual_seed(args.seed) obs_shape = env.observation_space.shape obs_shape = (obs_shape[0], *obs_shape[1:]) if args.load_saved_model: actor_critic = torch.load(args.save_path, map_location=args.device) print("Loading model:", args.save_path) else: controller = PoseVAEController(env) actor_critic = PoseVAEPolicy(controller) actor_critic = actor_critic.to(args.device) actor_critic.env_info = {"frame_skip": args.frame_skip} agent = PPO( actor_critic, clip_param, ppo_epoch, args.num_mini_batch, value_loss_coef, entropy_coef, lr=args.lr, eps=args.eps, max_grad_norm=max_grad_norm, ) rollouts = RolloutStorage( args.num_steps_per_rollout, args.num_parallel, obs_shape, args.action_size, actor_critic.state_size, ) obs = env.reset() rollouts.observations[0].copy_(obs) rollouts.to(args.device) log_path = os.path.join(current_dir, "log_ppo_progress-{}".format(args.env_name)) logger = StatsLogger(csv_path=log_path) for update in range(args.num_updates): ep_info = {"reward": []} ep_reward = 0 if args.lr_decay_type == "linear": update_linear_schedule(agent.optimizer, update, args.num_updates, args.lr, args.final_lr) elif args.lr_decay_type == "exponential": update_exponential_schedule(agent.optimizer, update, 0.99, args.lr, args.final_lr) for step in range(args.num_steps_per_rollout): # Sample actions with torch.no_grad(): value, action, action_log_prob = actor_critic.act( rollouts.observations[step]) obs, reward, done, info = env.step(action) ep_reward += reward end_of_rollout = info.get("reset") masks = (~done).float() bad_masks = (~(done * end_of_rollout)).float() if done.any(): ep_info["reward"].append(ep_reward[done].clone()) ep_reward *= (~done).float() # zero out the dones reset_indices = env.parallel_ind_buf.masked_select( done.squeeze()) obs = env.reset(reset_indices) if end_of_rollout: obs = env.reset() rollouts.insert(obs, action, action_log_prob, value, reward, masks, bad_masks) with torch.no_grad(): next_value = actor_critic.get_value( rollouts.observations[-1]).detach() rollouts.compute_returns(next_value, use_gae, gamma, gae_lambda) value_loss, action_loss, dist_entropy = agent.update(rollouts) rollouts.after_update() torch.save(copy.deepcopy(actor_critic).cpu(), args.save_path) ep_info["reward"] = torch.cat(ep_info["reward"]) logger.log_stats( args, { "update": update, "ep_info": ep_info, "dist_entropy": dist_entropy, "value_loss": value_loss, "action_loss": action_loss, }, )
def main(): # --- ARGUMENTS --- parser = argparse.ArgumentParser() parser.add_argument('--env_name', type=str, default='coinrun', help='name of the environment to train on.') parser.add_argument('--model', type=str, default='ppo', help='the model to use for training.') args, rest_args = parser.parse_known_args() env_name = args.env_name model = args.model # get arguments args = args_pretrain_aup.get_args(rest_args) # place other args back into argparse.Namespace args.env_name = env_name args.model = model # Weights & Biases logger if args.run_name is None: # make run name as {env_name}_{TIME} now = datetime.datetime.now().strftime('_%d-%m_%H:%M:%S') args.run_name = args.env_name + '_' + args.algo + now # initialise wandb wandb.init(name=args.run_name, project=args.proj_name, group=args.group_name, config=args, monitor_gym=False) # save wandb dir path args.run_dir = wandb.run.dir wandb.config.update(args) # set random seed of random, torch and numpy utl.set_global_seed(args.seed, args.deterministic_execution) # --- OBTAIN DATA FOR TRAINING R_aux --- print("Gathering data for R_aux Model.") # gather observations for pretraining the auxiliary reward function (CB-VAE) envs = make_vec_envs(env_name=args.env_name, start_level=0, num_levels=0, distribution_mode=args.distribution_mode, paint_vel_info=args.paint_vel_info, num_processes=args.num_processes, num_frame_stack=args.num_frame_stack, device=device) # number of frames ÷ number of policy steps before update ÷ number of cpu processes num_batch = args.num_processes * args.policy_num_steps num_updates = int(args.num_frames_r_aux) // num_batch # create list to store env observations obs_data = torch.zeros(num_updates * args.policy_num_steps + 1, args.num_processes, *envs.observation_space.shape) # reset environments obs = envs.reset() # obs.shape = (n_env,C,H,W) obs_data[0].copy_(obs) obs = obs.to(device) for iter_idx in range(num_updates): # rollout policy to collect num_batch of experience and store in storage for step in range(args.policy_num_steps): # sample actions from random agent action = torch.randint(0, envs.action_space.n, (args.num_processes, 1)) # observe rewards and next obs obs, reward, done, infos = envs.step(action) # store obs obs_data[1 + iter_idx * args.policy_num_steps + step].copy_(obs) # close envs envs.close() # --- TRAIN R_aux (CB-VAE) --- # define CB-VAE where the encoder will be used as the auxiliary reward function R_aux print("Training R_aux Model.") # create dataloader for observations gathered obs_data = obs_data.reshape(-1, *envs.observation_space.shape) sampler = BatchSampler(SubsetRandomSampler(range(obs_data.size(0))), args.cb_vae_batch_size, drop_last=False) # initialise CB-VAE cb_vae = CBVAE(obs_shape=envs.observation_space.shape, latent_dim=args.cb_vae_latent_dim).to(device) # optimiser optimiser = torch.optim.Adam(cb_vae.parameters(), lr=args.cb_vae_learning_rate) # put CB-VAE into train mode cb_vae.train() measures = defaultdict(list) for epoch in range(args.cb_vae_epochs): print("Epoch: ", epoch) start_time = time.time() batch_loss = 0 for indices in sampler: obs = obs_data[indices].to(device) # zero accumulated gradients cb_vae.zero_grad() # forward pass through CB-VAE recon_batch, mu, log_var = cb_vae(obs) # calculate loss loss = cb_vae_loss(recon_batch, obs, mu, log_var) # backpropogation: calculating gradients loss.backward() # update parameters of generator optimiser.step() # save loss per mini-batch batch_loss += loss.item() * obs.size(0) # log losses per epoch wandb.log({ 'cb_vae/loss': batch_loss / obs_data.size(0), 'cb_vae/time_taken': time.time() - start_time, 'cb_vae/epoch': epoch }) indices = np.random.randint(0, obs.size(0), args.cb_vae_num_samples**2) measures['true_images'].append(obs[indices].detach().cpu().numpy()) measures['recon_images'].append( recon_batch[indices].detach().cpu().numpy()) # plot ground truth images plt.rcParams.update({'font.size': 10}) fig, axs = plt.subplots(args.cb_vae_num_samples, args.cb_vae_num_samples, figsize=(20, 20)) for i, img in enumerate(measures['true_images'][0]): axs[i // args.cb_vae_num_samples][i % args.cb_vae_num_samples].imshow( img.transpose(1, 2, 0)) axs[i // args.cb_vae_num_samples][i % args.cb_vae_num_samples].axis('off') wandb.log({"Ground Truth Images": wandb.Image(plt)}) # plot reconstructed images fig, axs = plt.subplots(args.cb_vae_num_samples, args.cb_vae_num_samples, figsize=(20, 20)) for i, img in enumerate(measures['recon_images'][0]): axs[i // args.cb_vae_num_samples][i % args.cb_vae_num_samples].imshow( img.transpose(1, 2, 0)) axs[i // args.cb_vae_num_samples][i % args.cb_vae_num_samples].axis('off') wandb.log({"Reconstructed Images": wandb.Image(plt)}) # --- TRAIN Q_aux -- # train PPO agent with value head replaced with action-value head and training on R_aux instead of the environment R print("Training Q_aux Model.") # initialise environments for training Q_aux envs = make_vec_envs(env_name=args.env_name, start_level=0, num_levels=0, distribution_mode=args.distribution_mode, paint_vel_info=args.paint_vel_info, num_processes=args.num_processes, num_frame_stack=args.num_frame_stack, device=device) # initialise policy network actor_critic = QModel(obs_shape=envs.observation_space.shape, action_space=envs.action_space, hidden_size=args.hidden_size).to(device) # initialise policy trainer if args.algo == 'ppo': policy = PPO(actor_critic=actor_critic, ppo_epoch=args.policy_ppo_epoch, num_mini_batch=args.policy_num_mini_batch, clip_param=args.policy_clip_param, value_loss_coef=args.policy_value_loss_coef, entropy_coef=args.policy_entropy_coef, max_grad_norm=args.policy_max_grad_norm, lr=args.policy_lr, eps=args.policy_eps) else: raise NotImplementedError # initialise rollout storage for the policy rollouts = RolloutStorage(num_steps=args.policy_num_steps, num_processes=args.num_processes, obs_shape=envs.observation_space.shape, action_space=envs.action_space) # count number of frames and updates frames = 0 iter_idx = 0 update_start_time = time.time() # reset environments obs = envs.reset() # obs.shape = (n_envs,C,H,W) # insert initial observation to rollout storage rollouts.obs[0].copy_(obs) rollouts.to(device) # initialise buffer for calculating mean episodic returns episode_info_buf = deque(maxlen=10) # calculate number of updates # number of frames ÷ number of policy steps before update ÷ number of cpu processes args.num_batch = args.num_processes * args.policy_num_steps args.num_updates = int(args.num_frames_q_aux) // args.num_batch print("Number of updates: ", args.num_updates) for iter_idx in range(args.num_updates): print("Iter: ", iter_idx) # put actor-critic into train mode actor_critic.train() # rollout policy to collect num_batch of experience and store in storage for step in range(args.policy_num_steps): with torch.no_grad(): # sample actions from policy value, action, action_log_prob = actor_critic.act( rollouts.obs[step]) # obtain reward R_aux from encoder of CB-VAE r_aux, _, _ = cb_vae.encode(rollouts.obs[step]) # observe rewards and next obs obs, _, done, infos = envs.step(action) # log episode info if episode finished for i, info in enumerate(infos): if 'episode' in info.keys(): episode_info_buf.append(info['episode']) # create mask for episode ends masks = torch.FloatTensor([[0.0] if done_ else [1.0] for done_ in done]).to(device) # add experience to policy buffer rollouts.insert(obs, r_aux, action, value, action_log_prob, masks) frames += args.num_processes # --- UPDATE --- # bootstrap next value prediction with torch.no_grad(): next_value = actor_critic.get_value(rollouts.obs[-1]).detach() # compute returns for current rollouts rollouts.compute_returns(next_value, args.policy_gamma, args.policy_gae_lambda) # update actor-critic using policy gradient algo total_loss, value_loss, action_loss, dist_entropy = policy.update( rollouts) # clean up after update rollouts.after_update() # --- LOGGING --- if iter_idx % args.log_interval == 0 or iter_idx == args.num_updates - 1: # get stats for run update_end_time = time.time() num_interval_updates = 1 if iter_idx == 0 else args.log_interval fps = num_interval_updates * ( args.num_processes * args.policy_num_steps) / (update_end_time - update_start_time) update_start_time = update_end_time # Calculates if value function is a good predicator of the returns (ev > 1) # or if it's just worse than predicting nothing (ev =< 0) ev = utl_math.explained_variance(utl.sf01(rollouts.value_preds), utl.sf01(rollouts.returns)) wandb.log({ 'q_aux_misc/timesteps': frames, 'q_aux_misc/fps': fps, 'q_aux_misc/explained_variance': float(ev), 'q_aux_losses/total_loss': total_loss, 'q_aux_losses/value_loss': value_loss, 'q_aux_losses/action_loss': action_loss, 'q_aux_losses/dist_entropy': dist_entropy, 'q_aux_train/mean_episodic_return': utl_math.safe_mean( [episode_info['r'] for episode_info in episode_info_buf]), 'q_aux_train/mean_episodic_length': utl_math.safe_mean( [episode_info['l'] for episode_info in episode_info_buf]) }) # close envs envs.close() # --- SAVE MODEL --- print("Saving Q_aux Model.") torch.save(actor_critic.state_dict(), args.q_aux_path)
def main(_seed, _config, _run): args = init(_seed, _config, _run) env_name = args.env_name dummy_env = make_env(env_name, render=False) cleanup_log_dir(args.log_dir) cleanup_log_dir(args.log_dir + "_test") try: os.makedirs(args.save_dir) except OSError: pass torch.set_num_threads(1) envs = make_vec_envs(env_name, args.seed, args.num_processes, args.log_dir) envs.set_mirror(args.use_phase_mirror) test_envs = make_vec_envs(env_name, args.seed, args.num_tests, args.log_dir + "_test") test_envs.set_mirror(args.use_phase_mirror) torch.manual_seed(args.seed) torch.cuda.manual_seed(args.seed) if args.use_curriculum: curriculum = 0 print("curriculum", curriculum) envs.update_curriculum(curriculum) if args.use_specialist: specialist = 0 print("specialist", specialist) envs.update_specialist(specialist) if args.use_threshold_sampling: sampling_threshold = 200 first_sampling = False uniform_sampling = True uniform_every = 500000 uniform_counter = 1 evaluate_envs = make_env(env_name, render=False) evaluate_envs.set_mirror(args.use_phase_mirror) evaluate_envs.update_curriculum(0) prob_filter = np.zeros((11, 11)) prob_filter[5, 5] = 1 if args.use_adaptive_sampling: evaluate_envs = make_env(env_name, render=False) evaluate_envs.set_mirror(args.use_phase_mirror) evaluate_envs.update_curriculum(0) if args.plot_prob: import matplotlib.pyplot as plt fig = plt.figure() plt.show(block=False) ax1 = fig.add_subplot(121) ax2 = fig.add_subplot(122) obs_shape = envs.observation_space.shape obs_shape = (obs_shape[0], *obs_shape[1:]) if args.load_saved_controller: best_model = "{}_base.pt".format(env_name) model_path = os.path.join(current_dir, "models", best_model) print("Loading model {}".format(best_model)) actor_critic = torch.load(model_path) actor_critic.reset_dist() else: controller = SoftsignActor(dummy_env) actor_critic = Policy(controller, num_ensembles=args.num_ensembles) mirror_function = None if args.use_mirror: indices = dummy_env.unwrapped.get_mirror_indices() mirror_function = get_mirror_function(indices) device = "cuda:0" if args.cuda else "cpu" if args.cuda: actor_critic.cuda() agent = PPO(actor_critic, mirror_function=mirror_function, **args.ppo_params) rollouts = RolloutStorage( args.num_steps, args.num_processes, obs_shape, envs.action_space.shape[0], actor_critic.state_size, ) current_obs = torch.zeros(args.num_processes, *obs_shape) def update_current_obs(obs): shape_dim0 = envs.observation_space.shape[0] obs = torch.from_numpy(obs).float() current_obs[:, -shape_dim0:] = obs obs = envs.reset() update_current_obs(obs) rollouts.observations[0].copy_(current_obs) if args.cuda: current_obs = current_obs.cuda() rollouts.cuda() episode_rewards = deque(maxlen=args.num_processes) test_episode_rewards = deque(maxlen=args.num_tests) num_updates = int(args.num_frames) // args.num_steps // args.num_processes start = time.time() next_checkpoint = args.save_every max_ep_reward = float("-inf") logger = ConsoleCSVLogger(log_dir=args.experiment_dir, console_log_interval=args.log_interval) update_values = False if args.save_sampling_prob: sampling_prob_list = [] for j in range(num_updates): if args.lr_decay_type == "linear": scheduled_lr = linear_decay(j, num_updates, args.lr, final_value=0) elif args.lr_decay_type == "exponential": scheduled_lr = exponential_decay(j, 0.99, args.lr, final_value=3e-5) else: scheduled_lr = args.lr set_optimizer_lr(agent.optimizer, scheduled_lr) ac_state_dict = copy.deepcopy(actor_critic).cpu().state_dict() if update_values and args.use_threshold_sampling: envs.update_curriculum(5) elif (not update_values ) and args.use_threshold_sampling and first_sampling: envs.update_specialist(0) if args.use_threshold_sampling and not uniform_sampling: obs = evaluate_envs.reset() yaw_size = dummy_env.yaw_samples.shape[0] pitch_size = dummy_env.pitch_samples.shape[0] total_metric = torch.zeros(1, yaw_size * pitch_size).to(device) evaluate_counter = 0 while True: obs = torch.from_numpy(obs).float().unsqueeze(0).to(device) with torch.no_grad(): _, action, _, _ = actor_critic.act(obs, None, None, deterministic=True) cpu_actions = action.squeeze().cpu().numpy() obs, reward, done, info = evaluate_envs.step(cpu_actions) if done: obs = evaluate_envs.reset() if evaluate_envs.update_terrain: evaluate_counter += 1 temp_states = evaluate_envs.create_temp_states() with torch.no_grad(): temp_states = torch.from_numpy(temp_states).float().to( device) value_samples = actor_critic.get_ensemble_values( temp_states, None, None) #yaw_size = dummy_env.yaw_samples.shape[0] mean = value_samples.mean(dim=-1) #mean = value_samples.min(dim=-1)[0] metric = mean.clone() metric = metric.view(yaw_size, pitch_size) #metric = metric / (metric.abs().max()) metric = metric.view(1, yaw_size * pitch_size) total_metric += metric if evaluate_counter >= 5: total_metric /= (total_metric.abs().max()) #total_metric[total_metric < 0.7] = 0 print("metric", total_metric) sampling_probs = ( -10 * (total_metric - args.curriculum_threshold).abs() ).softmax(dim=1).view( yaw_size, pitch_size ) #threshold1:150, 0.9 l2, threshold2: 10, 0.85 l1, threshold3: 10, 0.85, l1, 0.40 gap #threshold 4: 20, 0.85, l1, yaw 10 if args.save_sampling_prob: sampling_prob_list.append(sampling_probs.cpu().numpy()) sample_probs = np.zeros( (args.num_processes, yaw_size, pitch_size)) #print("prob", sampling_probs) for i in range(args.num_processes): sample_probs[i, :, :] = np.copy( sampling_probs.cpu().numpy().astype(np.float64)) envs.update_sample_prob(sample_probs) break elif args.use_threshold_sampling and uniform_sampling: envs.update_curriculum(5) # if args.use_threshold_sampling and not uniform_sampling: # obs = evaluate_envs.reset() # yaw_size = dummy_env.yaw_samples.shape[0] # pitch_size = dummy_env.pitch_samples.shape[0] # r_size = dummy_env.r_samples.shape[0] # total_metric = torch.zeros(1, yaw_size * pitch_size * r_size).to(device) # evaluate_counter = 0 # while True: # obs = torch.from_numpy(obs).float().unsqueeze(0).to(device) # with torch.no_grad(): # _, action, _, _ = actor_critic.act( # obs, None, None, deterministic=True # ) # cpu_actions = action.squeeze().cpu().numpy() # obs, reward, done, info = evaluate_envs.step(cpu_actions) # if done: # obs = evaluate_envs.reset() # if evaluate_envs.update_terrain: # evaluate_counter += 1 # temp_states = evaluate_envs.create_temp_states() # with torch.no_grad(): # temp_states = torch.from_numpy(temp_states).float().to(device) # value_samples = actor_critic.get_ensemble_values(temp_states, None, None) # mean = value_samples.mean(dim=-1) # #mean = value_samples.min(dim=-1)[0] # metric = mean.clone() # metric = metric.view(yaw_size, pitch_size, r_size) # #metric = metric / (metric.abs().max()) # metric = metric.view(1, yaw_size*pitch_size*r_size) # total_metric += metric # if evaluate_counter >= 5: # total_metric /= (total_metric.abs().max()) # #total_metric[total_metric < 0.7] = 0 # #print("metric", total_metric) # sampling_probs = (-10*(total_metric-0.85).abs()).softmax(dim=1).view(yaw_size, pitch_size, r_size) #threshold1:150, 0.9 l2, threshold2: 10, 0.85 l1, threshold3: 10, 0.85, l1, 0.40 gap # #threshold 4: 3d grid, 10, 0.85, l1 # sample_probs = np.zeros((args.num_processes, yaw_size, pitch_size, r_size)) # #print("prob", sampling_probs) # for i in range(args.num_processes): # sample_probs[i, :, :, :] = np.copy(sampling_probs.cpu().numpy().astype(np.float64)) # envs.update_sample_prob(sample_probs) # break # elif args.use_threshold_sampling and uniform_sampling: # envs.update_curriculum(5) if args.use_adaptive_sampling: obs = evaluate_envs.reset() yaw_size = dummy_env.yaw_samples.shape[0] pitch_size = dummy_env.pitch_samples.shape[0] total_metric = torch.zeros(1, yaw_size * pitch_size).to(device) evaluate_counter = 0 while True: obs = torch.from_numpy(obs).float().unsqueeze(0).to(device) with torch.no_grad(): _, action, _, _ = actor_critic.act(obs, None, None, deterministic=True) cpu_actions = action.squeeze().cpu().numpy() obs, reward, done, info = evaluate_envs.step(cpu_actions) if done: obs = evaluate_envs.reset() if evaluate_envs.update_terrain: evaluate_counter += 1 temp_states = evaluate_envs.create_temp_states() with torch.no_grad(): temp_states = torch.from_numpy(temp_states).float().to( device) value_samples = actor_critic.get_ensemble_values( temp_states, None, None) mean = value_samples.mean(dim=-1) metric = mean.clone() metric = metric.view(yaw_size, pitch_size) #metric = metric / metric.abs().max() metric = metric.view(1, yaw_size * pitch_size) total_metric += metric # sampling_probs = (-30*metric).softmax(dim=1).view(size, size) # sample_probs = np.zeros((args.num_processes, size, size)) # for i in range(args.num_processes): # sample_probs[i, :, :] = np.copy(sampling_probs.cpu().numpy().astype(np.float64)) # envs.update_sample_prob(sample_probs) if evaluate_counter >= 5: total_metric /= (total_metric.abs().max()) print("metric", total_metric) sampling_probs = (-10 * total_metric).softmax(dim=1).view( yaw_size, pitch_size) sample_probs = np.zeros( (args.num_processes, yaw_size, pitch_size)) for i in range(args.num_processes): sample_probs[i, :, :] = np.copy( sampling_probs.cpu().numpy().astype(np.float64)) envs.update_sample_prob(sample_probs) break for step in range(args.num_steps): # Sample actions with torch.no_grad(): value, action, action_log_prob, states = actor_critic.act( rollouts.observations[step], rollouts.states[step], rollouts.masks[step], deterministic=update_values) cpu_actions = action.squeeze(1).cpu().numpy() obs, reward, done, infos = envs.step(cpu_actions) reward = torch.from_numpy(np.expand_dims(np.stack(reward), 1)).float() if args.plot_prob and step == 0: temp_states = envs.create_temp_states() with torch.no_grad(): temp_states = torch.from_numpy(temp_states).float().to( device) value_samples = actor_critic.get_value( temp_states, None, None) size = dummy_env.yaw_samples.shape[0] v = value_samples.mean(dim=0).view(size, size).cpu().numpy() vs = value_samples.var(dim=0).view(size, size).cpu().numpy() ax1.pcolormesh(v) ax2.pcolormesh(vs) print(np.round(v, 2)) fig.canvas.draw() # if args.use_adaptive_sampling: # temp_states = envs.create_temp_states() # with torch.no_grad(): # temp_states = torch.from_numpy(temp_states).float().to(device) # value_samples = actor_critic.get_value(temp_states, None, None) # size = dummy_env.yaw_samples.shape[0] # sample_probs = (-value_samples / 5).softmax(dim=1).view(args.num_processes, size, size) # envs.update_sample_prob(sample_probs.cpu().numpy()) # if args.use_threshold_sampling and not uniform_sampling: # temp_states = envs.create_temp_states() # with torch.no_grad(): # temp_states = torch.from_numpy(temp_states).float().to(device) # value_samples = actor_critic.get_ensemble_values(temp_states, None, None) # size = dummy_env.yaw_samples.shape[0] # mean = value_samples.mean(dim=-1) # std = value_samples.std(dim=-1) #using std # metric = std.clone() # metric = metric.view(args.num_processes, size, size) # value_filter = torch.ones(args.num_processes, 11, 11).to(device) * -1e5 # value_filter[:, 5 - curriculum: 5 + curriculum + 1, 5 - curriculum: 5 + curriculum + 1] = 0 # metric = metric / metric.max() + value_filter # metric = metric.view(args.num_processes, size*size) # sample_probs = (30*metric).softmax(dim=1).view(args.num_processes, size, size) #using value estimate # metric = mean.clone() # metric = metric.view(args.num_processes, size, size) # value_filter = torch.ones(args.num_processes, 11, 11).to(device) * -1e5 # value_filter[:, 5 - curriculum: 5 + curriculum + 1, 5 - curriculum: 5 + curriculum + 1] = 0 # metric = metric / metric.abs().max() - value_filter # metric = metric.view(args.num_processes, size*size) # sample_probs = (-30*metric).softmax(dim=1).view(args.num_processes, size, size) # if args.plot_prob and step == 0: # #print(sample_probs.cpu().numpy()[0, :, :]) # ax.pcolormesh(sample_probs.cpu().numpy()[0, :, :]) # print(np.round(sample_probs.cpu().numpy()[0, :, :], 4)) # fig.canvas.draw() # envs.update_sample_prob(sample_probs.cpu().numpy()) #using value threshold # metric = mean.clone() # metric = metric.view(args.num_processes, size, size) # metric = metric / metric.abs().max()# - value_filter # metric = metric.view(args.num_processes, size*size) # sample_probs = (-30*(metric-0.8)**2).softmax(dim=1).view(args.num_processes, size, size) # if args.plot_prob and step == 0: # ax.pcolormesh(sample_probs.cpu().numpy()[0, :, :]) # print(np.round(sample_probs.cpu().numpy()[0, :, :], 4)) # fig.canvas.draw() # envs.update_sample_prob(sample_probs.cpu().numpy()) bad_masks = np.ones((args.num_processes, 1)) for p_index, info in enumerate(infos): keys = info.keys() # This information is added by algorithms.utils.TimeLimitMask if "bad_transition" in keys: bad_masks[p_index] = 0.0 # This information is added by baselines.bench.Monitor if "episode" in keys: episode_rewards.append(info["episode"]["r"]) masks = torch.FloatTensor([[0.0] if done_ else [1.0] for done_ in done]) bad_masks = torch.from_numpy(bad_masks) update_current_obs(obs) rollouts.insert( current_obs, states, action, action_log_prob, value, reward, masks, bad_masks, ) obs = test_envs.reset() if args.use_threshold_sampling: if uniform_counter % uniform_every == 0: uniform_sampling = True uniform_counter = 0 else: uniform_sampling = False uniform_counter += 1 if uniform_sampling: envs.update_curriculum(5) print("uniform") #print("max_step", dummy_env._max_episode_steps) for step in range(dummy_env._max_episode_steps): # Sample actions with torch.no_grad(): obs = torch.from_numpy(obs).float().to(device) _, action, _, _ = actor_critic.act(obs, None, None, deterministic=True) cpu_actions = action.squeeze(1).cpu().numpy() obs, reward, done, infos = test_envs.step(cpu_actions) reward = torch.from_numpy(np.expand_dims(np.stack(reward), 1)).float() for p_index, info in enumerate(infos): keys = info.keys() # This information is added by baselines.bench.Monitor if "episode" in keys: #print(info["episode"]["r"]) test_episode_rewards.append(info["episode"]["r"]) if args.use_curriculum and np.mean( episode_rewards) > 1000 and curriculum <= 4: curriculum += 1 print("curriculum", curriculum) envs.update_curriculum(curriculum) with torch.no_grad(): next_value = actor_critic.get_value(rollouts.observations[-1], rollouts.states[-1], rollouts.masks[-1]).detach() rollouts.compute_returns(next_value, args.use_gae, args.gamma, args.gae_lambda) if update_values: value_loss = agent.update_values(rollouts) else: value_loss, action_loss, dist_entropy = agent.update(rollouts) #update_values = (not update_values) rollouts.after_update() frame_count = (j + 1) * args.num_steps * args.num_processes if (frame_count >= next_checkpoint or j == num_updates - 1) and args.save_dir != "": model_name = "{}_{:d}.pt".format(env_name, int(next_checkpoint)) next_checkpoint += args.save_every else: model_name = "{}_latest.pt".format(env_name) if args.save_sampling_prob: import pickle with open('{}_sampling_prob85.pkl'.format(env_name), 'wb') as fp: pickle.dump(sampling_prob_list, fp) # A really ugly way to save a model to CPU save_model = actor_critic if args.cuda: save_model = copy.deepcopy(actor_critic).cpu() if args.use_specialist and np.mean( episode_rewards) > 1000 and specialist <= 4: specialist_name = "{}_specialist_{:d}.pt".format( env_name, int(specialist)) specialist_model = actor_critic if args.cuda: specialist_model = copy.deepcopy(actor_critic).cpu() torch.save(specialist_model, os.path.join(args.save_dir, specialist_name)) specialist += 1 envs.update_specialist(specialist) # if args.use_threshold_sampling and np.mean(episode_rewards) > 1000 and curriculum <= 4: # first_sampling = False # curriculum += 1 # print("curriculum", curriculum) # envs.update_curriculum(curriculum) # prob_filter[5-curriculum:5+curriculum+1, 5-curriculum:5+curriculum+1] = 1 torch.save(save_model, os.path.join(args.save_dir, model_name)) if len(episode_rewards) > 1 and np.mean( episode_rewards) > max_ep_reward: model_name = "{}_best.pt".format(env_name) max_ep_reward = np.mean(episode_rewards) torch.save(save_model, os.path.join(args.save_dir, model_name)) if len(episode_rewards) > 1: end = time.time() total_num_steps = (j + 1) * args.num_processes * args.num_steps logger.log_epoch({ "iter": j + 1, "total_num_steps": total_num_steps, "fps": int(total_num_steps / (end - start)), "entropy": dist_entropy, "value_loss": value_loss, "action_loss": action_loss, "stats": { "rew": episode_rewards }, "test_stats": { "rew": test_episode_rewards }, })
def main(_seed, _config, _run): args = init(_seed, _config, _run, post_config=post_config) env_name = args.env_name dummy_env = make_env(env_name, render=False) cleanup_log_dir(args.log_dir) try: os.makedirs(args.save_dir) except OSError: pass torch.set_num_threads(1) envs = make_vec_envs(env_name, args.seed, args.num_processes, args.log_dir) obs_shape = envs.observation_space.shape obs_shape = (obs_shape[0], *obs_shape[1:]) if args.load_saved_controller: best_model = "{}_best.pt".format(env_name) model_path = os.path.join(current_dir, "models", best_model) print("Loading model {}".format(best_model)) actor_critic = torch.load(model_path) else: if args.mirror_method == MirrorMethods.net2: controller = SymmetricNetV2( *dummy_env.unwrapped.mirror_sizes, num_layers=6, hidden_size=256, tanh_finish=True ) else: controller = SoftsignActor(dummy_env) if args.mirror_method == MirrorMethods.net: controller = SymmetricNet(controller, *dummy_env.unwrapped.sym_act_inds) actor_critic = Policy(controller) if args.sym_value_net: actor_critic.critic = SymmetricVNet( actor_critic.critic, controller.state_dim ) mirror_function = None if ( args.mirror_method == MirrorMethods.traj or args.mirror_method == MirrorMethods.loss ): indices = dummy_env.unwrapped.get_mirror_indices() mirror_function = get_mirror_function(indices) if args.cuda: actor_critic.cuda() agent = PPO(actor_critic, mirror_function=mirror_function, **args.ppo_params) rollouts = RolloutStorage( args.num_steps, args.num_processes, obs_shape, envs.action_space.shape[0], actor_critic.state_size, ) current_obs = torch.zeros(args.num_processes, *obs_shape) def update_current_obs(obs): shape_dim0 = envs.observation_space.shape[0] obs = torch.from_numpy(obs).float() current_obs[:, -shape_dim0:] = obs obs = envs.reset() update_current_obs(obs) rollouts.observations[0].copy_(current_obs) if args.cuda: current_obs = current_obs.cuda() rollouts.cuda() episode_rewards = deque(maxlen=args.num_processes) num_updates = int(args.num_frames) // args.num_steps // args.num_processes start = time.time() next_checkpoint = args.save_every max_ep_reward = float("-inf") logger = ConsoleCSVLogger( log_dir=args.experiment_dir, console_log_interval=args.log_interval ) for j in range(num_updates): if args.lr_decay_type == "linear": scheduled_lr = linear_decay(j, num_updates, args.lr, final_value=0) elif args.lr_decay_type == "exponential": scheduled_lr = exponential_decay(j, 0.99, args.lr, final_value=3e-5) else: scheduled_lr = args.lr set_optimizer_lr(agent.optimizer, scheduled_lr) for step in range(args.num_steps): # Sample actions with torch.no_grad(): value, action, action_log_prob, states = actor_critic.act( rollouts.observations[step], rollouts.states[step], rollouts.masks[step], ) cpu_actions = action.squeeze(1).cpu().numpy() obs, reward, done, infos = envs.step(cpu_actions) reward = torch.from_numpy(np.expand_dims(np.stack(reward), 1)).float() bad_masks = np.ones((args.num_processes, 1)) for p_index, info in enumerate(infos): keys = info.keys() # This information is added by algorithms.utils.TimeLimitMask if "bad_transition" in keys: bad_masks[p_index] = 0.0 # This information is added by baselines.bench.Monitor if "episode" in keys: episode_rewards.append(info["episode"]["r"]) masks = torch.FloatTensor([[0.0] if done_ else [1.0] for done_ in done]) bad_masks = torch.from_numpy(bad_masks) update_current_obs(obs) rollouts.insert( current_obs, states, action, action_log_prob, value, reward, masks, bad_masks, ) with torch.no_grad(): next_value = actor_critic.get_value( rollouts.observations[-1], rollouts.states[-1], rollouts.masks[-1] ).detach() rollouts.compute_returns(next_value, args.use_gae, args.gamma, args.gae_lambda) value_loss, action_loss, dist_entropy = agent.update(rollouts) rollouts.after_update() frame_count = (j + 1) * args.num_steps * args.num_processes if ( frame_count >= next_checkpoint or j == num_updates - 1 ) and args.save_dir != "": model_name = "{}_{:d}.pt".format(env_name, int(next_checkpoint)) next_checkpoint += args.save_every else: model_name = "{}_latest.pt".format(env_name) # A really ugly way to save a model to CPU save_model = actor_critic if args.cuda: save_model = copy.deepcopy(actor_critic).cpu() drive=1 if drive: #print("save") torch.save(save_model, os.path.join("/content/gdrive/My Drive/darwin", model_name)) torch.save(save_model, os.path.join(args.save_dir, model_name)) if len(episode_rewards) > 1 and np.mean(episode_rewards) > max_ep_reward: model_name = "{}_best.pt".format(env_name) max_ep_reward = np.mean(episode_rewards) drive=1 if drive: #print("max_ep_reward",max_ep_reward) torch.save(save_model, os.path.join("/content/gdrive/My Drive/darwin", model_name)) torch.save(save_model, os.path.join(args.save_dir, model_name)) if len(episode_rewards) > 1: end = time.time() total_num_steps = (j + 1) * args.num_processes * args.num_steps logger.log_epoch( { "iter": j + 1, "total_num_steps": total_num_steps, "fps": int(total_num_steps / (end - start)), "entropy": dist_entropy, "value_loss": value_loss, "action_loss": action_loss, "stats": {"rew": episode_rewards}, } )
def main(): parser = argparse.ArgumentParser() parser.add_argument('--env_name', type=str, default='coinrun', help='name of the environment to train on.') parser.add_argument('--model', type=str, default='ppo', help='the model to use for training. {ppo, ppo_aup}') args, rest_args = parser.parse_known_args() env_name = args.env_name model = args.model # --- ARGUMENTS --- if model == 'ppo': args = args_ppo.get_args(rest_args) elif model == 'ppo_aup': args = args_ppo_aup.get_args(rest_args) else: raise NotImplementedError # place other args back into argparse.Namespace args.env_name = env_name args.model = model # warnings if args.deterministic_execution: print('Envoking deterministic code execution.') if torch.backends.cudnn.enabled: warnings.warn('Running with deterministic CUDNN.') if args.num_processes > 1: raise RuntimeError( 'If you want fully deterministic code, run it with num_processes=1.' 'Warning: This will slow things down and might break A2C if ' 'policy_num_steps < env._max_episode_steps.') # --- TRAINING --- print("Setting up wandb logging.") # Weights & Biases logger if args.run_name is None: # make run name as {env_name}_{TIME} now = datetime.datetime.now().strftime('_%d-%m_%H:%M:%S') args.run_name = args.env_name + '_' + args.algo + now # initialise wandb wandb.init(project=args.proj_name, name=args.run_name, group=args.group_name, config=args, monitor_gym=False) # save wandb dir path args.run_dir = wandb.run.dir # make directory for saving models save_dir = os.path.join(wandb.run.dir, 'models') if not os.path.exists(save_dir): os.makedirs(save_dir) # set random seed of random, torch and numpy utl.set_global_seed(args.seed, args.deterministic_execution) print("Setting up Environments.") # initialise environments for training train_envs = make_vec_envs(env_name=args.env_name, start_level=args.train_start_level, num_levels=args.train_num_levels, distribution_mode=args.distribution_mode, paint_vel_info=args.paint_vel_info, num_processes=args.num_processes, num_frame_stack=args.num_frame_stack, device=device) # initialise environments for evaluation eval_envs = make_vec_envs(env_name=args.env_name, start_level=0, num_levels=0, distribution_mode=args.distribution_mode, paint_vel_info=args.paint_vel_info, num_processes=args.num_processes, num_frame_stack=args.num_frame_stack, device=device) _ = eval_envs.reset() print("Setting up Actor-Critic model and Training algorithm.") # initialise policy network actor_critic = ACModel(obs_shape=train_envs.observation_space.shape, action_space=train_envs.action_space, hidden_size=args.hidden_size).to(device) # initialise policy training algorithm if args.algo == 'ppo': policy = PPO(actor_critic=actor_critic, ppo_epoch=args.policy_ppo_epoch, num_mini_batch=args.policy_num_mini_batch, clip_param=args.policy_clip_param, value_loss_coef=args.policy_value_loss_coef, entropy_coef=args.policy_entropy_coef, max_grad_norm=args.policy_max_grad_norm, lr=args.policy_lr, eps=args.policy_eps) else: raise NotImplementedError # initialise rollout storage for the policy training algorithm rollouts = RolloutStorage(num_steps=args.policy_num_steps, num_processes=args.num_processes, obs_shape=train_envs.observation_space.shape, action_space=train_envs.action_space) # initialise Q_aux function(s) for AUP if args.use_aup: print("Initialising Q_aux models.") q_aux = [ QModel(obs_shape=train_envs.observation_space.shape, action_space=train_envs.action_space, hidden_size=args.hidden_size).to(device) for _ in range(args.num_q_aux) ] if args.num_q_aux == 1: # load weights to model path = args.q_aux_dir + "0.pt" q_aux[0].load_state_dict(torch.load(path)) q_aux[0].eval() else: # get max number of q_aux functions to choose from args.max_num_q_aux = os.listdir(args.q_aux_dir) q_aux_models = random.sample(list(range(0, args.max_num_q_aux)), args.num_q_aux) # load weights to models for i, model in enumerate(q_aux): path = args.q_aux_dir + str(q_aux_models[i]) + ".pt" model.load_state_dict(torch.load(path)) model.eval() # count number of frames and updates frames = 0 iter_idx = 0 # update wandb args wandb.config.update(args) update_start_time = time.time() # reset environments obs = train_envs.reset() # obs.shape = (num_processes,C,H,W) # insert initial observation to rollout storage rollouts.obs[0].copy_(obs) rollouts.to(device) # initialise buffer for calculating mean episodic returns episode_info_buf = deque(maxlen=10) # calculate number of updates # number of frames ÷ number of policy steps before update ÷ number of processes args.num_batch = args.num_processes * args.policy_num_steps args.num_updates = int(args.num_frames) // args.num_batch # define AUP coefficient if args.use_aup: aup_coef = args.aup_coef_start aup_linear_increase_val = math.exp( math.log(args.aup_coef_end / args.aup_coef_start) / args.num_updates) print("Training beginning.") print("Number of updates: ", args.num_updates) for iter_idx in range(args.num_updates): print("Iter: ", iter_idx) # put actor-critic into train mode actor_critic.train() if args.use_aup: aup_measures = defaultdict(list) # rollout policy to collect num_batch of experience and place in storage for step in range(args.policy_num_steps): # sample actions from policy with torch.no_grad(): value, action, action_log_prob = actor_critic.act( rollouts.obs[step]) # observe rewards and next obs obs, reward, done, infos = train_envs.step(action) # calculate AUP reward if args.use_aup: intrinsic_reward = torch.zeros_like(reward) with torch.no_grad(): for model in q_aux: # get action-values action_values = model.get_action_value( rollouts.obs[step]) # get action-value for action taken action_value = torch.sum( action_values * torch.nn.functional.one_hot( action, num_classes=train_envs.action_space.n).squeeze( dim=1), dim=1) # calculate the penalty intrinsic_reward += torch.abs( action_value.unsqueeze(dim=1) - action_values[:, 4].unsqueeze(dim=1)) intrinsic_reward /= args.num_q_aux # add intrinsic reward to the extrinsic reward reward -= aup_coef * intrinsic_reward # log the intrinsic reward from the first env. aup_measures['intrinsic_reward'].append(aup_coef * intrinsic_reward[0, 0]) if done[0] and infos[0]['prev_level_complete'] == 1: aup_measures['episode_complete'].append(2) elif done[0] and infos[0]['prev_level_complete'] == 0: aup_measures['episode_complete'].append(1) else: aup_measures['episode_complete'].append(0) # log episode info if episode finished for info in infos: if 'episode' in info.keys(): episode_info_buf.append(info['episode']) # create mask for episode ends masks = torch.FloatTensor([[0.0] if done_ else [1.0] for done_ in done]).to(device) # add experience to storage rollouts.insert(obs, reward, action, value, action_log_prob, masks) frames += args.num_processes # linearly increase aup coefficient after every update if args.use_aup: aup_coef *= aup_linear_increase_val # --- UPDATE --- # bootstrap next value prediction with torch.no_grad(): next_value = actor_critic.get_value(rollouts.obs[-1]).detach() # compute returns for current rollouts rollouts.compute_returns(next_value, args.policy_gamma, args.policy_gae_lambda) # update actor-critic using policy training algorithm total_loss, value_loss, action_loss, dist_entropy = policy.update( rollouts) # clean up storage after update rollouts.after_update() # --- LOGGING --- if iter_idx % args.log_interval == 0 or iter_idx == args.num_updates - 1: # --- EVALUATION --- eval_episode_info_buf = utl_eval.evaluate( eval_envs=eval_envs, actor_critic=actor_critic, device=device) # get stats for run update_end_time = time.time() num_interval_updates = 1 if iter_idx == 0 else args.log_interval fps = num_interval_updates * ( args.num_processes * args.policy_num_steps) / (update_end_time - update_start_time) update_start_time = update_end_time # calculates whether the value function is a good predicator of the returns (ev > 1) # or if it's just worse than predicting nothing (ev =< 0) ev = utl_math.explained_variance(utl.sf01(rollouts.value_preds), utl.sf01(rollouts.returns)) if args.use_aup: step = frames - args.num_processes * args.policy_num_steps for i in range(args.policy_num_steps): wandb.log( { 'aup/intrinsic_reward': aup_measures['intrinsic_reward'][i], 'aup/episode_complete': aup_measures['episode_complete'][i] }, step=step) step += args.num_processes wandb.log( { 'misc/timesteps': frames, 'misc/fps': fps, 'misc/explained_variance': float(ev), 'losses/total_loss': total_loss, 'losses/value_loss': value_loss, 'losses/action_loss': action_loss, 'losses/dist_entropy': dist_entropy, 'train/mean_episodic_return': utl_math.safe_mean([ episode_info['r'] for episode_info in episode_info_buf ]), 'train/mean_episodic_length': utl_math.safe_mean([ episode_info['l'] for episode_info in episode_info_buf ]), 'eval/mean_episodic_return': utl_math.safe_mean([ episode_info['r'] for episode_info in eval_episode_info_buf ]), 'eval/mean_episodic_length': utl_math.safe_mean([ episode_info['l'] for episode_info in eval_episode_info_buf ]) }, step=frames) # --- SAVE MODEL --- # save for every interval-th episode or for the last epoch if iter_idx != 0 and (iter_idx % args.save_interval == 0 or iter_idx == args.num_updates - 1): print("Saving Actor-Critic Model.") torch.save(actor_critic.state_dict(), os.path.join(save_dir, "policy{0}.pt".format(iter_idx))) # close envs train_envs.close() eval_envs.close() # --- TEST --- if args.test: print("Testing beginning.") episodic_return = utl_test.test(args=args, actor_critic=actor_critic, device=device) # save returns from train and test levels to analyse using interactive mode train_levels = torch.arange( args.train_start_level, args.train_start_level + args.train_num_levels) for i, level in enumerate(train_levels): wandb.log({ 'test/train_levels': level, 'test/train_returns': episodic_return[0][i] }) test_levels = torch.arange( args.test_start_level, args.test_start_level + args.test_num_levels) for i, level in enumerate(test_levels): wandb.log({ 'test/test_levels': level, 'test/test_returns': episodic_return[1][i] }) # log returns from test envs wandb.run.summary["train_mean_episodic_return"] = utl_math.safe_mean( episodic_return[0]) wandb.run.summary["test_mean_episodic_return"] = utl_math.safe_mean( episodic_return[1])
def main(): parser = argparse.ArgumentParser() parser.add_argument('--env_name', type=str, default='coinrun', help='name of the environment to train on.') parser.add_argument( '--model', type=str, default='ppo', help='the model to use for training. {ppo, ibac, ibac_sni, dist_match}' ) args, rest_args = parser.parse_known_args() env_name = args.env_name model = args.model # --- ARGUMENTS --- if model == 'ppo': args = args_ppo.get_args(rest_args) elif model == 'ibac': args = args_ibac.get_args(rest_args) elif model == 'ibac_sni': args = args_ibac_sni.get_args(rest_args) elif model == 'dist_match': args = args_dist_match.get_args(rest_args) else: raise NotImplementedError # place other args back into argparse.Namespace args.env_name = env_name args.model = model args.num_train_envs = args.num_processes - args.num_val_envs if args.num_val_envs > 0 else args.num_processes # warnings if args.deterministic_execution: print('Envoking deterministic code execution.') if torch.backends.cudnn.enabled: warnings.warn('Running with deterministic CUDNN.') if args.num_processes > 1: raise RuntimeError( 'If you want fully deterministic code, run it with num_processes=1.' 'Warning: This will slow things down and might break A2C if ' 'policy_num_steps < env._max_episode_steps.') elif args.num_val_envs > 0 and (args.num_val_envs >= args.num_processes or not args.percentage_levels_train < 1.0): raise ValueError( 'If --args.num_val_envs>0 then you must also have' '--num_val_envs < --num_processes and 0 < --percentage_levels_train < 1.' ) elif args.num_val_envs > 0 and not args.use_dist_matching and args.dist_matching_coef != 0: raise ValueError( 'If --num_val_envs>0 and --use_dist_matching=False then you must also have' '--dist_matching_coef=0.') elif args.use_dist_matching and not args.num_val_envs > 0: raise ValueError( 'If --use_dist_matching=True then you must also have' '0 < --num_val_envs < --num_processes and 0 < --percentage_levels_train < 1.' ) elif args.analyse_rep and not args.use_bottleneck: raise ValueError('If --analyse_rep=True then you must also have' '--use_bottleneck=True.') # --- TRAINING --- print("Setting up wandb logging.") # Weights & Biases logger if args.run_name is None: # make run name as {env_name}_{TIME} now = datetime.datetime.now().strftime('_%d-%m_%H:%M:%S') args.run_name = args.env_name + '_' + args.algo + now # initialise wandb wandb.init(project=args.proj_name, name=args.run_name, group=args.group_name, config=args, monitor_gym=False) # save wandb dir path args.run_dir = wandb.run.dir # make directory for saving models save_dir = os.path.join(wandb.run.dir, 'models') if not os.path.exists(save_dir): os.makedirs(save_dir) # set random seed of random, torch and numpy utl.set_global_seed(args.seed, args.deterministic_execution) # initialise environments for training print("Setting up Environments.") if args.num_val_envs > 0: train_num_levels = int(args.train_num_levels * args.percentage_levels_train) val_start_level = args.train_start_level + train_num_levels val_num_levels = args.train_num_levels - train_num_levels train_envs = make_vec_envs(env_name=args.env_name, start_level=args.train_start_level, num_levels=train_num_levels, distribution_mode=args.distribution_mode, paint_vel_info=args.paint_vel_info, num_processes=args.num_train_envs, num_frame_stack=args.num_frame_stack, device=device) val_envs = make_vec_envs(env_name=args.env_name, start_level=val_start_level, num_levels=val_num_levels, distribution_mode=args.distribution_mode, paint_vel_info=args.paint_vel_info, num_processes=args.num_val_envs, num_frame_stack=args.num_frame_stack, device=device) else: train_envs = make_vec_envs(env_name=args.env_name, start_level=args.train_start_level, num_levels=args.train_num_levels, distribution_mode=args.distribution_mode, paint_vel_info=args.paint_vel_info, num_processes=args.num_processes, num_frame_stack=args.num_frame_stack, device=device) # initialise environments for evaluation eval_envs = make_vec_envs(env_name=args.env_name, start_level=0, num_levels=0, distribution_mode=args.distribution_mode, paint_vel_info=args.paint_vel_info, num_processes=args.num_processes, num_frame_stack=args.num_frame_stack, device=device) _ = eval_envs.reset() # initialise environments for analysing the representation if args.analyse_rep: analyse_rep_train1_envs, analyse_rep_train2_envs, analyse_rep_val_envs, analyse_rep_test_envs = make_rep_analysis_envs( args, device) print("Setting up Actor-Critic model and Training algorithm.") # initialise policy network actor_critic = ACModel(obs_shape=train_envs.observation_space.shape, action_space=train_envs.action_space, hidden_size=args.hidden_size, use_bottleneck=args.use_bottleneck, sni_type=args.sni_type).to(device) # initialise policy training algorithm if args.algo == 'ppo': policy = PPO(actor_critic=actor_critic, ppo_epoch=args.policy_ppo_epoch, num_mini_batch=args.policy_num_mini_batch, clip_param=args.policy_clip_param, value_loss_coef=args.policy_value_loss_coef, entropy_coef=args.policy_entropy_coef, max_grad_norm=args.policy_max_grad_norm, lr=args.policy_lr, eps=args.policy_eps, vib_coef=args.vib_coef, sni_coef=args.sni_coef, use_dist_matching=args.use_dist_matching, dist_matching_loss=args.dist_matching_loss, dist_matching_coef=args.dist_matching_coef, num_train_envs=args.num_train_envs, num_val_envs=args.num_val_envs) else: raise NotImplementedError # initialise rollout storage for the policy training algorithm rollouts = RolloutStorage(num_steps=args.policy_num_steps, num_processes=args.num_processes, obs_shape=train_envs.observation_space.shape, action_space=train_envs.action_space) # count number of frames and updates frames = 0 iter_idx = 0 # update wandb args wandb.config.update(args) # wandb.watch(actor_critic, log="all") # to log gradients of actor-critic network update_start_time = time.time() # reset environments if args.num_val_envs > 0: obs = torch.cat([train_envs.reset(), val_envs.reset()]) # obs.shape = (n_envs,C,H,W) else: obs = train_envs.reset() # obs.shape = (n_envs,C,H,W) # insert initial observation to rollout storage rollouts.obs[0].copy_(obs) rollouts.to(device) # initialise buffer for calculating mean episodic returns train_episode_info_buf = deque(maxlen=10) val_episode_info_buf = deque(maxlen=10) # calculate number of updates # number of frames ÷ number of policy steps before update ÷ number of processes args.num_batch = args.num_processes * args.policy_num_steps args.num_updates = int(args.num_frames) // args.num_batch print("Training beginning.") print("Number of updates: ", args.num_updates) for iter_idx in range(args.num_updates): print("Iter: ", iter_idx) # put actor-critic into train mode actor_critic.train() # rollout policy to collect num_batch of experience and place in storage for step in range(args.policy_num_steps): # sample actions from policy with torch.no_grad(): value, action, action_log_prob, _ = actor_critic.act( rollouts.obs[step]) # observe rewards and next obs if args.num_val_envs > 0: obs, reward, done, infos = train_envs.step( action[:args.num_train_envs, :]) val_obs, val_reward, val_done, val_infos = val_envs.step( action[args.num_train_envs:, :]) obs = torch.cat([obs, val_obs]) reward = torch.cat([reward, val_reward]) done, val_done = list(done), list(val_done) done.extend(val_done) infos.extend(val_infos) else: obs, reward, done, infos = train_envs.step(action) # log episode info if episode finished for i, info in enumerate(infos): if i < args.num_train_envs and 'episode' in info.keys(): train_episode_info_buf.append(info['episode']) elif i >= args.num_train_envs and 'episode' in info.keys(): val_episode_info_buf.append(info['episode']) # create mask for episode ends masks = torch.FloatTensor([[0.0] if done_ else [1.0] for done_ in done]).to(device) # add experience to storage rollouts.insert(obs, reward, action, value, action_log_prob, masks) frames += args.num_processes # --- UPDATE --- # bootstrap next value prediction with torch.no_grad(): next_value = actor_critic.get_value(rollouts.obs[-1]).detach() # compute returns for current rollouts rollouts.compute_returns(next_value, args.policy_gamma, args.policy_gae_lambda) # update actor-critic using policy gradient algo total_loss, value_loss, action_loss, dist_entropy, vib_kl, dist_matching_loss = policy.update( rollouts) # clean up storage after update rollouts.after_update() # --- LOGGING --- if iter_idx % args.log_interval == 0 or iter_idx == args.num_updates - 1: # --- EVALUATION --- eval_episode_info_buf = utl_eval.evaluate( eval_envs=eval_envs, actor_critic=actor_critic, device=device) # --- ANALYSE REPRESENTATION --- if args.analyse_rep: rep_measures = utl_rep.analyse_rep( args=args, train1_envs=analyse_rep_train1_envs, train2_envs=analyse_rep_train2_envs, val_envs=analyse_rep_val_envs, test_envs=analyse_rep_test_envs, actor_critic=actor_critic, device=device) # get stats for run update_end_time = time.time() num_interval_updates = 1 if iter_idx == 0 else args.log_interval fps = num_interval_updates * ( args.num_processes * args.policy_num_steps) / (update_end_time - update_start_time) update_start_time = update_end_time # Calculates if value function is a good predicator of the returns (ev > 1) # or if it's just worse than predicting nothing (ev =< 0) ev = utl_math.explained_variance(utl.sf01(rollouts.value_preds), utl.sf01(rollouts.returns)) wandb.log( { 'misc/timesteps': frames, 'misc/fps': fps, 'misc/explained_variance': float(ev), 'losses/total_loss': total_loss, 'losses/value_loss': value_loss, 'losses/action_loss': action_loss, 'losses/dist_entropy': dist_entropy, 'train/mean_episodic_return': utl_math.safe_mean([ episode_info['r'] for episode_info in train_episode_info_buf ]), 'train/mean_episodic_length': utl_math.safe_mean([ episode_info['l'] for episode_info in train_episode_info_buf ]), 'eval/mean_episodic_return': utl_math.safe_mean([ episode_info['r'] for episode_info in eval_episode_info_buf ]), 'eval/mean_episodic_length': utl_math.safe_mean([ episode_info['l'] for episode_info in eval_episode_info_buf ]) }, step=iter_idx) if args.use_bottleneck: wandb.log({'losses/vib_kl': vib_kl}, step=iter_idx) if args.num_val_envs > 0: wandb.log( { 'losses/dist_matching_loss': dist_matching_loss, 'val/mean_episodic_return': utl_math.safe_mean([ episode_info['r'] for episode_info in val_episode_info_buf ]), 'val/mean_episodic_length': utl_math.safe_mean([ episode_info['l'] for episode_info in val_episode_info_buf ]) }, step=iter_idx) if args.analyse_rep: wandb.log( { "analysis/" + key: val for key, val in rep_measures.items() }, step=iter_idx) # --- SAVE MODEL --- # save for every interval-th episode or for the last epoch if iter_idx != 0 and (iter_idx % args.save_interval == 0 or iter_idx == args.num_updates - 1): print("Saving Actor-Critic Model.") torch.save(actor_critic.state_dict(), os.path.join(save_dir, "policy{0}.pt".format(iter_idx))) # close envs train_envs.close() eval_envs.close() # --- TEST --- if args.test: print("Testing beginning.") episodic_return, latents_z = utl_test.test(args=args, actor_critic=actor_critic, device=device) # save returns from train and test levels to analyse using interactive mode train_levels = torch.arange( args.train_start_level, args.train_start_level + args.train_num_levels) for i, level in enumerate(train_levels): wandb.log({ 'test/train_levels': level, 'test/train_returns': episodic_return[0][i] }) test_levels = torch.arange( args.test_start_level, args.test_start_level + args.test_num_levels) for i, level in enumerate(test_levels): wandb.log({ 'test/test_levels': level, 'test/test_returns': episodic_return[1][i] }) # log returns from test envs wandb.run.summary["train_mean_episodic_return"] = utl_math.safe_mean( episodic_return[0]) wandb.run.summary["test_mean_episodic_return"] = utl_math.safe_mean( episodic_return[1]) # plot latent representation if args.plot_pca: print("Plotting PCA of Latent Representation.") utl_rep.pca(args, latents_z)