def _reconstruct_img(self, flat_img): latent_distribution_params = self.vae.encode( ptu.from_numpy(flat_img.reshape(1, -1))) reconstructions, _ = self.vae.decode(latent_distribution_params[0]) imgs = ptu.get_numpy(reconstructions) imgs = imgs.reshape(1, self.input_channels, self.imsize, self.imsize) return imgs[0]
def _update_info(self, info, obs): latent_distribution_params = self.vae.encode( ptu.from_numpy(obs[self.vae_input_observation_key].reshape(1, -1))) latent_obs, logvar = ptu.get_numpy( latent_distribution_params[0])[0], ptu.get_numpy( latent_distribution_params[1])[0] # assert (latent_obs == obs['latent_observation']).all() latent_goal = self.desired_goal['latent_desired_goal'] dist = latent_goal - latent_obs var = np.exp(logvar.flatten()) var = np.maximum(var, self.reward_min_variance) err = dist * dist / 2 / var mdist = np.sum(err) # mahalanobis distance info["vae_mdist"] = mdist info["vae_success"] = 1 if mdist < self.epsilon else 0 info["vae_dist"] = np.linalg.norm(dist, ord=self.norm_order) info["vae_dist_l1"] = np.linalg.norm(dist, ord=1) info["vae_dist_l2"] = np.linalg.norm(dist, ord=2)
def train_from_torch(self, batch): rewards = batch['rewards'] terminals = batch['terminals'] obs = batch['observations'] actions = batch['actions'] next_obs = batch['next_observations'] p_z = self.sample_empowerment_latents(obs) """ Update the networks """ rews = np.mean(ptu.get_numpy(rewards)) vf_loss, alpha_loss, alpha, qf1_loss, qf2_loss, emp_reward, pol_loss = self.update_state_value( observation=obs, action=actions, done=terminals, next_observation=next_obs, p_z_given=p_z) # Update the discriminator discriminator_loss = self.update_discriminator(observation=obs, action=actions) disc_loss = np.mean(ptu.get_numpy(discriminator_loss)) pol = np.mean(ptu.get_numpy(pol_loss)) emp_rew = np.mean(ptu.get_numpy(emp_reward)) value_loss = np.mean(ptu.get_numpy(vf_loss)) q1_loss = np.mean(ptu.get_numpy(qf1_loss)) q2_loss = np.mean(ptu.get_numpy(qf2_loss)) i = self._n_train_steps_total self.writer.add_scalar('data/reward', rews, i) self.writer.add_scalar('data/policy_loss', pol, i) self.writer.add_scalar('data/discriminator_loss', disc_loss, i) self.writer.add_scalar('data/empowerment_rewards', emp_rew, i) self.writer.add_scalar('data/value_loss', value_loss, i) self.writer.add_scalar('data/q_value_loss_1', q1_loss, i) self.writer.add_scalar('data/q_value_loss_2', q2_loss, i) self._n_train_steps_total += 1
def np_ify(tensor_or_other): if isinstance(tensor_or_other, torch.autograd.Variable): return ptu.get_numpy(tensor_or_other) else: return tensor_or_other
def _do_training(self): batch = self.get_batch() rewards = batch['rewards'] terminals = batch['terminals'] obs = batch['observations'] actions = batch['actions'] next_obs = batch['next_observations'] q1_pred = self.qf1(obs, actions) q2_pred = self.qf2(obs, actions) v_pred = self.vf(obs) # Make sure policy accounts for squashing functions like tanh correctly! policy_outputs = self.policy(obs, reparameterize=self.train_policy_with_reparameterization, return_log_prob=True) new_actions, policy_mean, policy_log_std, log_pi = policy_outputs[:4] """ Alpha Loss (if applicable) """ if self.use_automatic_entropy_tuning: """ Alpha Loss """ alpha_loss = -(self.log_alpha * (log_pi + self.target_entropy).detach()).mean() self.alpha_optimizer.zero_grad() alpha_loss.backward() self.alpha_optimizer.step() alpha = self.log_alpha.exp() else: alpha = 1 alpha_loss = 0 """ QF Loss """ target_v_values = self.target_vf(next_obs) q_target = self.reward_scale * rewards + \ (1. - terminals) * self.discount * target_v_values qf1_loss = self.qf_criterion(q1_pred, q_target.detach()) qf2_loss = self.qf_criterion(q2_pred, q_target.detach()) """ VF Loss """ q_new_actions = torch.min( self.qf1(obs, new_actions), self.qf2(obs, new_actions), ) v_target = q_new_actions - alpha*log_pi vf_loss = self.vf_criterion(v_pred, v_target.detach()) """ Update networks """ self.qf1_optimizer.zero_grad() qf1_loss.backward() self.qf1_optimizer.step() self.qf2_optimizer.zero_grad() qf2_loss.backward() self.qf2_optimizer.step() self.vf_optimizer.zero_grad() vf_loss.backward() self.vf_optimizer.step() policy_loss = None if self._n_train_steps_total % self.policy_update_period == 0: """ Policy Loss """ if self.train_policy_with_reparameterization: policy_loss = (alpha*log_pi - q_new_actions).mean() else: log_policy_target = q_new_actions - v_pred policy_loss = ( log_pi * (alpha*log_pi - log_policy_target).detach() ).mean() mean_reg_loss = self.policy_mean_reg_weight * \ (policy_mean**2).mean() std_reg_loss = self.policy_std_reg_weight * \ (policy_log_std**2).mean() pre_tanh_value = policy_outputs[-1] pre_activation_reg_loss = self.policy_pre_activation_weight * ( (pre_tanh_value**2).sum(dim=1).mean() ) policy_reg_loss = mean_reg_loss + std_reg_loss + pre_activation_reg_loss policy_loss = policy_loss + policy_reg_loss self.policy_optimizer.zero_grad() policy_loss.backward() self.policy_optimizer.step() if self._n_train_steps_total % self.target_update_period == 0: ptu.soft_update_from_to( self.vf, self.target_vf, self.soft_target_tau ) """ Save some statistics for eval using just one batch. """ if self.need_to_update_eval_statistics: self.need_to_update_eval_statistics = False if policy_loss is None: if self.train_policy_with_reparameterization: policy_loss = (log_pi - q_new_actions).mean() else: log_policy_target = q_new_actions - v_pred policy_loss = ( log_pi * (log_pi - log_policy_target).detach() ).mean() mean_reg_loss = self.policy_mean_reg_weight * \ (policy_mean**2).mean() std_reg_loss = self.policy_std_reg_weight * \ (policy_log_std**2).mean() pre_tanh_value = policy_outputs[-1] pre_activation_reg_loss = self.policy_pre_activation_weight * ( (pre_tanh_value**2).sum(dim=1).mean() ) policy_reg_loss = mean_reg_loss + std_reg_loss + pre_activation_reg_loss policy_loss = policy_loss + policy_reg_loss self.eval_statistics['QF1 Loss'] = np.mean(ptu.get_numpy(qf1_loss)) self.eval_statistics['QF2 Loss'] = np.mean(ptu.get_numpy(qf2_loss)) self.eval_statistics['VF Loss'] = np.mean(ptu.get_numpy(vf_loss)) self.eval_statistics['Policy Loss'] = np.mean(ptu.get_numpy( policy_loss )) self.eval_statistics.update(create_stats_ordered_dict( 'Q1 Predictions', ptu.get_numpy(q1_pred), )) self.eval_statistics.update(create_stats_ordered_dict( 'Q2 Predictions', ptu.get_numpy(q2_pred), )) self.eval_statistics.update(create_stats_ordered_dict( 'V Predictions', ptu.get_numpy(v_pred), )) self.eval_statistics.update(create_stats_ordered_dict( 'Log Pis', ptu.get_numpy(log_pi), )) self.eval_statistics.update(create_stats_ordered_dict( 'Policy mu', ptu.get_numpy(policy_mean), )) self.eval_statistics.update(create_stats_ordered_dict( 'Policy log std', ptu.get_numpy(policy_log_std), )) if self.use_automatic_entropy_tuning: self.eval_statistics['Alpha'] = alpha.item() self.eval_statistics['Alpha Loss'] = alpha_loss.item()
def _encode(self, imgs): latent_distribution_params = self.vae.encode(ptu.from_numpy(imgs)) return ptu.get_numpy(latent_distribution_params[0])
def _decode(self, latents): reconstructions, _ = self.vae.decode(ptu.from_numpy(latents)) decoded = ptu.get_numpy(reconstructions) return decoded
def update_state_value(self, observation, action, p_z_given, next_observation, done): """ Creates minimization operations for the state value functions. In principle, there is no need for a separate state value function approximator, since it could be evaluated using the Q-function and policy. However, in practice, the separate function approximator stabilizes training. :return: """ qf1_loss, qf2_loss, empowerment_reward, _ = self.update_critic( observation=observation, action=action, p_z_given=p_z_given, next_observation=next_observation, done=done) # Make sure policy accounts for squashing functions like tanh correctly! policy_outputs = self.policy( observation, reparameterize=self.train_policy_with_reparameterization, return_log_prob=True) new_actions, policy_mean, policy_log_std, log_pi = policy_outputs[:4] q_new_actions = torch.min( self.qf1(observation, new_actions), self.qf2(observation, new_actions), ) (observation, z_one_hot) = self.split_obs(obs=observation) v_pred = self.value_network(observation) """ Alpha Loss (if applicable) """ if self.use_automatic_entropy_tuning: """ Alpha Loss """ alpha_loss = -(self.log_alpha * (log_pi + self.target_entropy).detach()).mean() self.alpha_optimizer.zero_grad() alpha_loss.backward() self.alpha_optimizer.step() alpha = self.log_alpha.exp() else: alpha = 1 alpha_loss = 0 policy_distribution = self.policy.get_distibution(observation) log_pi = policy_distribution.log_pi v_target = q_new_actions - alpha * log_pi vf_loss = self.vf_criterion(v_pred, v_target.detach()) """ Update networks """ self.qf1_optimizer.zero_grad() qf1_loss.backward() self.qf1_optimizer.step() self.qf2_optimizer.zero_grad() qf2_loss.backward() self.qf2_optimizer.step() self.vf_optimizer.zero_grad() vf_loss.backward() self.vf_optimizer.step() policy_loss = None if self._n_train_steps_total % self.policy_update_period == 0: """ Policy Loss """ if self.train_policy_with_reparameterization: policy_loss = (alpha * log_pi - q_new_actions).mean() else: log_policy_target = q_new_actions - v_pred policy_loss = ( log_pi * (alpha * log_pi - log_policy_target).detach()).mean() mean_reg_loss = self.policy_mean_reg_weight * (policy_mean** 2).mean() std_reg_loss = self.policy_std_reg_weight * (policy_log_std** 2).mean() pre_tanh_value = policy_outputs[-1] pre_activation_reg_loss = self.policy_pre_activation_weight * ( (pre_tanh_value**2).sum(dim=1).mean()) policy_reg_loss = mean_reg_loss + std_reg_loss + pre_activation_reg_loss policy_loss = policy_loss + policy_reg_loss self.policy_optimizer.zero_grad() policy_loss.backward() self.policy_optimizer.step() if self._n_train_steps_total % self.target_update_period == 0: ptu.soft_update_from_to(self.value_network, self.target_vf, self.soft_target_tau) """ Save some statistics for eval using just one batch. """ if self.need_to_update_eval_statistics: self.need_to_update_eval_statistics = False if policy_loss is None: if self.train_policy_with_reparameterization: policy_loss = (log_pi - q_new_actions).mean() else: log_policy_target = q_new_actions - v_pred policy_loss = ( log_pi * (log_pi - log_policy_target).detach()).mean() mean_reg_loss = self.policy_mean_reg_weight * (policy_mean** 2).mean() std_reg_loss = self.policy_std_reg_weight * (policy_log_std** 2).mean() pre_tanh_value = policy_outputs[-1] pre_activation_reg_loss = self.policy_pre_activation_weight * ( (pre_tanh_value**2).sum(dim=1).mean()) policy_reg_loss = mean_reg_loss + std_reg_loss + pre_activation_reg_loss policy_loss = policy_loss + policy_reg_loss self.eval_statistics['QF1 Loss'] = np.mean(ptu.get_numpy(qf1_loss)) self.eval_statistics['QF2 Loss'] = np.mean(ptu.get_numpy(qf2_loss)) self.eval_statistics['VF Loss'] = np.mean(ptu.get_numpy(vf_loss)) self.eval_statistics['Policy Loss'] = np.mean( ptu.get_numpy(policy_loss)) self.eval_statistics.update( create_stats_ordered_dict( 'V Predictions', ptu.get_numpy(v_pred), )) self.eval_statistics.update( create_stats_ordered_dict( 'Log Pis', ptu.get_numpy(log_pi), )) self.eval_statistics.update( create_stats_ordered_dict( 'Policy mu', ptu.get_numpy(policy_mean), )) self.eval_statistics.update( create_stats_ordered_dict( 'Policy log std', ptu.get_numpy(policy_log_std), )) self.eval_statistics.update( create_stats_ordered_dict( 'Empowerment Reward', ptu.get_numpy(empowerment_reward), )) if self.use_automatic_entropy_tuning: self.eval_statistics['Alpha'] = alpha.item() self.eval_statistics['Alpha Loss'] = alpha_loss.item() return vf_loss, alpha_loss, alpha, qf1_loss, qf2_loss, empowerment_reward, policy_loss
def train_from_torch(self, batch): rewards = batch['rewards'] terminals = batch['terminals'] obs = batch['observations'] actions = batch['actions'] next_obs = batch['next_observations'] """ Policy and Alpha Loss """ new_obs_actions, policy_mean, policy_log_std, log_pi, *_ = self.policy( obs, reparameterize=True, return_log_prob=True, ) if self.use_automatic_entropy_tuning: alpha_loss = -(self.log_alpha * (log_pi + self.target_entropy).detach()).mean() self.alpha_optimizer.zero_grad() alpha_loss.backward() self.alpha_optimizer.step() alpha = self.log_alpha.exp() else: alpha_loss = 0 alpha = 1 q_new_actions = torch.min( self.qf1(obs, new_obs_actions), self.qf2(obs, new_obs_actions), ) policy_loss = (alpha*log_pi - q_new_actions).mean() """ QF Loss """ q1_pred = self.qf1(obs, actions) q2_pred = self.qf2(obs, actions) # Make sure policy accounts for squashing functions like tanh correctly! new_next_actions, _, _, new_log_pi, *_ = self.policy( next_obs, reparameterize=True, return_log_prob=True, ) target_q_values = torch.min( self.target_qf1(next_obs, new_next_actions), self.target_qf2(next_obs, new_next_actions), ) - alpha * new_log_pi q_target = self.reward_scale * rewards + (1. - terminals) * self.discount * target_q_values qf1_loss = self.qf_criterion(q1_pred, q_target.detach()) qf2_loss = self.qf_criterion(q2_pred, q_target.detach()) """ Update networks """ self.qf1_optimizer.zero_grad() qf1_loss.backward() self.qf1_optimizer.step() self.qf2_optimizer.zero_grad() qf2_loss.backward() self.qf2_optimizer.step() self.policy_optimizer.zero_grad() policy_loss.backward() self.policy_optimizer.step() """ Soft Updates """ if self._n_train_steps_total % self.target_update_period == 0: ptu.soft_update_from_to( self.qf1, self.target_qf1, self.soft_target_tau ) ptu.soft_update_from_to( self.qf2, self.target_qf2, self.soft_target_tau ) """ Save some statistics for eval """ if self._need_to_update_eval_statistics: self._need_to_update_eval_statistics = False """ Eval should set this to None. This way, these statistics are only computed for one batch. """ policy_loss = (log_pi - q_new_actions).mean() self.eval_statistics['QF1 Loss'] = np.mean(ptu.get_numpy(qf1_loss)) self.eval_statistics['QF2 Loss'] = np.mean(ptu.get_numpy(qf2_loss)) self.eval_statistics['Policy Loss'] = np.mean(ptu.get_numpy( policy_loss )) self.eval_statistics.update(create_stats_ordered_dict( 'Q1 Predictions', ptu.get_numpy(q1_pred), )) self.eval_statistics.update(create_stats_ordered_dict( 'Q2 Predictions', ptu.get_numpy(q2_pred), )) self.eval_statistics.update(create_stats_ordered_dict( 'Q Targets', ptu.get_numpy(q_target), )) self.eval_statistics.update(create_stats_ordered_dict( 'Log Pis', ptu.get_numpy(log_pi), )) self.eval_statistics.update(create_stats_ordered_dict( 'Policy mu', ptu.get_numpy(policy_mean), )) self.eval_statistics.update(create_stats_ordered_dict( 'Policy log std', ptu.get_numpy(policy_log_std), )) if self.use_automatic_entropy_tuning: self.eval_statistics['Alpha'] = alpha.item() self.eval_statistics['Alpha Loss'] = alpha_loss.item() self._n_train_steps_total += 1