def debug_statistics(self): """ Given an image $$x$$, samples a bunch of latents from the prior $$z_i$$ and decode them $$\hat x_i$$. Compare this to $$\hat x$$, the reconstruction of $$x$$. Ideally - All the $$\hat x_i$$s do worse than $$\hat x$$ (makes sure VAE isn’t ignoring the latent) - Some $$\hat x_i$$ do better than other $$\hat x_i$$ (tests for coverage) """ debug_batch_size = 64 data = self.get_batch(train=False) reconstructions, _, _ = self.model(data) img = data[0] recon_mse = ((reconstructions[0] - img)**2).mean().view(-1) img_repeated = img.expand((debug_batch_size, img.shape[0])) samples = ptu.randn(debug_batch_size, self.representation_size) random_imgs, _ = self.model.decode(samples) random_mses = (random_imgs - img_repeated)**2 mse_improvement = ptu.get_numpy(random_mses.mean(dim=1) - recon_mse) stats = create_stats_ordered_dict( 'debug/MSE improvement over random', mse_improvement, ) stats.update( create_stats_ordered_dict( 'debug/MSE of random decoding', ptu.get_numpy(random_mses), )) stats['debug/MSE of reconstruction'] = ptu.get_numpy(recon_mse)[0] return stats
def get_dataset_stats(self, data): torch_input = ptu.from_numpy(normalize_image(data)) mus, log_vars = self.model.encode(torch_input) mus = ptu.get_numpy(mus) mean = np.mean(mus, axis=0) std = np.std(mus, axis=0) return mus, mean, std
def _update_info(self, info, obs): latent_distribution_params = self.vae.encode( ptu.from_numpy(obs[self.vae_input_observation_key].reshape(1, -1)) ) latent_obs = ptu.get_numpy(latent_distribution_params[0])[0] latent_goal = self.desired_goal['latent_desired_goal'] dist = latent_goal - latent_obs info["vae_dist"] = np.linalg.norm(dist, ord=2)
def _reconstruct_img(self, flat_img): latent_distribution_params = self.vae.encode( ptu.from_numpy(flat_img.reshape(1, -1))) reconstructions, _ = self.vae.decode(latent_distribution_params[0]) imgs = ptu.get_numpy(reconstructions) imgs = imgs.reshape( 1, self.input_channels, self.imsize, self.imsize ) return imgs[0]
def _do_training(self): batch = self.get_batch() rewards = batch['rewards'] terminals = batch['terminals'] obs = batch['observations'] actions = batch['actions'] next_obs = batch['next_observations'] """ Compute loss """ target_q_values = self.target_qf(next_obs).detach().max( 1, keepdim=True )[0] y_target = rewards + (1. - terminals) * self.discount * target_q_values y_target = y_target.detach() # actions is a one-hot vector y_pred = torch.sum(self.qf(obs) * actions, dim=1, keepdim=True) qf_loss = self.qf_criterion(y_pred, y_target) """ Update networks """ self.qf_optimizer.zero_grad() qf_loss.backward() self.qf_optimizer.step() self._update_target_network() """ Save some statistics for eval using just one batch. """ if self.need_to_update_eval_statistics: self.need_to_update_eval_statistics = False self.eval_statistics['QF Loss'] = np.mean(ptu.get_numpy(qf_loss)) self.eval_statistics.update(create_stats_ordered_dict( 'Y Predictions', ptu.get_numpy(y_pred), ))
def _do_training(self): batch = self.get_batch() rewards = batch['rewards'] terminals = batch['terminals'] obs = batch['observations'] actions = batch['actions'] next_obs = batch['next_observations'] """ Policy operations. """ if self.policy_pre_activation_weight > 0: policy_actions, pre_tanh_value = self.policy( obs, return_preactivations=True, ) pre_activation_policy_loss = ((pre_tanh_value**2).sum( dim=1).mean()) q_output = self.qf(obs, policy_actions) raw_policy_loss = -q_output.mean() policy_loss = ( raw_policy_loss + pre_activation_policy_loss * self.policy_pre_activation_weight) else: policy_actions = self.policy(obs) q_output = self.qf(obs, policy_actions) raw_policy_loss = policy_loss = -q_output.mean() """ Critic operations. """ next_actions = self.target_policy(next_obs) # speed up computation by not backpropping these gradients next_actions.detach() target_q_values = self.target_qf( next_obs, next_actions, ) q_target = rewards + (1. - terminals) * self.discount * target_q_values q_target = q_target.detach() q_target = torch.clamp(q_target, self.min_q_value, self.max_q_value) # Hack for ICLR rebuttal if hasattr(self, 'reward_type') and self.reward_type == 'indicator': q_target = torch.clamp(q_target, -self.reward_scale / (1 - self.discount), 0) q_pred = self.qf(obs, actions) bellman_errors = (q_pred - q_target)**2 raw_qf_loss = self.qf_criterion(q_pred, q_target) if self.residual_gradient_weight > 0: residual_next_actions = self.policy(next_obs) # speed up computation by not backpropping these gradients residual_next_actions.detach() residual_target_q_values = self.qf( next_obs, residual_next_actions, ) residual_q_target = ( rewards + (1. - terminals) * self.discount * residual_target_q_values) residual_bellman_errors = (q_pred - residual_q_target)**2 # noinspection PyUnresolvedReferences residual_qf_loss = residual_bellman_errors.mean() raw_qf_loss = (self.residual_gradient_weight * residual_qf_loss + (1 - self.residual_gradient_weight) * raw_qf_loss) if self.qf_weight_decay > 0: reg_loss = self.qf_weight_decay * sum( torch.sum(param**2) for param in self.qf.regularizable_parameters()) qf_loss = raw_qf_loss + reg_loss else: qf_loss = raw_qf_loss """ Update Networks """ self.policy_optimizer.zero_grad() policy_loss.backward() self.policy_optimizer.step() self.qf_optimizer.zero_grad() qf_loss.backward() self.qf_optimizer.step() self._update_target_networks() """ Save some statistics for eval using just one batch. """ if self.need_to_update_eval_statistics: self.need_to_update_eval_statistics = False self.eval_statistics['QF Loss'] = np.mean(ptu.get_numpy(qf_loss)) self.eval_statistics['Policy Loss'] = np.mean( ptu.get_numpy(policy_loss)) self.eval_statistics['Raw Policy Loss'] = np.mean( ptu.get_numpy(raw_policy_loss)) self.eval_statistics['Preactivation Policy Loss'] = ( self.eval_statistics['Policy Loss'] - self.eval_statistics['Raw Policy Loss']) self.eval_statistics.update( create_stats_ordered_dict( 'Q Predictions', ptu.get_numpy(q_pred), )) self.eval_statistics.update( create_stats_ordered_dict( 'Q Targets', ptu.get_numpy(q_target), )) self.eval_statistics.update( create_stats_ordered_dict( 'Bellman Errors', ptu.get_numpy(bellman_errors), )) self.eval_statistics.update( create_stats_ordered_dict( 'Policy Action', ptu.get_numpy(policy_actions), ))
def get_action(self, obs): obs = np.expand_dims(obs, axis=0) obs = ptu.from_numpy(obs).float() q_values = self.qf(obs).squeeze(0) q_values_np = ptu.get_numpy(q_values) return q_values_np.argmax(), {}
def _encode(self, imgs): latent_distribution_params = self.vae.encode(ptu.from_numpy(imgs)) return ptu.get_numpy(latent_distribution_params[0])
def _decode(self, latents): reconstructions, _ = self.vae.decode(ptu.from_numpy(latents)) decoded = ptu.get_numpy(reconstructions) return decoded
def test_epoch( self, epoch, save_reconstruction=True, save_vae=True, from_rl=False, ): self.model.eval() losses = [] log_probs = [] kles = [] zs = [] for batch_idx in range(10): next_obs = self.get_batch(train=False) reconstructions, obs_distribution_params, latent_distribution_params = self.model( next_obs) log_prob = self.model.logprob(next_obs, obs_distribution_params) kle = self.model.kl_divergence(latent_distribution_params) loss = self.beta * kle - log_prob encoder_mean = latent_distribution_params[0] z_data = ptu.get_numpy(encoder_mean.cpu()) for i in range(len(z_data)): zs.append(z_data[i, :]) losses.append(loss.item()) log_probs.append(log_prob.item()) kles.append(kle.item()) if batch_idx == 0 and save_reconstruction: n = min(next_obs.size(0), 8) comparison = torch.cat([ next_obs[:n].narrow(start=0, length=self.imlength, dim=1).contiguous().view( -1, self.input_channels, self.imsize, self.imsize), reconstructions.view( self.batch_size, self.input_channels, self.imsize, self.imsize, )[:n] ]) save_dir = osp.join(logger.get_snapshot_dir(), 'r%d.png' % epoch) save_image(comparison.data.cpu(), save_dir, nrow=n) zs = np.array(zs) self.model.dist_mu = zs.mean(axis=0) self.model.dist_std = zs.std(axis=0) if from_rl: self.vae_logger_stats_for_rl['Test VAE Epoch'] = epoch self.vae_logger_stats_for_rl['Test VAE Log Prob'] = np.mean( log_probs) self.vae_logger_stats_for_rl['Test VAE KL'] = np.mean(kles) self.vae_logger_stats_for_rl['Test VAE loss'] = np.mean(losses) self.vae_logger_stats_for_rl['VAE Beta'] = self.beta else: for key, value in self.debug_statistics().items(): logger.record_tabular(key, value) logger.record_tabular("test/Log Prob", np.mean(log_probs)) logger.record_tabular("test/KL", np.mean(kles)) logger.record_tabular("test/loss", np.mean(losses)) logger.record_tabular("beta", self.beta) logger.dump_tabular() if save_vae: logger.save_itr_params(epoch, self.model) # slow...
def _do_training(self): batch = self.get_batch() rewards = batch['rewards'] terminals = batch['terminals'] obs = batch['observations'] actions = batch['actions'] next_obs = batch['next_observations'] """ Critic operations. """ next_actions = self.target_policy(next_obs) noise = torch.normal( torch.zeros_like(next_actions), self.target_policy_noise, ) noise = torch.clamp(noise, -self.target_policy_noise_clip, self.target_policy_noise_clip) noisy_next_actions = next_actions + noise target_q1_values = self.target_qf1(next_obs, noisy_next_actions) target_q2_values = self.target_qf2(next_obs, noisy_next_actions) target_q_values = torch.min(target_q1_values, target_q2_values) q_target = rewards + (1. - terminals) * self.discount * target_q_values q_target = q_target.detach() q1_pred = self.qf1(obs, actions) bellman_errors_1 = (q1_pred - q_target)**2 qf1_loss = bellman_errors_1.mean() q2_pred = self.qf2(obs, actions) bellman_errors_2 = (q2_pred - q_target)**2 qf2_loss = bellman_errors_2.mean() """ Update Networks """ self.qf1_optimizer.zero_grad() qf1_loss.backward() self.qf1_optimizer.step() self.qf2_optimizer.zero_grad() qf2_loss.backward() self.qf2_optimizer.step() policy_actions = policy_loss = None if self._n_train_steps_total % self.policy_and_target_update_period == 0: policy_actions = self.policy(obs) q_output = self.qf1(obs, policy_actions) policy_loss = -q_output.mean() self.policy_optimizer.zero_grad() policy_loss.backward() self.policy_optimizer.step() ptu.soft_update_from_to(self.policy, self.target_policy, self.tau) ptu.soft_update_from_to(self.qf1, self.target_qf1, self.tau) ptu.soft_update_from_to(self.qf2, self.target_qf2, self.tau) """ Save some statistics for eval using just one batch. """ if self.need_to_update_eval_statistics: self.need_to_update_eval_statistics = False if policy_loss is None: policy_actions = self.policy(obs) q_output = self.qf1(obs, policy_actions) policy_loss = -q_output.mean() self.eval_statistics['QF1 Loss'] = np.mean(ptu.get_numpy(qf1_loss)) self.eval_statistics['QF2 Loss'] = np.mean(ptu.get_numpy(qf2_loss)) self.eval_statistics['Policy Loss'] = np.mean( ptu.get_numpy(policy_loss)) self.eval_statistics.update( create_stats_ordered_dict( 'Q1 Predictions', ptu.get_numpy(q1_pred), )) self.eval_statistics.update( create_stats_ordered_dict( 'Q2 Predictions', ptu.get_numpy(q2_pred), )) self.eval_statistics.update( create_stats_ordered_dict( 'Q Targets', ptu.get_numpy(q_target), )) self.eval_statistics.update( create_stats_ordered_dict( 'Bellman Errors 1', ptu.get_numpy(bellman_errors_1), )) self.eval_statistics.update( create_stats_ordered_dict( 'Bellman Errors 2', ptu.get_numpy(bellman_errors_2), )) self.eval_statistics.update( create_stats_ordered_dict( 'Policy Action', ptu.get_numpy(policy_actions), ))
def _do_training(self): batch = self.get_batch() rewards = batch['rewards'] terminals = batch['terminals'] obs = batch['observations'] actions = batch['actions'] next_obs = batch['next_observations'] q1_pred = self.qf1(obs, actions) q2_pred = self.qf2(obs, actions) v_pred = self.vf(obs) # Make sure policy accounts for squashing functions like tanh correctly! policy_outputs = self.policy(obs, reparameterize=self.train_policy_with_reparameterization, return_log_prob=True) new_actions, policy_mean, policy_log_std, log_pi = policy_outputs[:4] """ Alpha Loss (if applicable) """ if self.use_automatic_entropy_tuning: """ Alpha Loss """ alpha_loss = -(self.log_alpha * (log_pi + self.target_entropy).detach()).mean() self.alpha_optimizer.zero_grad() alpha_loss.backward() self.alpha_optimizer.step() alpha = self.log_alpha.exp() else: alpha = 1 alpha_loss = 0 """ QF Loss """ target_v_values = self.target_vf(next_obs) q_target = self.reward_scale * rewards + (1. - terminals) * self.discount * target_v_values qf1_loss = self.qf_criterion(q1_pred, q_target.detach()) qf2_loss = self.qf_criterion(q2_pred, q_target.detach()) """ VF Loss """ q_new_actions = torch.min( self.qf1(obs, new_actions), self.qf2(obs, new_actions), ) v_target = q_new_actions - alpha*log_pi vf_loss = self.vf_criterion(v_pred, v_target.detach()) """ Update networks """ self.qf1_optimizer.zero_grad() qf1_loss.backward() self.qf1_optimizer.step() self.qf2_optimizer.zero_grad() qf2_loss.backward() self.qf2_optimizer.step() self.vf_optimizer.zero_grad() vf_loss.backward() self.vf_optimizer.step() policy_loss = None if self._n_train_steps_total % self.policy_update_period == 0: """ Policy Loss """ if self.train_policy_with_reparameterization: policy_loss = (alpha*log_pi - q_new_actions).mean() else: log_policy_target = q_new_actions - v_pred policy_loss = ( log_pi * (alpha*log_pi - log_policy_target).detach() ).mean() mean_reg_loss = self.policy_mean_reg_weight * (policy_mean**2).mean() std_reg_loss = self.policy_std_reg_weight * (policy_log_std**2).mean() pre_tanh_value = policy_outputs[-1] pre_activation_reg_loss = self.policy_pre_activation_weight * ( (pre_tanh_value**2).sum(dim=1).mean() ) policy_reg_loss = mean_reg_loss + std_reg_loss + pre_activation_reg_loss policy_loss = policy_loss + policy_reg_loss self.policy_optimizer.zero_grad() policy_loss.backward() self.policy_optimizer.step() if self._n_train_steps_total % self.target_update_period == 0: ptu.soft_update_from_to( self.vf, self.target_vf, self.soft_target_tau ) """ Save some statistics for eval using just one batch. """ if self.need_to_update_eval_statistics: self.need_to_update_eval_statistics = False if policy_loss is None: if self.train_policy_with_reparameterization: policy_loss = (log_pi - q_new_actions).mean() else: log_policy_target = q_new_actions - v_pred policy_loss = ( log_pi * (log_pi - log_policy_target).detach() ).mean() mean_reg_loss = self.policy_mean_reg_weight * (policy_mean**2).mean() std_reg_loss = self.policy_std_reg_weight * (policy_log_std**2).mean() pre_tanh_value = policy_outputs[-1] pre_activation_reg_loss = self.policy_pre_activation_weight * ( (pre_tanh_value**2).sum(dim=1).mean() ) policy_reg_loss = mean_reg_loss + std_reg_loss + pre_activation_reg_loss policy_loss = policy_loss + policy_reg_loss self.eval_statistics['QF1 Loss'] = np.mean(ptu.get_numpy(qf1_loss)) self.eval_statistics['QF2 Loss'] = np.mean(ptu.get_numpy(qf2_loss)) self.eval_statistics['VF Loss'] = np.mean(ptu.get_numpy(vf_loss)) self.eval_statistics['Policy Loss'] = np.mean(ptu.get_numpy( policy_loss )) self.eval_statistics.update(create_stats_ordered_dict( 'Q1 Predictions', ptu.get_numpy(q1_pred), )) self.eval_statistics.update(create_stats_ordered_dict( 'Q2 Predictions', ptu.get_numpy(q2_pred), )) self.eval_statistics.update(create_stats_ordered_dict( 'V Predictions', ptu.get_numpy(v_pred), )) self.eval_statistics.update(create_stats_ordered_dict( 'Log Pis', ptu.get_numpy(log_pi), )) self.eval_statistics.update(create_stats_ordered_dict( 'Policy mu', ptu.get_numpy(policy_mean), )) self.eval_statistics.update(create_stats_ordered_dict( 'Policy log std', ptu.get_numpy(policy_log_std), )) if self.use_automatic_entropy_tuning: self.eval_statistics['Alpha'] = alpha.item() self.eval_statistics['Alpha Loss'] = alpha_loss.item()
def _reconstruction_squared_error_np_to_np(self, np_imgs): torch_input = ptu.from_numpy(normalize_image(np_imgs)) recons, *_ = self.model(torch_input) error = torch_input - recons return ptu.get_numpy((error**2).sum(dim=1))
def _kl_np_to_np(self, np_imgs): torch_input = ptu.from_numpy(normalize_image(np_imgs)) mu, log_var = self.model.encode(torch_input) return ptu.get_numpy( -torch.sum(1 + log_var - mu.pow(2) - log_var.exp(), dim=1))
def get_param_values_np(self): state_dict = self.state_dict() np_dict = OrderedDict() for key, tensor in state_dict.items(): np_dict[key] = ptu.get_numpy(tensor) return np_dict
def np_ify(tensor_or_other): if isinstance(tensor_or_other, Variable): return ptu.get_numpy(tensor_or_other) else: return tensor_or_other
def _do_training(self): batch = self.get_batch() rewards = batch['rewards'] terminals = batch['terminals'] obs = batch['observations'] actions = batch['actions'] next_obs = batch['next_observations'] goals = batch['goals'] num_steps_left = batch['num_steps_left'] """ Policy operations. """ policy_actions, pre_tanh_value = self.policy( obs, goals, num_steps_left, return_preactivations=True, ) pre_activation_policy_loss = ((pre_tanh_value**2).sum(dim=1).mean()) q_output = self.qf( observations=obs, actions=policy_actions, num_steps_left=num_steps_left, goals=goals, ) raw_policy_loss = -q_output.mean() policy_loss = ( raw_policy_loss + pre_activation_policy_loss * self.policy_pre_activation_weight) """ Critic operations. """ next_actions = self.target_policy( observations=next_obs, goals=goals, num_steps_left=num_steps_left - 1, ) # speed up computation by not backpropping these gradients next_actions.detach() target_q_values = self.target_qf( observations=next_obs, actions=next_actions, goals=goals, num_steps_left=num_steps_left - 1, ) q_target = rewards + (1. - terminals) * self.discount * target_q_values q_target = q_target.detach() q_pred = self.qf( observations=obs, actions=actions, goals=goals, num_steps_left=num_steps_left, ) if self.tdm_normalizer: q_pred = self.tdm_normalizer.distance_normalizer.normalize_scale( q_pred) q_target = self.tdm_normalizer.distance_normalizer.normalize_scale( q_target) bellman_errors = (q_pred - q_target)**2 qf_loss = self.qf_criterion(q_pred, q_target) """ Update Networks """ self.policy_optimizer.zero_grad() policy_loss.backward() self.policy_optimizer.step() self.qf_optimizer.zero_grad() qf_loss.backward() self.qf_optimizer.step() self._update_target_networks() """ Save some statistics for eval using just one batch. """ if self.need_to_update_eval_statistics: self.need_to_update_eval_statistics = False self.eval_statistics['QF Loss'] = np.mean(ptu.get_numpy(qf_loss)) self.eval_statistics['Policy Loss'] = np.mean( ptu.get_numpy(policy_loss)) self.eval_statistics['Raw Policy Loss'] = np.mean( ptu.get_numpy(raw_policy_loss)) self.eval_statistics['Preactivation Policy Loss'] = ( self.eval_statistics['Policy Loss'] - self.eval_statistics['Raw Policy Loss']) self.eval_statistics.update( create_stats_ordered_dict( 'Q Predictions', ptu.get_numpy(q_pred), )) self.eval_statistics.update( create_stats_ordered_dict( 'Q Targets', ptu.get_numpy(q_target), )) self.eval_statistics.update( create_stats_ordered_dict( 'Bellman Errors', ptu.get_numpy(bellman_errors), )) self.eval_statistics.update( create_stats_ordered_dict( 'Policy Action', ptu.get_numpy(policy_actions), ))