コード例 #1
0
ファイル: offline_metalearner.py プロジェクト: vitchyr/BOReL
 def load_buffer(self, env):
     if self.args.hindsight_relabelling:  # without arr_type loading -- GPU will explode
         dataset, goals = off_utl.load_dataset(
             data_dir=self.args.relabelled_data_dir,
             args=self.args,
             num_tasks=self.args.num_train_tasks,
             allow_dense_data_loading=False,
             arr_type='numpy')
         dataset = off_utl.batch_to_trajectories(dataset, self.args)
         dataset, goals = off_utl.mix_task_rollouts(
             dataset, env, goals, self.args)  # reward relabelling
         dataset = off_utl.trajectories_to_batch(dataset)
     else:
         dataset, goals = off_utl.load_dataset(
             data_dir=self.args.relabelled_data_dir,
             args=self.args,
             num_tasks=self.args.num_train_tasks,
             allow_dense_data_loading=False,
             arr_type='numpy')
     augmented_obs_dim = dataset[0][0].shape[1]
     self.storage = MultiTaskPolicyStorage(
         max_replay_buffer_size=max([d[0].shape[0] for d in dataset]),
         obs_dim=dataset[0][0].shape[1],
         action_space=self.args.act_space,
         tasks=range(len(goals)),
         trajectory_len=self.args.trajectory_len)
     for task, set in enumerate(dataset):
         self.storage.add_samples(task,
                                  observations=set[0],
                                  actions=set[1],
                                  rewards=set[2],
                                  next_observations=set[3],
                                  terminals=set[4])
     return goals, augmented_obs_dim
コード例 #2
0
def main():
    parser = argparse.ArgumentParser()
    # parser.add_argument('--env-type', default='gridworld')
    # parser.add_argument('--env-type', default='point_robot_sparse')
    # parser.add_argument('--env-type', default='cheetah_vel')
    # parser.add_argument('--env-type', default='ant_semicircle_sparse')
    parser.add_argument('--env-type', default='point_robot_wind')
    # parser.add_argument('--env-type', default='escape_room')

    args, rest_args = parser.parse_known_args()
    env = args.env_type

    # --- GridWorld ---
    if env == 'gridworld':
        args = args_gridworld.get_args(rest_args)
    # --- PointRobot ---
    elif env == 'point_robot_sparse':
        args = args_point_robot_sparse.get_args(rest_args)
    elif env == 'escape_room':
        args = args_point_robot_barrier.get_args(rest_args)
    elif env == 'point_robot_wind':
        args = args_point_robot_rand_params.get_args(rest_args)
    # --- Mujoco ---
    elif env == 'cheetah_vel':
        args = args_cheetah_vel.get_args(rest_args)
    elif env == 'ant_semicircle_sparse':
        args = args_ant_semicircle_sparse.get_args(rest_args)

    set_gpu_mode(torch.cuda.is_available() and args.use_gpu)

    args, env = off_utl.expand_args(args)

    dataset, goals = off_utl.load_dataset(data_dir=args.data_dir, args=args, arr_type='numpy')
    # dataset, goals = off_utl.load_dataset(args)
    if args.hindsight_relabelling:
        print('Perform reward relabelling...')
        dataset, goals = off_utl.mix_task_rollouts(dataset, env, goals, args)

    if args.policy_replaying:
        mix_dataset, mix_goals = off_utl.load_replaying_dataset(data_dir=args.replaying_data_dir, args=args)
        print('Perform policy replaying...')
        dataset, goals = off_utl.mix_policy_rollouts(dataset, goals, mix_dataset, mix_goals, args)

    # vis test tasks
    # vis_train_tasks(env.unwrapped, goals)     # not with GridNavi

    if args.save_model:
        dir_prefix = args.save_dir_prefix if hasattr(args, 'save_dir_prefix') \
                                             and args.save_dir_prefix is not None else ''
        args.full_save_path = os.path.join(args.save_dir, args.env_name,
                                           dir_prefix + datetime.datetime.now().strftime('__%d_%m_%H_%M_%S'))
        os.makedirs(args.full_save_path, exist_ok=True)
        config_utl.save_config_file(args, args.full_save_path)

    vae = VAE(args)
    train(vae, dataset, goals, args)
