def main(): # setup parameters args = SimpleNamespace( env_module="environments", env_name="TargetEnv-v0", device="cuda:0" if torch.cuda.is_available() else "cpu", num_parallel=100, vae_path="models/", frame_skip=1, seed=16, load_saved_model=False, ) args.num_parallel *= args.frame_skip env = make_gym_environment(args) # env parameters args.action_size = env.action_space.shape[0] args.observation_size = env.observation_space.shape[0] # other configs args.save_path = os.path.join(current_dir, "con_" + args.env_name + ".pt") # sampling parameters args.num_frames = 10e7 args.num_steps_per_rollout = env.unwrapped.max_timestep args.num_updates = int(args.num_frames / args.num_parallel / args.num_steps_per_rollout) # learning parameters args.lr = 3e-5 args.final_lr = 1e-5 args.eps = 1e-5 args.lr_decay_type = "exponential" args.mini_batch_size = 1000 args.num_mini_batch = (args.num_parallel * args.num_steps_per_rollout // args.mini_batch_size) # ppo parameters use_gae = True entropy_coef = 0.0 value_loss_coef = 1.0 ppo_epoch = 10 gamma = 0.99 gae_lambda = 0.95 clip_param = 0.2 max_grad_norm = 1.0 torch.manual_seed(args.seed) torch.cuda.manual_seed(args.seed) obs_shape = env.observation_space.shape obs_shape = (obs_shape[0], *obs_shape[1:]) if args.load_saved_model: actor_critic = torch.load(args.save_path, map_location=args.device) print("Loading model:", args.save_path) else: controller = PoseVAEController(env) actor_critic = PoseVAEPolicy(controller) actor_critic = actor_critic.to(args.device) actor_critic.env_info = {"frame_skip": args.frame_skip} agent = PPO( actor_critic, clip_param, ppo_epoch, args.num_mini_batch, value_loss_coef, entropy_coef, lr=args.lr, eps=args.eps, max_grad_norm=max_grad_norm, ) rollouts = RolloutStorage( args.num_steps_per_rollout, args.num_parallel, obs_shape, args.action_size, actor_critic.state_size, ) obs = env.reset() rollouts.observations[0].copy_(obs) rollouts.to(args.device) log_path = os.path.join(current_dir, "log_ppo_progress-{}".format(args.env_name)) logger = StatsLogger(csv_path=log_path) for update in range(args.num_updates): ep_info = {"reward": []} ep_reward = 0 if args.lr_decay_type == "linear": update_linear_schedule(agent.optimizer, update, args.num_updates, args.lr, args.final_lr) elif args.lr_decay_type == "exponential": update_exponential_schedule(agent.optimizer, update, 0.99, args.lr, args.final_lr) for step in range(args.num_steps_per_rollout): # Sample actions with torch.no_grad(): value, action, action_log_prob = actor_critic.act( rollouts.observations[step]) obs, reward, done, info = env.step(action) ep_reward += reward end_of_rollout = info.get("reset") masks = (~done).float() bad_masks = (~(done * end_of_rollout)).float() if done.any(): ep_info["reward"].append(ep_reward[done].clone()) ep_reward *= (~done).float() # zero out the dones reset_indices = env.parallel_ind_buf.masked_select( done.squeeze()) obs = env.reset(reset_indices) if end_of_rollout: obs = env.reset() rollouts.insert(obs, action, action_log_prob, value, reward, masks, bad_masks) with torch.no_grad(): next_value = actor_critic.get_value( rollouts.observations[-1]).detach() rollouts.compute_returns(next_value, use_gae, gamma, gae_lambda) value_loss, action_loss, dist_entropy = agent.update(rollouts) rollouts.after_update() torch.save(copy.deepcopy(actor_critic).cpu(), args.save_path) ep_info["reward"] = torch.cat(ep_info["reward"]) logger.log_stats( args, { "update": update, "ep_info": ep_info, "dist_entropy": dist_entropy, "value_loss": value_loss, "action_loss": action_loss, }, )
def main(): parser = argparse.ArgumentParser() parser.add_argument('--env_name', type=str, default='coinrun', help='name of the environment to train on.') parser.add_argument('--model', type=str, default='ppo', help='the model to use for training. {ppo, ppo_aup}') args, rest_args = parser.parse_known_args() env_name = args.env_name model = args.model # --- ARGUMENTS --- if model == 'ppo': args = args_ppo.get_args(rest_args) elif model == 'ppo_aup': args = args_ppo_aup.get_args(rest_args) else: raise NotImplementedError # place other args back into argparse.Namespace args.env_name = env_name args.model = model # warnings if args.deterministic_execution: print('Envoking deterministic code execution.') if torch.backends.cudnn.enabled: warnings.warn('Running with deterministic CUDNN.') if args.num_processes > 1: raise RuntimeError( 'If you want fully deterministic code, run it with num_processes=1.' 'Warning: This will slow things down and might break A2C if ' 'policy_num_steps < env._max_episode_steps.') # --- TRAINING --- print("Setting up wandb logging.") # Weights & Biases logger if args.run_name is None: # make run name as {env_name}_{TIME} now = datetime.datetime.now().strftime('_%d-%m_%H:%M:%S') args.run_name = args.env_name + '_' + args.algo + now # initialise wandb wandb.init(project=args.proj_name, name=args.run_name, group=args.group_name, config=args, monitor_gym=False) # save wandb dir path args.run_dir = wandb.run.dir # make directory for saving models save_dir = os.path.join(wandb.run.dir, 'models') if not os.path.exists(save_dir): os.makedirs(save_dir) # set random seed of random, torch and numpy utl.set_global_seed(args.seed, args.deterministic_execution) print("Setting up Environments.") # initialise environments for training train_envs = make_vec_envs(env_name=args.env_name, start_level=args.train_start_level, num_levels=args.train_num_levels, distribution_mode=args.distribution_mode, paint_vel_info=args.paint_vel_info, num_processes=args.num_processes, num_frame_stack=args.num_frame_stack, device=device) # initialise environments for evaluation eval_envs = make_vec_envs(env_name=args.env_name, start_level=0, num_levels=0, distribution_mode=args.distribution_mode, paint_vel_info=args.paint_vel_info, num_processes=args.num_processes, num_frame_stack=args.num_frame_stack, device=device) _ = eval_envs.reset() print("Setting up Actor-Critic model and Training algorithm.") # initialise policy network actor_critic = ACModel(obs_shape=train_envs.observation_space.shape, action_space=train_envs.action_space, hidden_size=args.hidden_size).to(device) # initialise policy training algorithm if args.algo == 'ppo': policy = PPO(actor_critic=actor_critic, ppo_epoch=args.policy_ppo_epoch, num_mini_batch=args.policy_num_mini_batch, clip_param=args.policy_clip_param, value_loss_coef=args.policy_value_loss_coef, entropy_coef=args.policy_entropy_coef, max_grad_norm=args.policy_max_grad_norm, lr=args.policy_lr, eps=args.policy_eps) else: raise NotImplementedError # initialise rollout storage for the policy training algorithm rollouts = RolloutStorage(num_steps=args.policy_num_steps, num_processes=args.num_processes, obs_shape=train_envs.observation_space.shape, action_space=train_envs.action_space) # initialise Q_aux function(s) for AUP if args.use_aup: print("Initialising Q_aux models.") q_aux = [ QModel(obs_shape=train_envs.observation_space.shape, action_space=train_envs.action_space, hidden_size=args.hidden_size).to(device) for _ in range(args.num_q_aux) ] if args.num_q_aux == 1: # load weights to model path = args.q_aux_dir + "0.pt" q_aux[0].load_state_dict(torch.load(path)) q_aux[0].eval() else: # get max number of q_aux functions to choose from args.max_num_q_aux = os.listdir(args.q_aux_dir) q_aux_models = random.sample(list(range(0, args.max_num_q_aux)), args.num_q_aux) # load weights to models for i, model in enumerate(q_aux): path = args.q_aux_dir + str(q_aux_models[i]) + ".pt" model.load_state_dict(torch.load(path)) model.eval() # count number of frames and updates frames = 0 iter_idx = 0 # update wandb args wandb.config.update(args) update_start_time = time.time() # reset environments obs = train_envs.reset() # obs.shape = (num_processes,C,H,W) # insert initial observation to rollout storage rollouts.obs[0].copy_(obs) rollouts.to(device) # initialise buffer for calculating mean episodic returns episode_info_buf = deque(maxlen=10) # calculate number of updates # number of frames ÷ number of policy steps before update ÷ number of processes args.num_batch = args.num_processes * args.policy_num_steps args.num_updates = int(args.num_frames) // args.num_batch # define AUP coefficient if args.use_aup: aup_coef = args.aup_coef_start aup_linear_increase_val = math.exp( math.log(args.aup_coef_end / args.aup_coef_start) / args.num_updates) print("Training beginning.") print("Number of updates: ", args.num_updates) for iter_idx in range(args.num_updates): print("Iter: ", iter_idx) # put actor-critic into train mode actor_critic.train() if args.use_aup: aup_measures = defaultdict(list) # rollout policy to collect num_batch of experience and place in storage for step in range(args.policy_num_steps): # sample actions from policy with torch.no_grad(): value, action, action_log_prob = actor_critic.act( rollouts.obs[step]) # observe rewards and next obs obs, reward, done, infos = train_envs.step(action) # calculate AUP reward if args.use_aup: intrinsic_reward = torch.zeros_like(reward) with torch.no_grad(): for model in q_aux: # get action-values action_values = model.get_action_value( rollouts.obs[step]) # get action-value for action taken action_value = torch.sum( action_values * torch.nn.functional.one_hot( action, num_classes=train_envs.action_space.n).squeeze( dim=1), dim=1) # calculate the penalty intrinsic_reward += torch.abs( action_value.unsqueeze(dim=1) - action_values[:, 4].unsqueeze(dim=1)) intrinsic_reward /= args.num_q_aux # add intrinsic reward to the extrinsic reward reward -= aup_coef * intrinsic_reward # log the intrinsic reward from the first env. aup_measures['intrinsic_reward'].append(aup_coef * intrinsic_reward[0, 0]) if done[0] and infos[0]['prev_level_complete'] == 1: aup_measures['episode_complete'].append(2) elif done[0] and infos[0]['prev_level_complete'] == 0: aup_measures['episode_complete'].append(1) else: aup_measures['episode_complete'].append(0) # log episode info if episode finished for info in infos: if 'episode' in info.keys(): episode_info_buf.append(info['episode']) # create mask for episode ends masks = torch.FloatTensor([[0.0] if done_ else [1.0] for done_ in done]).to(device) # add experience to storage rollouts.insert(obs, reward, action, value, action_log_prob, masks) frames += args.num_processes # linearly increase aup coefficient after every update if args.use_aup: aup_coef *= aup_linear_increase_val # --- UPDATE --- # bootstrap next value prediction with torch.no_grad(): next_value = actor_critic.get_value(rollouts.obs[-1]).detach() # compute returns for current rollouts rollouts.compute_returns(next_value, args.policy_gamma, args.policy_gae_lambda) # update actor-critic using policy training algorithm total_loss, value_loss, action_loss, dist_entropy = policy.update( rollouts) # clean up storage after update rollouts.after_update() # --- LOGGING --- if iter_idx % args.log_interval == 0 or iter_idx == args.num_updates - 1: # --- EVALUATION --- eval_episode_info_buf = utl_eval.evaluate( eval_envs=eval_envs, actor_critic=actor_critic, device=device) # get stats for run update_end_time = time.time() num_interval_updates = 1 if iter_idx == 0 else args.log_interval fps = num_interval_updates * ( args.num_processes * args.policy_num_steps) / (update_end_time - update_start_time) update_start_time = update_end_time # calculates whether the value function is a good predicator of the returns (ev > 1) # or if it's just worse than predicting nothing (ev =< 0) ev = utl_math.explained_variance(utl.sf01(rollouts.value_preds), utl.sf01(rollouts.returns)) if args.use_aup: step = frames - args.num_processes * args.policy_num_steps for i in range(args.policy_num_steps): wandb.log( { 'aup/intrinsic_reward': aup_measures['intrinsic_reward'][i], 'aup/episode_complete': aup_measures['episode_complete'][i] }, step=step) step += args.num_processes wandb.log( { 'misc/timesteps': frames, 'misc/fps': fps, 'misc/explained_variance': float(ev), 'losses/total_loss': total_loss, 'losses/value_loss': value_loss, 'losses/action_loss': action_loss, 'losses/dist_entropy': dist_entropy, 'train/mean_episodic_return': utl_math.safe_mean([ episode_info['r'] for episode_info in episode_info_buf ]), 'train/mean_episodic_length': utl_math.safe_mean([ episode_info['l'] for episode_info in episode_info_buf ]), 'eval/mean_episodic_return': utl_math.safe_mean([ episode_info['r'] for episode_info in eval_episode_info_buf ]), 'eval/mean_episodic_length': utl_math.safe_mean([ episode_info['l'] for episode_info in eval_episode_info_buf ]) }, step=frames) # --- SAVE MODEL --- # save for every interval-th episode or for the last epoch if iter_idx != 0 and (iter_idx % args.save_interval == 0 or iter_idx == args.num_updates - 1): print("Saving Actor-Critic Model.") torch.save(actor_critic.state_dict(), os.path.join(save_dir, "policy{0}.pt".format(iter_idx))) # close envs train_envs.close() eval_envs.close() # --- TEST --- if args.test: print("Testing beginning.") episodic_return = utl_test.test(args=args, actor_critic=actor_critic, device=device) # save returns from train and test levels to analyse using interactive mode train_levels = torch.arange( args.train_start_level, args.train_start_level + args.train_num_levels) for i, level in enumerate(train_levels): wandb.log({ 'test/train_levels': level, 'test/train_returns': episodic_return[0][i] }) test_levels = torch.arange( args.test_start_level, args.test_start_level + args.test_num_levels) for i, level in enumerate(test_levels): wandb.log({ 'test/test_levels': level, 'test/test_returns': episodic_return[1][i] }) # log returns from test envs wandb.run.summary["train_mean_episodic_return"] = utl_math.safe_mean( episodic_return[0]) wandb.run.summary["test_mean_episodic_return"] = utl_math.safe_mean( episodic_return[1])
def main(): # --- ARGUMENTS --- parser = argparse.ArgumentParser() parser.add_argument('--env_name', type=str, default='coinrun', help='name of the environment to train on.') parser.add_argument('--model', type=str, default='ppo', help='the model to use for training.') args, rest_args = parser.parse_known_args() env_name = args.env_name model = args.model # get arguments args = args_pretrain_aup.get_args(rest_args) # place other args back into argparse.Namespace args.env_name = env_name args.model = model # Weights & Biases logger if args.run_name is None: # make run name as {env_name}_{TIME} now = datetime.datetime.now().strftime('_%d-%m_%H:%M:%S') args.run_name = args.env_name + '_' + args.algo + now # initialise wandb wandb.init(name=args.run_name, project=args.proj_name, group=args.group_name, config=args, monitor_gym=False) # save wandb dir path args.run_dir = wandb.run.dir wandb.config.update(args) # set random seed of random, torch and numpy utl.set_global_seed(args.seed, args.deterministic_execution) # --- OBTAIN DATA FOR TRAINING R_aux --- print("Gathering data for R_aux Model.") # gather observations for pretraining the auxiliary reward function (CB-VAE) envs = make_vec_envs(env_name=args.env_name, start_level=0, num_levels=0, distribution_mode=args.distribution_mode, paint_vel_info=args.paint_vel_info, num_processes=args.num_processes, num_frame_stack=args.num_frame_stack, device=device) # number of frames ÷ number of policy steps before update ÷ number of cpu processes num_batch = args.num_processes * args.policy_num_steps num_updates = int(args.num_frames_r_aux) // num_batch # create list to store env observations obs_data = torch.zeros(num_updates * args.policy_num_steps + 1, args.num_processes, *envs.observation_space.shape) # reset environments obs = envs.reset() # obs.shape = (n_env,C,H,W) obs_data[0].copy_(obs) obs = obs.to(device) for iter_idx in range(num_updates): # rollout policy to collect num_batch of experience and store in storage for step in range(args.policy_num_steps): # sample actions from random agent action = torch.randint(0, envs.action_space.n, (args.num_processes, 1)) # observe rewards and next obs obs, reward, done, infos = envs.step(action) # store obs obs_data[1 + iter_idx * args.policy_num_steps + step].copy_(obs) # close envs envs.close() # --- TRAIN R_aux (CB-VAE) --- # define CB-VAE where the encoder will be used as the auxiliary reward function R_aux print("Training R_aux Model.") # create dataloader for observations gathered obs_data = obs_data.reshape(-1, *envs.observation_space.shape) sampler = BatchSampler(SubsetRandomSampler(range(obs_data.size(0))), args.cb_vae_batch_size, drop_last=False) # initialise CB-VAE cb_vae = CBVAE(obs_shape=envs.observation_space.shape, latent_dim=args.cb_vae_latent_dim).to(device) # optimiser optimiser = torch.optim.Adam(cb_vae.parameters(), lr=args.cb_vae_learning_rate) # put CB-VAE into train mode cb_vae.train() measures = defaultdict(list) for epoch in range(args.cb_vae_epochs): print("Epoch: ", epoch) start_time = time.time() batch_loss = 0 for indices in sampler: obs = obs_data[indices].to(device) # zero accumulated gradients cb_vae.zero_grad() # forward pass through CB-VAE recon_batch, mu, log_var = cb_vae(obs) # calculate loss loss = cb_vae_loss(recon_batch, obs, mu, log_var) # backpropogation: calculating gradients loss.backward() # update parameters of generator optimiser.step() # save loss per mini-batch batch_loss += loss.item() * obs.size(0) # log losses per epoch wandb.log({ 'cb_vae/loss': batch_loss / obs_data.size(0), 'cb_vae/time_taken': time.time() - start_time, 'cb_vae/epoch': epoch }) indices = np.random.randint(0, obs.size(0), args.cb_vae_num_samples**2) measures['true_images'].append(obs[indices].detach().cpu().numpy()) measures['recon_images'].append( recon_batch[indices].detach().cpu().numpy()) # plot ground truth images plt.rcParams.update({'font.size': 10}) fig, axs = plt.subplots(args.cb_vae_num_samples, args.cb_vae_num_samples, figsize=(20, 20)) for i, img in enumerate(measures['true_images'][0]): axs[i // args.cb_vae_num_samples][i % args.cb_vae_num_samples].imshow( img.transpose(1, 2, 0)) axs[i // args.cb_vae_num_samples][i % args.cb_vae_num_samples].axis('off') wandb.log({"Ground Truth Images": wandb.Image(plt)}) # plot reconstructed images fig, axs = plt.subplots(args.cb_vae_num_samples, args.cb_vae_num_samples, figsize=(20, 20)) for i, img in enumerate(measures['recon_images'][0]): axs[i // args.cb_vae_num_samples][i % args.cb_vae_num_samples].imshow( img.transpose(1, 2, 0)) axs[i // args.cb_vae_num_samples][i % args.cb_vae_num_samples].axis('off') wandb.log({"Reconstructed Images": wandb.Image(plt)}) # --- TRAIN Q_aux -- # train PPO agent with value head replaced with action-value head and training on R_aux instead of the environment R print("Training Q_aux Model.") # initialise environments for training Q_aux envs = make_vec_envs(env_name=args.env_name, start_level=0, num_levels=0, distribution_mode=args.distribution_mode, paint_vel_info=args.paint_vel_info, num_processes=args.num_processes, num_frame_stack=args.num_frame_stack, device=device) # initialise policy network actor_critic = QModel(obs_shape=envs.observation_space.shape, action_space=envs.action_space, hidden_size=args.hidden_size).to(device) # initialise policy trainer if args.algo == 'ppo': policy = PPO(actor_critic=actor_critic, ppo_epoch=args.policy_ppo_epoch, num_mini_batch=args.policy_num_mini_batch, clip_param=args.policy_clip_param, value_loss_coef=args.policy_value_loss_coef, entropy_coef=args.policy_entropy_coef, max_grad_norm=args.policy_max_grad_norm, lr=args.policy_lr, eps=args.policy_eps) else: raise NotImplementedError # initialise rollout storage for the policy rollouts = RolloutStorage(num_steps=args.policy_num_steps, num_processes=args.num_processes, obs_shape=envs.observation_space.shape, action_space=envs.action_space) # count number of frames and updates frames = 0 iter_idx = 0 update_start_time = time.time() # reset environments obs = envs.reset() # obs.shape = (n_envs,C,H,W) # insert initial observation to rollout storage rollouts.obs[0].copy_(obs) rollouts.to(device) # initialise buffer for calculating mean episodic returns episode_info_buf = deque(maxlen=10) # calculate number of updates # number of frames ÷ number of policy steps before update ÷ number of cpu processes args.num_batch = args.num_processes * args.policy_num_steps args.num_updates = int(args.num_frames_q_aux) // args.num_batch print("Number of updates: ", args.num_updates) for iter_idx in range(args.num_updates): print("Iter: ", iter_idx) # put actor-critic into train mode actor_critic.train() # rollout policy to collect num_batch of experience and store in storage for step in range(args.policy_num_steps): with torch.no_grad(): # sample actions from policy value, action, action_log_prob = actor_critic.act( rollouts.obs[step]) # obtain reward R_aux from encoder of CB-VAE r_aux, _, _ = cb_vae.encode(rollouts.obs[step]) # observe rewards and next obs obs, _, done, infos = envs.step(action) # log episode info if episode finished for i, info in enumerate(infos): if 'episode' in info.keys(): episode_info_buf.append(info['episode']) # create mask for episode ends masks = torch.FloatTensor([[0.0] if done_ else [1.0] for done_ in done]).to(device) # add experience to policy buffer rollouts.insert(obs, r_aux, action, value, action_log_prob, masks) frames += args.num_processes # --- UPDATE --- # bootstrap next value prediction with torch.no_grad(): next_value = actor_critic.get_value(rollouts.obs[-1]).detach() # compute returns for current rollouts rollouts.compute_returns(next_value, args.policy_gamma, args.policy_gae_lambda) # update actor-critic using policy gradient algo total_loss, value_loss, action_loss, dist_entropy = policy.update( rollouts) # clean up after update rollouts.after_update() # --- LOGGING --- if iter_idx % args.log_interval == 0 or iter_idx == args.num_updates - 1: # get stats for run update_end_time = time.time() num_interval_updates = 1 if iter_idx == 0 else args.log_interval fps = num_interval_updates * ( args.num_processes * args.policy_num_steps) / (update_end_time - update_start_time) update_start_time = update_end_time # Calculates if value function is a good predicator of the returns (ev > 1) # or if it's just worse than predicting nothing (ev =< 0) ev = utl_math.explained_variance(utl.sf01(rollouts.value_preds), utl.sf01(rollouts.returns)) wandb.log({ 'q_aux_misc/timesteps': frames, 'q_aux_misc/fps': fps, 'q_aux_misc/explained_variance': float(ev), 'q_aux_losses/total_loss': total_loss, 'q_aux_losses/value_loss': value_loss, 'q_aux_losses/action_loss': action_loss, 'q_aux_losses/dist_entropy': dist_entropy, 'q_aux_train/mean_episodic_return': utl_math.safe_mean( [episode_info['r'] for episode_info in episode_info_buf]), 'q_aux_train/mean_episodic_length': utl_math.safe_mean( [episode_info['l'] for episode_info in episode_info_buf]) }) # close envs envs.close() # --- SAVE MODEL --- print("Saving Q_aux Model.") torch.save(actor_critic.state_dict(), args.q_aux_path)
def main(): parser = argparse.ArgumentParser() parser.add_argument('--env_name', type=str, default='coinrun', help='name of the environment to train on.') parser.add_argument( '--model', type=str, default='ppo', help='the model to use for training. {ppo, ibac, ibac_sni, dist_match}' ) args, rest_args = parser.parse_known_args() env_name = args.env_name model = args.model # --- ARGUMENTS --- if model == 'ppo': args = args_ppo.get_args(rest_args) elif model == 'ibac': args = args_ibac.get_args(rest_args) elif model == 'ibac_sni': args = args_ibac_sni.get_args(rest_args) elif model == 'dist_match': args = args_dist_match.get_args(rest_args) else: raise NotImplementedError # place other args back into argparse.Namespace args.env_name = env_name args.model = model args.num_train_envs = args.num_processes - args.num_val_envs if args.num_val_envs > 0 else args.num_processes # warnings if args.deterministic_execution: print('Envoking deterministic code execution.') if torch.backends.cudnn.enabled: warnings.warn('Running with deterministic CUDNN.') if args.num_processes > 1: raise RuntimeError( 'If you want fully deterministic code, run it with num_processes=1.' 'Warning: This will slow things down and might break A2C if ' 'policy_num_steps < env._max_episode_steps.') elif args.num_val_envs > 0 and (args.num_val_envs >= args.num_processes or not args.percentage_levels_train < 1.0): raise ValueError( 'If --args.num_val_envs>0 then you must also have' '--num_val_envs < --num_processes and 0 < --percentage_levels_train < 1.' ) elif args.num_val_envs > 0 and not args.use_dist_matching and args.dist_matching_coef != 0: raise ValueError( 'If --num_val_envs>0 and --use_dist_matching=False then you must also have' '--dist_matching_coef=0.') elif args.use_dist_matching and not args.num_val_envs > 0: raise ValueError( 'If --use_dist_matching=True then you must also have' '0 < --num_val_envs < --num_processes and 0 < --percentage_levels_train < 1.' ) elif args.analyse_rep and not args.use_bottleneck: raise ValueError('If --analyse_rep=True then you must also have' '--use_bottleneck=True.') # --- TRAINING --- print("Setting up wandb logging.") # Weights & Biases logger if args.run_name is None: # make run name as {env_name}_{TIME} now = datetime.datetime.now().strftime('_%d-%m_%H:%M:%S') args.run_name = args.env_name + '_' + args.algo + now # initialise wandb wandb.init(project=args.proj_name, name=args.run_name, group=args.group_name, config=args, monitor_gym=False) # save wandb dir path args.run_dir = wandb.run.dir # make directory for saving models save_dir = os.path.join(wandb.run.dir, 'models') if not os.path.exists(save_dir): os.makedirs(save_dir) # set random seed of random, torch and numpy utl.set_global_seed(args.seed, args.deterministic_execution) # initialise environments for training print("Setting up Environments.") if args.num_val_envs > 0: train_num_levels = int(args.train_num_levels * args.percentage_levels_train) val_start_level = args.train_start_level + train_num_levels val_num_levels = args.train_num_levels - train_num_levels train_envs = make_vec_envs(env_name=args.env_name, start_level=args.train_start_level, num_levels=train_num_levels, distribution_mode=args.distribution_mode, paint_vel_info=args.paint_vel_info, num_processes=args.num_train_envs, num_frame_stack=args.num_frame_stack, device=device) val_envs = make_vec_envs(env_name=args.env_name, start_level=val_start_level, num_levels=val_num_levels, distribution_mode=args.distribution_mode, paint_vel_info=args.paint_vel_info, num_processes=args.num_val_envs, num_frame_stack=args.num_frame_stack, device=device) else: train_envs = make_vec_envs(env_name=args.env_name, start_level=args.train_start_level, num_levels=args.train_num_levels, distribution_mode=args.distribution_mode, paint_vel_info=args.paint_vel_info, num_processes=args.num_processes, num_frame_stack=args.num_frame_stack, device=device) # initialise environments for evaluation eval_envs = make_vec_envs(env_name=args.env_name, start_level=0, num_levels=0, distribution_mode=args.distribution_mode, paint_vel_info=args.paint_vel_info, num_processes=args.num_processes, num_frame_stack=args.num_frame_stack, device=device) _ = eval_envs.reset() # initialise environments for analysing the representation if args.analyse_rep: analyse_rep_train1_envs, analyse_rep_train2_envs, analyse_rep_val_envs, analyse_rep_test_envs = make_rep_analysis_envs( args, device) print("Setting up Actor-Critic model and Training algorithm.") # initialise policy network actor_critic = ACModel(obs_shape=train_envs.observation_space.shape, action_space=train_envs.action_space, hidden_size=args.hidden_size, use_bottleneck=args.use_bottleneck, sni_type=args.sni_type).to(device) # initialise policy training algorithm if args.algo == 'ppo': policy = PPO(actor_critic=actor_critic, ppo_epoch=args.policy_ppo_epoch, num_mini_batch=args.policy_num_mini_batch, clip_param=args.policy_clip_param, value_loss_coef=args.policy_value_loss_coef, entropy_coef=args.policy_entropy_coef, max_grad_norm=args.policy_max_grad_norm, lr=args.policy_lr, eps=args.policy_eps, vib_coef=args.vib_coef, sni_coef=args.sni_coef, use_dist_matching=args.use_dist_matching, dist_matching_loss=args.dist_matching_loss, dist_matching_coef=args.dist_matching_coef, num_train_envs=args.num_train_envs, num_val_envs=args.num_val_envs) else: raise NotImplementedError # initialise rollout storage for the policy training algorithm rollouts = RolloutStorage(num_steps=args.policy_num_steps, num_processes=args.num_processes, obs_shape=train_envs.observation_space.shape, action_space=train_envs.action_space) # count number of frames and updates frames = 0 iter_idx = 0 # update wandb args wandb.config.update(args) # wandb.watch(actor_critic, log="all") # to log gradients of actor-critic network update_start_time = time.time() # reset environments if args.num_val_envs > 0: obs = torch.cat([train_envs.reset(), val_envs.reset()]) # obs.shape = (n_envs,C,H,W) else: obs = train_envs.reset() # obs.shape = (n_envs,C,H,W) # insert initial observation to rollout storage rollouts.obs[0].copy_(obs) rollouts.to(device) # initialise buffer for calculating mean episodic returns train_episode_info_buf = deque(maxlen=10) val_episode_info_buf = deque(maxlen=10) # calculate number of updates # number of frames ÷ number of policy steps before update ÷ number of processes args.num_batch = args.num_processes * args.policy_num_steps args.num_updates = int(args.num_frames) // args.num_batch print("Training beginning.") print("Number of updates: ", args.num_updates) for iter_idx in range(args.num_updates): print("Iter: ", iter_idx) # put actor-critic into train mode actor_critic.train() # rollout policy to collect num_batch of experience and place in storage for step in range(args.policy_num_steps): # sample actions from policy with torch.no_grad(): value, action, action_log_prob, _ = actor_critic.act( rollouts.obs[step]) # observe rewards and next obs if args.num_val_envs > 0: obs, reward, done, infos = train_envs.step( action[:args.num_train_envs, :]) val_obs, val_reward, val_done, val_infos = val_envs.step( action[args.num_train_envs:, :]) obs = torch.cat([obs, val_obs]) reward = torch.cat([reward, val_reward]) done, val_done = list(done), list(val_done) done.extend(val_done) infos.extend(val_infos) else: obs, reward, done, infos = train_envs.step(action) # log episode info if episode finished for i, info in enumerate(infos): if i < args.num_train_envs and 'episode' in info.keys(): train_episode_info_buf.append(info['episode']) elif i >= args.num_train_envs and 'episode' in info.keys(): val_episode_info_buf.append(info['episode']) # create mask for episode ends masks = torch.FloatTensor([[0.0] if done_ else [1.0] for done_ in done]).to(device) # add experience to storage rollouts.insert(obs, reward, action, value, action_log_prob, masks) frames += args.num_processes # --- UPDATE --- # bootstrap next value prediction with torch.no_grad(): next_value = actor_critic.get_value(rollouts.obs[-1]).detach() # compute returns for current rollouts rollouts.compute_returns(next_value, args.policy_gamma, args.policy_gae_lambda) # update actor-critic using policy gradient algo total_loss, value_loss, action_loss, dist_entropy, vib_kl, dist_matching_loss = policy.update( rollouts) # clean up storage after update rollouts.after_update() # --- LOGGING --- if iter_idx % args.log_interval == 0 or iter_idx == args.num_updates - 1: # --- EVALUATION --- eval_episode_info_buf = utl_eval.evaluate( eval_envs=eval_envs, actor_critic=actor_critic, device=device) # --- ANALYSE REPRESENTATION --- if args.analyse_rep: rep_measures = utl_rep.analyse_rep( args=args, train1_envs=analyse_rep_train1_envs, train2_envs=analyse_rep_train2_envs, val_envs=analyse_rep_val_envs, test_envs=analyse_rep_test_envs, actor_critic=actor_critic, device=device) # get stats for run update_end_time = time.time() num_interval_updates = 1 if iter_idx == 0 else args.log_interval fps = num_interval_updates * ( args.num_processes * args.policy_num_steps) / (update_end_time - update_start_time) update_start_time = update_end_time # Calculates if value function is a good predicator of the returns (ev > 1) # or if it's just worse than predicting nothing (ev =< 0) ev = utl_math.explained_variance(utl.sf01(rollouts.value_preds), utl.sf01(rollouts.returns)) wandb.log( { 'misc/timesteps': frames, 'misc/fps': fps, 'misc/explained_variance': float(ev), 'losses/total_loss': total_loss, 'losses/value_loss': value_loss, 'losses/action_loss': action_loss, 'losses/dist_entropy': dist_entropy, 'train/mean_episodic_return': utl_math.safe_mean([ episode_info['r'] for episode_info in train_episode_info_buf ]), 'train/mean_episodic_length': utl_math.safe_mean([ episode_info['l'] for episode_info in train_episode_info_buf ]), 'eval/mean_episodic_return': utl_math.safe_mean([ episode_info['r'] for episode_info in eval_episode_info_buf ]), 'eval/mean_episodic_length': utl_math.safe_mean([ episode_info['l'] for episode_info in eval_episode_info_buf ]) }, step=iter_idx) if args.use_bottleneck: wandb.log({'losses/vib_kl': vib_kl}, step=iter_idx) if args.num_val_envs > 0: wandb.log( { 'losses/dist_matching_loss': dist_matching_loss, 'val/mean_episodic_return': utl_math.safe_mean([ episode_info['r'] for episode_info in val_episode_info_buf ]), 'val/mean_episodic_length': utl_math.safe_mean([ episode_info['l'] for episode_info in val_episode_info_buf ]) }, step=iter_idx) if args.analyse_rep: wandb.log( { "analysis/" + key: val for key, val in rep_measures.items() }, step=iter_idx) # --- SAVE MODEL --- # save for every interval-th episode or for the last epoch if iter_idx != 0 and (iter_idx % args.save_interval == 0 or iter_idx == args.num_updates - 1): print("Saving Actor-Critic Model.") torch.save(actor_critic.state_dict(), os.path.join(save_dir, "policy{0}.pt".format(iter_idx))) # close envs train_envs.close() eval_envs.close() # --- TEST --- if args.test: print("Testing beginning.") episodic_return, latents_z = utl_test.test(args=args, actor_critic=actor_critic, device=device) # save returns from train and test levels to analyse using interactive mode train_levels = torch.arange( args.train_start_level, args.train_start_level + args.train_num_levels) for i, level in enumerate(train_levels): wandb.log({ 'test/train_levels': level, 'test/train_returns': episodic_return[0][i] }) test_levels = torch.arange( args.test_start_level, args.test_start_level + args.test_num_levels) for i, level in enumerate(test_levels): wandb.log({ 'test/test_levels': level, 'test/test_returns': episodic_return[1][i] }) # log returns from test envs wandb.run.summary["train_mean_episodic_return"] = utl_math.safe_mean( episodic_return[0]) wandb.run.summary["test_mean_episodic_return"] = utl_math.safe_mean( episodic_return[1]) # plot latent representation if args.plot_pca: print("Plotting PCA of Latent Representation.") utl_rep.pca(args, latents_z)