def get_diagnostics(self): path_lens = [len(path['actions']) for path in self._epoch_paths] stats = OrderedDict([ ('num steps total', self._num_steps_total), ('num paths total', self._num_paths_total), ]) stats.update( create_stats_ordered_dict( "path length", path_lens, always_show_all_stats=True, )) return stats
def log_diagnostics(self, paths, logger=default_logger): super().log_diagnostics(paths) MultitaskEnv.log_diagnostics(self, paths) statistics = OrderedDict() for stat_name in [ 'pos_error', 'vel_error', 'weighted_pos_error', 'weighted_vel_error', ]: stat = get_stat_in_paths(paths, 'env_infos', stat_name) statistics.update(create_stats_ordered_dict( '{}'.format(stat_name), stat, always_show_all_stats=True, )) statistics.update(create_stats_ordered_dict( 'Final {}'.format(stat_name), [s[-1] for s in stat], always_show_all_stats=True, )) weighted_error = ( get_stat_in_paths(paths, 'env_infos', 'weighted_pos_error') + get_stat_in_paths(paths, 'env_infos', 'weighted_vel_error') ) statistics.update(create_stats_ordered_dict( "Weighted Error", weighted_error, always_show_all_stats=True, )) statistics.update(create_stats_ordered_dict( "Final Weighted Error", [s[-1] for s in weighted_error], always_show_all_stats=True, )) for key, value in statistics.items(): logger.record_tabular(key, value)
def debug_statistics(self): """ Given an image $$x$$, samples a bunch of latents from the prior $$z_i$$ and decode them $$\hat x_i$$. Compare this to $$\hat x$$, the reconstruction of $$x$$. Ideally - All the $$\hat x_i$$s do worse than $$\hat x$$ (makes sure VAE isn’t ignoring the latent) - Some $$\hat x_i$$ do better than other $$\hat x_i$$ (tests for coverage) """ debug_batch_size = 64 data = self.get_batch(train=False) reconstructions, _, _ = self.model(data) img = data[0] recon_mse = ((reconstructions[0] - img) ** 2).mean().view(-1) img_repeated = img.expand((debug_batch_size, img.shape[0])) samples = ptu.randn(debug_batch_size, self.representation_size) random_imgs, _ = self.model.decode(samples) random_mses = (random_imgs - img_repeated) ** 2 mse_improvement = ptu.get_numpy(random_mses.mean(dim=1) - recon_mse) stats = create_stats_ordered_dict( 'debug/MSE improvement over random', mse_improvement, ) stats.update(create_stats_ordered_dict( 'debug/MSE of random decoding', ptu.get_numpy(random_mses), )) stats['debug/MSE of reconstruction'] = ptu.get_numpy( recon_mse )[0] if self.skew_dataset: stats.update(create_stats_ordered_dict( 'train weight', self._train_weights )) return stats
def train_from_torch(self, batch): rewards = batch["rewards"] * self.reward_scale terminals = batch["terminals"] obs = batch["observations"] actions = batch["actions"] next_obs = batch["next_observations"] try: plan_lengths = batch["plan_lengths"] if self.single_plan_discounting: plan_lengths = torch.ones_like(plan_lengths) except KeyError as e: plan_lengths = torch.ones_like(rewards) """ Compute loss """ target_q_values = self.target_qf(next_obs).detach().max( 1, keepdim=True)[0] y_target = (rewards + (1.0 - terminals) * torch.pow(self.discount, plan_lengths) * target_q_values) y_target = y_target.detach() # actions is a one-hot vector y_pred = torch.sum(self.qf(obs) * actions, dim=1, keepdim=True) # huber loss correction. if self.huber_loss: y_target = torch.max(y_target, y_pred.sub(1)) y_target = torch.min(y_target, y_pred.add(1)) qf_loss = self.qf_criterion(y_pred, y_target) """ Soft target network updates """ self.qf_optimizer.zero_grad() qf_loss.backward() # for param in self.qf.parameters(): # introduced parameter clipping # param.grad.data.clamp_(-1, 1) self.qf_optimizer.step() """ Soft Updates """ if self._n_train_steps_total % self.target_update_period == 0: ptu.soft_update_from_to(self.qf, self.target_qf, self.soft_target_tau) """ Save some statistics for eval using just one batch. """ if self._need_to_update_eval_statistics: self._need_to_update_eval_statistics = False self.eval_statistics["QF Loss"] = np.mean(ptu.get_numpy(qf_loss)) self.eval_statistics.update( create_stats_ordered_dict("Y Predictions", ptu.get_numpy(y_pred)))
def low_train_from_torch(self, batch): rewards = batch['rewards'] terminals = batch['terminals'] obs = batch['observations'] actions = batch['actions'] next_obs = batch['next_observations'] goals = batch['goals'] # kinda an approximation since doesn't account for goal switching next_goals = self.setter.goal_transition(obs, goals, next_obs) """ Compute loss """ best_action_idxs = self.low_qf(torch.cat( (next_obs, next_goals), dim=1)).max(1, keepdim=True)[1] target_q_values = self.low_target_qf( torch.cat((next_obs, next_goals), dim=1)).gather(1, best_action_idxs).detach() y_target = rewards + (1. - terminals) * self.discount * target_q_values y_target = y_target.detach() # actions is a one-hot vector y_pred = torch.sum(self.low_qf(torch.cat( (obs, goals), dim=1)) * actions, dim=1, keepdim=True) qf_loss = self.qf_criterion(y_pred, y_target) """ Update networks """ self.low_qf_optimizer.zero_grad() qf_loss.backward() if self.grad_clip_val is not None: nn.utils.clip_grad_norm_(self.low_qf.parameters(), self.grad_clip_val) self.low_qf_optimizer.step() """ Soft target network updates """ if self._n_train_steps_total % self.setter_and_target_update_period == 0: ptu.soft_update_from_to(self.low_qf, self.low_target_qf, self.tau) """ Save some statistics for eval using just one batch. """ if self._need_to_update_eval_statistics: self._need_to_update_eval_statistics = False self.eval_statistics['QF Loss'] = np.mean(ptu.get_numpy(qf_loss)) self.eval_statistics.update( create_stats_ordered_dict( 'Y Predictions', ptu.get_numpy(y_pred), ))
def train_from_torch(self, batch): rewards = batch['rewards'] terminals = batch['terminals'] obs = batch['observations'] actions = batch['actions'] next_obs = batch['next_observations'] """ Compute loss """ best_action_idxs = self.qf(next_obs).max( 1, keepdim=True )[1] target_q_values = self.target_qf(next_obs).gather( 1, best_action_idxs ).detach() y_target = rewards + (1. - terminals) * self.discount * target_q_values y_target = y_target.detach() # actions is a one-hot vector y_pred = torch.sum(self.qf(obs) * actions, dim=1, keepdim=True) qf_loss = self.qf_criterion(y_pred, y_target) """ Update networks """ self.qf_optimizer.zero_grad() qf_loss.backward() self.qf_optimizer.step() """ Soft target network updates """ if self._n_train_steps_total % self.target_update_period == 0: ptu.soft_update_from_to( self.qf, self.target_qf, self.soft_target_tau ) """ Save some statistics for eval using just one batch. """ if self._need_to_update_eval_statistics: self._need_to_update_eval_statistics = False self.eval_statistics['QF Loss'] = np.mean(ptu.get_numpy(qf_loss)) self.eval_statistics.update(create_stats_ordered_dict( 'Y Predictions', ptu.get_numpy(y_pred), )) self._n_train_steps_total += 1
def get_diagnostics(self): path_lens = [len(path["actions"]) for path in self._epoch_paths] stats = OrderedDict([ ("num steps total", self._num_steps_total), ("num paths total", self._num_paths_total), ]) stats.update( create_stats_ordered_dict( "path length", path_lens, always_show_all_stats=True, )) success = [path["rewards"][-1][0] > 0 for path in self._epoch_paths] stats["SuccessRate"] = sum(success) / len(success) return stats
def log_diagnostics(self, paths, logger=default_logger): super().log_diagnostics(paths) MultitaskEnv.log_diagnostics(self, paths) statistics = OrderedDict() for name_in_env_infos, name_to_log in [ ('x_pos', 'X Position'), ('y_pos', 'Y Position'), ('dist_from_origin', 'Distance from Origin'), ('desired_x_pos', 'Desired X Position'), ('desired_y_pos', 'Desired Y Position'), ('desired_dist_from_origin', 'Desired Distance from Origin'), ('pos_error', 'Distance to goal'), ]: stat = get_stat_in_paths(paths, 'env_infos', name_in_env_infos) statistics.update( create_stats_ordered_dict( name_to_log, stat, always_show_all_stats=True, exclude_max_min=True, )) for name_in_env_infos, name_to_log in [ ('dist_from_origin', 'Distance from Origin'), ('desired_dist_from_origin', 'Desired Distance from Origin'), ('pos_error', 'Distance to goal'), ]: stat = get_stat_in_paths(paths, 'env_infos', name_in_env_infos) statistics.update( create_stats_ordered_dict( 'Final {}'.format(name_to_log), [s[-1] for s in stat], always_show_all_stats=True, )) for key, value in statistics.items(): logger.record_tabular(key, value)
def get_diagnostics(self): if self._vae_sample_probs is None or self._vae_sample_priorities is None: stats = create_stats_ordered_dict( "VAE Sample Weights", np.zeros(self._size), ) stats.update( create_stats_ordered_dict( "VAE Sample Probs", np.zeros(self._size), )) else: vae_sample_priorities = self._vae_sample_priorities[:self._size] vae_sample_probs = self._vae_sample_probs[:self._size] stats = create_stats_ordered_dict( "VAE Sample Weights", vae_sample_priorities, ) stats.update( create_stats_ordered_dict( "VAE Sample Probs", vae_sample_probs, )) return stats
def __call__(self, paths, contexts): diagnostics = OrderedDict() for state_key in self.state_to_goal_keys_map: goal_key = self.state_to_goal_keys_map[state_key] values = [] for i in range(len(paths)): state = paths[i]["observations"][-1][state_key] goal = contexts[i][goal_key] distance = np.linalg.norm(state - goal) values.append(distance) diagnostics_key = goal_key + "/final/distance" diagnostics.update( create_stats_ordered_dict( diagnostics_key, values, )) return diagnostics
def get_diagnostics(self): path_lens = [len(path["actions"]) for path in self._epoch_paths] stats = OrderedDict([ ("num steps total", self._num_steps_total), ("num paths total", self._num_paths_total), ("num low level steps total", self._num_low_level_steps_total), ( "num low level steps total true", self._num_low_level_steps_total_true, ), ]) stats.update( create_stats_ordered_dict( "path length", path_lens, always_show_all_stats=True, )) return stats
def train_from_torch(self, batch): rewards = batch["rewards"] * self.reward_scale terminals = batch["terminals"] obs = batch["observations"] actions = batch["actions"] next_obs = batch["next_observations"] """ Compute loss """ target_q_values = self.target_qf(next_obs).detach().max( 1, keepdim=True)[0] y_target = rewards + (1.0 - terminals) * self.discount * target_q_values y_target = y_target.detach() # actions is a one-hot vector y_pred = torch.sum(self.qf(obs) * actions, dim=1, keepdim=True) qf_loss = self.qf_criterion(y_pred, y_target) """ Soft target network updates """ self.qf_optimizer.zero_grad() qf_loss.backward() self.qf_optimizer.step() """ Soft Updates """ if self._n_train_steps_total % self.target_update_period == 0: ptu.soft_update_from_to(self.qf, self.target_qf, self.soft_target_tau) """ Save some statistics for eval using just one batch. """ if self._need_to_update_eval_statistics: self._need_to_update_eval_statistics = False self.eval_statistics["QF Loss"] = np.mean(ptu.get_numpy(qf_loss)) self.eval_statistics.update( create_stats_ordered_dict( "Y Predictions", ptu.get_numpy(y_pred), )) self._n_train_steps_total += 1
def get_diagnostics(self): path_lens = [len(path['actions']) for path in self._epoch_paths] stats = OrderedDict([ ('num steps total', self._num_steps_total), ('num paths total', self._num_paths_total), ]) stats.update( create_stats_ordered_dict( "path length", path_lens, always_show_all_stats=True, )) paths_policy = [ path for path in self._epoch_paths if 'expert' not in path['agent_infos'][0] ] success = [path['rewards'][-1][0] > 0 for path in paths_policy] stats['SuccessRate'] = sum(success) / len(success) stats['Expert_Supervision'] = 1 - len(paths_policy) / len( self._epoch_paths) return stats
def _do_training(self): batch = self.get_batch() rewards = batch['rewards'] terminals = batch['terminals'] obs = batch['observations'] actions = batch['actions'] next_obs = batch['next_observations'] """ Compute loss """ target_q_values = self.target_qf(next_obs).detach().max( 1, keepdim=True )[0] y_target = rewards + (1. - terminals) * self.discount * target_q_values y_target = y_target.detach() # actions is a one-hot vector y_pred = torch.sum(self.qf(obs) * actions, dim=1, keepdim=True) qf_loss = self.qf_criterion(y_pred, y_target) """ Update networks """ self.qf_optimizer.zero_grad() qf_loss.backward() self.qf_optimizer.step() self._update_target_network() """ Save some statistics for eval using just one batch. """ if self.need_to_update_eval_statistics: self.need_to_update_eval_statistics = False self.eval_statistics['QF Loss'] = np.mean(ptu.get_numpy(qf_loss)) self.eval_statistics.update(create_stats_ordered_dict( 'Y Predictions', ptu.get_numpy(y_pred), ))
def _add_exploration_bonus(self, paths): paths = copy.deepcopy(paths) entropy_decreases = [] with torch.no_grad(): for path in paths: for i in range(len(path['observations']) - 1): obs1 = path['observations'][i] labels1 = torch.tensor(path['env_infos'][i]['sup_labels']) valid_mask1 = ~torch.isnan(labels1) entropy_1 = [ sup_learner.get_distribution( torch_ify(obs1)[None, :]).entropy() for sup_learner in self.sup_learners ] entropy_1 = torch.mean(torch.stack(entropy_1)[valid_mask1]) obs2 = path['observations'][i + 1] labels2 = torch.tensor(path['env_infos'][i + 1]['sup_labels']) valid_mask2 = ~torch.isnan(labels2) entropy_2 = [ sup_learner.get_distribution( torch_ify(obs2)[None, :]).entropy() for sup_learner in self.sup_learners ] entropy_2 = torch.mean(torch.stack(entropy_2)[valid_mask2]) entropy_decrease = (entropy_1 - entropy_2).item() entropy_decreases.append(entropy_decrease) path['rewards'][ i] += self.exploration_bonus * entropy_decrease if self._need_to_update_eval_statistics: self.eval_statistics.update( create_stats_ordered_dict( 'Entropy Decrease', entropy_decreases, )) return paths
def get_diagnostics(self): path_lens = [len(path["actions"]) for path in self._epoch_paths] average_score = (0 if self._num_episodes == 0 else self._total_score / self._num_episodes) epoch_score = (0 if self._epoch_episodes == 0 else self._epoch_score / self._epoch_episodes) explored = [path["explored"][0] for path in self._epoch_paths] paths_explored = (0 if len(explored) == 0 else sum(explored).item() / len(explored)) stats = OrderedDict([ ("num steps total", self._num_steps_total), ("num paths total", self._num_paths_total), ("average score", average_score), ("epoch score", epoch_score), ("plans explored", paths_explored), ]) action_lengths = [(path["actions"][0].item(), len(path["actions"])) for path in self._epoch_paths] action_lengths = [0] * 16 # TODO: fix magic number action_counts = [0] * 16 for path in self._epoch_paths: action_lengths[path["actions"][0].item()] += len(path["actions"]) action_counts[path["actions"][0].item()] += 1 a = {} for i in range(16): if action_counts[i] > 0: action_lengths[i] = action_lengths[i] / action_counts[i] a[f"action {i} count"] = action_counts[i] a[f"action {i} length"] = action_lengths[i] stats.update(a) stats.update( create_stats_ordered_dict("path length", path_lens, always_show_all_stats=True)) return stats
def train_from_torch(self, batch): rewards_n = batch['rewards'].detach() terminals_n = batch['terminals'].detach() obs_n = batch['observations'].detach() actions_n = batch['actions'].detach() next_obs_n = batch['next_observations'].detach() batch_size = rewards_n.shape[0] num_agent = rewards_n.shape[1] whole_obs = obs_n.view(batch_size, -1) whole_actions = actions_n.view(batch_size, -1) whole_next_obs = next_obs_n.view(batch_size, -1) """ Policy operations. """ online_actions_n, online_pre_values_n, online_log_pis_n = [], [], [] for agent in range(num_agent): policy_actions, info = self.policy_n[agent]( obs_n[:,agent,:], return_info=True, ) online_actions_n.append(policy_actions) online_pre_values_n.append(info['preactivation']) online_log_pis_n.append(info['log_prob']) k0_actions = torch.stack(online_actions_n) # num_agent x batch x a_dim k0_actions = k0_actions.transpose(0,1).contiguous() # batch x num_agent x a_dim k0_inputs = torch.cat([obs_n, k0_actions],dim=-1) k0_contexts = self.context_graph(k0_inputs) # batch x num_agent x c_dim k1_actions = self.cactor(k0_contexts, deterministic=self.deterministic_cactor_in_graph) k1_inputs = torch.cat([obs_n, k1_actions],dim=-1) k1_contexts = self.context_graph(k1_inputs) policy_gradients_n = [] alpha_n = [] for agent in range(num_agent): policy_actions = online_actions_n[agent] pre_value = online_pre_values_n[agent] log_pi = online_log_pis_n[agent] if self.pre_activation_weight > 0.: pre_activation_policy_loss = ( (pre_value**2).sum(dim=1).mean() ) else: pre_activation_policy_loss = torch.tensor(0.).to(ptu.device) if self.use_entropy_loss: if self.use_automatic_entropy_tuning: if self.state_dependent_alpha: alpha = self.log_alpha_n[agent](obs_n[:,agent,:]).exp() else: alpha = self.log_alpha_n[agent].exp() alpha_loss = -(alpha * (log_pi + self.target_entropy).detach()).mean() self.alpha_optimizer_n[agent].zero_grad() alpha_loss.backward() self.alpha_optimizer_n[agent].step() if self.state_dependent_alpha: alpha = self.log_alpha_n[agent](obs_n[:,agent,:]).exp().detach() else: alpha = self.log_alpha_n[agent].exp().detach() alpha_n.append(alpha) else: alpha_loss = torch.tensor(0.).to(ptu.device) alpha = torch.tensor(self.init_alpha).to(ptu.device) alpha_n.append(alpha) entropy_loss = (alpha*log_pi).mean() else: entropy_loss = torch.tensor(0.).to(ptu.device) q_input = torch.cat([policy_actions,k1_contexts[:,agent,:]],dim=-1) q1_output = self.qf1(q_input) q2_output = self.qf2(q_input) q_output = torch.min(q1_output,q2_output) raw_policy_loss = -q_output.mean() policy_loss = ( raw_policy_loss + pre_activation_policy_loss * self.pre_activation_weight + entropy_loss ) policy_gradients_n.append(torch.autograd.grad(policy_loss, self.policy_n[agent].parameters(),retain_graph=True)) if self._need_to_update_eval_statistics: self.eval_statistics['Policy Loss {}'.format(agent)] = np.mean(ptu.get_numpy( policy_loss )) self.eval_statistics['Raw Policy Loss {}'.format(agent)] = np.mean(ptu.get_numpy( raw_policy_loss )) self.eval_statistics['Preactivation Policy Loss {}'.format(agent)] = np.mean(ptu.get_numpy( pre_activation_policy_loss )) self.eval_statistics['Entropy Loss {}'.format(agent)] = np.mean(ptu.get_numpy( entropy_loss )) if self.use_entropy_loss: if self.state_dependent_alpha: self.eval_statistics.update(create_stats_ordered_dict( 'Alpha {}'.format(agent), ptu.get_numpy(alpha), )) else: self.eval_statistics['Alpha {} Mean'.format(agent)] = np.mean(ptu.get_numpy( alpha )) self.eval_statistics.update(create_stats_ordered_dict( 'Policy Action {}'.format(agent), ptu.get_numpy(policy_actions), )) for agent in range(num_agent): # self.policy_optimizer_n[agent].zero_grad() for pid,p in enumerate(self.policy_n[agent].parameters()): p.grad = policy_gradients_n[agent][pid] self.policy_optimizer_n[agent].step() """ Critic operations. """ with torch.no_grad(): next_actions_n, next_log_pis_n = [], [] for agent in range(num_agent): next_actions, next_info = self.policy_n[agent]( next_obs_n[:,agent,:], return_info=True, deterministic=self.deterministic_next_action, ) next_actions_n.append(next_actions) next_log_pis_n.append(next_info['log_prob']) next_k0_actions = torch.stack(next_actions_n) # num_agent x batch x a_dim next_k0_actions = next_k0_actions.transpose(0,1).contiguous() # batch x num_agent x a_dim next_k0_inputs = torch.cat([next_obs_n, next_k0_actions],dim=-1) next_k0_contexts = self.context_graph(next_k0_inputs) # batch x num_agent x c_dim next_k1_actions = self.cactor(next_k0_contexts, deterministic=self.deterministic_cactor_in_graph) next_k1_inputs = torch.cat([next_obs_n, next_k1_actions],dim=-1) next_k1_contexts = self.context_graph(next_k1_inputs) buffer_inputs = torch.cat([obs_n, actions_n],dim=-1) buffer_contexts = self.context_graph(buffer_inputs) # batch x num_agent x c_dim q_inputs = torch.cat([actions_n, buffer_contexts],dim=-1) q1_preds_n = self.qf1(q_inputs) q2_preds_n = self.qf2(q_inputs) raw_qf1_loss_n, raw_qf2_loss_n, q_target_n = [], [], [] for agent in range(num_agent): with torch.no_grad(): next_policy_actions = next_actions_n[agent] next_log_pi = next_log_pis_n[agent] next_q_input = torch.cat([next_policy_actions,next_k1_contexts[:,agent,:]],dim=-1) next_target_q1_values = self.target_qf1(next_q_input) next_target_q2_values = self.target_qf2(next_q_input) next_target_q_values = torch.min(next_target_q1_values, next_target_q2_values) if self.use_entropy_reward: if self.state_dependent_alpha: next_alpha = self.log_alpha_n[agent](next_obs_n[:,agent,:]).exp() else: next_alpha = alpha_n[agent] next_target_q_values = next_target_q_values - next_alpha * next_log_pi q_target = self.reward_scale*rewards_n[:,agent,:] + (1. - terminals_n[:,agent,:]) * self.discount * next_target_q_values q_target = torch.clamp(q_target, self.min_q_value, self.max_q_value) q_target_n.append(q_target) q1_pred = q1_preds_n[:,agent,:] raw_qf1_loss = self.qf_criterion(q1_pred, q_target) raw_qf1_loss_n.append(raw_qf1_loss) q2_pred = q2_preds_n[:,agent,:] raw_qf2_loss = self.qf_criterion(q2_pred, q_target) raw_qf2_loss_n.append(raw_qf2_loss) if self._need_to_update_eval_statistics: self.eval_statistics['QF1 Loss {}'.format(agent)] = np.mean(ptu.get_numpy(raw_qf1_loss)) self.eval_statistics['QF2 Loss {}'.format(agent)] = np.mean(ptu.get_numpy(raw_qf2_loss)) self.eval_statistics.update(create_stats_ordered_dict( 'Q1 Predictions {}'.format(agent), ptu.get_numpy(q1_pred), )) self.eval_statistics.update(create_stats_ordered_dict( 'Q2 Predictions {}'.format(agent), ptu.get_numpy(q2_pred), )) self.eval_statistics.update(create_stats_ordered_dict( 'Q Targets {}'.format(agent), ptu.get_numpy(q_target), )) if self.sum_n_loss: raw_qf1_loss = torch.sum(torch.stack(raw_qf1_loss_n)) raw_qf2_loss = torch.sum(torch.stack(raw_qf2_loss_n)) else: raw_qf1_loss = torch.mean(torch.stack(raw_qf1_loss_n)) raw_qf2_loss = torch.mean(torch.stack(raw_qf2_loss_n)) if self.negative_sampling: perturb_actions = actions_n.clone() # batch x agent x |A| batch_size, num_agent, a_dim = perturb_actions.shape perturb_agents = torch.randint(low=0,high=num_agent,size=(batch_size,)) neg_actions = torch.rand(batch_size,a_dim)*2.-1. # ranged in -1 to 1 perturb_actions[torch.arange(batch_size),perturb_agents,:] = neg_actions perturb_inputs = torch.cat([obs_n,perturb_actions],dim=-1) perturb_contexts = self.context_graph(perturb_inputs) # batch x num_agent x c_dim perturb_q_inputs = torch.cat([actions_n, perturb_contexts],dim=-1) perturb_q1_preds = self.qf1(perturb_q_inputs)[torch.arange(batch_size),perturb_agents,:] perturb_q2_preds = self.qf2(perturb_q_inputs)[torch.arange(batch_size),perturb_agents,:] perturb_q_targets = torch.stack(q_target_n).transpose(0,1).contiguous()[torch.arange(batch_size),perturb_agents,:] neg_loss1 = self.qf_criterion(perturb_q1_preds, perturb_q_targets) neg_loss2 = self.qf_criterion(perturb_q2_preds, perturb_q_targets) else: neg_loss1, neg_loss2 = torch.tensor(0.).to(ptu.device), torch.tensor(0.).to(ptu.device) if self.qf_weight_decay > 0: reg_loss1 = self.qf_weight_decay * sum( torch.sum(param ** 2) for param in self.qf1.regularizable_parameters() ) reg_loss2 = self.qf_weight_decay * sum( torch.sum(param ** 2) for param in self.qf2.regularizable_parameters() ) else: reg_loss1, reg_loss2 = torch.tensor(0.).to(ptu.device), torch.tensor(0.).to(ptu.device) qf1_loss = raw_qf1_loss + reg_loss1 + neg_loss1 qf2_loss = raw_qf2_loss + reg_loss2 + neg_loss2 if self._need_to_update_eval_statistics: self.eval_statistics['raw_qf1_loss'] = np.mean(ptu.get_numpy(raw_qf1_loss)) self.eval_statistics['raw_qf2_loss'] = np.mean(ptu.get_numpy(raw_qf2_loss)) self.eval_statistics['neg_qf1_loss'] = np.mean(ptu.get_numpy(neg_loss1)) self.eval_statistics['neg_qf2_loss'] = np.mean(ptu.get_numpy(neg_loss2)) self.eval_statistics['reg_qf2_loss'] = np.mean(ptu.get_numpy(reg_loss1)) self.eval_statistics['reg_qf2_loss'] = np.mean(ptu.get_numpy(reg_loss2)) self.context_graph_optimizer.zero_grad() cg_loss = qf1_loss+qf2_loss cg_loss.backward() self.context_graph_optimizer.step() """ Central actor operations. """ buffer_inputs = torch.cat([obs_n, actions_n],dim=-1) buffer_contexts = self.context_graph(buffer_inputs) # batch x num_agent x c_dim cactor_loss_n = [] for agent in range(num_agent): cactor_actions, cactor_info = self.cactor( buffer_contexts[:,agent,:], return_info=True, ) cactor_pre_value = cactor_info['preactivation'] if self.pre_activation_weight > 0: pre_activation_cactor_loss = ( (cactor_pre_value**2).sum(dim=1).mean() ) else: pre_activation_cactor_loss = torch.tensor(0.).to(ptu.device) if self.use_cactor_entropy_loss: cactor_log_pi = cactor_info['log_prob'] if self.use_automatic_entropy_tuning: if self.state_dependent_alpha: calpha = self.log_calpha_n[agent](whole_obs).exp() else: calpha = self.log_calpha_n[agent].exp() calpha_loss = -(calpha * (cactor_log_pi + self.target_entropy).detach()).mean() self.calpha_optimizer_n[agent].zero_grad() calpha_loss.backward() self.calpha_optimizer_n[agent].step() if self.state_dependent_alpha: calpha = self.log_calpha_n[agent](whole_obs).exp().detach() else: calpha = self.log_calpha_n[agent].exp().detach() else: calpha_loss = torch.tensor(0.).to(ptu.device) calpha = torch.tensor(self.init_alpha).to(ptu.device) cactor_entropy_loss = (calpha*cactor_log_pi).mean() else: cactor_entropy_loss = torch.tensor(0.).to(ptu.device) q_input = torch.cat([cactor_actions,buffer_contexts[:,agent,:]],dim=-1) q1_output = self.qf1(q_input) q2_output = self.qf2(q_input) q_output = torch.min(q1_output,q2_output) raw_cactor_loss = -q_output.mean() cactor_loss = ( raw_cactor_loss + pre_activation_cactor_loss * self.pre_activation_weight + cactor_entropy_loss ) cactor_loss_n.append(cactor_loss) if self._need_to_update_eval_statistics: if self.use_cactor_entropy_loss: if self.state_dependent_alpha: self.eval_statistics.update(create_stats_ordered_dict( 'CAlpha {}'.format(agent), ptu.get_numpy(calpha), )) else: self.eval_statistics['CAlpha {} Mean'.format(agent)] = np.mean(ptu.get_numpy( calpha )) self.eval_statistics['Cactor Loss {}'.format(agent)] = np.mean(ptu.get_numpy( cactor_loss )) self.eval_statistics['Raw Cactor Loss {}'.format(agent)] = np.mean(ptu.get_numpy( raw_cactor_loss )) self.eval_statistics['Preactivation Cactor Loss {}'.format(agent)] = np.mean(ptu.get_numpy( pre_activation_cactor_loss )) self.eval_statistics['Entropy Cactor Loss {}'.format(agent)] = np.mean(ptu.get_numpy( cactor_entropy_loss )) if self.sum_n_loss: cactor_loss = torch.sum(torch.stack(cactor_loss_n)) else: cactor_loss = torch.mean(torch.stack(cactor_loss_n)) self.cactor_optimizer.zero_grad() cactor_loss.backward() self.cactor_optimizer.step() self._need_to_update_eval_statistics = False self._update_target_networks() self._n_train_steps_total += 1
def train_from_torch(self, batch): rewards = batch['rewards'] terminals = batch['terminals'] obs = batch['observations'] actions = batch['actions'] next_obs = batch['next_observations'] context = batch['context'] # data is (task, batch, feat) # obs, actions, rewards, next_obs, terms = self.sample_sac(indices) # run inference in networks action_distrib, p_z, task_z_with_grad = self.agent( obs, context, return_latent_posterior_and_task_z=True, ) task_z_detached = task_z_with_grad.detach() new_actions, log_pi, pre_tanh_value = ( action_distrib.rsample_logprob_and_pretanh()) log_pi = log_pi.unsqueeze(1) policy_mean = action_distrib.mean policy_log_std = torch.log(action_distrib.stddev) # flattens out the task dimension t, b, _ = obs.size() obs = obs.view(t * b, -1) actions = actions.view(t * b, -1) next_obs = next_obs.view(t * b, -1) unscaled_rewards_flat = rewards.view(t * b, 1) rewards_flat = unscaled_rewards_flat * self.reward_scale terms_flat = terminals.view(t * b, 1) # Q and V networks # encoder will only get gradients from Q nets if self.backprop_q_loss_into_encoder: q1_pred = self.qf1(obs, actions, task_z_with_grad) q2_pred = self.qf2(obs, actions, task_z_with_grad) else: q1_pred = self.qf1(obs, actions, task_z_detached) q2_pred = self.qf2(obs, actions, task_z_detached) v_pred = self.vf(obs, task_z_detached) # get targets for use in V and Q updates with torch.no_grad(): target_v_values = self.target_vf(next_obs, task_z_detached) """ QF, Encoder, and Decoder Loss """ # note: encoder/deocder do not get grads from policy or vf q_target = rewards_flat + ( 1. - terms_flat) * self.discount * target_v_values qf_loss = torch.mean((q1_pred - q_target)**2) + torch.mean( (q2_pred - q_target)**2) # KL constraint on z if probabilistic kl_div = kl_divergence(p_z, self.agent.latent_prior).sum() kl_loss = self.kl_lambda * kl_div if self.train_context_decoder: # TODO: change to use a distribution reward_pred = self.context_decoder(obs, actions, task_z_with_grad) reward_prediction_loss = ((reward_pred - unscaled_rewards_flat)**2).mean() context_loss = kl_loss + reward_prediction_loss else: context_loss = kl_loss reward_prediction_loss = ptu.zeros(1) if self.train_encoder_decoder: self.context_optimizer.zero_grad() if self.train_agent: self.qf1_optimizer.zero_grad() self.qf2_optimizer.zero_grad() context_loss.backward(retain_graph=True) qf_loss.backward() if self.train_agent: self.qf1_optimizer.step() self.qf2_optimizer.step() if self.train_encoder_decoder: self.context_optimizer.step() """ VF update """ min_q_new_actions = self._min_q(obs, new_actions, task_z_detached) v_target = min_q_new_actions - log_pi vf_loss = self.vf_criterion(v_pred, v_target.detach()) self.vf_optimizer.zero_grad() vf_loss.backward() self.vf_optimizer.step() self._update_target_network() """ Policy update """ # n.b. policy update includes dQ/da log_policy_target = min_q_new_actions policy_loss = (log_pi - log_policy_target).mean() mean_reg_loss = self.policy_mean_reg_weight * (policy_mean**2).mean() std_reg_loss = self.policy_std_reg_weight * (policy_log_std**2).mean() pre_activation_reg_loss = self.policy_pre_activation_weight * ( (pre_tanh_value**2).sum(dim=1).mean()) policy_reg_loss = mean_reg_loss + std_reg_loss + pre_activation_reg_loss policy_loss = policy_loss + policy_reg_loss self.policy_optimizer.zero_grad() policy_loss.backward() self.policy_optimizer.step() # save some statistics for eval if self._need_to_update_eval_statistics: self._need_to_update_eval_statistics = False # eval should set this to None. # this way, these statistics are only computed for one batch. self.eval_statistics = OrderedDict() if self.use_information_bottleneck: z_mean = np.mean(np.abs(ptu.get_numpy(p_z.mean))) z_sig = np.mean(ptu.get_numpy(p_z.stddev)) self.eval_statistics['Z mean-abs train'] = z_mean self.eval_statistics['Z variance train'] = z_sig self.eval_statistics['KL Divergence'] = ptu.get_numpy(kl_div) self.eval_statistics['KL Loss'] = ptu.get_numpy(kl_loss) self.eval_statistics['QF Loss'] = np.mean(ptu.get_numpy(qf_loss)) self.eval_statistics['VF Loss'] = np.mean(ptu.get_numpy(vf_loss)) self.eval_statistics['Policy Loss'] = np.mean( ptu.get_numpy(policy_loss)) self.eval_statistics['task_embedding/kl_divergence'] = ( ptu.get_numpy(kl_div)) self.eval_statistics['task_embedding/kl_loss'] = ( ptu.get_numpy(kl_loss)) self.eval_statistics['task_embedding/reward_prediction_loss'] = ( ptu.get_numpy(reward_prediction_loss)) self.eval_statistics['task_embedding/context_loss'] = ( ptu.get_numpy(context_loss)) self.eval_statistics.update( create_stats_ordered_dict( 'Q Predictions', ptu.get_numpy(q1_pred), )) self.eval_statistics.update( create_stats_ordered_dict( 'V Predictions', ptu.get_numpy(v_pred), )) self.eval_statistics.update( create_stats_ordered_dict( 'Log Pis', ptu.get_numpy(log_pi), )) self.eval_statistics.update( create_stats_ordered_dict( 'Policy mu', ptu.get_numpy(policy_mean), )) self.eval_statistics.update( create_stats_ordered_dict( 'Policy log std', ptu.get_numpy(policy_log_std), ))
def compute_loss( self, batch, skip_statistics=False, ) -> Tuple[SACLosses, LossStatistics]: rewards = batch['rewards'] terminals = batch['terminals'] obs = batch['observations'] actions = batch['actions'] next_obs = batch['next_observations'] weights = batch["weights"] """ Policy and Alpha Loss """ dist = self.policy(obs) new_obs_actions, log_pi = dist.rsample_and_logprob() log_pi = log_pi.unsqueeze(-1) if self.use_automatic_entropy_tuning: alpha_loss = -(weights.detach() * (self.log_alpha * (log_pi + self.target_entropy).detach())).mean() alpha = self.log_alpha.exp() else: alpha_loss = 0 alpha = 1 q_new_actions = torch.min( self.qf1(obs, new_obs_actions), self.qf2(obs, new_obs_actions), ) policy_loss = (weights.detach() * (alpha * log_pi - q_new_actions)).mean() """ QF Loss """ q1_pred = self.qf1(obs, actions) q2_pred = self.qf2(obs, actions) next_dist = self.policy(next_obs) new_next_actions, new_log_pi = next_dist.rsample_and_logprob() new_log_pi = new_log_pi.unsqueeze(-1) target_q_values = torch.min( self.target_qf1(next_obs, new_next_actions), self.target_qf2(next_obs, new_next_actions), ) - alpha * new_log_pi q_target = self.reward_scale * rewards + ( 1. - terminals) * self.discount * target_q_values qf1_loss = (weights.detach() * ((q1_pred - q_target.detach())**2)).mean() qf2_loss = (weights.detach() * ((q2_pred - q_target.detach())**2)).mean() errors = ( ((torch.abs(q_target - q1_pred) + torch.abs(q_target - q2_pred)) / 2) * weights).detach() """ Save some statistics for eval """ eval_statistics = OrderedDict() if not skip_statistics: eval_statistics['QF1 Loss'] = np.mean(ptu.get_numpy(qf1_loss)) eval_statistics['QF2 Loss'] = np.mean(ptu.get_numpy(qf2_loss)) eval_statistics['Policy Loss'] = np.mean( ptu.get_numpy(policy_loss)) eval_statistics.update( create_stats_ordered_dict( 'Q1 Predictions', ptu.get_numpy(q1_pred), )) eval_statistics.update( create_stats_ordered_dict( 'Q2 Predictions', ptu.get_numpy(q2_pred), )) eval_statistics.update( create_stats_ordered_dict( 'Q Targets', ptu.get_numpy(q_target), )) eval_statistics.update( create_stats_ordered_dict( 'Log Pis', ptu.get_numpy(log_pi), )) policy_statistics = add_prefix(dist.get_diagnostics(), "policy/") eval_statistics.update(policy_statistics) if self.use_automatic_entropy_tuning: eval_statistics['Alpha'] = alpha.item() eval_statistics['Alpha Loss'] = alpha_loss.item() loss = SACLosses( policy_loss=policy_loss, qf1_loss=qf1_loss, qf2_loss=qf2_loss, alpha_loss=alpha_loss, ) return loss, eval_statistics, errors
def train_from_torch(self, batch): rewards_n = batch['rewards'].detach() terminals_n = batch['terminals'].detach() obs_n = batch['observations'].detach() actions_n = batch['actions'].detach() next_obs_n = batch['next_observations'].detach() batch_size = rewards_n.shape[0] num_agent = rewards_n.shape[1] whole_obs = obs_n.view(batch_size, -1) whole_actions = actions_n.view(batch_size, -1) whole_next_obs = next_obs_n.view(batch_size, -1) """ Policy operations. """ online_actions_n, online_pre_values_n, online_log_pis_n = [], [], [] for agent in range(num_agent): policy_actions, info = self.policy_n[agent]( obs_n[:,agent,:], return_info=True, ) online_actions_n.append(policy_actions) online_pre_values_n.append(info['preactivation']) online_log_pis_n.append(info['log_prob']) k0_actions = torch.stack(online_actions_n) # num_agent x batch x a_dim k0_actions = k0_actions.transpose(0,1).contiguous() # batch x num_agent x a_dim k0_inputs = torch.cat([obs_n, k0_actions],dim=-1) k0_contexts = self.cgca(k0_inputs) k1_actions = self.cactor(k0_contexts, deterministic=self.deterministic_cactor_in_graph) policy_gradients_n = [] alpha_n = [] for agent in range(num_agent): policy_actions = online_actions_n[agent] pre_value = online_pre_values_n[agent] log_pi = online_log_pis_n[agent] if self.pre_activation_weight > 0.: pre_activation_policy_loss = ( (pre_value**2).sum(dim=1).mean() ) else: pre_activation_policy_loss = torch.tensor(0.).to(ptu.device) if self.use_entropy_loss: if self.use_automatic_entropy_tuning: alpha = self.log_alpha_n[agent].exp() alpha_loss = -(alpha * (log_pi + self.target_entropy).detach()).mean() self.alpha_optimizer_n[agent].zero_grad() alpha_loss.backward() self.alpha_optimizer_n[agent].step() alpha = self.log_alpha_n[agent].exp().detach() alpha_n.append(alpha) else: alpha_loss = torch.tensor(0.).to(ptu.device) alpha = torch.tensor(self.init_alpha).to(ptu.device) alpha_n.append(alpha) entropy_loss = (alpha*log_pi).mean() else: entropy_loss = torch.tensor(0.).to(ptu.device) input_actions = k1_actions.clone() input_actions[:,agent,:] = policy_actions q1_output = self.qf1_n[agent](whole_obs, input_actions.view(batch_size, -1)) q2_output = self.qf2_n[agent](whole_obs, input_actions.view(batch_size, -1)) q_output = torch.min(q1_output,q2_output) raw_policy_loss = -q_output.mean() policy_loss = ( raw_policy_loss + pre_activation_policy_loss * self.pre_activation_weight + entropy_loss ) policy_gradients_n.append(torch.autograd.grad(policy_loss, self.policy_n[agent].parameters(),retain_graph=True)) if self._need_to_update_eval_statistics: self.eval_statistics['Policy Loss {}'.format(agent)] = np.mean(ptu.get_numpy( policy_loss )) self.eval_statistics['Raw Policy Loss {}'.format(agent)] = np.mean(ptu.get_numpy( raw_policy_loss )) self.eval_statistics['Preactivation Policy Loss {}'.format(agent)] = np.mean(ptu.get_numpy( pre_activation_policy_loss )) self.eval_statistics['Entropy Loss {}'.format(agent)] = np.mean(ptu.get_numpy( entropy_loss )) if self.use_entropy_loss: self.eval_statistics['Alpha {} Mean'.format(agent)] = np.mean(ptu.get_numpy( alpha )) self.eval_statistics.update(create_stats_ordered_dict( 'Policy Action {}'.format(agent), ptu.get_numpy(policy_actions), )) for agent in range(num_agent): # self.policy_optimizer_n[agent].zero_grad() for pid,p in enumerate(self.policy_n[agent].parameters()): p.grad = policy_gradients_n[agent][pid] self.policy_optimizer_n[agent].step() """ Critic operations. """ with torch.no_grad(): next_actions_n, next_log_pis_n = [], [] for agent in range(num_agent): next_actions, next_info = self.policy_n[agent]( next_obs_n[:,agent,:], return_info=True, deterministic=self.deterministic_next_action, ) next_actions_n.append(next_actions) next_log_pis_n.append(next_info['log_prob']) next_k0_actions = torch.stack(next_actions_n) # num_agent x batch x a_dim next_k0_actions = next_k0_actions.transpose(0,1).contiguous() # batch x num_agent x a_dim next_k0_inputs = torch.cat([next_obs_n, next_k0_actions],dim=-1) next_k0_contexts = self.cgca(next_k0_inputs) next_k1_actions = self.cactor(next_k0_contexts, deterministic=self.deterministic_cactor_in_graph) for agent in range(num_agent): with torch.no_grad(): input_actions = next_k1_actions.clone() input_actions[:,agent,:] = next_actions_n[agent] next_target_q1_values = self.target_qf1_n[agent]( whole_next_obs, input_actions.view(batch_size,-1), ) next_target_q2_values = self.target_qf2_n[agent]( whole_next_obs, input_actions.view(batch_size,-1), ) next_target_q_values = torch.min(next_target_q1_values, next_target_q2_values) if self.use_entropy_reward: next_alpha = alpha_n[agent] next_target_q_values = next_target_q_values - next_alpha * next_log_pis_n[agent] q_target = self.reward_scale*rewards_n[:,agent,:] + (1. - terminals_n[:,agent,:]) * self.discount * next_target_q_values q_target = torch.clamp(q_target, self.min_q_value, self.max_q_value) q1_pred = self.qf1_n[agent](whole_obs, whole_actions) raw_qf1_loss = self.qf_criterion(q1_pred, q_target) if self.qf_weight_decay > 0: reg_loss1 = self.qf_weight_decay * sum( torch.sum(param ** 2) for param in self.qf1_n[agent].regularizable_parameters() ) qf1_loss = raw_qf1_loss + reg_loss1 else: qf1_loss = raw_qf1_loss q2_pred = self.qf2_n[agent](whole_obs, whole_actions) raw_qf2_loss = self.qf_criterion(q2_pred, q_target) if self.qf_weight_decay > 0: reg_loss2 = self.qf_weight_decay * sum( torch.sum(param ** 2) for param in self.qf2_n[agent].regularizable_parameters() ) qf2_loss = raw_qf2_loss + reg_loss2 else: qf2_loss = raw_qf2_loss self.qf1_optimizer_n[agent].zero_grad() qf1_loss.backward() self.qf1_optimizer_n[agent].step() self.qf2_optimizer_n[agent].zero_grad() qf2_loss.backward() self.qf2_optimizer_n[agent].step() if self._need_to_update_eval_statistics: self.eval_statistics['QF1 Loss {}'.format(agent)] = np.mean(ptu.get_numpy(qf1_loss)) self.eval_statistics['QF2 Loss {}'.format(agent)] = np.mean(ptu.get_numpy(qf2_loss)) self.eval_statistics.update(create_stats_ordered_dict( 'Q1 Predictions {}'.format(agent), ptu.get_numpy(q1_pred), )) self.eval_statistics.update(create_stats_ordered_dict( 'Q2 Predictions {}'.format(agent), ptu.get_numpy(q2_pred), )) self.eval_statistics.update(create_stats_ordered_dict( 'Q Targets {}'.format(agent), ptu.get_numpy(q_target), )) """ Central actor operations. """ buffer_inputs = torch.cat([obs_n, actions_n],dim=-1) buffer_contexts_ca = self.cgca(buffer_inputs) cactor_actions, cactor_infos = self.cactor(buffer_contexts_ca,return_info=True) # batch x agent_num x |A| cactor_loss_n = [] for agent in range(num_agent): cactor_pre_value = cactor_infos['preactivation'][:,agent,:] if self.pre_activation_weight > 0: pre_activation_cactor_loss = ( (cactor_pre_value**2).sum(dim=1).mean() ) else: pre_activation_cactor_loss = torch.tensor(0.).to(ptu.device) if self.use_cactor_entropy_loss: cactor_log_pi = cactor_infos['log_prob'][:,agent,:] if self.use_automatic_entropy_tuning: calpha = self.log_calpha_n[agent].exp() calpha_loss = -(calpha * (cactor_log_pi + self.target_entropy).detach()).mean() self.calpha_optimizer_n[agent].zero_grad() calpha_loss.backward() self.calpha_optimizer_n[agent].step() calpha = self.log_calpha_n[agent].exp().detach() else: calpha_loss = torch.tensor(0.).to(ptu.device) calpha = torch.tensor(self.init_alpha).to(ptu.device) cactor_entropy_loss = (calpha*cactor_log_pi).mean() else: cactor_entropy_loss = torch.tensor(0.).to(ptu.device) current_actions = actions_n.clone() current_actions[:,agent,:] = cactor_actions[:,agent,:] q1_output = self.qf1_n[agent](whole_obs, current_actions.view(batch_size, -1)) q2_output = self.qf2_n[agent](whole_obs, current_actions.view(batch_size, -1)) q_output = torch.min(q1_output,q2_output) raw_cactor_loss = -q_output.mean() cactor_loss = ( raw_cactor_loss + pre_activation_cactor_loss * self.pre_activation_weight + cactor_entropy_loss ) cactor_loss_n.append(cactor_loss) if self._need_to_update_eval_statistics: if self.use_cactor_entropy_loss: self.eval_statistics['CAlpha {} Mean'.format(agent)] = np.mean(ptu.get_numpy( calpha )) self.eval_statistics['Cactor Loss {}'.format(agent)] = np.mean(ptu.get_numpy( cactor_loss )) self.eval_statistics['Raw Cactor Loss {}'.format(agent)] = np.mean(ptu.get_numpy( raw_cactor_loss )) self.eval_statistics['Preactivation Cactor Loss {}'.format(agent)] = np.mean(ptu.get_numpy( pre_activation_cactor_loss )) self.eval_statistics['Entropy Cactor Loss {}'.format(agent)] = np.mean(ptu.get_numpy( cactor_entropy_loss )) cactor_loss = torch.mean(torch.stack(cactor_loss_n)) self.cactor_optimizer.zero_grad() cactor_loss.backward() cgca_grad_norm = torch.tensor(0.).to(ptu.device) for p in self.cgca.parameters(): p_norm = p.grad.data.norm(2) cgca_grad_norm += p_norm.item() ** 2 cgca_grad_norm = (cgca_grad_norm ** (1. / 2)).item() cactor_grad_norm = torch.tensor(0.).to(ptu.device) for p in self.cactor.parameters(): p_norm = p.grad.data.norm(2) cactor_grad_norm += p_norm.item() ** 2 cactor_grad_norm = (cactor_grad_norm ** (1. / 2)).item() self.cactor_optimizer.step() if self._need_to_update_eval_statistics: self.eval_statistics['CGCA Gradient'] = cgca_grad_norm self.eval_statistics['CActor Gradient'] = cactor_grad_norm self._need_to_update_eval_statistics = False self._update_target_networks() self._n_train_steps_total += 1
def train_from_torch(self, batch): rewards_n = batch['rewards'].detach() terminals_n = batch['terminals'].detach() obs_n = batch['observations'].detach() actions_n = batch['actions'].detach() next_obs_n = batch['next_observations'].detach() batch_size = rewards_n.shape[0] num_agent = rewards_n.shape[1] whole_obs = obs_n.view(batch_size, -1) whole_actions = actions_n.view(batch_size, -1) whole_next_obs = next_obs_n.view(batch_size, -1) """ Policy operations. """ online_actions_n, online_pre_values_n, online_log_pis_n = [], [], [] for agent in range(num_agent): policy_actions, info = self.policy_n[agent]( obs_n[:,agent,:], return_info=True, ) online_actions_n.append(policy_actions) online_pre_values_n.append(info['preactivation']) online_log_pis_n.append(info['log_prob']) k0_actions = torch.stack(online_actions_n) # num_agent x batch x a_dim k0_actions = k0_actions.transpose(0,1).contiguous() # batch x num_agent x a_dim k0_inputs = torch.cat([obs_n, k0_actions],dim=-1) k0_contexts = self.cgca(k0_inputs) k1_actions = [self.cactor_n[agent](k0_contexts[:,agent,:], deterministic=self.deterministic_cactor_in_graph) for agent in range(num_agent)] k1_actions = torch.stack(k1_actions).transpose(0,1).contiguous() k1_inputs = torch.cat([obs_n, k1_actions],dim=-1) k1_contexts_1 = self.cg1(k1_inputs) k1_contexts_2 = self.cg2(k1_inputs) policy_gradients_n = [] alpha_n = [] for agent in range(num_agent): policy_actions = online_actions_n[agent] pre_value = online_pre_values_n[agent] log_pi = online_log_pis_n[agent] if self.pre_activation_weight > 0.: pre_activation_policy_loss = ( (pre_value**2).sum(dim=1).mean() ) else: pre_activation_policy_loss = torch.tensor(0.).to(ptu.device) if self.use_entropy_loss: if self.use_automatic_entropy_tuning: alpha = self.log_alpha_n[agent].exp() alpha_loss = -(alpha * (log_pi + self.target_entropy).detach()).mean() self.alpha_optimizer_n[agent].zero_grad() alpha_loss.backward() self.alpha_optimizer_n[agent].step() alpha = self.log_alpha_n[agent].exp().detach() alpha_n.append(alpha) else: alpha_loss = torch.tensor(0.).to(ptu.device) alpha = torch.tensor(self.init_alpha).to(ptu.device) alpha_n.append(alpha) entropy_loss = (alpha*log_pi).mean() else: entropy_loss = torch.tensor(0.).to(ptu.device) q1_input = torch.cat([policy_actions,k1_contexts_1[:,agent,:]],dim=-1) q1_output = self.qf1_n[agent](q1_input) q2_input = torch.cat([policy_actions,k1_contexts_2[:,agent,:]],dim=-1) q2_output = self.qf2_n[agent](q2_input) q_output = torch.min(q1_output,q2_output) raw_policy_loss = -q_output.mean() policy_loss = ( raw_policy_loss + pre_activation_policy_loss * self.pre_activation_weight + entropy_loss ) policy_gradients_n.append(torch.autograd.grad(policy_loss, self.policy_n[agent].parameters(),retain_graph=True)) if self._need_to_update_eval_statistics: self.eval_statistics['Policy Loss {}'.format(agent)] = np.mean(ptu.get_numpy( policy_loss )) self.eval_statistics['Raw Policy Loss {}'.format(agent)] = np.mean(ptu.get_numpy( raw_policy_loss )) self.eval_statistics['Preactivation Policy Loss {}'.format(agent)] = np.mean(ptu.get_numpy( pre_activation_policy_loss )) self.eval_statistics['Entropy Loss {}'.format(agent)] = np.mean(ptu.get_numpy( entropy_loss )) if self.use_entropy_loss: self.eval_statistics['Alpha {} Mean'.format(agent)] = np.mean(ptu.get_numpy( alpha )) self.eval_statistics.update(create_stats_ordered_dict( 'Policy Action {}'.format(agent), ptu.get_numpy(policy_actions), )) for agent in range(num_agent): # self.policy_optimizer_n[agent].zero_grad() for pid,p in enumerate(self.policy_n[agent].parameters()): p.grad = policy_gradients_n[agent][pid] self.policy_optimizer_n[agent].step() """ Critic operations. """ with torch.no_grad(): next_actions_n, next_log_pis_n = [], [] for agent in range(num_agent): next_actions, next_info = self.policy_n[agent]( next_obs_n[:,agent,:], return_info=True, deterministic=self.deterministic_next_action, ) next_actions_n.append(next_actions) next_log_pis_n.append(next_info['log_prob']) next_k0_actions = torch.stack(next_actions_n) # num_agent x batch x a_dim next_k0_actions = next_k0_actions.transpose(0,1).contiguous() # batch x num_agent x a_dim next_k0_inputs = torch.cat([next_obs_n, next_k0_actions],dim=-1) next_k0_contexts = self.cgca(next_k0_inputs) next_k1_actions = [self.cactor_n[agent](next_k0_contexts[:,agent,:], deterministic=self.deterministic_cactor_in_graph) for agent in range(num_agent)] next_k1_actions = torch.stack(next_k1_actions).transpose(0,1).contiguous() next_k1_inputs = torch.cat([next_obs_n, next_k1_actions],dim=-1) next_k1_contexts_1 = self.target_cg1(next_k1_inputs) next_k1_contexts_2 = self.target_cg2(next_k1_inputs) next_q1_inputs = torch.cat([next_k0_actions,next_k1_contexts_1],dim=-1) next_target_q1_values = [self.target_qf1_n[agent](next_q1_inputs[:,agent,:]) for agent in range(num_agent)] next_target_q1_values = torch.stack(next_target_q1_values).transpose(0,1).contiguous() next_q2_inputs = torch.cat([next_k0_actions,next_k1_contexts_2],dim=-1) next_target_q2_values = [self.target_qf2_n[agent](next_q2_inputs[:,agent,:]) for agent in range(num_agent)] next_target_q2_values = torch.stack(next_target_q2_values).transpose(0,1).contiguous() next_target_q_values = torch.min(next_target_q1_values, next_target_q2_values) if self.use_entropy_reward: next_alphas = torch.stack(alpha_n)[None,:] next_log_pis = torch.stack(next_log_pis_n).transpose(0,1).contiguous() next_target_q_values = next_target_q_values - next_alphas * next_log_pis q_targets = self.reward_scale*rewards_n + (1. - terminals_n) * self.discount * next_target_q_values q_targets = torch.clamp(q_targets, self.min_q_value, self.max_q_value) buffer_inputs = torch.cat([obs_n, actions_n],dim=-1) buffer_contexts_1 = self.cg1(buffer_inputs) # batch x num_agent x c_dim q1_inputs = torch.cat([actions_n, buffer_contexts_1],dim=-1) q1_preds = [self.qf1_n[agent](q1_inputs[:,agent,:]) for agent in range(num_agent)] q1_preds = torch.stack(q1_preds).transpose(0,1).contiguous() raw_qf1_loss = self.qf_criterion(q1_preds, q_targets) buffer_contexts_2 = self.cg2(buffer_inputs) # batch x num_agent x c_dim q2_inputs = torch.cat([actions_n, buffer_contexts_2],dim=-1) q2_preds = [self.qf2_n[agent](q2_inputs[:,agent,:]) for agent in range(num_agent)] q2_preds = torch.stack(q2_preds).transpose(0,1).contiguous() raw_qf2_loss = self.qf_criterion(q2_preds, q_targets) if self.qf_weight_decay > 0: reg_loss1 = self.qf_weight_decay * sum( torch.sum(param ** 2) for param in list(self.qf1.regularizable_parameters())+list(self.cg1.regularizable_parameters()) ) reg_loss2 = self.qf_weight_decay * sum( torch.sum(param ** 2) for param in list(self.qf2.regularizable_parameters())+list(stack.cg2.regularizable_parameters()) ) else: reg_loss1, reg_loss2 = torch.tensor(0.).to(ptu.device), torch.tensor(0.).to(ptu.device) qf1_loss = raw_qf1_loss + reg_loss1 qf2_loss = raw_qf2_loss + reg_loss2 if self._need_to_update_eval_statistics: self.eval_statistics['Qf1 Loss'] = ptu.get_numpy(qf1_loss) self.eval_statistics['Qf2 Loss'] = ptu.get_numpy(qf2_loss) self.eval_statistics['Raw Qf1 Loss'] = ptu.get_numpy(raw_qf1_loss) self.eval_statistics['Raw Qf2 Loss'] = ptu.get_numpy(raw_qf2_loss) self.eval_statistics['Reg Qf2 Loss'] = ptu.get_numpy(reg_loss1) self.eval_statistics['Reg Qf2 Loss'] = ptu.get_numpy(reg_loss2) self.qf1_optimizer.zero_grad() qf1_loss.backward() self.qf1_optimizer.step() self.qf2_optimizer.zero_grad() qf2_loss.backward() self.qf2_optimizer.step() """ Central actor operations. """ buffer_inputs = torch.cat([obs_n, actions_n],dim=-1) buffer_ca_contexts = self.cgca(buffer_inputs) cactor_outputs = [self.cactor_n[agent](buffer_ca_contexts[:,agent,:],return_info=True) for agent in range(num_agent)] # [(action, info),...] cactor_actions = torch.stack([cactor_outputs[agent][0] for agent in range(num_agent)]).transpose(0,1).contiguous() # batch x agent_num x |A| buffer_contexts_1 = self.cg1(buffer_inputs).detach() buffer_contexts_2 = self.cg2(buffer_inputs).detach() cactor_pre_values = torch.stack([cactor_outputs[agent][1]['preactivation'] for agent in range(num_agent)]).transpose(0,1).contiguous() if self.pre_activation_weight > 0: pre_activation_cactor_loss = ( (cactor_pre_values**2).sum(dim=1).mean() ) else: pre_activation_cactor_loss = torch.tensor(0.).to(ptu.device) if self.use_cactor_entropy_loss: cactor_log_pis = torch.stack([cactor_outputs[agent][1]['log_prob'] for agent in range(num_agent)]).transpose(0,1).contiguous() # batch x num_agent x 1 if self.use_automatic_entropy_tuning: calphas = torch.stack(self.log_calpha_n).exp()[None,:] calpha_loss = -(calphas * (cactor_log_pis + self.target_entropy).detach()) calpha_loss = calpha_loss.mean() self.calpha_optimizer.zero_grad() calpha_loss.backward() self.calpha_optimizer.step() calphas = torch.stack(self.log_calpha_n).exp().detach() else: calpha_loss = torch.tensor(0.).to(ptu.device) calphas = torch.stack([torch.tensor(self.init_alpha).to(ptu.device) for i in range(num_agent)]) cactor_entropy_loss = (calphas[None,:]*cactor_log_pis).mean() else: cactor_entropy_loss = torch.tensor(0.).to(ptu.device) q1_inputs = torch.cat([cactor_actions,buffer_contexts_1],dim=-1) q1_outputs = [self.qf1_n[agent](q1_inputs[:,agent,:]) for agent in range(num_agent)] q1_outputs = torch.stack(q1_outputs).transpose(0,1).contiguous() q2_inputs = torch.cat([cactor_actions,buffer_contexts_2],dim=-1) q2_outputs = [self.qf2_n[agent](q2_inputs[:,agent,:]) for agent in range(num_agent)] q2_outputs = torch.stack(q2_outputs).transpose(0,1).contiguous() q_outputs = torch.min(q1_outputs,q2_outputs) raw_cactor_loss = -q_outputs.mean() cactor_loss = ( raw_cactor_loss + pre_activation_cactor_loss * self.pre_activation_weight + cactor_entropy_loss ) if self._need_to_update_eval_statistics: if self.use_cactor_entropy_loss: self.eval_statistics.update(create_stats_ordered_dict( 'CAlpha', ptu.get_numpy(calphas), )) self.eval_statistics['Cactor Loss'] = ptu.get_numpy(cactor_loss) self.eval_statistics['Raw Cactor Loss'] = ptu.get_numpy(raw_cactor_loss) self.eval_statistics['Preactivation Cactor Loss'] = ptu.get_numpy(pre_activation_cactor_loss) self.eval_statistics['Entropy Cactor Loss'] = ptu.get_numpy(cactor_entropy_loss) self.cactor_optimizer.zero_grad() cactor_loss.backward() self.cactor_optimizer.step() self._need_to_update_eval_statistics = False self._update_target_networks() self._n_train_steps_total += 1
def train_from_torch(self, batch): self._current_epoch += 1 rewards = batch['rewards'] terminals = batch['terminals'] obs = batch['observations'] actions = batch['actions'] next_obs = batch['next_observations'] """ Policy and Alpha Loss """ new_obs_actions, policy_mean, policy_log_std, log_pi, *_ = self.policy( obs, reparameterize=True, return_log_prob=True, ) if self.use_automatic_entropy_tuning: alpha_loss = -(self.log_alpha * (log_pi + self.target_entropy).detach()).mean() self.alpha_optimizer.zero_grad() alpha_loss.backward() self.alpha_optimizer.step() alpha = self.log_alpha.exp() else: alpha_loss = 0 alpha = 1 if self.num_qs == 1: q_new_actions = self.qf1(obs, new_obs_actions) else: q_new_actions = torch.min( self.qf1(obs, new_obs_actions), self.qf2(obs, new_obs_actions), ) policy_loss = (alpha*log_pi - q_new_actions).mean() if self._current_epoch < self.policy_eval_start: """ For the initial few epochs, try doing behaivoral cloning, if needed conventionally, there's not much difference in performance with having 20k gradient steps here, or not having it """ policy_log_prob = self.policy.log_prob(obs, actions) policy_loss = (alpha * log_pi - policy_log_prob).mean() """ QF Loss """ q1_pred = self.qf1(obs, actions) if self.num_qs > 1: q2_pred = self.qf2(obs, actions) new_next_actions, _, _, new_log_pi, *_ = self.policy( next_obs, reparameterize=True, return_log_prob=True, ) new_curr_actions, _, _, new_curr_log_pi, *_ = self.policy( obs, reparameterize=True, return_log_prob=True, ) if not self.max_q_backup: if self.num_qs == 1: target_q_values = self.target_qf1(next_obs, new_next_actions) else: target_q_values = torch.min( self.target_qf1(next_obs, new_next_actions), self.target_qf2(next_obs, new_next_actions), ) if not self.deterministic_backup: target_q_values = target_q_values - alpha * new_log_pi if self.max_q_backup: """when using max q backup""" next_actions_temp, _ = self._get_policy_actions(next_obs, num_actions=10, network=self.policy) target_qf1_values = self._get_tensor_values(next_obs, next_actions_temp, network=self.target_qf1).max(1)[0].view(-1, 1) target_qf2_values = self._get_tensor_values(next_obs, next_actions_temp, network=self.target_qf2).max(1)[0].view(-1, 1) target_q_values = torch.min(target_qf1_values, target_qf2_values) q_target = self.reward_scale * rewards + (1. - terminals) * self.discount * target_q_values q_target = q_target.detach() qf1_loss = self.qf_criterion(q1_pred, q_target) if self.num_qs > 1: qf2_loss = self.qf_criterion(q2_pred, q_target) ## add CQL random_actions_tensor = torch.FloatTensor(q2_pred.shape[0] * self.num_random, actions.shape[-1]).uniform_(-1, 1) # .cuda() curr_actions_tensor, curr_log_pis = self._get_policy_actions(obs, num_actions=self.num_random, network=self.policy) new_curr_actions_tensor, new_log_pis = self._get_policy_actions(next_obs, num_actions=self.num_random, network=self.policy) q1_rand = self._get_tensor_values(obs, random_actions_tensor, network=self.qf1) q2_rand = self._get_tensor_values(obs, random_actions_tensor, network=self.qf2) q1_curr_actions = self._get_tensor_values(obs, curr_actions_tensor, network=self.qf1) q2_curr_actions = self._get_tensor_values(obs, curr_actions_tensor, network=self.qf2) q1_next_actions = self._get_tensor_values(obs, new_curr_actions_tensor, network=self.qf1) q2_next_actions = self._get_tensor_values(obs, new_curr_actions_tensor, network=self.qf2) cat_q1 = torch.cat( [q1_rand, q1_pred.unsqueeze(1), q1_next_actions, q1_curr_actions], 1 ) cat_q2 = torch.cat( [q2_rand, q2_pred.unsqueeze(1), q2_next_actions, q2_curr_actions], 1 ) std_q1 = torch.std(cat_q1, dim=1) std_q2 = torch.std(cat_q2, dim=1) if self.min_q_version == 3: # importance sammpled version random_density = np.log(0.5 ** curr_actions_tensor.shape[-1]) cat_q1 = torch.cat( [q1_rand - random_density, q1_next_actions - new_log_pis.detach(), q1_curr_actions - curr_log_pis.detach()], 1 ) cat_q2 = torch.cat( [q2_rand - random_density, q2_next_actions - new_log_pis.detach(), q2_curr_actions - curr_log_pis.detach()], 1 ) min_qf1_loss = torch.logsumexp(cat_q1 / self.temp, dim=1,).mean() * self.min_q_weight * self.temp min_qf2_loss = torch.logsumexp(cat_q2 / self.temp, dim=1,).mean() * self.min_q_weight * self.temp """Subtract the log likelihood of data""" min_qf1_loss = min_qf1_loss - q1_pred.mean() * self.min_q_weight min_qf2_loss = min_qf2_loss - q2_pred.mean() * self.min_q_weight if self.with_lagrange: alpha_prime = torch.clamp(self.log_alpha_prime.exp(), min=0.0, max=1000000.0) min_qf1_loss = alpha_prime * (min_qf1_loss - self.target_action_gap) min_qf2_loss = alpha_prime * (min_qf2_loss - self.target_action_gap) self.alpha_prime_optimizer.zero_grad() alpha_prime_loss = (-min_qf1_loss - min_qf2_loss)*0.5 alpha_prime_loss.backward(retain_graph=True) self.alpha_prime_optimizer.step() qf1_loss = qf1_loss + min_qf1_loss qf2_loss = qf2_loss + min_qf2_loss """ Update networks """ # Update the Q-functions iff self._num_q_update_steps += 1 self.qf1_optimizer.zero_grad() qf1_loss.backward(retain_graph=True) self.qf1_optimizer.step() if self.num_qs > 1: self.qf2_optimizer.zero_grad() qf2_loss.backward(retain_graph=True) self.qf2_optimizer.step() self._num_policy_update_steps += 1 self.policy_optimizer.zero_grad() policy_loss.backward(retain_graph=False) self.policy_optimizer.step() """ Soft Updates """ ptu.soft_update_from_to( self.qf1, self.target_qf1, self.soft_target_tau ) if self.num_qs > 1: ptu.soft_update_from_to( self.qf2, self.target_qf2, self.soft_target_tau ) """ Save some statistics for eval """ if self._need_to_update_eval_statistics: self._need_to_update_eval_statistics = False """ Eval should set this to None. This way, these statistics are only computed for one batch. """ policy_loss = (log_pi - q_new_actions).mean() self.eval_statistics['QF1 Loss'] = np.mean(ptu.get_numpy(qf1_loss)) self.eval_statistics['min QF1 Loss'] = np.mean(ptu.get_numpy(min_qf1_loss)) if self.num_qs > 1: self.eval_statistics['QF2 Loss'] = np.mean(ptu.get_numpy(qf2_loss)) self.eval_statistics['min QF2 Loss'] = np.mean(ptu.get_numpy(min_qf2_loss)) if not self.discrete: self.eval_statistics['Std QF1 values'] = np.mean(ptu.get_numpy(std_q1)) self.eval_statistics['Std QF2 values'] = np.mean(ptu.get_numpy(std_q2)) self.eval_statistics.update(create_stats_ordered_dict( 'QF1 in-distribution values', ptu.get_numpy(q1_curr_actions), )) self.eval_statistics.update(create_stats_ordered_dict( 'QF2 in-distribution values', ptu.get_numpy(q2_curr_actions), )) self.eval_statistics.update(create_stats_ordered_dict( 'QF1 random values', ptu.get_numpy(q1_rand), )) self.eval_statistics.update(create_stats_ordered_dict( 'QF2 random values', ptu.get_numpy(q2_rand), )) self.eval_statistics.update(create_stats_ordered_dict( 'QF1 next_actions values', ptu.get_numpy(q1_next_actions), )) self.eval_statistics.update(create_stats_ordered_dict( 'QF2 next_actions values', ptu.get_numpy(q2_next_actions), )) self.eval_statistics.update(create_stats_ordered_dict( 'actions', ptu.get_numpy(actions) )) self.eval_statistics.update(create_stats_ordered_dict( 'rewards', ptu.get_numpy(rewards) )) self.eval_statistics['Num Q Updates'] = self._num_q_update_steps self.eval_statistics['Num Policy Updates'] = self._num_policy_update_steps self.eval_statistics['Policy Loss'] = np.mean(ptu.get_numpy( policy_loss )) self.eval_statistics.update(create_stats_ordered_dict( 'Q1 Predictions', ptu.get_numpy(q1_pred), )) if self.num_qs > 1: self.eval_statistics.update(create_stats_ordered_dict( 'Q2 Predictions', ptu.get_numpy(q2_pred), )) self.eval_statistics.update(create_stats_ordered_dict( 'Q Targets', ptu.get_numpy(q_target), )) self.eval_statistics.update(create_stats_ordered_dict( 'Log Pis', ptu.get_numpy(log_pi), )) if not self.discrete: self.eval_statistics.update(create_stats_ordered_dict( 'Policy mu', ptu.get_numpy(policy_mean), )) self.eval_statistics.update(create_stats_ordered_dict( 'Policy log std', ptu.get_numpy(policy_log_std), )) if self.use_automatic_entropy_tuning: self.eval_statistics['Alpha'] = alpha.item() self.eval_statistics['Alpha Loss'] = alpha_loss.item() if self.with_lagrange: self.eval_statistics['Alpha_prime'] = alpha_prime.item() self.eval_statistics['min_q1_loss'] = ptu.get_numpy(min_qf1_loss).mean() self.eval_statistics['min_q2_loss'] = ptu.get_numpy(min_qf2_loss).mean() self.eval_statistics['threshold action gap'] = self.target_action_gap self.eval_statistics['alpha prime loss'] = alpha_prime_loss.item() self._n_train_steps_total += 1
def train_from_torch(self, batch): rewards_n = batch['rewards'].detach() terminals_n = batch['terminals'].detach() obs_n = batch['observations'].detach() actions_n = batch['actions'].detach() next_obs_n = batch['next_observations'].detach() batch_size = rewards_n.shape[0] num_agent = rewards_n.shape[1] """ Policy operations. """ online_actions_n, online_pre_values_n, online_log_pis_n = [], [], [] for agent in range(num_agent): if self.shared_obs: policy_actions, info = self.policy_n[agent]( obs_n, return_info=True, ) else: policy_actions, info = self.policy_n[agent]( obs_n[:, agent, :], return_info=True, ) online_actions_n.append(policy_actions) online_pre_values_n.append(info['preactivation']) online_log_pis_n.append(info['log_prob']) k0_actions = torch.stack(online_actions_n) # num_agent x batch x a_dim k0_actions = k0_actions.transpose( 0, 1).contiguous() # batch x num_agent x a_dim if self.shared_obs: k0_inputs = torch.cat( [obs_n, k0_actions.reshape(batch_size, -1)], dim=-1) else: k0_inputs = torch.cat([obs_n, k0_actions], dim=-1) k1_actions = self.cactor( k0_inputs, deterministic=self.deterministic_cactor_in_graph) if self.shared_obs: k1_inputs = torch.cat( [obs_n, k1_actions.reshape(batch_size, -1)], dim=-1) else: k1_inputs = torch.cat([obs_n, k1_actions], dim=-1) k1_contexts_1 = self.cg1(k1_inputs) k1_contexts_2 = self.cg2(k1_inputs) q1_inputs = torch.cat([k0_actions, k1_contexts_1], dim=-1) q1_outputs = self.qf1(q1_inputs) q2_inputs = torch.cat([k0_actions, k1_contexts_2], dim=-1) q2_outputs = self.qf2(q2_inputs) min_q_outputs = torch.min(q1_outputs, q2_outputs) # batch x num_agent x 1 policy_gradients_n = [] alpha_n = [] for agent in range(num_agent): policy_actions = online_actions_n[agent] pre_value = online_pre_values_n[agent] log_pi = online_log_pis_n[agent] if self.pre_activation_weight > 0.: pre_activation_policy_loss = ((pre_value**2).sum(dim=1).mean()) else: pre_activation_policy_loss = torch.tensor(0.).to(ptu.device) if self.use_entropy_loss: if self.use_automatic_entropy_tuning: alpha = self.log_alpha_n[agent].exp() alpha_loss = -( alpha * (log_pi + self.target_entropy).detach()).mean() self.alpha_optimizer_n[agent].zero_grad() alpha_loss.backward() self.alpha_optimizer_n[agent].step() alpha = self.log_alpha_n[agent].exp().detach() alpha_n.append(alpha) else: alpha_loss = torch.tensor(0.).to(ptu.device) alpha = torch.tensor(self.init_alpha).to(ptu.device) alpha_n.append(alpha) entropy_loss = (alpha * log_pi).mean() else: entropy_loss = torch.tensor(0.).to(ptu.device) raw_policy_loss = -min_q_outputs[:, agent, :].mean() policy_loss = ( raw_policy_loss + pre_activation_policy_loss * self.pre_activation_weight + entropy_loss) policy_gradients_n.append( torch.autograd.grad(policy_loss, self.policy_n[agent].parameters(), retain_graph=True)) if self._need_to_update_eval_statistics: self.eval_statistics['Policy Loss {}'.format(agent)] = np.mean( ptu.get_numpy(policy_loss)) self.eval_statistics['Raw Policy Loss {}'.format( agent)] = np.mean(ptu.get_numpy(raw_policy_loss)) self.eval_statistics[ 'Preactivation Policy Loss {}'.format(agent)] = np.mean( ptu.get_numpy(pre_activation_policy_loss)) self.eval_statistics['Entropy Loss {}'.format( agent)] = np.mean(ptu.get_numpy(entropy_loss)) if self.use_entropy_loss: self.eval_statistics['Alpha {} Mean'.format( agent)] = np.mean(ptu.get_numpy(alpha)) self.eval_statistics.update( create_stats_ordered_dict( 'Policy Action {}'.format(agent), ptu.get_numpy(policy_actions), )) for agent in range(num_agent): # self.policy_optimizer_n[agent].zero_grad() for pid, p in enumerate(self.policy_n[agent].parameters()): p.grad = policy_gradients_n[agent][pid] self.policy_optimizer_n[agent].step() """ Critic operations. """ with torch.no_grad(): next_actions_n, next_log_pis_n = [], [] for agent in range(num_agent): if self.shared_obs: next_actions, next_info = self.policy_n[agent]( next_obs_n, return_info=True, deterministic=self.deterministic_next_action, ) else: next_actions, next_info = self.policy_n[agent]( next_obs_n[:, agent, :], return_info=True, deterministic=self.deterministic_next_action, ) next_actions_n.append(next_actions) next_log_pis_n.append(next_info['log_prob']) next_k0_actions = torch.stack( next_actions_n) # num_agent x batch x a_dim next_k0_actions = next_k0_actions.transpose( 0, 1).contiguous() # batch x num_agent x a_dim if self.shared_obs: next_k0_inputs = torch.cat( [next_obs_n, next_k0_actions.reshape(batch_size, -1)], dim=-1) else: next_k0_inputs = torch.cat([next_obs_n, next_k0_actions], dim=-1) next_k1_actions = self.cactor( next_k0_inputs, deterministic=self.deterministic_cactor_in_graph) if self.shared_obs: next_k1_inputs = torch.cat( [next_obs_n, next_k1_actions.reshape(batch_size, -1)], dim=-1) else: next_k1_inputs = torch.cat([next_obs_n, next_k1_actions], dim=-1) next_k1_contexts_1 = self.target_cg1(next_k1_inputs) next_k1_contexts_2 = self.target_cg2(next_k1_inputs) next_q1_inputs = torch.cat([next_k0_actions, next_k1_contexts_1], dim=-1) next_target_q1_values = self.target_qf1(next_q1_inputs) next_q2_inputs = torch.cat([next_k0_actions, next_k1_contexts_2], dim=-1) next_target_q2_values = self.target_qf2(next_q2_inputs) next_target_q_values = torch.min(next_target_q1_values, next_target_q2_values) if self.use_entropy_reward: next_alphas = torch.stack(alpha_n)[None, :] next_log_pis = torch.stack(next_log_pis_n).transpose( 0, 1).contiguous() next_target_q_values = next_target_q_values - next_alphas * next_log_pis q_targets = self.reward_scale * rewards_n + ( 1. - terminals_n) * self.discount * next_target_q_values q_targets = torch.clamp(q_targets, self.min_q_value, self.max_q_value) if self.grad_loss: k0_actions = actions_n.clone().detach() k0_actions.requires_grad = True else: k0_actions = actions_n k1_actions = actions_n if self.shared_obs: buffer_inputs = torch.cat( [obs_n, k0_actions.reshape(batch_size, -1)], dim=-1) else: buffer_inputs = torch.cat([obs_n, k0_actions], dim=-1) buffer_contexts_1 = self.cg1( buffer_inputs) # batch x num_agent x c_dim q1_inputs = torch.cat([k1_actions, buffer_contexts_1], dim=-1) q1_preds = self.qf1(q1_inputs) raw_qf1_loss = self.qf_criterion(q1_preds, q_targets) buffer_contexts_2 = self.cg2( buffer_inputs) # batch x num_agent x c_dim q2_inputs = torch.cat([k1_actions, buffer_contexts_2], dim=-1) q2_preds = self.qf2(q2_inputs) raw_qf2_loss = self.qf_criterion(q2_preds, q_targets) if self.negative_sampling: batch_size, num_agent, a_dim = actions_n.shape perturb_agents = torch.randint(low=0, high=num_agent, size=(batch_size, )).to(ptu.device) neg_actions = (torch.rand(batch_size, num_agent, a_dim) * 2. - 1.).to(ptu.device) # ranged in -1 to 1 perturb_k0_actions = actions_n.clone() # batch x agent x |A| perturb_k0_actions[torch.arange(batch_size), perturb_agents, :] = neg_actions[ torch.arange(batch_size), perturb_agents, :] perturb_k1_actions = neg_actions.clone() perturb_k1_actions[torch.arange(batch_size), perturb_agents, :] = actions_n[ torch.arange(batch_size), perturb_agents, :] if self.shared_obs: perturb_inputs = torch.cat( [obs_n, perturb_k0_actions.reshape(batch_size, -1)], dim=-1) else: perturb_inputs = torch.cat([obs_n, perturb_k0_actions], dim=-1) perturb_contexts_1 = self.cg1( perturb_inputs) # batch x num_agent x c_dim perturb_q1_inputs = torch.cat( [perturb_k1_actions, perturb_contexts_1], dim=-1) perturb_q1_preds = self.qf1(perturb_q1_inputs)[ torch.arange(batch_size), perturb_agents, :] perturb_contexts_2 = self.cg2( perturb_inputs) # batch x num_agent x c_dim perturb_q2_inputs = torch.cat( [perturb_k1_actions, perturb_contexts_2], dim=-1) perturb_q2_preds = self.qf2(perturb_q2_inputs)[ torch.arange(batch_size), perturb_agents, :] perturb_q_targets = q_targets[torch.arange(batch_size), perturb_agents, :] neg_loss1 = self.qf_criterion(perturb_q1_preds, perturb_q_targets) neg_loss2 = self.qf_criterion(perturb_q2_preds, perturb_q_targets) else: neg_loss1, neg_loss2 = torch.tensor(0.).to( ptu.device), torch.tensor(0.).to(ptu.device) if self.grad_loss: grad_loss1 = 0 for agent in range(num_agent): grads = torch.autograd.grad(torch.sum(q1_preds[:, agent, :]), k0_actions, retain_graph=True, create_graph=True)[0] grad_loss1 += grads[:, agent, :].norm(2) grad_loss2 = 0 for agent in range(num_agent): grads = torch.autograd.grad(torch.sum(q2_preds[:, agent, :]), k0_actions, retain_graph=True, create_graph=True)[0] grad_loss2 += grads[:, agent, :].norm(2) else: grad_loss1, grad_loss2 = torch.tensor(0.).to( ptu.device), torch.tensor(0.).to(ptu.device) if self.qf_weight_decay > 0: reg_loss1 = self.qf_weight_decay * sum( torch.sum(param**2) for param in list(self.qf1.regularizable_parameters()) + list(self.cg1.regularizable_parameters())) reg_loss2 = self.qf_weight_decay * sum( torch.sum(param**2) for param in list(self.qf2.regularizable_parameters()) + list(self.cg2.regularizable_parameters())) else: reg_loss1, reg_loss2 = torch.tensor(0.).to( ptu.device), torch.tensor(0.).to(ptu.device) qf1_loss = raw_qf1_loss + neg_loss1 + grad_loss1 + reg_loss1 qf2_loss = raw_qf2_loss + neg_loss2 + grad_loss2 + reg_loss2 if self._need_to_update_eval_statistics: self.eval_statistics['Qf1 Loss'] = ptu.get_numpy(qf1_loss) self.eval_statistics['Qf2 Loss'] = ptu.get_numpy(qf2_loss) self.eval_statistics['Raw Qf1 Loss'] = ptu.get_numpy(raw_qf1_loss) self.eval_statistics['Raw Qf2 Loss'] = ptu.get_numpy(raw_qf2_loss) self.eval_statistics['Neg Qf1 Loss'] = ptu.get_numpy(neg_loss1) self.eval_statistics['Neg Qf2 Loss'] = ptu.get_numpy(neg_loss2) self.eval_statistics['Grad Qf1 Loss'] = ptu.get_numpy(grad_loss1) self.eval_statistics['Grad Qf2 Loss'] = ptu.get_numpy(grad_loss2) self.eval_statistics['Reg Qf2 Loss'] = ptu.get_numpy(reg_loss1) self.eval_statistics['Reg Qf2 Loss'] = ptu.get_numpy(reg_loss2) self.qf1_optimizer.zero_grad() qf1_loss.backward() self.qf1_optimizer.step() self.qf2_optimizer.zero_grad() qf2_loss.backward() self.qf2_optimizer.step() """ Central actor operations. """ if self.shared_obs: buffer_inputs = torch.cat( [obs_n, actions_n.reshape(batch_size, -1)], dim=-1) else: buffer_inputs = torch.cat([obs_n, actions_n], dim=-1) cactor_actions, cactor_infos = self.cactor(buffer_inputs, return_info=True) # batch x agent_num x |A| buffer_contexts_1 = self.cg1(buffer_inputs).detach() buffer_contexts_2 = self.cg2(buffer_inputs).detach() cactor_pre_values = cactor_infos['preactivation'] if self.pre_activation_weight > 0: pre_activation_cactor_loss = ((cactor_pre_values**2).sum( dim=1).mean()) else: pre_activation_cactor_loss = torch.tensor(0.).to(ptu.device) if self.use_cactor_entropy_loss: cactor_log_pis = cactor_infos['log_prob'] # batch x num_ageng x 1 if self.use_automatic_entropy_tuning: calphas = torch.stack(self.log_calpha_n).exp()[None, :] calpha_loss = -( calphas * (cactor_log_pis + self.target_entropy).detach()) calpha_loss = calpha_loss.mean() self.calpha_optimizer.zero_grad() calpha_loss.backward() self.calpha_optimizer.step() calphas = torch.stack(self.log_calpha_n).exp().detach() else: calpha_loss = torch.tensor(0.).to(ptu.device) calphas = torch.stack([ torch.tensor(self.init_alpha).to(ptu.device) for i in range(num_agent) ]) cactor_entropy_loss = (calphas[None, :] * cactor_log_pis).mean() else: cactor_entropy_loss = torch.tensor(0.).to(ptu.device) q1_inputs = torch.cat([cactor_actions, buffer_contexts_1], dim=-1) q1_outputs = self.qf1(q1_inputs) q2_inputs = torch.cat([cactor_actions, buffer_contexts_2], dim=-1) q2_outputs = self.qf2(q2_inputs) q_outputs = torch.min(q1_outputs, q2_outputs) raw_cactor_loss = -q_outputs.mean() cactor_loss = ( raw_cactor_loss + pre_activation_cactor_loss * self.pre_activation_weight + cactor_entropy_loss) if self._need_to_update_eval_statistics: if self.use_cactor_entropy_loss: self.eval_statistics.update( create_stats_ordered_dict( 'CAlpha', ptu.get_numpy(calphas), )) self.eval_statistics['Cactor Loss'] = ptu.get_numpy(cactor_loss) self.eval_statistics['Raw Cactor Loss'] = ptu.get_numpy( raw_cactor_loss) self.eval_statistics['Preactivation Cactor Loss'] = ptu.get_numpy( pre_activation_cactor_loss) self.eval_statistics['Entropy Cactor Loss'] = ptu.get_numpy( cactor_entropy_loss) self.cactor_optimizer.zero_grad() cactor_loss.backward() self.cactor_optimizer.step() self._need_to_update_eval_statistics = False self._update_target_networks() self._n_train_steps_total += 1
def test_from_torch(self, batch): rewards = batch['rewards'] terminals = batch['terminals'] obs = batch['observations'] actions = batch['actions'] next_obs = batch['next_observations'] weights = batch.get('weights', None) if self.reward_transform: rewards = self.reward_transform(rewards) if self.terminal_transform: terminals = self.terminal_transform(terminals) """ Policy and Alpha Loss """ dist = self.policy(obs) new_obs_actions, log_pi = dist.rsample_and_logprob() policy_mle = dist.mle_estimate() if self.use_automatic_entropy_tuning: alpha_loss = -(self.log_alpha * (log_pi + self.target_entropy).detach()).mean() alpha = self.log_alpha.exp() else: alpha_loss = 0 alpha = self.alpha q1_pred = self.qf1(obs, actions) q2_pred = self.qf2(obs, actions) # Make sure policy accounts for squashing functions like tanh correctly! next_dist = self.policy(next_obs) new_next_actions, new_log_pi = next_dist.rsample_and_logprob() target_q_values = torch.min( self.target_qf1(next_obs, new_next_actions), self.target_qf2(next_obs, new_next_actions), ) - alpha * new_log_pi q_target = self.reward_scale * rewards + ( 1. - terminals) * self.discount * target_q_values qf1_loss = self.qf_criterion(q1_pred, q_target.detach()) qf2_loss = self.qf_criterion(q2_pred, q_target.detach()) qf1_new_actions = self.qf1(obs, new_obs_actions) qf2_new_actions = self.qf2(obs, new_obs_actions) q_new_actions = torch.min( qf1_new_actions, qf2_new_actions, ) policy_loss = (log_pi - q_new_actions).mean() self.eval_statistics['validation/QF1 Loss'] = np.mean( ptu.get_numpy(qf1_loss)) self.eval_statistics['validation/QF2 Loss'] = np.mean( ptu.get_numpy(qf2_loss)) self.eval_statistics['validation/Policy Loss'] = np.mean( ptu.get_numpy(policy_loss)) self.eval_statistics.update( create_stats_ordered_dict( 'validation/Q1 Predictions', ptu.get_numpy(q1_pred), )) self.eval_statistics.update( create_stats_ordered_dict( 'validation/Q2 Predictions', ptu.get_numpy(q2_pred), )) self.eval_statistics.update( create_stats_ordered_dict( 'validation/Q Targets', ptu.get_numpy(q_target), )) self.eval_statistics.update( create_stats_ordered_dict( 'validation/Log Pis', ptu.get_numpy(log_pi), )) policy_statistics = add_prefix(dist.get_diagnostics(), "validation/policy/") self.eval_statistics.update(policy_statistics)
def train_from_torch(self, batch, train=True, pretrain=False,): rewards = batch['rewards'] terminals = batch['terminals'] obs = batch['observations'] actions = batch['actions'] next_obs = batch['next_observations'] if self.reward_transform: rewards = self.reward_transform(rewards) if self.terminal_transform: terminals = self.terminal_transform(terminals) """ Policy and Alpha Loss """ dist = self.policy(obs) """ QF Loss """ q1_pred = self.qf1(obs, actions) q2_pred = self.qf2(obs, actions) target_vf_pred = self.vf(next_obs).detach() q_target = self.reward_scale * rewards + (1. - terminals) * self.discount * target_vf_pred q_target = q_target.detach() qf1_loss = self.qf_criterion(q1_pred, q_target) qf2_loss = self.qf_criterion(q2_pred, q_target) """ VF Loss """ q_pred = torch.min( self.target_qf1(obs, actions), self.target_qf2(obs, actions), ).detach() vf_pred = self.vf(obs) vf_err = vf_pred - q_pred vf_sign = (vf_err > 0).float() vf_weight = (1 - vf_sign) * self.quantile + vf_sign * (1 - self.quantile) vf_loss = (vf_weight * (vf_err ** 2)).mean() """ Policy Loss """ policy_logpp = dist.log_prob(actions) adv = q_pred - vf_pred exp_adv = torch.exp(adv / self.beta) if self.clip_score is not None: exp_adv = torch.clamp(exp_adv, max=self.clip_score) weights = exp_adv[:, 0].detach() policy_loss = (-policy_logpp * weights).mean() """ Update networks """ if self._n_train_steps_total % self.q_update_period == 0: self.qf1_optimizer.zero_grad() qf1_loss.backward() self.qf1_optimizer.step() self.qf2_optimizer.zero_grad() qf2_loss.backward() self.qf2_optimizer.step() self.vf_optimizer.zero_grad() vf_loss.backward() self.vf_optimizer.step() if self._n_train_steps_total % self.policy_update_period == 0: self.policy_optimizer.zero_grad() policy_loss.backward() self.policy_optimizer.step() """ Soft Updates """ if self._n_train_steps_total % self.target_update_period == 0: ptu.soft_update_from_to( self.qf1, self.target_qf1, self.soft_target_tau ) ptu.soft_update_from_to( self.qf2, self.target_qf2, self.soft_target_tau ) """ Save some statistics for eval """ if self._need_to_update_eval_statistics: self._need_to_update_eval_statistics = False """ Eval should set this to None. This way, these statistics are only computed for one batch. """ self.eval_statistics['QF1 Loss'] = np.mean(ptu.get_numpy(qf1_loss)) self.eval_statistics['QF2 Loss'] = np.mean(ptu.get_numpy(qf2_loss)) self.eval_statistics['Policy Loss'] = np.mean(ptu.get_numpy( policy_loss )) self.eval_statistics.update(create_stats_ordered_dict( 'Q1 Predictions', ptu.get_numpy(q1_pred), )) self.eval_statistics.update(create_stats_ordered_dict( 'Q2 Predictions', ptu.get_numpy(q2_pred), )) self.eval_statistics.update(create_stats_ordered_dict( 'Q Targets', ptu.get_numpy(q_target), )) self.eval_statistics.update(create_stats_ordered_dict( 'rewards', ptu.get_numpy(rewards), )) self.eval_statistics.update(create_stats_ordered_dict( 'terminals', ptu.get_numpy(terminals), )) self.eval_statistics['replay_buffer_len'] = self.replay_buffer._size policy_statistics = add_prefix(dist.get_diagnostics(), "policy/") self.eval_statistics.update(policy_statistics) self.eval_statistics.update(create_stats_ordered_dict( 'Advantage Weights', ptu.get_numpy(weights), )) self.eval_statistics.update(create_stats_ordered_dict( 'Advantage Score', ptu.get_numpy(adv), )) self.eval_statistics.update(create_stats_ordered_dict( 'V1 Predictions', ptu.get_numpy(vf_pred), )) self.eval_statistics['VF Loss'] = np.mean(ptu.get_numpy(vf_loss)) self._n_train_steps_total += 1
def train_from_torch(self, batch): rewards = batch['rewards'] terminals = batch['terminals'] obs = batch['observations'] actions = batch['actions'] next_obs = batch['next_observations'] """ Critic operations. """ next_actions = self.target_policy(next_obs) noise = ptu.randn(next_actions.shape) * self.target_policy_noise noise = torch.clamp( noise, -self.target_policy_noise_clip, self.target_policy_noise_clip ) noisy_next_actions = next_actions + noise target_q1_values = self.target_qf1(next_obs, noisy_next_actions) target_q2_values = self.target_qf2(next_obs, noisy_next_actions) target_q_values = torch.min(target_q1_values, target_q2_values) q_target = self.reward_scale * rewards + (1. - terminals) * self.discount * target_q_values q_target = q_target.detach() q1_pred = self.qf1(obs, actions) bellman_errors_1 = (q1_pred - q_target) ** 2 qf1_loss = bellman_errors_1.mean() q2_pred = self.qf2(obs, actions) bellman_errors_2 = (q2_pred - q_target) ** 2 qf2_loss = bellman_errors_2.mean() """ Update Networks """ self.qf1_optimizer.zero_grad() qf1_loss.backward() self.qf1_optimizer.step() self.qf2_optimizer.zero_grad() qf2_loss.backward() self.qf2_optimizer.step() policy_actions = policy_loss = None if self._n_train_steps_total % self.policy_and_target_update_period == 0: policy_actions = self.policy(obs) q_output = self.qf1(obs, policy_actions) policy_loss = - q_output.mean() self.policy_optimizer.zero_grad() policy_loss.backward() self.policy_optimizer.step() ptu.soft_update_from_to(self.policy, self.target_policy, self.tau) ptu.soft_update_from_to(self.qf1, self.target_qf1, self.tau) ptu.soft_update_from_to(self.qf2, self.target_qf2, self.tau) if self._need_to_update_eval_statistics: self._need_to_update_eval_statistics = False if policy_loss is None: policy_actions = self.policy(obs) q_output = self.qf1(obs, policy_actions) policy_loss = - q_output.mean() self.eval_statistics['QF1 Loss'] = np.mean(ptu.get_numpy(qf1_loss)) self.eval_statistics['QF2 Loss'] = np.mean(ptu.get_numpy(qf2_loss)) self.eval_statistics['Policy Loss'] = np.mean(ptu.get_numpy( policy_loss )) self.eval_statistics.update(create_stats_ordered_dict( 'Q1 Predictions', ptu.get_numpy(q1_pred), )) self.eval_statistics.update(create_stats_ordered_dict( 'Q2 Predictions', ptu.get_numpy(q2_pred), )) self.eval_statistics.update(create_stats_ordered_dict( 'Q Targets', ptu.get_numpy(q_target), )) self.eval_statistics.update(create_stats_ordered_dict( 'Bellman Errors 1', ptu.get_numpy(bellman_errors_1), )) self.eval_statistics.update(create_stats_ordered_dict( 'Bellman Errors 2', ptu.get_numpy(bellman_errors_2), )) self.eval_statistics.update(create_stats_ordered_dict( 'Policy Action', ptu.get_numpy(policy_actions), )) self._n_train_steps_total += 1
def train_from_torch(self, batch): rewards = batch['rewards'] terminals = batch['terminals'] obs = batch['observations'] actions = batch['actions'] next_obs = batch['next_observations'] """ Policy and Alpha Loss """ pis = self.policy(obs) if self.use_automatic_entropy_tuning: alpha_loss = -(pis.detach() * self.log_alpha.exp() * (torch.log(pis + 1e-3) + self.target_entropy).detach()).sum(-1).mean() self.alpha_optimizer.zero_grad() alpha_loss.backward() self.alpha_optimizer.step() alpha = self.log_alpha.exp() else: alpha_loss = 0 alpha = 1 min_q = torch.min(self.qf1(obs), self.qf2(obs)).detach() policy_loss = (pis * (alpha * torch.log(pis + 1e-3) - min_q)).sum(-1).mean() """ QF Loss """ new_pis = self.policy(next_obs).detach() target_min_q_values = torch.min( self.target_qf1(next_obs), self.target_qf2(next_obs), ) target_q_values = ( new_pis * (target_min_q_values - alpha * torch.log(new_pis + 1e-3))).sum( -1, keepdim=True) q_target = self.reward_scale * rewards + ( 1. - terminals) * self.discount * target_q_values q1_pred = torch.sum(self.qf1(obs) * actions.detach(), dim=-1, keepdim=True) q2_pred = torch.sum(self.qf2(obs) * actions.detach(), dim=-1, keepdim=True) qf1_loss = self.qf_criterion(q1_pred, q_target.detach()) qf2_loss = self.qf_criterion(q2_pred, q_target.detach()) """ Update networks """ self.qf1_optimizer.zero_grad() qf1_loss.backward() self.qf1_optimizer.step() self.qf2_optimizer.zero_grad() qf2_loss.backward() self.qf2_optimizer.step() self.policy_optimizer.zero_grad() policy_loss.backward() self.policy_optimizer.step() """ Soft Updates """ if self._n_train_steps_total % self.target_update_period == 0: ptu.soft_update_from_to(self.qf1, self.target_qf1, self.soft_target_tau) ptu.soft_update_from_to(self.qf2, self.target_qf2, self.soft_target_tau) """ Save some statistics for eval """ if self._need_to_update_eval_statistics: self._need_to_update_eval_statistics = False """ Eval should set this to None. This way, these statistics are only computed for one batch. """ self.eval_statistics['QF1 Loss'] = np.mean(ptu.get_numpy(qf1_loss)) self.eval_statistics['QF2 Loss'] = np.mean(ptu.get_numpy(qf2_loss)) self.eval_statistics['Policy Loss'] = np.mean( ptu.get_numpy(policy_loss)) self.eval_statistics.update( create_stats_ordered_dict( 'Q1 Predictions', ptu.get_numpy(q1_pred), )) self.eval_statistics.update( create_stats_ordered_dict( 'Q2 Predictions', ptu.get_numpy(q2_pred), )) self.eval_statistics.update( create_stats_ordered_dict( 'Q Targets', ptu.get_numpy(q_target), )) self.eval_statistics.update( create_stats_ordered_dict( 'Pis', ptu.get_numpy(pis), )) if self.use_automatic_entropy_tuning: self.eval_statistics['Alpha'] = alpha.item() self.eval_statistics['Alpha Loss'] = alpha_loss.item() self._n_train_steps_total += 1
def train_from_torch(self, batch): rewards = batch['rewards'] terminals = batch['terminals'] obs = batch['observations'] actions = batch['actions'] next_obs = batch['next_observations'] """ Policy and Alpha Loss """ new_obs_actions, policy_mean, policy_log_std, log_pi, *_ = self.policy( obs, reparameterize=True, return_log_prob=True, ) if self.use_automatic_entropy_tuning: alpha_loss = -(self.log_alpha * (log_pi + self.target_entropy).detach()).mean() self.alpha_optimizer.zero_grad() alpha_loss.backward() self.alpha_optimizer.step() alpha = self.log_alpha.exp() else: alpha_loss = 0 alpha = 1 q_new_actions = torch.min( self.qf1(obs, new_obs_actions), self.qf2(obs, new_obs_actions), ) policy_loss = (alpha * log_pi - q_new_actions).mean() """ QF Loss """ q1_pred = self.qf1(obs, actions) q2_pred = self.qf2(obs, actions) # Make sure policy accounts for squashing functions like tanh correctly! new_next_actions, _, _, new_log_pi, *_ = self.policy( next_obs, reparameterize=True, return_log_prob=True, ) target_q_values = torch.min( self.target_qf1(next_obs, new_next_actions), self.target_qf2(next_obs, new_next_actions), ) - alpha * new_log_pi q_target = self.reward_scale * rewards + ( 1. - terminals) * self.discount * target_q_values qf1_loss = self.qf_criterion(q1_pred, q_target.detach()) qf2_loss = self.qf_criterion(q2_pred, q_target.detach()) """ Update networks """ self.qf1_optimizer.zero_grad() qf1_loss.backward() self.qf1_optimizer.step() self.qf2_optimizer.zero_grad() qf2_loss.backward() self.qf2_optimizer.step() self.policy_optimizer.zero_grad() policy_loss.backward() self.policy_optimizer.step() """ Soft Updates """ if self._n_train_steps_total % self.target_update_period == 0: ptu.soft_update_from_to(self.qf1, self.target_qf1, self.soft_target_tau) ptu.soft_update_from_to(self.qf2, self.target_qf2, self.soft_target_tau) """ Save some statistics for eval """ if self._need_to_update_eval_statistics: self._need_to_update_eval_statistics = False """ Eval should set this to None. This way, these statistics are only computed for one batch. """ policy_loss = (log_pi - q_new_actions).mean() self.eval_statistics['QF1 Loss'] = np.mean(ptu.get_numpy(qf1_loss)) self.eval_statistics['QF2 Loss'] = np.mean(ptu.get_numpy(qf2_loss)) self.eval_statistics['Policy Loss'] = np.mean( ptu.get_numpy(policy_loss)) self.eval_statistics.update( create_stats_ordered_dict( 'Q1 Predictions', ptu.get_numpy(q1_pred), )) self.eval_statistics.update( create_stats_ordered_dict( 'Q2 Predictions', ptu.get_numpy(q2_pred), )) self.eval_statistics.update( create_stats_ordered_dict( 'Q Targets', ptu.get_numpy(q_target), )) self.eval_statistics.update( create_stats_ordered_dict( 'Log Pis', ptu.get_numpy(log_pi), )) self.eval_statistics.update( create_stats_ordered_dict( 'Policy mu', ptu.get_numpy(policy_mean), )) self.eval_statistics.update( create_stats_ordered_dict( 'Policy log std', ptu.get_numpy(policy_log_std), )) if self.use_automatic_entropy_tuning: self.eval_statistics['Alpha'] = alpha.item() self.eval_statistics['Alpha Loss'] = alpha_loss.item() self._n_train_steps_total += 1
def train_from_torch(self, batch): self._current_epoch += 1 rewards = batch['rewards'] terminals = batch['terminals'] obs = batch['observations'] actions = batch['actions'] next_obs = batch['next_observations'] """ Behavior clone a policy """ recon, mean, std = self.vae(obs, actions) recon_loss = self.qf_criterion(recon, actions) kl_loss = -0.5 * (1 + torch.log(std.pow(2)) - mean.pow(2) - std.pow(2)).mean() vae_loss = recon_loss + 0.5 * kl_loss self.vae_optimizer.zero_grad() vae_loss.backward() self.vae_optimizer.step() """ Critic Training """ # import ipdb; ipdb.set_trace() with torch.no_grad(): # Duplicate state 10 times (10 is a hyperparameter chosen by BCQ) state_rep = next_obs.unsqueeze(1).repeat(1, 10, 1).view( next_obs.shape[0] * 10, next_obs.shape[1]) # 10BxS # Compute value of perturbed actions sampled from the VAE action_rep = self.policy(state_rep)[0] target_qf1 = self.target_qf1(state_rep, action_rep) target_qf2 = self.target_qf2(state_rep, action_rep) # Soft Clipped Double Q-learning target_Q = 0.75 * torch.min(target_qf1, target_qf2) + 0.25 * torch.max( target_qf1, target_qf2) target_Q = target_Q.view(next_obs.shape[0], -1).max(1)[0].view(-1, 1) target_Q = self.reward_scale * rewards + ( 1.0 - terminals) * self.discount * target_Q # Bx1 qf1_pred = self.qf1(obs, actions) # Bx1 qf2_pred = self.qf2(obs, actions) # Bx1 qf1_loss = (qf1_pred - target_Q.detach()).pow(2).mean() qf2_loss = (qf2_pred - target_Q.detach()).pow(2).mean() """ Actor Training """ sampled_actions, raw_sampled_actions = self.vae.decode_multiple( obs, num_decode=self.num_samples_mmd_match) actor_samples, _, _, _, _, _, _, raw_actor_actions = self.policy( obs.unsqueeze(1).repeat(1, self.num_samples_mmd_match, 1).view(-1, obs.shape[1]), return_log_prob=True) actor_samples = actor_samples.view(obs.shape[0], self.num_samples_mmd_match, actions.shape[1]) raw_actor_actions = raw_actor_actions.view(obs.shape[0], self.num_samples_mmd_match, actions.shape[1]) if self.kernel_choice == 'laplacian': mmd_loss = self.mmd_loss_laplacian(raw_sampled_actions, raw_actor_actions, sigma=self.mmd_sigma) elif self.kernel_choice == 'gaussian': mmd_loss = self.mmd_loss_gaussian(raw_sampled_actions, raw_actor_actions, sigma=self.mmd_sigma) action_divergence = ((sampled_actions - actor_samples)**2).sum(-1) raw_action_divergence = ((raw_sampled_actions - raw_actor_actions)**2).sum(-1) q_val1 = self.qf1(obs, actor_samples[:, 0, :]) q_val2 = self.qf2(obs, actor_samples[:, 0, :]) if self.policy_update_style == '0': policy_loss = torch.min(q_val1, q_val2)[:, 0] elif self.policy_update_style == '1': policy_loss = torch.mean(q_val1, q_val2)[:, 0] if self._n_train_steps_total >= 40000: # Now we can update the policy if self.mode == 'auto': policy_loss = (-policy_loss + self.log_alpha.exp() * (mmd_loss - self.target_mmd_thresh)).mean() else: policy_loss = (-policy_loss + 100 * mmd_loss).mean() else: if self.mode == 'auto': policy_loss = (self.log_alpha.exp() * (mmd_loss - self.target_mmd_thresh)).mean() else: policy_loss = 100 * mmd_loss.mean() """ Update Networks """ self.qf1_optimizer.zero_grad() qf1_loss.backward() self.qf1_optimizer.step() self.qf2_optimizer.zero_grad() qf2_loss.backward() self.qf2_optimizer.step() self.policy_optimizer.zero_grad() if self.mode == 'auto': policy_loss.backward(retain_graph=True) self.policy_optimizer.step() if self.mode == 'auto': self.alpha_optimizer.zero_grad() (-policy_loss).backward() self.alpha_optimizer.step() self.log_alpha.data.clamp_(min=-5.0, max=10.0) """ Update networks """ if self._n_train_steps_total % self.target_update_period == 0: ptu.soft_update_from_to(self.qf1, self.target_qf1, self.soft_target_tau) ptu.soft_update_from_to(self.qf2, self.target_qf2, self.soft_target_tau) """ Some statistics for eval """ if self._need_to_update_eval_statistics: self._need_to_update_eval_statistics = False """ Eval should set this to None. This way, these statistics are only computed for one batch. """ self.eval_statistics['QF1 Loss'] = np.mean(ptu.get_numpy(qf1_loss)) self.eval_statistics['QF2 Loss'] = np.mean(ptu.get_numpy(qf2_loss)) self.eval_statistics['Num Q Updates'] = self._num_q_update_steps self.eval_statistics[ 'Num Policy Updates'] = self._num_policy_update_steps self.eval_statistics['Policy Loss'] = np.mean( ptu.get_numpy(policy_loss)) self.eval_statistics.update( create_stats_ordered_dict( 'Q1 Predictions', ptu.get_numpy(qf1_pred), )) self.eval_statistics.update( create_stats_ordered_dict( 'Q2 Predictions', ptu.get_numpy(qf2_pred), )) self.eval_statistics.update( create_stats_ordered_dict( 'Q Targets', ptu.get_numpy(target_Q), )) self.eval_statistics.update( create_stats_ordered_dict('MMD Loss', ptu.get_numpy(mmd_loss))) self.eval_statistics.update( create_stats_ordered_dict('Action Divergence', ptu.get_numpy(action_divergence))) self.eval_statistics.update( create_stats_ordered_dict( 'Raw Action Divergence', ptu.get_numpy(raw_action_divergence))) if self.mode == 'auto': self.eval_statistics['Alpha'] = self.log_alpha.exp().item() self._n_train_steps_total += 1
def train_from_torch( self, batch, train=True, pretrain=False, ): """ :param batch: :param train: :param pretrain: :return: """ rewards = batch['rewards'] terminals = batch['terminals'] obs = batch['observations'] actions = batch['actions'] next_obs = batch['next_observations'] weights = batch.get('weights', None) if self.reward_transform: rewards = self.reward_transform(rewards) if self.terminal_transform: terminals = self.terminal_transform(terminals) """ Policy and Alpha Loss """ dist = self.policy(obs) new_obs_actions, log_pi = dist.rsample_and_logprob() policy_mle = dist.mle_estimate() if self.brac: buf_dist = self.buffer_policy(obs) buf_log_pi = buf_dist.log_prob(actions) rewards = rewards + buf_log_pi if self.use_automatic_entropy_tuning: alpha_loss = -(self.log_alpha * (log_pi + self.target_entropy).detach()).mean() self.alpha_optimizer.zero_grad() alpha_loss.backward() self.alpha_optimizer.step() alpha = self.log_alpha.exp() else: alpha_loss = 0 alpha = self.alpha """ QF Loss """ q1_pred = self.qf1(obs, actions) q2_pred = self.qf2(obs, actions) # Make sure policy accounts for squashing functions like tanh correctly! next_dist = self.policy(next_obs) new_next_actions, new_log_pi = next_dist.rsample_and_logprob() target_q_values = torch.min( self.target_qf1(next_obs, new_next_actions), self.target_qf2(next_obs, new_next_actions), ) - alpha * new_log_pi q_target = self.reward_scale * rewards + ( 1. - terminals) * self.discount * target_q_values qf1_loss = self.qf_criterion(q1_pred, q_target.detach()) qf2_loss = self.qf_criterion(q2_pred, q_target.detach()) """ Policy Loss """ qf1_new_actions = self.qf1(obs, new_obs_actions) qf2_new_actions = self.qf2(obs, new_obs_actions) q_new_actions = torch.min( qf1_new_actions, qf2_new_actions, ) # Advantage-weighted regression if self.awr_use_mle_for_vf: v1_pi = self.qf1(obs, policy_mle) v2_pi = self.qf2(obs, policy_mle) v_pi = torch.min(v1_pi, v2_pi) else: if self.vf_K > 1: vs = [] for i in range(self.vf_K): u = dist.sample() q1 = self.qf1(obs, u) q2 = self.qf2(obs, u) v = torch.min(q1, q2) # v = q1 vs.append(v) v_pi = torch.cat(vs, 1).mean(dim=1) else: # v_pi = self.qf1(obs, new_obs_actions) v1_pi = self.qf1(obs, new_obs_actions) v2_pi = self.qf2(obs, new_obs_actions) v_pi = torch.min(v1_pi, v2_pi) if self.awr_sample_actions: u = new_obs_actions if self.awr_min_q: q_adv = q_new_actions else: q_adv = qf1_new_actions elif self.buffer_policy_sample_actions: buf_dist = self.buffer_policy(obs) u, _ = buf_dist.rsample_and_logprob() qf1_buffer_actions = self.qf1(obs, u) qf2_buffer_actions = self.qf2(obs, u) q_buffer_actions = torch.min( qf1_buffer_actions, qf2_buffer_actions, ) if self.awr_min_q: q_adv = q_buffer_actions else: q_adv = qf1_buffer_actions else: u = actions if self.awr_min_q: q_adv = torch.min(q1_pred, q2_pred) else: q_adv = q1_pred policy_logpp = dist.log_prob(u) if self.use_automatic_beta_tuning: buffer_dist = self.buffer_policy(obs) beta = self.log_beta.exp() kldiv = torch.distributions.kl.kl_divergence(dist, buffer_dist) beta_loss = -1 * (beta * (kldiv - self.beta_epsilon).detach()).mean() self.beta_optimizer.zero_grad() beta_loss.backward() self.beta_optimizer.step() else: beta = self.beta_schedule.get_value(self._n_train_steps_total) if self.normalize_over_state == "advantage": score = q_adv - v_pi if self.mask_positive_advantage: score = torch.sign(score) elif self.normalize_over_state == "Z": buffer_dist = self.buffer_policy(obs) K = self.Z_K buffer_obs = [] buffer_actions = [] log_bs = [] log_pis = [] for i in range(K): u = buffer_dist.sample() log_b = buffer_dist.log_prob(u) log_pi = dist.log_prob(u) buffer_obs.append(obs) buffer_actions.append(u) log_bs.append(log_b) log_pis.append(log_pi) buffer_obs = torch.cat(buffer_obs, 0) buffer_actions = torch.cat(buffer_actions, 0) p_buffer = torch.exp(torch.cat(log_bs, 0).sum(dim=1, )) log_pi = torch.cat(log_pis, 0) log_pi = log_pi.sum(dim=1, ) q1_b = self.qf1(buffer_obs, buffer_actions) q2_b = self.qf2(buffer_obs, buffer_actions) q_b = torch.min(q1_b, q2_b) q_b = torch.reshape(q_b, (-1, K)) adv_b = q_b - v_pi # if self._n_train_steps_total % 100 == 0: # import ipdb; ipdb.set_trace() # Z = torch.exp(adv_b / beta).mean(dim=1, keepdim=True) # score = torch.exp((q_adv - v_pi) / beta) / Z # score = score / sum(score) logK = torch.log(ptu.tensor(float(K))) logZ = torch.logsumexp(adv_b / beta - logK, dim=1, keepdim=True) logS = (q_adv - v_pi) / beta - logZ # logZ = torch.logsumexp(q_b/beta - logK, dim=1, keepdim=True) # logS = q_adv/beta - logZ score = F.softmax(logS, dim=0) # score / sum(score) else: error if self.clip_score is not None: score = torch.clamp(score, max=self.clip_score) if self.weight_loss and weights is None: if self.normalize_over_batch: weights = F.softmax(score / beta, dim=0) elif self.normalize_over_batch == "whiten": adv_mean = torch.mean(score) adv_std = torch.std(score) + 1e-5 normalized_score = (score - adv_mean) / adv_std weights = torch.exp(normalized_score / beta) elif self.normalize_over_batch == "exp": weights = torch.exp(score / beta) elif self.normalize_over_batch == "step_fn": weights = (score > 0).float() elif not self.normalize_over_batch: weights = score else: error weights = weights[:, 0] policy_loss = alpha * log_pi.mean() if self.use_awr_update and self.weight_loss: policy_loss = policy_loss + self.awr_weight * ( -policy_logpp * len(weights) * weights.detach()).mean() elif self.use_awr_update: policy_loss = policy_loss + self.awr_weight * ( -policy_logpp).mean() if self.use_reparam_update: policy_loss = policy_loss + self.reparam_weight * ( -q_new_actions).mean() policy_loss = self.rl_weight * policy_loss if self.compute_bc: train_policy_loss, train_logp_loss, train_mse_loss, _ = self.run_bc_batch( self.demo_train_buffer, self.policy) policy_loss = policy_loss + self.bc_weight * train_policy_loss if not pretrain and self.buffer_policy_reset_period > 0 and self._n_train_steps_total % self.buffer_policy_reset_period == 0: del self.buffer_policy_optimizer self.buffer_policy_optimizer = self.optimizer_class( self.buffer_policy.parameters(), weight_decay=self.policy_weight_decay, lr=self.policy_lr, ) self.optimizers[self.buffer_policy] = self.buffer_policy_optimizer for i in range(self.num_buffer_policy_train_steps_on_reset): if self.train_bc_on_rl_buffer: if self.advantage_weighted_buffer_loss: buffer_dist = self.buffer_policy(obs) buffer_u = actions buffer_new_obs_actions, _ = buffer_dist.rsample_and_logprob( ) buffer_policy_logpp = buffer_dist.log_prob(buffer_u) buffer_policy_logpp = buffer_policy_logpp[:, None] buffer_q1_pred = self.qf1(obs, buffer_u) buffer_q2_pred = self.qf2(obs, buffer_u) buffer_q_adv = torch.min(buffer_q1_pred, buffer_q2_pred) buffer_v1_pi = self.qf1(obs, buffer_new_obs_actions) buffer_v2_pi = self.qf2(obs, buffer_new_obs_actions) buffer_v_pi = torch.min(buffer_v1_pi, buffer_v2_pi) buffer_score = buffer_q_adv - buffer_v_pi buffer_weights = F.softmax(buffer_score / beta, dim=0) buffer_policy_loss = self.awr_weight * ( -buffer_policy_logpp * len(buffer_weights) * buffer_weights.detach()).mean() else: buffer_policy_loss, buffer_train_logp_loss, buffer_train_mse_loss, _ = self.run_bc_batch( self.replay_buffer.train_replay_buffer, self.buffer_policy) self.buffer_policy_optimizer.zero_grad() buffer_policy_loss.backward(retain_graph=True) self.buffer_policy_optimizer.step() if self.train_bc_on_rl_buffer: if self.advantage_weighted_buffer_loss: buffer_dist = self.buffer_policy(obs) buffer_u = actions buffer_new_obs_actions, _ = buffer_dist.rsample_and_logprob() buffer_policy_logpp = buffer_dist.log_prob(buffer_u) buffer_policy_logpp = buffer_policy_logpp[:, None] buffer_q1_pred = self.qf1(obs, buffer_u) buffer_q2_pred = self.qf2(obs, buffer_u) buffer_q_adv = torch.min(buffer_q1_pred, buffer_q2_pred) buffer_v1_pi = self.qf1(obs, buffer_new_obs_actions) buffer_v2_pi = self.qf2(obs, buffer_new_obs_actions) buffer_v_pi = torch.min(buffer_v1_pi, buffer_v2_pi) buffer_score = buffer_q_adv - buffer_v_pi buffer_weights = F.softmax(buffer_score / beta, dim=0) buffer_policy_loss = self.awr_weight * ( -buffer_policy_logpp * len(buffer_weights) * buffer_weights.detach()).mean() else: buffer_policy_loss, buffer_train_logp_loss, buffer_train_mse_loss, _ = self.run_bc_batch( self.replay_buffer.train_replay_buffer, self.buffer_policy) """ Update networks """ if self._n_train_steps_total % self.q_update_period == 0: self.qf1_optimizer.zero_grad() qf1_loss.backward() self.qf1_optimizer.step() self.qf2_optimizer.zero_grad() qf2_loss.backward() self.qf2_optimizer.step() if self._n_train_steps_total % self.policy_update_period == 0 and self.update_policy: self.policy_optimizer.zero_grad() policy_loss.backward() self.policy_optimizer.step() if self.train_bc_on_rl_buffer and self._n_train_steps_total % self.policy_update_period == 0: self.buffer_policy_optimizer.zero_grad() buffer_policy_loss.backward() self.buffer_policy_optimizer.step() """ Soft Updates """ if self._n_train_steps_total % self.target_update_period == 0: ptu.soft_update_from_to(self.qf1, self.target_qf1, self.soft_target_tau) ptu.soft_update_from_to(self.qf2, self.target_qf2, self.soft_target_tau) """ Save some statistics for eval """ if self._need_to_update_eval_statistics: self._need_to_update_eval_statistics = False """ Eval should set this to None. This way, these statistics are only computed for one batch. """ policy_loss = (log_pi - q_new_actions).mean() self.eval_statistics['QF1 Loss'] = np.mean(ptu.get_numpy(qf1_loss)) self.eval_statistics['QF2 Loss'] = np.mean(ptu.get_numpy(qf2_loss)) self.eval_statistics['Policy Loss'] = np.mean( ptu.get_numpy(policy_loss)) self.eval_statistics.update( create_stats_ordered_dict( 'Q1 Predictions', ptu.get_numpy(q1_pred), )) self.eval_statistics.update( create_stats_ordered_dict( 'Q2 Predictions', ptu.get_numpy(q2_pred), )) self.eval_statistics.update( create_stats_ordered_dict( 'Q Targets', ptu.get_numpy(q_target), )) self.eval_statistics.update( create_stats_ordered_dict( 'Log Pis', ptu.get_numpy(log_pi), )) self.eval_statistics.update( create_stats_ordered_dict( 'rewards', ptu.get_numpy(rewards), )) self.eval_statistics.update( create_stats_ordered_dict( 'terminals', ptu.get_numpy(terminals), )) policy_statistics = add_prefix(dist.get_diagnostics(), "policy/") self.eval_statistics.update(policy_statistics) self.eval_statistics.update( create_stats_ordered_dict( 'Advantage Weights', ptu.get_numpy(weights), )) self.eval_statistics.update( create_stats_ordered_dict( 'Advantage Score', ptu.get_numpy(score), )) if self.normalize_over_state == "Z": self.eval_statistics.update( create_stats_ordered_dict( 'logZ', ptu.get_numpy(logZ), )) if self.use_automatic_entropy_tuning: self.eval_statistics['Alpha'] = alpha.item() self.eval_statistics['Alpha Loss'] = alpha_loss.item() if self.compute_bc: test_policy_loss, test_logp_loss, test_mse_loss, _ = self.run_bc_batch( self.demo_test_buffer, self.policy) self.eval_statistics.update({ "bc/Train Logprob Loss": ptu.get_numpy(train_logp_loss), "bc/Test Logprob Loss": ptu.get_numpy(test_logp_loss), "bc/Train MSE": ptu.get_numpy(train_mse_loss), "bc/Test MSE": ptu.get_numpy(test_mse_loss), "bc/train_policy_loss": ptu.get_numpy(train_policy_loss), "bc/test_policy_loss": ptu.get_numpy(test_policy_loss), }) if self.train_bc_on_rl_buffer: _, buffer_train_logp_loss, _, _ = self.run_bc_batch( self.replay_buffer.train_replay_buffer, self.buffer_policy) _, buffer_test_logp_loss, _, _ = self.run_bc_batch( self.replay_buffer.validation_replay_buffer, self.buffer_policy) buffer_dist = self.buffer_policy(obs) kldiv = torch.distributions.kl.kl_divergence(dist, buffer_dist) _, train_offline_logp_loss, _, _ = self.run_bc_batch( self.demo_train_buffer, self.buffer_policy) _, test_offline_logp_loss, _, _ = self.run_bc_batch( self.demo_test_buffer, self.buffer_policy) self.eval_statistics.update({ "buffer_policy/Train Online Logprob": -1 * ptu.get_numpy(buffer_train_logp_loss), "buffer_policy/Test Online Logprob": -1 * ptu.get_numpy(buffer_test_logp_loss), "buffer_policy/Train Offline Logprob": -1 * ptu.get_numpy(train_offline_logp_loss), "buffer_policy/Test Offline Logprob": -1 * ptu.get_numpy(test_offline_logp_loss), "buffer_policy/train_policy_loss": ptu.get_numpy(buffer_policy_loss), # "buffer_policy/test_policy_loss": ptu.get_numpy(buffer_test_policy_loss), "buffer_policy/kl_div": ptu.get_numpy(kldiv.mean()), }) if self.use_automatic_beta_tuning: self.eval_statistics.update({ "adaptive_beta/beta": ptu.get_numpy(beta.mean()), "adaptive_beta/beta loss": ptu.get_numpy(beta_loss.mean()), }) if self.validation_qlearning: train_data = self.replay_buffer.validation_replay_buffer.random_batch( self.bc_batch_size) train_data = np_to_pytorch_batch(train_data) obs = train_data['observations'] next_obs = train_data['next_observations'] # goals = train_data['resampled_goals'] train_data[ 'observations'] = obs # torch.cat((obs, goals), dim=1) train_data[ 'next_observations'] = next_obs # torch.cat((next_obs, goals), dim=1) self.test_from_torch(train_data) self._n_train_steps_total += 1