def load_buffer(self, env): if self.args.hindsight_relabelling: # without arr_type loading -- GPU will explode dataset, goals = off_utl.load_dataset( data_dir=self.args.relabelled_data_dir, args=self.args, num_tasks=self.args.num_train_tasks, allow_dense_data_loading=False, arr_type='numpy') dataset = off_utl.batch_to_trajectories(dataset, self.args) dataset, goals = off_utl.mix_task_rollouts( dataset, env, goals, self.args) # reward relabelling dataset = off_utl.trajectories_to_batch(dataset) else: dataset, goals = off_utl.load_dataset( data_dir=self.args.relabelled_data_dir, args=self.args, num_tasks=self.args.num_train_tasks, allow_dense_data_loading=False, arr_type='numpy') augmented_obs_dim = dataset[0][0].shape[1] self.storage = MultiTaskPolicyStorage( max_replay_buffer_size=max([d[0].shape[0] for d in dataset]), obs_dim=dataset[0][0].shape[1], action_space=self.args.act_space, tasks=range(len(goals)), trajectory_len=self.args.trajectory_len) for task, set in enumerate(dataset): self.storage.add_samples(task, observations=set[0], actions=set[1], rewards=set[2], next_observations=set[3], terminals=set[4]) return goals, augmented_obs_dim
def main(): parser = argparse.ArgumentParser() # parser.add_argument('--env-type', default='gridworld') # parser.add_argument('--env-type', default='point_robot_sparse') # parser.add_argument('--env-type', default='cheetah_vel') # parser.add_argument('--env-type', default='ant_semicircle_sparse') parser.add_argument('--env-type', default='point_robot_wind') # parser.add_argument('--env-type', default='escape_room') args, rest_args = parser.parse_known_args() env = args.env_type # --- GridWorld --- if env == 'gridworld': args = args_gridworld.get_args(rest_args) # --- PointRobot --- elif env == 'point_robot_sparse': args = args_point_robot_sparse.get_args(rest_args) elif env == 'escape_room': args = args_point_robot_barrier.get_args(rest_args) elif env == 'point_robot_wind': args = args_point_robot_rand_params.get_args(rest_args) # --- Mujoco --- elif env == 'cheetah_vel': args = args_cheetah_vel.get_args(rest_args) elif env == 'ant_semicircle_sparse': args = args_ant_semicircle_sparse.get_args(rest_args) set_gpu_mode(torch.cuda.is_available() and args.use_gpu) args, env = off_utl.expand_args(args) dataset, goals = off_utl.load_dataset(data_dir=args.data_dir, args=args, arr_type='numpy') # dataset, goals = off_utl.load_dataset(args) if args.hindsight_relabelling: print('Perform reward relabelling...') dataset, goals = off_utl.mix_task_rollouts(dataset, env, goals, args) if args.policy_replaying: mix_dataset, mix_goals = off_utl.load_replaying_dataset(data_dir=args.replaying_data_dir, args=args) print('Perform policy replaying...') dataset, goals = off_utl.mix_policy_rollouts(dataset, goals, mix_dataset, mix_goals, args) # vis test tasks # vis_train_tasks(env.unwrapped, goals) # not with GridNavi if args.save_model: dir_prefix = args.save_dir_prefix if hasattr(args, 'save_dir_prefix') \ and args.save_dir_prefix is not None else '' args.full_save_path = os.path.join(args.save_dir, args.env_name, dir_prefix + datetime.datetime.now().strftime('__%d_%m_%H_%M_%S')) os.makedirs(args.full_save_path, exist_ok=True) config_utl.save_config_file(args, args.full_save_path) vae = VAE(args) train(vae, dataset, goals, args)
def offline_experiment(doodad_config, variant): save_doodad_config(doodad_config) parser = argparse.ArgumentParser() # parser.add_argument('--env-type', default='gridworld') # parser.add_argument('--env-type', default='point_robot_sparse') # parser.add_argument('--env-type', default='cheetah_vel') parser.add_argument('--env-type', default='ant_semicircle_sparse') args, rest_args = parser.parse_known_args(args=[]) env = args.env_type # --- GridWorld --- if env == 'gridworld': args = args_gridworld.get_args(rest_args) # --- PointRobot --- elif env == 'point_robot_sparse': args = args_point_robot_sparse.get_args(rest_args) # --- Mujoco --- elif env == 'cheetah_vel': args = args_cheetah_vel.get_args(rest_args) elif env == 'ant_semicircle_sparse': args = args_ant_semicircle_sparse.get_args(rest_args) set_gpu_mode(torch.cuda.is_available() and args.use_gpu) vae_args = config_utl.load_config_file( os.path.join(args.vae_dir, args.env_name, args.vae_model_name, 'online_config.json')) args = config_utl.merge_configs( vae_args, args) # order of input to this function is important # Transform data BAMDP (state relabelling) if args.transform_data_bamdp: # load VAE for state relabelling vae_models_path = os.path.join(args.vae_dir, args.env_name, args.vae_model_name, 'models') vae = VAE(args) off_utl.load_trained_vae(vae, vae_models_path) # load data and relabel save_data_path = os.path.join(args.main_data_dir, args.env_name, args.relabelled_data_dir) os.makedirs(save_data_path) dataset, goals = off_utl.load_dataset(data_dir=args.data_dir, args=args, arr_type='numpy') bamdp_dataset = off_utl.transform_mdps_ds_to_bamdp_ds( dataset, vae, args) # save relabelled data off_utl.save_dataset(save_data_path, bamdp_dataset, goals) learner = OfflineMetaLearner(args) learner.train()
def collect_hindsight_data(): parser = argparse.ArgumentParser() parser.add_argument('--env-type', default='point_robot_wind') #parser.add_argument('--env-type', default='escape_room') args, rest_args = parser.parse_known_args() env = args.env_type if env == 'point_robot_wind': args = args_point_robot_rand_params.get_args(rest_args) elif env == 'escape_room': args = args_point_robot_barrier.get_args(rest_args) # necessary args because we use VAE functions args.main_data_dir = args.main_save_dir args.trajectory_len = 50 args.num_trajs_per_task = None args.num_rollouts = 10 set_gpu_mode(torch.cuda.is_available()) if hasattr(args, 'save_buffer') and args.save_buffer: os.makedirs(args.main_save_dir, exist_ok=True) _, goals = off_utl.load_dataset(data_dir=args.save_dir, args=args, arr_type='numpy') args.save_dir = "hindsight_data" args.save_data_path = os.path.join(args.main_data_dir, args.env_name, args.save_dir) os.makedirs(args.save_data_path) models_dir = './trained_agents' all_dirs = os.listdir(models_dir) for i, goal in enumerate(goals): print("start collect rollouts for task number ", i + 1) collect_rollout_per_goal(args, goal, all_dirs)