def __init__( self, trainer, replay_buffer, demo_train_buffer, demo_test_buffer, model=None, model_path=None, input_model=None, input_model_path=None, reward_fn=None, env=None, demo_paths=[], # list of dicts normalize=False, demo_train_split=0.9, demo_data_split=1, add_demos_to_replay_buffer=True, condition_input_encoding=False, bc_num_pretrain_steps=0, bc_batch_size=64, bc_weight=1.0, rl_weight=1.0, q_num_pretrain_steps=0, weight_decay=0, eval_policy=None, recompute_reward=False, object_list=None, env_info_key=None, obs_key=None, load_terminals=True, delete_after_loading=False, data_filter_fn=lambda x: True, # Return true to add path, false to ignore it **kwargs): super().__init__(trainer, replay_buffer, demo_train_buffer, demo_test_buffer, demo_paths, demo_train_split, demo_data_split, add_demos_to_replay_buffer, bc_num_pretrain_steps, bc_batch_size, bc_weight, rl_weight, q_num_pretrain_steps, weight_decay, eval_policy, recompute_reward, env_info_key, obs_key, load_terminals, delete_after_loading, data_filter_fn, **kwargs) if model is None: self.model = load_local_or_remote_file( model_path, delete_after_loading=delete_after_loading) else: self.model = model if input_model is None: self.input_model = load_local_or_remote_file( input_model_path, delete_after_loading=delete_after_loading) else: self.input_model = input_model self.condition_input_encoding = condition_input_encoding self.reward_fn = reward_fn self.normalize = normalize self.object_list = object_list self.env = env
def __init__( self, wrapped_env, clusterer, mode='train', reward_params=None, ): self.quick_init(locals()) super().__init__(wrapped_env) if type(clusterer) is str: self.clusterer = load_local_or_remote_file(clusterer) else: self.clusterer = clusterer self._num_clusters = self.clusterer.num_clusters self.task = {} self.reward_params = reward_params self.reward_type = self.reward_params.get('type', 's_given_z') spaces = copy.deepcopy(self.wrapped_env.observation_space.spaces) spaces['context'] = Discrete(n=self._num_clusters) self.observation_space = Dict(spaces) assert self.reward_type == 'wrapped_env' or self.reward_type == 's_given_z' self._set_clusterer_attributes()
def __init__( self, wrapped_env, disc, mode='train', reward_params=None, unsupervised_reward_weight=0., reward_weight=0., noise_scale=0., ): self.quick_init(locals()) super().__init__(wrapped_env) if type(disc) is str: self.disc = load_local_or_remote_file(disc) else: self.disc = disc self._num_skills = self.disc.num_skills self._p_z = np.full(self._num_skills, 1.0 / self._num_skills) self.task = {'context': -1} self.reward_params = reward_params self.reward_type = self.reward_params.get('type', 'diayn') # TODO: Check that TIAYN reward is set properly. spaces = copy.deepcopy(self.wrapped_env.observation_space.spaces) spaces['context'] = Discrete(n=self._num_skills) self.observation_space = Dict(spaces) self.unsupervised_reward_weight = unsupervised_reward_weight self.reward_weight = reward_weight self.noise_scale = noise_scale assert self.reward_type == 'wrapped_env' or self.reward_type == 'diayn' or self.reward_type == 'wrapped_env + diayn' or self.reward_type == 'tiayn' or self.reward_type == 'wrapped_env + tiayn'
def main(): debug = True dry = False mode = 'here_no_doodad' suffix = '' nseeds = 1 gpu = True path_parts = __file__.split('/') suffix = '' if suffix is None else '--{}'.format(suffix) exp_name = 'pearl-awac-{}--{}{}'.format( path_parts[-2].replace('_', '-'), path_parts[-1].split('.')[0].replace('_', '-'), suffix, ) if debug or dry: exp_name = 'dev--' + exp_name mode = 'here_no_doodad' nseeds = 1 if dry: mode = 'here_no_doodad' print(exp_name) task_data = load_local_or_remote_file( "examples/smac/ant_tasks.joblib", # TODO: update to point to correct file file_type='joblib') tasks = task_data['tasks'] search_space = { 'seed': list(range(nseeds)), } variant = DEFAULT_PEARL_CONFIG.copy() variant["env_name"] = "ant-dir" # variant["train_task_idxs"] = list(range(100)) # variant["eval_task_idxs"] = list(range(100, 120)) variant["env_params"]["fixed_tasks"] = [t['goal'] for t in tasks] variant["env_params"]["direction_in_degrees"] = True variant["trainer_kwargs"]["train_context_decoder"] = True variant["trainer_kwargs"]["backprop_q_loss_into_encoder"] = True variant[ "saved_tasks_path"] = "examples/smac/ant_tasks.joblib" # TODO: update to point to correct file sweeper = hyp.DeterministicHyperparameterSweeper( search_space, default_parameters=variant, ) for exp_id, variant in enumerate(sweeper.iterate_hyperparameters()): variant['exp_id'] = exp_id run_experiment( pearl_experiment, unpack_variant=True, exp_prefix=exp_name, mode=mode, variant=variant, time_in_mins=3 * 24 * 60 - 1, use_gpu=gpu, )
def resume(variant): data = load_local_or_remote_file(variant.get("pretrained_algorithm_path"), map_location="cuda") algo = data['algorithm'] algo.num_epochs = variant['num_epochs'] post_pretrain_hyperparams = variant["trainer_kwargs"].get("post_pretrain_hyperparams", {}) algo.trainer.set_algorithm_weights(**post_pretrain_hyperparams) algo.train()
def load_demo_path(self, demo_path, on_policy=True): data = list(load_local_or_remote_file(demo_path)) N = int(len(data) * self.demo_train_split) print("using", N, "paths for training") if self.add_demos_to_replay_buffer: for path in data[:N]: self.load_path(path, self.replay_buffer) if on_policy: for path in data[:N]: self.load_path(path, self.demo_train_buffer) for path in data[N:]: self.load_path(path, self.demo_test_buffer)
def __init__( self, datapath, representation_size, initialize_encodings=True, # Set to true if you plan to re-encode presampled images ): self._presampled_goals = load_local_or_remote_file(datapath) self.representation_size = representation_size self._num_presampled_goals = self._presampled_goals[list( self._presampled_goals)[0]].shape[0] if initialize_encodings: self._presampled_goals['initial_latent_state'] = np.zeros( (self._num_presampled_goals, self.representation_size)) self._set_spaces()
def encode_dataset(dataset_path): data = load_local_or_remote_file(dataset_path) data = data.item() all_data = [] vqvae.to('cpu') for i in tqdm(range(data["observations"].shape[0])): obs = ptu.from_numpy(data["observations"][i] / 255.0) latent = vqvae.encode(obs, cont=False) all_data.append(latent) vqvae.to('cuda') encodings = ptu.get_numpy(torch.cat(all_data, dim=0)) return encodings
def encode_dataset(path, object_list): data = load_local_or_remote_file(path) data = data.item() data = data_filter_fn(data) all_data = [] n = min(data["observations"].shape[0], data_size) for i in tqdm(range(n)): obs = ptu.from_numpy(data["observations"][i] / 255.0) latent = vqvae.encode(obs, cont=False) all_data.append(latent) encodings = ptu.get_numpy(torch.stack(all_data, dim=0)) return encodings
def concatenate_datasets(data_list): from rlkit.util.io import load_local_or_remote_file obs, envs, dataset = [], [], {} for path in data_list: curr_data = load_local_or_remote_file(path) curr_data = curr_data.item() n_random_steps = curr_data['observations'].shape[1] imlength = curr_data['observations'].shape[2] curr_data['env'] = np.repeat(curr_data['env'], n_random_steps, axis=0) curr_data['observations'] = curr_data['observations'].reshape( -1, 1, imlength) obs.append(curr_data['observations']) envs.append(curr_data['env']) dataset['observations'] = np.concatenate(obs, axis=0) dataset['env'] = np.concatenate(envs, axis=0) return dataset
def load_demo_path(self, path, is_demo, obs_dict, train_split=None, data_split=None, sync_dir=None): print("loading off-policy path", path) if sync_dir is not None: sync_down_folder(sync_dir) paths = glob.glob(get_absolute_path(path)) else: paths = [path] data = [] for filename in paths: data.extend( list( load_local_or_remote_file( filename, delete_after_loading=self.delete_after_loading))) # if not is_demo: # data = [data] # random.shuffle(data) if train_split is None: train_split = self.demo_train_split if data_split is None: data_split = self.demo_data_split M = int(len(data) * train_split * data_split) N = int(len(data) * data_split) print("using", N, "paths for training") if self.add_demos_to_replay_buffer: for path in data[:M]: self.load_path(path, self.replay_buffer, obs_dict=obs_dict) if is_demo: for path in data[:M]: self.load_path(path, self.demo_train_buffer, obs_dict=obs_dict) for path in data[M:N]: self.load_path(path, self.demo_test_buffer, obs_dict=obs_dict)
def __init__( self, wrapped_env, disc, mode='train', reward_params=None, ): self.quick_init(locals()) super().__init__(wrapped_env) if type(disc) is str: self.disc = load_local_or_remote_file(disc) else: self.disc = disc self._num_skills = self.disc.num_skills self._p_z = np.full(self._num_skills, 1.0 / self._num_skills) self.task = {'context': -1} self.reward_params = reward_params self.reward_type = self.reward_params.get('type', 'diayn') spaces = copy.deepcopy(self.wrapped_env.observation_space.spaces) spaces['context'] = Discrete(n=self._num_skills) self.observation_space = Dict(spaces) assert self.reward_type == 'wrapped_env' or self.reward_type == 'diayn'
def experiment(variant): if variant.get("pretrained_algorithm_path", False): resume(variant) return normalize_env = variant.get("normalize_env", True) env_id = variant.get("env_id", None) env_params = ENV_PARAMS.get(env_id, {}) variant.update(env_params) env_class = variant.get("env_class", None) env_kwargs = variant.get("env_kwargs", {}) expl_env = make(env_id, env_class, env_kwargs, normalize_env) eval_env = make(env_id, env_class, env_kwargs, normalize_env) if variant.get("add_env_demos", False): variant["path_loader_kwargs"]["demo_paths"].append( variant["env_demo_path"]) if variant.get("add_env_offpolicy_data", False): variant["path_loader_kwargs"]["demo_paths"].append( variant["env_offpolicy_data_path"]) path_loader_kwargs = variant.get("path_loader_kwargs", {}) stack_obs = path_loader_kwargs.get("stack_obs", 1) if stack_obs > 1: expl_env = StackObservationEnv(expl_env, stack_obs=stack_obs) eval_env = StackObservationEnv(eval_env, stack_obs=stack_obs) obs_dim = expl_env.observation_space.low.size action_dim = eval_env.action_space.low.size if hasattr(expl_env, "info_sizes"): env_info_sizes = expl_env.info_sizes else: env_info_sizes = dict() qf_kwargs = variant.get("qf_kwargs", {}) qf1 = ConcatMlp(input_size=obs_dim + action_dim, output_size=1, **qf_kwargs) qf2 = ConcatMlp(input_size=obs_dim + action_dim, output_size=1, **qf_kwargs) target_qf1 = ConcatMlp(input_size=obs_dim + action_dim, output_size=1, **qf_kwargs) target_qf2 = ConcatMlp(input_size=obs_dim + action_dim, output_size=1, **qf_kwargs) policy_class = variant.get("policy_class", TanhGaussianPolicy) policy_kwargs = variant["policy_kwargs"] policy_path = variant.get("policy_path", False) if policy_path: policy = load_local_or_remote_file(policy_path) else: policy = policy_class( obs_dim=obs_dim, action_dim=action_dim, **policy_kwargs, ) buffer_policy_path = variant.get("buffer_policy_path", False) if buffer_policy_path: buffer_policy = load_local_or_remote_file(buffer_policy_path) else: buffer_policy_class = variant.get("buffer_policy_class", policy_class) buffer_policy = buffer_policy_class( obs_dim=obs_dim, action_dim=action_dim, **variant.get("buffer_policy_kwargs", policy_kwargs), ) eval_policy = MakeDeterministic(policy) eval_path_collector = MdpPathCollector( eval_env, eval_policy, ) expl_policy = policy exploration_kwargs = variant.get("exploration_kwargs", {}) if exploration_kwargs: if exploration_kwargs.get("deterministic_exploration", False): expl_policy = MakeDeterministic(policy) exploration_strategy = exploration_kwargs.get("strategy", None) if exploration_strategy is None: pass elif exploration_strategy == "ou": es = OUStrategy( action_space=expl_env.action_space, max_sigma=exploration_kwargs["noise"], min_sigma=exploration_kwargs["noise"], ) expl_policy = PolicyWrappedWithExplorationStrategy( exploration_strategy=es, policy=expl_policy, ) elif exploration_strategy == "gauss_eps": es = GaussianAndEpsilonStrategy( action_space=expl_env.action_space, max_sigma=exploration_kwargs["noise"], min_sigma=exploration_kwargs["noise"], # constant sigma epsilon=0, ) expl_policy = PolicyWrappedWithExplorationStrategy( exploration_strategy=es, policy=expl_policy, ) else: error main_replay_buffer_kwargs = dict( max_replay_buffer_size=variant["replay_buffer_size"], env=expl_env, ) replay_buffer_kwargs = dict( max_replay_buffer_size=variant["replay_buffer_size"], env=expl_env, ) replay_buffer = variant.get("replay_buffer_class", EnvReplayBuffer)(**main_replay_buffer_kwargs, ) if variant.get("use_validation_buffer", False): train_replay_buffer = replay_buffer validation_replay_buffer = variant.get( "replay_buffer_class", EnvReplayBuffer)(**main_replay_buffer_kwargs, ) replay_buffer = SplitReplayBuffer(train_replay_buffer, validation_replay_buffer, 0.9) trainer_class = variant.get("trainer_class", AWACTrainer) trainer = trainer_class( env=eval_env, policy=policy, qf1=qf1, qf2=qf2, target_qf1=target_qf1, target_qf2=target_qf2, buffer_policy=buffer_policy, **variant["trainer_kwargs"], ) if variant["collection_mode"] == "online": expl_path_collector = MdpStepCollector( expl_env, policy, ) algorithm = TorchOnlineRLAlgorithm( trainer=trainer, exploration_env=expl_env, evaluation_env=eval_env, exploration_data_collector=expl_path_collector, evaluation_data_collector=eval_path_collector, replay_buffer=replay_buffer, max_path_length=variant["max_path_length"], batch_size=variant["batch_size"], num_epochs=variant["num_epochs"], num_eval_steps_per_epoch=variant["num_eval_steps_per_epoch"], num_expl_steps_per_train_loop=variant[ "num_expl_steps_per_train_loop"], num_trains_per_train_loop=variant["num_trains_per_train_loop"], min_num_steps_before_training=variant[ "min_num_steps_before_training"], ) else: expl_path_collector = MdpPathCollector( expl_env, expl_policy, ) algorithm = TorchBatchRLAlgorithm( trainer=trainer, exploration_env=expl_env, evaluation_env=eval_env, exploration_data_collector=expl_path_collector, evaluation_data_collector=eval_path_collector, replay_buffer=replay_buffer, max_path_length=variant["max_path_length"], batch_size=variant["batch_size"], num_epochs=variant["num_epochs"], num_eval_steps_per_epoch=variant["num_eval_steps_per_epoch"], num_expl_steps_per_train_loop=variant[ "num_expl_steps_per_train_loop"], num_trains_per_train_loop=variant["num_trains_per_train_loop"], min_num_steps_before_training=variant[ "min_num_steps_before_training"], ) algorithm.to(ptu.device) demo_train_buffer = EnvReplayBuffer(**replay_buffer_kwargs, ) demo_test_buffer = EnvReplayBuffer(**replay_buffer_kwargs, ) if variant.get("save_video", False): if variant.get("presampled_goals", None): variant["image_env_kwargs"][ "presampled_goals"] = load_local_or_remote_file( variant["presampled_goals"]).item() def get_img_env(env): renderer = EnvRenderer(**variant["renderer_kwargs"]) img_env = InsertImageEnv(GymToMultiEnv(env), renderer=renderer) image_eval_env = ImageEnv(GymToMultiEnv(eval_env), **variant["image_env_kwargs"]) # image_eval_env = get_img_env(eval_env) image_eval_path_collector = ObsDictPathCollector( image_eval_env, eval_policy, observation_key="state_observation", ) image_expl_env = ImageEnv(GymToMultiEnv(expl_env), **variant["image_env_kwargs"]) # image_expl_env = get_img_env(expl_env) image_expl_path_collector = ObsDictPathCollector( image_expl_env, expl_policy, observation_key="state_observation", ) video_func = VideoSaveFunction( image_eval_env, variant, image_expl_path_collector, image_eval_path_collector, ) algorithm.post_train_funcs.append(video_func) if variant.get("save_paths", False): algorithm.post_train_funcs.append(save_paths) if variant.get("load_demos", False): path_loader_class = variant.get("path_loader_class", MDPPathLoader) path_loader = path_loader_class( trainer, replay_buffer=replay_buffer, demo_train_buffer=demo_train_buffer, demo_test_buffer=demo_test_buffer, **path_loader_kwargs, ) path_loader.load_demos() if variant.get("load_env_dataset_demos", False): path_loader_class = variant.get("path_loader_class", HDF5PathLoader) path_loader = path_loader_class( trainer, replay_buffer=replay_buffer, demo_train_buffer=demo_train_buffer, demo_test_buffer=demo_test_buffer, **path_loader_kwargs, ) path_loader.load_demos(expl_env.get_dataset()) if variant.get("save_initial_buffers", False): buffers = dict( replay_buffer=replay_buffer, demo_train_buffer=demo_train_buffer, demo_test_buffer=demo_test_buffer, ) buffer_path = osp.join(logger.get_snapshot_dir(), "buffers.p") pickle.dump(buffers, open(buffer_path, "wb")) if variant.get("pretrain_buffer_policy", False): trainer.pretrain_policy_with_bc( buffer_policy, replay_buffer.train_replay_buffer, replay_buffer.validation_replay_buffer, 10000, label="buffer", ) if variant.get("pretrain_policy", False): trainer.pretrain_policy_with_bc( policy, demo_train_buffer, demo_test_buffer, trainer.bc_num_pretrain_steps, ) if variant.get("pretrain_rl", False): trainer.pretrain_q_with_bc_data() if variant.get("save_pretrained_algorithm", False): p_path = osp.join(logger.get_snapshot_dir(), "pretrain_algorithm.p") pt_path = osp.join(logger.get_snapshot_dir(), "pretrain_algorithm.pt") data = algorithm._get_snapshot() data["algorithm"] = algorithm torch.save(data, open(pt_path, "wb")) torch.save(data, open(p_path, "wb")) if variant.get("train_rl", True): algorithm.train()
def generate_LSTM_vae_only_dataset(variant, segmented=False, segmentation_method='color'): from multiworld.core.image_env import ImageEnv, unormalize_image env_id = variant.get('env_id', None) N = variant.get('N', 500) test_p = variant.get('test_p', 0.9) imsize = variant.get('imsize', 48) num_channels = variant.get('num_channels', 3) init_camera = variant.get('init_camera', None) occlusion_prob = variant.get('occlusion_prob', 0) occlusion_level = variant.get('occlusion_level', 0.5) segmentation_kwargs = variant.get('segmentation_kwargs', {}) if segmentation_kwargs.get('segment') is not None: segmented = segmentation_kwargs.get('segment') assert env_id is not None, 'you must provide an env id!' obj = 'puck-pos' if env_id == 'SawyerDoorHookResetFreeEnv-v1': obj = 'door-angle' pjhome = os.environ['PJHOME'] if segmented: if 'unet' in segmentation_method: seg_name = 'seg-unet' else: seg_name = 'seg-' + segmentation_method else: seg_name = 'no-seg' if env_id == 'SawyerDoorHookResetFreeEnv-v1': seg_name += '-2' data_file_path = osp.join( pjhome, 'data/local/pre-train-lstm', 'vae-only-{}-{}-{}-{}-{}.npy'.format(env_id, seg_name, N, occlusion_prob, occlusion_level)) obj_state_path = osp.join( pjhome, 'data/local/pre-train-lstm', 'vae-only-{}-{}-{}-{}-{}-{}.npy'.format(env_id, seg_name, N, occlusion_prob, occlusion_level, obj)) print(data_file_path) if osp.exists(data_file_path): all_data = np.load(data_file_path) if len(all_data) >= N: print("load stored data at: ", data_file_path) n = int(len(all_data) * test_p) train_dataset = all_data[:n] test_dataset = all_data[n:] obj_states = np.load(obj_state_path) info = {'obj_state': obj_states} return train_dataset, test_dataset, info if segmented: print( "generating lstm vae pretrain only dataset with segmented images using method: ", segmentation_method) if segmentation_method == 'unet': segment_func = segment_image_unet else: raise NotImplementedError else: print("generating lstm vae pretrain only dataset with original images") info = {} dataset = np.zeros((N, imsize * imsize * num_channels), dtype=np.uint8) imgs = [] obj_states = None if env_id == 'SawyerDoorHookResetFreeEnv-v1': from rlkit.util.io import load_local_or_remote_file pjhome = os.environ['PJHOME'] pre_sampled_goal_path = osp.join( pjhome, 'data/local/pre-train-vae/door_original_dataset.npy') goal_dict = np.load(pre_sampled_goal_path, allow_pickle=True).item() imgs = goal_dict['image_desired_goal'] door_angles = goal_dict['state_desired_goal'][:, -1] obj_states = door_angles[:, np.newaxis] elif env_id == 'SawyerPickupEnvYZEasy-v0': from rlkit.util.io import load_local_or_remote_file pjhome = os.environ['PJHOME'] pre_sampled_goal_path = osp.join( pjhome, 'data/local/pre-train-vae/pickup-original-dataset.npy') goal_dict = load_local_or_remote_file(pre_sampled_goal_path).item() imgs = goal_dict['image_desired_goal'] puck_pos = goal_dict['state_desired_goal'][:, 3:] obj_states = puck_pos else: import gym import multiworld multiworld.register_all_envs() env = gym.make(env_id) if not isinstance(env, ImageEnv): env = ImageEnv( env, imsize, init_camera=init_camera, transpose=True, normalize=True, ) env.reset() info['env'] = env puck_pos = np.zeros((N, 2), dtype=np.float) for i in range(N): print("lstm vae pretrain only dataset generation, number: ", i) if env_id == 'SawyerPushHurdle-v0': obs, puck_p = _generate_sawyerhurdle_dataset( env, return_puck_pos=True, segmented=segmented) elif env_id == 'SawyerPushHurdleMiddle-v0': obs, puck_p = _generate_sawyerhurdlemiddle_dataset( env, return_puck_pos=True) elif env_id == 'SawyerPushNIPSEasy-v0': obs, puck_p = _generate_sawyerpushnipseasy_dataset( env, return_puck_pos=True) elif env_id == 'SawyerPushHurdleResetFreeEnv-v0': obs, puck_p = _generate_sawyerhurldeblockresetfree_dataset( env, return_puck_pos=True) else: raise NotImplementedError img = obs[ 'image_observation'] # NOTE: this is already normalized image, of detype np.float64. imgs.append(img) puck_pos[i] = puck_p obj_states = puck_pos # now we segment the images for i in range(N): print("segmenting image ", i) img = imgs[i] if segmented: dataset[i, :] = segment_func(img, normalize=False, **segmentation_kwargs) p = np.random.rand( ) # manually drop some images, so as to make occlusions if p < occlusion_prob: mask = (np.random.uniform(low=0, high=1, size=(imsize, imsize)) > occlusion_level).astype(np.uint8) img = dataset[i].reshape(3, imsize, imsize).transpose() img[mask < 1] = 0 dataset[i] = img.transpose().flatten() else: dataset[i, :] = unormalize_image(img) # add the trajectory dimension dataset = dataset[:, np.newaxis, :] # batch_size x traj_len = 1 x imlen obj_states = obj_states[:, np.newaxis, :] # batch_size x traj_len = 1 x imlen info['obj_state'] = obj_states n = int(N * test_p) train_dataset = dataset[:n] test_dataset = dataset[n:] if N >= 500: print('save data to: ', data_file_path) all_data = np.concatenate([train_dataset, test_dataset], axis=0) np.save(data_file_path, all_data) np.save(obj_state_path, obj_states) return train_dataset, test_dataset, info
def train_pixelcnn( vqvae=None, vqvae_path=None, num_epochs=100, batch_size=32, n_layers=15, dataset_path=None, save=True, save_period=10, cached_dataset_path=False, trainer_kwargs=None, model_kwargs=None, data_filter_fn=lambda x: x, debug=False, data_size=float('inf'), num_train_batches_per_epoch=None, num_test_batches_per_epoch=None, train_img_loader=None, test_img_loader=None, ): trainer_kwargs = {} if trainer_kwargs is None else trainer_kwargs model_kwargs = {} if model_kwargs is None else model_kwargs # Load VQVAE + Define Args if vqvae is None: vqvae = load_local_or_remote_file(vqvae_path) vqvae.to(ptu.device) vqvae.eval() root_len = vqvae.root_len num_embeddings = vqvae.num_embeddings embedding_dim = vqvae.embedding_dim cond_size = vqvae.num_embeddings imsize = vqvae.imsize discrete_size = root_len * root_len representation_size = embedding_dim * discrete_size input_channels = vqvae.input_channels imlength = imsize * imsize * input_channels log_dir = logger.get_snapshot_dir() # Define data loading info new_path = osp.join(log_dir, 'pixelcnn_data.npy') def prep_sample_data(cached_path): data = load_local_or_remote_file(cached_path).item() train_data = data['train'] test_data = data['test'] return train_data, test_data def encode_dataset(path, object_list): data = load_local_or_remote_file(path) data = data.item() data = data_filter_fn(data) all_data = [] n = min(data["observations"].shape[0], data_size) for i in tqdm(range(n)): obs = ptu.from_numpy(data["observations"][i] / 255.0) latent = vqvae.encode(obs, cont=False) all_data.append(latent) encodings = ptu.get_numpy(torch.stack(all_data, dim=0)) return encodings if train_img_loader: _, test_loader, test_batch_loader = create_conditional_data_loader( test_img_loader, 80, vqvae, "test2") # 80 _, train_loader, train_batch_loader = create_conditional_data_loader( train_img_loader, 2000, vqvae, "train2") # 2000 else: if cached_dataset_path: train_data, test_data = prep_sample_data(cached_dataset_path) else: train_data = encode_dataset(dataset_path['train'], None) # object_list) test_data = encode_dataset(dataset_path['test'], None) dataset = {'train': train_data, 'test': test_data} np.save(new_path, dataset) _, _, train_loader, test_loader, _ = \ rlkit.torch.vae.pixelcnn_utils.load_data_and_data_loaders(new_path, 'COND_LATENT_BLOCK', batch_size) #train_dataset = InfiniteBatchLoader(train_loader) #test_dataset = InfiniteBatchLoader(test_loader) print("Finished loading data") model = GatedPixelCNN(num_embeddings, root_len**2, n_classes=representation_size, **model_kwargs).to(ptu.device) trainer = PixelCNNTrainer( model, vqvae, batch_size=batch_size, **trainer_kwargs, ) print("Starting training") BEST_LOSS = 999 for epoch in range(num_epochs): should_save = (epoch % save_period == 0) and (epoch > 0) trainer.train_epoch(epoch, train_loader, num_train_batches_per_epoch) trainer.test_epoch(epoch, test_loader, num_test_batches_per_epoch) test_data = test_batch_loader.random_batch(bz)["x"] train_data = train_batch_loader.random_batch(bz)["x"] trainer.dump_samples(epoch, test_data, test=True) trainer.dump_samples(epoch, train_data, test=False) if should_save: logger.save_itr_params(epoch, model) stats = trainer.get_diagnostics() cur_loss = stats["test/loss"] if cur_loss < BEST_LOSS: BEST_LOSS = cur_loss vqvae.set_pixel_cnn(model) logger.save_extra_data(vqvae, 'best_vqvae', mode='torch') else: return vqvae for k, v in stats.items(): logger.record_tabular(k, v) logger.dump_tabular() trainer.end_epoch(epoch) return vqvae
def generate_vae_dataset(variant): """ If not provided a pre-train vae dataset generation function, this function will be used to collect the dataset for training vae. """ import rlkit.torch.pytorch_util as ptu import gym import multiworld multiworld.register_all_envs() print("generating vae dataset with original images") env_class = variant.get('env_class', None) env_kwargs = variant.get('env_kwargs', None) env_id = variant.get('env_id', None) N = variant.get('N', 10000) test_p = variant.get('test_p', 0.9) use_cached = variant.get('use_cached', True) imsize = variant.get('imsize', 84) num_channels = variant.get('num_channels', 3) show = variant.get('show', False) init_camera = variant.get('init_camera', None) dataset_path = variant.get('dataset_path', None) oracle_dataset_using_set_to_goal = variant.get( 'oracle_dataset_using_set_to_goal', False) random_rollout_data = variant.get('random_rollout_data', False) random_and_oracle_policy_data = variant.get( 'random_and_oracle_policy_data', False) random_and_oracle_policy_data_split = variant.get( 'random_and_oracle_policy_data_split', 0) policy_file = variant.get('policy_file', None) n_random_steps = variant.get('n_random_steps', 100) vae_dataset_specific_env_kwargs = variant.get( 'vae_dataset_specific_env_kwargs', None) save_file_prefix = variant.get('save_file_prefix', None) non_presampled_goal_img_is_garbage = variant.get( 'non_presampled_goal_img_is_garbage', None) tag = variant.get('tag', '') info = {} if dataset_path is not None: print('load vae training dataset from: ', dataset_path) pjhome = os.environ['PJHOME'] dataset = np.load(osp.join(pjhome, dataset_path), allow_pickle=True).item() if isinstance(dataset, dict): dataset = dataset['image_desired_goal'] dataset = unormalize_image(dataset) N = dataset.shape[0] else: if env_kwargs is None: env_kwargs = {} if save_file_prefix is None: save_file_prefix = env_id if save_file_prefix is None: save_file_prefix = env_class.__name__ filename = "/tmp/{}_N{}_{}_imsize{}_random_oracle_split_{}{}.npy".format( save_file_prefix, str(N), init_camera.__name__ if init_camera else '', imsize, random_and_oracle_policy_data_split, tag, ) if use_cached and osp.isfile(filename): dataset = np.load(filename) print("loaded data from saved file", filename) else: now = time.time() if env_id is not None: import gym import multiworld multiworld.register_all_envs() env = gym.make(env_id) else: if vae_dataset_specific_env_kwargs is None: vae_dataset_specific_env_kwargs = {} for key, val in env_kwargs.items(): if key not in vae_dataset_specific_env_kwargs: vae_dataset_specific_env_kwargs[key] = val env = env_class(**vae_dataset_specific_env_kwargs) if not isinstance(env, ImageEnv): env = ImageEnv( env, imsize, init_camera=init_camera, transpose=True, normalize=True, non_presampled_goal_img_is_garbage= non_presampled_goal_img_is_garbage, ) else: imsize = env.imsize env.non_presampled_goal_img_is_garbage = non_presampled_goal_img_is_garbage env.reset() info['env'] = env if random_and_oracle_policy_data: policy_file = load_local_or_remote_file(policy_file) policy = policy_file['policy'] policy.to(ptu.device) if random_rollout_data: from rlkit.exploration_strategies.ou_strategy import OUStrategy policy = OUStrategy(env.action_space) dataset = np.zeros((N, imsize * imsize * num_channels), dtype=np.uint8) for i in range(N): if random_and_oracle_policy_data: num_random_steps = int(N * random_and_oracle_policy_data_split) if i < num_random_steps: env.reset() for _ in range(n_random_steps): obs = env.step(env.action_space.sample())[0] else: obs = env.reset() policy.reset() for _ in range(n_random_steps): policy_obs = np.hstack(( obs['state_observation'], obs['state_desired_goal'], )) action, _ = policy.get_action(policy_obs) obs, _, _, _ = env.step(action) elif oracle_dataset_using_set_to_goal: print(i) goal = env.sample_goal() env.set_to_goal(goal) obs = env._get_obs() elif random_rollout_data: if i % n_random_steps == 0: g = dict( state_desired_goal=env.sample_goal_for_rollout()) env.set_to_goal(g) policy.reset() # env.reset() u = policy.get_action_from_raw_action( env.action_space.sample()) obs = env.step(u)[0] else: print("using totally random rollouts") for _ in range(n_random_steps): obs = env.step(env.action_space.sample())[0] img = obs[ 'image_observation'] # NOTE yufei: this is already normalized image, of detype np.float64. dataset[i, :] = unormalize_image(img) np.save(filename, dataset) n = int(N * test_p) train_dataset = dataset[:n, :] test_dataset = dataset[n:, :] return train_dataset, test_dataset, info
def prep_sample_data(cached_path): data = load_local_or_remote_file(cached_path).item() train_data = data['train'] test_data = data['test'] return train_data, test_data
def pearl_experiment( qf_kwargs=None, vf_kwargs=None, trainer_kwargs=None, algo_kwargs=None, context_encoder_kwargs=None, context_decoder_kwargs=None, policy_kwargs=None, env_name=None, env_params=None, latent_dim=None, # video/debug debug=False, _debug_do_not_sqrt=False, # PEARL n_train_tasks=0, n_eval_tasks=0, use_next_obs_in_context=False, saved_tasks_path=None, tags=None, ): del tags register_pearl_envs() env_params = env_params or {} context_encoder_kwargs = context_encoder_kwargs or {} context_decoder_kwargs = context_decoder_kwargs or {} trainer_kwargs = trainer_kwargs or {} base_env = ENVS[env_name](**env_params) if saved_tasks_path: task_data = load_local_or_remote_file(saved_tasks_path, file_type='joblib') tasks = task_data['tasks'] train_task_idxs = task_data['train_task_indices'] eval_task_idxs = task_data['eval_task_indices'] base_env.tasks = tasks else: tasks = base_env.tasks task_indices = base_env.get_all_task_idx() train_task_idxs = list(task_indices[:n_train_tasks]) eval_task_idxs = list(task_indices[-n_eval_tasks:]) if hasattr(base_env, 'task_to_vec'): train_tasks = [base_env.task_to_vec(tasks[i]) for i in train_task_idxs] eval_tasks = [base_env.task_to_vec(tasks[i]) for i in eval_task_idxs] else: train_tasks = [tasks[i] for i in train_task_idxs] eval_tasks = [tasks[i] for i in eval_task_idxs] expl_env = NormalizedBoxEnv(base_env) eval_env = NormalizedBoxEnv(ENVS[env_name](**env_params)) eval_env.tasks = expl_env.tasks reward_dim = 1 if debug: algo_kwargs['max_path_length'] = 50 algo_kwargs['batch_size'] = 5 algo_kwargs['num_epochs'] = 5 algo_kwargs['num_eval_steps_per_epoch'] = 100 algo_kwargs['num_expl_steps_per_train_loop'] = 100 algo_kwargs['num_trains_per_train_loop'] = 10 algo_kwargs['min_num_steps_before_training'] = 100 obs_dim = expl_env.observation_space.low.size action_dim = eval_env.action_space.low.size if use_next_obs_in_context: context_encoder_input_dim = 2 * obs_dim + action_dim + reward_dim else: context_encoder_input_dim = obs_dim + action_dim + reward_dim context_encoder_output_dim = latent_dim * 2 def create_qf(): return ConcatMlp(input_size=obs_dim + action_dim + latent_dim, output_size=1, **qf_kwargs) qf1 = create_qf() qf2 = create_qf() vf = ConcatMlp(input_size=obs_dim + latent_dim, output_size=1, **vf_kwargs) policy = TanhGaussianPolicy( obs_dim=obs_dim + latent_dim, action_dim=action_dim, **policy_kwargs, ) context_encoder = MlpEncoder(input_size=context_encoder_input_dim, output_size=context_encoder_output_dim, hidden_sizes=[200, 200, 200], **context_encoder_kwargs) context_decoder = MlpDecoder(input_size=obs_dim + action_dim + latent_dim, output_size=1, **context_decoder_kwargs) reward_predictor = context_decoder agent = SmacAgent( latent_dim, context_encoder, policy, reward_predictor, use_next_obs_in_context=use_next_obs_in_context, _debug_do_not_sqrt=_debug_do_not_sqrt, ) trainer = PEARLSoftActorCriticTrainer(latent_dim=latent_dim, agent=agent, qf1=qf1, qf2=qf2, vf=vf, reward_predictor=reward_predictor, context_encoder=context_encoder, context_decoder=context_decoder, **trainer_kwargs) algorithm = MetaRLAlgorithm( agent=agent, env=expl_env, trainer=trainer, train_task_indices=train_task_idxs, eval_task_indices=eval_task_idxs, train_tasks=train_tasks, eval_tasks=eval_tasks, use_next_obs_in_context=use_next_obs_in_context, env_info_sizes=get_env_info_sizes(expl_env), **algo_kwargs) saved_path = logger.save_extra_data( data=dict( tasks=expl_env.tasks, train_task_indices=train_task_idxs, eval_task_indices=eval_task_idxs, train_tasks=train_tasks, eval_tasks=eval_tasks, ), file_name='tasks_description', ) print('saved tasks description to', saved_path) saved_path = logger.save_extra_data( expl_env.tasks, file_name='tasks', mode='pickle', ) print('saved raw tasks to', saved_path) algorithm.to(ptu.device) algorithm.to(ptu.device) algorithm.train()
def smac_experiment( trainer_kwargs=None, algo_kwargs=None, qf_kwargs=None, policy_kwargs=None, context_encoder_kwargs=None, context_decoder_kwargs=None, env_name=None, env_params=None, path_loader_kwargs=None, latent_dim=None, policy_class="TanhGaussianPolicy", # video/debug debug=False, use_dummy_encoder=False, networks_ignore_context=False, use_ground_truth_context=False, save_video=False, save_video_period=False, # Pre-train params pretrain_rl=False, pretrain_offline_algo_kwargs=None, pretrain_buffer_kwargs=None, load_buffer_kwargs=None, saved_tasks_path=None, macaw_format_base_path=None, # overrides saved_tasks_path and load_buffer_kwargs load_macaw_buffer_kwargs=None, train_task_idxs=None, eval_task_idxs=None, relabel_offline_dataset=False, skip_initial_data_collection_if_pretrained=False, relabel_kwargs=None, # PEARL n_train_tasks=0, n_eval_tasks=0, use_next_obs_in_context=False, tags=None, online_trainer_kwargs=None, ): if not skip_initial_data_collection_if_pretrained: raise NotImplementedError("deprecated! make sure to skip it!") if relabel_kwargs is None: relabel_kwargs = {} del tags pretrain_buffer_kwargs = pretrain_buffer_kwargs or {} context_decoder_kwargs = context_decoder_kwargs or {} pretrain_offline_algo_kwargs = pretrain_offline_algo_kwargs or {} online_trainer_kwargs = online_trainer_kwargs or {} register_pearl_envs() env_params = env_params or {} context_encoder_kwargs = context_encoder_kwargs or {} trainer_kwargs = trainer_kwargs or {} path_loader_kwargs = path_loader_kwargs or {} load_macaw_buffer_kwargs = load_macaw_buffer_kwargs or {} base_env = ENVS[env_name](**env_params) if saved_tasks_path: task_data = load_local_or_remote_file(saved_tasks_path, file_type='joblib') tasks = task_data['tasks'] train_task_idxs = task_data['train_task_indices'] eval_task_idxs = task_data['eval_task_indices'] base_env.tasks = tasks elif macaw_format_base_path is not None: tasks = pickle.load( open('{}/tasks.pkl'.format(macaw_format_base_path), 'rb')) base_env.tasks = tasks else: tasks = base_env.tasks task_indices = base_env.get_all_task_idx() train_task_idxs = list(task_indices[:n_train_tasks]) eval_task_idxs = list(task_indices[-n_eval_tasks:]) if hasattr(base_env, 'task_to_vec'): train_tasks = [base_env.task_to_vec(tasks[i]) for i in train_task_idxs] eval_tasks = [base_env.task_to_vec(tasks[i]) for i in eval_task_idxs] else: train_tasks = [tasks[i] for i in train_task_idxs] eval_tasks = [tasks[i] for i in eval_task_idxs] if use_ground_truth_context: latent_dim = len(train_tasks[0]) expl_env = NormalizedBoxEnv(base_env) reward_dim = 1 if debug: algo_kwargs['max_path_length'] = 50 algo_kwargs['batch_size'] = 5 algo_kwargs['num_epochs'] = 5 algo_kwargs['num_eval_steps_per_epoch'] = 100 algo_kwargs['num_expl_steps_per_train_loop'] = 100 algo_kwargs['num_trains_per_train_loop'] = 10 algo_kwargs['min_num_steps_before_training'] = 100 obs_dim = expl_env.observation_space.low.size action_dim = expl_env.action_space.low.size if use_next_obs_in_context: context_encoder_input_dim = 2 * obs_dim + action_dim + reward_dim else: context_encoder_input_dim = obs_dim + action_dim + reward_dim context_encoder_output_dim = latent_dim * 2 def create_qf(): return ConcatMlp(input_size=obs_dim + action_dim + latent_dim, output_size=1, **qf_kwargs) qf1 = create_qf() qf2 = create_qf() target_qf1 = create_qf() target_qf2 = create_qf() if isinstance(policy_class, str): policy_class = policy_class_from_str(policy_class) policy = policy_class( obs_dim=obs_dim + latent_dim, action_dim=action_dim, **policy_kwargs, ) encoder_class = DummyMlpEncoder if use_dummy_encoder else MlpEncoder context_encoder = encoder_class( input_size=context_encoder_input_dim, output_size=context_encoder_output_dim, hidden_sizes=[200, 200, 200], use_ground_truth_context=use_ground_truth_context, **context_encoder_kwargs) context_decoder = MlpDecoder(input_size=obs_dim + action_dim + latent_dim, output_size=1, **context_decoder_kwargs) reward_predictor = context_decoder agent = SmacAgent( latent_dim, context_encoder, policy, reward_predictor, use_next_obs_in_context=use_next_obs_in_context, _debug_ignore_context=networks_ignore_context, _debug_use_ground_truth_context=use_ground_truth_context, ) trainer = SmacTrainer( agent=agent, env=expl_env, latent_dim=latent_dim, qf1=qf1, qf2=qf2, target_qf1=target_qf1, target_qf2=target_qf2, reward_predictor=reward_predictor, context_encoder=context_encoder, context_decoder=context_decoder, _debug_ignore_context=networks_ignore_context, _debug_use_ground_truth_context=use_ground_truth_context, **trainer_kwargs) algorithm = MetaRLAlgorithm( agent=agent, env=expl_env, trainer=trainer, train_task_indices=train_task_idxs, eval_task_indices=eval_task_idxs, train_tasks=train_tasks, eval_tasks=eval_tasks, use_next_obs_in_context=use_next_obs_in_context, use_ground_truth_context=use_ground_truth_context, env_info_sizes=get_env_info_sizes(expl_env), **algo_kwargs) if macaw_format_base_path: load_macaw_buffer_onto_algo(algo=algorithm, base_directory=macaw_format_base_path, train_task_idxs=train_task_idxs, **load_macaw_buffer_kwargs) elif load_buffer_kwargs: load_buffer_onto_algo(algorithm, **load_buffer_kwargs) if relabel_offline_dataset: relabel_offline_data(algorithm, tasks=tasks, env=expl_env.wrapped_env, **relabel_kwargs) if path_loader_kwargs: replay_buffer = algorithm.replay_buffer.task_buffers[0] enc_replay_buffer = algorithm.enc_replay_buffer.task_buffers[0] demo_test_buffer = EnvReplayBuffer(env=expl_env, **pretrain_buffer_kwargs) path_loader = MDPPathLoader(trainer, replay_buffer=replay_buffer, demo_train_buffer=enc_replay_buffer, demo_test_buffer=demo_test_buffer, **path_loader_kwargs) path_loader.load_demos() if pretrain_rl: eval_pearl_fn = EvalPearl(algorithm, train_task_idxs, eval_task_idxs) pretrain_algo = OfflineMetaRLAlgorithm( meta_replay_buffer=algorithm.meta_replay_buffer, replay_buffer=algorithm.replay_buffer, task_embedding_replay_buffer=algorithm.enc_replay_buffer, trainer=trainer, train_tasks=train_task_idxs, extra_eval_fns=[eval_pearl_fn], use_meta_learning_buffer=algorithm.use_meta_learning_buffer, **pretrain_offline_algo_kwargs) pretrain_algo.to(ptu.device) logger.remove_tabular_output('progress.csv', relative_to_snapshot_dir=True) logger.add_tabular_output('pretrain.csv', relative_to_snapshot_dir=True) pretrain_algo.train() logger.remove_tabular_output('pretrain.csv', relative_to_snapshot_dir=True) logger.add_tabular_output( 'progress.csv', relative_to_snapshot_dir=True, ) if skip_initial_data_collection_if_pretrained: algorithm.num_initial_steps = 0 algorithm.trainer.configure(**online_trainer_kwargs) algorithm.to(ptu.device) algorithm.train()
def generate_vae_dataset(cfgs): env_id = cfgs.ENV.id img_size = cfgs.ENV.imsize init_camera = cfgs.ENV.init_camera N = cfgs.GENERATE_VAE_DATASET.N use_cached = cfgs.GENERATE_VAE_DATASET.use_cached n_random_steps = cfgs.GENERATE_VAE_DATASET.n_random_steps dataset_path = cfgs.GENERATE_VAE_DATASET.dataset_path # FIXME non_presampled_goal_img_is_garbage = cfgs.GENERATE_VAE_DATASET.non_presampled_goal_img_is_garbage random_and_oracle_policy_data_split = cfgs.GENERATE_VAE_DATASET.random_and_oracle_policy_data_split random_and_oracle_policy_data = cfgs.GENERATE_VAE_DATASET.random_and_oracle_policy_data random_rollout_data = cfgs.GENERATE_VAE_DATASET.random_rollout_data oracle_dataset_using_set_to_goal = cfgs.GENERATE_VAE_DATASET.oracle_dataset_using_set_to_goal num_channels = cfgs.VAE.input_channels policy_file = cfgs.POLICY.model_path from roworld.core.image_env import ImageEnv, unormalize_image import rlkit.torch.pytorch_util as ptu info = {} if dataset_path is not None: dataset = load_local_or_remote_file(dataset_path) N = dataset.shape[0] else: filename = "/tmp/{}_N{}_{}_size{}_random_oracle_split_{}.npy".format( env_id, str(N), init_camera.__name__ if init_camera else '', img_size, random_and_oracle_policy_data_split, ) if use_cached and osp.isfile(filename): dataset = np.load(filename) print("loaded data from saved file", filename) else: now = time.time() assert env_id is not None import gym import roworld roworld.register_all_envs() env = gym.make(env_id) if not isinstance(env, ImageEnv): env = ImageEnv( env, img_size, init_camera=init_camera, transpose=True, normalize=True, non_presampled_goal_img_is_garbage= non_presampled_goal_img_is_garbage, ) else: env.imsize = img_size env.non_presampled_goal_img_is_garbage = non_presampled_goal_img_is_garbage env.reset() info['env'] = env if random_and_oracle_policy_data: policy_file = load_local_or_remote_file(policy_file) policy = policy_file['policy'] policy.to(ptu.device) if random_rollout_data: from rlkit.exploration_strategies.ou_strategy import OUStrategy policy = OUStrategy(env.action_space) dataset = np.zeros((N, img_size * img_size * num_channels), dtype=np.uint8) obs = env.reset() for i in range(N): if random_and_oracle_policy_data: num_random_steps = int(N * random_and_oracle_policy_data_split) if i < num_random_steps: # Randomly obtain observation env.reset() for _ in range(n_random_steps): obs = env.step(env.action_space.sample())[0] else: # Obtain observation with policy obs = env.reset() policy.reset() for _ in range(n_random_steps): policy_obs = np.hstack(( obs['state_observation'], obs['state_desired_goal'], )) action, _ = policy.get_action(policy_obs) obs, _, _, _ = env.step(action) elif oracle_dataset_using_set_to_goal: goal = env.sample_goal() env.set_to_goal(goal) obs = env._get_obs() elif random_rollout_data: if i % n_random_steps == 0: g = dict( state_desired_goal=env.sample_goal_for_rollout()) env.set_to_goal(g) policy.reset() # env.reset() u = policy.get_action_from_raw_action( env.action_space.sample()) obs = env.step(u)[0] else: env.reset() # The output obs will be the last observation after stepping n_random_steps for _ in range(n_random_steps): obs = env.step(env.action_space.sample())[0] img = obs['image_observation'] dataset[i, :] = unormalize_image(img) if cfgs.GENERATE_VAE_DATASET.show: img = img.reshape(3, img_size, img_size).transpose() img = img[::-1, :, ::-1] cv2.imshow('img', img) cv2.waitKey(1) # radius = input('waiting...') print("Done making training data", filename, time.time() - now) np.save(filename, dataset) n = int(N * cfgs.GENERATE_VAE_DATASET.ratio) train_dataset = dataset[:n, :] test_dataset = dataset[n:, :] return train_dataset, test_dataset, info
def get_envs(cfgs): from roworld.core.image_env import ImageEnv from rlkit.envs.vae_wrapper import VAEWrappedEnv from rlkit.util.io import load_local_or_remote_file render = cfgs.get('render', False) reward_params = cfgs.get("reward_params", dict()) do_state_exp = cfgs.get("do_state_exp", False) # TODO vae_path = cfgs.VAE_TRAINER.get("vae_path", None) init_camera = cfgs.ENV.get("init_camera", None) presample_goals = cfgs.SKEW_FIT.get('presample_goals', False) presample_image_goals_only = cfgs.SKEW_FIT.get( 'presample_image_goals_only', False) presampled_goals_path = cfgs.SKEW_FIT.get('presampled_goals_path', None) vae = load_local_or_remote_file( vae_path) if type(vae_path) is str else vae_path if cfgs.ENV.id: import gym import roworld roworld.register_all_envs() env = gym.make(cfgs.ENV.id) else: env = cfgs.ENV.cls(**cfgs.ENV.kwargs) if not do_state_exp: if isinstance(env, ImageEnv): image_env = env else: image_env = ImageEnv( env, cfgs.ENV.imsize, init_camera=init_camera, transpose=True, normalize=True, ) if presample_goals: """ This will fail for online-parallel as presampled_goals will not be serialized. Also don't use this for online-vae. """ if presampled_goals_path is None: image_env.non_presampled_goal_img_is_garbage = True vae_env = VAEWrappedEnv(image_env, vae, imsize=image_env.imsize, decode_goals=render, render_goals=render, render_rollouts=render, reward_params=reward_params, **cfgs.get('vae_wrapped_env_kwargs', {})) presampled_goals = cfgs['generate_goal_dataset_fctn']( env=vae_env, env_id=cfgs.get('env_id', None), **cfgs['goal_generation_kwargs']) del vae_env else: presampled_goals = load_local_or_remote_file( presampled_goals_path).item() del image_env image_env = ImageEnv( env, cfgs.ENV.get('imsize'), init_camera=init_camera, transpose=True, normalize=True, presampled_goals=presampled_goals, ) vae_env = VAEWrappedEnv( image_env, vae, imsize=image_env.imsize, decode_goals=render, render_goals=render, render_rollouts=render, reward_params=reward_params, presampled_goals=presampled_goals, sample_from_true_prior=True, ) print("Pre sampling all goals only") else: vae_env = VAEWrappedEnv( image_env, vae, imsize=image_env.imsize, decode_goals=render, render_goals=render, render_rollouts=render, reward_params=reward_params, goal_sampling_mode='vae_prior', sample_from_true_prior=True, ) if presample_image_goals_only: presampled_goals = cfgs['generate_goal_dataset_fctn']( image_env=vae_env.wrapped_env, **cfgs['goal_generation_kwargs']) image_env.set_presampled_goals(presampled_goals) print("Pre sampling image goals only") else: print("Not using presampled goals") env = vae_env return env
def __init__(self, filename): self.data = load_local_or_remote_file(filename)
def generate_vae_dataset(variant): env_class = variant.get('env_class', None) env_kwargs = variant.get('env_kwargs', None) env_id = variant.get('env_id', None) N = variant.get('N', 10000) test_p = variant.get('test_p', 0.9) use_cached = variant.get('use_cached', True) imsize = variant.get('imsize', 84) num_channels = variant.get('num_channels', 3) show = variant.get('show', False) init_camera = variant.get('init_camera', None) dataset_path = variant.get('dataset_path', None) oracle_dataset_using_set_to_goal = variant.get( 'oracle_dataset_using_set_to_goal', False) random_rollout_data = variant.get('random_rollout_data', False) random_and_oracle_policy_data = variant.get( 'random_and_oracle_policy_data', False) random_and_oracle_policy_data_split = variant.get( 'random_and_oracle_policy_data_split', 0) policy_file = variant.get('policy_file', None) n_random_steps = variant.get('n_random_steps', 100) vae_dataset_specific_env_kwargs = variant.get( 'vae_dataset_specific_env_kwargs', None) save_file_prefix = variant.get('save_file_prefix', None) non_presampled_goal_img_is_garbage = variant.get( 'non_presampled_goal_img_is_garbage', None) tag = variant.get('tag', '') from multiworld.core.image_env import ImageEnv, unormalize_image import rlkit.torch.pytorch_util as ptu info = {} if dataset_path is not None: dataset = load_local_or_remote_file(dataset_path) N = dataset.shape[0] else: if env_kwargs is None: env_kwargs = {} if save_file_prefix is None: save_file_prefix = env_id if save_file_prefix is None: save_file_prefix = env_class.__name__ filename = "/tmp/{}_N{}_{}_imsize{}_random_oracle_split_{}{}.npy".format( save_file_prefix, str(N), init_camera.__name__ if init_camera else '', imsize, random_and_oracle_policy_data_split, tag, ) if use_cached and osp.isfile(filename): dataset = np.load(filename) print("loaded data from saved file", filename) else: now = time.time() if env_id is not None: import gym import multiworld multiworld.register_all_envs() env = gym.make(env_id) else: if vae_dataset_specific_env_kwargs is None: vae_dataset_specific_env_kwargs = {} for key, val in env_kwargs.items(): if key not in vae_dataset_specific_env_kwargs: vae_dataset_specific_env_kwargs[key] = val env = env_class(**vae_dataset_specific_env_kwargs) if not isinstance(env, ImageEnv): env = ImageEnv( env, imsize, init_camera=init_camera, transpose=True, normalize=True, non_presampled_goal_img_is_garbage= non_presampled_goal_img_is_garbage, ) else: imsize = env.imsize env.non_presampled_goal_img_is_garbage = non_presampled_goal_img_is_garbage env.reset() info['env'] = env if random_and_oracle_policy_data: policy_file = load_local_or_remote_file(policy_file) policy = policy_file['policy'] policy.to(ptu.device) if random_rollout_data: from rlkit.exploration_strategies.ou_strategy import OUStrategy policy = OUStrategy(env.action_space) dataset = np.zeros((N, imsize * imsize * num_channels), dtype=np.uint8) for i in range(N): if random_and_oracle_policy_data: num_random_steps = int(N * random_and_oracle_policy_data_split) if i < num_random_steps: env.reset() for _ in range(n_random_steps): obs = env.step(env.action_space.sample())[0] else: obs = env.reset() policy.reset() for _ in range(n_random_steps): policy_obs = np.hstack(( obs['state_observation'], obs['state_desired_goal'], )) action, _ = policy.get_action(policy_obs) obs, _, _, _ = env.step(action) elif oracle_dataset_using_set_to_goal: print(i) goal = env.sample_goal() env.set_to_goal(goal) obs = env._get_obs() elif random_rollout_data: if i % n_random_steps == 0: g = dict( state_desired_goal=env.sample_goal_for_rollout()) env.set_to_goal(g) policy.reset() # env.reset() u = policy.get_action_from_raw_action( env.action_space.sample()) obs = env.step(u)[0] else: env.reset() for _ in range(n_random_steps): obs = env.step(env.action_space.sample())[0] img = obs['image_observation'] dataset[i, :] = unormalize_image(img) if show: img = img.reshape(3, imsize, imsize).transpose() img = img[::-1, :, ::-1] cv2.imshow('img', img) cv2.waitKey(1) # radius = input('waiting...') print("done making training data", filename, time.time() - now) np.save(filename, dataset) n = int(N * test_p) train_dataset = dataset[:n, :] test_dataset = dataset[n:, :] return train_dataset, test_dataset, info
def getdata(variant): skewfit_variant = variant['skewfit_variant'] print('-------------------------------') skewfit_preprocess_variant(skewfit_variant) skewfit_variant['render'] = True vae_environment = get_envs(skewfit_variant) print('done loading vae_env') env_class = variant.get('env_class', None) env_kwargs = variant.get('env_kwargs', None) env_id = variant.get('env_id', None) N = variant.get('N', 10000) test_p = variant.get('test_p', 0.9) use_cached = variant.get('use_cached', True) imsize = variant.get('imsize', 84) num_channels = variant.get('num_channels', 3) show = variant.get('show', False) init_camera = variant.get('init_camera', None) dataset_path = variant.get('dataset_path', None) oracle_dataset_using_set_to_goal = variant.get( 'oracle_dataset_using_set_to_goal', False) random_rollout_data = variant.get('random_rollout_data', False) random_and_oracle_policy_data = variant.get( 'random_and_oracle_policy_data', False) random_and_oracle_policy_data_split = variant.get( 'random_and_oracle_policy_data_split', 0) policy_file = variant.get('policy_file', None) n_random_steps = variant.get('n_random_steps', 100) vae_dataset_specific_env_kwargs = variant.get( 'vae_dataset_specific_env_kwargs', None) save_file_prefix = variant.get('save_file_prefix', None) non_presampled_goal_img_is_garbage = variant.get( 'non_presampled_goal_img_is_garbage', None) tag = variant.get('tag', '') from multiworld.core.image_env import ImageEnv, unormalize_image import rlkit.torch.pytorch_util as ptu info = {} if dataset_path is not None: dataset = load_local_or_remote_file(dataset_path) N = dataset.shape[0] else: if env_kwargs is None: env_kwargs = {} if save_file_prefix is None: save_file_prefix = env_id if save_file_prefix is None: save_file_prefix = env_class.__name__ filename = "/tmp/{}_N{}_{}_imsize{}_random_oracle_split_{}{}.npy".format( save_file_prefix, str(N), init_camera.__name__ if init_camera else '', imsize, random_and_oracle_policy_data_split, tag, ) if True: now = time.time() if env_id is not None: import gym import multiworld multiworld.register_all_envs() env = gym.make(env_id) else: if vae_dataset_specific_env_kwargs is None: vae_dataset_specific_env_kwargs = {} for key, val in env_kwargs.items(): if key not in vae_dataset_specific_env_kwargs: vae_dataset_specific_env_kwargs[key] = val env = env_class(**vae_dataset_specific_env_kwargs) if not isinstance(env, ImageEnv): print("using(ImageEnv)") env = ImageEnv( env, imsize, init_camera=init_camera, transpose=True, normalize=True, non_presampled_goal_img_is_garbage= non_presampled_goal_img_is_garbage, ) else: imsize = env.imsize env.non_presampled_goal_img_is_garbage = non_presampled_goal_img_is_garbage env.reset() info['env'] = env if random_and_oracle_policy_data: policy_file = load_local_or_remote_file(policy_file) policy = policy_file['policy'] policy.to(ptu.device) if random_rollout_data: from rlkit.exploration_strategies.ou_strategy import OUStrategy policy = OUStrategy(env.action_space) dataset = np.zeros((N, imsize * imsize * num_channels), dtype=np.uint8) for i in range(10): NP = [] if True: print(i) #print('th step') goal = env.sample_goal() # print("goal___________________________") # print(goal) # print("goal___________________________") env.set_to_goal(goal) obs = env._get_obs() #img = img.reshape(3, imsize, imsize).transpose() # img = img[::-1, :, ::-1] # cv2.imshow('img', img) # cv2.waitKey(1) img_1 = obs['image_observation'] img_1 = img_1.reshape(3, imsize, imsize).transpose() NP.append(img_1) if i % 3 == 0: cv2.imshow('img1', img_1) cv2.waitKey(1) #img_1_reconstruct = vae_environment._reconstruct_img(obs['image_observation']).transpose() encoded_1 = vae_environment._get_encoded( obs['image_observation']) print(encoded_1) NP.append(encoded_1) img_1_reconstruct = vae_environment._get_img( encoded_1).transpose() NP.append(img_1_reconstruct) #dataset[i, :] = unormalize_image(img) # img_1 = img_1.reshape(3, imsize, imsize).transpose() if i % 3 == 0: cv2.imshow('img1_reconstruction', img_1_reconstruct) cv2.waitKey(1) env.reset() instr = env.generate_new_state(goal) if i % 3 == 0: print(instr) obs = env._get_obs() # obs = env._get_obs() img_2 = obs['image_observation'] img_2 = img_2.reshape(3, imsize, imsize).transpose() NP.append(img_2) if i % 3 == 0: cv2.imshow('img2', img_2) cv2.waitKey(1) #img_2_reconstruct = vae_environment._reconstruct_img(obs['image_observation']).transpose() encoded_2 = vae_environment._get_encoded( obs['image_observation']) NP.append(encoded_2) img_2_reconstruct = vae_environment._get_img( encoded_2).transpose() NP.append(img_2_reconstruct) NP.append(instr) # img_2 = img_2.reshape(3, imsize, imsize).transpose() if i % 3 == 0: cv2.imshow('img2_reconstruct', img_2_reconstruct) cv2.waitKey(1) NP = np.array(NP) idx = str(i) name = "/home/xiaomin/Downloads/IFIG_DATA_1/" + idx + ".npy" np.save(open(name, 'wb'), NP) # radius = input('waiting...') # #get the in between functions import dill import pickle get_encoded = dill.dumps(vae_environment._get_encoded) with open( "/home/xiaomin/Downloads/IFIG_encoder_decoder/get_encoded_1000_epochs_one_puck.txt", "wb") as fp: pickle.dump(get_encoded, fp) with open( "/home/xiaomin/Downloads/IFIG_encoder_decoder/get_encoded_1000_epochs_one_puck.txt", "rb") as fp: b = pickle.load(fp) func_get_encoded = dill.loads(b) encoded = func_get_encoded(obs['image_observation']) print(encoded) print('------------------------------') get_img = dill.dumps(vae_environment._get_img) with open( "/home/xiaomin/Downloads/IFIG_encoder_decoder/get_img_1000_epochs_one_puck.txt", "wb") as fp: pickle.dump(get_img, fp) with open( "/home/xiaomin/Downloads/IFIG_encoder_decoder/get_img_1000_epochs_one_puck.txt", "rb") as fp: c = pickle.load(fp) func_get_img = dill.loads(c) img_1_reconstruct = func_get_img(encoded).transpose() print(img_1_reconstruct) #dataset[i, :] = unormalize_image(img) # img_1 = img_1.reshape(3, imsize, imsize).transpose() cv2.imshow('test', img_1_reconstruct) cv2.waitKey(0) print("done making training data", filename, time.time() - now) np.save(filename, dataset)
def train_vae_and_update_variant(variant): # actually pretrain vae and ROLL. skewfit_variant = variant['skewfit_variant'] train_vae_variant = variant['train_vae_variant'] # prepare the background subtractor needed to perform segmentation if 'unet' in skewfit_variant['segmentation_method']: print("training opencv background model!") v = train_vae_variant['generate_lstm_dataset_kwargs'] env_id = v.get('env_id', None) env_id_invis = invisiable_env_id[env_id] import gym import multiworld multiworld.register_all_envs() obj_invisible_env = gym.make(env_id_invis) init_camera = v.get('init_camera', None) presampled_goals = None if skewfit_variant.get("presampled_goals_path") is not None: presampled_goals = load_local_or_remote_file( skewfit_variant['presampled_goals_path']).item() print("presampled goal path is: ", skewfit_variant['presampled_goals_path']) obj_invisible_env = ImageEnv( obj_invisible_env, v.get('imsize'), init_camera=init_camera, transpose=True, normalize=True, presampled_goals=presampled_goals, ) train_num = 2000 if 'Push' in env_id else 4000 train_bgsb(obj_invisible_env, train_num=train_num) if skewfit_variant.get('vae_path', None) is None: # train new vaes logger.remove_tabular_output('progress.csv', relative_to_snapshot_dir=True) vaes, vae_train_datas, vae_test_datas = train_vae( train_vae_variant, skewfit_variant=skewfit_variant, return_data=True) # one original vae, one segmented ROLL. if skewfit_variant.get('save_vae_data', False): skewfit_variant['vae_train_data'] = vae_train_datas skewfit_variant['vae_test_data'] = vae_test_datas logger.add_tabular_output( 'progress.csv', relative_to_snapshot_dir=True, ) skewfit_variant['vae_path'] = vaes # just pass the VAE directly else: # load pre-trained vaes print("load pretrain scene-/objce-VAE from: {}".format( skewfit_variant['vae_path'])) data = torch.load(osp.join(skewfit_variant['vae_path'], 'params.pkl')) vae_original = data['vae_original'] vae_segmented = data['lstm_segmented'] skewfit_variant['vae_path'] = [vae_segmented, vae_original] generate_vae_dataset_fctn = train_vae_variant.get( 'generate_vae_data_fctn', generate_vae_dataset) generate_lstm_dataset_fctn = train_vae_variant.get( 'generate_lstm_data_fctn') assert generate_lstm_dataset_fctn is not None, "Must provide a custom generate lstm pretraining dataset function!" train_data_lstm, test_data_lstm, info_lstm = generate_lstm_dataset_fctn( train_vae_variant['generate_lstm_dataset_kwargs'], segmented=True, segmentation_method=skewfit_variant['segmentation_method']) train_data_ori, test_data_ori, info_ori = generate_vae_dataset_fctn( train_vae_variant['generate_vae_dataset_kwargs']) train_datas = [train_data_lstm, train_data_ori] test_datas = [test_data_lstm, test_data_ori] if skewfit_variant.get('save_vae_data', False): skewfit_variant['vae_train_data'] = train_datas skewfit_variant['vae_test_data'] = test_datas
def __init__( self, wrapped_env, vae, vae_input_key_prefix='image', sample_from_true_prior=False, decode_goals=False, decode_goals_on_reset=True, render_goals=False, render_rollouts=False, reward_params=None, goal_sampling_mode="vae_prior", presampled_goals_path=None, num_goals_to_presample=0, imsize=84, obs_size=None, norm_order=2, epsilon=20, presampled_goals=None, ): if reward_params is None: reward_params = dict() super().__init__(wrapped_env) if type(vae) is str: self.vae = load_local_or_remote_file(vae) else: self.vae = vae self.representation_size = self.vae.representation_size self.input_channels = self.vae.input_channels self.sample_from_true_prior = sample_from_true_prior self._decode_goals = decode_goals self.render_goals = render_goals self.render_rollouts = render_rollouts self.default_kwargs = dict( decode_goals=decode_goals, render_goals=render_goals, render_rollouts=render_rollouts, ) self.imsize = imsize self.reward_params = reward_params self.reward_type = self.reward_params.get("type", 'latent_distance') self.norm_order = self.reward_params.get("norm_order", norm_order) self.epsilon = self.reward_params.get("epsilon", epsilon) self.reward_min_variance = self.reward_params.get("min_variance", 0) self.decode_goals_on_reset = decode_goals_on_reset latent_space = Box( -10 * np.ones(obs_size or self.representation_size), 10 * np.ones(obs_size or self.representation_size), dtype=np.float32, ) spaces = self.wrapped_env.observation_space.spaces spaces['observation'] = latent_space spaces['desired_goal'] = latent_space spaces['achieved_goal'] = latent_space spaces['latent_observation'] = latent_space spaces['latent_desired_goal'] = latent_space spaces['latent_achieved_goal'] = latent_space self.observation_space = Dict(spaces) self._presampled_goals = presampled_goals if num_goals_to_presample > 0: self._presampled_goals = self.wrapped_env.sample_goals( num_goals_to_presample) if presampled_goals_path is not None: self._presampled_goals = load_local_or_remote_file( presampled_goals_path) if self._presampled_goals is None: self.num_goals_presampled = 0 else: self.num_goals_presampled = self._presampled_goals[random.choice( list(self._presampled_goals))].shape[0] self.vae_input_key_prefix = vae_input_key_prefix assert vae_input_key_prefix in {'image', 'image_proprio'} self.vae_input_observation_key = vae_input_key_prefix + '_observation' self.vae_input_achieved_goal_key = vae_input_key_prefix + '_achieved_goal' self.vae_input_desired_goal_key = vae_input_key_prefix + '_desired_goal' self._mode_map = {} self.desired_goal = {'latent_desired_goal': latent_space.sample()} self._initial_obs = None self._custom_goal_sampler = None self._goal_sampling_mode = goal_sampling_mode
def get_envs(variant): render = variant.get('render', False) vae_path = variant.get("vae_path", None) reward_params = variant.get("reward_params", dict()) init_camera = variant.get("init_camera", None) presample_goals = variant.get('presample_goals', False) presample_image_goals_only = variant.get('presample_image_goals_only', False) presampled_goals_path = variant.get('presampled_goals_path', None) vae = load_local_or_remote_file( vae_path) if type(vae_path) is str else vae_path if 'env_id' in variant: env = gym.make(variant['env_id']) else: env = variant["env_class"](**variant['env_kwargs']) if isinstance(env, ImageEnv): image_env = env else: image_env = ImageEnv( env, variant.get('imsize'), init_camera=init_camera, transpose=True, normalize=True, ) if presample_goals: """ This will fail for online-parallel as presampled_goals will not be serialized. Also don't use this for online-vae. """ if presampled_goals_path is None: image_env.non_presampled_goal_img_is_garbage = True vae_env = VAEWrappedEnv(image_env, vae, imsize=image_env.imsize, decode_goals=render, render_goals=render, render_rollouts=render, reward_params=reward_params, **variant.get('vae_wrapped_env_kwargs', {})) presampled_goals = variant['generate_goal_dataset_fctn']( env=vae_env, env_id=variant.get('env_id', None), **variant['goal_generation_kwargs']) del vae_env else: presampled_goals = load_local_or_remote_file( presampled_goals_path).item() del image_env image_env = ImageEnv(env, variant.get('imsize'), init_camera=init_camera, transpose=True, normalize=True, presampled_goals=presampled_goals, **variant.get('image_env_kwargs', {})) vae_env = VAEWrappedEnv(image_env, vae, imsize=image_env.imsize, decode_goals=render, render_goals=render, render_rollouts=render, reward_params=reward_params, presampled_goals=presampled_goals, **variant.get('vae_wrapped_env_kwargs', {})) print("Presampling all goals only") else: vae_env = VAEWrappedEnv(image_env, vae, imsize=image_env.imsize, decode_goals=render, render_goals=render, render_rollouts=render, reward_params=reward_params, **variant.get('vae_wrapped_env_kwargs', {})) if presample_image_goals_only: presampled_goals = variant['generate_goal_dataset_fctn']( image_env=vae_env.wrapped_env, **variant['goal_generation_kwargs']) image_env.set_presampled_goals(presampled_goals) print("Presampling image goals only") else: print("Not using presampled goals") env = vae_env training_mode = variant.get("training_mode", "train") testing_mode = variant.get("testing_mode", "test") env.add_mode('eval', testing_mode) env.add_mode('train', training_mode) env.add_mode('relabeling', training_mode) # relabeling_env.disable_render() env.add_mode("video_vae", 'video_vae') env.add_mode("video_env", 'video_env') return env
def __init__( self, wrapped_env, vae, vae_input_key_prefix='image', use_vae_goals=True, sample_from_true_prior=False, decode_goals=False, render_goals=False, render_rollouts=False, reward_params=None, mode="train", imsize=84, obs_size=None, epsilon=20, presampled_goals=None, ): self.quick_init(locals()) if reward_params is None: reward_params = dict() super().__init__(wrapped_env) if type(vae) is str: self.vae = load_local_or_remote_file(vae) else: self.vae = vae self.representation_size = self.vae.representation_size self.input_channels = self.vae.input_channels self._use_vae_goals = use_vae_goals self.sample_from_true_prior = sample_from_true_prior self.decode_goals = decode_goals self.render_goals = render_goals self.render_rollouts = render_rollouts self.default_kwargs = dict( decode_goals=decode_goals, render_goals=render_goals, render_rollouts=render_rollouts, ) self.imsize = imsize self.reward_params = reward_params self.reward_type = self.reward_params.get("type", 'latent_distance') self.epsilon = self.reward_params.get("epsilon", epsilon) self.reward_min_variance = self.reward_params.get("min_variance", 0) latent_space = Box( -10 * np.ones(obs_size or self.representation_size), 10 * np.ones(obs_size or self.representation_size), dtype=np.float32, ) spaces = copy.deepcopy(self.wrapped_env.observation_space.spaces) spaces['observation'] = latent_space spaces['desired_goal'] = latent_space spaces['achieved_goal'] = latent_space spaces['latent_observation'] = latent_space spaces['latent_desired_goal'] = latent_space spaces['latent_achieved_goal'] = latent_space self.observation_space = Dict(spaces) self.mode(mode) self._presampled_goals = presampled_goals if self._presampled_goals is None: self.num_goals_presampled = 0 else: self.num_goals_presampled = ( presampled_goals[list(presampled_goals)[0]].shape[0] ) self.vae_input_key_prefix = vae_input_key_prefix assert vae_input_key_prefix in {'image'} self.vae_input_observation_key = vae_input_key_prefix + '_observation' self.vae_input_achieved_goal_key = vae_input_key_prefix + '_achieved_goal' self.vae_input_desired_goal_key = vae_input_key_prefix + '_desired_goal' self._mode_map = {} self.desired_goal = {'latent_desired_goal': latent_space.sample()} self._initial_obs = None
def get_envs(variant): from multiworld.core.image_env import ImageEnv from rlkit.envs.vae_wrapper import VAEWrappedEnv from rlkit.util.io import load_local_or_remote_file render = variant.get('render', False) vae_path = variant.get("vae_path", None) reward_params = variant.get("reward_params", dict()) init_camera = variant.get("init_camera", None) do_state_exp = variant.get("do_state_exp", False) presample_goals = variant.get('presample_goals', False) presample_image_goals_only = variant.get('presample_image_goals_only', False) presampled_goals_path = variant.get('presampled_goals_path', None) vae = load_local_or_remote_file( vae_path) if type(vae_path) is str else vae_path if 'env_id' in variant: import gym import multiworld multiworld.register_all_envs() env = gym.make(variant['env_id']) else: env = variant["env_class"](**variant['env_kwargs']) if not do_state_exp: if isinstance(env, ImageEnv): image_env = env else: image_env = ImageEnv( env, variant.get('imsize'), init_camera=init_camera, transpose=True, normalize=True, ) if presample_goals: """ This will fail for online-parallel as presampled_goals will not be serialized. Also don't use this for online-vae. """ if presampled_goals_path is None: image_env.non_presampled_goal_img_is_garbage = True vae_env = VAEWrappedEnv(image_env, vae, imsize=image_env.imsize, decode_goals=render, render_goals=render, render_rollouts=render, reward_params=reward_params, **variant.get('vae_wrapped_env_kwargs', {})) presampled_goals = variant['generate_goal_dataset_fctn']( env=vae_env, env_id=variant.get('env_id', None), **variant['goal_generation_kwargs']) del vae_env else: presampled_goals = load_local_or_remote_file( presampled_goals_path).item() del image_env image_env = ImageEnv(env, variant.get('imsize'), init_camera=init_camera, transpose=True, normalize=True, presampled_goals=presampled_goals, **variant.get('image_env_kwargs', {})) vae_env = VAEWrappedEnv(image_env, vae, imsize=image_env.imsize, decode_goals=render, render_goals=render, render_rollouts=render, reward_params=reward_params, presampled_goals=presampled_goals, **variant.get('vae_wrapped_env_kwargs', {})) print("Presampling all goals only") else: vae_env = VAEWrappedEnv(image_env, vae, imsize=image_env.imsize, decode_goals=render, render_goals=render, render_rollouts=render, reward_params=reward_params, **variant.get('vae_wrapped_env_kwargs', {})) if presample_image_goals_only: presampled_goals = variant['generate_goal_dataset_fctn']( image_env=vae_env.wrapped_env, **variant['goal_generation_kwargs']) image_env.set_presampled_goals(presampled_goals) print("Presampling image goals only") else: print("Not using presampled goals") env = vae_env return env
def generate_vae_dataset(variant): print(variant) from tqdm import tqdm env_class = variant.get('env_class', None) env_kwargs = variant.get('env_kwargs', None) env_id = variant.get('env_id', None) N = variant.get('N', 10000) batch_size = variant.get('batch_size', 128) test_p = variant.get('test_p', 0.9) use_cached = variant.get('use_cached', True) imsize = variant.get('imsize', 84) num_channels = variant.get('num_channels', 3) show = variant.get('show', False) init_camera = variant.get('init_camera', None) dataset_path = variant.get('dataset_path', None) augment_data = variant.get('augment_data', False) data_filter_fn = variant.get('data_filter_fn', lambda x: x) delete_after_loading = variant.get('delete_after_loading', False) oracle_dataset_using_set_to_goal = variant.get( 'oracle_dataset_using_set_to_goal', False) random_rollout_data = variant.get('random_rollout_data', False) random_rollout_data_set_to_goal = variant.get( 'random_rollout_data_set_to_goal', True) random_and_oracle_policy_data = variant.get( 'random_and_oracle_policy_data', False) random_and_oracle_policy_data_split = variant.get( 'random_and_oracle_policy_data_split', 0) policy_file = variant.get('policy_file', None) n_random_steps = variant.get('n_random_steps', 100) vae_dataset_specific_env_kwargs = variant.get( 'vae_dataset_specific_env_kwargs', None) save_file_prefix = variant.get('save_file_prefix', None) non_presampled_goal_img_is_garbage = variant.get( 'non_presampled_goal_img_is_garbage', None) conditional_vae_dataset = variant.get('conditional_vae_dataset', False) use_env_labels = variant.get('use_env_labels', False) use_linear_dynamics = variant.get('use_linear_dynamics', False) enviorment_dataset = variant.get('enviorment_dataset', False) save_trajectories = variant.get('save_trajectories', False) save_trajectories = save_trajectories or use_linear_dynamics or conditional_vae_dataset tag = variant.get('tag', '') assert N % n_random_steps == 0, "Fix N/horizon or dataset generation will fail" from multiworld.core.image_env import ImageEnv, unormalize_image import rlkit.torch.pytorch_util as ptu from rlkit.util.io import load_local_or_remote_file from rlkit.data_management.dataset import ( TrajectoryDataset, ImageObservationDataset, InitialObservationDataset, EnvironmentDataset, ConditionalDynamicsDataset, InitialObservationNumpyDataset, InfiniteBatchLoader, InitialObservationNumpyJitteringDataset) info = {} use_test_dataset = False if dataset_path is not None: if type(dataset_path) == str: dataset = load_local_or_remote_file( dataset_path, delete_after_loading=delete_after_loading) dataset = dataset.item() N = dataset['observations'].shape[0] * dataset[ 'observations'].shape[1] n_random_steps = dataset['observations'].shape[1] if isinstance(dataset_path, list): dataset = concatenate_datasets(dataset_path) N = dataset['observations'].shape[0] * dataset[ 'observations'].shape[1] n_random_steps = dataset['observations'].shape[1] if isinstance(dataset_path, dict): if type(dataset_path['train']) == str: dataset = load_local_or_remote_file( dataset_path['train'], delete_after_loading=delete_after_loading) dataset = dataset.item() elif isinstance(dataset_path['train'], list): dataset = concatenate_datasets(dataset_path['train']) if type(dataset_path['test']) == str: test_dataset = load_local_or_remote_file( dataset_path['test'], delete_after_loading=delete_after_loading) test_dataset = test_dataset.item() elif isinstance(dataset_path['test'], list): test_dataset = concatenate_datasets(dataset_path['test']) N = dataset['observations'].shape[0] * dataset[ 'observations'].shape[1] n_random_steps = dataset['observations'].shape[1] use_test_dataset = True else: if env_kwargs is None: env_kwargs = {} if save_file_prefix is None: save_file_prefix = env_id if save_file_prefix is None: save_file_prefix = env_class.__name__ filename = "/tmp/{}_N{}_{}_imsize{}_random_oracle_split_{}{}.npy".format( save_file_prefix, str(N), init_camera.__name__ if init_camera and hasattr(init_camera, '__name__') else '', imsize, random_and_oracle_policy_data_split, tag, ) if use_cached and osp.isfile(filename): dataset = load_local_or_remote_file( filename, delete_after_loading=delete_after_loading) if conditional_vae_dataset: dataset = dataset.item() print("loaded data from saved file", filename) else: now = time.time() if env_id is not None: import gym import multiworld multiworld.register_all_envs() env = gym.make(env_id) else: if vae_dataset_specific_env_kwargs is None: vae_dataset_specific_env_kwargs = {} for key, val in env_kwargs.items(): if key not in vae_dataset_specific_env_kwargs: vae_dataset_specific_env_kwargs[key] = val env = env_class(**vae_dataset_specific_env_kwargs) if not isinstance(env, ImageEnv): env = ImageEnv( env, imsize, init_camera=init_camera, transpose=True, normalize=True, non_presampled_goal_img_is_garbage= non_presampled_goal_img_is_garbage, ) else: imsize = env.imsize env.non_presampled_goal_img_is_garbage = non_presampled_goal_img_is_garbage env.reset() info['env'] = env if random_and_oracle_policy_data: policy_file = load_local_or_remote_file(policy_file) policy = policy_file['policy'] policy.to(ptu.device) if random_rollout_data: from rlkit.exploration_strategies.ou_strategy import OUStrategy policy = OUStrategy(env.action_space) if save_trajectories: dataset = { 'observations': np.zeros((N // n_random_steps, n_random_steps, imsize * imsize * num_channels), dtype=np.uint8), 'actions': np.zeros((N // n_random_steps, n_random_steps, env.action_space.shape[0]), dtype=np.float), 'env': np.zeros( (N // n_random_steps, imsize * imsize * num_channels), dtype=np.uint8), } else: dataset = np.zeros((N, imsize * imsize * num_channels), dtype=np.uint8) labels = [] for i in tqdm(range(N)): if random_and_oracle_policy_data: num_random_steps = int(N * random_and_oracle_policy_data_split) if i < num_random_steps: env.reset() for _ in range(n_random_steps): obs = env.step(env.action_space.sample())[0] else: obs = env.reset() policy.reset() for _ in range(n_random_steps): policy_obs = np.hstack(( obs['state_observation'], obs['state_desired_goal'], )) action, _ = policy.get_action(policy_obs) obs, _, _, _ = env.step(action) elif random_rollout_data: #ADD DATA WHERE JUST PUCK MOVES if i % n_random_steps == 0: env.reset() policy.reset() env_img = env._get_obs()['image_observation'] if random_rollout_data_set_to_goal: env.set_to_goal(env.get_goal()) obs = env._get_obs() u = policy.get_action_from_raw_action( env.action_space.sample()) env.step(u) elif oracle_dataset_using_set_to_goal: print(i) goal = env.sample_goal() env.set_to_goal(goal) obs = env._get_obs() else: env.reset() for _ in range(n_random_steps): obs = env.step(env.action_space.sample())[0] img = obs['image_observation'] if use_env_labels: labels.append(obs['label']) if save_trajectories: dataset['observations'][ i // n_random_steps, i % n_random_steps, :] = unormalize_image(img) dataset['actions'][i // n_random_steps, i % n_random_steps, :] = u dataset['env'][i // n_random_steps, :] = unormalize_image( env_img) else: dataset[i, :] = unormalize_image(img) if show: img = img.reshape(3, imsize, imsize).transpose() img = img[::-1, :, ::-1] cv2.imshow('img', img) cv2.waitKey(1) # radius = input('waiting...') print("done making training data", filename, time.time() - now) np.save(filename, dataset) #np.save(filename[:-4] + 'labels.npy', np.array(labels)) info['train_labels'] = [] info['test_labels'] = [] dataset = data_filter_fn(dataset) if use_linear_dynamics and conditional_vae_dataset: num_trajectories = N // n_random_steps n = int(num_trajectories * test_p) train_dataset = ConditionalDynamicsDataset({ 'observations': dataset['observations'][:n, :, :], 'actions': dataset['actions'][:n, :, :], 'env': dataset['env'][:n, :] }) test_dataset = ConditionalDynamicsDataset({ 'observations': dataset['observations'][n:, :, :], 'actions': dataset['actions'][n:, :, :], 'env': dataset['env'][n:, :] }) num_trajectories = N // n_random_steps n = int(num_trajectories * test_p) indices = np.arange(num_trajectories) np.random.shuffle(indices) train_i, test_i = indices[:n], indices[n:] try: train_dataset = ConditionalDynamicsDataset({ 'observations': dataset['observations'][train_i, :, :], 'actions': dataset['actions'][train_i, :, :], 'env': dataset['env'][train_i, :] }) test_dataset = ConditionalDynamicsDataset({ 'observations': dataset['observations'][test_i, :, :], 'actions': dataset['actions'][test_i, :, :], 'env': dataset['env'][test_i, :] }) except: train_dataset = ConditionalDynamicsDataset({ 'observations': dataset['observations'][train_i, :, :], 'actions': dataset['actions'][train_i, :, :], }) test_dataset = ConditionalDynamicsDataset({ 'observations': dataset['observations'][test_i, :, :], 'actions': dataset['actions'][test_i, :, :], }) elif use_linear_dynamics: num_trajectories = N // n_random_steps n = int(num_trajectories * test_p) train_dataset = TrajectoryDataset({ 'observations': dataset['observations'][:n, :, :], 'actions': dataset['actions'][:n, :, :] }) test_dataset = TrajectoryDataset({ 'observations': dataset['observations'][n:, :, :], 'actions': dataset['actions'][n:, :, :] }) elif enviorment_dataset: n = int(n_random_steps * test_p) train_dataset = EnvironmentDataset({ 'observations': dataset['observations'][:, :n, :], }) test_dataset = EnvironmentDataset({ 'observations': dataset['observations'][:, n:, :], }) elif conditional_vae_dataset: num_trajectories = N // n_random_steps n = int(num_trajectories * test_p) indices = np.arange(num_trajectories) np.random.shuffle(indices) train_i, test_i = indices[:n], indices[n:] if augment_data: dataset_class = InitialObservationNumpyJitteringDataset else: dataset_class = InitialObservationNumpyDataset if 'env' not in dataset: dataset['env'] = dataset['observations'][:, 0] if use_test_dataset and ('env' not in test_dataset): test_dataset['env'] = test_dataset['observations'][:, 0] if use_test_dataset: train_dataset = dataset_class({ 'observations': dataset['observations'], 'env': dataset['env'] }) test_dataset = dataset_class({ 'observations': test_dataset['observations'], 'env': test_dataset['env'] }) else: train_dataset = dataset_class({ 'observations': dataset['observations'][train_i, :, :], 'env': dataset['env'][train_i, :] }) test_dataset = dataset_class({ 'observations': dataset['observations'][test_i, :, :], 'env': dataset['env'][test_i, :] }) train_batch_loader_kwargs = variant.get( 'train_batch_loader_kwargs', dict( batch_size=batch_size, num_workers=0, )) test_batch_loader_kwargs = variant.get( 'test_batch_loader_kwargs', dict( batch_size=batch_size, num_workers=0, )) train_data_loader = data.DataLoader(train_dataset, shuffle=True, drop_last=True, **train_batch_loader_kwargs) test_data_loader = data.DataLoader(test_dataset, shuffle=True, drop_last=True, **test_batch_loader_kwargs) train_dataset = InfiniteBatchLoader(train_data_loader) test_dataset = InfiniteBatchLoader(test_data_loader) else: n = int(N * test_p) train_dataset = ImageObservationDataset(dataset[:n, :]) test_dataset = ImageObservationDataset(dataset[n:, :]) return train_dataset, test_dataset, info