def experiment(variant): from railrl.core import logger import railrl.torch.pytorch_util as ptu beta = variant["beta"] representation_size = variant["representation_size"] train_data, test_data, info = generate_vae_dataset( **variant['get_data_kwargs']) logger.save_extra_data(info) logger.get_snapshot_dir() if 'beta_schedule_kwargs' in variant: beta_schedule = PiecewiseLinearSchedule( **variant['beta_schedule_kwargs']) else: beta_schedule = None m = ConvVAE(representation_size, input_channels=3) if ptu.gpu_enabled(): m.to(ptu.device) gpu_id = variant.get("gpu_id", None) if gpu_id is not None: ptu.set_device(gpu_id) t = ConvVAETrainer(train_data, test_data, m, beta=beta, beta_schedule=beta_schedule, **variant['algo_kwargs']) save_period = variant['save_period'] for epoch in range(variant['num_epochs']): should_save_imgs = (epoch % save_period == 0) t.train_epoch(epoch) t.test_epoch(epoch, save_reconstruction=should_save_imgs, save_scatterplot=should_save_imgs) if should_save_imgs: t.dump_samples(epoch)
def experiment(variant): from railrl.core import logger beta = variant["beta"] representation_size = variant["representation_size"] train_data, test_data, info = generate_dataset( ) logger.save_extra_data(info) logger.get_snapshot_dir() if 'beta_schedule_kwargs' in variant: kwargs = variant['beta_schedule_kwargs'] kwargs['y_values'][2] = variant['beta'] kwargs['x_values'][1] = variant['flat_x'] kwargs['x_values'][2] = variant['ramp_x'] + variant['flat_x'] beta_schedule = PiecewiseLinearSchedule(**variant['beta_schedule_kwargs']) else: beta_schedule = None output_scale=1 if variant['algo_kwargs']['is_auto_encoder']: m = AutoEncoder(representation_size, train_data.shape[1], output_scale=output_scale, **variant['vae_kwargs'] ) else: m = VAE(representation_size, train_data.shape[1], output_scale=output_scale, **variant['vae_kwargs'] ) t = VAETrainer(train_data, test_data, m, beta=beta, beta_schedule=beta_schedule, **variant['algo_kwargs']) for epoch in range(variant['num_epochs']): t.train_epoch(epoch) t.test_epoch(epoch)
def train_vae(variant, return_data=False): from railrl.misc.ml_util import PiecewiseLinearSchedule from railrl.torch.vae.vae_trainer import ConvVAETrainer from railrl.core import logger beta = variant["beta"] use_linear_dynamics = variant.get('use_linear_dynamics', False) generate_vae_dataset_fctn = variant.get('generate_vae_data_fctn', generate_vae_dataset) variant['generate_vae_dataset_kwargs'][ 'use_linear_dynamics'] = use_linear_dynamics train_dataset, test_dataset, info = generate_vae_dataset_fctn( variant['generate_vae_dataset_kwargs']) if use_linear_dynamics: action_dim = train_dataset.data['actions'].shape[2] else: action_dim = 0 model = get_vae(variant, action_dim) logger.save_extra_data(info) logger.get_snapshot_dir() if 'beta_schedule_kwargs' in variant: beta_schedule = PiecewiseLinearSchedule( **variant['beta_schedule_kwargs']) else: beta_schedule = None vae_trainer_class = variant.get('vae_trainer_class', ConvVAETrainer) trainer = vae_trainer_class(model, beta=beta, beta_schedule=beta_schedule, **variant['algo_kwargs']) save_period = variant['save_period'] dump_skew_debug_plots = variant.get('dump_skew_debug_plots', False) for epoch in range(variant['num_epochs']): should_save_imgs = (epoch % save_period == 0) trainer.train_epoch(epoch, train_dataset) trainer.test_epoch(epoch, test_dataset) if should_save_imgs: trainer.dump_reconstructions(epoch) trainer.dump_samples(epoch) if dump_skew_debug_plots: trainer.dump_best_reconstruction(epoch) trainer.dump_worst_reconstruction(epoch) trainer.dump_sampling_histogram(epoch) stats = trainer.get_diagnostics() for k, v in stats.items(): logger.record_tabular(k, v) logger.dump_tabular() trainer.end_epoch(epoch) if epoch % 50 == 0: logger.save_itr_params(epoch, model) logger.save_extra_data(model, 'vae.pkl', mode='pickle') if return_data: return model, train_dataset, test_dataset return model
def experiment(variant): from railrl.core import logger import railrl.torch.pytorch_util as ptu beta = variant["beta"] representation_size = variant["representation_size"] train_data, test_data, info = variant['generate_vae_dataset_fn']( variant['generate_vae_dataset_kwargs']) uniform_dataset = load_local_or_remote_file( variant['uniform_dataset_path']).item() uniform_dataset = unormalize_image(uniform_dataset['image_desired_goal']) logger.save_extra_data(info) logger.get_snapshot_dir() if 'beta_schedule_kwargs' in variant: # kwargs = variant['beta_schedule_kwargs'] # kwargs['y_values'][2] = variant['beta'] # kwargs['x_values'][1] = variant['flat_x'] # kwargs['x_values'][2] = variant['ramp_x'] + variant['flat_x'] variant['beta_schedule_kwargs']['y_values'][-1] = variant['beta'] beta_schedule = PiecewiseLinearSchedule( **variant['beta_schedule_kwargs']) else: beta_schedule = None m = variant['vae'](representation_size, decoder_output_activation=nn.Sigmoid(), **variant['vae_kwargs']) m.to(ptu.device) t = ConvVAETrainer(train_data, test_data, m, beta=beta, beta_schedule=beta_schedule, **variant['algo_kwargs']) save_period = variant['save_period'] for epoch in range(variant['num_epochs']): should_save_imgs = (epoch % save_period == 0) t.train_epoch(epoch) t.log_loss_under_uniform( m, uniform_dataset, variant['algo_kwargs']['priority_function_kwargs']) t.test_epoch(epoch, save_reconstruction=should_save_imgs, save_scatterplot=should_save_imgs) if should_save_imgs: t.dump_samples(epoch) if variant['dump_skew_debug_plots']: t.dump_best_reconstruction(epoch) t.dump_worst_reconstruction(epoch) t.dump_sampling_histogram(epoch) t.dump_uniform_imgs_and_reconstructions( dataset=uniform_dataset, epoch=epoch) t.update_train_weights()
def experiment(variant): from railrl.core import logger import railrl.torch.pytorch_util as ptu beta = variant["beta"] representation_size = variant["representation_size"] # train_data, test_data, info = generate_vae_dataset( # **variant['get_data_kwargs'] # ) num_divisions = 5 images = np.zeros((num_divisions * 10000, 21168)) for i in range(num_divisions): imgs = np.load( '/home/murtaza/vae_data/sawyer_torque_control_images100000_' + str(i + 1) + '.npy') images[i * 10000:(i + 1) * 10000] = imgs print(i) mid = int(num_divisions * 10000 * .9) train_data, test_data = images[:mid], images[mid:] info = dict() logger.save_extra_data(info) logger.get_snapshot_dir() if 'beta_schedule_kwargs' in variant: kwargs = variant['beta_schedule_kwargs'] kwargs['y_values'][2] = variant['beta'] kwargs['x_values'][1] = variant['flat_x'] kwargs['x_values'][2] = variant['ramp_x'] + variant['flat_x'] beta_schedule = PiecewiseLinearSchedule( **variant['beta_schedule_kwargs']) else: beta_schedule = None m = ConvVAE(representation_size, input_channels=3, **variant['conv_vae_kwargs']) if ptu.gpu_enabled(): m.cuda() t = ConvVAETrainer(train_data, test_data, m, beta=beta, beta_schedule=beta_schedule, **variant['algo_kwargs']) save_period = variant['save_period'] for epoch in range(variant['num_epochs']): should_save_imgs = (epoch % save_period == 0) t.train_epoch(epoch) t.test_epoch(epoch, save_reconstruction=should_save_imgs, save_scatterplot=should_save_imgs) if should_save_imgs: t.dump_samples(epoch)
def experiment(variant): from railrl.core import logger import railrl.torch.pytorch_util as ptu beta = variant["beta"] representation_size = variant["representation_size"] data = joblib.load(variant['file']) obs = data['obs'] size = int(data['size']) dataset = obs[:size, :] n = int(size * .9) train_data = dataset[:n, :] test_data = dataset[n:, :] logger.get_snapshot_dir() print('SIZE: ', size) uniform_dataset = generate_uniform_dataset_door( **variant['generate_uniform_dataset_kwargs'] ) logger.get_snapshot_dir() if 'beta_schedule_kwargs' in variant: # kwargs = variant['beta_schedule_kwargs'] # kwargs['y_values'][2] = variant['beta'] # kwargs['x_values'][1] = variant['flat_x'] # kwargs['x_values'][2] = variant['ramp_x'] + variant['flat_x'] variant['beta_schedule_kwargs']['y_values'][-1] = variant['beta'] beta_schedule = PiecewiseLinearSchedule(**variant['beta_schedule_kwargs']) else: beta_schedule = None m = variant['vae'](representation_size, decoder_output_activation=nn.Sigmoid(), **variant['vae_kwargs']) m.to(ptu.device) t = ConvVAETrainer(train_data, test_data, m, beta=beta, beta_schedule=beta_schedule, **variant['algo_kwargs']) save_period = variant['save_period'] for epoch in range(variant['num_epochs']): should_save_imgs = (epoch % save_period == 0) t.train_epoch(epoch) t.log_loss_under_uniform(uniform_dataset) t.test_epoch(epoch, save_reconstruction=should_save_imgs, save_scatterplot=should_save_imgs) if should_save_imgs: t.dump_samples(epoch) if variant['dump_skew_debug_plots']: t.dump_best_reconstruction(epoch) t.dump_worst_reconstruction(epoch) t.dump_sampling_histogram(epoch) t.dump_uniform_imgs_and_reconstructions(dataset=uniform_dataset, epoch=epoch) t.update_train_weights()
def experiment(variant): from railrl.core import logger import railrl.torch.pytorch_util as ptu beta = variant["beta"] representation_size = variant["representation_size"] #this has both states and images so can't use generate vae dataset X = np.load( '/home/murtaza/vae_data/sawyer_torque_control_ou_imgs_zoomed_out10000.npy' ) Y = np.load( '/home/murtaza/vae_data/sawyer_torque_control_ou_states_zoomed_out10000.npy' ) Y = np.concatenate((Y[:, :7], Y[:, 14:]), axis=1) X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size=.1) info = dict() logger.save_extra_data(info) logger.get_snapshot_dir() if 'beta_schedule_kwargs' in variant: beta_schedule = PiecewiseLinearSchedule( **variant['beta_schedule_kwargs']) else: beta_schedule = None m = ConvVAE(representation_size, input_channels=3, state_sim_debug=True, state_size=Y.shape[1], **variant['conv_vae_kwargs']) if ptu.gpu_enabled(): m.cuda() t = ConvVAETrainer((X_train, Y_train), (X_test, Y_test), m, beta=beta, beta_schedule=beta_schedule, state_sim_debug=True, **variant['algo_kwargs']) save_period = variant['save_period'] for epoch in range(variant['num_epochs']): should_save_imgs = (epoch % save_period == 0) t.train_epoch(epoch) t.test_epoch(epoch, save_reconstruction=should_save_imgs, save_scatterplot=should_save_imgs) if should_save_imgs: t.dump_samples(epoch)
def experiment(variant): if variant["use_gpu"]: gpu_id = variant["gpu_id"] ptu.set_gpu_mode(True) ptu.set_device(gpu_id) beta = variant["beta"] representation_size = variant["representation_size"] train_data, test_data = get_data(10000) m = ConvVAE(representation_size, input_channels=3) t = ConvVAETrainer(train_data, test_data, m, beta_schedule=PiecewiseLinearSchedule([0, 400, 800], [0.5, 0.5, beta])) for epoch in range(1001): t.train_epoch(epoch) t.test_epoch(epoch) t.dump_samples(epoch)
def experiment(variant): from railrl.core import logger import railrl.torch.pytorch_util as ptu beta = variant["beta"] representation_size = variant["representation_size"] train_data, test_data, info = variant['generate_vae_dataset_fn']( variant['generate_vae_dataset_kwargs']) logger.save_extra_data(info) logger.get_snapshot_dir() if 'beta_schedule_kwargs' in variant: # kwargs = variant['beta_schedule_kwargs'] # kwargs['y_values'][2] = variant['beta'] # kwargs['x_values'][1] = variant['flat_x'] # kwargs['x_values'][2] = variant['ramp_x'] + variant['flat_x'] variant['beta_schedule_kwargs']['y_values'][-1] = variant['beta'] beta_schedule = PiecewiseLinearSchedule( **variant['beta_schedule_kwargs']) else: beta_schedule = None m = variant['vae'](representation_size, **variant['vae_kwargs']) m.to(ptu.device) t = ConvVAETrainer(train_data, test_data, m, beta=beta, beta_schedule=beta_schedule, **variant['algo_kwargs']) save_period = variant['save_period'] for epoch in range(variant['num_epochs']): should_save_imgs = (epoch % save_period == 0) t.train_epoch(epoch) t.test_epoch(epoch, save_reconstruction=should_save_imgs, save_scatterplot=should_save_imgs) if should_save_imgs: t.dump_samples(epoch) if variant['dump_skew_debug_plots']: t.dump_best_reconstruction(epoch) t.dump_worst_reconstruction(epoch) t.dump_sampling_histogram(epoch) t.update_train_weights()
def experiment(variant): from railrl.core import logger import railrl.torch.pytorch_util as ptu beta = variant["beta"] representation_size = variant["representation_size"] train_data, test_data, info = get_data(**variant['get_data_kwargs']) logger.save_extra_data(info) logger.get_snapshot_dir() beta_schedule = PiecewiseLinearSchedule(**variant['beta_schedule_kwargs']) m = ConvVAE(representation_size, input_channels=3) if ptu.gpu_enabled(): m.to(ptu.device) t = ConvVAETrainer(train_data, test_data, m, beta=beta, beta_schedule=beta_schedule, **variant['algo_kwargs']) for epoch in range(variant['num_epochs']): t.train_epoch(epoch) t.test_epoch(epoch) t.dump_samples(epoch)
def experiment(variant): from railrl.core import logger import railrl.torch.pytorch_util as ptu beta = variant["beta"] representation_size = variant["representation_size"] train_data, test_data, info = generate_vae_dataset( **variant['generate_vae_dataset_kwargs']) logger.save_extra_data(info) logger.get_snapshot_dir() if 'beta_schedule_kwargs' in variant: # kwargs = variant['beta_schedule_kwargs'] # kwargs['y_values'][2] = variant['beta'] # kwargs['x_values'][1] = variant['flat_x'] # kwargs['x_values'][2] = variant['ramp_x'] + variant['flat_x'] variant['beta_schedule_kwargs']['y_values'][-1] = variant['beta'] beta_schedule = PiecewiseLinearSchedule( **variant['beta_schedule_kwargs']) else: beta_schedule = None m = ConvVAE(representation_size, input_channels=3, **variant['conv_vae_kwargs']) if ptu.gpu_enabled(): m.cuda() t = ConvVAETrainer(train_data, test_data, m, beta=beta, beta_schedule=beta_schedule, **variant['algo_kwargs']) save_period = variant['save_period'] for epoch in range(variant['num_epochs']): should_save_imgs = (epoch % save_period == 0) t.train_epoch(epoch) t.test_epoch(epoch, save_reconstruction=should_save_imgs, save_scatterplot=should_save_imgs) if should_save_imgs: t.dump_samples(epoch)
def experiment(variant): lmbda = variant['lmbda'] gamma = variant['gamma'] mu = variant['mu'] beta = PiecewiseLinearSchedule([0, 2500, 3500], [0, 0, variant['beta']]) representation_size = variant["representation_size"] train_data, test_data, info = generate_vae_dataset( **variant['generate_vae_dataset_kwargs']) m = ACAI(representation_size, input_channels=3) t = ACAITrainer(train_data, test_data, m, beta_schedule=beta, gamma=gamma, mu=mu, lmbda=lmbda) for epoch in range(6001): t.train_epoch(epoch) t.test_epoch(epoch) if epoch % variant['save_period'] == 0: t.dump_samples(epoch)
def train_vae(variant, return_data=False): from railrl.misc.ml_util import PiecewiseLinearSchedule from railrl.torch.vae.conv_vae import ( ConvVAE, SpatialAutoEncoder, AutoEncoder, ) import railrl.torch.vae.conv_vae as conv_vae from railrl.torch.vae.vae_trainer import ConvVAETrainer from railrl.core import logger import railrl.torch.pytorch_util as ptu from railrl.pythonplusplus import identity import torch beta = variant["beta"] representation_size = variant["representation_size"] train_data, test_data, info = generate_vae_dataset_from_demos( variant['generate_vae_dataset_kwargs']) logger.save_extra_data(info) logger.get_snapshot_dir() if 'beta_schedule_kwargs' in variant: beta_schedule = PiecewiseLinearSchedule( **variant['beta_schedule_kwargs']) else: beta_schedule = None if variant.get('decoder_activation', None) == 'sigmoid': decoder_activation = torch.nn.Sigmoid() else: decoder_activation = identity architecture = variant['vae_kwargs'].get('architecture', None) if not architecture and variant.get('imsize') == 84: architecture = conv_vae.imsize84_default_architecture elif not architecture and variant.get('imsize') == 48: architecture = conv_vae.imsize48_default_architecture variant['vae_kwargs']['architecture'] = architecture variant['vae_kwargs']['imsize'] = variant.get('imsize') if variant['algo_kwargs'].get('is_auto_encoder', False): m = AutoEncoder(representation_size, decoder_output_activation=decoder_activation, **variant['vae_kwargs']) elif variant.get('use_spatial_auto_encoder', False): raise NotImplementedError( 'This is currently broken, please update SpatialAutoEncoder then remove this line' ) m = SpatialAutoEncoder(representation_size, int(representation_size / 2)) else: vae_class = variant.get('vae_class', ConvVAE) m = vae_class(representation_size, decoder_output_activation=decoder_activation, **variant['vae_kwargs']) m.to(ptu.device) t = ConvVAETrainer(train_data, test_data, m, beta=beta, beta_schedule=beta_schedule, **variant['algo_kwargs']) save_period = variant['save_period'] dump_skew_debug_plots = variant.get('dump_skew_debug_plots', False) for epoch in range(variant['num_epochs']): should_save_imgs = (epoch % save_period == 0) t.train_epoch(epoch) t.test_epoch( epoch, save_reconstruction=should_save_imgs, save_scatterplot=should_save_imgs, # save_vae=False, ) if should_save_imgs: t.dump_samples(epoch) if dump_skew_debug_plots: t.dump_best_reconstruction(epoch) t.dump_worst_reconstruction(epoch) t.dump_sampling_histogram(epoch) t.update_train_weights() logger.save_extra_data(m, 'vae.pkl', mode='pickle') if return_data: return m, train_data, test_data return m
class OnlineVaeRelabelingBuffer(ObsDictRelabelingBuffer): def __init__( self, vae, *args, decoded_obs_key='image_observation', decoded_achieved_goal_key='image_achieved_goal', decoded_desired_goal_key='image_desired_goal', exploration_rewards_type='None', exploration_rewards_scale=1.0, vae_priority_type='None', start_skew_epoch=0, power=1.0, internal_keys=None, exploration_schedule_kwargs=None, priority_function_kwargs=None, exploration_counter_kwargs=None, relabeling_goal_sampling_mode='vae_prior', decode_vae_goals=False, **kwargs ): if internal_keys is None: internal_keys = [] for key in [ decoded_obs_key, decoded_achieved_goal_key, decoded_desired_goal_key ]: if key not in internal_keys: internal_keys.append(key) super().__init__(internal_keys=internal_keys, *args, **kwargs) # assert isinstance(self.env, VAEWrappedEnv) self.vae = vae self.decoded_obs_key = decoded_obs_key self.decoded_desired_goal_key = decoded_desired_goal_key self.decoded_achieved_goal_key = decoded_achieved_goal_key self.exploration_rewards_type = exploration_rewards_type self.exploration_rewards_scale = exploration_rewards_scale self.start_skew_epoch = start_skew_epoch self.vae_priority_type = vae_priority_type self.power = power self._relabeling_goal_sampling_mode = relabeling_goal_sampling_mode self.decode_vae_goals = decode_vae_goals if exploration_schedule_kwargs is None: self.explr_reward_scale_schedule = \ ConstantSchedule(self.exploration_rewards_scale) else: self.explr_reward_scale_schedule = \ PiecewiseLinearSchedule(**exploration_schedule_kwargs) self._give_explr_reward_bonus = ( exploration_rewards_type != 'None' and exploration_rewards_scale != 0. ) self._exploration_rewards = np.zeros((self.max_size, 1), dtype=np.float32) self._prioritize_vae_samples = ( vae_priority_type != 'None' and power != 0. ) self._vae_sample_priorities = np.zeros((self.max_size, 1), dtype=np.float32) self._vae_sample_probs = None self.use_dynamics_model = ( self.exploration_rewards_type == 'forward_model_error' ) if self.use_dynamics_model: self.initialize_dynamics_model() type_to_function = { 'reconstruction_error': self.reconstruction_mse, 'bce': self.binary_cross_entropy, 'latent_distance': self.latent_novelty, 'latent_distance_true_prior': self.latent_novelty_true_prior, 'forward_model_error': self.forward_model_error, 'gaussian_inv_prob': self.gaussian_inv_prob, 'bernoulli_inv_prob': self.bernoulli_inv_prob, 'vae_prob': self.vae_prob, 'hash_count': self.hash_count_reward, 'None': self.no_reward, } self.exploration_reward_func = ( type_to_function[self.exploration_rewards_type] ) self.vae_prioritization_func = ( type_to_function[self.vae_priority_type] ) if priority_function_kwargs is None: self.priority_function_kwargs = dict() else: self.priority_function_kwargs = priority_function_kwargs if self.exploration_rewards_type == 'hash_count': if exploration_counter_kwargs is None: exploration_counter_kwargs = dict() self.exploration_counter = CountExploration(env=self.env, **exploration_counter_kwargs) self.epoch = 0 def add_path(self, path): if self.decode_vae_goals: self.add_decoded_vae_goals_to_path(path) super().add_path(path) def add_decoded_vae_goals_to_path(self, path): # decoding the self-sampled vae images should be done in batch (here) # rather than in the env for efficiency desired_goals = flatten_dict( path['observations'], [self.desired_goal_key] )[self.desired_goal_key] desired_decoded_goals = self.env._decode(desired_goals) desired_decoded_goals = desired_decoded_goals.reshape( len(desired_decoded_goals), -1 ) for idx, next_obs in enumerate(path['observations']): path['observations'][idx][self.decoded_desired_goal_key] = \ desired_decoded_goals[idx] path['next_observations'][idx][self.decoded_desired_goal_key] = \ desired_decoded_goals[idx] def random_batch(self, batch_size): batch = super().random_batch(batch_size) exploration_rewards_scale = float(self.explr_reward_scale_schedule.get_value(self.epoch)) if self._give_explr_reward_bonus: batch_idxs = batch['indices'].flatten() batch['exploration_rewards'] = self._exploration_rewards[batch_idxs] batch['rewards'] += exploration_rewards_scale * batch['exploration_rewards'] return batch def get_diagnostics(self): if self._vae_sample_probs is None or self._vae_sample_priorities is None: stats = create_stats_ordered_dict( 'VAE Sample Weights', np.zeros(self._size), ) stats.update(create_stats_ordered_dict( 'VAE Sample Probs', np.zeros(self._size), )) else: vae_sample_priorities = self._vae_sample_priorities[:self._size] vae_sample_probs = self._vae_sample_probs[:self._size] stats = create_stats_ordered_dict( 'VAE Sample Weights', vae_sample_priorities, ) stats.update(create_stats_ordered_dict( 'VAE Sample Probs', vae_sample_probs, )) return stats def refresh_latents(self, epoch): self.epoch = epoch self.skew = (self.epoch > self.start_skew_epoch) batch_size = 512 next_idx = min(batch_size, self._size) if self.exploration_rewards_type == 'hash_count': # you have to count everything then compute exploration rewards cur_idx = 0 next_idx = min(batch_size, self._size) while cur_idx < self._size: idxs = np.arange(cur_idx, next_idx) normalized_imgs = self._next_obs[self.decoded_obs_key][idxs] self.update_hash_count(normalized_imgs) cur_idx = next_idx next_idx += batch_size next_idx = min(next_idx, self._size) cur_idx = 0 obs_sum = np.zeros(self.vae.representation_size) obs_square_sum = np.zeros(self.vae.representation_size) while cur_idx < self._size: idxs = np.arange(cur_idx, next_idx) self._obs[self.observation_key][idxs] = \ self.env._encode(self._obs[self.decoded_obs_key][idxs]) self._next_obs[self.observation_key][idxs] = \ self.env._encode(self._next_obs[self.decoded_obs_key][idxs]) # WARNING: we only refresh the desired/achieved latents for # "next_obs". This means that obs[desired/achieve] will be invalid, # so make sure there's no code that references this. # TODO: enforce this with code and not a comment self._next_obs[self.desired_goal_key][idxs] = \ self.env._encode(self._next_obs[self.decoded_desired_goal_key][idxs]) self._next_obs[self.achieved_goal_key][idxs] = \ self.env._encode(self._next_obs[self.decoded_achieved_goal_key][idxs]) normalized_imgs = self._next_obs[self.decoded_obs_key][idxs] if self._give_explr_reward_bonus: rewards = self.exploration_reward_func( normalized_imgs, idxs, **self.priority_function_kwargs ) self._exploration_rewards[idxs] = rewards.reshape(-1, 1) if self._prioritize_vae_samples: if ( self.exploration_rewards_type == self.vae_priority_type and self._give_explr_reward_bonus ): self._vae_sample_priorities[idxs] = ( self._exploration_rewards[idxs] ) else: self._vae_sample_priorities[idxs] = ( self.vae_prioritization_func( normalized_imgs, idxs, **self.priority_function_kwargs ).reshape(-1, 1) ) obs_sum+= self._obs[self.observation_key][idxs].sum(axis=0) obs_square_sum+= np.power(self._obs[self.observation_key][idxs], 2).sum(axis=0) cur_idx = next_idx next_idx += batch_size next_idx = min(next_idx, self._size) self.vae.dist_mu = obs_sum/self._size self.vae.dist_std = np.sqrt(obs_square_sum/self._size - np.power(self.vae.dist_mu, 2)) if self._prioritize_vae_samples: """ priority^power is calculated in the priority function for image_bernoulli_prob or image_gaussian_inv_prob and directly here if not. """ if self.vae_priority_type == 'vae_prob': self._vae_sample_priorities[:self._size] = relative_probs_from_log_probs( self._vae_sample_priorities[:self._size] ) self._vae_sample_probs = self._vae_sample_priorities[:self._size] else: self._vae_sample_probs = self._vae_sample_priorities[:self._size] ** self.power p_sum = np.sum(self._vae_sample_probs) assert p_sum > 0, "Unnormalized p sum is {}".format(p_sum) self._vae_sample_probs /= np.sum(self._vae_sample_probs) self._vae_sample_probs = self._vae_sample_probs.flatten() def sample_weighted_indices(self, batch_size): if ( self._prioritize_vae_samples and self._vae_sample_probs is not None and self.skew ): indices = np.random.choice( len(self._vae_sample_probs), batch_size, p=self._vae_sample_probs, ) assert ( np.max(self._vae_sample_probs) <= 1 and np.min(self._vae_sample_probs) >= 0 ) else: indices = self._sample_indices(batch_size) return indices def _sample_goals_from_env(self, batch_size): self.env.goal_sampling_mode = self._relabeling_goal_sampling_mode return self.env.sample_goals(batch_size) def sample_buffer_goals(self, batch_size): """ Samples goals from weighted replay buffer for relabeling or exploration. Returns None if replay buffer is empty. Example of what might be returned: dict( image_desired_goals: image_achieved_goals[weighted_indices], latent_desired_goals: latent_desired_goals[weighted_indices], ) """ if self._size == 0: return None weighted_idxs = self.sample_weighted_indices( batch_size, ) next_image_obs = self._next_obs[self.decoded_obs_key][weighted_idxs] next_latent_obs = self._next_obs[self.achieved_goal_key][weighted_idxs] return { self.decoded_desired_goal_key: next_image_obs, self.desired_goal_key: next_latent_obs } def random_vae_training_data(self, batch_size, epoch): # epoch no longer needed. Using self.skew in sample_weighted_indices # instead. weighted_idxs = self.sample_weighted_indices( batch_size, ) next_image_obs = self._next_obs[self.decoded_obs_key][weighted_idxs] observations = ptu.from_numpy(next_image_obs) return dict( observations=observations, ) def reconstruction_mse(self, next_vae_obs, indices): torch_input = ptu.from_numpy(next_vae_obs) recon_next_vae_obs, _, _ = self.vae(torch_input) error = torch_input - recon_next_vae_obs mse = torch.sum(error ** 2, dim=1) return ptu.get_numpy(mse) def gaussian_inv_prob(self, next_vae_obs, indices): return np.exp(self.reconstruction_mse(next_vae_obs, indices)) def binary_cross_entropy(self, next_vae_obs, indices): torch_input = ptu.from_numpy(next_vae_obs) recon_next_vae_obs, _, _ = self.vae(torch_input) error = - torch_input * torch.log( torch.clamp( recon_next_vae_obs, min=1e-30, # corresponds to about -70 ) ) bce = torch.sum(error, dim=1) return ptu.get_numpy(bce) def bernoulli_inv_prob(self, next_vae_obs, indices): torch_input = ptu.from_numpy(next_vae_obs) recon_next_vae_obs, _, _ = self.vae(torch_input) prob = ( torch_input * recon_next_vae_obs + (1 - torch_input) * (1 - recon_next_vae_obs) ).prod(dim=1) return ptu.get_numpy(1 / prob) def vae_prob(self, next_vae_obs, indices, **kwargs): return compute_p_x_np_to_np( self.vae, next_vae_obs, power=self.power, **kwargs ) def forward_model_error(self, next_vae_obs, indices): obs = self._obs[self.observation_key][indices] next_obs = self._next_obs[self.observation_key][indices] actions = self._actions[indices] state_action_pair = ptu.from_numpy(np.c_[obs, actions]) prediction = self.dynamics_model(state_action_pair) mse = self.dynamics_loss(prediction, ptu.from_numpy(next_obs)) return ptu.get_numpy(mse) def latent_novelty(self, next_vae_obs, indices): distances = ((self.env._encode(next_vae_obs) - self.vae.dist_mu) / self.vae.dist_std) ** 2 return distances.sum(axis=1) def latent_novelty_true_prior(self, next_vae_obs, indices): distances = self.env._encode(next_vae_obs) ** 2 return distances.sum(axis=1) def _kl_np_to_np(self, next_vae_obs, indices): torch_input = ptu.from_numpy(next_vae_obs) mu, log_var = self.vae.encode(torch_input) return ptu.get_numpy( - torch.sum(1 + log_var - mu.pow(2) - log_var.exp(), dim=1) ) def update_hash_count(self, next_vae_obs): torch_input = ptu.from_numpy(next_vae_obs) mus, log_vars = self.vae.encode(torch_input) mus = ptu.get_numpy(mus) self.exploration_counter.increment_counts(mus) return None def hash_count_reward(self, next_vae_obs, indices): obs = self.env._encode(next_vae_obs) return self.exploration_counter.compute_count_based_reward(obs) def no_reward(self, next_vae_obs, indices): return np.zeros((len(next_vae_obs), 1)) def initialize_dynamics_model(self): obs_dim = self._obs[self.observation_key].shape[1] self.dynamics_model = Mlp( hidden_sizes=[128, 128], output_size=obs_dim, input_size=obs_dim + self._action_dim, ) self.dynamics_model.to(ptu.device) self.dynamics_optimizer = Adam(self.dynamics_model.parameters()) self.dynamics_loss = MSELoss() def train_dynamics_model(self, batches=50, batch_size=100): if not self.use_dynamics_model: return for _ in range(batches): indices = self._sample_indices(batch_size) self.dynamics_optimizer.zero_grad() obs = self._obs[self.observation_key][indices] next_obs = self._next_obs[self.observation_key][indices] actions = self._actions[indices] if self.exploration_rewards_type == 'inverse_model_error': obs, next_obs = next_obs, obs state_action_pair = ptu.from_numpy(np.c_[obs, actions]) prediction = self.dynamics_model(state_action_pair) mse = self.dynamics_loss(prediction, ptu.from_numpy(next_obs)) mse.backward() self.dynamics_optimizer.step() def log_loss_under_uniform(self, model, data, batch_size, rl_logger, priority_function_kwargs): import torch.nn.functional as F log_probs_prior = [] log_probs_biased = [] log_probs_importance = [] kles = [] mses = [] for i in range(0, data.shape[0], batch_size): img = data[i:min(data.shape[0], i + batch_size), :] torch_img = ptu.from_numpy(img) reconstructions, obs_distribution_params, latent_distribution_params = self.vae(torch_img) priority_function_kwargs['sampling_method'] = 'true_prior_sampling' log_p, log_q, log_d = compute_log_p_log_q_log_d(model, img, **priority_function_kwargs) log_prob_prior = log_d.mean() priority_function_kwargs['sampling_method'] = 'biased_sampling' log_p, log_q, log_d = compute_log_p_log_q_log_d(model, img, **priority_function_kwargs) log_prob_biased = log_d.mean() priority_function_kwargs['sampling_method'] = 'importance_sampling' log_p, log_q, log_d = compute_log_p_log_q_log_d(model, img, **priority_function_kwargs) log_prob_importance = (log_p - log_q + log_d).mean() kle = model.kl_divergence(latent_distribution_params) mse = F.mse_loss(torch_img, reconstructions, reduction='elementwise_mean') mses.append(mse.item()) kles.append(kle.item()) log_probs_prior.append(log_prob_prior.item()) log_probs_biased.append(log_prob_biased.item()) log_probs_importance.append(log_prob_importance.item()) rl_logger["Uniform Data Log Prob (Prior)"] = np.mean(log_probs_prior) rl_logger["Uniform Data Log Prob (Biased)"] = np.mean(log_probs_biased) rl_logger["Uniform Data Log Prob (Importance)"] = np.mean(log_probs_importance) rl_logger["Uniform Data KL"] = np.mean(kles) rl_logger["Uniform Data MSE"] = np.mean(mses) def _get_sorted_idx_and_train_weights(self): idx_and_weights = zip(range(len(self._vae_sample_probs)), self._vae_sample_probs) return sorted(idx_and_weights, key=lambda x: x[1])
def train_vae(variant, return_data=False): from railrl.misc.ml_util import PiecewiseLinearSchedule, ConstantSchedule from railrl.torch.vae.conv_vae import ( ConvVAE, SpatialAutoEncoder, AutoEncoder, ) import railrl.torch.vae.conv_vae as conv_vae from railrl.torch.vae.vae_trainer import ConvVAETrainer from railrl.core import logger import railrl.torch.pytorch_util as ptu from railrl.pythonplusplus import identity import torch beta = variant["beta"] representation_size = variant.get("representation_size", variant.get("latent_sizes", None)) use_linear_dynamics = variant.get('use_linear_dynamics', False) generate_vae_dataset_fctn = variant.get('generate_vae_data_fctn', generate_vae_dataset) variant['generate_vae_dataset_kwargs'][ 'use_linear_dynamics'] = use_linear_dynamics variant['generate_vae_dataset_kwargs']['batch_size'] = variant[ 'algo_kwargs']['batch_size'] train_dataset, test_dataset, info = generate_vae_dataset_fctn( variant['generate_vae_dataset_kwargs']) if use_linear_dynamics: action_dim = train_dataset.data['actions'].shape[2] logger.save_extra_data(info) logger.get_snapshot_dir() if 'beta_schedule_kwargs' in variant: beta_schedule = PiecewiseLinearSchedule( **variant['beta_schedule_kwargs']) else: beta_schedule = None if 'context_schedule' in variant: schedule = variant['context_schedule'] if type(schedule) is dict: context_schedule = PiecewiseLinearSchedule(**schedule) else: context_schedule = ConstantSchedule(schedule) variant['algo_kwargs']['context_schedule'] = context_schedule if variant.get('decoder_activation', None) == 'sigmoid': decoder_activation = torch.nn.Sigmoid() else: decoder_activation = identity architecture = variant['vae_kwargs'].get('architecture', None) if not architecture and variant.get('imsize') == 84: architecture = conv_vae.imsize84_default_architecture elif not architecture and variant.get('imsize') == 48: architecture = conv_vae.imsize48_default_architecture variant['vae_kwargs']['architecture'] = architecture variant['vae_kwargs']['imsize'] = variant.get('imsize') if variant['algo_kwargs'].get('is_auto_encoder', False): model = AutoEncoder(representation_size, decoder_output_activation=decoder_activation, **variant['vae_kwargs']) elif variant.get('use_spatial_auto_encoder', False): model = SpatialAutoEncoder( representation_size, decoder_output_activation=decoder_activation, **variant['vae_kwargs']) else: vae_class = variant.get('vae_class', ConvVAE) if use_linear_dynamics: model = vae_class(representation_size, decoder_output_activation=decoder_activation, action_dim=action_dim, **variant['vae_kwargs']) else: model = vae_class(representation_size, decoder_output_activation=decoder_activation, **variant['vae_kwargs']) model.to(ptu.device) vae_trainer_class = variant.get('vae_trainer_class', ConvVAETrainer) trainer = vae_trainer_class(model, beta=beta, beta_schedule=beta_schedule, **variant['algo_kwargs']) save_period = variant['save_period'] dump_skew_debug_plots = variant.get('dump_skew_debug_plots', False) for epoch in range(variant['num_epochs']): should_save_imgs = (epoch % save_period == 0) trainer.train_epoch(epoch, train_dataset) trainer.test_epoch(epoch, test_dataset) if should_save_imgs: trainer.dump_reconstructions(epoch) trainer.dump_samples(epoch) if dump_skew_debug_plots: trainer.dump_best_reconstruction(epoch) trainer.dump_worst_reconstruction(epoch) trainer.dump_sampling_histogram(epoch) stats = trainer.get_diagnostics() for k, v in stats.items(): logger.record_tabular(k, v) logger.dump_tabular() trainer.end_epoch(epoch) if epoch % 50 == 0: logger.save_itr_params(epoch, model) logger.save_extra_data(model, 'vae.pkl', mode='pickle') if return_data: return model, train_dataset, test_dataset return model
def __init__( self, vae, *args, decoded_obs_key='image_observation', decoded_achieved_goal_key='image_achieved_goal', decoded_desired_goal_key='image_desired_goal', exploration_rewards_type='None', exploration_rewards_scale=1.0, vae_priority_type='None', start_skew_epoch=0, power=1.0, internal_keys=None, exploration_schedule_kwargs=None, priority_function_kwargs=None, exploration_counter_kwargs=None, relabeling_goal_sampling_mode='vae_prior', decode_vae_goals=False, **kwargs ): if internal_keys is None: internal_keys = [] for key in [ decoded_obs_key, decoded_achieved_goal_key, decoded_desired_goal_key ]: if key not in internal_keys: internal_keys.append(key) super().__init__(internal_keys=internal_keys, *args, **kwargs) # assert isinstance(self.env, VAEWrappedEnv) self.vae = vae self.decoded_obs_key = decoded_obs_key self.decoded_desired_goal_key = decoded_desired_goal_key self.decoded_achieved_goal_key = decoded_achieved_goal_key self.exploration_rewards_type = exploration_rewards_type self.exploration_rewards_scale = exploration_rewards_scale self.start_skew_epoch = start_skew_epoch self.vae_priority_type = vae_priority_type self.power = power self._relabeling_goal_sampling_mode = relabeling_goal_sampling_mode self.decode_vae_goals = decode_vae_goals if exploration_schedule_kwargs is None: self.explr_reward_scale_schedule = \ ConstantSchedule(self.exploration_rewards_scale) else: self.explr_reward_scale_schedule = \ PiecewiseLinearSchedule(**exploration_schedule_kwargs) self._give_explr_reward_bonus = ( exploration_rewards_type != 'None' and exploration_rewards_scale != 0. ) self._exploration_rewards = np.zeros((self.max_size, 1), dtype=np.float32) self._prioritize_vae_samples = ( vae_priority_type != 'None' and power != 0. ) self._vae_sample_priorities = np.zeros((self.max_size, 1), dtype=np.float32) self._vae_sample_probs = None self.use_dynamics_model = ( self.exploration_rewards_type == 'forward_model_error' ) if self.use_dynamics_model: self.initialize_dynamics_model() type_to_function = { 'reconstruction_error': self.reconstruction_mse, 'bce': self.binary_cross_entropy, 'latent_distance': self.latent_novelty, 'latent_distance_true_prior': self.latent_novelty_true_prior, 'forward_model_error': self.forward_model_error, 'gaussian_inv_prob': self.gaussian_inv_prob, 'bernoulli_inv_prob': self.bernoulli_inv_prob, 'vae_prob': self.vae_prob, 'hash_count': self.hash_count_reward, 'None': self.no_reward, } self.exploration_reward_func = ( type_to_function[self.exploration_rewards_type] ) self.vae_prioritization_func = ( type_to_function[self.vae_priority_type] ) if priority_function_kwargs is None: self.priority_function_kwargs = dict() else: self.priority_function_kwargs = priority_function_kwargs if self.exploration_rewards_type == 'hash_count': if exploration_counter_kwargs is None: exploration_counter_kwargs = dict() self.exploration_counter = CountExploration(env=self.env, **exploration_counter_kwargs) self.epoch = 0
def train_vae(variant): from railrl.misc.ml_util import PiecewiseLinearSchedule from railrl.torch.vae.conv_vae import ConvVAE from railrl.torch.vae.conv_vae_trainer import ConvVAETrainer from railrl.core import logger import railrl.torch.pytorch_util as ptu from multiworld.core.image_env import ImageEnv from railrl.envs.vae_wrappers import VAEWrappedEnv from railrl.misc.asset_loader import local_path_from_s3_or_local_path logger.remove_tabular_output('progress.csv', relative_to_snapshot_dir=True) logger.add_tabular_output('vae_progress.csv', relative_to_snapshot_dir=True) env_id = variant['generate_vae_dataset_kwargs'].get('env_id', None) if env_id is not None: import gym env = gym.make(env_id) else: env_class = variant['generate_vae_dataset_kwargs']['env_class'] env_kwargs = variant['generate_vae_dataset_kwargs']['env_kwargs'] env = env_class(**env_kwargs) representation_size = variant["representation_size"] beta = variant["beta"] if 'beta_schedule_kwargs' in variant: beta_schedule = PiecewiseLinearSchedule( **variant['beta_schedule_kwargs']) else: beta_schedule = None # obtain training and testing data dataset_path = variant['generate_vae_dataset_kwargs'].get( 'dataset_path', None) test_p = variant['generate_vae_dataset_kwargs'].get('test_p', 0.9) filename = local_path_from_s3_or_local_path(dataset_path) dataset = np.load(filename, allow_pickle=True).item() N = dataset['obs'].shape[0] n = int(N * test_p) train_data = {} test_data = {} for k in dataset.keys(): train_data[k] = dataset[k][:n, :] test_data[k] = dataset[k][n:, :] # setup vae variant['vae_kwargs']['action_dim'] = train_data['actions'].shape[1] if variant.get('vae_type', None) == "VAE-state": from railrl.torch.vae.vae import VAE input_size = train_data['obs'].shape[1] variant['vae_kwargs']['input_size'] = input_size m = VAE(representation_size, **variant['vae_kwargs']) elif variant.get('vae_type', None) == "VAE2": from railrl.torch.vae.conv_vae2 import ConvVAE2 variant['vae_kwargs']['imsize'] = variant['imsize'] m = ConvVAE2(representation_size, **variant['vae_kwargs']) else: variant['vae_kwargs']['imsize'] = variant['imsize'] m = ConvVAE(representation_size, **variant['vae_kwargs']) if ptu.gpu_enabled(): m.cuda() # setup vae trainer if variant.get('vae_type', None) == "VAE-state": from railrl.torch.vae.vae_trainer import VAETrainer t = VAETrainer(train_data, test_data, m, beta=beta, beta_schedule=beta_schedule, **variant['algo_kwargs']) else: t = ConvVAETrainer(train_data, test_data, m, beta=beta, beta_schedule=beta_schedule, **variant['algo_kwargs']) # visualization vis_variant = variant.get('vis_kwargs', {}) save_video = vis_variant.get('save_video', False) if isinstance(env, ImageEnv): image_env = env else: image_env = ImageEnv( env, variant['generate_vae_dataset_kwargs'].get('imsize'), init_camera=variant['generate_vae_dataset_kwargs'].get( 'init_camera'), transpose=True, normalize=True, ) render = variant.get('render', False) reward_params = variant.get("reward_params", dict()) vae_env = VAEWrappedEnv(image_env, m, imsize=image_env.imsize, decode_goals=render, render_goals=render, render_rollouts=render, reward_params=reward_params, **variant.get('vae_wrapped_env_kwargs', {})) vae_env.reset() vae_env.add_mode("video_env", 'video_env') vae_env.add_mode("video_vae", 'video_vae') if save_video: import railrl.samplers.rollout_functions as rf from railrl.policies.simple import RandomPolicy random_policy = RandomPolicy(vae_env.action_space) rollout_function = rf.create_rollout_function( rf.multitask_rollout, max_path_length=100, observation_key='latent_observation', desired_goal_key='latent_desired_goal', vis_list=vis_variant.get('vis_list', []), dont_terminate=True, ) dump_video_kwargs = variant.get("dump_video_kwargs", dict()) dump_video_kwargs['imsize'] = vae_env.imsize dump_video_kwargs['vis_list'] = [ 'image_observation', 'reconstr_image_observation', 'image_latent_histogram_2d', 'image_latent_histogram_mu_2d', 'image_plt', 'image_rew', 'image_rew_euclidean', 'image_rew_mahalanobis', 'image_rew_logp', 'image_rew_kl', 'image_rew_kl_rev', ] def visualization_post_processing(save_vis, save_video, epoch): vis_list = vis_variant.get('vis_list', []) if save_vis: if vae_env.vae_input_key_prefix == 'state': vae_env.dump_reconstructions(epoch, n_recon=vis_variant.get( 'n_recon', 16)) vae_env.dump_samples(epoch, n_samples=vis_variant.get('n_samples', 64)) if 'latent_representation' in vis_list: vae_env.dump_latent_plots(epoch) if any(elem in vis_list for elem in [ 'latent_histogram', 'latent_histogram_mu', 'latent_histogram_2d', 'latent_histogram_mu_2d' ]): vae_env.compute_latent_histogram() if not save_video and ('latent_histogram' in vis_list): vae_env.dump_latent_histogram(epoch=epoch, noisy=True, use_true_prior=True) if not save_video and ('latent_histogram_mu' in vis_list): vae_env.dump_latent_histogram(epoch=epoch, noisy=False, use_true_prior=True) if save_video and save_vis: from railrl.envs.vae_wrappers import temporary_mode from railrl.misc.video_gen import dump_video from railrl.core import logger vae_env.compute_goal_encodings() logdir = logger.get_snapshot_dir() filename = osp.join(logdir, 'video_{epoch}.mp4'.format(epoch=epoch)) variant['dump_video_kwargs']['epoch'] = epoch temporary_mode(vae_env, mode='video_env', func=dump_video, args=(vae_env, random_policy, filename, rollout_function), kwargs=variant['dump_video_kwargs']) if not vis_variant.get('save_video_env_only', True): filename = osp.join( logdir, 'video_{epoch}_vae.mp4'.format(epoch=epoch)) temporary_mode(vae_env, mode='video_vae', func=dump_video, args=(vae_env, random_policy, filename, rollout_function), kwargs=variant['dump_video_kwargs']) # train vae for epoch in range(variant['num_epochs']): #for epoch in range(2000): save_vis = (epoch % vis_variant['save_period'] == 0 or epoch == variant['num_epochs'] - 1) save_vae = (epoch % variant['snapshot_gap'] == 0 or epoch == variant['num_epochs'] - 1) t.train_epoch(epoch) '''if epoch % 500 == 0 or epoch == variant['num_epochs']-1: t.test_epoch( epoch, save_reconstruction=save_vis, save_interpolation=save_vis, save_vae=save_vae, ) if epoch % 200 == 0 or epoch == variant['num_epochs']-1: visualization_post_processing(save_video, save_video, epoch)''' t.test_epoch( epoch, save_reconstruction=save_vis, save_interpolation=save_vis, save_vae=save_vae, ) if epoch % 300 == 0 or epoch == variant['num_epochs'] - 1: visualization_post_processing(save_vis, save_video, epoch) logger.save_extra_data(m, 'vae.pkl', mode='pickle') logger.remove_tabular_output( 'vae_progress.csv', relative_to_snapshot_dir=True, ) logger.add_tabular_output( 'progress.csv', relative_to_snapshot_dir=True, ) print("finished --------------------!!!!!!!!!!!!!!!") return m