def get_non_linear_results( ob_space, encoder, latent_dim, batch_size=128, num_batches=10000, ) -> NonLinearResults: state_dim = ob_space.low.size decoder = Mlp( hidden_sizes=[64, 64], output_size=state_dim, input_size=latent_dim, ) decoder.to(ptu.device) optimizer = optim.Adam(decoder.parameters()) initial_loss = last_10_percent_loss = 0 for i in range(num_batches): states = get_batch(ob_space, batch_size) x = ptu.from_numpy(states) z = encoder(x) x_hat = decoder(z) loss = ((x - x_hat)**2).mean() optimizer.zero_grad() loss.backward() optimizer.step() if i == 0: initial_loss = ptu.get_numpy(loss) if i == int(num_batches * 0.9): last_10_percent_loss = ptu.get_numpy(loss) eval_states = get_batch(ob_space, batch_size=2**15) x = ptu.from_numpy(eval_states) z = encoder(x) x_hat = decoder(z) reconstruction = ptu.get_numpy(x_hat) loss = ((eval_states - reconstruction)**2).mean() last_10_percent_contribution = ( (last_10_percent_loss - loss)) / (initial_loss - loss) del decoder, optimizer return NonLinearResults( loss=loss, initial_loss=initial_loss, last_10_percent_contribution=last_10_percent_contribution, )
class OnlineVaeRelabelingBuffer(ObsDictRelabelingBuffer): def __init__( self, vae, *args, decoded_obs_key='image_observation', decoded_achieved_goal_key='image_achieved_goal', decoded_desired_goal_key='image_desired_goal', exploration_rewards_type='None', exploration_rewards_scale=1.0, vae_priority_type='None', start_skew_epoch=0, power=1.0, internal_keys=None, exploration_schedule_kwargs=None, priority_function_kwargs=None, exploration_counter_kwargs=None, relabeling_goal_sampling_mode='vae_prior', decode_vae_goals=False, **kwargs ): if internal_keys is None: internal_keys = [] for key in [ decoded_obs_key, decoded_achieved_goal_key, decoded_desired_goal_key ]: if key not in internal_keys: internal_keys.append(key) super().__init__(internal_keys=internal_keys, *args, **kwargs) # assert isinstance(self.env, VAEWrappedEnv) self.vae = vae self.decoded_obs_key = decoded_obs_key self.decoded_desired_goal_key = decoded_desired_goal_key self.decoded_achieved_goal_key = decoded_achieved_goal_key self.exploration_rewards_type = exploration_rewards_type self.exploration_rewards_scale = exploration_rewards_scale self.start_skew_epoch = start_skew_epoch self.vae_priority_type = vae_priority_type self.power = power self._relabeling_goal_sampling_mode = relabeling_goal_sampling_mode self.decode_vae_goals = decode_vae_goals if exploration_schedule_kwargs is None: self.explr_reward_scale_schedule = \ ConstantSchedule(self.exploration_rewards_scale) else: self.explr_reward_scale_schedule = \ PiecewiseLinearSchedule(**exploration_schedule_kwargs) self._give_explr_reward_bonus = ( exploration_rewards_type != 'None' and exploration_rewards_scale != 0. ) self._exploration_rewards = np.zeros((self.max_size, 1), dtype=np.float32) self._prioritize_vae_samples = ( vae_priority_type != 'None' and power != 0. ) self._vae_sample_priorities = np.zeros((self.max_size, 1), dtype=np.float32) self._vae_sample_probs = None self.use_dynamics_model = ( self.exploration_rewards_type == 'forward_model_error' ) if self.use_dynamics_model: self.initialize_dynamics_model() type_to_function = { 'reconstruction_error': self.reconstruction_mse, 'bce': self.binary_cross_entropy, 'latent_distance': self.latent_novelty, 'latent_distance_true_prior': self.latent_novelty_true_prior, 'forward_model_error': self.forward_model_error, 'gaussian_inv_prob': self.gaussian_inv_prob, 'bernoulli_inv_prob': self.bernoulli_inv_prob, 'vae_prob': self.vae_prob, 'hash_count': self.hash_count_reward, 'None': self.no_reward, } self.exploration_reward_func = ( type_to_function[self.exploration_rewards_type] ) self.vae_prioritization_func = ( type_to_function[self.vae_priority_type] ) if priority_function_kwargs is None: self.priority_function_kwargs = dict() else: self.priority_function_kwargs = priority_function_kwargs if self.exploration_rewards_type == 'hash_count': if exploration_counter_kwargs is None: exploration_counter_kwargs = dict() self.exploration_counter = CountExploration(env=self.env, **exploration_counter_kwargs) self.epoch = 0 def add_path(self, path): if self.decode_vae_goals: self.add_decoded_vae_goals_to_path(path) super().add_path(path) def add_decoded_vae_goals_to_path(self, path): # decoding the self-sampled vae images should be done in batch (here) # rather than in the env for efficiency desired_goals = combine_dicts( path['observations'], [self.desired_goal_key] )[self.desired_goal_key] desired_decoded_goals = self.env._decode(desired_goals) desired_decoded_goals = desired_decoded_goals.reshape( len(desired_decoded_goals), -1 ) for idx, next_obs in enumerate(path['observations']): path['observations'][idx][self.decoded_desired_goal_key] = \ desired_decoded_goals[idx] path['next_observations'][idx][self.decoded_desired_goal_key] = \ desired_decoded_goals[idx] def random_batch(self, batch_size): batch = super().random_batch(batch_size) exploration_rewards_scale = float(self.explr_reward_scale_schedule.get_value(self.epoch)) if self._give_explr_reward_bonus: batch_idxs = batch['indices'].flatten() batch['exploration_rewards'] = self._exploration_rewards[batch_idxs] batch['rewards'] += exploration_rewards_scale * batch['exploration_rewards'] return batch def get_diagnostics(self): if self._vae_sample_probs is None or self._vae_sample_priorities is None: stats = create_stats_ordered_dict( 'VAE Sample Weights', np.zeros(self._size), ) stats.update(create_stats_ordered_dict( 'VAE Sample Probs', np.zeros(self._size), )) else: vae_sample_priorities = self._vae_sample_priorities[:self._size] vae_sample_probs = self._vae_sample_probs[:self._size] stats = create_stats_ordered_dict( 'VAE Sample Weights', vae_sample_priorities, ) stats.update(create_stats_ordered_dict( 'VAE Sample Probs', vae_sample_probs, )) return stats def refresh_latents(self, epoch): self.epoch = epoch self.skew = (self.epoch > self.start_skew_epoch) batch_size = 512 next_idx = min(batch_size, self._size) if self.exploration_rewards_type == 'hash_count': # you have to count everything then compute exploration rewards cur_idx = 0 next_idx = min(batch_size, self._size) while cur_idx < self._size: idxs = np.arange(cur_idx, next_idx) normalized_imgs = self._next_obs[self.decoded_obs_key][idxs] self.update_hash_count(normalized_imgs) cur_idx = next_idx next_idx += batch_size next_idx = min(next_idx, self._size) cur_idx = 0 obs_sum = np.zeros(self.vae.representation_size) obs_square_sum = np.zeros(self.vae.representation_size) while cur_idx < self._size: idxs = np.arange(cur_idx, next_idx) self._obs[self.observation_key][idxs] = \ self.env._encode(self._obs[self.decoded_obs_key][idxs]) self._next_obs[self.observation_key][idxs] = \ self.env._encode(self._next_obs[self.decoded_obs_key][idxs]) # WARNING: we only refresh the desired/achieved latents for # "next_obs". This means that obs[desired/achieve] will be invalid, # so make sure there's no code that references this. # TODO: enforce this with code and not a comment self._next_obs[self.desired_goal_key][idxs] = \ self.env._encode(self._next_obs[self.decoded_desired_goal_key][idxs]) self._next_obs[self.achieved_goal_key][idxs] = \ self.env._encode(self._next_obs[self.decoded_achieved_goal_key][idxs]) normalized_imgs = self._next_obs[self.decoded_obs_key][idxs] if self._give_explr_reward_bonus: rewards = self.exploration_reward_func( normalized_imgs, idxs, **self.priority_function_kwargs ) self._exploration_rewards[idxs] = rewards.reshape(-1, 1) if self._prioritize_vae_samples: if ( self.exploration_rewards_type == self.vae_priority_type and self._give_explr_reward_bonus ): self._vae_sample_priorities[idxs] = ( self._exploration_rewards[idxs] ) else: self._vae_sample_priorities[idxs] = ( self.vae_prioritization_func( normalized_imgs, idxs, **self.priority_function_kwargs ).reshape(-1, 1) ) obs_sum+= self._obs[self.observation_key][idxs].sum(axis=0) obs_square_sum+= np.power(self._obs[self.observation_key][idxs], 2).sum(axis=0) cur_idx = next_idx next_idx += batch_size next_idx = min(next_idx, self._size) self.vae.dist_mu = obs_sum/self._size self.vae.dist_std = np.sqrt(obs_square_sum/self._size - np.power(self.vae.dist_mu, 2)) if self._prioritize_vae_samples: """ priority^power is calculated in the priority function for image_bernoulli_prob or image_gaussian_inv_prob and directly here if not. """ if self.vae_priority_type == 'vae_prob': self._vae_sample_priorities[:self._size] = relative_probs_from_log_probs( self._vae_sample_priorities[:self._size] ) self._vae_sample_probs = self._vae_sample_priorities[:self._size] else: self._vae_sample_probs = self._vae_sample_priorities[:self._size] ** self.power p_sum = np.sum(self._vae_sample_probs) assert p_sum > 0, "Unnormalized p sum is {}".format(p_sum) self._vae_sample_probs /= np.sum(self._vae_sample_probs) self._vae_sample_probs = self._vae_sample_probs.flatten() def sample_weighted_indices(self, batch_size): if ( self._prioritize_vae_samples and self._vae_sample_probs is not None and self.skew ): indices = np.random.choice( len(self._vae_sample_probs), batch_size, p=self._vae_sample_probs, ) assert ( np.max(self._vae_sample_probs) <= 1 and np.min(self._vae_sample_probs) >= 0 ) else: indices = self._sample_indices(batch_size) return indices def _sample_goals_from_env(self, batch_size): self.env.goal_sampling_mode = self._relabeling_goal_sampling_mode return self.env.sample_goals(batch_size) def sample_buffer_goals(self, batch_size): """ Samples goals from weighted replay buffer for relabeling or exploration. Returns None if replay buffer is empty. Example of what might be returned: dict( image_desired_goals: image_achieved_goals[weighted_indices], latent_desired_goals: latent_desired_goals[weighted_indices], ) """ if self._size == 0: return None weighted_idxs = self.sample_weighted_indices( batch_size, ) next_image_obs = self._next_obs[self.decoded_obs_key][weighted_idxs] next_latent_obs = self._next_obs[self.achieved_goal_key][weighted_idxs] return { self.decoded_desired_goal_key: next_image_obs, self.desired_goal_key: next_latent_obs } def random_vae_training_data(self, batch_size, epoch): # epoch no longer needed. Using self.skew in sample_weighted_indices # instead. weighted_idxs = self.sample_weighted_indices( batch_size, ) next_image_obs = self._next_obs[self.decoded_obs_key][weighted_idxs] observations = ptu.from_numpy(next_image_obs) return dict( observations=observations, ) def reconstruction_mse(self, next_vae_obs, indices): torch_input = ptu.from_numpy(next_vae_obs) recon_next_vae_obs, _, _ = self.vae(torch_input) error = torch_input - recon_next_vae_obs mse = torch.sum(error ** 2, dim=1) return ptu.get_numpy(mse) def gaussian_inv_prob(self, next_vae_obs, indices): return np.exp(self.reconstruction_mse(next_vae_obs, indices)) def binary_cross_entropy(self, next_vae_obs, indices): torch_input = ptu.from_numpy(next_vae_obs) recon_next_vae_obs, _, _ = self.vae(torch_input) error = - torch_input * torch.log( torch.clamp( recon_next_vae_obs, min=1e-30, # corresponds to about -70 ) ) bce = torch.sum(error, dim=1) return ptu.get_numpy(bce) def bernoulli_inv_prob(self, next_vae_obs, indices): torch_input = ptu.from_numpy(next_vae_obs) recon_next_vae_obs, _, _ = self.vae(torch_input) prob = ( torch_input * recon_next_vae_obs + (1 - torch_input) * (1 - recon_next_vae_obs) ).prod(dim=1) return ptu.get_numpy(1 / prob) def vae_prob(self, next_vae_obs, indices, **kwargs): return compute_p_x_np_to_np( self.vae, next_vae_obs, power=self.power, **kwargs ) def forward_model_error(self, next_vae_obs, indices): obs = self._obs[self.observation_key][indices] next_obs = self._next_obs[self.observation_key][indices] actions = self._actions[indices] state_action_pair = ptu.from_numpy(np.c_[obs, actions]) prediction = self.dynamics_model(state_action_pair) mse = self.dynamics_loss(prediction, ptu.from_numpy(next_obs)) return ptu.get_numpy(mse) def latent_novelty(self, next_vae_obs, indices): distances = ((self.env._encode(next_vae_obs) - self.vae.dist_mu) / self.vae.dist_std) ** 2 return distances.sum(axis=1) def latent_novelty_true_prior(self, next_vae_obs, indices): distances = self.env._encode(next_vae_obs) ** 2 return distances.sum(axis=1) def _kl_np_to_np(self, next_vae_obs, indices): torch_input = ptu.from_numpy(next_vae_obs) mu, log_var = self.vae.encode(torch_input) return ptu.get_numpy( - torch.sum(1 + log_var - mu.pow(2) - log_var.exp(), dim=1) ) def update_hash_count(self, next_vae_obs): torch_input = ptu.from_numpy(next_vae_obs) mus, log_vars = self.vae.encode(torch_input) mus = ptu.get_numpy(mus) self.exploration_counter.increment_counts(mus) return None def hash_count_reward(self, next_vae_obs, indices): obs = self.env._encode(next_vae_obs) return self.exploration_counter.compute_count_based_reward(obs) def no_reward(self, next_vae_obs, indices): return np.zeros((len(next_vae_obs), 1)) def initialize_dynamics_model(self): obs_dim = self._obs[self.observation_key].shape[1] self.dynamics_model = Mlp( hidden_sizes=[128, 128], output_size=obs_dim, input_size=obs_dim + self._action_dim, ) self.dynamics_model.to(ptu.device) self.dynamics_optimizer = Adam(self.dynamics_model.parameters()) self.dynamics_loss = MSELoss() def train_dynamics_model(self, batches=50, batch_size=100): if not self.use_dynamics_model: return for _ in range(batches): indices = self._sample_indices(batch_size) self.dynamics_optimizer.zero_grad() obs = self._obs[self.observation_key][indices] next_obs = self._next_obs[self.observation_key][indices] actions = self._actions[indices] if self.exploration_rewards_type == 'inverse_model_error': obs, next_obs = next_obs, obs state_action_pair = ptu.from_numpy(np.c_[obs, actions]) prediction = self.dynamics_model(state_action_pair) mse = self.dynamics_loss(prediction, ptu.from_numpy(next_obs)) mse.backward() self.dynamics_optimizer.step() def log_loss_under_uniform(self, model, data, batch_size, rl_logger, priority_function_kwargs): import torch.nn.functional as F log_probs_prior = [] log_probs_biased = [] log_probs_importance = [] kles = [] mses = [] for i in range(0, data.shape[0], batch_size): img = data[i:min(data.shape[0], i + batch_size), :] torch_img = ptu.from_numpy(img) reconstructions, obs_distribution_params, latent_distribution_params = self.vae(torch_img) priority_function_kwargs['sampling_method'] = 'true_prior_sampling' log_p, log_q, log_d = compute_log_p_log_q_log_d(model, img, **priority_function_kwargs) log_prob_prior = log_d.mean() priority_function_kwargs['sampling_method'] = 'biased_sampling' log_p, log_q, log_d = compute_log_p_log_q_log_d(model, img, **priority_function_kwargs) log_prob_biased = log_d.mean() priority_function_kwargs['sampling_method'] = 'importance_sampling' log_p, log_q, log_d = compute_log_p_log_q_log_d(model, img, **priority_function_kwargs) log_prob_importance = (log_p - log_q + log_d).mean() kle = model.kl_divergence(latent_distribution_params) mse = F.mse_loss(torch_img, reconstructions, reduction='elementwise_mean') mses.append(mse.item()) kles.append(kle.item()) log_probs_prior.append(log_prob_prior.item()) log_probs_biased.append(log_prob_biased.item()) log_probs_importance.append(log_prob_importance.item()) rl_logger["Uniform Data Log Prob (Prior)"] = np.mean(log_probs_prior) rl_logger["Uniform Data Log Prob (Biased)"] = np.mean(log_probs_biased) rl_logger["Uniform Data Log Prob (Importance)"] = np.mean(log_probs_importance) rl_logger["Uniform Data KL"] = np.mean(kles) rl_logger["Uniform Data MSE"] = np.mean(mses) def _get_sorted_idx_and_train_weights(self): idx_and_weights = zip(range(len(self._vae_sample_probs)), self._vae_sample_probs) return sorted(idx_and_weights, key=lambda x: x[1])
class TimestepPredictionModel(torch.nn.Module): def __init__( self, representation_size, architecture, normalize=True, output_classes=100, encoder_class=CNN, decoder_class=DCNN, decoder_output_activation=identity, decoder_distribution='bernoulli', input_channels=1, imsize=224, init_w=1e-3, min_variance=1e-3, hidden_init=ptu.fanin_init, delta_features=False, pretrained_features=False, ): """ :param representation_size: :param conv_args: must be a dictionary specifying the following: kernel_sizes n_channels strides :param conv_kwargs: a dictionary specifying the following: hidden_sizes batch_norm :param deconv_args: must be a dictionary specifying the following: hidden_sizes deconv_input_width deconv_input_height deconv_input_channels deconv_output_kernel_size deconv_output_strides deconv_output_channels kernel_sizes n_channels strides :param deconv_kwargs: batch_norm :param encoder_class: :param decoder_class: :param decoder_output_activation: :param decoder_distribution: :param input_channels: :param imsize: :param init_w: :param min_variance: :param hidden_init: """ super().__init__() # super().__init__(representation_size) if min_variance is None: self.log_min_variance = None else: self.log_min_variance = float(np.log(min_variance)) self.input_channels = input_channels self.imsize = imsize self.imlength = self.imsize * self.imsize * self.input_channels self.representation_size = representation_size self.output_classes = output_classes self.normalize = normalize self.img_mean = torch.tensor([0.485, 0.456, 0.406]) self.img_std = torch.tensor([0.229, 0.224, 0.225]) self.img_mean = self.img_mean.repeat(epic.CROP_WIDTH, epic.CROP_HEIGHT, 1).transpose(0, 2).to(ptu.device) self.img_std = self.img_std.repeat(epic.CROP_WIDTH, epic.CROP_HEIGHT, 1).transpose(0, 2).to(ptu.device) # self.img_normalizer = torchvision.transforms.Normalize(self.img_mean, self.img_std) self.encoder = torchvision.models.resnet.ResNet( torchvision.models.resnet.BasicBlock, [2, 2, 2, 2], num_classes=representation_size, ) self.encoder.avgpool = nn.AdaptiveAvgPool2d((1, 1)) # self.encoder = nn.DataParallel(self.encoder) if pretrained_features: exclude_names = ["fc"] state_dict = load_state_dict_from_url( "https://download.pytorch.org/models/resnet18-5c106cde.pth", progress=True, ) new_state_dict = state_dict.copy() for key in state_dict: for name in exclude_names: if name in key: del new_state_dict[key] break self.encoder.load_state_dict(new_state_dict, strict=False) self.delta_features = delta_features input_size = representation_size * 2 if delta_features else representation_size * 3 self.predictor = Mlp(output_size=output_classes, input_size=input_size, **architecture) self.predictor = self.predictor.to("cuda:0") self.epoch = 0 def get_latents(self, x0, xt, xT): bz = x0.shape[0] x = torch.cat([x0, xt, xT], dim=0).view( -1, 3, epic.CROP_HEIGHT, epic.CROP_WIDTH, ) z = self.encode(x) # # import pdb; pdb.set_trace() # if self.normalize: # x = x - self.img_mean # x = x / self.img_std # # x = self.img_normalizer(x) # zs = [] # for i in range(0, 3 * bz, MAX_BATCH_SIZE): # z = self.encoder(x[i:i+MAX_BATCH_SIZE, :, :, :]) # zs.append(z) # # z = self.encoder(x) # z = torch.cat(zs) # self.encoder(x) # .to("cuda:0") z0, zt, zT = z[:bz, :], z[bz:2 * bz, :], z[2 * bz:3 * bz, :] return z0, zt, zT def forward(self, x0, xt, xT): z0, zt, zT = self.get_latents(x0, xt, xT) # z0 = self.encoder(x0.view(-1, 3, 456, 256)).to("cuda:0") #.view((-1, 3, 240, 240))[:, :, :224, :224]) # zt = self.encoder(xt.view(-1, 3, 456, 256)).to("cuda:0") # .view((-1, 3, 240, 240))[:, :, :224, :224]) # zT = self.encoder(xT.view(-1, 3, 456, 256)).to("cuda:0") # .view((-1, 3, 240, 240))[:, :, :224, :224]) if self.delta_features: dt = zt - z0 dT = zT - z0 z = torch.cat([dt, dT], dim=1) else: z = torch.cat([z0, zt, zT], dim=1) out = self.predictor(z) return out def encode(self, x): bz = x.shape[0] if self.normalize: x = x - self.img_mean x = x / self.img_std zs = [] for i in range(0, bz, MAX_BATCH_SIZE): z = self.encoder(x[i:i + MAX_BATCH_SIZE, :, :, :]) zs.append(z) z = torch.cat(zs) return z