def get_task_dim(args): env = make_vec_envs(env_name=args.env_name, seed=args.seed, num_processes=args.num_processes, gamma=args.policy_gamma, device=device, episodes_per_task=args.max_rollouts_per_task, normalise_rew=args.norm_rew_for_policy, ret_rms=None, tasks=None ) return env.task_dim
def get_num_tasks(args): env = make_vec_envs(env_name=args.env_name, seed=args.seed, num_processes=args.num_processes, gamma=args.policy_gamma, device=device, episodes_per_task=args.max_rollouts_per_task, normalise_rew=args.norm_rew_for_policy, ret_rms=None, tasks=None ) try: num_tasks = env.num_tasks except AttributeError: num_tasks = None return num_tasks
def __init__(self, args): self.args = args utl.seed(self.args.seed, self.args.deterministic_execution) # calculate number of updates and keep count of frames/iterations self.num_updates = int( args.num_frames) // args.policy_num_steps // args.num_processes self.frames = 0 self.iter_idx = 0 # initialise tensorboard logger self.logger = TBLogger(self.args, self.args.exp_label) # initialise environments self.envs = make_vec_envs( env_name=args.env_name, seed=args.seed, num_processes=args.num_processes, gamma=args.policy_gamma, device=device, episodes_per_task=self.args.max_rollouts_per_task, normalise_rew=args.norm_rew_for_policy, ret_rms=None, ) # calculate what the maximum length of the trajectories is self.args.max_trajectory_len = self.envs._max_episode_steps self.args.max_trajectory_len *= self.args.max_rollouts_per_task # get policy input dimensions self.args.state_dim = self.envs.observation_space.shape[0] self.args.task_dim = self.envs.task_dim self.args.belief_dim = self.envs.belief_dim self.args.num_states = self.envs.num_states # get policy output (action) dimensions self.args.action_space = self.envs.action_space if isinstance(self.envs.action_space, gym.spaces.discrete.Discrete): self.args.action_dim = 1 elif isinstance(self.envs.action_space, gym.spaces.multi_discrete.MultiDiscrete): self.args.action_dim = self.envs.action_space.nvec[0] else: self.args.action_dim = self.envs.action_space.shape[0] # initialise VAE and policy self.vae = VaribadVAE(self.args, self.logger, lambda: self.iter_idx) self.policy_storage = self.initialise_policy_storage() self.policy = self.initialise_policy()
def __init__(self, args): self.args = args utl.seed(self.args.seed, self.args.deterministic_execution) # count number of frames and number of meta-iterations self.frames = 0 self.iter_idx = 0 # initialise tensorboard logger self.logger = TBLogger(self.args, self.args.exp_label) # initialise environments self.envs = make_vec_envs( env_name=args.env_name, seed=args.seed, num_processes=args.num_processes, gamma=args.policy_gamma, log_dir=args.agent_log_dir, device=device, allow_early_resets=False, episodes_per_task=self.args.max_rollouts_per_task, obs_rms=None, ret_rms=None, ) # calculate what the maximum length of the trajectories is args.max_trajectory_len = self.envs._max_episode_steps args.max_trajectory_len *= self.args.max_rollouts_per_task # calculate number of meta updates self.args.num_updates = int( args.num_frames) // args.policy_num_steps // args.num_processes # get action / observation dimensions if isinstance(self.envs.action_space, gym.spaces.discrete.Discrete): self.args.action_dim = 1 else: self.args.action_dim = self.envs.action_space.shape[0] self.args.obs_dim = self.envs.observation_space.shape[0] self.args.num_states = self.envs.num_states if str.startswith( self.args.env_name, 'Grid') else None self.args.act_space = self.envs.action_space self.vae = VaribadVAE(self.args, self.logger, lambda: self.iter_idx) self.initialise_policy()
def main(): parser = argparse.ArgumentParser() parser.add_argument('--env-type', default='gridworld_varibad') args, rest_args = parser.parse_known_args() env = args.env_type # --- GridWorld --- if env == 'gridworld_belief_oracle': args = args_grid_belief_oracle.get_args(rest_args) elif env == 'gridworld_varibad': args = args_grid_varibad.get_args(rest_args) elif env == 'gridworld_rl2': args = args_grid_rl2.get_args(rest_args) # --- PointRobot 2D Navigation --- elif env == 'pointrobot_multitask': args = args_pointrobot_multitask.get_args(rest_args) elif env == 'pointrobot_varibad': args = args_pointrobot_varibad.get_args(rest_args) elif env == 'pointrobot_rl2': args = args_pointrobot_rl2.get_args(rest_args) elif env == 'pointrobot_humplik': args = args_pointrobot_humplik.get_args(rest_args) # --- MUJOCO --- # - CheetahDir - elif env == 'cheetah_dir_multitask': args = args_cheetah_dir_multitask.get_args(rest_args) elif env == 'cheetah_dir_expert': args = args_cheetah_dir_expert.get_args(rest_args) elif env == 'cheetah_dir_varibad': args = args_cheetah_dir_varibad.get_args(rest_args) elif env == 'cheetah_dir_rl2': args = args_cheetah_dir_rl2.get_args(rest_args) # # - CheetahVel - elif env == 'cheetah_vel_multitask': args = args_cheetah_vel_multitask.get_args(rest_args) elif env == 'cheetah_vel_expert': args = args_cheetah_vel_expert.get_args(rest_args) elif env == 'cheetah_vel_avg': args = args_cheetah_vel_avg.get_args(rest_args) elif env == 'cheetah_vel_varibad': args = args_cheetah_vel_varibad.get_args(rest_args) elif env == 'cheetah_vel_rl2': args = args_cheetah_vel_rl2.get_args(rest_args) # # - AntDir - elif env == 'ant_dir_multitask': args = args_ant_dir_multitask.get_args(rest_args) elif env == 'ant_dir_expert': args = args_ant_dir_expert.get_args(rest_args) elif env == 'ant_dir_varibad': args = args_ant_dir_varibad.get_args(rest_args) elif env == 'ant_dir_rl2': args = args_ant_dir_rl2.get_args(rest_args) # # - AntGoal - elif env == 'ant_goal_multitask': args = args_ant_goal_multitask.get_args(rest_args) elif env == 'ant_goal_expert': args = args_ant_goal_expert.get_args(rest_args) elif env == 'ant_goal_varibad': args = args_ant_goal_varibad.get_args(rest_args) elif env == 'ant_goal_humplik': args = args_ant_goal_humplik.get_args(rest_args) elif env == 'ant_goal_rl2': args = args_ant_goal_rl2.get_args(rest_args) # # - Walker - elif env == 'walker_multitask': args = args_walker_multitask.get_args(rest_args) elif env == 'walker_expert': args = args_walker_expert.get_args(rest_args) elif env == 'walker_avg': args = args_walker_avg.get_args(rest_args) elif env == 'walker_varibad': args = args_walker_varibad.get_args(rest_args) elif env == 'walker_rl2': args = args_walker_rl2.get_args(rest_args) # # - HumanoidDir - elif env == 'humanoid_dir_multitask': args = args_humanoid_dir_multitask.get_args(rest_args) elif env == 'humanoid_dir_expert': args = args_humanoid_dir_expert.get_args(rest_args) elif env == 'humanoid_dir_varibad': args = args_humanoid_dir_varibad.get_args(rest_args) elif env == 'humanoid_dir_rl2': args = args_humanoid_dir_rl2.get_args(rest_args) else: raise Exception("Invalid Environment") # warning for deterministic execution if args.deterministic_execution: print('Envoking deterministic code execution.') if torch.backends.cudnn.enabled: warnings.warn('Running with deterministic CUDNN.') if args.num_processes > 1: raise RuntimeError( 'If you want fully deterministic code, run it with num_processes=1.' 'Warning: This will slow things down and might break A2C if ' 'policy_num_steps < env._max_episode_steps.') # if we're normalising the actions, we have to make sure that the env expects actions within [-1, 1] if args.norm_actions_pre_sampling or args.norm_actions_post_sampling: envs = make_vec_envs( env_name=args.env_name, seed=0, num_processes=args.num_processes, gamma=args.policy_gamma, device='cpu', episodes_per_task=args.max_rollouts_per_task, normalise_rew=args.norm_rew_for_policy, ret_rms=None, tasks=None, ) assert np.unique(envs.action_space.low) == [-1] assert np.unique(envs.action_space.high) == [1] # clean up arguments if args.disable_metalearner or args.disable_decoder: args.decode_reward = False args.decode_state = False args.decode_task = False if hasattr(args, 'decode_only_past') and args.decode_only_past: args.split_batches_by_elbo = True # if hasattr(args, 'vae_subsample_decodes') and args.vae_subsample_decodes: # args.split_batches_by_elbo = True # begin training (loop through all passed seeds) seed_list = [args.seed] if isinstance(args.seed, int) else args.seed for seed in seed_list: print('training', seed) args.seed = seed args.action_space = None if args.disable_metalearner: # If `disable_metalearner` is true, the file `learner.py` will be used instead of `metalearner.py`. # This is a stripped down version without encoder, decoder, stochastic latent variables, etc. learner = Learner(args) else: learner = MetaLearner(args) learner.train()
def load_and_render(self, load_iter): #save_path = os.path.join('/ext/varibad_github/v2/varibad/logs/logs_HalfCheetahJoint-v0/varibad_73__15:05_17:14:07', 'models') #save_path = os.path.join('/ext/varibad_github/v2/varibad/logs/hfield', 'models') save_path = os.path.join( '/ext/varibad_github/v2/varibad/logs/logs_HalfCheetahBlocks-v0/varibad_73__15:05_20:20:25', 'models') self.policy.actor_critic = torch.load( os.path.join(save_path, "policy{0}.pt".format(load_iter))) self.vae.encoder = torch.load( os.path.join(save_path, "encoder{0}.pt").format(load_iter)) args = self.args device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") num_processes = 1 num_episodes = 100 num_steps = 1999 #import pdb; pdb.set_trace() # initialise environments envs = make_vec_envs( env_name=args.env_name, seed=args.seed, num_processes=num_processes, # 1 gamma=args.policy_gamma, log_dir=args.agent_log_dir, device=device, allow_early_resets=False, episodes_per_task=self.args.max_rollouts_per_task, obs_rms=None, ret_rms=None, ) # reset latent state to prior latent_sample, latent_mean, latent_logvar, hidden_state = self.vae.encoder.prior( num_processes) for episode_idx in range(num_episodes): (prev_obs_raw, prev_obs_normalised) = envs.reset() prev_obs_raw = prev_obs_raw.to(device) prev_obs_normalised = prev_obs_normalised.to(device) for step_idx in range(num_steps): with torch.no_grad(): _, action, _ = utl.select_action( args=self.args, policy=self.policy, obs=prev_obs_normalised if self.args.norm_obs_for_policy else prev_obs_raw, latent_sample=latent_sample, latent_mean=latent_mean, latent_logvar=latent_logvar, deterministic=True) # observe reward and next obs (next_obs_raw, next_obs_normalised), ( rew_raw, rew_normalised), done, infos = utl.env_step(envs, action) # render envs.venv.venv.envs[0].env.env.env.env.render() # update the hidden state latent_sample, latent_mean, latent_logvar, hidden_state = utl.update_encoding( encoder=self.vae.encoder, next_obs=next_obs_raw, action=action, reward=rew_raw, done=None, hidden_state=hidden_state) prev_obs_normalised = next_obs_normalised prev_obs_raw = next_obs_raw if done[0]: break
def visualise_behaviour( args, policy, image_folder, iter_idx, ret_rms, tasks, encoder=None, reward_decoder=None, state_decoder=None, task_decoder=None, compute_rew_reconstruction_loss=None, compute_task_reconstruction_loss=None, compute_state_reconstruction_loss=None, compute_kl_loss=None, ): # initialise environment env = make_vec_envs( env_name=args.env_name, seed=args.seed * 42 + iter_idx, num_processes=1, gamma=args.policy_gamma, device=device, episodes_per_task=args.max_rollouts_per_task, normalise_rew=args.norm_rew_for_policy, ret_rms=ret_rms, rank_offset=args.num_processes + 42, # not sure if the temp folders would otherwise clash tasks=tasks) episode_task = torch.from_numpy(np.array( env.get_task())).to(device).float() # get a sample rollout unwrapped_env = env.venv.unwrapped.envs[0] if hasattr(env.venv.unwrapped.envs[0], 'unwrapped'): unwrapped_env = unwrapped_env.unwrapped if hasattr(unwrapped_env, 'visualise_behaviour'): # if possible, get it from the env directly # (this might visualise other things in addition) traj = unwrapped_env.visualise_behaviour( env=env, args=args, policy=policy, iter_idx=iter_idx, encoder=encoder, reward_decoder=reward_decoder, state_decoder=state_decoder, task_decoder=task_decoder, image_folder=image_folder, ) else: traj = get_test_rollout(args, env, policy, encoder) latent_means, latent_logvars, episode_prev_obs, episode_next_obs, episode_actions, episode_rewards, episode_returns = traj if latent_means is not None: plot_latents(latent_means, latent_logvars, image_folder=image_folder, iter_idx=iter_idx) if not (args.disable_decoder and args.disable_kl_term): plot_vae_loss( args, latent_means, latent_logvars, episode_prev_obs, episode_next_obs, episode_actions, episode_rewards, episode_task, image_folder=image_folder, iter_idx=iter_idx, reward_decoder=reward_decoder, state_decoder=state_decoder, task_decoder=task_decoder, compute_task_reconstruction_loss= compute_task_reconstruction_loss, compute_rew_reconstruction_loss=compute_rew_reconstruction_loss, compute_state_reconstruction_loss= compute_state_reconstruction_loss, compute_kl_loss=compute_kl_loss, ) env.close()
def evaluate(args, policy, ret_rms, iter_idx, tasks, encoder=None, num_episodes=None): env_name = args.env_name if hasattr(args, 'test_env_name'): env_name = args.test_env_name if num_episodes is None: num_episodes = args.max_rollouts_per_task num_processes = args.num_processes # --- set up the things we want to log --- # for each process, we log the returns during the first, second, ... episode # (such that we have a minimum of [num_episodes]; the last column is for # any overflow and will be discarded at the end, because we need to wait until # all processes have at least [num_episodes] many episodes) returns_per_episode = torch.zeros( (num_processes, num_episodes + 1)).to(device) # --- initialise environments and latents --- envs = make_vec_envs( env_name, seed=args.seed * 42 + iter_idx, num_processes=num_processes, gamma=args.policy_gamma, device=device, rank_offset=num_processes + 1, # to use diff tmp folders than main processes episodes_per_task=num_episodes, normalise_rew=args.norm_rew_for_policy, ret_rms=ret_rms, tasks=tasks, add_done_info=args.max_rollouts_per_task > 1, ) num_steps = envs._max_episode_steps # reset environments state, belief, task = utl.reset_env(envs, args) # this counts how often an agent has done the same task already task_count = torch.zeros(num_processes).long().to(device) if encoder is not None: # reset latent state to prior latent_sample, latent_mean, latent_logvar, hidden_state = encoder.prior( num_processes) else: latent_sample = latent_mean = latent_logvar = hidden_state = None for episode_idx in range(num_episodes): for step_idx in range(num_steps): with torch.no_grad(): _, action = utl.select_action(args=args, policy=policy, state=state, belief=belief, task=task, latent_sample=latent_sample, latent_mean=latent_mean, latent_logvar=latent_logvar, deterministic=True) # observe reward and next obs [state, belief, task], (rew_raw, rew_normalised), done, infos = utl.env_step( envs, action, args) done_mdp = [info['done_mdp'] for info in infos] if encoder is not None: # update the hidden state latent_sample, latent_mean, latent_logvar, hidden_state = utl.update_encoding( encoder=encoder, next_obs=state, action=action, reward=rew_raw, done=None, hidden_state=hidden_state) # add rewards returns_per_episode[range(num_processes), task_count] += rew_raw.view(-1) for i in np.argwhere(done_mdp).flatten(): # count task up, but cap at num_episodes + 1 task_count[i] = min(task_count[i] + 1, num_episodes) # zero-indexed, so no +1 if np.sum(done) > 0: done_indices = np.argwhere(done.flatten()).flatten() state, belief, task = utl.reset_env(envs, args, indices=done_indices, state=state) envs.close() return returns_per_episode[:, :num_episodes]
def __init__(self, args): self.args = args utl.seed(self.args.seed, self.args.deterministic_execution) # calculate number of updates and keep count of frames/iterations self.num_updates = int( args.num_frames) // args.policy_num_steps // args.num_processes self.frames = 0 self.iter_idx = -1 # initialise tensorboard logger self.logger = TBLogger(self.args, self.args.exp_label) # initialise environments self.envs = make_vec_envs( env_name=args.env_name, seed=args.seed, num_processes=args.num_processes, gamma=args.policy_gamma, device=device, episodes_per_task=self.args.max_rollouts_per_task, normalise_rew=args.norm_rew_for_policy, ret_rms=None, tasks=None) if self.args.single_task_mode: # get the current tasks (which will be num_process many different tasks) self.train_tasks = self.envs.get_task() # set the tasks to the first task (i.e. just a random task) self.train_tasks[1:] = self.train_tasks[0] # make it a list self.train_tasks = [t for t in self.train_tasks] # re-initialise environments with those tasks self.envs = make_vec_envs( env_name=args.env_name, seed=args.seed, num_processes=args.num_processes, gamma=args.policy_gamma, device=device, episodes_per_task=self.args.max_rollouts_per_task, normalise_rew=args.norm_rew_for_policy, ret_rms=None, tasks=self.train_tasks, ) # save the training tasks so we can evaluate on the same envs later utl.save_obj(self.train_tasks, self.logger.full_output_folder, "train_tasks") else: self.train_tasks = None # calculate what the maximum length of the trajectories is args.max_trajectory_len = self.envs._max_episode_steps args.max_trajectory_len *= self.args.max_rollouts_per_task # get policy input dimensions self.args.state_dim = self.envs.observation_space.shape[0] self.args.task_dim = self.envs.task_dim self.args.belief_dim = self.envs.belief_dim self.args.num_states = self.envs.num_states # get policy output (action) dimensions self.args.action_space = self.envs.action_space if isinstance(self.envs.action_space, gym.spaces.discrete.Discrete): self.args.action_dim = 1 else: self.args.action_dim = self.envs.action_space.shape[0] # initialise policy self.policy_storage = self.initialise_policy_storage() self.policy = self.initialise_policy()