def get_dataset_stats(self, data): torch_input = ptu.from_numpy(normalize_image(data)) mus, log_vars = self.model.encode(torch_input) mus = ptu.get_numpy(mus) mean = np.mean(mus, axis=0) std = np.std(mus, axis=0) return mus, mean, std
def _reconstruct_img(self, flat_img): latent_distribution_params = self.vae.encode( ptu.from_numpy(flat_img.reshape(1, -1))) reconstructions, _ = self.vae.decode(latent_distribution_params[0]) imgs = ptu.get_numpy(reconstructions) imgs = imgs.reshape(1, self.input_channels, self.imsize, self.imsize) return imgs[0]
def train_from_torch(self, batch): rewards = batch['rewards'] * self.reward_scale terminals = batch['terminals'] obs = batch['observations'] actions = batch['actions'] next_obs = batch['next_observations'] """ Compute loss """ target_q_values = self.target_qf(next_obs).detach().max( 1, keepdim=True )[0] y_target = rewards + (1. - terminals) * self.discount * target_q_values y_target = y_target.detach() # actions is a one-hot vector y_pred = torch.sum(self.qf(obs) * actions, dim=1, keepdim=True) qf_loss = self.qf_criterion(y_pred, y_target) """ Soft target network updates """ self.qf_optimizer.zero_grad() qf_loss.backward() self.qf_optimizer.step() """ Soft Updates """ if self._n_train_steps_total % self.target_update_period == 0: ptu.soft_update_from_to( self.qf, self.target_qf, self.soft_target_tau ) """ Save some statistics for eval using just one batch. """ if self._need_to_update_eval_statistics: self._need_to_update_eval_statistics = False self.eval_statistics['QF Loss'] = np.mean(ptu.get_numpy(qf_loss)) self.eval_statistics.update(create_stats_ordered_dict( 'Y Predictions', ptu.get_numpy(y_pred), ))
def _update_info(self, info, obs): latent_distribution_params = self.vae.encode( ptu.from_numpy(obs[self.vae_input_observation_key].reshape(1, -1))) latent_obs, logvar = ptu.get_numpy( latent_distribution_params[0])[0], ptu.get_numpy( latent_distribution_params[1])[0] # assert (latent_obs == obs['latent_observation']).all() latent_goal = self.desired_goal['latent_desired_goal'] dist = latent_goal - latent_obs var = np.exp(logvar.flatten()) var = np.maximum(var, self.reward_min_variance) err = dist * dist / 2 / var mdist = np.sum(err) # mahalanobis distance info["vae_mdist"] = mdist info["vae_success"] = 1 if mdist < self.epsilon else 0 info["vae_dist"] = np.linalg.norm(dist, ord=self.norm_order) info["vae_dist_l1"] = np.linalg.norm(dist, ord=1) info["vae_dist_l2"] = np.linalg.norm(dist, ord=2)
def train_epoch(self, epoch, sample_batch=None, batches=100, from_rl=False): self.model.train() losses = [] log_probs = [] kles = [] zs = [] beta = float(self.beta_schedule.get_value(epoch)) for batch_idx in range(batches): if sample_batch is not None: data = sample_batch(self.batch_size, epoch) # obs = data['obs'] next_obs = data['next_obs'] # actions = data['actions'] else: next_obs = self.get_batch(epoch=epoch) obs = None actions = None self.optimizer.zero_grad() reconstructions, obs_distribution_params, latent_distribution_params = self.model( next_obs) log_prob = self.model.logprob(next_obs, obs_distribution_params) kle = self.model.kl_divergence(latent_distribution_params) encoder_mean = self.model.get_encoding_from_latent_distribution_params( latent_distribution_params) z_data = ptu.get_numpy(encoder_mean.cpu()) for i in range(len(z_data)): zs.append(z_data[i, :]) loss = -1 * log_prob + beta * kle self.optimizer.zero_grad() loss.backward() losses.append(loss.item()) log_probs.append(log_prob.item()) kles.append(kle.item()) self.optimizer.step() if self.log_interval and batch_idx % self.log_interval == 0: print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( epoch, batch_idx * len(data), len(self.train_loader.dataset), 100. * batch_idx / len(self.train_loader), loss.item() / len(next_obs))) if not from_rl: zs = np.array(zs) self.model.dist_mu = zs.mean(axis=0) self.model.dist_std = zs.std(axis=0) self.eval_statistics['train/log prob'] = np.mean(log_probs) self.eval_statistics['train/KL'] = np.mean(kles) self.eval_statistics['train/loss'] = np.mean(losses)
def debug_statistics(self): """ Given an image $$x$$, samples a bunch of latents from the prior $$z_i$$ and decode them $$\hat x_i$$. Compare this to $$\hat x$$, the reconstruction of $$x$$. Ideally - All the $$\hat x_i$$s do worse than $$\hat x$$ (makes sure VAE isn’t ignoring the latent) - Some $$\hat x_i$$ do better than other $$\hat x_i$$ (tests for coverage) """ debug_batch_size = 64 data = self.get_batch(train=False) reconstructions, _, _ = self.model(data) img = data[0] recon_mse = ((reconstructions[0] - img)**2).mean().view(-1) img_repeated = img.expand((debug_batch_size, img.shape[0])) samples = ptu.randn(debug_batch_size, self.representation_size) random_imgs, _ = self.model.decode(samples) random_mses = (random_imgs - img_repeated)**2 mse_improvement = ptu.get_numpy(random_mses.mean(dim=1) - recon_mse) stats = create_stats_ordered_dict( 'debug/MSE improvement over random', mse_improvement, ) stats.update( create_stats_ordered_dict( 'debug/MSE of random decoding', ptu.get_numpy(random_mses), )) stats['debug/MSE of reconstruction'] = ptu.get_numpy(recon_mse)[0] if self.skew_dataset: stats.update( create_stats_ordered_dict('train weight', self._train_weights)) return stats
def compute_p_x_np_to_np(model, data, power, decoder_distribution='bernoulli', num_latents_to_sample=1, sampling_method='importance_sampling'): assert data.dtype == np.float64, 'images should be normalized' assert power >= -1 and power <= 0, 'power for skew-fit should belong to [-1, 0]' log_p, log_q, log_d = compute_log_p_log_q_log_d(model, data, decoder_distribution, num_latents_to_sample, sampling_method) if sampling_method == 'importance_sampling': log_p_x = (log_p - log_q + log_d).mean(dim=1) elif sampling_method == 'biased_sampling' or sampling_method == 'true_prior_sampling': log_p_x = log_d.mean(dim=1) else: raise EnvironmentError('Invalid Sampling Method Provided') log_p_x_skewed = power * log_p_x return ptu.get_numpy(log_p_x_skewed)
def train_from_torch(self, batch): rewards = batch['rewards'] terminals = batch['terminals'] obs = batch['observations'] actions = batch['actions'] next_obs = batch['next_observations'] """ Policy and Alpha Loss """ new_obs_actions, policy_mean, policy_log_std, log_pi, *_ = self.policy( obs, reparameterize=True, return_log_prob=True, ) if self.use_automatic_entropy_tuning: alpha_loss = -(self.log_alpha * (log_pi + self.target_entropy).detach()).mean() self.alpha_optimizer.zero_grad() alpha_loss.backward() self.alpha_optimizer.step() alpha = self.log_alpha.exp() else: alpha_loss = 0 alpha = 1 q_new_actions = torch.min( self.qf1(obs, new_obs_actions), self.qf2(obs, new_obs_actions), ) policy_loss = (alpha*log_pi - q_new_actions).mean() """ QF Loss """ q1_pred = self.qf1(obs, actions) q2_pred = self.qf2(obs, actions) # Make sure policy accounts for squashing functions like tanh correctly! new_next_actions, _, _, new_log_pi, *_ = self.policy( next_obs, reparameterize=True, return_log_prob=True, ) target_q_values = torch.min( self.target_qf1(next_obs, new_next_actions), self.target_qf2(next_obs, new_next_actions), ) - alpha * new_log_pi q_target = self.reward_scale * rewards + (1. - terminals) * self.discount * target_q_values qf1_loss = self.qf_criterion(q1_pred, q_target.detach()) qf2_loss = self.qf_criterion(q2_pred, q_target.detach()) """ Update networks """ self.qf1_optimizer.zero_grad() qf1_loss.backward() self.qf1_optimizer.step() self.qf2_optimizer.zero_grad() qf2_loss.backward() self.qf2_optimizer.step() self.policy_optimizer.zero_grad() policy_loss.backward() self.policy_optimizer.step() """ Soft Updates """ if self._n_train_steps_total % self.target_update_period == 0: ptu.soft_update_from_to( self.qf1, self.target_qf1, self.soft_target_tau ) ptu.soft_update_from_to( self.qf2, self.target_qf2, self.soft_target_tau ) """ Save some statistics for eval """ if self._need_to_update_eval_statistics: self._need_to_update_eval_statistics = False """ Eval should set this to None. This way, these statistics are only computed for one batch. """ policy_loss = (log_pi - q_new_actions).mean() self.eval_statistics['QF1 Loss'] = np.mean(ptu.get_numpy(qf1_loss)) self.eval_statistics['QF2 Loss'] = np.mean(ptu.get_numpy(qf2_loss)) self.eval_statistics['Policy Loss'] = np.mean(ptu.get_numpy( policy_loss )) self.eval_statistics.update(create_stats_ordered_dict( 'Q1 Predictions', ptu.get_numpy(q1_pred), )) self.eval_statistics.update(create_stats_ordered_dict( 'Q2 Predictions', ptu.get_numpy(q2_pred), )) self.eval_statistics.update(create_stats_ordered_dict( 'Q Targets', ptu.get_numpy(q_target), )) self.eval_statistics.update(create_stats_ordered_dict( 'Log Pis', ptu.get_numpy(log_pi), )) self.eval_statistics.update(create_stats_ordered_dict( 'Policy mu', ptu.get_numpy(policy_mean), )) self.eval_statistics.update(create_stats_ordered_dict( 'Policy log std', ptu.get_numpy(policy_log_std), )) if self.use_automatic_entropy_tuning: self.eval_statistics['Alpha'] = alpha.item() self.eval_statistics['Alpha Loss'] = alpha_loss.item() self._n_train_steps_total += 1
def train_from_torch(self, batch): rewards = batch['rewards'] terminals = batch['terminals'] obs = batch['observations'] actions = batch['actions'] next_obs = batch['next_observations'] """ Critic operations. """ next_actions = self.target_policy(next_obs) noise = ptu.randn(next_actions.shape) * self.target_policy_noise noise = torch.clamp(noise, -self.target_policy_noise_clip, self.target_policy_noise_clip) noisy_next_actions = next_actions + noise target_q1_values = self.target_qf1(next_obs, noisy_next_actions) target_q2_values = self.target_qf2(next_obs, noisy_next_actions) target_q_values = torch.min(target_q1_values, target_q2_values) q_target = self.reward_scale * rewards + ( 1. - terminals) * self.discount * target_q_values q_target = q_target.detach() q1_pred = self.qf1(obs, actions) bellman_errors_1 = (q1_pred - q_target)**2 qf1_loss = bellman_errors_1.mean() q2_pred = self.qf2(obs, actions) bellman_errors_2 = (q2_pred - q_target)**2 qf2_loss = bellman_errors_2.mean() """ Update Networks """ self.qf1_optimizer.zero_grad() qf1_loss.backward() self.qf1_optimizer.step() self.qf2_optimizer.zero_grad() qf2_loss.backward() self.qf2_optimizer.step() policy_actions = policy_loss = None if self._n_train_steps_total % self.policy_and_target_update_period == 0: policy_actions = self.policy(obs) q_output = self.qf1(obs, policy_actions) policy_loss = -q_output.mean() self.policy_optimizer.zero_grad() policy_loss.backward() self.policy_optimizer.step() ptu.soft_update_from_to(self.policy, self.target_policy, self.tau) ptu.soft_update_from_to(self.qf1, self.target_qf1, self.tau) ptu.soft_update_from_to(self.qf2, self.target_qf2, self.tau) if self._need_to_update_eval_statistics: self._need_to_update_eval_statistics = False if policy_loss is None: policy_actions = self.policy(obs) q_output = self.qf1(obs, policy_actions) policy_loss = -q_output.mean() self.eval_statistics['QF1 Loss'] = np.mean(ptu.get_numpy(qf1_loss)) self.eval_statistics['QF2 Loss'] = np.mean(ptu.get_numpy(qf2_loss)) self.eval_statistics['Policy Loss'] = np.mean( ptu.get_numpy(policy_loss)) self.eval_statistics.update( create_stats_ordered_dict( 'Q1 Predictions', ptu.get_numpy(q1_pred), )) self.eval_statistics.update( create_stats_ordered_dict( 'Q2 Predictions', ptu.get_numpy(q2_pred), )) self.eval_statistics.update( create_stats_ordered_dict( 'Q Targets', ptu.get_numpy(q_target), )) self.eval_statistics.update( create_stats_ordered_dict( 'Bellman Errors 1', ptu.get_numpy(bellman_errors_1), )) self.eval_statistics.update( create_stats_ordered_dict( 'Bellman Errors 2', ptu.get_numpy(bellman_errors_2), )) self.eval_statistics.update( create_stats_ordered_dict( 'Policy Action', ptu.get_numpy(policy_actions), )) self._n_train_steps_total += 1
def test_epoch( self, epoch, save_reconstruction=True, save_vae=True, from_rl=False, ): self.model.eval() losses = [] log_probs = [] kles = [] zs = [] beta = float(self.beta_schedule.get_value(epoch)) for batch_idx in range(10): next_obs = self.get_batch(train=False) reconstructions, obs_distribution_params, latent_distribution_params = self.model( next_obs) log_prob = self.model.logprob(next_obs, obs_distribution_params) kle = self.model.kl_divergence(latent_distribution_params) loss = -1 * log_prob + beta * kle encoder_mean = latent_distribution_params[0] z_data = ptu.get_numpy(encoder_mean.cpu()) for i in range(len(z_data)): zs.append(z_data[i, :]) losses.append(loss.item()) log_probs.append(log_prob.item()) kles.append(kle.item()) if batch_idx == 0 and save_reconstruction: n = min(next_obs.size(0), 8) comparison = torch.cat([ next_obs[:n].narrow(start=0, length=self.imlength, dim=1).contiguous().view( -1, self.input_channels, self.imsize, self.imsize).transpose(2, 3), reconstructions.view( self.batch_size, self.input_channels, self.imsize, self.imsize, )[:n].transpose(2, 3) ]) save_dir = osp.join(logger.get_snapshot_dir(), 'r%d.png' % epoch) save_image(comparison.data.cpu(), save_dir, nrow=n) zs = np.array(zs) self.eval_statistics['epoch'] = epoch self.eval_statistics['test/log prob'] = np.mean(log_probs) self.eval_statistics['test/KL'] = np.mean(kles) self.eval_statistics['test/loss'] = np.mean(losses) self.eval_statistics['beta'] = beta if not from_rl: for k, v in self.eval_statistics.items(): logger.record_tabular(k, v) logger.dump_tabular() if save_vae: logger.save_itr_params(epoch, self.model)
def train_from_torch(self, batch): rewards = batch['rewards'] terminals = batch['terminals'] obs = batch['observations'] actions = batch['actions'] next_obs = batch['next_observations'] """ Policy operations. """ if self.policy_pre_activation_weight > 0: policy_actions, pre_tanh_value = self.policy( obs, return_preactivations=True, ) pre_activation_policy_loss = ((pre_tanh_value**2).sum( dim=1).mean()) q_output = self.qf(obs, policy_actions) raw_policy_loss = -q_output.mean() policy_loss = ( raw_policy_loss + pre_activation_policy_loss * self.policy_pre_activation_weight) else: policy_actions = self.policy(obs) q_output = self.qf(obs, policy_actions) raw_policy_loss = policy_loss = -q_output.mean() """ Critic operations. """ next_actions = self.target_policy(next_obs) # speed up computation by not backpropping these gradients next_actions.detach() target_q_values = self.target_qf( next_obs, next_actions, ) q_target = rewards + (1. - terminals) * self.discount * target_q_values q_target = q_target.detach() q_target = torch.clamp(q_target, self.min_q_value, self.max_q_value) q_pred = self.qf(obs, actions) bellman_errors = (q_pred - q_target)**2 raw_qf_loss = self.qf_criterion(q_pred, q_target) if self.qf_weight_decay > 0: reg_loss = self.qf_weight_decay * sum( torch.sum(param**2) for param in self.qf.regularizable_parameters()) qf_loss = raw_qf_loss + reg_loss else: qf_loss = raw_qf_loss """ Update Networks """ self.policy_optimizer.zero_grad() policy_loss.backward() self.policy_optimizer.step() self.qf_optimizer.zero_grad() qf_loss.backward() self.qf_optimizer.step() self._update_target_networks() """ Save some statistics for eval using just one batch. """ if self._need_to_update_eval_statistics: self._need_to_update_eval_statistics = False self.eval_statistics['QF Loss'] = np.mean(ptu.get_numpy(qf_loss)) self.eval_statistics['Policy Loss'] = np.mean( ptu.get_numpy(policy_loss)) self.eval_statistics['Raw Policy Loss'] = np.mean( ptu.get_numpy(raw_policy_loss)) self.eval_statistics['Preactivation Policy Loss'] = ( self.eval_statistics['Policy Loss'] - self.eval_statistics['Raw Policy Loss']) self.eval_statistics.update( create_stats_ordered_dict( 'Q Predictions', ptu.get_numpy(q_pred), )) self.eval_statistics.update( create_stats_ordered_dict( 'Q Targets', ptu.get_numpy(q_target), )) self.eval_statistics.update( create_stats_ordered_dict( 'Bellman Errors', ptu.get_numpy(bellman_errors), )) self.eval_statistics.update( create_stats_ordered_dict( 'Policy Action', ptu.get_numpy(policy_actions), )) self._n_train_steps_total += 1
def _encode(self, imgs): latent_distribution_params = self.vae.encode(ptu.from_numpy(imgs)) return ptu.get_numpy(latent_distribution_params[0])
def _decode(self, latents): reconstructions, _ = self.vae.decode(ptu.from_numpy(latents)) decoded = ptu.get_numpy(reconstructions) return decoded
def np_ify(tensor_or_other): if isinstance(tensor_or_other, torch.autograd.Variable): return ptu.get_numpy(tensor_or_other) else: return tensor_or_other