def main(): parser = argparse.ArgumentParser() # parser.add_argument('--env-type', default='gridworld') # parser.add_argument('--env-type', default='point_robot_sparse') # parser.add_argument('--env-type', default='cheetah_vel') # parser.add_argument('--env-type', default='ant_semicircle_sparse') parser.add_argument('--env-type', default='point_robot_wind') # parser.add_argument('--env-type', default='escape_room') args, rest_args = parser.parse_known_args() env = args.env_type # --- GridWorld --- if env == 'gridworld': args = args_gridworld.get_args(rest_args) # --- PointRobot --- elif env == 'point_robot_sparse': args = args_point_robot_sparse.get_args(rest_args) elif env == 'escape_room': args = args_point_robot_barrier.get_args(rest_args) elif env == 'point_robot_wind': args = args_point_robot_rand_params.get_args(rest_args) # --- Mujoco --- elif env == 'cheetah_vel': args = args_cheetah_vel.get_args(rest_args) elif env == 'ant_semicircle_sparse': args = args_ant_semicircle_sparse.get_args(rest_args) set_gpu_mode(torch.cuda.is_available() and args.use_gpu) args, env = off_utl.expand_args(args) dataset, goals = off_utl.load_dataset(data_dir=args.data_dir, args=args, arr_type='numpy') # dataset, goals = off_utl.load_dataset(args) if args.hindsight_relabelling: print('Perform reward relabelling...') dataset, goals = off_utl.mix_task_rollouts(dataset, env, goals, args) if args.policy_replaying: mix_dataset, mix_goals = off_utl.load_replaying_dataset(data_dir=args.replaying_data_dir, args=args) print('Perform policy replaying...') dataset, goals = off_utl.mix_policy_rollouts(dataset, goals, mix_dataset, mix_goals, args) # vis test tasks # vis_train_tasks(env.unwrapped, goals) # not with GridNavi if args.save_model: dir_prefix = args.save_dir_prefix if hasattr(args, 'save_dir_prefix') \ and args.save_dir_prefix is not None else '' args.full_save_path = os.path.join(args.save_dir, args.env_name, dir_prefix + datetime.datetime.now().strftime('__%d_%m_%H_%M_%S')) os.makedirs(args.full_save_path, exist_ok=True) config_utl.save_config_file(args, args.full_save_path) vae = VAE(args) train(vae, dataset, goals, args)
def __init__(self, args): """ Seeds everything. Initialises: logger, environments, policy (+storage +optimiser). """ self.args = args # make sure everything has the same seed utl.seed(self.args.seed) # initialize tensorboard logger if self.args.log_tensorboard: self.tb_logger = TBLogger(self.args) self.args, env = off_utl.expand_args(self.args, include_act_space=True) if self.args.act_space.__class__.__name__ == "Discrete": self.args.policy = 'dqn' else: self.args.policy = 'sac' # load buffers with data if 'load_data' not in self.args or self.args.load_data: goals, augmented_obs_dim = self.load_buffer( env) # env is input just for possible relabelling option self.args.augmented_obs_dim = augmented_obs_dim self.goals = goals # initialize policy self.initialize_policy() # load vae for inference in evaluation self.load_vae() # create environment for evaluation self.env = make_env( args.env_name, args.max_rollouts_per_task, presampled_tasks=args.presampled_tasks, seed=args.seed, ) # n_tasks=self.args.num_eval_tasks) if self.args.env_name == 'GridNavi-v2': self.env.unwrapped.goals = [ tuple(goal.astype(int)) for goal in self.goals ]
def _train_vae(log_dir, offline_buffer_path, saved_tasks_path, env_type, seed, path_length, meta_episode_len, load_buffer_kwargs=None, **kwargs): with open(os.path.join(log_dir, 'test.txt'), 'w') as f: f.write("hello from train_vae_offline.py") if load_buffer_kwargs is None: load_buffer_kwargs = {} random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed(seed) parser = argparse.ArgumentParser() # parser.add_argument('--env-type', default='gridworld') # parser.add_argument('--env-type', default='point_robot_sparse') # parser.add_argument('--env-type', default='cheetah_vel') parser.add_argument('--env-type', default='ant_semicircle_sparse') extra_args = [] for k, v in kwargs.items(): extra_args.append('--{}'.format(k)) extra_args.append(str(v)) args, rest_args = parser.parse_known_args(args=extra_args) # --- GridWorld --- if env_type == 'cheetah_vel': args = args_cheetah_vel.get_args(rest_args) args.env_name = 'HalfCheetahVel-v0' elif env_type == 'ant_dir': # TODO: replace with ant_dir env args = args_ant_semicircle_sparse.get_args(rest_args) parser.add_argument('--env-name', default='AntSemiCircleSparse-v0') args.env_name = 'AntDir-v0' elif env_type == 'walker': args = args_walker_param.get_args(rest_args) elif env_type == 'hopper': args = args_hopper_param.get_args(rest_args) elif env_type == 'humanoid': args = args_humanoid_dir.get_args(rest_args) else: raise ValueError('Unknown env_type: {}'.format(env_type)) set_gpu_mode(torch.cuda.is_available() and args.use_gpu) args, env = off_utl.expand_args(args) args.save_dir = os.path.join(log_dir, 'trained_vae') args.trajectory_len = path_length task_data = joblib.load(saved_tasks_path) tasks = task_data['tasks'] print("loading dataset") with open(os.path.join(log_dir, 'tmp1.txt'), 'w') as f: f.write("train_vae_offline.py: start loading dataset") dataset, goals = off_utl.load_pearl_buffer( pretrain_buffer_path=offline_buffer_path, tasks=tasks, add_done_info=env.add_done_info, path_length=path_length, meta_episode_len=meta_episode_len, **load_buffer_kwargs) with open(os.path.join(log_dir, 'tmp1.txt'), 'a') as f: f.write("train_vae_offline.py: done loading dataset") print("done loading dataset") for data in dataset: print(data[0].shape) dataset = [[x.astype(np.float32) for x in d] for d in dataset] if args.save_model: dir_prefix = args.save_dir_prefix if hasattr(args, 'save_dir_prefix') \ and args.save_dir_prefix is not None else '' args.full_save_path = os.path.join( args.save_dir, args.env_name, dir_prefix + datetime.datetime.now().strftime('__%d_%m_%H_%M_%S')) os.makedirs(args.full_save_path, exist_ok=True) config_utl.save_config_file(args, args.full_save_path) vae = VAE(args) train(vae, dataset, args)