def rsample(self, *args, **kwargs): event = super().rsample(*args, **kwargs) clipped = torch.max( torch.min(event, ptu.ones_like(event) - self._clip), -1 * ptu.ones_like(event) + self._clip, ) event = event - event.detach() + clipped.detach() event *= self._mult return event
def decode(self, input): output = self(input) if self.output_var == 'learned': mu, logvar = torch.split(output, 2, dim=1) var = logvar.exp() else: mu = output var = self.output_var * ptu.ones_like(mu) return mu, var
def encode(self, input, lstm_hidden=None, return_hidden=False, return_vae_latent=False): ''' input: [seq_len x batch x flatten_img_dim] of flattened images lstm_hidden: [lstm_layers x batch x lstm_hidden_size] mark: change depends on how latent distribution parameters are used ''' seq_len, batch_size, feature_size = input.shape # print("in lstm encode: ", seq_len, batch_size, feature_size) input = input.reshape((-1, feature_size)) feature = self.encoder(input) # [seq_len x batch x conv_output_size] vae_mu = self.vae_fc1(feature) if self.log_min_variance is None: vae_logvar = self.vae_fc2(feature) else: vae_logvar = self.log_min_variance + torch.abs( self.vae_fc2(feature)) # lstm_input = self.rsample((vae_mu, vae_logvar)) # if self.detach_vae_output: # lstm_input = lstm_input.detach() if self.detach_vae_output: lstm_input = vae_mu.detach().clone() else: lstm_input = vae_mu lstm_input = lstm_input.view((seq_len, batch_size, -1)) # if self.detach_vae_output: # lstm_input = lstm_input.detach() if lstm_hidden is None: lstm_hidden = (ptu.zeros(self.lstm_num_layers, batch_size, self.lstm_hidden_size), \ ptu.zeros(self.lstm_num_layers, batch_size, self.lstm_hidden_size)) h, hidden = self.lstm( lstm_input, lstm_hidden) # [seq_len, batch_size, lstm_hidden_size] lstm_latent = self.lstm_fc(h) ret = (lstm_latent, ptu.ones_like(lstm_latent)) if return_vae_latent: ret += (vae_mu, vae_logvar) if return_hidden: return ret, hidden return ret #, lstm_input # [seq_len, batch_size, representation_size]
def train_from_torch(self, batch): rewards = batch['rewards'] terminals = batch['terminals'] obs = batch['observations'] actions = batch['actions'] next_obs = batch['next_observations'] gt.stamp('preback_start', unique=False) """ Update Alpha """ new_actions, policy_mean, policy_log_std, log_pi, *_ = self.policy( obs, reparameterize=True, return_log_prob=True, ) if self.use_automatic_entropy_tuning: alpha_loss = -(self.log_alpha.exp() * (log_pi + self.target_entropy).detach()).mean() self.alpha_optimizer.zero_grad() alpha_loss.backward() self.alpha_optimizer.step() alpha = self.log_alpha.exp() else: alpha_loss = 0 alpha = self.alpha gt.stamp('preback_alpha', unique=False) """ Update ZF """ with torch.no_grad(): new_next_actions, _, _, new_log_pi, *_ = self.target_policy( next_obs, reparameterize=True, return_log_prob=True, ) next_tau, next_tau_hat, next_presum_tau = self.get_tau( next_obs, new_next_actions, fp=self.target_fp) target_z1_values = self.target_zf1(next_obs, new_next_actions, next_tau_hat) target_z2_values = self.target_zf2(next_obs, new_next_actions, next_tau_hat) target_z_values = torch.min(target_z1_values, target_z2_values) - alpha * new_log_pi z_target = self.reward_scale * rewards + ( 1. - terminals) * self.discount * target_z_values tau, tau_hat, presum_tau = self.get_tau(obs, actions, fp=self.fp) z1_pred = self.zf1(obs, actions, tau_hat) z2_pred = self.zf2(obs, actions, tau_hat) zf1_loss = self.zf_criterion(z1_pred, z_target, tau_hat, next_presum_tau) zf2_loss = self.zf_criterion(z2_pred, z_target, tau_hat, next_presum_tau) gt.stamp('preback_zf', unique=False) self.zf1_optimizer.zero_grad() zf1_loss.backward() self.zf1_optimizer.step() gt.stamp('backward_zf1', unique=False) self.zf2_optimizer.zero_grad() zf2_loss.backward() self.zf2_optimizer.step() gt.stamp('backward_zf2', unique=False) """ Update FP """ if self.tau_type == 'fqf': with torch.no_grad(): dWdtau = 0.5 * (2 * self.zf1(obs, actions, tau[:, :-1]) - z1_pred[:, :-1] - z1_pred[:, 1:] + 2 * self.zf2(obs, actions, tau[:, :-1]) - z2_pred[:, :-1] - z2_pred[:, 1:]) dWdtau /= dWdtau.shape[0] # (N, T-1) gt.stamp('preback_fp', unique=False) self.fp_optimizer.zero_grad() tau[:, :-1].backward(gradient=dWdtau) self.fp_optimizer.step() gt.stamp('backward_fp', unique=False) """ Update Policy """ risk_param = self.risk_schedule(self._n_train_steps_total) if self.risk_type == 'VaR': tau_ = ptu.ones_like(rewards) * risk_param q1_new_actions = self.zf1(obs, new_actions, tau_) q2_new_actions = self.zf2(obs, new_actions, tau_) else: with torch.no_grad(): new_tau, new_tau_hat, new_presum_tau = self.get_tau( obs, new_actions, fp=self.fp) z1_new_actions = self.zf1(obs, new_actions, new_tau_hat) z2_new_actions = self.zf2(obs, new_actions, new_tau_hat) if self.risk_type in ['neutral', 'std']: q1_new_actions = torch.sum(new_presum_tau * z1_new_actions, dim=1, keepdims=True) q2_new_actions = torch.sum(new_presum_tau * z2_new_actions, dim=1, keepdims=True) if self.risk_type == 'std': q1_std = new_presum_tau * (z1_new_actions - q1_new_actions).pow(2) q2_std = new_presum_tau * (z2_new_actions - q2_new_actions).pow(2) q1_new_actions -= risk_param * q1_std.sum( dim=1, keepdims=True).sqrt() q2_new_actions -= risk_param * q2_std.sum( dim=1, keepdims=True).sqrt() else: with torch.no_grad(): risk_weights = distortion_de(new_tau_hat, self.risk_type, risk_param) q1_new_actions = torch.sum(risk_weights * new_presum_tau * z1_new_actions, dim=1, keepdims=True) q2_new_actions = torch.sum(risk_weights * new_presum_tau * z2_new_actions, dim=1, keepdims=True) q_new_actions = torch.min(q1_new_actions, q2_new_actions) policy_loss = (alpha * log_pi - q_new_actions).mean() gt.stamp('preback_policy', unique=False) self.policy_optimizer.zero_grad() policy_loss.backward() policy_grad = ptu.fast_clip_grad_norm(self.policy.parameters(), self.clip_norm) self.policy_optimizer.step() gt.stamp('backward_policy', unique=False) """ Soft Updates """ if self._n_train_steps_total % self.target_update_period == 0: ptu.soft_update_from_to(self.policy, self.target_policy, self.soft_target_tau) ptu.soft_update_from_to(self.zf1, self.target_zf1, self.soft_target_tau) ptu.soft_update_from_to(self.zf2, self.target_zf2, self.soft_target_tau) if self.tau_type == 'fqf': ptu.soft_update_from_to(self.fp, self.target_fp, self.soft_target_tau) """ Save some statistics for eval """ if self._need_to_update_eval_statistics: self._need_to_update_eval_statistics = False """ Eval should set this to None. This way, these statistics are only computed for one batch. """ policy_loss = (log_pi - q_new_actions).mean() self.eval_statistics['ZF1 Loss'] = zf1_loss.item() self.eval_statistics['ZF2 Loss'] = zf2_loss.item() self.eval_statistics['Policy Loss'] = policy_loss.item() self.eval_statistics['Policy Grad'] = policy_grad self.eval_statistics.update( create_stats_ordered_dict( 'Z1 Predictions', ptu.get_numpy(z1_pred), )) self.eval_statistics.update( create_stats_ordered_dict( 'Z2 Predictions', ptu.get_numpy(z2_pred), )) self.eval_statistics.update( create_stats_ordered_dict( 'Z Targets', ptu.get_numpy(z_target), )) self.eval_statistics.update( create_stats_ordered_dict( 'Log Pis', ptu.get_numpy(log_pi), )) self.eval_statistics.update( create_stats_ordered_dict( 'Policy mu', ptu.get_numpy(policy_mean), )) self.eval_statistics.update( create_stats_ordered_dict( 'Policy log std', ptu.get_numpy(policy_log_std), )) if self.use_automatic_entropy_tuning: self.eval_statistics['Alpha'] = alpha.item() self.eval_statistics['Alpha Loss'] = alpha_loss.item() self._n_train_steps_total += 1
def forward_ith_model(self, input, i): mean = self.ensemble[i](input) return self.get_dist(mean, ptu.ones_like(mean))
def mode(self): mode = torch.max( torch.min(self.mean, ptu.ones_like(self.mean)), -1 * ptu.ones_like(self.mean), ) return mode
def forward( self, obs, action, use_network_action=False, state=None, batch_indices=None, raps_obs_indices=None, ): """ Forward world model on trajectory. :param obs: (batch_size, path_length, obs_dim) :param action: List [(batch_size, path_length, high_level_action_dim), (batch_size, path_length, low_level_action_dim)] :param use_network_action: :param state: :param batch_indices: :param raps_obs_indices: :return post: Dict mean: (batch_size, path_length, stoch_size) std: (batch_size, path_length, stoch_size) stoch: (batch_size, path_length, stoch_size) deter: (batch_size, path_length, deter_size) :return prior: mean: (batch_size, path_length, stoch_size) std: (batch_size, path_length, stoch_size) stoch: (batch_size, path_length, stoch_size) deter: (batch_size, path_length, deter_size) :return post_dist: mean: (batch_size*(path_length), stoch_size) std: (batch_size*(path_length), stoch_size) :return prior_dist: mean: (batch_size*(path_length), stoch_size) std: (batch_size*(path_length), stoch_size) :return image_dist: mean: (batch_size*(path_length), obs_dim) :return reward_dist: mean: (batch_size*(raps_path_length), 1) :return pred_discount_dist logits: (batch_size*(raps_path_length), 1) :return embed: (batch_size, path_length, embed_dim) :return low_level_action_preds: (batch_size, path_length, low_level_action_dim) """ assert (obs.shape[:2] == action[0].shape[:2] == action[1].shape[:2] ), "Obs and action first two dimensions should be the same." original_batch_size = action[1].shape[0] path_length = action[1].shape[1] if state is None: state = self.initial(original_batch_size) post, prior = ( dict(mean=[], std=[], stoch=[], deter=[]), dict(mean=[], std=[], stoch=[], deter=[]), ) obs_path_len = obs.shape[1] obs = obs.reshape(-1, obs.shape[-1]) embed = self.encode(obs) embedding_size = embed.shape[1] embed = embed.reshape(original_batch_size, obs_path_len, embedding_size) if obs_path_len < path_length: idxs = raps_obs_indices.tolist() else: idxs = np.arange( 0, path_length, 1, ).tolist() post, prior, low_level_action_preds = self.forward_batch( path_length, action, embed, post, prior, state, idxs, use_network_action, ) for key in post.keys(): post[key] = torch.cat(post[key], dim=1) for key in prior.keys(): prior[key] = torch.cat(prior[key], dim=1) if self.use_prior_instead_of_posterior: # in this case, o_hat_t depends on a_t-1 and o_t-1, reset obs decoded from null state + action # only works when first state is reset obs and never changes feat = self.get_features(prior) else: feat = self.get_features(post) raps_obs_feat = feat[:, raps_obs_indices] raps_obs_feat = raps_obs_feat.reshape(-1, raps_obs_feat.shape[-1]) if batch_indices.shape != raps_obs_indices.shape: feat = get_indexed_arr_from_batch_indices(feat, batch_indices).reshape( -1, feat.shape[-1]) else: feat = feat[:, batch_indices] images = self.decode(feat) rewards = self.reward(raps_obs_feat) pred_discounts = self.pred_discount(raps_obs_feat) if batch_indices.shape != raps_obs_indices.shape: post_dist = self.get_dist( get_indexed_arr_from_batch_indices(post["mean"], batch_indices).reshape( -1, post["mean"].shape[-1]), get_indexed_arr_from_batch_indices(post["std"], batch_indices).reshape( -1, post["std"].shape[-1]), ) prior_dist = self.get_dist( get_indexed_arr_from_batch_indices( prior["mean"], batch_indices).reshape(-1, prior["mean"].shape[-1]), get_indexed_arr_from_batch_indices(prior["std"], batch_indices).reshape( -1, prior["std"].shape[-1]), ) else: post_dist = self.get_dist( post["mean"][:, batch_indices].reshape(-1, post["mean"].shape[-1]), post["std"][:, batch_indices].reshape(-1, post["std"].shape[-1]), ) prior_dist = self.get_dist( prior["mean"][:, batch_indices].reshape(-1, prior["mean"].shape[-1]), prior["std"][:, batch_indices].reshape(-1, prior["std"].shape[-1]), ) image_dist = self.get_dist(images, ptu.ones_like(images), dims=3) if self.reward_classifier: reward_dist = self.get_dist(rewards, None, normal=False) else: reward_dist = self.get_dist(rewards, ptu.ones_like(rewards)) pred_discount_dist = self.get_dist(pred_discounts, None, normal=False) return ( post, prior, post_dist, prior_dist, image_dist, reward_dist, pred_discount_dist, embed, low_level_action_preds, )
def forward(self, obs, action): """ Forward world model on trajectory. :param obs: (batch_size, path_length, obs_dim) :param action: (batch_size, path_length, action_dim) :return post: Dict mean: (batch_size, path_length, stoch_size) std: (batch_size, path_length, stoch_size) stoch: (batch_size, path_length, stoch_size) deter: (batch_size, path_length, deter_size) :return prior: mean: (batch_size, path_length, stoch_size) std: (batch_size, path_length, stoch_size) stoch: (batch_size, path_length, stoch_size) deter: (batch_size, path_length, deter_size) :return post_dist: mean: (batch_size*(path_length), stoch_size) std: (batch_size*(path_length), stoch_size) :return prior_dist: mean: (batch_size*(path_length), stoch_size) std: (batch_size*(path_length), stoch_size) :return image_dist: mean: (batch_size*(path_length), obs_dim) :return reward_dist: mean: (batch_size*(raps_path_length), 1) :return pred_discount_dist logits: (batch_size*(raps_path_length), 1) :return embed: (batch_size, path_length, embed_dim) """ original_batch_size = obs.shape[0] state = self.initial(original_batch_size) path_length = obs.shape[1] post, prior = ( dict(mean=[], std=[], stoch=[], deter=[]), dict(mean=[], std=[], stoch=[], deter=[]), ) obs = obs.reshape(-1, obs.shape[-1]) embed = self.encode(obs) embedding_size = embed.shape[1] embed = embed.reshape(original_batch_size, path_length, embedding_size) post, prior = self.forward_batch( path_length, action, embed, post, prior, state, ) for key in post.keys(): post[key] = torch.cat(post[key], dim=1) for key in prior.keys(): prior[key] = torch.cat(prior[key], dim=1) if self.use_prior_instead_of_posterior: # In this case, o_hat_t depends on a_t-1 and o_t-1, reset obs decoded from null state + action. # This only works when first state is reset obs and never changes. feat = self.get_features(prior) else: feat = self.get_features(post) feat = feat.reshape(-1, feat.shape[-1]) images = self.decode(feat) rewards = self.reward(feat) pred_discounts = self.pred_discount(feat) post_dist = self.get_dist( post["mean"].reshape(-1, post["mean"].shape[-1]), post["std"].reshape(-1, post["std"].shape[-1]), ) prior_dist = self.get_dist( prior["mean"].reshape(-1, prior["mean"].shape[-1]), prior["std"].reshape(-1, prior["std"].shape[-1]), ) image_dist = self.get_dist(images, ptu.ones_like(images), dims=3) reward_dist = self.get_dist(rewards, ptu.ones_like(rewards)) pred_discount_dist = self.get_dist(pred_discounts, None, normal=False) return ( post, prior, post_dist, prior_dist, image_dist, reward_dist, pred_discount_dist, embed, )
def train_from_torch(self, batch): rewards = batch['rewards'] terminals = batch['terminals'] obs = batch['observations'] actions = batch['actions'] next_obs = batch['next_observations'] gt.stamp('preback_start', unique=False) """ Update QF """ with torch.no_grad(): next_actions = self.target_policy(next_obs) noise = ptu.randn(next_actions.shape) * self.target_policy_noise noise = torch.clamp(noise, -self.target_policy_noise_clip, self.target_policy_noise_clip) noisy_next_actions = torch.clamp(next_actions + noise, -self.max_action, self.max_action) next_tau, next_tau_hat, next_presum_tau = self.get_tau( next_obs, noisy_next_actions, fp=self.target_fp) target_z1_values = self.target_zf1(next_obs, noisy_next_actions, next_tau_hat) target_z2_values = self.target_zf2(next_obs, noisy_next_actions, next_tau_hat) target_z_values = torch.min(target_z1_values, target_z2_values) z_target = self.reward_scale * rewards + ( 1. - terminals) * self.discount * target_z_values tau, tau_hat, presum_tau = self.get_tau(obs, actions, fp=self.fp) z1_pred = self.zf1(obs, actions, tau_hat) z2_pred = self.zf2(obs, actions, tau_hat) zf1_loss = self.zf_criterion(z1_pred, z_target, tau_hat, next_presum_tau) zf2_loss = self.zf_criterion(z2_pred, z_target, tau_hat, next_presum_tau) gt.stamp('preback_zf', unique=False) self.zf1_optimizer.zero_grad() zf1_loss.backward() self.zf1_optimizer.step() gt.stamp('backward_zf1', unique=False) self.zf2_optimizer.zero_grad() zf2_loss.backward() self.zf2_optimizer.step() gt.stamp('backward_zf2', unique=False) """ Update FP """ if self.tau_type == 'fqf': with torch.no_grad(): dWdtau = 0.5 * (2 * self.zf1(obs, actions, tau[:, :-1]) - z1_pred[:, :-1] - z1_pred[:, 1:] + 2 * self.zf2(obs, actions, tau[:, :-1]) - z2_pred[:, :-1] - z2_pred[:, 1:]) dWdtau /= dWdtau.shape[0] # (N, T-1) gt.stamp('preback_fp', unique=False) self.fp_optimizer.zero_grad() tau[:, :-1].backward(gradient=dWdtau) self.fp_optimizer.step() gt.stamp('backward_fp', unique=False) """ Policy Loss """ policy_actions = self.policy(obs) risk_param = self.risk_schedule(self._n_train_steps_total) if self.risk_type == 'VaR': tau_ = ptu.ones_like(rewards) * risk_param q_new_actions = self.zf1(obs, policy_actions, tau_) else: with torch.no_grad(): new_tau, new_tau_hat, new_presum_tau = self.get_tau( obs, policy_actions, fp=self.fp) z_new_actions = self.zf1(obs, policy_actions, new_tau_hat) if self.risk_type in ['neutral', 'std']: q_new_actions = torch.sum(new_presum_tau * z_new_actions, dim=1, keepdims=True) if self.risk_type == 'std': q_std = new_presum_tau * (z_new_actions - q_new_actions).pow(2) q_new_actions -= risk_param * q_std.sum( dim=1, keepdims=True).sqrt() else: with torch.no_grad(): risk_weights = distortion_de(new_tau_hat, self.risk_type, risk_param) q_new_actions = torch.sum(risk_weights * new_presum_tau * z_new_actions, dim=1, keepdims=True) policy_loss = -q_new_actions.mean() gt.stamp('preback_policy', unique=False) if self._n_train_steps_total % self.policy_and_target_update_period == 0: self.policy_optimizer.zero_grad() policy_loss.backward() policy_grad = ptu.fast_clip_grad_norm(self.policy.parameters(), self.clip_norm) self.policy_optimizer.step() gt.stamp('backward_policy', unique=False) ptu.soft_update_from_to(self.policy, self.target_policy, self.soft_target_tau) ptu.soft_update_from_to(self.zf1, self.target_zf1, self.soft_target_tau) ptu.soft_update_from_to(self.zf2, self.target_zf2, self.soft_target_tau) if self.tau_type == 'fqf': ptu.soft_update_from_to(self.fp, self.target_fp, self.soft_target_tau) gt.stamp('soft_update', unique=False) """ Save some statistics for eval """ if self._need_to_update_eval_statistics: self._need_to_update_eval_statistics = False """ Eval should set this to None. This way, these statistics are only computed for one batch. """ self.eval_statistics['ZF1 Loss'] = zf1_loss.item() self.eval_statistics['ZF2 Loss'] = zf2_loss.item() self.eval_statistics['Policy Loss'] = policy_loss.item() self.eval_statistics['Policy Grad'] = policy_grad self.eval_statistics.update( create_stats_ordered_dict( 'Z1 Predictions', ptu.get_numpy(z1_pred), )) self.eval_statistics.update( create_stats_ordered_dict( 'Z2 Predictions', ptu.get_numpy(z2_pred), )) self.eval_statistics.update( create_stats_ordered_dict( 'Z Targets', ptu.get_numpy(z_target), )) self.eval_statistics.update( create_stats_ordered_dict( 'Policy Action', ptu.get_numpy(policy_actions), )) self._n_train_steps_total += 1
def train_from_torch(self, batch): rewards = batch['rewards'] terminals = batch['terminals'] obs = batch['observations'] actions = batch['actions'] next_obs = batch['next_observations'] context = batch['context'] if self.reward_transform: rewards = self.reward_transform(rewards) if self.terminal_transform: terminals = self.terminal_transform(terminals) """ Policy and Alpha Loss """ dist, p_z, task_z_with_grad = self.agent( obs, context, return_latent_posterior_and_task_z=True, ) task_z_detached = task_z_with_grad.detach() new_obs_actions, log_pi = dist.rsample_and_logprob() log_pi = log_pi.unsqueeze(1) next_dist = self.agent(next_obs, context) if self._debug_ignore_context: task_z_with_grad = task_z_with_grad * 0 # flattens out the task dimension t, b, _ = obs.size() obs = obs.view(t * b, -1) actions = actions.view(t * b, -1) next_obs = next_obs.view(t * b, -1) unscaled_rewards_flat = rewards.view(t * b, 1) rewards_flat = unscaled_rewards_flat * self.reward_scale terms_flat = terminals.view(t * b, 1) if self.use_automatic_entropy_tuning: alpha_loss = -(self.log_alpha * (log_pi + self.target_entropy).detach()).mean() self.alpha_optimizer.zero_grad() alpha_loss.backward() self.alpha_optimizer.step() alpha = self.log_alpha.exp() else: alpha_loss = 0 alpha = self.alpha """ QF Loss """ if self.backprop_q_loss_into_encoder: q1_pred = self.qf1(obs, actions, task_z_with_grad) q2_pred = self.qf2(obs, actions, task_z_with_grad) else: q1_pred = self.qf1(obs, actions, task_z_detached) q2_pred = self.qf2(obs, actions, task_z_detached) # Make sure policy accounts for squashing functions like tanh correctly! new_next_actions, new_log_pi = next_dist.rsample_and_logprob() new_log_pi = new_log_pi.unsqueeze(1) with torch.no_grad(): target_q_values = torch.min( self.target_qf1(next_obs, new_next_actions, task_z_detached), self.target_qf2(next_obs, new_next_actions, task_z_detached), ) - alpha * new_log_pi q_target = rewards_flat + ( 1. - terms_flat) * self.discount * target_q_values qf1_loss = self.qf_criterion(q1_pred, q_target.detach()) qf2_loss = self.qf_criterion(q2_pred, q_target.detach()) """ Context Encoder Loss """ if self._debug_use_ground_truth_context: kl_div = kl_loss = ptu.zeros(0) else: kl_div = kl_divergence(p_z, self.agent.latent_prior).mean(dim=0).sum() kl_loss = self.kl_lambda * kl_div if self.train_context_decoder: # TODO: change to use a distribution reward_pred = self.context_decoder(obs, actions, task_z_with_grad) reward_prediction_loss = ((reward_pred - unscaled_rewards_flat)**2).mean() context_loss = kl_loss + reward_prediction_loss else: context_loss = kl_loss reward_prediction_loss = ptu.zeros(1) """ Policy Loss """ qf1_new_actions = self.qf1(obs, new_obs_actions, task_z_detached) qf2_new_actions = self.qf2(obs, new_obs_actions, task_z_detached) q_new_actions = torch.min( qf1_new_actions, qf2_new_actions, ) # Advantage-weighted regression if self.vf_K > 1: vs = [] for i in range(self.vf_K): u = dist.sample() q1 = self.qf1(obs, u, task_z_detached) q2 = self.qf2(obs, u, task_z_detached) v = torch.min(q1, q2) # v = q1 vs.append(v) v_pi = torch.cat(vs, 1).mean(dim=1) else: # v_pi = self.qf1(obs, new_obs_actions) v1_pi = self.qf1(obs, new_obs_actions, task_z_detached) v2_pi = self.qf2(obs, new_obs_actions, task_z_detached) v_pi = torch.min(v1_pi, v2_pi) u = actions if self.awr_min_q: q_adv = torch.min(q1_pred, q2_pred) else: q_adv = q1_pred policy_logpp = dist.log_prob(u) if self.use_automatic_beta_tuning: buffer_dist = self.buffer_policy(obs) beta = self.log_beta.exp() kldiv = torch.distributions.kl.kl_divergence(dist, buffer_dist) beta_loss = -1 * (beta * (kldiv - self.beta_epsilon).detach()).mean() self.beta_optimizer.zero_grad() beta_loss.backward() self.beta_optimizer.step() else: beta = self.beta_schedule.get_value(self._n_train_steps_total) beta_loss = ptu.zeros(1) score = q_adv - v_pi if self.mask_positive_advantage: score = torch.sign(score) if self.clip_score is not None: score = torch.clamp(score, max=self.clip_score) weights = batch.get('weights', None) if self.weight_loss and weights is None: if self.normalize_over_batch == True: weights = F.softmax(score / beta, dim=0) elif self.normalize_over_batch == "whiten": adv_mean = torch.mean(score) adv_std = torch.std(score) + 1e-5 normalized_score = (score - adv_mean) / adv_std weights = torch.exp(normalized_score / beta) elif self.normalize_over_batch == "exp": weights = torch.exp(score / beta) elif self.normalize_over_batch == "step_fn": weights = (score > 0).float() elif self.normalize_over_batch == False: weights = score elif self.normalize_over_batch == 'uniform': weights = F.softmax(ptu.ones_like(score) / beta, dim=0) else: raise ValueError(self.normalize_over_batch) weights = weights[:, 0] policy_loss = alpha * log_pi.mean() if self.use_awr_update and self.weight_loss: policy_loss = policy_loss + self.awr_weight * ( -policy_logpp * len(weights) * weights.detach()).mean() elif self.use_awr_update: policy_loss = policy_loss + self.awr_weight * ( -policy_logpp).mean() if self.use_reparam_update: policy_loss = policy_loss + self.train_reparam_weight * ( -q_new_actions).mean() policy_loss = self.rl_weight * policy_loss """ Update networks """ if self._n_train_steps_total % self.q_update_period == 0: if self.train_encoder_decoder: self.context_optimizer.zero_grad() if self.train_agent: self.qf1_optimizer.zero_grad() self.qf2_optimizer.zero_grad() context_loss.backward(retain_graph=True) # retain graph because the encoder is trained by both QF losses qf1_loss.backward(retain_graph=True) qf2_loss.backward() if self.train_agent: self.qf1_optimizer.step() self.qf2_optimizer.step() if self.train_encoder_decoder: self.context_optimizer.step() if self.train_agent: if self._n_train_steps_total % self.policy_update_period == 0 and self.update_policy: self.policy_optimizer.zero_grad() policy_loss.backward() self.policy_optimizer.step() self._num_gradient_steps += 1 """ Soft Updates """ if self._n_train_steps_total % self.target_update_period == 0: ptu.soft_update_from_to(self.qf1, self.target_qf1, self.soft_target_tau) ptu.soft_update_from_to(self.qf2, self.target_qf2, self.soft_target_tau) """ Save some statistics for eval """ if self._need_to_update_eval_statistics: self._need_to_update_eval_statistics = False """ Eval should set this to None. This way, these statistics are only computed for one batch. """ policy_loss = (log_pi - q_new_actions).mean() self.eval_statistics['QF1 Loss'] = np.mean(ptu.get_numpy(qf1_loss)) self.eval_statistics['QF2 Loss'] = np.mean(ptu.get_numpy(qf2_loss)) self.eval_statistics['Policy Loss'] = np.mean( ptu.get_numpy(policy_loss)) self.eval_statistics.update( create_stats_ordered_dict( 'Q1 Predictions', ptu.get_numpy(q1_pred), )) self.eval_statistics.update( create_stats_ordered_dict( 'Q2 Predictions', ptu.get_numpy(q2_pred), )) self.eval_statistics.update( create_stats_ordered_dict( 'Q Targets', ptu.get_numpy(q_target), )) self.eval_statistics['task_embedding/kl_divergence'] = ( ptu.get_numpy(kl_div)) self.eval_statistics['task_embedding/kl_loss'] = ( ptu.get_numpy(kl_loss)) self.eval_statistics['task_embedding/reward_prediction_loss'] = ( ptu.get_numpy(reward_prediction_loss)) self.eval_statistics['task_embedding/context_loss'] = ( ptu.get_numpy(context_loss)) self.eval_statistics.update( create_stats_ordered_dict( 'Log Pis', ptu.get_numpy(log_pi), )) self.eval_statistics.update( create_stats_ordered_dict( 'rewards', ptu.get_numpy(rewards), )) self.eval_statistics.update( create_stats_ordered_dict( 'terminals', ptu.get_numpy(terminals), )) policy_statistics = add_prefix(dist.get_diagnostics(), "policy/") self.eval_statistics.update(policy_statistics) self.eval_statistics.update( create_stats_ordered_dict( 'Advantage Weights', ptu.get_numpy(weights), )) self.eval_statistics.update( create_stats_ordered_dict( 'Advantage Score', ptu.get_numpy(score), )) self.eval_statistics['reparam_weight'] = self.train_reparam_weight self.eval_statistics['num_gradient_steps'] = ( self._num_gradient_steps) if self.use_automatic_entropy_tuning: self.eval_statistics['Alpha'] = alpha.item() self.eval_statistics['Alpha Loss'] = alpha_loss.item() if self.use_automatic_beta_tuning: self.eval_statistics.update({ "adaptive_beta/beta": ptu.get_numpy(beta.mean()), "adaptive_beta/beta loss": ptu.get_numpy(beta_loss.mean()), }) self._n_train_steps_total += 1
def update_parameters(self, memory, batch_size, updates): """ Update parameters of SAC-NF Exactly like SAC, but keep two separate Adam optimizers for the Gaussian policy AND the NF layers .backward() on them sequentially """ state_batch, action_batch, reward_batch, next_state_batch, mask_batch = memory.sample(batch_size=batch_size) obs = torch.FloatTensor(state_batch).to(self.device) next_obs = torch.FloatTensor(next_state_batch).to(self.device) actions = torch.FloatTensor(action_batch).to(self.device) rewards = torch.FloatTensor(reward_batch).to(self.device).unsqueeze(1) mask_batch = torch.FloatTensor(mask_batch).to(self.device).unsqueeze(1) # for visualization #with torch.no_grad(): # sample_size = 500 # _action, _logprob, _preact, _, _ = self.policy.evaluate(state_batch, num_samples=sample_size) # _action = _action.cpu().detach() # _preact = _preact.cpu().detach() # _logprob = _logprob.view(batch_size, sample_size, -1).cpu().detach() # info = { # 'action': _action, # 'preact': _preact, # 'logprob': _logprob, # } info = {} ''' update critic ''' with torch.no_grad(): new_next_actions, next_state_log_pi, _,_,_ = self.policy.evaluate(next_obs) next_tau, next_tau_hat, next_presum_tau = self.get_tau( new_next_actions) target_z1_values= self.target_zf1(next_obs, new_next_actions,next_tau_hat) target_z2_values = self.target_zf2(next_obs, new_next_actions,next_tau_hat) min_qf_next_target = torch.min(target_z1_values, target_z2_values) - self.alpha * next_state_log_pi z_target = rewards + mask_batch * self.gamma * (min_qf_next_target) tau, tau_hat, presum_tau = self.get_tau(actions) z1_pred = self.zf1(obs, actions, tau_hat) z2_pred = self.zf2(obs, actions, tau_hat) # Two Q-functions to mitigate positive bias in the policy improvement step zf1_loss = self.zf_criterion(z1_pred, z_target, tau_hat, next_presum_tau) zf2_loss = self.zf_criterion(z2_pred, z_target, tau_hat, next_presum_tau) new_actions, log_pi, _,_,_ = self.policy.evaluate(obs) # update self.zf1_optimizer.zero_grad() zf1_loss.backward() self.zf1_optimizer.step() self.zf2_optimizer.zero_grad() zf2_loss.backward() self.zf2_optimizer.step() risk_param = self.risk_schedule(self._n_train_steps_total) if self.risk_type == 'VaR': tau_ = ptu.ones_like(rewards) * risk_param q1_new_actions = self.zf1(obs, new_actions, tau_) q2_new_actions = self.zf2(obs, new_actions, tau_) else: with torch.no_grad(): new_tau, new_tau_hat, new_presum_tau = self.get_tau(obs, new_actions ) z1_new_actions = self.zf1(obs, new_actions, new_tau_hat) z2_new_actions = self.zf2(obs, new_actions, new_tau_hat) if self.risk_type in ['neutral', 'std']: q1_new_actions = torch.sum(new_presum_tau * z1_new_actions, dim=1, keepdims=True) q2_new_actions = torch.sum(new_presum_tau * z2_new_actions, dim=1, keepdims=True) if self.risk_type == 'std': q1_std = new_presum_tau * (z1_new_actions - q1_new_actions).pow(2) q2_std = new_presum_tau * (z2_new_actions - q2_new_actions).pow(2) q1_new_actions -= risk_param * q1_std.sum(dim=1, keepdims=True).sqrt() q2_new_actions -= risk_param * q2_std.sum(dim=1, keepdims=True).sqrt() else: with torch.no_grad(): risk_weights = distortion_de(new_tau_hat, self.risk_type, risk_param) q1_new_actions = torch.sum(risk_weights * new_presum_tau * z1_new_actions, dim=1, keepdims=True) q2_new_actions = torch.sum(risk_weights * new_presum_tau * z2_new_actions, dim=1, keepdims=True) q_new_actions = torch.min(q1_new_actions, q2_new_actions) policy_loss = (self.alpha * log_pi - q_new_actions).mean() nf_loss = ((self.alpha * log_pi) - q_new_actions).mean() self.policy_optim.zero_grad() policy_loss.backward(retain_graph=True) self.policy_optim.step() self.nf_optim.zero_grad() nf_loss.backward() self.nf_optim.step() if self.automatic_entropy_tuning: alpha_loss = -(self.log_alpha * (log_pi + self.target_entropy).detach()).mean() self.alpha_optim.zero_grad() alpha_loss.backward() self.alpha_optim.step() self.alpha = self.log_alpha.exp() alpha_tlogs = self.alpha.clone() # For TensorboardX logs else: alpha_loss = torch.tensor(0.).to(self.device) alpha_tlogs = torch.tensor(self.alpha) # For TensorboardX logs # update target value fuctions if updates % self.target_update_interval == 0: soft_update(self.target_zf1, self.zf1, self.tau) soft_update(self.target_zf2, self.zf2, self.tau) return zf1_loss.item(), zf2_loss.item(), policy_loss.item(), alpha_loss.item(), alpha_tlogs.item(), info