def debug_statistics(self): """ Given an image $$x$$, samples a bunch of latents from the prior $$z_i$$ and decode them $$\hat x_i$$. Compare this to $$\hat x$$, the reconstruction of $$x$$. Ideally - All the $$\hat x_i$$s do worse than $$\hat x$$ (makes sure VAE isn’t ignoring the latent) - Some $$\hat x_i$$ do better than other $$\hat x_i$$ (tests for coverage) """ debug_batch_size = 64 data = self.get_batch(train=False) reconstructions, _, _ = self.model(data) img = data[0] recon_mse = ((reconstructions[0] - img)**2).mean().view(-1) img_repeated = img.expand((debug_batch_size, img.shape[0])) samples = ptu.randn(debug_batch_size, self.representation_size) random_imgs, _ = self.model.decode(samples) random_mses = (random_imgs - img_repeated)**2 mse_improvement = ptu.get_numpy(random_mses.mean(dim=1) - recon_mse) stats = create_stats_ordered_dict( 'debug/MSE improvement over random', mse_improvement, ) stats.update( create_stats_ordered_dict( 'debug/MSE of random decoding', ptu.get_numpy(random_mses), )) stats['debug/MSE of reconstruction'] = ptu.get_numpy(recon_mse)[0] return stats
def load_dataset(self, dataset_path): dataset = load_local_or_remote_file(dataset_path) dataset = dataset.item() observations = dataset['observations'] actions = dataset['actions'] # dataset['observations'].shape # (2000, 50, 6912) # dataset['actions'].shape # (2000, 50, 2) # dataset['env'].shape # (2000, 6912) N, H, imlength = observations.shape self.vae.eval() for n in range(N): x0 = ptu.from_numpy(dataset['env'][n:n + 1, :] / 255.0) x = ptu.from_numpy(observations[n, :, :] / 255.0) latents = self.vae.encode(x, x0, distrib=False) r1, r2 = self.vae.latent_sizes conditioning = latents[0, r1:] goal = torch.cat( [ptu.randn(self.vae.latent_sizes[0]), conditioning]) goal = ptu.get_numpy(goal) # latents[-1, :] latents = ptu.get_numpy(latents) latent_delta = latents - goal distances = np.zeros((H - 1, 1)) for i in range(H - 1): distances[i, 0] = np.linalg.norm(latent_delta[i + 1, :]) terminals = np.zeros((H - 1, 1)) # terminals[-1, 0] = 1 path = dict( observations=[], actions=actions[n, :H - 1, :], next_observations=[], rewards=-distances, terminals=terminals, ) for t in range(H - 1): # reward = -np.linalg.norm(latent_delta[i, :]) obs = dict( latent_observation=latents[t, :], latent_achieved_goal=latents[t, :], latent_desired_goal=goal, ) next_obs = dict( latent_observation=latents[t + 1, :], latent_achieved_goal=latents[t + 1, :], latent_desired_goal=goal, ) path['observations'].append(obs) path['next_observations'].append(next_obs) # import ipdb; ipdb.set_trace() self.replay_buffer.add_path(path)
def dump_samples(self, epoch): self.model.eval() sample = ptu.randn(64, self.representation_size) sample = self.model.decode(sample)[0].cpu() save_dir = osp.join(self.log_dir, 's%d.png' % epoch) save_image( sample.data.view(64, self.input_channels, self.imsize, self.imsize).transpose(2, 3), save_dir)
def get_encoding_and_suff_stats(self, x): output = self(x) z_dim = output.shape[1] // 2 means, log_var = ( output[:, :z_dim], output[:, z_dim:] ) stds = (0.5 * log_var).exp() epsilon = ptu.randn(means.shape) latents = epsilon * stds + means return latents, means, log_var, stds
def compute_density(self, data): orig_data_length = len(data) data = np.vstack([ data for _ in range(self.n_average) ]) data = ptu.from_numpy(data) if self.mode == 'biased': latents, means, log_vars, stds = ( self.encoder.get_encoding_and_suff_stats(data) ) importance_weights = ptu.ones(data.shape[0]) elif self.mode == 'prior': latents = ptu.randn(len(data), self.z_dim) importance_weights = ptu.ones(data.shape[0]) elif self.mode == 'importance_sampling': latents, means, log_vars, stds = ( self.encoder.get_encoding_and_suff_stats(data) ) prior = Normal(ptu.zeros(1), ptu.ones(1)) prior_log_prob = prior.log_prob(latents).sum(dim=1) encoder_distrib = Normal(means, stds) encoder_log_prob = encoder_distrib.log_prob(latents).sum(dim=1) importance_weights = (prior_log_prob - encoder_log_prob).exp() else: raise NotImplementedError() unweighted_data_log_prob = self.compute_log_prob( data, self.decoder, latents ).squeeze(1) unweighted_data_prob = unweighted_data_log_prob.exp() unnormalized_data_prob = unweighted_data_prob * importance_weights """ Average over `n_average` """ dp_split = torch.split(unnormalized_data_prob, orig_data_length, dim=0) # pre_avg.shape = ORIG_LEN x N_AVERAGE dp_stacked = torch.stack(dp_split, dim=1) # final.shape = ORIG_LEN unnormalized_dp = torch.sum(dp_stacked, dim=1, keepdim=False) """ Compute the importance weight denomintors. This requires summing across the `n_average` dimension. """ iw_split = torch.split(importance_weights, orig_data_length, dim=0) iw_stacked = torch.stack(iw_split, dim=1) iw_denominators = iw_stacked.sum(dim=1, keepdim=False) final = unnormalized_dp / iw_denominators return ptu.get_numpy(final)
def sample_prior(self, batch_size, x_0, true_prior=True): if x_0.shape[0] == 1: x_0 = x_0.repeat(batch_size, 1) z_sample = ptu.randn(batch_size, self.latent_sizes[0]) if not true_prior: stds = np.exp(0.5 * self.prior_logvar) z_sample = z_sample * stds + self.prior_mu conditioning = self.bn_c(self.c(self.dropout(self.cond_encoder(x_0)))) cond_sample = torch.cat([z_sample, conditioning], dim=1) return cond_sample
def dump_samples(self, epoch): self.model.eval() batch, _ = self.eval_data["test/last_batch"] sample = ptu.randn(64, self.representation_size) sample = self.model.decode(sample, batch["observations"])[0].cpu() save_dir = osp.join(self.log_dir, 's%d.png' % epoch) save_image( sample.data.view(64, 3, self.imsize, self.imsize).transpose(2, 3), save_dir) x0 = batch["x_0"] x0_img = x0[:64].narrow(start=0, length=self.imlength // 2, dim=1).contiguous().view( -1, 3, self.imsize, self.imsize).transpose(2, 3) save_dir = osp.join(self.log_dir, 'x0_%d.png' % epoch) save_image(x0_img.data.cpu(), save_dir)
def rsample(self, latent_distribution_params): mu, logvar = latent_distribution_params stds = (0.5 * logvar).exp() epsilon = ptu.randn(*mu.size()) latents = epsilon * stds + mu return latents
def sample(self, num_samples): return ptu.get_numpy(self.sample_given_z( ptu.randn(num_samples, self.z_dim) ))
def train_from_torch(self, batch): logger.push_tabular_prefix("train_q/") self.eval_statistics = dict() self._need_to_update_eval_statistics = True rewards = batch['rewards'] terminals = batch['terminals'] obs = batch['observations'] actions = batch['actions'] next_obs = batch['next_observations'] next_actions = self.target_policy(next_obs) noise = ptu.randn(next_actions.shape) * self.target_policy_noise noise = torch.clamp(noise, -self.target_policy_noise_clip, self.target_policy_noise_clip) noisy_next_actions = next_actions + noise target_q1_values = self.target_qf1(next_obs, noisy_next_actions) target_q2_values = self.target_qf2(next_obs, noisy_next_actions) target_q_values = torch.min(target_q1_values, target_q2_values) q_target = self.reward_scale * rewards + ( 1. - terminals) * self.discount * target_q_values q_target = q_target.detach() q1_pred = self.qf1(obs, actions) bellman_errors_1 = (q1_pred - q_target)**2 qf1_loss = bellman_errors_1.mean() q2_pred = self.qf2(obs, actions) bellman_errors_2 = (q2_pred - q_target)**2 qf2_loss = bellman_errors_2.mean() """ Update Networks """ self.qf1_optimizer.zero_grad() qf1_loss.backward() self.qf1_optimizer.step() self.qf2_optimizer.zero_grad() qf2_loss.backward() self.qf2_optimizer.step() policy_actions = policy_loss = None if self._n_train_steps_total % self.policy_update_period == 0: policy_actions = self.policy(obs) q_output = self.qf1(obs, policy_actions) if self.demo_train_buffer._size >= self.bc_batch_size: if self.use_demo_awr: train_batch = self.get_batch_from_buffer( self.demo_train_buffer) train_o = train_batch["observations"] train_u = train_batch["actions"] if self.goal_conditioned: train_g = train_batch["resampled_goals"] train_o = torch.cat((train_o, train_g), dim=1) train_pred_u = self.policy(train_o) train_error = (train_pred_u - train_u)**2 train_bc_loss = train_error.mean() policy_q_output_demo_state = self.qf1( train_o, train_pred_u) demo_q_output = self.qf1(train_o, train_u) advantage = demo_q_output - policy_q_output_demo_state self.eval_statistics['Train BC Loss'] = np.mean( ptu.get_numpy(train_bc_loss)) if self.awr_policy_update: train_bc_loss = (train_error * torch.exp( (advantage) * self.demo_beta)) self.eval_statistics['Advantage'] = np.mean( ptu.get_numpy(advantage)) if self._n_train_steps_total < self.max_steps_till_train_rl: rl_weight = 0 else: rl_weight = self.rl_weight policy_loss = -rl_weight * q_output.mean( ) + self.bc_weight * train_bc_loss.mean() else: train_batch = self.get_batch_from_buffer( self.demo_train_buffer) train_o = train_batch["observations"] # train_pred_u = self.policy(train_o) if self.goal_conditioned: train_g = train_batch["resampled_goals"] train_o = torch.cat((train_o, train_g), dim=1) train_pred_u = self.policy(train_o) train_u = train_batch["actions"] train_error = (train_pred_u - train_u)**2 train_bc_loss = train_error.mean() # Advantage-weighted regression policy_error = (policy_actions - actions)**2 policy_error = policy_error.mean(dim=1) advantage = q1_pred - q_output weights = F.softmax((advantage / self.beta)[:, 0]) if self.awr_policy_update: policy_loss = self.rl_weight * ( policy_error * weights.detach() * self.bc_batch_size).mean() else: policy_loss = -self.rl_weight * q_output.mean( ) + self.bc_weight * train_bc_loss.mean() self.eval_statistics.update( create_stats_ordered_dict( 'Advantage Weights', ptu.get_numpy(weights), )) self.eval_statistics['BC Loss'] = np.mean( ptu.get_numpy(train_bc_loss)) else: # Normal TD3 update policy_loss = -self.rl_weight * q_output.mean() if self.update_policy and not (self.rl_weight == 0 and self.bc_weight == 0): self.policy_optimizer.zero_grad() policy_loss.backward() self.policy_optimizer.step() self.eval_statistics['Policy Loss'] = np.mean( ptu.get_numpy(policy_loss)) if self._n_train_steps_total % self.target_update_period == 0: ptu.soft_update_from_to(self.policy, self.target_policy, self.tau) ptu.soft_update_from_to(self.qf1, self.target_qf1, self.tau) ptu.soft_update_from_to(self.qf2, self.target_qf2, self.tau) if self._need_to_update_eval_statistics: self._need_to_update_eval_statistics = False if policy_loss is None: policy_actions = self.policy(obs) q_output = self.qf1(obs, policy_actions) policy_loss = -q_output.mean() self.eval_statistics['QF1 Loss'] = np.mean(ptu.get_numpy(qf1_loss)) self.eval_statistics['QF2 Loss'] = np.mean(ptu.get_numpy(qf2_loss)) self.eval_statistics.update( create_stats_ordered_dict( 'Q1 Predictions', ptu.get_numpy(q1_pred), )) self.eval_statistics.update( create_stats_ordered_dict( 'Q2 Predictions', ptu.get_numpy(q2_pred), )) self.eval_statistics.update( create_stats_ordered_dict( 'Q Targets', ptu.get_numpy(q_target), )) self.eval_statistics.update( create_stats_ordered_dict( 'Bellman Errors 1', ptu.get_numpy(bellman_errors_1), )) self.eval_statistics.update( create_stats_ordered_dict( 'Bellman Errors 2', ptu.get_numpy(bellman_errors_2), )) self.eval_statistics.update( create_stats_ordered_dict( 'Policy Action', ptu.get_numpy(policy_actions), )) if self.demo_test_buffer._size >= self.bc_batch_size: train_batch = self.get_batch_from_buffer( self.demo_train_buffer) train_u = train_batch["actions"] train_o = train_batch["observations"] if self.goal_conditioned: train_g = train_batch["resampled_goals"] train_o = torch.cat((train_o, train_g), dim=1) train_pred_u = self.policy(train_o) train_error = (train_pred_u - train_u)**2 train_bc_loss = train_error policy_q_output_demo_state = self.qf1(train_o, train_pred_u) demo_q_output = self.qf1(train_o, train_u) train_advantage = demo_q_output - policy_q_output_demo_state test_batch = self.get_batch_from_buffer(self.demo_test_buffer) test_o = test_batch["observations"] test_u = test_batch["actions"] if self.goal_conditioned: test_g = test_batch["resampled_goals"] test_o = torch.cat((test_o, test_g), dim=1) test_pred_u = self.policy(test_o) test_error = (test_pred_u - test_u)**2 test_bc_loss = test_error policy_q_output_demo_state = self.qf1(test_o, test_pred_u) demo_q_output = self.qf1(test_o, test_u) test_advantage = demo_q_output - policy_q_output_demo_state self.eval_statistics.update( create_stats_ordered_dict( 'Train BC Loss', ptu.get_numpy(train_bc_loss), )) self.eval_statistics.update( create_stats_ordered_dict( 'Train Demo Advantage', ptu.get_numpy(train_advantage), )) self.eval_statistics.update( create_stats_ordered_dict( 'Test BC Loss', ptu.get_numpy(test_bc_loss), )) self.eval_statistics.update( create_stats_ordered_dict( 'Test Demo Advantage', ptu.get_numpy(test_advantage), )) self._n_train_steps_total += 1 logger.pop_tabular_prefix()
all_imgs = [ x0.narrow(start=0, length=imlength, dim=1).contiguous().view(-1, 3, imsize, imsize).transpose(2, 3), ] comparison = torch.cat(all_imgs) save_dir = "/home/ashvin/data/s3doodad/share/multiobj/%sx0.png" % prefix save_image(comparison.data.cpu(), save_dir, nrow=8) vae = load_local_or_remote_file(vae_path).to("cpu") vae.eval() model = vae all_imgs = [] for i in range(N_ROWS): latent = ptu.randn( n, model.representation_size) # model.sample_prior(self.batch_size, env) samples = model.decode(latent)[0] all_imgs.extend([samples.view( n, 3, imsize, imsize, )[:n].transpose(2, 3)]) comparison = torch.cat(all_imgs) save_dir = "/home/ashvin/data/s3doodad/share/multiobj/%svae_samples.png" % prefix save_image(comparison.data.cpu(), save_dir, nrow=8) cvae = load_local_or_remote_file(cvae_path).to("cpu") cvae.eval()