コード例 #3
0
ファイル: run_experiment.py プロジェクト: vitchyr/BOReL
def offline_experiment(doodad_config, variant):
    save_doodad_config(doodad_config)
    parser = argparse.ArgumentParser()
    # parser.add_argument('--env-type', default='gridworld')
    # parser.add_argument('--env-type', default='point_robot_sparse')
    # parser.add_argument('--env-type', default='cheetah_vel')
    parser.add_argument('--env-type', default='ant_semicircle_sparse')
    args, rest_args = parser.parse_known_args(args=[])
    env = args.env_type

    # --- GridWorld ---
    if env == 'gridworld':
        args = args_gridworld.get_args(rest_args)
    # --- PointRobot ---
    elif env == 'point_robot_sparse':
        args = args_point_robot_sparse.get_args(rest_args)
    # --- Mujoco ---
    elif env == 'cheetah_vel':
        args = args_cheetah_vel.get_args(rest_args)
    elif env == 'ant_semicircle_sparse':
        args = args_ant_semicircle_sparse.get_args(rest_args)

    set_gpu_mode(torch.cuda.is_available() and args.use_gpu)

    vae_args = config_utl.load_config_file(
        os.path.join(args.vae_dir, args.env_name, args.vae_model_name,
                     'online_config.json'))
    args = config_utl.merge_configs(
        vae_args, args)  # order of input to this function is important

    # Transform data BAMDP (state relabelling)
    if args.transform_data_bamdp:
        # load VAE for state relabelling
        vae_models_path = os.path.join(args.vae_dir, args.env_name,
                                       args.vae_model_name, 'models')
        vae = VAE(args)
        off_utl.load_trained_vae(vae, vae_models_path)
        # load data and relabel
        save_data_path = os.path.join(args.main_data_dir, args.env_name,
                                      args.relabelled_data_dir)
        os.makedirs(save_data_path)
        dataset, goals = off_utl.load_dataset(data_dir=args.data_dir,
                                              args=args,
                                              arr_type='numpy')
        bamdp_dataset = off_utl.transform_mdps_ds_to_bamdp_ds(
            dataset, vae, args)
        # save relabelled data
        off_utl.save_dataset(save_data_path, bamdp_dataset, goals)

    learner = OfflineMetaLearner(args)

    learner.train()
コード例 #4
0
def collect_hindsight_data():
    parser = argparse.ArgumentParser()

    parser.add_argument('--env-type', default='point_robot_wind')
    #parser.add_argument('--env-type', default='escape_room')

    args, rest_args = parser.parse_known_args()
    env = args.env_type

    if env == 'point_robot_wind':
        args = args_point_robot_rand_params.get_args(rest_args)
    elif env == 'escape_room':
        args = args_point_robot_barrier.get_args(rest_args)

    # necessary args because we use VAE functions
    args.main_data_dir = args.main_save_dir
    args.trajectory_len = 50
    args.num_trajs_per_task = None

    args.num_rollouts = 10

    set_gpu_mode(torch.cuda.is_available())

    if hasattr(args, 'save_buffer') and args.save_buffer:
        os.makedirs(args.main_save_dir, exist_ok=True)

    _, goals = off_utl.load_dataset(data_dir=args.save_dir,
                                    args=args,
                                    arr_type='numpy')

    args.save_dir = "hindsight_data"
    args.save_data_path = os.path.join(args.main_data_dir, args.env_name,
                                       args.save_dir)
    os.makedirs(args.save_data_path)
    models_dir = './trained_agents'
    all_dirs = os.listdir(models_dir)
    for i, goal in enumerate(goals):
        print("start collect rollouts for task number ", i + 1)
        collect_rollout_per_goal(args, goal, all_dirs)