def main():
    parser = argparse.ArgumentParser()
    # parser.add_argument('--env-type', default='gridworld')
    # parser.add_argument('--env-type', default='point_robot_sparse')
    # parser.add_argument('--env-type', default='cheetah_vel')
    # parser.add_argument('--env-type', default='ant_semicircle_sparse')
    parser.add_argument('--env-type', default='point_robot_wind')
    # parser.add_argument('--env-type', default='escape_room')

    args, rest_args = parser.parse_known_args()
    env = args.env_type

    # --- GridWorld ---
    if env == 'gridworld':
        args = args_gridworld.get_args(rest_args)
    # --- PointRobot ---
    elif env == 'point_robot_sparse':
        args = args_point_robot_sparse.get_args(rest_args)
    elif env == 'escape_room':
        args = args_point_robot_barrier.get_args(rest_args)
    elif env == 'point_robot_wind':
        args = args_point_robot_rand_params.get_args(rest_args)
    # --- Mujoco ---
    elif env == 'cheetah_vel':
        args = args_cheetah_vel.get_args(rest_args)
    elif env == 'ant_semicircle_sparse':
        args = args_ant_semicircle_sparse.get_args(rest_args)

    set_gpu_mode(torch.cuda.is_available() and args.use_gpu)

    args, env = off_utl.expand_args(args)

    dataset, goals = off_utl.load_dataset(data_dir=args.data_dir, args=args, arr_type='numpy')
    # dataset, goals = off_utl.load_dataset(args)
    if args.hindsight_relabelling:
        print('Perform reward relabelling...')
        dataset, goals = off_utl.mix_task_rollouts(dataset, env, goals, args)

    if args.policy_replaying:
        mix_dataset, mix_goals = off_utl.load_replaying_dataset(data_dir=args.replaying_data_dir, args=args)
        print('Perform policy replaying...')
        dataset, goals = off_utl.mix_policy_rollouts(dataset, goals, mix_dataset, mix_goals, args)

    # vis test tasks
    # vis_train_tasks(env.unwrapped, goals)     # not with GridNavi

    if args.save_model:
        dir_prefix = args.save_dir_prefix if hasattr(args, 'save_dir_prefix') \
                                             and args.save_dir_prefix is not None else ''
        args.full_save_path = os.path.join(args.save_dir, args.env_name,
                                           dir_prefix + datetime.datetime.now().strftime('__%d_%m_%H_%M_%S'))
        os.makedirs(args.full_save_path, exist_ok=True)
        config_utl.save_config_file(args, args.full_save_path)

    vae = VAE(args)
    train(vae, dataset, goals, args)
Exemple #2
0
    def __init__(self, args):
        """
        Seeds everything.
        Initialises: logger, environments, policy (+storage +optimiser).
        """

        self.args = args

        # make sure everything has the same seed
        utl.seed(self.args.seed)

        # initialize tensorboard logger
        if self.args.log_tensorboard:
            self.tb_logger = TBLogger(self.args)

        self.args, env = off_utl.expand_args(self.args, include_act_space=True)
        if self.args.act_space.__class__.__name__ == "Discrete":
            self.args.policy = 'dqn'
        else:
            self.args.policy = 'sac'

        # load buffers with data
        if 'load_data' not in self.args or self.args.load_data:
            goals, augmented_obs_dim = self.load_buffer(
                env)  # env is input just for possible relabelling option
            self.args.augmented_obs_dim = augmented_obs_dim
            self.goals = goals

        # initialize policy
        self.initialize_policy()

        # load vae for inference in evaluation
        self.load_vae()

        # create environment for evaluation
        self.env = make_env(
            args.env_name,
            args.max_rollouts_per_task,
            presampled_tasks=args.presampled_tasks,
            seed=args.seed,
        )
        # n_tasks=self.args.num_eval_tasks)
        if self.args.env_name == 'GridNavi-v2':
            self.env.unwrapped.goals = [
                tuple(goal.astype(int)) for goal in self.goals
            ]
Exemple #3
0
def _train_vae(log_dir,
               offline_buffer_path,
               saved_tasks_path,
               env_type,
               seed,
               path_length,
               meta_episode_len,
               load_buffer_kwargs=None,
               **kwargs):
    with open(os.path.join(log_dir, 'test.txt'), 'w') as f:
        f.write("hello from train_vae_offline.py")
    if load_buffer_kwargs is None:
        load_buffer_kwargs = {}
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)

    parser = argparse.ArgumentParser()
    # parser.add_argument('--env-type', default='gridworld')
    # parser.add_argument('--env-type', default='point_robot_sparse')
    # parser.add_argument('--env-type', default='cheetah_vel')
    parser.add_argument('--env-type', default='ant_semicircle_sparse')
    extra_args = []
    for k, v in kwargs.items():
        extra_args.append('--{}'.format(k))
        extra_args.append(str(v))
    args, rest_args = parser.parse_known_args(args=extra_args)

    # --- GridWorld ---
    if env_type == 'cheetah_vel':
        args = args_cheetah_vel.get_args(rest_args)
        args.env_name = 'HalfCheetahVel-v0'
    elif env_type == 'ant_dir':
        # TODO: replace with ant_dir env
        args = args_ant_semicircle_sparse.get_args(rest_args)
        parser.add_argument('--env-name', default='AntSemiCircleSparse-v0')
        args.env_name = 'AntDir-v0'
    elif env_type == 'walker':
        args = args_walker_param.get_args(rest_args)
    elif env_type == 'hopper':
        args = args_hopper_param.get_args(rest_args)
    elif env_type == 'humanoid':
        args = args_humanoid_dir.get_args(rest_args)
    else:
        raise ValueError('Unknown env_type: {}'.format(env_type))

    set_gpu_mode(torch.cuda.is_available() and args.use_gpu)

    args, env = off_utl.expand_args(args)
    args.save_dir = os.path.join(log_dir, 'trained_vae')

    args.trajectory_len = path_length
    task_data = joblib.load(saved_tasks_path)
    tasks = task_data['tasks']
    print("loading dataset")
    with open(os.path.join(log_dir, 'tmp1.txt'), 'w') as f:
        f.write("train_vae_offline.py: start loading dataset")
    dataset, goals = off_utl.load_pearl_buffer(
        pretrain_buffer_path=offline_buffer_path,
        tasks=tasks,
        add_done_info=env.add_done_info,
        path_length=path_length,
        meta_episode_len=meta_episode_len,
        **load_buffer_kwargs)
    with open(os.path.join(log_dir, 'tmp1.txt'), 'a') as f:
        f.write("train_vae_offline.py: done loading dataset")
    print("done loading dataset")
    for data in dataset:
        print(data[0].shape)

    dataset = [[x.astype(np.float32) for x in d] for d in dataset]

    if args.save_model:
        dir_prefix = args.save_dir_prefix if hasattr(args, 'save_dir_prefix') \
                                             and args.save_dir_prefix is not None else ''
        args.full_save_path = os.path.join(
            args.save_dir, args.env_name,
            dir_prefix + datetime.datetime.now().strftime('__%d_%m_%H_%M_%S'))
        os.makedirs(args.full_save_path, exist_ok=True)
        config_utl.save_config_file(args, args.full_save_path)

    vae = VAE(args)
    train(vae, dataset, args)