def load_dataset(filename, test_p=0.9): dataset = load_local_or_remote_file(filename).item() num_trajectories = dataset["observations"].shape[0] n_random_steps = dataset["observations"].shape[1] #num_trajectories = N // n_random_steps n = int(num_trajectories * test_p) try: train_dataset = InitialObservationDataset({ 'observations': dataset['observations'][:n, :, :], 'env': dataset['env'][:n, :], }) test_dataset = InitialObservationDataset({ 'observations': dataset['observations'][n:, :, :], 'env': dataset['env'][n:, :], }) except: train_dataset = InitialObservationDataset({ 'observations': dataset['observations'][:n, :, :], }) test_dataset = InitialObservationDataset({ 'observations': dataset['observations'][n:, :, :], }) return train_dataset, test_dataset
def load_dataset(self, dataset_path): dataset = load_local_or_remote_file(dataset_path) dataset = dataset.item() observations = dataset['observations'] actions = dataset['actions'] # dataset['observations'].shape # (2000, 50, 6912) # dataset['actions'].shape # (2000, 50, 2) # dataset['env'].shape # (2000, 6912) N, H, imlength = observations.shape self.vae.eval() for n in range(N): x0 = ptu.from_numpy(dataset['env'][n:n + 1, :] / 255.0) x = ptu.from_numpy(observations[n, :, :] / 255.0) latents = self.vae.encode(x, x0, distrib=False) r1, r2 = self.vae.latent_sizes conditioning = latents[0, r1:] goal = torch.cat( [ptu.randn(self.vae.latent_sizes[0]), conditioning]) goal = ptu.get_numpy(goal) # latents[-1, :] latents = ptu.get_numpy(latents) latent_delta = latents - goal distances = np.zeros((H - 1, 1)) for i in range(H - 1): distances[i, 0] = np.linalg.norm(latent_delta[i + 1, :]) terminals = np.zeros((H - 1, 1)) # terminals[-1, 0] = 1 path = dict( observations=[], actions=actions[n, :H - 1, :], next_observations=[], rewards=-distances, terminals=terminals, ) for t in range(H - 1): # reward = -np.linalg.norm(latent_delta[i, :]) obs = dict( latent_observation=latents[t, :], latent_achieved_goal=latents[t, :], latent_desired_goal=goal, ) next_obs = dict( latent_observation=latents[t + 1, :], latent_achieved_goal=latents[t + 1, :], latent_desired_goal=goal, ) path['observations'].append(obs) path['next_observations'].append(next_obs) # import ipdb; ipdb.set_trace() self.replay_buffer.add_path(path)
def __init__( self, wrapped_env, vae, pixel_cnn=None, vae_input_key_prefix='image', sample_from_true_prior=False, decode_goals=False, decode_goals_on_reset=True, render_goals=False, render_rollouts=False, reward_params=None, goal_sampling_mode="vae_prior", imsize=84, obs_size=None, norm_order=2, epsilon=20, presampled_goals=None, ): if reward_params is None: reward_params = dict() super().__init__( wrapped_env, vae, vae_input_key_prefix, sample_from_true_prior, decode_goals, decode_goals_on_reset, render_goals, render_rollouts, reward_params, goal_sampling_mode, imsize, obs_size, norm_order, epsilon, presampled_goals, ) if type(pixel_cnn) is str: self.pixel_cnn = load_local_or_remote_file(pixel_cnn) self.representation_size = self.vae.representation_size self.imsize = self.vae.imsize print("Location: BiGAN WRAPPER") latent_space = Box( -10 * np.ones(obs_size or self.representation_size), 10 * np.ones(obs_size or self.representation_size), dtype=np.float32, ) spaces = self.wrapped_env.observation_space.spaces spaces['observation'] = latent_space spaces['desired_goal'] = latent_space spaces['achieved_goal'] = latent_space spaces['latent_observation'] = latent_space spaces['latent_desired_goal'] = latent_space spaces['latent_achieved_goal'] = latent_space self.observation_space = Dict(spaces)
def __init__( self, datapath, ): self._presampled_goals = load_local_or_remote_file(datapath) self._num_presampled_goals = self._presampled_goals[list( self._presampled_goals)[0]].shape[0] self._set_spaces()
def __init__( self, trainer, replay_buffer, demo_train_buffer, demo_test_buffer, model_path=None, reward_fn=None, env=None, demo_paths=[], # list of dicts normalize=False, demo_train_split=0.9, demo_data_split=1, add_demos_to_replay_buffer=True, bc_num_pretrain_steps=0, bc_batch_size=64, bc_weight=1.0, rl_weight=1.0, q_num_pretrain_steps=0, weight_decay=0, eval_policy=None, recompute_reward=False, env_info_key=None, obs_key=None, load_terminals=True, do_preprocess=True, **kwargs ): super().__init__(trainer, replay_buffer, demo_train_buffer, demo_test_buffer, demo_paths, demo_train_split, demo_data_split, add_demos_to_replay_buffer, bc_num_pretrain_steps, bc_batch_size, bc_weight, rl_weight, q_num_pretrain_steps, weight_decay, eval_policy, recompute_reward, env_info_key, obs_key, load_terminals, **kwargs) self.model = load_local_or_remote_file(model_path) self.reward_fn = reward_fn self.normalize = normalize self.env = env self.do_preprocess = do_preprocess print("ZEROING OUT GOALS")
def load_demos(self, demo_path): data = load_local_or_remote_file(demo_path) random.shuffle(data) N = int(len(data) * self.train_split) print("using", N, "paths for training") for path in data[:N]: self.load_path(path, self.replay_buffer) for path in data[N:]: self.load_path(path, self.test_replay_buffer)
def resume(variant): data = load_local_or_remote_file(variant.get("pretrained_algorithm_path"), map_location="cuda") algo = data['algorithm'] algo.num_epochs = variant['num_epochs'] post_pretrain_hyperparams = variant["trainer_kwargs"].get("post_pretrain_hyperparams", {}) algo.trainer.set_algorithm_weights(**post_pretrain_hyperparams) algo.train()
def load_demo_path(self, demo_path, on_policy=True): data = list(load_local_or_remote_file(demo_path)) # if not on_policy: # data = [data] # random.shuffle(data) N = int(len(data) * self.demo_train_split) print("using", N, "paths for training") if self.add_demos_to_replay_buffer: for path in data[:N]: self.load_path(path, self.replay_buffer) if on_policy: for path in data[:N]: self.load_path(path, self.demo_train_buffer) for path in data[N:]: self.load_path(path, self.demo_test_buffer)
def create_sets(env_id, env_class, env_kwargs, renderer, saved_filename=None, save_to_filename=None, **kwargs): if saved_filename is not None: sets = asset_loader.load_local_or_remote_file(saved_filename) else: env = get_gym_env(env_id, env_class=env_class, env_kwargs=env_kwargs) if isinstance(env, PickAndPlaceEnv): sets = sample_pnp_sets(env, renderer, **kwargs) else: raise NotImplementedError() if save_to_filename: save(sets, save_to_filename) return sets
def encode_dataset(dataset_path): data = load_local_or_remote_file(dataset_path) data = data.item() # resize_dataset(data) data["observations"] = data["observations"].reshape(-1, 50, imlength) all_data = [] vqvae.to('cpu') for i in tqdm(range(data["observations"].shape[0])): obs = ptu.from_numpy(data["observations"][i] / 255.0) latent = vqvae.encode(obs, cont=False).reshape(-1, 50, discrete_size) all_data.append(latent) vqvae.to('cuda') encodings = ptu.get_numpy(torch.cat(all_data, dim=0)) return encodings
def load_demo_path(self, path, is_demo, obs_dict, train_split=None, data_split=None, sync_dir=None): print("loading off-policy path", path) if sync_dir is not None: sync_down_folder(sync_dir) paths = glob.glob(get_absolute_path(path)) else: paths = [path] data = [] for filename in paths: data.extend(list(load_local_or_remote_file(filename))) # if not is_demo: # data = [data] # random.shuffle(data) if train_split is None: train_split = self.demo_train_split if data_split is None: data_split = self.demo_data_split M = int(len(data) * train_split * data_split) N = int(len(data) * data_split) print("using", N, "paths for training") if self.add_demos_to_replay_buffer: for path in data[:M]: self.load_path(path, self.replay_buffer, obs_dict=obs_dict) if is_demo: for path in data[:M]: self.load_path(path, self.demo_train_buffer, obs_dict=obs_dict) for path in data[M:N]: self.load_path(path, self.demo_test_buffer, obs_dict=obs_dict)
def generate_trajectories( snapshot_path, max_path_length, num_steps, save_observation_keys, ): ptu.set_gpu_mode(True) snapshot = asset_loader.load_local_or_remote_file( snapshot_path, file_type='torch', ) policy = snapshot['exploration/policy'] env = snapshot['exploration/env'] observation_key = snapshot['exploration/observation_key'] context_keys_for_rl = snapshot['exploration/context_keys_for_policy'] path_collector = ContextualPathCollector( env, policy, observation_key=observation_key, context_keys_for_policy=context_keys_for_rl, ) policy.to(ptu.device) paths = path_collector.collect_new_paths( max_path_length, num_steps, True, ) trajectories = [] for path in paths: trajectory = dict( actions=path['actions'], terminals=path['terminals'], ) for key in save_observation_keys: trajectory[key] = np.array([ obs[key] for obs in path['full_observations'] ]) trajectory['next_' + key] = np.array([ obs[key] for obs in path['full_next_observations'] ]) trajectories.append(trajectory) return trajectories
def encode_dataset(path, max_traj=None): data = load_local_or_remote_file(path) data = data.item() data["observations"] = data["observations"] all_data = [] n = data["observations"].shape[0] N = min(max_traj or n, n) # vqvae.to('cpu') # 3X faster on a GPU for i in tqdm(range(n)): obs = ptu.from_numpy(data["observations"][i] / 255.0) latent = vqvae.encode(obs, cont=False) all_data.append(latent) encodings = ptu.get_numpy(torch.cat(all_data, dim=0)) return encodings
def concatenate_datasets(data_list): prefix = '/home/ashvin/data/pusher_pucks/' obs, envs, actions, dataset = [], [], [], {} for path in data_list: curr_data = load_local_or_remote_file(prefix + path) curr_data = curr_data.item() n_random_steps = curr_data['observations'].shape[1] imlength = curr_data['observations'].shape[2] action_dim = curr_data['actions'].shape[2] curr_data['env'] = np.repeat(curr_data['env'], n_random_steps, axis=0) curr_data['observations'] = curr_data['observations'].reshape( -1, 1, imlength) curr_data['actions'] = curr_data['actions'].reshape(-1, 1, action_dim) obs.append(curr_data['observations']) envs.append(curr_data['env']) actions.append(curr_data['actions']) dataset['observations'] = np.concatenate(obs, axis=0) dataset['env'] = np.concatenate(envs, axis=0) dataset['actions'] = np.concatenate(actions, axis=0) return dataset
def load_demo_path(self, path, is_demo, obs_dict, train_split=None): print("loading off-policy path", path) data = list(load_local_or_remote_file(path)) # if not is_demo: # data = [data] # random.shuffle(data) if train_split is None: train_split = self.demo_train_split N = int(len(data) * train_split) print("using", N, "paths for training") if self.add_demos_to_replay_buffer: for path in data[:N]: self.load_path(path, self.replay_buffer, obs_dict=obs_dict) if is_demo: for path in data[:N]: self.load_path(path, self.demo_train_buffer, obs_dict=obs_dict) for path in data[N:]: self.load_path(path, self.demo_test_buffer, obs_dict=obs_dict)
def __init__( self, wrapped_env, vae, reward_params=None, config_params=None, imsize=84, obs_size=None, vae_input_observation_key="image_observation", small_image_step=6, ): if config_params is None: config_params = dict if reward_params is None: reward_params = dict() super().__init__(wrapped_env) if type(vae) is str: self.vae = load_local_or_remote_file(vae) else: self.vae = vae self.representation_size = self.vae.representation_size self.input_channels = self.vae.input_channels self.imsize = imsize self.config_params = config_params self.t = 0 self.episode_num = 0 self.reward_params = reward_params self.reward_type = self.reward_params.get("type", 'latent_distance') self.zT = self.reward_params["goal_latent"] self.z0 = self.reward_params["initial_latent"] self.dT = self.zT - self.z0 self.small_image_step = small_image_step # if self.config_params["use_initial"]: # self.dT = self.zT - self.z0 # else: # self.dT = self.zT self.vae_input_observation_key = vae_input_observation_key latent_size = obs_size or self.representation_size latent_space = Box( -10 * np.ones(latent_size), 10 * np.ones(latent_size), dtype=np.float32, ) goal_space = Box( np.zeros((0, )), np.zeros((0, )), dtype=np.float32, ) spaces = self.wrapped_env.observation_space.spaces spaces['observation'] = latent_space spaces['desired_goal'] = goal_space spaces['achieved_goal'] = goal_space spaces['latent_observation'] = latent_space spaces['latent_desired_goal'] = goal_space spaces['latent_achieved_goal'] = goal_space concat_size = latent_size + spaces["state_observation"].low.size concat_space = Box( -10 * np.ones(concat_size), 10 * np.ones(concat_size), dtype=np.float32, ) spaces['concat_observation'] = concat_space small_image_size = 288 // self.small_image_step small_image_imglength = small_image_size * small_image_size * 3 small_image_space = Box( 0 * np.ones(small_image_imglength), 1 * np.ones(small_image_imglength), dtype=np.float32, ) spaces['small_image_observation'] = small_image_space small_image_observation_with_state_size = small_image_imglength + spaces[ "state_observation"].low.size small_image_observation_with_state_space = Box( 0 * np.ones(small_image_observation_with_state_size), 1 * np.ones(small_image_observation_with_state_size), dtype=np.float32, ) spaces[ 'small_image_observation_with_state'] = small_image_observation_with_state_space self.observation_space = Dict(spaces)
def __init__(self, filename): self.data = load_local_or_remote_file(filename)
def load_encoder(encoder_file): encoder = load_local_or_remote_file(encoder_file) # TEMP # #encoder.representation_size = encoder.discrete_size * encoder.embedding_dim # TEMP # return encoder
def generate_vae_dataset(variant): print(variant) env_class = variant.get('env_class', None) env_kwargs = variant.get('env_kwargs', None) env_id = variant.get('env_id', None) N = variant.get('N', 10000) test_p = variant.get('test_p', 0.9) use_cached = variant.get('use_cached', True) imsize = variant.get('imsize', 84) num_channels = variant.get('num_channels', 3) show = variant.get('show', False) init_camera = variant.get('init_camera', None) dataset_path = variant.get('dataset_path', None) oracle_dataset_using_set_to_goal = variant.get( 'oracle_dataset_using_set_to_goal', False) random_rollout_data = variant.get('random_rollout_data', False) random_rollout_data_set_to_goal = variant.get( 'random_rollout_data_set_to_goal', True) random_and_oracle_policy_data = variant.get( 'random_and_oracle_policy_data', False) random_and_oracle_policy_data_split = variant.get( 'random_and_oracle_policy_data_split', 0) policy_file = variant.get('policy_file', None) n_random_steps = variant.get('n_random_steps', 100) vae_dataset_specific_env_kwargs = variant.get( 'vae_dataset_specific_env_kwargs', None) save_file_prefix = variant.get('save_file_prefix', None) non_presampled_goal_img_is_garbage = variant.get( 'non_presampled_goal_img_is_garbage', None) conditional_vae_dataset = variant.get('conditional_vae_dataset', False) use_env_labels = variant.get('use_env_labels', False) use_linear_dynamics = variant.get('use_linear_dynamics', False) enviorment_dataset = variant.get('enviorment_dataset', False) save_trajectories = variant.get('save_trajectories', False) save_trajectories = save_trajectories or use_linear_dynamics or conditional_vae_dataset tag = variant.get('tag', '') from multiworld.core.image_env import ImageEnv, unormalize_image import rlkit.torch.pytorch_util as ptu from rlkit.misc.asset_loader import load_local_or_remote_file from rlkit.data_management.dataset import \ TrajectoryDataset, ImageObservationDataset, InitialObservationDataset, EnvironmentDataset, ConditionalDynamicsDataset info = {} if dataset_path is not None: dataset = load_local_or_remote_file(dataset_path) dataset = dataset.item() N = dataset['observations'].shape[0] * dataset['observations'].shape[1] n_random_steps = dataset['observations'].shape[1] else: if env_kwargs is None: env_kwargs = {} if save_file_prefix is None: save_file_prefix = env_id if save_file_prefix is None: save_file_prefix = env_class.__name__ filename = "/tmp/{}_N{}_{}_imsize{}_random_oracle_split_{}{}.npy".format( save_file_prefix, str(N), init_camera.__name__ if init_camera and hasattr(init_camera, '__name__') else '', imsize, random_and_oracle_policy_data_split, tag, ) if use_cached and osp.isfile(filename): dataset = np.load(filename) if conditional_vae_dataset: dataset = dataset.item() print("loaded data from saved file", filename) else: now = time.time() if env_id is not None: import gym import multiworld multiworld.register_all_envs() env = gym.make(env_id) else: if vae_dataset_specific_env_kwargs is None: vae_dataset_specific_env_kwargs = {} for key, val in env_kwargs.items(): if key not in vae_dataset_specific_env_kwargs: vae_dataset_specific_env_kwargs[key] = val env = env_class(**vae_dataset_specific_env_kwargs) if not isinstance(env, ImageEnv): env = ImageEnv( env, imsize, init_camera=init_camera, transpose=True, normalize=True, non_presampled_goal_img_is_garbage= non_presampled_goal_img_is_garbage, ) else: imsize = env.imsize env.non_presampled_goal_img_is_garbage = non_presampled_goal_img_is_garbage env.reset() info['env'] = env if random_and_oracle_policy_data: policy_file = load_local_or_remote_file(policy_file) policy = policy_file['policy'] policy.to(ptu.device) if random_rollout_data: from rlkit.exploration_strategies.ou_strategy import OUStrategy policy = OUStrategy(env.action_space) if save_trajectories: dataset = { 'observations': np.zeros((N // n_random_steps, n_random_steps, imsize * imsize * num_channels), dtype=np.uint8), 'actions': np.zeros((N // n_random_steps, n_random_steps, env.action_space.shape[0]), dtype=np.float), 'env': np.zeros( (N // n_random_steps, imsize * imsize * num_channels), dtype=np.uint8), } else: dataset = np.zeros((N, imsize * imsize * num_channels), dtype=np.uint8) labels = [] for i in range(N): if random_and_oracle_policy_data: num_random_steps = int(N * random_and_oracle_policy_data_split) if i < num_random_steps: env.reset() for _ in range(n_random_steps): obs = env.step(env.action_space.sample())[0] else: obs = env.reset() policy.reset() for _ in range(n_random_steps): policy_obs = np.hstack(( obs['state_observation'], obs['state_desired_goal'], )) action, _ = policy.get_action(policy_obs) obs, _, _, _ = env.step(action) elif random_rollout_data: #ADD DATA WHERE JUST PUCK MOVES if i % n_random_steps == 0: env.reset() policy.reset() env_img = env._get_obs()['image_observation'] if random_rollout_data_set_to_goal: env.set_to_goal(env.get_goal()) obs = env._get_obs() u = policy.get_action_from_raw_action( env.action_space.sample()) env.step(u) elif oracle_dataset_using_set_to_goal: print(i) goal = env.sample_goal() env.set_to_goal(goal) obs = env._get_obs() else: env.reset() for _ in range(n_random_steps): obs = env.step(env.action_space.sample())[0] img = obs['image_observation'] if use_env_labels: labels.append(obs['label']) if save_trajectories: dataset['observations'][ i // n_random_steps, i % n_random_steps, :] = unormalize_image(img) dataset['actions'][i // n_random_steps, i % n_random_steps, :] = u dataset['env'][i // n_random_steps, :] = unormalize_image( env_img) else: dataset[i, :] = unormalize_image(img) if show: img = img.reshape(3, imsize, imsize).transpose() img = img[::-1, :, ::-1] cv2.imshow('img', img) cv2.waitKey(1) # radius = input('waiting...') print("done making training data", filename, time.time() - now) np.save(filename, dataset) np.save(filename[:-4] + 'labels.npy', np.array(labels)) info['train_labels'] = [] info['test_labels'] = [] if use_linear_dynamics and conditional_vae_dataset: num_trajectories = N // n_random_steps n = int(num_trajectories * test_p) train_dataset = ConditionalDynamicsDataset({ 'observations': dataset['observations'][:n, :, :], 'actions': dataset['actions'][:n, :, :], 'env': dataset['env'][:n, :] }) test_dataset = ConditionalDynamicsDataset({ 'observations': dataset['observations'][n:, :, :], 'actions': dataset['actions'][n:, :, :], 'env': dataset['env'][n:, :] }) num_trajectories = N // n_random_steps n = int(num_trajectories * test_p) indices = np.arange(num_trajectories) np.random.shuffle(indices) train_i, test_i = indices[:n], indices[n:] try: train_dataset = ConditionalDynamicsDataset({ 'observations': dataset['observations'][train_i, :, :], 'actions': dataset['actions'][train_i, :, :], 'env': dataset['env'][train_i, :] }) test_dataset = ConditionalDynamicsDataset({ 'observations': dataset['observations'][test_i, :, :], 'actions': dataset['actions'][test_i, :, :], 'env': dataset['env'][test_i, :] }) except: train_dataset = ConditionalDynamicsDataset({ 'observations': dataset['observations'][train_i, :, :], 'actions': dataset['actions'][train_i, :, :], }) test_dataset = ConditionalDynamicsDataset({ 'observations': dataset['observations'][test_i, :, :], 'actions': dataset['actions'][test_i, :, :], }) elif use_linear_dynamics: num_trajectories = N // n_random_steps n = int(num_trajectories * test_p) train_dataset = TrajectoryDataset({ 'observations': dataset['observations'][:n, :, :], 'actions': dataset['actions'][:n, :, :] }) test_dataset = TrajectoryDataset({ 'observations': dataset['observations'][n:, :, :], 'actions': dataset['actions'][n:, :, :] }) elif enviorment_dataset: n = int(n_random_steps * test_p) train_dataset = EnvironmentDataset({ 'observations': dataset['observations'][:, :n, :], }) test_dataset = EnvironmentDataset({ 'observations': dataset['observations'][:, n:, :], }) elif conditional_vae_dataset: num_trajectories = N // n_random_steps n = int(num_trajectories * test_p) indices = np.arange(num_trajectories) np.random.shuffle(indices) train_i, test_i = indices[:n], indices[n:] try: train_dataset = InitialObservationDataset({ 'observations': dataset['observations'][train_i, :, :], 'env': dataset['env'][train_i, :] }) test_dataset = InitialObservationDataset({ 'observations': dataset['observations'][test_i, :, :], 'env': dataset['env'][test_i, :] }) except: train_dataset = InitialObservationDataset({ 'observations': dataset['observations'][train_i, :, :], }) test_dataset = InitialObservationDataset({ 'observations': dataset['observations'][test_i, :, :], }) else: n = int(N * test_p) train_dataset = ImageObservationDataset(dataset[:n, :]) test_dataset = ImageObservationDataset(dataset[n:, :]) return train_dataset, test_dataset, info
def __init__( self, wrapped_env, model, vae_input_key_prefix='image', sample_from_true_prior=False, decode_goals=False, render_goals=False, render_rollouts=False, reward_params=None, goal_sampling_mode="vae_prior", imsize=84, obs_size=None, norm_order=2, epsilon=20, presampled_goals=None, ): if reward_params is None: reward_params = dict() super().__init__(wrapped_env) if type(model) is str: self.vae = load_local_or_remote_file(model) else: self.vae = model self.representation_size = self.vae.representation_size self.input_channels = self.vae.input_channels self.sample_from_true_prior = sample_from_true_prior self._decode_goals = decode_goals self.render_goals = render_goals self.render_rollouts = render_rollouts self.default_kwargs=dict( decode_goals=decode_goals, render_goals=render_goals, render_rollouts=render_rollouts, ) self.imsize = imsize self.reward_params = reward_params self.reward_type = self.reward_params.get("type", 'latent_distance') self.norm_order = self.reward_params.get("norm_order", norm_order) self.epsilon = self.reward_params.get("epsilon", epsilon) self.reward_min_variance = self.reward_params.get("min_variance", 0) latent_space = Box( -10 * np.ones(obs_size or self.representation_size), 10 * np.ones(obs_size or self.representation_size), dtype=np.float32, ) spaces = self.wrapped_env.observation_space.spaces spaces['observation'] = latent_space spaces['desired_goal'] = latent_space spaces['achieved_goal'] = latent_space spaces['latent_observation'] = latent_space spaces['latent_desired_goal'] = latent_space spaces['latent_achieved_goal'] = latent_space self.observation_space = Dict(spaces) self._presampled_goals = presampled_goals if self._presampled_goals is None: self.num_goals_presampled = 0 else: self.num_goals_presampled = presampled_goals[random.choice(list(presampled_goals))].shape[0] self.vae_input_key_prefix = vae_input_key_prefix assert vae_input_key_prefix in {'image', 'image_proprio'} self.vae_input_observation_key = vae_input_key_prefix + '_observation' self.vae_input_achieved_goal_key = vae_input_key_prefix + '_achieved_goal' self.vae_input_desired_goal_key = vae_input_key_prefix + '_desired_goal' self._mode_map = {} self.desired_goal = {'latent_desired_goal': latent_space.sample()} self._initial_obs = None self._custom_goal_sampler = None self._goal_sampling_mode = goal_sampling_mode
def experiment(variant): if variant.get("pretrained_algorithm_path", False): resume(variant) return normalize_env = variant.get('normalize_env', True) env_id = variant.get('env_id', None) env_class = variant.get('env_class', None) env_kwargs = variant.get('env_kwargs', {}) expl_env = make(env_id, env_class, env_kwargs, normalize_env) eval_env = make(env_id, env_class, env_kwargs, normalize_env) if variant.get('add_env_demos', False): variant["path_loader_kwargs"]["demo_paths"].append( variant["env_demo_path"]) if variant.get('add_env_offpolicy_data', False): variant["path_loader_kwargs"]["demo_paths"].append( variant["env_offpolicy_data_path"]) path_loader_kwargs = variant.get("path_loader_kwargs", {}) stack_obs = path_loader_kwargs.get("stack_obs", 1) if stack_obs > 1: expl_env = StackObservationEnv(expl_env, stack_obs=stack_obs) eval_env = StackObservationEnv(eval_env, stack_obs=stack_obs) obs_dim = expl_env.observation_space.low.size action_dim = eval_env.action_space.low.size if hasattr(expl_env, 'info_sizes'): env_info_sizes = expl_env.info_sizes else: env_info_sizes = dict() qf_kwargs = variant.get("qf_kwargs", {}) qf1 = ConcatMlp(input_size=obs_dim + action_dim, output_size=1, **qf_kwargs) qf2 = ConcatMlp(input_size=obs_dim + action_dim, output_size=1, **qf_kwargs) target_qf1 = ConcatMlp(input_size=obs_dim + action_dim, output_size=1, **qf_kwargs) target_qf2 = ConcatMlp(input_size=obs_dim + action_dim, output_size=1, **qf_kwargs) policy_class = variant.get("policy_class", TanhGaussianPolicy) policy_kwargs = variant['policy_kwargs'] policy_path = variant.get("policy_path", False) if policy_path: policy = load_local_or_remote_file(policy_path) else: policy = policy_class( obs_dim=obs_dim, action_dim=action_dim, **policy_kwargs, ) buffer_policy_path = variant.get("buffer_policy_path", False) if buffer_policy_path: buffer_policy = load_local_or_remote_file(buffer_policy_path) else: buffer_policy_class = variant.get("buffer_policy_class", policy_class) buffer_policy = buffer_policy_class( obs_dim=obs_dim, action_dim=action_dim, **variant.get("buffer_policy_kwargs", policy_kwargs), ) eval_policy = MakeDeterministic(policy) eval_path_collector = MdpPathCollector( eval_env, eval_policy, ) expl_policy = policy exploration_kwargs = variant.get('exploration_kwargs', {}) if exploration_kwargs: if exploration_kwargs.get("deterministic_exploration", False): expl_policy = MakeDeterministic(policy) exploration_strategy = exploration_kwargs.get("strategy", None) if exploration_strategy is None: pass elif exploration_strategy == 'ou': es = OUStrategy( action_space=expl_env.action_space, max_sigma=exploration_kwargs['noise'], min_sigma=exploration_kwargs['noise'], ) expl_policy = PolicyWrappedWithExplorationStrategy( exploration_strategy=es, policy=expl_policy, ) elif exploration_strategy == 'gauss_eps': es = GaussianAndEpislonStrategy( action_space=expl_env.action_space, max_sigma=exploration_kwargs['noise'], min_sigma=exploration_kwargs['noise'], # constant sigma epsilon=0, ) expl_policy = PolicyWrappedWithExplorationStrategy( exploration_strategy=es, policy=expl_policy, ) else: error if variant.get('replay_buffer_class', EnvReplayBuffer) == AWREnvReplayBuffer: main_replay_buffer_kwargs = variant['replay_buffer_kwargs'] main_replay_buffer_kwargs['env'] = expl_env main_replay_buffer_kwargs['qf1'] = qf1 main_replay_buffer_kwargs['qf2'] = qf2 main_replay_buffer_kwargs['policy'] = policy else: main_replay_buffer_kwargs = dict( max_replay_buffer_size=variant['replay_buffer_size'], env=expl_env, ) replay_buffer_kwargs = dict( max_replay_buffer_size=variant['replay_buffer_size'], env=expl_env, ) replay_buffer = variant.get('replay_buffer_class', EnvReplayBuffer)(**main_replay_buffer_kwargs, ) if variant.get('use_validation_buffer', False): train_replay_buffer = replay_buffer validation_replay_buffer = variant.get( 'replay_buffer_class', EnvReplayBuffer)(**main_replay_buffer_kwargs, ) replay_buffer = SplitReplayBuffer(train_replay_buffer, validation_replay_buffer, 0.9) trainer_class = variant.get("trainer_class", AWACTrainer) trainer = trainer_class(env=eval_env, policy=policy, qf1=qf1, qf2=qf2, target_qf1=target_qf1, target_qf2=target_qf2, buffer_policy=buffer_policy, **variant['trainer_kwargs']) if variant['collection_mode'] == 'online': expl_path_collector = MdpStepCollector( expl_env, policy, ) algorithm = TorchOnlineRLAlgorithm( trainer=trainer, exploration_env=expl_env, evaluation_env=eval_env, exploration_data_collector=expl_path_collector, evaluation_data_collector=eval_path_collector, replay_buffer=replay_buffer, max_path_length=variant['max_path_length'], batch_size=variant['batch_size'], num_epochs=variant['num_epochs'], num_eval_steps_per_epoch=variant['num_eval_steps_per_epoch'], num_expl_steps_per_train_loop=variant[ 'num_expl_steps_per_train_loop'], num_trains_per_train_loop=variant['num_trains_per_train_loop'], min_num_steps_before_training=variant[ 'min_num_steps_before_training'], ) else: expl_path_collector = MdpPathCollector( expl_env, expl_policy, ) algorithm = TorchBatchRLAlgorithm( trainer=trainer, exploration_env=expl_env, evaluation_env=eval_env, exploration_data_collector=expl_path_collector, evaluation_data_collector=eval_path_collector, replay_buffer=replay_buffer, max_path_length=variant['max_path_length'], batch_size=variant['batch_size'], num_epochs=variant['num_epochs'], num_eval_steps_per_epoch=variant['num_eval_steps_per_epoch'], num_expl_steps_per_train_loop=variant[ 'num_expl_steps_per_train_loop'], num_trains_per_train_loop=variant['num_trains_per_train_loop'], min_num_steps_before_training=variant[ 'min_num_steps_before_training'], ) algorithm.to(ptu.device) demo_train_buffer = EnvReplayBuffer(**replay_buffer_kwargs, ) demo_test_buffer = EnvReplayBuffer(**replay_buffer_kwargs, ) if variant.get("save_video", False): if variant.get("presampled_goals", None): variant['image_env_kwargs'][ 'presampled_goals'] = load_local_or_remote_file( variant['presampled_goals']).item() def get_img_env(env): renderer = EnvRenderer(**variant["renderer_kwargs"]) img_env = InsertImageEnv(GymToMultiEnv(env), renderer=renderer) image_eval_env = ImageEnv(GymToMultiEnv(eval_env), **variant["image_env_kwargs"]) # image_eval_env = get_img_env(eval_env) image_eval_path_collector = ObsDictPathCollector( image_eval_env, eval_policy, observation_key="state_observation", ) image_expl_env = ImageEnv(GymToMultiEnv(expl_env), **variant["image_env_kwargs"]) # image_expl_env = get_img_env(expl_env) image_expl_path_collector = ObsDictPathCollector( image_expl_env, expl_policy, observation_key="state_observation", ) video_func = VideoSaveFunction( image_eval_env, variant, image_expl_path_collector, image_eval_path_collector, ) algorithm.post_train_funcs.append(video_func) if variant.get('save_paths', False): algorithm.post_train_funcs.append(save_paths) if variant.get('load_demos', False): path_loader_class = variant.get('path_loader_class', MDPPathLoader) path_loader = path_loader_class(trainer, replay_buffer=replay_buffer, demo_train_buffer=demo_train_buffer, demo_test_buffer=demo_test_buffer, **path_loader_kwargs) path_loader.load_demos() if variant.get('load_env_dataset_demos', False): path_loader_class = variant.get('path_loader_class', HDF5PathLoader) path_loader = path_loader_class(trainer, replay_buffer=replay_buffer, demo_train_buffer=demo_train_buffer, demo_test_buffer=demo_test_buffer, **path_loader_kwargs) path_loader.load_demos(expl_env.get_dataset()) if variant.get('save_initial_buffers', False): buffers = dict( replay_buffer=replay_buffer, demo_train_buffer=demo_train_buffer, demo_test_buffer=demo_test_buffer, ) buffer_path = osp.join(logger.get_snapshot_dir(), 'buffers.p') pickle.dump(buffers, open(buffer_path, "wb")) if variant.get('pretrain_buffer_policy', False): trainer.pretrain_policy_with_bc( buffer_policy, replay_buffer.train_replay_buffer, replay_buffer.validation_replay_buffer, 10000, label="buffer", ) if variant.get('pretrain_policy', False): trainer.pretrain_policy_with_bc( policy, demo_train_buffer, demo_test_buffer, trainer.bc_num_pretrain_steps, ) if variant.get('pretrain_rl', False): trainer.pretrain_q_with_bc_data() if variant.get('save_pretrained_algorithm', False): p_path = osp.join(logger.get_snapshot_dir(), 'pretrain_algorithm.p') pt_path = osp.join(logger.get_snapshot_dir(), 'pretrain_algorithm.pt') data = algorithm._get_snapshot() data['algorithm'] = algorithm torch.save(data, open(pt_path, "wb")) torch.save(data, open(p_path, "wb")) if variant.get('train_rl', True): algorithm.train()
import torch import numpy as np import pickle from rlkit.misc.asset_loader import load_local_or_remote_file import rlkit.torch.pytorch_util as ptu vae_path = '/home/khazatsky/rail/data/rail-khazatsky/sasha/PCVAE/DCVAE/run20/id0/itr_600.pkl' # vae_path = '/home/shikharbahl/research/rlkit-private/data/local/shikhar/corl2019/pointmass/real/run0/id0/vae.pkl' vae = load_local_or_remote_file(vae_path) dataset_path = '/home/khazatsky/rail/data/train_data.npy' dataset = load_local_or_remote_file(dataset_path).item() import matplotlib.pyplot as plt traj = dataset['observations'][17] n = traj.shape[0] x0 = traj[0] x0 = ptu.from_numpy(x0.reshape(1, -1)) goal = traj[-1] vae = vae.cpu() latent_goal = vae.encode(ptu.from_numpy(goal.reshape(1, -1)), x0, distrib=False) decoded_goal, _ = vae.decode(latent_goal, x0) log_probs = [] distances = []
def experiment(variant): # Use any random seed, and not the user provided seed seed = np.random.randint(10, 1000) # if not os.path.exists("./results"): # os.makedirs("./results") # if args.env_name == 'Multigoal-v0': # env = point_mass.MultiGoalEnv(distance_cost_coeff=10.0) env_name = variant["env_name"] env = gym.make(env_name) env_params = ENV_PARAMS[env_name] variant.update(env_params) env.seed(seed) torch.manual_seed(seed) np.random.seed(seed) state_dim = env.observation_space.shape[0] action_dim = env.action_space.shape[0] max_action = float(env.action_space.high[0]) print(state_dim, action_dim) print('Max action: ', max_action) log_dir = osp.join(railrl_logger.get_snapshot_dir(), "log") # don't clobber setup_logger(variant=variant, log_dir=log_dir) algo_kwargs = variant["algo_kwargs"] algo_name = variant["algorithm"] if algo_name == 'BCQ': policy = algos.BCQ(state_dim, action_dim, max_action) elif algo_name == 'TD3': policy = TD3.TD3(state_dim, action_dim, max_action) elif algo_name == 'BC': policy = algos.BCQ(state_dim, action_dim, max_action, cloning=True) elif algo_name == 'DQfD': policy = algos.DQfD(state_dim, action_dim, max_action, lambda_=variant["lamda"], margin_threshold=float( variant["margin_threshold"])) elif algo_name == 'KLControl': policy = algos.KLControl(2, state_dim, action_dim, max_action) elif algo_name == 'BEAR': policy = algos.BEAR( 2, state_dim, action_dim, max_action, delta_conf=0.1, use_bootstrap=False, **algo_kwargs, ) elif algo_name == 'BEAR_IS': policy = algos.BEAR_IS( 2, state_dim, action_dim, max_action, delta_conf=0.1, use_bootstrap=False, **algo_kwargs, ) # Load buffer replay_buffer = utils.ReplayBuffer() if variant["env_name"] == 'Multigoal-v0': replay_buffer.load_point_mass(buffer_name, bootstrap_dim=4, dist_cost_coeff=0.01) else: for off_policy_kwargs in variant.get("off_policy_data"): file_path = off_policy_kwargs.pop("path") demo_data = load_local_or_remote_file(file_path) replay_buffer.load_data(demo_data, bootstrap_dim=4, trajs=True, **off_policy_kwargs) evaluations = [] episode_num = 0 done = True training_iters = 0 while training_iters < variant["max_timesteps"]: pol_vals = policy.train(replay_buffer, iterations=int(variant["eval_freq"])) ret_eval, var_ret, median_ret = evaluate_policy(env, policy) evaluations.append(ret_eval) np.save(osp.join(log_dir, "results.npy"), evaluations) training_iters += variant["eval_freq"] print("Training iterations: " + str(training_iters)) logger.record_tabular('Training Epochs', int(training_iters // int(variant["eval_freq"]))) logger.record_tabular('AverageReturn', ret_eval) logger.record_tabular('VarianceReturn', var_ret) logger.record_tabular('MedianReturn', median_ret) logger.dump_tabular()
def rig_experiment( max_path_length, qf_kwargs, sac_trainer_kwargs, replay_buffer_kwargs, policy_kwargs, algo_kwargs, train_vae_kwargs, env_id=None, env_class=None, env_kwargs=None, observation_key='latent_observation', desired_goal_key='latent_desired_goal', state_goal_key='state_desired_goal', state_observation_key='state_observation', image_goal_key='image_desired_goal', exploration_policy_kwargs=None, evaluation_goal_sampling_mode=None, exploration_goal_sampling_mode=None, # Video parameters save_video=True, save_video_kwargs=None, renderer_kwargs=None, imsize=48, pretrained_vae_path="", init_camera=None, ): if exploration_policy_kwargs is None: exploration_policy_kwargs = {} if not save_video_kwargs: save_video_kwargs = {} if not renderer_kwargs: renderer_kwargs = {} renderer = EnvRenderer(init_camera=init_camera, **renderer_kwargs) def contextual_env_distrib_and_reward(env_id, env_class, env_kwargs, goal_sampling_mode): state_env = get_gym_env(env_id, env_class=env_class, env_kwargs=env_kwargs) renderer = EnvRenderer(init_camera=init_camera, **renderer_kwargs) img_env = InsertImageEnv(state_env, renderer=renderer) encoded_env = EncoderWrappedEnv( img_env, model, dict(image_observation="latent_observation", ), ) if goal_sampling_mode == "vae_prior": latent_goal_distribution = PriorDistribution( model.representation_size, desired_goal_key, ) diagnostics = StateImageGoalDiagnosticsFn({}, ) elif goal_sampling_mode == "reset_of_env": state_goal_env = get_gym_env(env_id, env_class=env_class, env_kwargs=env_kwargs) state_goal_distribution = GoalDictDistributionFromMultitaskEnv( state_goal_env, desired_goal_keys=[state_goal_key], ) image_goal_distribution = AddImageDistribution( env=state_env, base_distribution=state_goal_distribution, image_goal_key=image_goal_key, renderer=renderer, ) latent_goal_distribution = AddLatentDistribution( image_goal_distribution, image_goal_key, desired_goal_key, model, ) if hasattr(state_goal_env, 'goal_conditioned_diagnostics'): diagnostics = GoalConditionedDiagnosticsToContextualDiagnostics( state_goal_env.goal_conditioned_diagnostics, desired_goal_key=state_goal_key, observation_key=state_observation_key, ) else: state_goal_env.get_contextual_diagnostics diagnostics = state_goal_env.get_contextual_diagnostics else: raise NotImplementedError('unknown goal sampling method: %s' % goal_sampling_mode) reward_fn = DistanceRewardFn( observation_key=observation_key, desired_goal_key=desired_goal_key, ) env = ContextualEnv( encoded_env, context_distribution=latent_goal_distribution, reward_fn=reward_fn, observation_key=observation_key, contextual_diagnostics_fns=[diagnostics], ) return env, latent_goal_distribution, reward_fn if pretrained_vae_path: model = load_local_or_remote_file(pretrained_vae_path) else: model = train_vae(train_vae_kwargs, env_kwargs, env_id, env_class, imsize, init_camera) expl_env, expl_context_distrib, expl_reward = contextual_env_distrib_and_reward( env_id, env_class, env_kwargs, exploration_goal_sampling_mode) eval_env, eval_context_distrib, eval_reward = contextual_env_distrib_and_reward( env_id, env_class, env_kwargs, evaluation_goal_sampling_mode) context_key = desired_goal_key obs_dim = (expl_env.observation_space.spaces[observation_key].low.size + expl_env.observation_space.spaces[context_key].low.size) action_dim = expl_env.action_space.low.size def create_qf(): return ConcatMlp(input_size=obs_dim + action_dim, output_size=1, **qf_kwargs) qf1 = create_qf() qf2 = create_qf() target_qf1 = create_qf() target_qf2 = create_qf() policy = TanhGaussianPolicy(obs_dim=obs_dim, action_dim=action_dim, **policy_kwargs) def concat_context_to_obs(batch, *args, **kwargs): obs = batch['observations'] next_obs = batch['next_observations'] context = batch[context_key] batch['observations'] = np.concatenate([obs, context], axis=1) batch['next_observations'] = np.concatenate([next_obs, context], axis=1) return batch replay_buffer = ContextualRelabelingReplayBuffer( env=eval_env, context_keys=[context_key], observation_keys=[observation_key], observation_key=observation_key, context_distribution=expl_context_distrib, sample_context_from_obs_dict_fn=RemapKeyFn( {context_key: observation_key}), reward_fn=eval_reward, post_process_batch_fn=concat_context_to_obs, **replay_buffer_kwargs) trainer = SACTrainer(env=expl_env, policy=policy, qf1=qf1, qf2=qf2, target_qf1=target_qf1, target_qf2=target_qf2, **sac_trainer_kwargs) eval_path_collector = ContextualPathCollector( eval_env, MakeDeterministic(policy), observation_key=observation_key, context_keys_for_policy=[ context_key, ], ) exploration_policy = create_exploration_policy(expl_env, policy, **exploration_policy_kwargs) expl_path_collector = ContextualPathCollector( expl_env, exploration_policy, observation_key=observation_key, context_keys_for_policy=[ context_key, ], ) algorithm = TorchBatchRLAlgorithm( trainer=trainer, exploration_env=expl_env, evaluation_env=eval_env, exploration_data_collector=expl_path_collector, evaluation_data_collector=eval_path_collector, replay_buffer=replay_buffer, max_path_length=max_path_length, **algo_kwargs) algorithm.to(ptu.device) if save_video: expl_video_func = RIGVideoSaveFunction( model, expl_path_collector, "train", decode_goal_image_key="image_decoded_goal", reconstruction_key="image_reconstruction", rows=2, columns=5, unnormalize=True, # max_path_length=200, imsize=48, image_format=renderer.output_image_format, **save_video_kwargs) algorithm.post_train_funcs.append(expl_video_func) eval_video_func = RIGVideoSaveFunction( model, eval_path_collector, "eval", goal_image_key=image_goal_key, decode_goal_image_key="image_decoded_goal", reconstruction_key="image_reconstruction", num_imgs=4, rows=2, columns=5, unnormalize=True, # max_path_length=200, imsize=48, image_format=renderer.output_image_format, **save_video_kwargs) algorithm.post_train_funcs.append(eval_video_func) algorithm.train()
def generate_goal_dataset_using_policy( env=None, num_goals=1000, use_cached_dataset=False, policy_file=None, show=False, path_length=500, save_file_prefix=None, env_id=None, tag='', ): if isinstance(env, ImageEnv): env_class_name = env._wrapped_env.__class__.__name__ else: env_class_name = env._wrapped_env.wrapped_env.__class__.__name__ if save_file_prefix is None and env_id is not None: save_file_prefix = env_id elif save_file_prefix is None: save_file_prefix = env_class_name filename = "/tmp/{}_N{}_imsize{}goals{}.npy".format( save_file_prefix, str(num_goals), env.imsize, tag, ) if use_cached_dataset and osp.isfile(filename): goal_dict = np.load(filename).item() print("Loaded data from {}".format(filename)) return goal_dict goal_generation_dict = dict() for goal_key, obs_key in [ ('image_desired_goal', 'image_achieved_goal'), ('state_desired_goal', 'state_achieved_goal'), ]: goal_size = env.observation_space.spaces[goal_key].low.size goal_generation_dict[goal_key] = [goal_size, obs_key] goal_dict = dict() policy_file = load_local_or_remote_file(policy_file) policy = policy_file['policy'] policy.to(ptu.device) for goal_key in goal_generation_dict: goal_size, obs_key = goal_generation_dict[goal_key] goal_dict[goal_key] = np.zeros((num_goals, goal_size)) print('Generating Random Goals') for j in range(num_goals): obs = env.reset() policy.reset() for i in range(path_length): policy_obs = np.hstack(( obs['state_observation'], obs['state_desired_goal'], )) action, _ = policy.get_action(policy_obs) obs, _, _, _ = env.step(action) if show: img = obs['image_observation'] img = img.reshape(3, env.imsize, env.imsize).transpose() img = img[::-1, :, ::-1] cv2.imshow('img', img) cv2.waitKey(1) for goal_key in goal_generation_dict: goal_size, obs_key = goal_generation_dict[goal_key] goal_dict[goal_key][j, :] = obs[obs_key] np.save(filename, goal_dict) print("Saving file to {}".format(filename)) return goal_dict
def experiment(variant): if variant.get("pretrained_algorithm_path", False): resume(variant) return if 'env' in variant: env_params = ENV_PARAMS[variant['env']] variant.update(env_params) if 'env_id' in env_params: if env_params['env_id'] in ['pen-v0', 'pen-sparse-v0', 'door-v0', 'relocate-v0', 'hammer-v0', 'pen-sparse-v0', 'door-sparse-v0', 'relocate-sparse-v0', 'hammer-sparse-v0']: import mj_envs expl_env = gym.make(env_params['env_id']) eval_env = gym.make(env_params['env_id']) else: expl_env = NormalizedBoxEnv(variant['env_class']()) eval_env = NormalizedBoxEnv(variant['env_class']()) if variant.get('sparse_reward', False): expl_env = RewardWrapperEnv(expl_env, compute_hand_sparse_reward) eval_env = RewardWrapperEnv(eval_env, compute_hand_sparse_reward) if variant.get('add_env_demos', False): variant["path_loader_kwargs"]["demo_paths"].append(variant["env_demo_path"]) if variant.get('add_env_offpolicy_data', False): variant["path_loader_kwargs"]["demo_paths"].append(variant["env_offpolicy_data_path"]) else: expl_env = encoder_wrapped_env(variant) eval_env = encoder_wrapped_env(variant) path_loader_kwargs = variant.get("path_loader_kwargs", {}) stack_obs = path_loader_kwargs.get("stack_obs", 1) if stack_obs > 1: expl_env = StackObservationEnv(expl_env, stack_obs=stack_obs) eval_env = StackObservationEnv(eval_env, stack_obs=stack_obs) obs_dim = expl_env.observation_space.low.size action_dim = eval_env.action_space.low.size if hasattr(expl_env, 'info_sizes'): env_info_sizes = expl_env.info_sizes else: env_info_sizes = dict() M = variant['layer_size'] vf_kwargs = variant.get("vf_kwargs", {}) vf1 = ConcatMlp( input_size=obs_dim, output_size=1, hidden_sizes=[M, M], **vf_kwargs ) target_vf1 = ConcatMlp( input_size=obs_dim, output_size=1, hidden_sizes=[M, M], **vf_kwargs ) policy_class = variant.get("policy_class", TanhGaussianPolicy) policy_kwargs = variant['policy_kwargs'] policy = policy_class( obs_dim=obs_dim, action_dim=action_dim, **policy_kwargs, ) target_policy = policy_class( obs_dim=obs_dim, action_dim=action_dim, **policy_kwargs, ) buffer_policy_class = variant.get("buffer_policy_class", policy_class) buffer_policy = buffer_policy_class( obs_dim=obs_dim, action_dim=action_dim, **variant.get("buffer_policy_kwargs", policy_kwargs), ) eval_policy = MakeDeterministic(policy) eval_path_collector = MdpPathCollector( eval_env, eval_policy, ) expl_policy = policy exploration_kwargs = variant.get('exploration_kwargs', {}) if exploration_kwargs: if exploration_kwargs.get("deterministic_exploration", False): expl_policy = MakeDeterministic(policy) exploration_strategy = exploration_kwargs.get("strategy", None) if exploration_strategy is None: pass elif exploration_strategy == 'ou': es = OUStrategy( action_space=expl_env.action_space, max_sigma=exploration_kwargs['noise'], min_sigma=exploration_kwargs['noise'], ) expl_policy = PolicyWrappedWithExplorationStrategy( exploration_strategy=es, policy=expl_policy, ) elif exploration_strategy == 'gauss_eps': es = GaussianAndEpislonStrategy( action_space=expl_env.action_space, max_sigma=exploration_kwargs['noise'], min_sigma=exploration_kwargs['noise'], # constant sigma epsilon=0, ) expl_policy = PolicyWrappedWithExplorationStrategy( exploration_strategy=es, policy=expl_policy, ) else: error if variant.get('replay_buffer_class', EnvReplayBuffer) == AWREnvReplayBuffer: main_replay_buffer_kwargs = variant['replay_buffer_kwargs'] main_replay_buffer_kwargs['env'] = expl_env main_replay_buffer_kwargs['qf1'] = qf1 main_replay_buffer_kwargs['qf2'] = qf2 main_replay_buffer_kwargs['policy'] = policy else: main_replay_buffer_kwargs=dict( max_replay_buffer_size=variant['replay_buffer_size'], env=expl_env, ) replay_buffer_kwargs = dict( max_replay_buffer_size=variant['replay_buffer_size'], env=expl_env, ) replay_buffer = variant.get('replay_buffer_class', EnvReplayBuffer)( **main_replay_buffer_kwargs, ) if variant.get('use_validation_buffer', False): train_replay_buffer = replay_buffer validation_replay_buffer = variant.get('replay_buffer_class', EnvReplayBuffer)( **main_replay_buffer_kwargs, ) replay_buffer = SplitReplayBuffer(train_replay_buffer, validation_replay_buffer, 0.9) trainer_class = variant.get("trainer_class", QuinoaTrainer) trainer = trainer_class( env=eval_env, policy=policy, vf1=vf1, target_policy=target_policy, target_vf1=target_vf1, buffer_policy=buffer_policy, **variant['trainer_kwargs'] ) if variant['collection_mode'] == 'online': expl_path_collector = MdpStepCollector( expl_env, policy, ) algorithm = TorchOnlineRLAlgorithm( trainer=trainer, exploration_env=expl_env, evaluation_env=eval_env, exploration_data_collector=expl_path_collector, evaluation_data_collector=eval_path_collector, replay_buffer=replay_buffer, max_path_length=variant['max_path_length'], batch_size=variant['batch_size'], num_epochs=variant['num_epochs'], num_eval_steps_per_epoch=variant['num_eval_steps_per_epoch'], num_expl_steps_per_train_loop=variant['num_expl_steps_per_train_loop'], num_trains_per_train_loop=variant['num_trains_per_train_loop'], min_num_steps_before_training=variant['min_num_steps_before_training'], ) else: expl_path_collector = MdpPathCollector( expl_env, expl_policy, ) algorithm = TorchBatchRLAlgorithm( trainer=trainer, exploration_env=expl_env, evaluation_env=eval_env, exploration_data_collector=expl_path_collector, evaluation_data_collector=eval_path_collector, replay_buffer=replay_buffer, max_path_length=variant['max_path_length'], batch_size=variant['batch_size'], num_epochs=variant['num_epochs'], num_eval_steps_per_epoch=variant['num_eval_steps_per_epoch'], num_expl_steps_per_train_loop=variant['num_expl_steps_per_train_loop'], num_trains_per_train_loop=variant['num_trains_per_train_loop'], min_num_steps_before_training=variant['min_num_steps_before_training'], ) algorithm.to(ptu.device) demo_train_buffer = EnvReplayBuffer( **replay_buffer_kwargs, ) demo_test_buffer = EnvReplayBuffer( **replay_buffer_kwargs, ) if variant.get("save_video", False): if variant.get("presampled_goals", None): variant['image_env_kwargs']['presampled_goals'] = load_local_or_remote_file(variant['presampled_goals']).item() image_eval_env = ImageEnv(GymToMultiEnv(eval_env), **variant["image_env_kwargs"]) image_eval_path_collector = ObsDictPathCollector( image_eval_env, eval_policy, observation_key="state_observation", ) image_expl_env = ImageEnv(GymToMultiEnv(expl_env), **variant["image_env_kwargs"]) image_expl_path_collector = ObsDictPathCollector( image_expl_env, expl_policy, observation_key="state_observation", ) video_func = VideoSaveFunction( image_eval_env, variant, image_expl_path_collector, image_eval_path_collector, ) algorithm.post_train_funcs.append(video_func) if variant.get('save_paths', False): algorithm.post_train_funcs.append(save_paths) if variant.get('load_demos', False): path_loader_class = variant.get('path_loader_class', MDPPathLoader) path_loader = path_loader_class(trainer, replay_buffer=replay_buffer, demo_train_buffer=demo_train_buffer, demo_test_buffer=demo_test_buffer, **path_loader_kwargs ) path_loader.load_demos() if variant.get('save_initial_buffers', False): buffers = dict( replay_buffer=replay_buffer, demo_train_buffer=demo_train_buffer, demo_test_buffer=demo_test_buffer, ) buffer_path = osp.join(logger.get_snapshot_dir(), 'buffers.p') pickle.dump(buffers, open(buffer_path, "wb")) if variant.get('pretrain_policy', False): trainer.pretrain_policy_with_bc() if variant.get('pretrain_rl', False): trainer.pretrain_q_with_bc_data() if variant.get('save_pretrained_algorithm', False): p_path = osp.join(logger.get_snapshot_dir(), 'pretrain_algorithm.p') pt_path = osp.join(logger.get_snapshot_dir(), 'pretrain_algorithm.pt') data = algorithm._get_snapshot() data['algorithm'] = algorithm torch.save(data, open(pt_path, "wb")) torch.save(data, open(p_path, "wb")) if variant.get('train_rl', True): algorithm.train()
def experiment(variant): render = variant.get("render", False) debug = variant.get("debug", False) if variant.get("pretrained_algorithm_path", False): resume(variant) return env_class = variant["env_class"] env_kwargs = variant["env_kwargs"] expl_env = env_class(**env_kwargs) eval_env = env_class(**env_kwargs) env = eval_env if variant.get('sparse_reward', False): expl_env = RewardWrapperEnv(expl_env, compute_hand_sparse_reward) eval_env = RewardWrapperEnv(eval_env, compute_hand_sparse_reward) if variant.get('add_env_demos', False): variant["path_loader_kwargs"]["demo_paths"].append(variant["env_demo_path"]) if variant.get('add_env_offpolicy_data', False): variant["path_loader_kwargs"]["demo_paths"].append(variant["env_offpolicy_data_path"]) if variant.get("use_masks", False): mask_wrapper_kwargs = variant.get("mask_wrapper_kwargs", dict()) expl_mask_distribution_kwargs = variant["expl_mask_distribution_kwargs"] expl_mask_distribution = DiscreteDistribution(**expl_mask_distribution_kwargs) expl_env = RewardMaskWrapper(env, expl_mask_distribution, **mask_wrapper_kwargs) eval_mask_distribution_kwargs = variant["eval_mask_distribution_kwargs"] eval_mask_distribution = DiscreteDistribution(**eval_mask_distribution_kwargs) eval_env = RewardMaskWrapper(env, eval_mask_distribution, **mask_wrapper_kwargs) env = eval_env path_loader_kwargs = variant.get("path_loader_kwargs", {}) stack_obs = path_loader_kwargs.get("stack_obs", 1) if stack_obs > 1: expl_env = StackObservationEnv(expl_env, stack_obs=stack_obs) eval_env = StackObservationEnv(eval_env, stack_obs=stack_obs) observation_key = variant.get('observation_key', 'latent_observation') desired_goal_key = variant.get('desired_goal_key', 'latent_desired_goal') achieved_goal_key = variant.get('achieved_goal_key', 'latent_achieved_goal') obs_dim = ( env.observation_space.spaces[observation_key].low.size + env.observation_space.spaces[desired_goal_key].low.size ) action_dim = eval_env.action_space.low.size if hasattr(expl_env, 'info_sizes'): env_info_sizes = expl_env.info_sizes else: env_info_sizes = dict() replay_buffer_kwargs=dict( env=env, observation_key=observation_key, desired_goal_key=desired_goal_key, achieved_goal_key=achieved_goal_key, ) replay_buffer_kwargs.update(variant.get('replay_buffer_kwargs', dict())) replay_buffer = ConcatToObsWrapper( ObsDictRelabelingBuffer(**replay_buffer_kwargs), ["resampled_goals", ], ) replay_buffer_kwargs.update(variant.get('demo_replay_buffer_kwargs', dict())) demo_train_buffer = ConcatToObsWrapper( ObsDictRelabelingBuffer(**replay_buffer_kwargs), ["resampled_goals", ], ) demo_test_buffer = ConcatToObsWrapper( ObsDictRelabelingBuffer(**replay_buffer_kwargs), ["resampled_goals", ], ) M = variant['layer_size'] qf1 = ConcatMlp( input_size=obs_dim + action_dim, output_size=1, hidden_sizes=[M, M], ) qf2 = ConcatMlp( input_size=obs_dim + action_dim, output_size=1, hidden_sizes=[M, M], ) target_qf1 = ConcatMlp( input_size=obs_dim + action_dim, output_size=1, hidden_sizes=[M, M], ) target_qf2 = ConcatMlp( input_size=obs_dim + action_dim, output_size=1, hidden_sizes=[M, M], ) policy_class = variant.get("policy_class", TanhGaussianPolicy) policy_kwargs = variant['policy_kwargs'] policy_path = variant.get("policy_path", False) if policy_path: policy = load_local_or_remote_file(policy_path) else: policy = policy_class( obs_dim=obs_dim, action_dim=action_dim, **policy_kwargs, ) buffer_policy_path = variant.get("buffer_policy_path", False) if buffer_policy_path: buffer_policy = load_local_or_remote_file(buffer_policy_path) else: buffer_policy_class = variant.get("buffer_policy_class", policy_class) buffer_policy = buffer_policy_class( obs_dim=obs_dim, action_dim=action_dim, **variant.get("buffer_policy_kwargs", policy_kwargs), ) expl_policy = policy exploration_kwargs = variant.get('exploration_kwargs', {}) if exploration_kwargs: if exploration_kwargs.get("deterministic_exploration", False): expl_policy = MakeDeterministic(policy) exploration_strategy = exploration_kwargs.get("strategy", None) if exploration_strategy is None: pass elif exploration_strategy == 'ou': es = OUStrategy( action_space=expl_env.action_space, max_sigma=exploration_kwargs['noise'], min_sigma=exploration_kwargs['noise'], ) expl_policy = PolicyWrappedWithExplorationStrategy( exploration_strategy=es, policy=expl_policy, ) elif exploration_strategy == 'gauss_eps': es = GaussianAndEpislonStrategy( action_space=expl_env.action_space, max_sigma=exploration_kwargs['noise'], min_sigma=exploration_kwargs['noise'], # constant sigma epsilon=0, ) expl_policy = PolicyWrappedWithExplorationStrategy( exploration_strategy=es, policy=expl_policy, ) else: error trainer = AWACTrainer( env=eval_env, policy=policy, qf1=qf1, qf2=qf2, target_qf1=target_qf1, target_qf2=target_qf2, buffer_policy=buffer_policy, **variant['trainer_kwargs'] ) if variant['collection_mode'] == 'online': expl_path_collector = MdpStepCollector( expl_env, policy, ) algorithm = TorchOnlineRLAlgorithm( trainer=trainer, exploration_env=expl_env, evaluation_env=eval_env, exploration_data_collector=expl_path_collector, evaluation_data_collector=eval_path_collector, replay_buffer=replay_buffer, max_path_length=variant['max_path_length'], batch_size=variant['batch_size'], num_epochs=variant['num_epochs'], num_eval_steps_per_epoch=variant['num_eval_steps_per_epoch'], num_expl_steps_per_train_loop=variant['num_expl_steps_per_train_loop'], num_trains_per_train_loop=variant['num_trains_per_train_loop'], min_num_steps_before_training=variant['min_num_steps_before_training'], ) else: eval_path_collector = GoalConditionedPathCollector( eval_env, MakeDeterministic(policy), observation_key=observation_key, desired_goal_key=desired_goal_key, render=render, ) expl_path_collector = GoalConditionedPathCollector( expl_env, expl_policy, observation_key=observation_key, desired_goal_key=desired_goal_key, render=render, ) algorithm = TorchBatchRLAlgorithm( trainer=trainer, exploration_env=expl_env, evaluation_env=eval_env, exploration_data_collector=expl_path_collector, evaluation_data_collector=eval_path_collector, replay_buffer=replay_buffer, max_path_length=variant['max_path_length'], batch_size=variant['batch_size'], num_epochs=variant['num_epochs'], num_eval_steps_per_epoch=variant['num_eval_steps_per_epoch'], num_expl_steps_per_train_loop=variant['num_expl_steps_per_train_loop'], num_trains_per_train_loop=variant['num_trains_per_train_loop'], min_num_steps_before_training=variant['min_num_steps_before_training'], ) algorithm.to(ptu.device) if variant.get("save_video", False): renderer_kwargs = variant.get("renderer_kwargs", {}) save_video_kwargs = variant.get("save_video_kwargs", {}) def get_video_func( env, policy, tag, ): renderer = EnvRenderer(**renderer_kwargs) state_goal_distribution = GoalDictDistributionFromMultitaskEnv( env, desired_goal_keys=[desired_goal_key], ) image_goal_distribution = AddImageDistribution( env=env, base_distribution=state_goal_distribution, image_goal_key='image_desired_goal', renderer=renderer, ) img_env = InsertImageEnv(env, renderer=renderer) rollout_function = partial( rf.multitask_rollout, max_path_length=variant['max_path_length'], observation_key=observation_key, desired_goal_key=desired_goal_key, return_dict_obs=True, ) reward_fn = ContextualRewardFnFromMultitaskEnv( env=env, achieved_goal_from_observation=IndexIntoAchievedGoal(observation_key), desired_goal_key=desired_goal_key, achieved_goal_key="state_achieved_goal", ) contextual_env = ContextualEnv( img_env, context_distribution=image_goal_distribution, reward_fn=reward_fn, observation_key=observation_key, ) video_func = get_save_video_function( rollout_function, contextual_env, policy, tag=tag, imsize=renderer.width, image_format='CWH', **save_video_kwargs ) return video_func expl_video_func = get_video_func(expl_env, expl_policy, "expl") eval_video_func = get_video_func(eval_env, MakeDeterministic(policy), "eval") algorithm.post_train_funcs.append(eval_video_func) algorithm.post_train_funcs.append(expl_video_func) if variant.get('save_paths', False): algorithm.post_train_funcs.append(save_paths) if variant.get('load_demos', False): path_loader_class = variant.get('path_loader_class', MDPPathLoader) path_loader = path_loader_class(trainer, replay_buffer=replay_buffer, demo_train_buffer=demo_train_buffer, demo_test_buffer=demo_test_buffer, **path_loader_kwargs ) path_loader.load_demos() if variant.get('pretrain_policy', False): trainer.pretrain_policy_with_bc( policy, demo_train_buffer, demo_test_buffer, trainer.bc_num_pretrain_steps, ) if variant.get('pretrain_rl', False): trainer.pretrain_q_with_bc_data() if variant.get('save_pretrained_algorithm', False): p_path = osp.join(logger.get_snapshot_dir(), 'pretrain_algorithm.p') pt_path = osp.join(logger.get_snapshot_dir(), 'pretrain_algorithm.pt') data = algorithm._get_snapshot() data['algorithm'] = algorithm torch.save(data, open(pt_path, "wb")) torch.save(data, open(p_path, "wb")) algorithm.train()
def _e2e_disentangled_experiment(max_path_length, encoder_kwargs, disentangled_qf_kwargs, qf_kwargs, twin_sac_trainer_kwargs, replay_buffer_kwargs, policy_kwargs, vae_evaluation_goal_sampling_mode, vae_exploration_goal_sampling_mode, base_env_evaluation_goal_sampling_mode, base_env_exploration_goal_sampling_mode, algo_kwargs, env_id=None, env_class=None, env_kwargs=None, observation_key='state_observation', desired_goal_key='state_desired_goal', achieved_goal_key='state_achieved_goal', latent_dim=2, vae_wrapped_env_kwargs=None, vae_path=None, vae_n_vae_training_kwargs=None, vectorized=False, save_video=True, save_video_kwargs=None, have_no_disentangled_encoder=False, **kwargs): if env_kwargs is None: env_kwargs = {} assert env_id or env_class if env_id: import gym import multiworld multiworld.register_all_envs() train_env = gym.make(env_id) eval_env = gym.make(env_id) else: eval_env = env_class(**env_kwargs) train_env = env_class(**env_kwargs) train_env.goal_sampling_mode = base_env_exploration_goal_sampling_mode eval_env.goal_sampling_mode = base_env_evaluation_goal_sampling_mode if vae_path: vae = load_local_or_remote_file(vae_path) else: vae = get_n_train_vae(latent_dim=latent_dim, env=eval_env, **vae_n_vae_training_kwargs) train_env = VAEWrappedEnv(train_env, vae, imsize=train_env.imsize, **vae_wrapped_env_kwargs) eval_env = VAEWrappedEnv(eval_env, vae, imsize=train_env.imsize, **vae_wrapped_env_kwargs) obs_dim = train_env.observation_space.spaces[observation_key].low.size goal_dim = train_env.observation_space.spaces[desired_goal_key].low.size action_dim = train_env.action_space.low.size encoder = ConcatMlp(input_size=obs_dim, output_size=latent_dim, **encoder_kwargs) def make_qf(): if have_no_disentangled_encoder: return ConcatMlp( input_size=obs_dim + goal_dim + action_dim, output_size=1, **qf_kwargs, ) else: return DisentangledMlpQf(encoder=encoder, preprocess_obs_dim=obs_dim, action_dim=action_dim, qf_kwargs=qf_kwargs, vectorized=vectorized, **disentangled_qf_kwargs) qf1 = make_qf() qf2 = make_qf() target_qf1 = make_qf() target_qf2 = make_qf() policy = TanhGaussianPolicy(obs_dim=obs_dim + goal_dim, action_dim=action_dim, **policy_kwargs) replay_buffer = ObsDictRelabelingBuffer( env=train_env, observation_key=observation_key, desired_goal_key=desired_goal_key, achieved_goal_key=achieved_goal_key, vectorized=vectorized, **replay_buffer_kwargs) sac_trainer = SACTrainer(env=train_env, policy=policy, qf1=qf1, qf2=qf2, target_qf1=target_qf1, target_qf2=target_qf2, **twin_sac_trainer_kwargs) trainer = HERTrainer(sac_trainer) eval_path_collector = VAEWrappedEnvPathCollector( eval_env, MakeDeterministic(policy), max_path_length, observation_key=observation_key, desired_goal_key=desired_goal_key, goal_sampling_mode=vae_evaluation_goal_sampling_mode, ) expl_path_collector = VAEWrappedEnvPathCollector( train_env, policy, max_path_length, observation_key=observation_key, desired_goal_key=desired_goal_key, goal_sampling_mode=vae_exploration_goal_sampling_mode, ) algorithm = TorchBatchRLAlgorithm( trainer=trainer, exploration_env=train_env, evaluation_env=eval_env, exploration_data_collector=expl_path_collector, evaluation_data_collector=eval_path_collector, replay_buffer=replay_buffer, max_path_length=max_path_length, **algo_kwargs, ) algorithm.to(ptu.device) if save_video: save_vf_heatmap = save_video_kwargs.get('save_vf_heatmap', True) if have_no_disentangled_encoder: def v_function(obs): action = policy.get_actions(obs) obs, action = ptu.from_numpy(obs), ptu.from_numpy(action) return qf1(obs, action) add_heatmap = partial(add_heatmap_img_to_o_dict, v_function=v_function) else: def v_function(obs): action = policy.get_actions(obs) obs, action = ptu.from_numpy(obs), ptu.from_numpy(action) return qf1(obs, action, return_individual_q_vals=True) add_heatmap = partial( add_heatmap_imgs_to_o_dict, v_function=v_function, vectorized=vectorized, ) rollout_function = rf.create_rollout_function( rf.multitask_rollout, max_path_length=max_path_length, observation_key=observation_key, desired_goal_key=desired_goal_key, full_o_postprocess_func=add_heatmap if save_vf_heatmap else None, ) img_keys = ['v_vals'] + [ 'v_vals_dim_{}'.format(dim) for dim in range(latent_dim) ] eval_video_func = get_save_video_function(rollout_function, eval_env, MakeDeterministic(policy), get_extra_imgs=partial( get_extra_imgs, img_keys=img_keys), tag="eval", **save_video_kwargs) train_video_func = get_save_video_function(rollout_function, train_env, policy, get_extra_imgs=partial( get_extra_imgs, img_keys=img_keys), tag="train", **save_video_kwargs) algorithm.post_train_funcs.append(eval_video_func) algorithm.post_train_funcs.append(train_video_func) algorithm.train()
def HER_baseline_td3_experiment(variant): import rlkit.torch.pytorch_util as ptu from rlkit.data_management.obs_dict_replay_buffer import \ ObsDictRelabelingBuffer from rlkit.exploration_strategies.base import ( PolicyWrappedWithExplorationStrategy) from rlkit.torch.her.her_td3 import HerTd3 from rlkit.torch.networks import MergedCNN, CNNPolicy import torch from multiworld.core.image_env import ImageEnv from rlkit.misc.asset_loader import load_local_or_remote_file init_camera = variant.get("init_camera", None) presample_goals = variant.get('presample_goals', False) presampled_goals_path = get_presampled_goals_path( variant.get('presampled_goals_path', None)) if 'env_id' in variant: import gym import multiworld multiworld.register_all_envs() env = gym.make(variant['env_id']) else: env = variant["env_class"](**variant['env_kwargs']) image_env = ImageEnv( env, variant.get('imsize'), reward_type='image_sparse', init_camera=init_camera, transpose=True, normalize=True, ) if presample_goals: if presampled_goals_path is None: image_env.non_presampled_goal_img_is_garbage = True presampled_goals = variant['generate_goal_dataset_fctn']( env=image_env, **variant['goal_generation_kwargs']) else: presampled_goals = load_local_or_remote_file( presampled_goals_path).item() del image_env env = ImageEnv( env, variant.get('imsize'), reward_type='image_distance', init_camera=init_camera, transpose=True, normalize=True, presampled_goals=presampled_goals, ) else: env = image_env es = get_exploration_strategy(variant, env) observation_key = variant.get('observation_key', 'image_observation') desired_goal_key = variant.get('desired_goal_key', 'image_desired_goal') achieved_goal_key = desired_goal_key.replace("desired", "achieved") imsize = variant['imsize'] action_dim = env.action_space.low.size qf1 = MergedCNN(input_width=imsize, input_height=imsize, output_size=1, input_channels=3 * 2, added_fc_input_size=action_dim, **variant['cnn_params']) qf2 = MergedCNN(input_width=imsize, input_height=imsize, output_size=1, input_channels=3 * 2, added_fc_input_size=action_dim, **variant['cnn_params']) policy = CNNPolicy( input_width=imsize, input_height=imsize, added_fc_input_size=0, output_size=action_dim, input_channels=3 * 2, output_activation=torch.tanh, **variant['cnn_params'], ) target_qf1 = MergedCNN(input_width=imsize, input_height=imsize, output_size=1, input_channels=3 * 2, added_fc_input_size=action_dim, **variant['cnn_params']) target_qf2 = MergedCNN(input_width=imsize, input_height=imsize, output_size=1, input_channels=3 * 2, added_fc_input_size=action_dim, **variant['cnn_params']) target_policy = CNNPolicy( input_width=imsize, input_height=imsize, added_fc_input_size=0, output_size=action_dim, input_channels=3 * 2, output_activation=torch.tanh, **variant['cnn_params'], ) exploration_policy = PolicyWrappedWithExplorationStrategy( exploration_strategy=es, policy=policy, ) replay_buffer = ObsDictRelabelingBuffer( env=env, observation_key=observation_key, desired_goal_key=desired_goal_key, achieved_goal_key=achieved_goal_key, **variant['replay_buffer_kwargs']) algo_kwargs = variant['algo_kwargs'] algo_kwargs['replay_buffer'] = replay_buffer base_kwargs = algo_kwargs['base_kwargs'] base_kwargs['training_env'] = env base_kwargs['render'] = variant["render"] base_kwargs['render_during_eval'] = variant["render"] her_kwargs = algo_kwargs['her_kwargs'] her_kwargs['observation_key'] = observation_key her_kwargs['desired_goal_key'] = desired_goal_key algorithm = HerTd3(env, qf1=qf1, qf2=qf2, policy=policy, target_qf1=target_qf1, target_qf2=target_qf2, target_policy=target_policy, exploration_policy=exploration_policy, **variant['algo_kwargs']) algorithm.to(ptu.device) algorithm.train()
def get_envs(variant): from multiworld.core.image_env import ImageEnv from rlkit.envs.vae_wrappers import VAEWrappedEnv, ConditionalVAEWrappedEnv from rlkit.misc.asset_loader import load_local_or_remote_file from rlkit.torch.vae.conditional_conv_vae import CVAE, CDVAE, ACE, CADVAE, DeltaCVAE render = variant.get('render', False) vae_path = variant.get("vae_path", None) reward_params = variant.get("reward_params", dict()) init_camera = variant.get("init_camera", None) do_state_exp = variant.get("do_state_exp", False) presample_goals = variant.get('presample_goals', False) presample_image_goals_only = variant.get('presample_image_goals_only', False) presampled_goals_path = get_presampled_goals_path( variant.get('presampled_goals_path', None)) vae = load_local_or_remote_file( vae_path) if type(vae_path) is str else vae_path if 'env_id' in variant: import gym import multiworld multiworld.register_all_envs() env = gym.make(variant['env_id']) else: env = variant["env_class"](**variant['env_kwargs']) if not do_state_exp: if isinstance(env, ImageEnv): image_env = env else: image_env = ImageEnv( env, variant.get('imsize'), init_camera=init_camera, transpose=True, normalize=True, ) if presample_goals: """ This will fail for online-parallel as presampled_goals will not be serialized. Also don't use this for online-vae. """ if presampled_goals_path is None: image_env.non_presampled_goal_img_is_garbage = True vae_env = VAEWrappedEnv(image_env, vae, imsize=image_env.imsize, decode_goals=render, render_goals=render, render_rollouts=render, reward_params=reward_params, **variant.get('vae_wrapped_env_kwargs', {})) presampled_goals = variant['generate_goal_dataset_fctn']( env=vae_env, env_id=variant.get('env_id', None), **variant['goal_generation_kwargs']) del vae_env else: presampled_goals = load_local_or_remote_file( presampled_goals_path).item() del image_env image_env = ImageEnv(env, variant.get('imsize'), init_camera=init_camera, transpose=True, normalize=True, presampled_goals=presampled_goals, **variant.get('image_env_kwargs', {})) vae_env = VAEWrappedEnv(image_env, vae, imsize=image_env.imsize, decode_goals=render, render_goals=render, render_rollouts=render, reward_params=reward_params, presampled_goals=presampled_goals, **variant.get('vae_wrapped_env_kwargs', {})) print("Presampling all goals only") else: if type(vae) is CVAE or type(vae) is CDVAE or type( vae) is ACE or type(vae) is CADVAE or type( vae) is DeltaCVAE: vae_env = ConditionalVAEWrappedEnv( image_env, vae, imsize=image_env.imsize, decode_goals=render, render_goals=render, render_rollouts=render, reward_params=reward_params, **variant.get('vae_wrapped_env_kwargs', {})) else: vae_env = VAEWrappedEnv(image_env, vae, imsize=image_env.imsize, decode_goals=render, render_goals=render, render_rollouts=render, reward_params=reward_params, **variant.get('vae_wrapped_env_kwargs', {})) if presample_image_goals_only: presampled_goals = variant['generate_goal_dataset_fctn']( image_env=vae_env.wrapped_env, **variant['goal_generation_kwargs']) image_env.set_presampled_goals(presampled_goals) print("Presampling image goals only") else: print("Not using presampled goals") env = vae_env return env