def update(self, aug_obs): obs = aug_obs enc_features = self.enc(obs) mu = self.enc_mu(enc_features) logvar = self.enc_logvar(enc_features) stds = (0.5 * logvar).exp() epsilon = ptu.randn(*mu.size()) code = epsilon * stds + mu kle = -0.5 * torch.sum( 1 + logvar - mu.pow(2) - logvar.exp(), dim=1 ).mean() obs_distribution_params = self.dec(code) log_prob = -1. * F.mse_loss(obs, obs_distribution_params, reduction='elementwise_mean') loss = self.beta * kle - log_prob self.optimizer.zero_grad() loss.backward() self.optimizer.step() return loss.cpu().item()
def debug_statistics(self): """ Given an image $$x$$, samples a bunch of latents from the prior $$z_i$$ and decode them $$\hat x_i$$. Compare this to $$\hat x$$, the reconstruction of $$x$$. Ideally - All the $$\hat x_i$$s do worse than $$\hat x$$ (makes sure VAE isn’t ignoring the latent) - Some $$\hat x_i$$ do better than other $$\hat x_i$$ (tests for coverage) """ debug_batch_size = 64 data = self.get_batch(train=False) reconstructions, _, _ = self.model(data) img = data[0] recon_mse = ((reconstructions[0] - img)**2).mean().view(-1) img_repeated = img.expand((debug_batch_size, img.shape[0])) samples = ptu.randn(debug_batch_size, self.representation_size) random_imgs, _ = self.model.decode(samples) random_mses = (random_imgs - img_repeated)**2 mse_improvement = ptu.get_numpy(random_mses.mean(dim=1) - recon_mse) stats = create_stats_ordered_dict( 'debug/MSE improvement over random', mse_improvement, ) stats.update( create_stats_ordered_dict( 'debug/MSE of random decoding', ptu.get_numpy(random_mses), )) stats['debug/MSE of reconstruction'] = ptu.get_numpy(recon_mse)[0] return stats
def load_dataset(self, dataset_path): dataset = load_local_or_remote_file(dataset_path) dataset = dataset.item() observations = dataset['observations'] actions = dataset['actions'] # dataset['observations'].shape # (2000, 50, 6912) # dataset['actions'].shape # (2000, 50, 2) # dataset['env'].shape # (2000, 6912) N, H, imlength = observations.shape self.vae.eval() for n in range(N): x0 = ptu.from_numpy(dataset['env'][n:n + 1, :] / 255.0) x = ptu.from_numpy(observations[n, :, :] / 255.0) latents = self.vae.encode(x, x0, distrib=False) r1, r2 = self.vae.latent_sizes conditioning = latents[0, r1:] goal = torch.cat( [ptu.randn(self.vae.latent_sizes[0]), conditioning]) goal = ptu.get_numpy(goal) # latents[-1, :] latents = ptu.get_numpy(latents) latent_delta = latents - goal distances = np.zeros((H - 1, 1)) for i in range(H - 1): distances[i, 0] = np.linalg.norm(latent_delta[i + 1, :]) terminals = np.zeros((H - 1, 1)) # terminals[-1, 0] = 1 path = dict( observations=[], actions=actions[n, :H - 1, :], next_observations=[], rewards=-distances, terminals=terminals, ) for t in range(H - 1): # reward = -np.linalg.norm(latent_delta[i, :]) obs = dict( latent_observation=latents[t, :], latent_achieved_goal=latents[t, :], latent_desired_goal=goal, ) next_obs = dict( latent_observation=latents[t + 1, :], latent_achieved_goal=latents[t + 1, :], latent_desired_goal=goal, ) path['observations'].append(obs) path['next_observations'].append(next_obs) # import ipdb; ipdb.set_trace() self.replay_buffer.add_path(path)
def dump_samples(self, epoch): self.model.eval() sample = ptu.randn(64, self.representation_size) sample = self.model.decode(sample)[0].cpu() save_dir = osp.join(self.log_dir, 's%d.png' % epoch) save_image( sample.data.view(64, self.input_channels, self.imsize, self.imsize).transpose(2, 3), save_dir)
def dump_samples(self, epoch): self.model.eval() sample = ptu.randn(64, self.representation_size) sample = self.model.decode(sample)[0].cpu() save_dir = osp.join(logger.get_snapshot_dir(), 's%d.png' % epoch) save_image( sample.data.view(64, self.input_channels, self.imsize, self.imsize), save_dir)
def get_encoding_and_suff_stats(self, x): output = self(x) z_dim = output.shape[1] // 2 means, log_var = (output[:, :z_dim], output[:, z_dim:]) stds = (0.5 * log_var).exp() epsilon = ptu.randn(means.shape) latents = epsilon * stds + means return latents, means, log_var, stds
def dump_samples(self, epoch, save_prefix='s'): self.model.eval() sample = ptu.randn(64, self.representation_size) sample = self.model.decode(sample)[0].cpu() save_dir = osp.join(logger.get_snapshot_dir(), '{}{}.png'.format(save_prefix, epoch)) save_image( sample.data.view(64, self.input_channels, self.imsize, self.imsize).transpose(2, 3), save_dir)
def sample_prior(self, batch_size, x_0): if x_0.shape[0] == 1: x_0 = x_0.repeat(batch_size, 1) x_0 = x_0.reshape(-1, self.input_channels, self.imsize, self.imsize) z_cond, _, _, _ = self.netE.cond_encoder(x_0) z_delta = ptu.randn(batch_size, self.latent_size, 1, 1) cond_sample = torch.cat([z_delta, z_cond], dim=1) return cond_sample
def dump_samples(self, epoch): self.model.eval() sample = ptu.randn(64, self.representation_size) sample = self.model.decode(sample)[0].cpu() # save_dir = osp.join(logger.get_snapshot_dir(), 's%d.png' % epoch) # save_dir = osp.join('/mnt/manh/project/visual_RL_imaged_goal', 's%d.png' % epoch) project_path = osp.abspath(os.curdir) save_dir = osp.join(project_path + str('/result_image/'), 's%d.png' % epoch) save_image( sample.data.view(64, self.input_channels, self.imsize, self.imsize).transpose(2, 3), save_dir)
def sample_prior(self, batch_size, x_0, true_prior=True): if x_0.shape[0] == 1: x_0 = x_0.repeat(batch_size, 1) z_sample = ptu.randn(batch_size, self.latent_sizes[0]) if not true_prior: stds = np.exp(0.5 * self.prior_logvar) z_sample = z_sample * stds + self.prior_mu conditioning = self.bn_c(self.c(self.dropout(self.cond_encoder(x_0)))) cond_sample = torch.cat([z_sample, conditioning], dim=1) return cond_sample
def sample_prior(self, batch_size, cond=None, image_cond=True): if cond.shape[0] == 1: cond = cond.repeat(batch_size, axis=0) cond = ptu.from_numpy(cond) if image_cond: z_cond = self.encode_cond(batch_size, cond) else: z_cat = cond.reshape(batch_size, 2 * self.embedding_dim, self.root_len, self.root_len) z_cond = z_cat[:, self.embedding_dim:] z_delta = ptu.randn(batch_size, self.embedding_dim, self.root_len, self.root_len) z_cat = torch.cat([z_delta, z_cond], dim=1).view(-1, self.representation_size) return ptu.get_numpy(z_cat)
def compute_density(self, data): orig_data_length = len(data) data = np.vstack([data for _ in range(self.n_average)]) data = ptu.from_numpy(data) if self.mode == 'biased': latents, means, log_vars, stds = ( self.encoder.get_encoding_and_suff_stats(data)) importance_weights = ptu.ones(data.shape[0]) elif self.mode == 'prior': latents = ptu.randn(len(data), self.z_dim) importance_weights = ptu.ones(data.shape[0]) elif self.mode == 'importance_sampling': latents, means, log_vars, stds = ( self.encoder.get_encoding_and_suff_stats(data)) prior = Normal(ptu.zeros(1), ptu.ones(1)) prior_log_prob = prior.log_prob(latents).sum(dim=1) encoder_distrib = Normal(means, stds) encoder_log_prob = encoder_distrib.log_prob(latents).sum(dim=1) importance_weights = (prior_log_prob - encoder_log_prob).exp() else: raise NotImplementedError() unweighted_data_log_prob = self.compute_log_prob( data, self.decoder, latents).squeeze(1) unweighted_data_prob = unweighted_data_log_prob.exp() unnormalized_data_prob = unweighted_data_prob * importance_weights """ Average over `n_average` """ dp_split = torch.split(unnormalized_data_prob, orig_data_length, dim=0) # pre_avg.shape = ORIG_LEN x N_AVERAGE dp_stacked = torch.stack(dp_split, dim=1) # final.shape = ORIG_LEN unnormalized_dp = torch.sum(dp_stacked, dim=1, keepdim=False) """ Compute the importance weight denomintors. This requires summing across the `n_average` dimension. """ iw_split = torch.split(importance_weights, orig_data_length, dim=0) iw_stacked = torch.stack(iw_split, dim=1) iw_denominators = iw_stacked.sum(dim=1, keepdim=False) final = unnormalized_dp / iw_denominators return ptu.get_numpy(final)
def sample_prior(self, batch_size, cond=None, image_cond=True): if cond.shape[0] == 1: cond = cond.repeat(batch_size, axis=0) cond = ptu.from_numpy(cond) if image_cond: cond = cond.reshape(-1, self.input_channels, self.imsize, self.imsize) z_cond, _, _, _ = self.netE.cond_encoder(cond) z_cond = z_cond.reshape(-1, self.latent_size) else: z_cond = cond[:, self.latent_size:] z_delta = ptu.randn(batch_size, self.latent_size) cond_sample = torch.cat([z_delta, z_cond], dim=1) return ptu.get_numpy(cond_sample)
def dump_samples(self, epoch): self.model.eval() batch, _ = self.eval_data["test/last_batch"] sample = ptu.randn(64, self.representation_size) sample = self.model.decode(sample, batch["observations"])[0].cpu() save_dir = osp.join(self.log_dir, 's%d.png' % epoch) save_image( sample.data.view(64, 3, self.imsize, self.imsize).transpose(2, 3), save_dir) x0 = batch["x_0"] x0_img = x0[:64].narrow(start=0, length=self.imlength // 2, dim=1).contiguous().view( -1, 3, self.imsize, self.imsize).transpose(2, 3) save_dir = osp.join(self.log_dir, 'x0_%d.png' % epoch) save_image(x0_img.data.cpu(), save_dir)
def get_output_for(self, aug_obs, sample=True): """ Returns the log probability of the given observation. """ obs = aug_obs with torch.no_grad(): enc_features = self.enc(obs) mu = self.enc_mu(enc_features) logvar = self.enc_logvar(enc_features) stds = (0.5 * logvar).exp() if sample: epsilon = ptu.randn(*mu.size()) else: epsilon = torch.ones_like(mu) code = epsilon * stds + mu obs_distribution_params = self.dec(code) log_prob = -1. * F.mse_loss(obs, obs_distribution_params, reduction='none') log_prob = torch.sum(log_prob, -1, keepdim=True) return log_prob.detach()
def sample(self, num_samples): return ptu.get_numpy( self.sample_given_z(ptu.randn(num_samples, self.z_dim)))
def fixed_noise(self, b_size): return ptu.randn(b_size, self.representation_size, 1, 1)
def noise(self, size, num_epochs, epoch): noise = ptu.randn(size) std = 0.1 * (num_epochs - epoch) / num_epochs return std * noise
def fixed_noise(self, b_size, latent): z_cond = latent[:, self.model.latent_size:] z_delta = ptu.randn(b_size, self.model.latent_size, 1, 1) return torch.cat([z_delta, z_cond], dim=1)
def train_from_torch(self, batch): rewards = batch['rewards'] terminals = batch['terminals'] obs = batch['observations'] actions = batch['actions'] next_obs = batch['next_observations'] """ Critic operations. """ next_actions = self.target_policy(next_obs) noise = ptu.randn(next_actions.shape) * self.target_policy_noise noise = torch.clamp( noise, -self.target_policy_noise_clip, self.target_policy_noise_clip ) noisy_next_actions = next_actions + noise target_q1_values = self.target_qf1(next_obs, noisy_next_actions) target_q2_values = self.target_qf2(next_obs, noisy_next_actions) target_q_values = torch.min(target_q1_values, target_q2_values) q_target = self.reward_scale * rewards + (1. - terminals) * self.discount * target_q_values q_target = q_target.detach() q1_pred = self.qf1(obs, actions) bellman_errors_1 = (q1_pred - q_target) ** 2 qf1_loss = bellman_errors_1.mean() q2_pred = self.qf2(obs, actions) bellman_errors_2 = (q2_pred - q_target) ** 2 qf2_loss = bellman_errors_2.mean() """ Update Networks """ self.qf1_optimizer.zero_grad() qf1_loss.backward() self.qf1_optimizer.step() self.qf2_optimizer.zero_grad() qf2_loss.backward() self.qf2_optimizer.step() policy_actions = policy_loss = None if self._n_train_steps_total % self.policy_and_target_update_period == 0: policy_actions = self.policy(obs) q_output = self.qf1(obs, policy_actions) policy_loss = - q_output.mean() self.policy_optimizer.zero_grad() policy_loss.backward() self.policy_optimizer.step() ptu.soft_update_from_to(self.policy, self.target_policy, self.tau) ptu.soft_update_from_to(self.qf1, self.target_qf1, self.tau) ptu.soft_update_from_to(self.qf2, self.target_qf2, self.tau) if self._need_to_update_eval_statistics: self._need_to_update_eval_statistics = False if policy_loss is None: policy_actions = self.policy(obs) q_output = self.qf1(obs, policy_actions) policy_loss = - q_output.mean() self.eval_statistics['QF1 Loss'] = np.mean(ptu.get_numpy(qf1_loss)) self.eval_statistics['QF2 Loss'] = np.mean(ptu.get_numpy(qf2_loss)) self.eval_statistics['Policy Loss'] = np.mean(ptu.get_numpy( policy_loss )) self.eval_statistics.update(create_stats_ordered_dict( 'Q1 Predictions', ptu.get_numpy(q1_pred), )) self.eval_statistics.update(create_stats_ordered_dict( 'Q2 Predictions', ptu.get_numpy(q2_pred), )) self.eval_statistics.update(create_stats_ordered_dict( 'Q Targets', ptu.get_numpy(q_target), )) self.eval_statistics.update(create_stats_ordered_dict( 'Bellman Errors 1', ptu.get_numpy(bellman_errors_1), )) self.eval_statistics.update(create_stats_ordered_dict( 'Bellman Errors 2', ptu.get_numpy(bellman_errors_2), )) self.eval_statistics.update(create_stats_ordered_dict( 'Policy Action', ptu.get_numpy(policy_actions), )) self._n_train_steps_total += 1
all_imgs = [ x0.narrow(start=0, length=imlength, dim=1).contiguous().view(-1, 3, imsize, imsize).transpose(2, 3), ] comparison = torch.cat(all_imgs) save_dir = "/home/ashvin/data/s3doodad/share/multiobj/%sx0.png" % prefix save_image(comparison.data.cpu(), save_dir, nrow=8) vae = load_local_or_remote_file(vae_path).to("cpu") vae.eval() model = vae all_imgs = [] for i in range(N_ROWS): latent = ptu.randn( n, model.representation_size) # model.sample_prior(self.batch_size, env) samples = model.decode(latent)[0] all_imgs.extend([samples.view( n, 3, imsize, imsize, )[:n].transpose(2, 3)]) comparison = torch.cat(all_imgs) save_dir = "/home/ashvin/data/s3doodad/share/multiobj/%svae_samples.png" % prefix save_image(comparison.data.cpu(), save_dir, nrow=8) cvae = load_local_or_remote_file(cvae_path).to("cpu") cvae.eval()
def train_from_torch(self, batch): rewards = batch['rewards'] terminals = batch['terminals'] obs = batch['observations'] actions = batch['actions'] next_obs = batch['next_observations'] gt.stamp('preback_start', unique=False) """ Update QF """ with torch.no_grad(): next_actions = self.target_policy(next_obs) noise = ptu.randn(next_actions.shape) * self.target_policy_noise noise = torch.clamp(noise, -self.target_policy_noise_clip, self.target_policy_noise_clip) noisy_next_actions = torch.clamp(next_actions + noise, -self.max_action, self.max_action) next_tau, next_tau_hat, next_presum_tau = self.get_tau( next_obs, noisy_next_actions, fp=self.target_fp) target_z1_values = self.target_zf1(next_obs, noisy_next_actions, next_tau_hat) target_z2_values = self.target_zf2(next_obs, noisy_next_actions, next_tau_hat) target_z_values = torch.min(target_z1_values, target_z2_values) z_target = self.reward_scale * rewards + ( 1. - terminals) * self.discount * target_z_values tau, tau_hat, presum_tau = self.get_tau(obs, actions, fp=self.fp) z1_pred = self.zf1(obs, actions, tau_hat) z2_pred = self.zf2(obs, actions, tau_hat) zf1_loss = self.zf_criterion(z1_pred, z_target, tau_hat, next_presum_tau) zf2_loss = self.zf_criterion(z2_pred, z_target, tau_hat, next_presum_tau) gt.stamp('preback_zf', unique=False) self.zf1_optimizer.zero_grad() zf1_loss.backward() self.zf1_optimizer.step() gt.stamp('backward_zf1', unique=False) self.zf2_optimizer.zero_grad() zf2_loss.backward() self.zf2_optimizer.step() gt.stamp('backward_zf2', unique=False) """ Update FP """ if self.tau_type == 'fqf': with torch.no_grad(): dWdtau = 0.5 * (2 * self.zf1(obs, actions, tau[:, :-1]) - z1_pred[:, :-1] - z1_pred[:, 1:] + 2 * self.zf2(obs, actions, tau[:, :-1]) - z2_pred[:, :-1] - z2_pred[:, 1:]) dWdtau /= dWdtau.shape[0] # (N, T-1) gt.stamp('preback_fp', unique=False) self.fp_optimizer.zero_grad() tau[:, :-1].backward(gradient=dWdtau) self.fp_optimizer.step() gt.stamp('backward_fp', unique=False) """ Policy Loss """ policy_actions = self.policy(obs) risk_param = self.risk_schedule(self._n_train_steps_total) if self.risk_type == 'VaR': tau_ = ptu.ones_like(rewards) * risk_param q_new_actions = self.zf1(obs, policy_actions, tau_) else: with torch.no_grad(): new_tau, new_tau_hat, new_presum_tau = self.get_tau( obs, policy_actions, fp=self.fp) z_new_actions = self.zf1(obs, policy_actions, new_tau_hat) if self.risk_type in ['neutral', 'std']: q_new_actions = torch.sum(new_presum_tau * z_new_actions, dim=1, keepdims=True) if self.risk_type == 'std': q_std = new_presum_tau * (z_new_actions - q_new_actions).pow(2) q_new_actions -= risk_param * q_std.sum( dim=1, keepdims=True).sqrt() else: with torch.no_grad(): risk_weights = distortion_de(new_tau_hat, self.risk_type, risk_param) q_new_actions = torch.sum(risk_weights * new_presum_tau * z_new_actions, dim=1, keepdims=True) policy_loss = -q_new_actions.mean() gt.stamp('preback_policy', unique=False) if self._n_train_steps_total % self.policy_and_target_update_period == 0: self.policy_optimizer.zero_grad() policy_loss.backward() policy_grad = ptu.fast_clip_grad_norm(self.policy.parameters(), self.clip_norm) self.policy_optimizer.step() gt.stamp('backward_policy', unique=False) ptu.soft_update_from_to(self.policy, self.target_policy, self.soft_target_tau) ptu.soft_update_from_to(self.zf1, self.target_zf1, self.soft_target_tau) ptu.soft_update_from_to(self.zf2, self.target_zf2, self.soft_target_tau) if self.tau_type == 'fqf': ptu.soft_update_from_to(self.fp, self.target_fp, self.soft_target_tau) gt.stamp('soft_update', unique=False) """ Save some statistics for eval """ if self._need_to_update_eval_statistics: self._need_to_update_eval_statistics = False """ Eval should set this to None. This way, these statistics are only computed for one batch. """ self.eval_statistics['ZF1 Loss'] = zf1_loss.item() self.eval_statistics['ZF2 Loss'] = zf2_loss.item() self.eval_statistics['Policy Loss'] = policy_loss.item() self.eval_statistics['Policy Grad'] = policy_grad self.eval_statistics.update( create_stats_ordered_dict( 'Z1 Predictions', ptu.get_numpy(z1_pred), )) self.eval_statistics.update( create_stats_ordered_dict( 'Z2 Predictions', ptu.get_numpy(z2_pred), )) self.eval_statistics.update( create_stats_ordered_dict( 'Z Targets', ptu.get_numpy(z_target), )) self.eval_statistics.update( create_stats_ordered_dict( 'Policy Action', ptu.get_numpy(policy_actions), )) self._n_train_steps_total += 1
def sample_prior(self, batch_size): z_s = ptu.randn(batch_size, self.representation_size) return ptu.get_numpy(z_s)
def dump_samples(self, epoch): self.model.eval() sample = ptu.randn(64, self.representation_size) sample = self.model.decode(sample) save_dir = osp.join(self.log_dir, 's%d.png' % epoch) save_image(sample.data.transpose(2, 3), save_dir)
def rsample(self, latent_distribution_params): mu, logvar = latent_distribution_params stds = (0.5 * logvar).exp() epsilon = ptu.randn(*mu.size()) latents = epsilon * stds + mu return latents