def update_target_networks(self): ptu.soft_update_from_to( self.qf1, self.target_qf1, self.soft_target_tau ) ptu.soft_update_from_to( self.qf2, self.target_qf2, self.soft_target_tau )
def _update_target_network(self, high=True): if high: ptu.soft_update_from_to(self.high_vf, self.high_vf_target, self.soft_target_tau) else: ptu.soft_update_from_to(self.low_vf, self.low_vf_target, self.soft_target_tau)
def _update_target_networks(self): if self.use_soft_update: ptu.soft_update_from_to(self.qf1, self.target_qf1, self.tau) ptu.soft_update_from_to(self.qf2, self.target_qf2, self.tau) else: if self._n_train_steps_total % self.target_hard_update_period == 0: ptu.copy_model_params_from_to(self.qf1, self.target_qf1) ptu.copy_model_params_from_to(self.qf2, self.target_qf2)
def _update_target_networks(self): if self.use_soft_update: ptu.soft_update_from_to(self.policy, self.target_policy, self.tau) ptu.soft_update_from_to(self.qf, self.target_qf, self.tau) else: if self._n_env_steps_total % self.target_hard_update_period == 0: ptu.copy_model_params_from_to(self.qf, self.target_qf) ptu.copy_model_params_from_to(self.policy, self.target_policy)
def train_from_torch(self, batch): rewards = batch["rewards"] terminals = batch["terminals"] obs = batch["observations"] actions = batch["actions"] next_obs = batch["next_observations"] try: plan_lengths = batch["plan_lengths"] if self.single_plan_discounting: plan_lengths = torch.ones_like(plan_lengths) except KeyError as e: plan_lengths = torch.ones_like(rewards) """ Compute loss """ best_action_idxs = self.qf(next_obs).max(1, keepdim=True)[1] target_q_values = self.target_qf(next_obs).gather(1, best_action_idxs).detach() y_target = ( rewards + (1.0 - terminals) * torch.pow(self.discount, plan_lengths) * target_q_values ) y_target = y_target.detach() # actions is a one-hot vector y_pred = torch.sum(self.qf(obs) * actions, dim=1, keepdim=True) if self.huber_loss: y_target = torch.max(y_target, y_pred.sub(1)) y_target = torch.min(y_target, y_pred.add(1)) qf_loss = self.qf_criterion(y_pred, y_target) """ Update networks """ self.qf_optimizer.zero_grad() qf_loss.backward() self.qf_optimizer.step() """ Soft target network updates """ if self._n_train_steps_total % self.target_update_period == 0: ptu.soft_update_from_to(self.qf, self.target_qf, self.soft_target_tau) """ Save some statistics for eval using just one batch. """ if self._need_to_update_eval_statistics: self._need_to_update_eval_statistics = False self.eval_statistics["QF Loss"] = np.mean(ptu.get_numpy(qf_loss)) self.eval_statistics.update( create_stats_ordered_dict("Y Predictions", ptu.get_numpy(y_pred)) )
def train_from_torch(self, batch): rewards = batch['rewards'] * self.reward_scale terminals = batch['terminals'] obs = batch['observations'] actions = batch['actions'] next_obs = batch['next_observations'] """ Compute loss """ qf_losses = [] for ensemble_idx in range(self.ensemble_size): qf = self.qfs[ensemble_idx] target_qf = self.target_qfs[ensemble_idx] target_q_values = target_qf(next_obs).detach().max( 1, keepdim=True )[0] y_target = rewards + (1. - terminals) * self.discount * target_q_values y_target = y_target.detach() # actions is a one-hot vector y_pred = torch.sum(qf(obs) * actions, dim=1, keepdim=True) qf_loss = self.qf_criterion(y_pred, y_target) qf_losses.append(qf_loss) """ Save some statistics for eval using just one batch. """ if self._need_to_update_eval_statistics: if ensemble_idx == self.ensemble_size - 1: self._need_to_update_eval_statistics = False self.eval_statistics['QF %d Loss' % ensemble_idx] = np.mean(ptu.get_numpy(qf_loss)) self.eval_statistics.update(create_stats_ordered_dict( 'Y %d Predictions' % ensemble_idx, ptu.get_numpy(y_pred), )) """ Soft target network updates """ self.qf_optimizer.zero_grad() total_qf_loss = sum(qf_losses) total_qf_loss.backward() self.qf_optimizer.step() for ensemble_idx in range(self.ensemble_size): qf = self.qfs[ensemble_idx] target_qf = self.target_qfs[ensemble_idx] """ Soft Updates """ if self._n_train_steps_total % self.target_update_period == 0: ptu.soft_update_from_to( qf, target_qf, self.soft_target_tau )
def _do_training(self): batch = self.get_batch() """ Optimize Critic/Actor. """ rewards = batch['rewards'] terminals = batch['terminals'] obs = batch['observations'] actions = batch['actions'] next_obs = batch['next_observations'] _, _, v_pred = self.target_policy(next_obs, None) y_target = self.reward_scale * rewards + ( 1. - terminals) * self.discount * v_pred y_target = y_target.detach() mu, y_pred, v = self.policy(obs, actions) policy_loss = self.policy_criterion(y_pred, y_target) self.policy_optimizer.zero_grad() policy_loss.backward() self.policy_optimizer.step() """ Update Target Networks """ if self.use_soft_update: ptu.soft_update_from_to(self.policy, self.target_policy, self.tau) else: if self._n_train_steps_total % self.target_hard_update_period == 0: ptu.copy_model_params_from_to(self.policy, self.target_policy) if self.need_to_update_eval_statistics: self.need_to_update_eval_statistics = False self.eval_statistics['Policy Loss'] = np.mean( ptu.get_numpy(policy_loss)) self.eval_statistics.update( create_stats_ordered_dict( 'Policy v', ptu.get_numpy(v), )) self.eval_statistics.update( create_stats_ordered_dict( 'Policy mu', ptu.get_numpy(mu), )) self.eval_statistics.update( create_stats_ordered_dict( 'Y targets', ptu.get_numpy(y_target), )) self.eval_statistics.update( create_stats_ordered_dict( 'Y predictions', ptu.get_numpy(y_pred), ))
def train_from_torch(self, batch): rewards = batch['rewards'] terminals = batch['terminals'] obs = batch['observations'] actions = batch['actions'] next_obs = batch['next_observations'] """ Compute loss """ best_action_idxs = self.qf(next_obs).max( 1, keepdim=True )[1] target_q_values = self.target_qf(next_obs).gather( 1, best_action_idxs ).detach() y_target = rewards + (1. - terminals) * self.discount * target_q_values y_target = y_target.detach() # actions is a one-hot vector y_pred = torch.sum(self.qf(obs) * actions, dim=1, keepdim=True) qf_loss = self.qf_criterion(y_pred, y_target) """ Update networks """ self.qf_optimizer.zero_grad() qf_loss.backward() self.qf_optimizer.step() """ Soft target network updates """ if self._n_train_steps_total % self.target_update_period == 0: ptu.soft_update_from_to( self.qf, self.target_qf, self.soft_target_tau ) """ Save some statistics for eval using just one batch. """ if self._need_to_update_eval_statistics: self._need_to_update_eval_statistics = False self.eval_statistics['QF Loss'] = np.mean(ptu.get_numpy(qf_loss)) self.eval_statistics.update(create_stats_ordered_dict( 'Y Predictions', ptu.get_numpy(y_pred), )) self._n_train_steps_total += 1
def low_train_from_torch(self, batch): rewards = batch['rewards'] terminals = batch['terminals'] obs = batch['observations'] actions = batch['actions'] next_obs = batch['next_observations'] goals = batch['goals'] # kinda an approximation since doesn't account for goal switching next_goals = self.setter.goal_transition(obs, goals, next_obs) """ Compute loss """ best_action_idxs = self.low_qf(torch.cat( (next_obs, next_goals), dim=1)).max(1, keepdim=True)[1] target_q_values = self.low_target_qf( torch.cat((next_obs, next_goals), dim=1)).gather(1, best_action_idxs).detach() y_target = rewards + (1. - terminals) * self.discount * target_q_values y_target = y_target.detach() # actions is a one-hot vector y_pred = torch.sum(self.low_qf(torch.cat( (obs, goals), dim=1)) * actions, dim=1, keepdim=True) qf_loss = self.qf_criterion(y_pred, y_target) """ Update networks """ self.low_qf_optimizer.zero_grad() qf_loss.backward() if self.grad_clip_val is not None: nn.utils.clip_grad_norm_(self.low_qf.parameters(), self.grad_clip_val) self.low_qf_optimizer.step() """ Soft target network updates """ if self._n_train_steps_total % self.setter_and_target_update_period == 0: ptu.soft_update_from_to(self.low_qf, self.low_target_qf, self.tau) """ Save some statistics for eval using just one batch. """ if self._need_to_update_eval_statistics: self._need_to_update_eval_statistics = False self.eval_statistics['QF Loss'] = np.mean(ptu.get_numpy(qf_loss)) self.eval_statistics.update( create_stats_ordered_dict( 'Y Predictions', ptu.get_numpy(y_pred), ))
def _do_training(self, n_steps_total): raw_subtraj_batch, start_indices = ( self.replay_buffer.train_replay_buffer.random_subtrajectories( self.num_subtrajs_per_batch)) subtraj_batch = create_torch_subtraj_batch(raw_subtraj_batch) if self.save_memory_gradients: subtraj_batch['memories'].requires_grad = True self.train_critic(subtraj_batch) self.train_policy(subtraj_batch, start_indices) if self.use_soft_update: ptu.soft_update_from_to(self.policy, self.target_policy, self.tau) ptu.soft_update_from_to(self.qf, self.target_qf, self.tau) else: if n_steps_total % self.target_hard_update_period == 0: ptu.copy_model_params_from_to(self.qf, self.target_qf) ptu.copy_model_params_from_to(self.policy, self.target_policy)
def train_from_torch(self, batch): rewards = batch["rewards"] * self.reward_scale terminals = batch["terminals"] obs = batch["observations"] actions = batch["actions"] next_obs = batch["next_observations"] """ Compute loss """ target_q_values = self.target_qf(next_obs).detach().max( 1, keepdim=True)[0] y_target = rewards + (1.0 - terminals) * self.discount * target_q_values y_target = y_target.detach() # actions is a one-hot vector y_pred = torch.sum(self.qf(obs) * actions, dim=1, keepdim=True) qf_loss = self.qf_criterion(y_pred, y_target) """ Soft target network updates """ self.qf_optimizer.zero_grad() qf_loss.backward() self.qf_optimizer.step() """ Soft Updates """ if self._n_train_steps_total % self.target_update_period == 0: ptu.soft_update_from_to(self.qf, self.target_qf, self.soft_target_tau) """ Save some statistics for eval using just one batch. """ if self._need_to_update_eval_statistics: self._need_to_update_eval_statistics = False self.eval_statistics["QF Loss"] = np.mean(ptu.get_numpy(qf_loss)) self.eval_statistics.update( create_stats_ordered_dict( "Y Predictions", ptu.get_numpy(y_pred), )) self._n_train_steps_total += 1
def _pre_take_step(self, indices, context1, context1_): #num_tasks = len(indices) # data is (task, batch, feat) #positive sample self.curl_optimizer.zero_grad() self.encoder_optimizer.zero_grad() if self.use_information_bottleneck: kl_div = self.agent.compute_kl_div() kl_loss = self.kl_lambda / 1000 * kl_div kl_loss.backward(retain_graph=True) loss.backward() self.curl_optimizer.step() self.encoder_optimizer.step() ptu.soft_update_from_to(self.agent.context_encoder, self.agent.context_encoder_target, self.encoder_tau)
def _update_target_networks(self): for cg1, target_cg1, qf1, target_qf1, cg2, target_cg2, qf2, target_qf2 in \ zip(self.cg1_n, self.target_cg1_n, self.qf1_n, self.target_qf1_n, self.cg2_n, self.target_cg2_n, self.qf2_n, self.target_qf2_n): if self.use_soft_update: ptu.soft_update_from_to(cg1, target_cg1, self.tau) ptu.soft_update_from_to(qf1, target_qf1, self.tau) ptu.soft_update_from_to(cg2, target_cg2, self.tau) ptu.soft_update_from_to(qf2, target_qf2, self.tau) else: if self._n_train_steps_total % self.target_hard_update_period == 0: ptu.copy_model_params_from_to(cg1, target_cg1) ptu.copy_model_params_from_to(qf1, target_qf1) ptu.copy_model_params_from_to(cg2, target_cg2) ptu.copy_model_params_from_to(qf2, target_qf2)
def _update_target_networks(self): for policy, target_policy, qf, target_qf in \ zip(self.policy_n, self.target_policy_n, self.qf_n, self.target_qf_n): if self.use_soft_update: ptu.soft_update_from_to(policy, target_policy, self.tau) ptu.soft_update_from_to(qf, target_qf, self.tau) else: if self._n_train_steps_total % self.target_hard_update_period == 0: ptu.copy_model_params_from_to(qf, target_qf) ptu.copy_model_params_from_to(policy, target_policy) if self.double_q: for qf2, target_qf2 in zip(self.qf2_n, self.target_qf2_n): if self.use_soft_update: ptu.soft_update_from_to(qf2, target_qf2, self.tau) else: if self._n_train_steps_total % self.target_hard_update_period == 0: ptu.copy_model_params_from_to(qf2, target_qf2)
def train_from_torch( self, batch, train=True, pretrain=False, ): """ :param batch: :param train: :param pretrain: :return: """ rewards = batch['rewards'] terminals = batch['terminals'] obs = batch['observations'] actions = batch['actions'] next_obs = batch['next_observations'] weights = batch.get('weights', None) if self.reward_transform: rewards = self.reward_transform(rewards) if self.terminal_transform: terminals = self.terminal_transform(terminals) """ Policy and Alpha Loss """ dist = self.policy(obs) new_obs_actions, log_pi = dist.rsample_and_logprob() policy_mle = dist.mle_estimate() if self.brac: buf_dist = self.buffer_policy(obs) buf_log_pi = buf_dist.log_prob(actions) rewards = rewards + buf_log_pi if self.use_automatic_entropy_tuning: alpha_loss = -(self.log_alpha * (log_pi + self.target_entropy).detach()).mean() self.alpha_optimizer.zero_grad() alpha_loss.backward() self.alpha_optimizer.step() alpha = self.log_alpha.exp() else: alpha_loss = 0 alpha = self.alpha """ QF Loss """ q1_pred = self.qf1(obs, actions) q2_pred = self.qf2(obs, actions) # Make sure policy accounts for squashing functions like tanh correctly! next_dist = self.policy(next_obs) new_next_actions, new_log_pi = next_dist.rsample_and_logprob() target_q_values = torch.min( self.target_qf1(next_obs, new_next_actions), self.target_qf2(next_obs, new_next_actions), ) - alpha * new_log_pi q_target = self.reward_scale * rewards + ( 1. - terminals) * self.discount * target_q_values qf1_loss = self.qf_criterion(q1_pred, q_target.detach()) qf2_loss = self.qf_criterion(q2_pred, q_target.detach()) """ Policy Loss """ qf1_new_actions = self.qf1(obs, new_obs_actions) qf2_new_actions = self.qf2(obs, new_obs_actions) q_new_actions = torch.min( qf1_new_actions, qf2_new_actions, ) # Advantage-weighted regression if self.awr_use_mle_for_vf: v1_pi = self.qf1(obs, policy_mle) v2_pi = self.qf2(obs, policy_mle) v_pi = torch.min(v1_pi, v2_pi) else: if self.vf_K > 1: vs = [] for i in range(self.vf_K): u = dist.sample() q1 = self.qf1(obs, u) q2 = self.qf2(obs, u) v = torch.min(q1, q2) # v = q1 vs.append(v) v_pi = torch.cat(vs, 1).mean(dim=1) else: # v_pi = self.qf1(obs, new_obs_actions) v1_pi = self.qf1(obs, new_obs_actions) v2_pi = self.qf2(obs, new_obs_actions) v_pi = torch.min(v1_pi, v2_pi) if self.awr_sample_actions: u = new_obs_actions if self.awr_min_q: q_adv = q_new_actions else: q_adv = qf1_new_actions elif self.buffer_policy_sample_actions: buf_dist = self.buffer_policy(obs) u, _ = buf_dist.rsample_and_logprob() qf1_buffer_actions = self.qf1(obs, u) qf2_buffer_actions = self.qf2(obs, u) q_buffer_actions = torch.min( qf1_buffer_actions, qf2_buffer_actions, ) if self.awr_min_q: q_adv = q_buffer_actions else: q_adv = qf1_buffer_actions else: u = actions if self.awr_min_q: q_adv = torch.min(q1_pred, q2_pred) else: q_adv = q1_pred policy_logpp = dist.log_prob(u) if self.use_automatic_beta_tuning: buffer_dist = self.buffer_policy(obs) beta = self.log_beta.exp() kldiv = torch.distributions.kl.kl_divergence(dist, buffer_dist) beta_loss = -1 * (beta * (kldiv - self.beta_epsilon).detach()).mean() self.beta_optimizer.zero_grad() beta_loss.backward() self.beta_optimizer.step() else: beta = self.beta_schedule.get_value(self._n_train_steps_total) if self.normalize_over_state == "advantage": score = q_adv - v_pi if self.mask_positive_advantage: score = torch.sign(score) elif self.normalize_over_state == "Z": buffer_dist = self.buffer_policy(obs) K = self.Z_K buffer_obs = [] buffer_actions = [] log_bs = [] log_pis = [] for i in range(K): u = buffer_dist.sample() log_b = buffer_dist.log_prob(u) log_pi = dist.log_prob(u) buffer_obs.append(obs) buffer_actions.append(u) log_bs.append(log_b) log_pis.append(log_pi) buffer_obs = torch.cat(buffer_obs, 0) buffer_actions = torch.cat(buffer_actions, 0) p_buffer = torch.exp(torch.cat(log_bs, 0).sum(dim=1, )) log_pi = torch.cat(log_pis, 0) log_pi = log_pi.sum(dim=1, ) q1_b = self.qf1(buffer_obs, buffer_actions) q2_b = self.qf2(buffer_obs, buffer_actions) q_b = torch.min(q1_b, q2_b) q_b = torch.reshape(q_b, (-1, K)) adv_b = q_b - v_pi # if self._n_train_steps_total % 100 == 0: # import ipdb; ipdb.set_trace() # Z = torch.exp(adv_b / beta).mean(dim=1, keepdim=True) # score = torch.exp((q_adv - v_pi) / beta) / Z # score = score / sum(score) logK = torch.log(ptu.tensor(float(K))) logZ = torch.logsumexp(adv_b / beta - logK, dim=1, keepdim=True) logS = (q_adv - v_pi) / beta - logZ # logZ = torch.logsumexp(q_b/beta - logK, dim=1, keepdim=True) # logS = q_adv/beta - logZ score = F.softmax(logS, dim=0) # score / sum(score) else: error if self.clip_score is not None: score = torch.clamp(score, max=self.clip_score) if self.weight_loss and weights is None: if self.normalize_over_batch: weights = F.softmax(score / beta, dim=0) elif self.normalize_over_batch == "whiten": adv_mean = torch.mean(score) adv_std = torch.std(score) + 1e-5 normalized_score = (score - adv_mean) / adv_std weights = torch.exp(normalized_score / beta) elif self.normalize_over_batch == "exp": weights = torch.exp(score / beta) elif self.normalize_over_batch == "step_fn": weights = (score > 0).float() elif not self.normalize_over_batch: weights = score else: error weights = weights[:, 0] policy_loss = alpha * log_pi.mean() if self.use_awr_update and self.weight_loss: policy_loss = policy_loss + self.awr_weight * ( -policy_logpp * len(weights) * weights.detach()).mean() elif self.use_awr_update: policy_loss = policy_loss + self.awr_weight * ( -policy_logpp).mean() if self.use_reparam_update: policy_loss = policy_loss + self.reparam_weight * ( -q_new_actions).mean() policy_loss = self.rl_weight * policy_loss if self.compute_bc: train_policy_loss, train_logp_loss, train_mse_loss, _ = self.run_bc_batch( self.demo_train_buffer, self.policy) policy_loss = policy_loss + self.bc_weight * train_policy_loss if not pretrain and self.buffer_policy_reset_period > 0 and self._n_train_steps_total % self.buffer_policy_reset_period == 0: del self.buffer_policy_optimizer self.buffer_policy_optimizer = self.optimizer_class( self.buffer_policy.parameters(), weight_decay=self.policy_weight_decay, lr=self.policy_lr, ) self.optimizers[self.buffer_policy] = self.buffer_policy_optimizer for i in range(self.num_buffer_policy_train_steps_on_reset): if self.train_bc_on_rl_buffer: if self.advantage_weighted_buffer_loss: buffer_dist = self.buffer_policy(obs) buffer_u = actions buffer_new_obs_actions, _ = buffer_dist.rsample_and_logprob( ) buffer_policy_logpp = buffer_dist.log_prob(buffer_u) buffer_policy_logpp = buffer_policy_logpp[:, None] buffer_q1_pred = self.qf1(obs, buffer_u) buffer_q2_pred = self.qf2(obs, buffer_u) buffer_q_adv = torch.min(buffer_q1_pred, buffer_q2_pred) buffer_v1_pi = self.qf1(obs, buffer_new_obs_actions) buffer_v2_pi = self.qf2(obs, buffer_new_obs_actions) buffer_v_pi = torch.min(buffer_v1_pi, buffer_v2_pi) buffer_score = buffer_q_adv - buffer_v_pi buffer_weights = F.softmax(buffer_score / beta, dim=0) buffer_policy_loss = self.awr_weight * ( -buffer_policy_logpp * len(buffer_weights) * buffer_weights.detach()).mean() else: buffer_policy_loss, buffer_train_logp_loss, buffer_train_mse_loss, _ = self.run_bc_batch( self.replay_buffer.train_replay_buffer, self.buffer_policy) self.buffer_policy_optimizer.zero_grad() buffer_policy_loss.backward(retain_graph=True) self.buffer_policy_optimizer.step() if self.train_bc_on_rl_buffer: if self.advantage_weighted_buffer_loss: buffer_dist = self.buffer_policy(obs) buffer_u = actions buffer_new_obs_actions, _ = buffer_dist.rsample_and_logprob() buffer_policy_logpp = buffer_dist.log_prob(buffer_u) buffer_policy_logpp = buffer_policy_logpp[:, None] buffer_q1_pred = self.qf1(obs, buffer_u) buffer_q2_pred = self.qf2(obs, buffer_u) buffer_q_adv = torch.min(buffer_q1_pred, buffer_q2_pred) buffer_v1_pi = self.qf1(obs, buffer_new_obs_actions) buffer_v2_pi = self.qf2(obs, buffer_new_obs_actions) buffer_v_pi = torch.min(buffer_v1_pi, buffer_v2_pi) buffer_score = buffer_q_adv - buffer_v_pi buffer_weights = F.softmax(buffer_score / beta, dim=0) buffer_policy_loss = self.awr_weight * ( -buffer_policy_logpp * len(buffer_weights) * buffer_weights.detach()).mean() else: buffer_policy_loss, buffer_train_logp_loss, buffer_train_mse_loss, _ = self.run_bc_batch( self.replay_buffer.train_replay_buffer, self.buffer_policy) """ Update networks """ if self._n_train_steps_total % self.q_update_period == 0: self.qf1_optimizer.zero_grad() qf1_loss.backward() self.qf1_optimizer.step() self.qf2_optimizer.zero_grad() qf2_loss.backward() self.qf2_optimizer.step() if self._n_train_steps_total % self.policy_update_period == 0 and self.update_policy: self.policy_optimizer.zero_grad() policy_loss.backward() self.policy_optimizer.step() if self.train_bc_on_rl_buffer and self._n_train_steps_total % self.policy_update_period == 0: self.buffer_policy_optimizer.zero_grad() buffer_policy_loss.backward() self.buffer_policy_optimizer.step() """ Soft Updates """ if self._n_train_steps_total % self.target_update_period == 0: ptu.soft_update_from_to(self.qf1, self.target_qf1, self.soft_target_tau) ptu.soft_update_from_to(self.qf2, self.target_qf2, self.soft_target_tau) """ Save some statistics for eval """ if self._need_to_update_eval_statistics: self._need_to_update_eval_statistics = False """ Eval should set this to None. This way, these statistics are only computed for one batch. """ policy_loss = (log_pi - q_new_actions).mean() self.eval_statistics['QF1 Loss'] = np.mean(ptu.get_numpy(qf1_loss)) self.eval_statistics['QF2 Loss'] = np.mean(ptu.get_numpy(qf2_loss)) self.eval_statistics['Policy Loss'] = np.mean( ptu.get_numpy(policy_loss)) self.eval_statistics.update( create_stats_ordered_dict( 'Q1 Predictions', ptu.get_numpy(q1_pred), )) self.eval_statistics.update( create_stats_ordered_dict( 'Q2 Predictions', ptu.get_numpy(q2_pred), )) self.eval_statistics.update( create_stats_ordered_dict( 'Q Targets', ptu.get_numpy(q_target), )) self.eval_statistics.update( create_stats_ordered_dict( 'Log Pis', ptu.get_numpy(log_pi), )) self.eval_statistics.update( create_stats_ordered_dict( 'rewards', ptu.get_numpy(rewards), )) self.eval_statistics.update( create_stats_ordered_dict( 'terminals', ptu.get_numpy(terminals), )) policy_statistics = add_prefix(dist.get_diagnostics(), "policy/") self.eval_statistics.update(policy_statistics) self.eval_statistics.update( create_stats_ordered_dict( 'Advantage Weights', ptu.get_numpy(weights), )) self.eval_statistics.update( create_stats_ordered_dict( 'Advantage Score', ptu.get_numpy(score), )) if self.normalize_over_state == "Z": self.eval_statistics.update( create_stats_ordered_dict( 'logZ', ptu.get_numpy(logZ), )) if self.use_automatic_entropy_tuning: self.eval_statistics['Alpha'] = alpha.item() self.eval_statistics['Alpha Loss'] = alpha_loss.item() if self.compute_bc: test_policy_loss, test_logp_loss, test_mse_loss, _ = self.run_bc_batch( self.demo_test_buffer, self.policy) self.eval_statistics.update({ "bc/Train Logprob Loss": ptu.get_numpy(train_logp_loss), "bc/Test Logprob Loss": ptu.get_numpy(test_logp_loss), "bc/Train MSE": ptu.get_numpy(train_mse_loss), "bc/Test MSE": ptu.get_numpy(test_mse_loss), "bc/train_policy_loss": ptu.get_numpy(train_policy_loss), "bc/test_policy_loss": ptu.get_numpy(test_policy_loss), }) if self.train_bc_on_rl_buffer: _, buffer_train_logp_loss, _, _ = self.run_bc_batch( self.replay_buffer.train_replay_buffer, self.buffer_policy) _, buffer_test_logp_loss, _, _ = self.run_bc_batch( self.replay_buffer.validation_replay_buffer, self.buffer_policy) buffer_dist = self.buffer_policy(obs) kldiv = torch.distributions.kl.kl_divergence(dist, buffer_dist) _, train_offline_logp_loss, _, _ = self.run_bc_batch( self.demo_train_buffer, self.buffer_policy) _, test_offline_logp_loss, _, _ = self.run_bc_batch( self.demo_test_buffer, self.buffer_policy) self.eval_statistics.update({ "buffer_policy/Train Online Logprob": -1 * ptu.get_numpy(buffer_train_logp_loss), "buffer_policy/Test Online Logprob": -1 * ptu.get_numpy(buffer_test_logp_loss), "buffer_policy/Train Offline Logprob": -1 * ptu.get_numpy(train_offline_logp_loss), "buffer_policy/Test Offline Logprob": -1 * ptu.get_numpy(test_offline_logp_loss), "buffer_policy/train_policy_loss": ptu.get_numpy(buffer_policy_loss), # "buffer_policy/test_policy_loss": ptu.get_numpy(buffer_test_policy_loss), "buffer_policy/kl_div": ptu.get_numpy(kldiv.mean()), }) if self.use_automatic_beta_tuning: self.eval_statistics.update({ "adaptive_beta/beta": ptu.get_numpy(beta.mean()), "adaptive_beta/beta loss": ptu.get_numpy(beta_loss.mean()), }) if self.validation_qlearning: train_data = self.replay_buffer.validation_replay_buffer.random_batch( self.bc_batch_size) train_data = np_to_pytorch_batch(train_data) obs = train_data['observations'] next_obs = train_data['next_observations'] # goals = train_data['resampled_goals'] train_data[ 'observations'] = obs # torch.cat((obs, goals), dim=1) train_data[ 'next_observations'] = next_obs # torch.cat((next_obs, goals), dim=1) self.test_from_torch(train_data) self._n_train_steps_total += 1
def train_from_torch(self, batch): rewards = batch['rewards'] terminals = batch['terminals'] obs = batch['observations'] actions = batch['actions'] next_obs = batch['next_observations'] """ Critic operations. """ next_actions = self.target_policy(next_obs) noise = ptu.randn(next_actions.shape) * self.target_policy_noise noise = torch.clamp( noise, -self.target_policy_noise_clip, self.target_policy_noise_clip ) noisy_next_actions = next_actions + noise target_q1_values = self.target_qf1(next_obs, noisy_next_actions) target_q2_values = self.target_qf2(next_obs, noisy_next_actions) target_q_values = torch.min(target_q1_values, target_q2_values) q_target = self.reward_scale * rewards + (1. - terminals) * self.discount * target_q_values q_target = q_target.detach() q1_pred = self.qf1(obs, actions) bellman_errors_1 = (q1_pred - q_target) ** 2 qf1_loss = bellman_errors_1.mean() q2_pred = self.qf2(obs, actions) bellman_errors_2 = (q2_pred - q_target) ** 2 qf2_loss = bellman_errors_2.mean() """ Update Networks """ self.qf1_optimizer.zero_grad() qf1_loss.backward() self.qf1_optimizer.step() self.qf2_optimizer.zero_grad() qf2_loss.backward() self.qf2_optimizer.step() policy_actions = policy_loss = None if self._n_train_steps_total % self.policy_and_target_update_period == 0: policy_actions = self.policy(obs) q_output = self.qf1(obs, policy_actions) policy_loss = - q_output.mean() self.policy_optimizer.zero_grad() policy_loss.backward() self.policy_optimizer.step() ptu.soft_update_from_to(self.policy, self.target_policy, self.tau) ptu.soft_update_from_to(self.qf1, self.target_qf1, self.tau) ptu.soft_update_from_to(self.qf2, self.target_qf2, self.tau) if self._need_to_update_eval_statistics: self._need_to_update_eval_statistics = False if policy_loss is None: policy_actions = self.policy(obs) q_output = self.qf1(obs, policy_actions) policy_loss = - q_output.mean() self.eval_statistics['QF1 Loss'] = np.mean(ptu.get_numpy(qf1_loss)) self.eval_statistics['QF2 Loss'] = np.mean(ptu.get_numpy(qf2_loss)) self.eval_statistics['Policy Loss'] = np.mean(ptu.get_numpy( policy_loss )) self.eval_statistics.update(create_stats_ordered_dict( 'Q1 Predictions', ptu.get_numpy(q1_pred), )) self.eval_statistics.update(create_stats_ordered_dict( 'Q2 Predictions', ptu.get_numpy(q2_pred), )) self.eval_statistics.update(create_stats_ordered_dict( 'Q Targets', ptu.get_numpy(q_target), )) self.eval_statistics.update(create_stats_ordered_dict( 'Bellman Errors 1', ptu.get_numpy(bellman_errors_1), )) self.eval_statistics.update(create_stats_ordered_dict( 'Bellman Errors 2', ptu.get_numpy(bellman_errors_2), )) self.eval_statistics.update(create_stats_ordered_dict( 'Policy Action', ptu.get_numpy(policy_actions), )) self._n_train_steps_total += 1
def train_from_torch(self, batch, train=True, pretrain=False,): rewards = batch['rewards'] terminals = batch['terminals'] obs = batch['observations'] actions = batch['actions'] next_obs = batch['next_observations'] if self.reward_transform: rewards = self.reward_transform(rewards) if self.terminal_transform: terminals = self.terminal_transform(terminals) """ Policy and Alpha Loss """ dist = self.policy(obs) """ QF Loss """ q1_pred = self.qf1(obs, actions) q2_pred = self.qf2(obs, actions) target_vf_pred = self.vf(next_obs).detach() q_target = self.reward_scale * rewards + (1. - terminals) * self.discount * target_vf_pred q_target = q_target.detach() qf1_loss = self.qf_criterion(q1_pred, q_target) qf2_loss = self.qf_criterion(q2_pred, q_target) """ VF Loss """ q_pred = torch.min( self.target_qf1(obs, actions), self.target_qf2(obs, actions), ).detach() vf_pred = self.vf(obs) vf_err = vf_pred - q_pred vf_sign = (vf_err > 0).float() vf_weight = (1 - vf_sign) * self.quantile + vf_sign * (1 - self.quantile) vf_loss = (vf_weight * (vf_err ** 2)).mean() """ Policy Loss """ policy_logpp = dist.log_prob(actions) adv = q_pred - vf_pred exp_adv = torch.exp(adv / self.beta) if self.clip_score is not None: exp_adv = torch.clamp(exp_adv, max=self.clip_score) weights = exp_adv[:, 0].detach() policy_loss = (-policy_logpp * weights).mean() """ Update networks """ if self._n_train_steps_total % self.q_update_period == 0: self.qf1_optimizer.zero_grad() qf1_loss.backward() self.qf1_optimizer.step() self.qf2_optimizer.zero_grad() qf2_loss.backward() self.qf2_optimizer.step() self.vf_optimizer.zero_grad() vf_loss.backward() self.vf_optimizer.step() if self._n_train_steps_total % self.policy_update_period == 0: self.policy_optimizer.zero_grad() policy_loss.backward() self.policy_optimizer.step() """ Soft Updates """ if self._n_train_steps_total % self.target_update_period == 0: ptu.soft_update_from_to( self.qf1, self.target_qf1, self.soft_target_tau ) ptu.soft_update_from_to( self.qf2, self.target_qf2, self.soft_target_tau ) """ Save some statistics for eval """ if self._need_to_update_eval_statistics: self._need_to_update_eval_statistics = False """ Eval should set this to None. This way, these statistics are only computed for one batch. """ self.eval_statistics['QF1 Loss'] = np.mean(ptu.get_numpy(qf1_loss)) self.eval_statistics['QF2 Loss'] = np.mean(ptu.get_numpy(qf2_loss)) self.eval_statistics['Policy Loss'] = np.mean(ptu.get_numpy( policy_loss )) self.eval_statistics.update(create_stats_ordered_dict( 'Q1 Predictions', ptu.get_numpy(q1_pred), )) self.eval_statistics.update(create_stats_ordered_dict( 'Q2 Predictions', ptu.get_numpy(q2_pred), )) self.eval_statistics.update(create_stats_ordered_dict( 'Q Targets', ptu.get_numpy(q_target), )) self.eval_statistics.update(create_stats_ordered_dict( 'rewards', ptu.get_numpy(rewards), )) self.eval_statistics.update(create_stats_ordered_dict( 'terminals', ptu.get_numpy(terminals), )) self.eval_statistics['replay_buffer_len'] = self.replay_buffer._size policy_statistics = add_prefix(dist.get_diagnostics(), "policy/") self.eval_statistics.update(policy_statistics) self.eval_statistics.update(create_stats_ordered_dict( 'Advantage Weights', ptu.get_numpy(weights), )) self.eval_statistics.update(create_stats_ordered_dict( 'Advantage Score', ptu.get_numpy(adv), )) self.eval_statistics.update(create_stats_ordered_dict( 'V1 Predictions', ptu.get_numpy(vf_pred), )) self.eval_statistics['VF Loss'] = np.mean(ptu.get_numpy(vf_loss)) self._n_train_steps_total += 1
def train_from_torch(self, batch): self._current_epoch += 1 rewards = batch['rewards'] terminals = batch['terminals'] obs = batch['observations'] actions = batch['actions'] next_obs = batch['next_observations'] """ Behavior clone a policy """ recon, mean, std = self.vae(obs, actions) recon_loss = self.qf_criterion(recon, actions) kl_loss = -0.5 * (1 + torch.log(std.pow(2)) - mean.pow(2) - std.pow(2)).mean() vae_loss = recon_loss + 0.5 * kl_loss self.vae_optimizer.zero_grad() vae_loss.backward() self.vae_optimizer.step() """ Critic Training """ # import ipdb; ipdb.set_trace() with torch.no_grad(): # Duplicate state 10 times (10 is a hyperparameter chosen by BCQ) state_rep = next_obs.unsqueeze(1).repeat(1, 10, 1).view( next_obs.shape[0] * 10, next_obs.shape[1]) # 10BxS # Compute value of perturbed actions sampled from the VAE action_rep = self.policy(state_rep)[0] target_qf1 = self.target_qf1(state_rep, action_rep) target_qf2 = self.target_qf2(state_rep, action_rep) # Soft Clipped Double Q-learning target_Q = 0.75 * torch.min(target_qf1, target_qf2) + 0.25 * torch.max( target_qf1, target_qf2) target_Q = target_Q.view(next_obs.shape[0], -1).max(1)[0].view(-1, 1) target_Q = self.reward_scale * rewards + ( 1.0 - terminals) * self.discount * target_Q # Bx1 qf1_pred = self.qf1(obs, actions) # Bx1 qf2_pred = self.qf2(obs, actions) # Bx1 qf1_loss = (qf1_pred - target_Q.detach()).pow(2).mean() qf2_loss = (qf2_pred - target_Q.detach()).pow(2).mean() """ Actor Training """ sampled_actions, raw_sampled_actions = self.vae.decode_multiple( obs, num_decode=self.num_samples_mmd_match) actor_samples, _, _, _, _, _, _, raw_actor_actions = self.policy( obs.unsqueeze(1).repeat(1, self.num_samples_mmd_match, 1).view(-1, obs.shape[1]), return_log_prob=True) actor_samples = actor_samples.view(obs.shape[0], self.num_samples_mmd_match, actions.shape[1]) raw_actor_actions = raw_actor_actions.view(obs.shape[0], self.num_samples_mmd_match, actions.shape[1]) if self.kernel_choice == 'laplacian': mmd_loss = self.mmd_loss_laplacian(raw_sampled_actions, raw_actor_actions, sigma=self.mmd_sigma) elif self.kernel_choice == 'gaussian': mmd_loss = self.mmd_loss_gaussian(raw_sampled_actions, raw_actor_actions, sigma=self.mmd_sigma) action_divergence = ((sampled_actions - actor_samples)**2).sum(-1) raw_action_divergence = ((raw_sampled_actions - raw_actor_actions)**2).sum(-1) q_val1 = self.qf1(obs, actor_samples[:, 0, :]) q_val2 = self.qf2(obs, actor_samples[:, 0, :]) if self.policy_update_style == '0': policy_loss = torch.min(q_val1, q_val2)[:, 0] elif self.policy_update_style == '1': policy_loss = torch.mean(q_val1, q_val2)[:, 0] if self._n_train_steps_total >= 40000: # Now we can update the policy if self.mode == 'auto': policy_loss = (-policy_loss + self.log_alpha.exp() * (mmd_loss - self.target_mmd_thresh)).mean() else: policy_loss = (-policy_loss + 100 * mmd_loss).mean() else: if self.mode == 'auto': policy_loss = (self.log_alpha.exp() * (mmd_loss - self.target_mmd_thresh)).mean() else: policy_loss = 100 * mmd_loss.mean() """ Update Networks """ self.qf1_optimizer.zero_grad() qf1_loss.backward() self.qf1_optimizer.step() self.qf2_optimizer.zero_grad() qf2_loss.backward() self.qf2_optimizer.step() self.policy_optimizer.zero_grad() if self.mode == 'auto': policy_loss.backward(retain_graph=True) self.policy_optimizer.step() if self.mode == 'auto': self.alpha_optimizer.zero_grad() (-policy_loss).backward() self.alpha_optimizer.step() self.log_alpha.data.clamp_(min=-5.0, max=10.0) """ Update networks """ if self._n_train_steps_total % self.target_update_period == 0: ptu.soft_update_from_to(self.qf1, self.target_qf1, self.soft_target_tau) ptu.soft_update_from_to(self.qf2, self.target_qf2, self.soft_target_tau) """ Some statistics for eval """ if self._need_to_update_eval_statistics: self._need_to_update_eval_statistics = False """ Eval should set this to None. This way, these statistics are only computed for one batch. """ self.eval_statistics['QF1 Loss'] = np.mean(ptu.get_numpy(qf1_loss)) self.eval_statistics['QF2 Loss'] = np.mean(ptu.get_numpy(qf2_loss)) self.eval_statistics['Num Q Updates'] = self._num_q_update_steps self.eval_statistics[ 'Num Policy Updates'] = self._num_policy_update_steps self.eval_statistics['Policy Loss'] = np.mean( ptu.get_numpy(policy_loss)) self.eval_statistics.update( create_stats_ordered_dict( 'Q1 Predictions', ptu.get_numpy(qf1_pred), )) self.eval_statistics.update( create_stats_ordered_dict( 'Q2 Predictions', ptu.get_numpy(qf2_pred), )) self.eval_statistics.update( create_stats_ordered_dict( 'Q Targets', ptu.get_numpy(target_Q), )) self.eval_statistics.update( create_stats_ordered_dict('MMD Loss', ptu.get_numpy(mmd_loss))) self.eval_statistics.update( create_stats_ordered_dict('Action Divergence', ptu.get_numpy(action_divergence))) self.eval_statistics.update( create_stats_ordered_dict( 'Raw Action Divergence', ptu.get_numpy(raw_action_divergence))) if self.mode == 'auto': self.eval_statistics['Alpha'] = self.log_alpha.exp().item() self._n_train_steps_total += 1
def train_from_torch(self, batch): rewards = batch['rewards'] terminals = batch['terminals'] obs = batch['observations'] actions = batch['actions'] next_obs = batch['next_observations'] """ Policy and Alpha Loss """ pis = self.policy(obs) if self.use_automatic_entropy_tuning: alpha_loss = -(pis.detach() * self.log_alpha.exp() * (torch.log(pis + 1e-3) + self.target_entropy).detach()).sum(-1).mean() self.alpha_optimizer.zero_grad() alpha_loss.backward() self.alpha_optimizer.step() alpha = self.log_alpha.exp() else: alpha_loss = 0 alpha = 1 min_q = torch.min(self.qf1(obs), self.qf2(obs)).detach() policy_loss = (pis * (alpha * torch.log(pis + 1e-3) - min_q)).sum(-1).mean() """ QF Loss """ new_pis = self.policy(next_obs).detach() target_min_q_values = torch.min( self.target_qf1(next_obs), self.target_qf2(next_obs), ) target_q_values = ( new_pis * (target_min_q_values - alpha * torch.log(new_pis + 1e-3))).sum( -1, keepdim=True) q_target = self.reward_scale * rewards + ( 1. - terminals) * self.discount * target_q_values q1_pred = torch.sum(self.qf1(obs) * actions.detach(), dim=-1, keepdim=True) q2_pred = torch.sum(self.qf2(obs) * actions.detach(), dim=-1, keepdim=True) qf1_loss = self.qf_criterion(q1_pred, q_target.detach()) qf2_loss = self.qf_criterion(q2_pred, q_target.detach()) """ Update networks """ self.qf1_optimizer.zero_grad() qf1_loss.backward() self.qf1_optimizer.step() self.qf2_optimizer.zero_grad() qf2_loss.backward() self.qf2_optimizer.step() self.policy_optimizer.zero_grad() policy_loss.backward() self.policy_optimizer.step() """ Soft Updates """ if self._n_train_steps_total % self.target_update_period == 0: ptu.soft_update_from_to(self.qf1, self.target_qf1, self.soft_target_tau) ptu.soft_update_from_to(self.qf2, self.target_qf2, self.soft_target_tau) """ Save some statistics for eval """ if self._need_to_update_eval_statistics: self._need_to_update_eval_statistics = False """ Eval should set this to None. This way, these statistics are only computed for one batch. """ self.eval_statistics['QF1 Loss'] = np.mean(ptu.get_numpy(qf1_loss)) self.eval_statistics['QF2 Loss'] = np.mean(ptu.get_numpy(qf2_loss)) self.eval_statistics['Policy Loss'] = np.mean( ptu.get_numpy(policy_loss)) self.eval_statistics.update( create_stats_ordered_dict( 'Q1 Predictions', ptu.get_numpy(q1_pred), )) self.eval_statistics.update( create_stats_ordered_dict( 'Q2 Predictions', ptu.get_numpy(q2_pred), )) self.eval_statistics.update( create_stats_ordered_dict( 'Q Targets', ptu.get_numpy(q_target), )) self.eval_statistics.update( create_stats_ordered_dict( 'Pis', ptu.get_numpy(pis), )) if self.use_automatic_entropy_tuning: self.eval_statistics['Alpha'] = alpha.item() self.eval_statistics['Alpha Loss'] = alpha_loss.item() self._n_train_steps_total += 1
def _do_training(self): batch = self.get_batch() rewards = batch['rewards'] terminals = batch['terminals'] obs = batch['observations'] actions = batch['actions'] next_obs = batch['next_observations'] goals = batch['goals'] num_steps_left = batch['num_steps_left'] q1_pred = self.qf1( observations=obs, actions=actions, goals=goals, num_steps_left=num_steps_left, ) q2_pred = self.qf2( observations=obs, actions=actions, goals=goals, num_steps_left=num_steps_left, ) # Make sure policy accounts for squashing functions like tanh correctly! policy_outputs = self.policy( obs, goals, num_steps_left, reparameterize=self.train_policy_with_reparameterization, return_log_prob=True) new_actions, policy_mean, policy_log_std, log_pi = policy_outputs[:4] if not self.dense_rewards: log_pi = log_pi * terminals """ QF Loss """ target_v_values = self.target_vf( observations=next_obs, goals=goals, num_steps_left=num_steps_left - 1, ) q_target = self.reward_scale * rewards + ( 1. - terminals) * self.discount * target_v_values q_target = q_target.detach() bellman_errors_1 = (q1_pred - q_target)**2 bellman_errors_2 = (q2_pred - q_target)**2 qf1_loss = bellman_errors_1.mean() qf2_loss = bellman_errors_2.mean() """ VF Loss """ q1_new_actions = self.qf1( observations=obs, actions=new_actions, goals=goals, num_steps_left=num_steps_left, ) q2_new_actions = self.qf2( observations=obs, actions=new_actions, goals=goals, num_steps_left=num_steps_left, ) q_new_actions = torch.min(q1_new_actions, q2_new_actions) v_target = q_new_actions - log_pi v_pred = self.vf( observations=obs, goals=goals, num_steps_left=num_steps_left, ) v_target = v_target.detach() bellman_errors = (v_pred - v_target)**2 vf_loss = bellman_errors.mean() """ Update networks """ self.qf1_optimizer.zero_grad() qf1_loss.backward() self.qf1_optimizer.step() self.qf2_optimizer.zero_grad() qf2_loss.backward() self.qf2_optimizer.step() self.vf_optimizer.zero_grad() vf_loss.backward() self.vf_optimizer.step() """ Policy Loss """ # paper says to do + but apparently that's a typo. Do Q - V. if self.train_policy_with_reparameterization: policy_loss = (log_pi - q_new_actions).mean() else: log_policy_target = q_new_actions - v_pred policy_loss = (log_pi * (log_pi - log_policy_target).detach()).mean() mean_reg_loss = self.policy_mean_reg_weight * (policy_mean**2).mean() std_reg_loss = self.policy_std_reg_weight * (policy_log_std**2).mean() pre_tanh_value = policy_outputs[-1] pre_activation_reg_loss = self.policy_pre_activation_weight * ( (pre_tanh_value**2).sum(dim=1).mean()) policy_reg_loss = mean_reg_loss + std_reg_loss + pre_activation_reg_loss policy_loss = policy_loss + policy_reg_loss if self._n_train_steps_total % self.policy_update_period == 0: self.policy_optimizer.zero_grad() policy_loss.backward() self.policy_optimizer.step() if self._n_train_steps_total % self.target_update_period == 0: ptu.soft_update_from_to(self.vf, self.target_vf, self.soft_target_tau) """ Save some statistics for eval """ if self.need_to_update_eval_statistics: self.need_to_update_eval_statistics = False """ Eval should set this to None. This way, these statistics are only computed for one batch. """ self.eval_statistics['QF1 Loss'] = np.mean(ptu.get_numpy(qf1_loss)) self.eval_statistics['QF2 Loss'] = np.mean(ptu.get_numpy(qf2_loss)) self.eval_statistics['VF Loss'] = np.mean(ptu.get_numpy(vf_loss)) self.eval_statistics['Policy Loss'] = np.mean( ptu.get_numpy(policy_loss)) self.eval_statistics.update( create_stats_ordered_dict( 'Q1 Predictions', ptu.get_numpy(q1_pred), )) self.eval_statistics.update( create_stats_ordered_dict( 'Q2 Predictions', ptu.get_numpy(q2_pred), )) self.eval_statistics.update( create_stats_ordered_dict( 'V Predictions', ptu.get_numpy(v_pred), )) self.eval_statistics.update( create_stats_ordered_dict( 'Log Pis', ptu.get_numpy(log_pi), )) self.eval_statistics.update( create_stats_ordered_dict( 'Policy mu', ptu.get_numpy(policy_mean), )) self.eval_statistics.update( create_stats_ordered_dict( 'Policy log std', ptu.get_numpy(policy_log_std), ))
def _do_training(self): batch = self.get_batch() rewards = batch['rewards'] terminals = batch['terminals'] obs = batch['observations'] actions = batch['actions'] next_obs = batch['next_observations'] q1_pred = self.qf1(obs, actions) q2_pred = self.qf2(obs, actions) v_pred = self.vf(obs) # Make sure policy accounts for squashing functions like tanh correctly! policy_outputs = self.policy(obs, reparameterize=self.train_policy_with_reparameterization, return_log_prob=True) new_actions, policy_mean, policy_log_std, log_pi = policy_outputs[:4] """ Alpha Loss (if applicable) """ if self.use_automatic_entropy_tuning: """ Alpha Loss """ alpha_loss = -(self.log_alpha * (log_pi + self.target_entropy).detach()).mean() self.alpha_optimizer.zero_grad() alpha_loss.backward() self.alpha_optimizer.step() alpha = self.log_alpha.exp() else: alpha = 1 alpha_loss = 0 """ QF Loss """ target_v_values = self.target_vf(next_obs) q_target = self.reward_scale * rewards + (1. - terminals) * self.discount * target_v_values qf1_loss = self.qf_criterion(q1_pred, q_target.detach()) qf2_loss = self.qf_criterion(q2_pred, q_target.detach()) """ VF Loss """ q_new_actions = torch.min( self.qf1(obs, new_actions), self.qf2(obs, new_actions), ) v_target = q_new_actions - alpha*log_pi vf_loss = self.vf_criterion(v_pred, v_target.detach()) """ Update networks """ self.qf1_optimizer.zero_grad() qf1_loss.backward() self.qf1_optimizer.step() self.qf2_optimizer.zero_grad() qf2_loss.backward() self.qf2_optimizer.step() self.vf_optimizer.zero_grad() vf_loss.backward() self.vf_optimizer.step() policy_loss = None if self._n_train_steps_total % self.policy_update_period == 0: """ Policy Loss """ if self.train_policy_with_reparameterization: policy_loss = (alpha*log_pi - q_new_actions).mean() else: log_policy_target = q_new_actions - v_pred policy_loss = ( log_pi * (alpha*log_pi - log_policy_target).detach() ).mean() mean_reg_loss = self.policy_mean_reg_weight * (policy_mean**2).mean() std_reg_loss = self.policy_std_reg_weight * (policy_log_std**2).mean() pre_tanh_value = policy_outputs[-1] pre_activation_reg_loss = self.policy_pre_activation_weight * ( (pre_tanh_value**2).sum(dim=1).mean() ) policy_reg_loss = mean_reg_loss + std_reg_loss + pre_activation_reg_loss policy_loss = policy_loss + policy_reg_loss self.policy_optimizer.zero_grad() policy_loss.backward() self.policy_optimizer.step() if self._n_train_steps_total % self.target_update_period == 0: ptu.soft_update_from_to( self.vf, self.target_vf, self.soft_target_tau ) """ Save some statistics for eval using just one batch. """ if self.need_to_update_eval_statistics: self.need_to_update_eval_statistics = False if policy_loss is None: if self.train_policy_with_reparameterization: policy_loss = (log_pi - q_new_actions).mean() else: log_policy_target = q_new_actions - v_pred policy_loss = ( log_pi * (log_pi - log_policy_target).detach() ).mean() mean_reg_loss = self.policy_mean_reg_weight * (policy_mean**2).mean() std_reg_loss = self.policy_std_reg_weight * (policy_log_std**2).mean() pre_tanh_value = policy_outputs[-1] pre_activation_reg_loss = self.policy_pre_activation_weight * ( (pre_tanh_value**2).sum(dim=1).mean() ) policy_reg_loss = mean_reg_loss + std_reg_loss + pre_activation_reg_loss policy_loss = policy_loss + policy_reg_loss self.eval_statistics['QF1 Loss'] = np.mean(ptu.get_numpy(qf1_loss)) self.eval_statistics['QF2 Loss'] = np.mean(ptu.get_numpy(qf2_loss)) self.eval_statistics['VF Loss'] = np.mean(ptu.get_numpy(vf_loss)) self.eval_statistics['Policy Loss'] = np.mean(ptu.get_numpy( policy_loss )) self.eval_statistics.update(create_stats_ordered_dict( 'Q1 Predictions', ptu.get_numpy(q1_pred), )) self.eval_statistics.update(create_stats_ordered_dict( 'Q2 Predictions', ptu.get_numpy(q2_pred), )) self.eval_statistics.update(create_stats_ordered_dict( 'V Predictions', ptu.get_numpy(v_pred), )) self.eval_statistics.update(create_stats_ordered_dict( 'Log Pis', ptu.get_numpy(log_pi), )) self.eval_statistics.update(create_stats_ordered_dict( 'Policy mu', ptu.get_numpy(policy_mean), )) self.eval_statistics.update(create_stats_ordered_dict( 'Policy log std', ptu.get_numpy(policy_log_std), )) if self.use_automatic_entropy_tuning: self.eval_statistics['Alpha'] = alpha.item() self.eval_statistics['Alpha Loss'] = alpha_loss.item()
def _do_training(self): batch = self.get_batch() rewards = batch['rewards'] terminals = batch['terminals'] obs = batch['observations'] actions = batch['actions'] next_obs = batch['next_observations'] goals = batch['goals'] num_steps_left = batch['num_steps_left'] """ Critic operations. """ next_actions = self.target_policy( observations=next_obs, goals=goals, num_steps_left=num_steps_left - 1, ) noise = torch.normal( torch.zeros_like(next_actions), self.target_policy_noise, ) noise = torch.clamp(noise, -self.target_policy_noise_clip, self.target_policy_noise_clip) noisy_next_actions = next_actions + noise target_q1_values = self.target_qf1( observations=next_obs, actions=noisy_next_actions, goals=goals, num_steps_left=num_steps_left - 1, ) target_q2_values = self.target_qf2( observations=next_obs, actions=noisy_next_actions, goals=goals, num_steps_left=num_steps_left - 1, ) target_q_values = torch.min(target_q1_values, target_q2_values) q_target = self.reward_scale * rewards + ( 1. - terminals) * self.discount * target_q_values q_target = q_target.detach() q1_pred = self.qf1( observations=obs, actions=actions, goals=goals, num_steps_left=num_steps_left, ) q2_pred = self.qf2( observations=obs, actions=actions, goals=goals, num_steps_left=num_steps_left, ) bellman_errors_1 = (q1_pred - q_target)**2 bellman_errors_2 = (q2_pred - q_target)**2 qf1_loss = bellman_errors_1.mean() qf2_loss = bellman_errors_2.mean() """ Update Networks """ self.qf1_optimizer.zero_grad() qf1_loss.backward() self.qf1_optimizer.step() self.qf2_optimizer.zero_grad() qf2_loss.backward() self.qf2_optimizer.step() policy_actions, pre_tanh_value = self.policy( obs, goals, num_steps_left, return_preactivations=True, ) q_output = self.qf1( observations=obs, actions=policy_actions, num_steps_left=num_steps_left, goals=goals, ) policy_loss = -q_output.mean() if self._n_train_steps_total % self.policy_and_target_update_period == 0: self.policy_optimizer.zero_grad() policy_loss.backward() self.policy_optimizer.step() ptu.soft_update_from_to(self.policy, self.target_policy, self.tau) ptu.soft_update_from_to(self.qf1, self.target_qf1, self.tau) ptu.soft_update_from_to(self.qf2, self.target_qf2, self.tau) if self.need_to_update_eval_statistics: self.need_to_update_eval_statistics = False self.eval_statistics = OrderedDict() self.eval_statistics['QF1 Loss'] = np.mean(ptu.get_numpy(qf1_loss)) self.eval_statistics['QF2 Loss'] = np.mean(ptu.get_numpy(qf2_loss)) self.eval_statistics['Policy Loss'] = np.mean( ptu.get_numpy(policy_loss)) self.eval_statistics.update( create_stats_ordered_dict( 'Q1 Predictions', ptu.get_numpy(q1_pred), )) self.eval_statistics.update( create_stats_ordered_dict( 'Q2 Predictions', ptu.get_numpy(q2_pred), )) self.eval_statistics.update( create_stats_ordered_dict( 'Q Targets', ptu.get_numpy(q_target), )) self.eval_statistics.update( create_stats_ordered_dict( 'Bellman1 Errors', ptu.get_numpy(bellman_errors_1), )) self.eval_statistics.update( create_stats_ordered_dict( 'Bellman2 Errors', ptu.get_numpy(bellman_errors_2), )) self.eval_statistics.update( create_stats_ordered_dict( 'Policy Action', ptu.get_numpy(policy_actions), ))
def train_from_torch(self, batch): self._current_epoch += 1 rewards = batch['rewards'] terminals = batch['terminals'] obs = batch['observations'] actions = batch['actions'] next_obs = batch['next_observations'] """ Policy and Alpha Loss """ new_obs_actions, policy_mean, policy_log_std, log_pi, *_ = self.policy( obs, reparameterize=True, return_log_prob=True, ) if self.use_automatic_entropy_tuning: alpha_loss = -(self.log_alpha * (log_pi + self.target_entropy).detach()).mean() self.alpha_optimizer.zero_grad() alpha_loss.backward() self.alpha_optimizer.step() alpha = self.log_alpha.exp() else: alpha_loss = 0 alpha = 1 if self.num_qs == 1: q_new_actions = self.qf1(obs, new_obs_actions) else: q_new_actions = torch.min( self.qf1(obs, new_obs_actions), self.qf2(obs, new_obs_actions), ) policy_loss = (alpha*log_pi - q_new_actions).mean() if self._current_epoch < self.policy_eval_start: """ For the initial few epochs, try doing behaivoral cloning, if needed conventionally, there's not much difference in performance with having 20k gradient steps here, or not having it """ policy_log_prob = self.policy.log_prob(obs, actions) policy_loss = (alpha * log_pi - policy_log_prob).mean() """ QF Loss """ q1_pred = self.qf1(obs, actions) if self.num_qs > 1: q2_pred = self.qf2(obs, actions) new_next_actions, _, _, new_log_pi, *_ = self.policy( next_obs, reparameterize=True, return_log_prob=True, ) new_curr_actions, _, _, new_curr_log_pi, *_ = self.policy( obs, reparameterize=True, return_log_prob=True, ) if not self.max_q_backup: if self.num_qs == 1: target_q_values = self.target_qf1(next_obs, new_next_actions) else: target_q_values = torch.min( self.target_qf1(next_obs, new_next_actions), self.target_qf2(next_obs, new_next_actions), ) if not self.deterministic_backup: target_q_values = target_q_values - alpha * new_log_pi if self.max_q_backup: """when using max q backup""" next_actions_temp, _ = self._get_policy_actions(next_obs, num_actions=10, network=self.policy) target_qf1_values = self._get_tensor_values(next_obs, next_actions_temp, network=self.target_qf1).max(1)[0].view(-1, 1) target_qf2_values = self._get_tensor_values(next_obs, next_actions_temp, network=self.target_qf2).max(1)[0].view(-1, 1) target_q_values = torch.min(target_qf1_values, target_qf2_values) q_target = self.reward_scale * rewards + (1. - terminals) * self.discount * target_q_values q_target = q_target.detach() qf1_loss = self.qf_criterion(q1_pred, q_target) if self.num_qs > 1: qf2_loss = self.qf_criterion(q2_pred, q_target) ## add CQL random_actions_tensor = torch.FloatTensor(q2_pred.shape[0] * self.num_random, actions.shape[-1]).uniform_(-1, 1) # .cuda() curr_actions_tensor, curr_log_pis = self._get_policy_actions(obs, num_actions=self.num_random, network=self.policy) new_curr_actions_tensor, new_log_pis = self._get_policy_actions(next_obs, num_actions=self.num_random, network=self.policy) q1_rand = self._get_tensor_values(obs, random_actions_tensor, network=self.qf1) q2_rand = self._get_tensor_values(obs, random_actions_tensor, network=self.qf2) q1_curr_actions = self._get_tensor_values(obs, curr_actions_tensor, network=self.qf1) q2_curr_actions = self._get_tensor_values(obs, curr_actions_tensor, network=self.qf2) q1_next_actions = self._get_tensor_values(obs, new_curr_actions_tensor, network=self.qf1) q2_next_actions = self._get_tensor_values(obs, new_curr_actions_tensor, network=self.qf2) cat_q1 = torch.cat( [q1_rand, q1_pred.unsqueeze(1), q1_next_actions, q1_curr_actions], 1 ) cat_q2 = torch.cat( [q2_rand, q2_pred.unsqueeze(1), q2_next_actions, q2_curr_actions], 1 ) std_q1 = torch.std(cat_q1, dim=1) std_q2 = torch.std(cat_q2, dim=1) if self.min_q_version == 3: # importance sammpled version random_density = np.log(0.5 ** curr_actions_tensor.shape[-1]) cat_q1 = torch.cat( [q1_rand - random_density, q1_next_actions - new_log_pis.detach(), q1_curr_actions - curr_log_pis.detach()], 1 ) cat_q2 = torch.cat( [q2_rand - random_density, q2_next_actions - new_log_pis.detach(), q2_curr_actions - curr_log_pis.detach()], 1 ) min_qf1_loss = torch.logsumexp(cat_q1 / self.temp, dim=1,).mean() * self.min_q_weight * self.temp min_qf2_loss = torch.logsumexp(cat_q2 / self.temp, dim=1,).mean() * self.min_q_weight * self.temp """Subtract the log likelihood of data""" min_qf1_loss = min_qf1_loss - q1_pred.mean() * self.min_q_weight min_qf2_loss = min_qf2_loss - q2_pred.mean() * self.min_q_weight if self.with_lagrange: alpha_prime = torch.clamp(self.log_alpha_prime.exp(), min=0.0, max=1000000.0) min_qf1_loss = alpha_prime * (min_qf1_loss - self.target_action_gap) min_qf2_loss = alpha_prime * (min_qf2_loss - self.target_action_gap) self.alpha_prime_optimizer.zero_grad() alpha_prime_loss = (-min_qf1_loss - min_qf2_loss)*0.5 alpha_prime_loss.backward(retain_graph=True) self.alpha_prime_optimizer.step() qf1_loss = qf1_loss + min_qf1_loss qf2_loss = qf2_loss + min_qf2_loss """ Update networks """ # Update the Q-functions iff self._num_q_update_steps += 1 self.qf1_optimizer.zero_grad() qf1_loss.backward(retain_graph=True) self.qf1_optimizer.step() if self.num_qs > 1: self.qf2_optimizer.zero_grad() qf2_loss.backward(retain_graph=True) self.qf2_optimizer.step() self._num_policy_update_steps += 1 self.policy_optimizer.zero_grad() policy_loss.backward(retain_graph=False) self.policy_optimizer.step() """ Soft Updates """ ptu.soft_update_from_to( self.qf1, self.target_qf1, self.soft_target_tau ) if self.num_qs > 1: ptu.soft_update_from_to( self.qf2, self.target_qf2, self.soft_target_tau ) """ Save some statistics for eval """ if self._need_to_update_eval_statistics: self._need_to_update_eval_statistics = False """ Eval should set this to None. This way, these statistics are only computed for one batch. """ policy_loss = (log_pi - q_new_actions).mean() self.eval_statistics['QF1 Loss'] = np.mean(ptu.get_numpy(qf1_loss)) self.eval_statistics['min QF1 Loss'] = np.mean(ptu.get_numpy(min_qf1_loss)) if self.num_qs > 1: self.eval_statistics['QF2 Loss'] = np.mean(ptu.get_numpy(qf2_loss)) self.eval_statistics['min QF2 Loss'] = np.mean(ptu.get_numpy(min_qf2_loss)) if not self.discrete: self.eval_statistics['Std QF1 values'] = np.mean(ptu.get_numpy(std_q1)) self.eval_statistics['Std QF2 values'] = np.mean(ptu.get_numpy(std_q2)) self.eval_statistics.update(create_stats_ordered_dict( 'QF1 in-distribution values', ptu.get_numpy(q1_curr_actions), )) self.eval_statistics.update(create_stats_ordered_dict( 'QF2 in-distribution values', ptu.get_numpy(q2_curr_actions), )) self.eval_statistics.update(create_stats_ordered_dict( 'QF1 random values', ptu.get_numpy(q1_rand), )) self.eval_statistics.update(create_stats_ordered_dict( 'QF2 random values', ptu.get_numpy(q2_rand), )) self.eval_statistics.update(create_stats_ordered_dict( 'QF1 next_actions values', ptu.get_numpy(q1_next_actions), )) self.eval_statistics.update(create_stats_ordered_dict( 'QF2 next_actions values', ptu.get_numpy(q2_next_actions), )) self.eval_statistics.update(create_stats_ordered_dict( 'actions', ptu.get_numpy(actions) )) self.eval_statistics.update(create_stats_ordered_dict( 'rewards', ptu.get_numpy(rewards) )) self.eval_statistics['Num Q Updates'] = self._num_q_update_steps self.eval_statistics['Num Policy Updates'] = self._num_policy_update_steps self.eval_statistics['Policy Loss'] = np.mean(ptu.get_numpy( policy_loss )) self.eval_statistics.update(create_stats_ordered_dict( 'Q1 Predictions', ptu.get_numpy(q1_pred), )) if self.num_qs > 1: self.eval_statistics.update(create_stats_ordered_dict( 'Q2 Predictions', ptu.get_numpy(q2_pred), )) self.eval_statistics.update(create_stats_ordered_dict( 'Q Targets', ptu.get_numpy(q_target), )) self.eval_statistics.update(create_stats_ordered_dict( 'Log Pis', ptu.get_numpy(log_pi), )) if not self.discrete: self.eval_statistics.update(create_stats_ordered_dict( 'Policy mu', ptu.get_numpy(policy_mean), )) self.eval_statistics.update(create_stats_ordered_dict( 'Policy log std', ptu.get_numpy(policy_log_std), )) if self.use_automatic_entropy_tuning: self.eval_statistics['Alpha'] = alpha.item() self.eval_statistics['Alpha Loss'] = alpha_loss.item() if self.with_lagrange: self.eval_statistics['Alpha_prime'] = alpha_prime.item() self.eval_statistics['min_q1_loss'] = ptu.get_numpy(min_qf1_loss).mean() self.eval_statistics['min_q2_loss'] = ptu.get_numpy(min_qf2_loss).mean() self.eval_statistics['threshold action gap'] = self.target_action_gap self.eval_statistics['alpha prime loss'] = alpha_prime_loss.item() self._n_train_steps_total += 1
def _do_training_step(self, epoch, loop_iter): ''' Train the discriminator ''' self.encoder_optimizer.zero_grad() self.policy_optimizer.zero_grad() # prep the batches # OLD VERSION ----------------------------------------------------------------------------- # context_batch, context_pred_batch, test_pred_batch, mask = self._get_training_batch() # post_dist = self.encoder(context_batch, mask) # z = post_dist.sample() # N_tasks x Dim # # z = post_dist.mean # # convert it to a pytorch tensor # # note that our objective says we should maximize likelihood of # # BOTH the context_batch and the test_batch # obs_batch = np.concatenate((context_pred_batch['observations'], test_pred_batch['observations']), axis=0) # obs_batch = Variable(ptu.from_numpy(obs_batch), requires_grad=False) # acts_batch = np.concatenate((context_pred_batch['actions'], test_pred_batch['actions']), axis=0) # acts_batch = Variable(ptu.from_numpy(acts_batch), requires_grad=False) # # make z's for expert samples # context_pred_z = z.repeat(1, self.num_context_trajs_for_training * self.train_samples_per_traj).view( # -1, # z.size(1) # ) # test_pred_z = z.repeat(1, self.num_test_trajs_for_training * self.train_samples_per_traj).view( # -1, # z.size(1) # ) # z_batch = torch.cat([context_pred_z, test_pred_z], dim=0) # NEW VERSION (this is more fair to this model) ------------------------------------------- context_batch, mask, pred_batch = self._get_training_batch(epoch) post_dist = self.encoder(context_batch, mask) z = post_dist.sample() # N_tasks x Dim # z = post_dist.mean obs_batch = Variable(ptu.from_numpy(pred_batch['observations']), requires_grad=False) acts_batch = Variable(ptu.from_numpy(pred_batch['actions']), requires_grad=False) z_batch = z.repeat(1, self.policy_optim_batch_size_per_task).view( -1, z.size(1)) input_batch = torch.cat([obs_batch, z_batch], dim=-1) if self.use_mse_objective: pred_acts = self.policy(input_batch)[1] recon_loss = self.mse_loss(pred_acts, acts_batch) else: recon_loss = -1.0 * self.policy.get_log_prob( input_batch, acts_batch).mean() # add KL loss term cur_KL_beta = linear_schedule( self._n_train_steps_total * self.num_update_loops_per_train_call + loop_iter - self.KL_ramp_up_start_iter, 0.0, self.max_KL_beta, self.KL_ramp_up_end_iter - self.KL_ramp_up_start_iter) KL_loss = self._compute_KL_loss(post_dist) if cur_KL_beta == 0.0: KL_loss = KL_loss.detach() loss = recon_loss + cur_KL_beta * KL_loss loss.backward() self.policy_optimizer.step() self.encoder_optimizer.step() if self.use_target_policy: ptu.soft_update_from_to(self.policy, self.target_policy, self.soft_target_policy_tau) if self.use_target_enc: ptu.soft_update_from_to(self.encoder, self.target_enc, self.soft_target_enc_tau) """ Save some statistics for eval """ if self.eval_statistics is None: """ Eval should set this to None. This way, these statistics are only computed for one batch. """ self.eval_statistics = OrderedDict() if self.use_target_policy: enc_to_use = self.target_enc if self.use_target_enc else self.encoder pol_to_use = self.target_policy if self.use_mse_objective: pred_acts = pol_to_use(input_batch)[1] target_loss = self.mse_loss(pred_acts, acts_batch) self.eval_statistics['Target MSE Loss'] = np.mean( ptu.get_numpy(target_loss)) else: target_loss = -1.0 * pol_to_use.get_log_prob( input_batch, acts_batch).mean() self.eval_statistics['Target Neg Log Like'] = np.mean( ptu.get_numpy(target_loss)) else: if self.use_mse_objective: self.eval_statistics['Target MSE Loss'] = np.mean( ptu.get_numpy(recon_loss)) else: self.eval_statistics['Target Neg Log Like'] = np.mean( ptu.get_numpy(recon_loss)) self.eval_statistics['Target KL'] = np.mean(ptu.get_numpy(KL_loss)) self.eval_statistics['Cur KL Beta'] = cur_KL_beta self.eval_statistics['Max KL Beta'] = self.max_KL_beta self.eval_statistics['Avg Post Mean Abs'] = np.mean( np.abs(ptu.get_numpy(post_dist.mean))) self.eval_statistics['Avg Post Cov Abs'] = np.mean( np.abs(ptu.get_numpy(post_dist.cov)))
def _update_target_network(self): ptu.soft_update_from_to(self.vf, self.target_vf, self.soft_target_tau)
def _take_step(self, indices, obs_enc, act_enc, rewards_enc): num_tasks = len(indices) import time t6 = time.time() # data is (task, batch, feat) batch = self.replay_loader.next() # print('sample', time.time() - t6) t7 = time.time() obs, actions, rewards, next_obs, terms = [x.cuda() for x in batch] # print('to_cuda', time.time() - t7) t5 = time.time() enc_data = self.prepare_encoder_data(obs_enc, act_enc, rewards_enc) # print('prep enc data', time.time() - t5) self.cnn_optimizer.zero_grad() self.qf1_optimizer.zero_grad() self.context_optimizer.zero_grad() t5 = time.time() # run inference in networks q1_pred, q1_next_pred, q2_next_pred, policy_outputs, task_z = self.policy(obs, actions, next_obs, enc_data, obs_enc, act_enc) #print('policy', time.time() - t5) # new_actions, policy_mean, policy_log_std, log_pi = policy_outputs[:4] new_actions = policy_outputs # KL constraint on z if probabilistic t4 = time.time() kl_loss = 0 if self.use_information_bottleneck: kl_div = self.policy.compute_kl_div() kl_loss = self.kl_lambda * kl_div #print('kl', time.time() - t4) # kl_loss.backward(retain_graph=True) # qf and encoder update (note encoder does not get grads from policy or vf) rewards_flat = rewards.view(self.batch_size * num_tasks, -1) # scale rewards for Bellman update rewards_flat = rewards_flat * self.reward_scale terms_flat = terms.view(self.batch_size * num_tasks, -1) actions = actions.view(self.batch_size * num_tasks, -1) t3 = time.time() best_action_idxs = q1_next_pred.max( 1, keepdim=True )[1] target_q_values = q2_next_pred.gather( 1, best_action_idxs ).detach() #print('get actions', time.time() - t3) y_target = rewards_flat + (1. - terms_flat) * self.discount * target_q_values y_target = y_target.detach() t2 = time.time() # actions is a one-hot vector y_pred = torch.sum(q1_pred * actions, dim=1, keepdim=True) qf_loss = self.qf_criterion(y_pred, y_target) #print('compute loss', time.time() - t2) t1 = time.time() """ Update networks """ loss = qf_loss + kl_loss loss.backward() #print('backward', time.time() - t1) t0 = time.time() self.qf1_optimizer.step() self.cnn_optimizer.step() self.context_optimizer.step() #print('step', time.time() - t0) """ Soft target network updates """ if self.target_update_period > 1 and self._n_train_steps_total % self.target_update_period == 0: ptu.soft_update_from_to( self.policy.qf1, self.policy.qf2, 1 ) else: ptu.soft_update_from_to( self.policy.qf1, self.policy.qf2, self.soft_target_tau, ) # save some statistics for eval if self.eval_statistics is None: # eval should set this to None. # this way, these statistics are only computed for one batch. # TODO this is kind of annoying and higher variance, why not just average # across all the train steps? self.eval_statistics = OrderedDict() if self.use_information_bottleneck: z_mean = np.mean(np.abs(ptu.get_numpy(self.policy.z_dists[0].mean))) z_sig = np.mean(ptu.get_numpy(self.policy.z_dists[0].variance)) self.eval_statistics['Z mean train'] = z_mean self.eval_statistics['Z variance train'] = z_sig self.eval_statistics['KL Divergence'] = ptu.get_numpy(kl_div) self.eval_statistics['KL Loss'] = ptu.get_numpy(kl_loss) self.eval_statistics['QF Loss'] = np.mean(ptu.get_numpy(qf_loss)) # self.eval_statistics['VF Loss'] = np.mean(ptu.get_numpy(vf_loss)) # self.eval_statistics['Policy Loss'] = np.mean(ptu.get_numpy( # policy_loss # )) self.eval_statistics.update(create_stats_ordered_dict( 'Q Predictions', ptu.get_numpy(q1_pred), ))
def train_from_torch(self, batch): rewards = batch['rewards'] terminals = batch['terminals'] obs = batch['observations'] actions = batch['actions'] next_obs = batch['next_observations'] """ Policy and Alpha Loss """ new_obs_actions, policy_mean, policy_log_std, log_pi, *_ = self.policy( obs, reparameterize=True, return_log_prob=True, ) if self.use_automatic_entropy_tuning: alpha_loss = -(self.log_alpha * (log_pi + self.target_entropy).detach()).mean() self.alpha_optimizer.zero_grad() alpha_loss.backward() self.alpha_optimizer.step() alpha = self.log_alpha.exp() else: alpha_loss = 0 alpha = 1 q_new_actions = torch.min( self.qf1(obs, new_obs_actions), self.qf2(obs, new_obs_actions), ) policy_loss = (alpha * log_pi - q_new_actions).mean() """ QF Loss """ q1_pred = self.qf1(obs, actions) q2_pred = self.qf2(obs, actions) # Make sure policy accounts for squashing functions like tanh correctly! new_next_actions, _, _, new_log_pi, *_ = self.policy( next_obs, reparameterize=True, return_log_prob=True, ) target_q_values = torch.min( self.target_qf1(next_obs, new_next_actions), self.target_qf2(next_obs, new_next_actions), ) - alpha * new_log_pi q_target = self.reward_scale * rewards + ( 1. - terminals) * self.discount * target_q_values qf1_loss = self.qf_criterion(q1_pred, q_target.detach()) qf2_loss = self.qf_criterion(q2_pred, q_target.detach()) """ Update networks """ self.qf1_optimizer.zero_grad() qf1_loss.backward() self.qf1_optimizer.step() self.qf2_optimizer.zero_grad() qf2_loss.backward() self.qf2_optimizer.step() self.policy_optimizer.zero_grad() policy_loss.backward() self.policy_optimizer.step() """ Soft Updates """ if self._n_train_steps_total % self.target_update_period == 0: ptu.soft_update_from_to(self.qf1, self.target_qf1, self.soft_target_tau) ptu.soft_update_from_to(self.qf2, self.target_qf2, self.soft_target_tau) """ Save some statistics for eval """ if self._need_to_update_eval_statistics: self._need_to_update_eval_statistics = False """ Eval should set this to None. This way, these statistics are only computed for one batch. """ policy_loss = (log_pi - q_new_actions).mean() self.eval_statistics['QF1 Loss'] = np.mean(ptu.get_numpy(qf1_loss)) self.eval_statistics['QF2 Loss'] = np.mean(ptu.get_numpy(qf2_loss)) self.eval_statistics['Policy Loss'] = np.mean( ptu.get_numpy(policy_loss)) self.eval_statistics.update( create_stats_ordered_dict( 'Q1 Predictions', ptu.get_numpy(q1_pred), )) self.eval_statistics.update( create_stats_ordered_dict( 'Q2 Predictions', ptu.get_numpy(q2_pred), )) self.eval_statistics.update( create_stats_ordered_dict( 'Q Targets', ptu.get_numpy(q_target), )) self.eval_statistics.update( create_stats_ordered_dict( 'Log Pis', ptu.get_numpy(log_pi), )) self.eval_statistics.update( create_stats_ordered_dict( 'Policy mu', ptu.get_numpy(policy_mean), )) self.eval_statistics.update( create_stats_ordered_dict( 'Policy log std', ptu.get_numpy(policy_log_std), )) if self.use_automatic_entropy_tuning: self.eval_statistics['Alpha'] = alpha.item() self.eval_statistics['Alpha Loss'] = alpha_loss.item() self._n_train_steps_total += 1
def train_from_torch(self, batch): rewards = batch['rewards'] terminals = batch['terminals'] obs = batch['observations'] actions = batch['actions'] next_obs = batch['next_observations'] masks = batch['masks'] # variables for logging tot_qf1_loss, tot_qf2_loss, tot_q1_pred, tot_q2_pred, tot_q_target = 0, 0, 0, 0, 0 tot_log_pi, tot_policy_mean, tot_policy_log_std, tot_policy_loss = 0, 0, 0, 0 tot_alpha, tot_alpha_loss = 0, 0 std_Q_actor_list = self.corrective_feedback(obs=obs, update_type=0) std_Q_critic_list = self.corrective_feedback(obs=next_obs, update_type=1) for en_index in range(self.num_ensemble): mask = masks[:,en_index].reshape(-1, 1) """ Policy and Alpha Loss """ new_obs_actions, policy_mean, policy_log_std, log_pi, *_ = self.policy[en_index]( obs, reparameterize=True, return_log_prob=True, ) if self.use_automatic_entropy_tuning: alpha_loss = -(self.log_alpha[en_index] * (log_pi + self.target_entropy).detach()) * mask alpha_loss = alpha_loss.sum() / (mask.sum() + 1) self.alpha_optimizer[en_index].zero_grad() alpha_loss.backward() self.alpha_optimizer[en_index].step() alpha = self.log_alpha[en_index].exp() else: alpha_loss = 0 alpha = 1 q_new_actions = torch.min( self.qf1[en_index](obs, new_obs_actions), self.qf2[en_index](obs, new_obs_actions), ) if self.feedback_type == 0 or self.feedback_type == 2: std_Q = std_Q_actor_list[en_index] else: std_Q = std_Q_actor_list[0] if self.feedback_type == 1 or self.feedback_type == 0: weight_actor_Q = torch.sigmoid(-std_Q*self.temperature_act) + 0.5 else: weight_actor_Q = 2*torch.sigmoid(-std_Q*self.temperature_act) policy_loss = (alpha*log_pi - q_new_actions - self.expl_gamma * std_Q) * mask * weight_actor_Q.detach() policy_loss = policy_loss.sum() / (mask.sum() + 1) """ QF Loss """ q1_pred = self.qf1[en_index](obs, actions) q2_pred = self.qf2[en_index](obs, actions) # Make sure policy accounts for squashing functions like tanh correctly! new_next_actions, _, _, new_log_pi, *_ = self.policy[en_index]( next_obs, reparameterize=True, return_log_prob=True, ) target_q_values = torch.min( self.target_qf1[en_index](next_obs, new_next_actions), self.target_qf2[en_index](next_obs, new_next_actions), ) - alpha * new_log_pi if self.feedback_type == 0 or self.feedback_type == 2: if self.feedback_type == 0: weight_target_Q = torch.sigmoid(-std_Q_critic_list[en_index]*self.temperature) + 0.5 else: weight_target_Q = 2*torch.sigmoid(-std_Q_critic_list[en_index]*self.temperature) else: if self.feedback_type == 1: weight_target_Q = torch.sigmoid(-std_Q_critic_list[0]*self.temperature) + 0.5 else: weight_target_Q = 2*torch.sigmoid(-std_Q_critic_list[0]*self.temperature) q_target = self.reward_scale * rewards + (1. - terminals) * self.discount * target_q_values qf1_loss = self.qf_criterion(q1_pred, q_target.detach()) * mask * (weight_target_Q.detach()) qf2_loss = self.qf_criterion(q2_pred, q_target.detach()) * mask * (weight_target_Q.detach()) qf1_loss = qf1_loss.sum() / (mask.sum() + 1) qf2_loss = qf2_loss.sum() / (mask.sum() + 1) """ Update networks """ self.qf1_optimizer[en_index].zero_grad() qf1_loss.backward() self.qf1_optimizer[en_index].step() self.qf2_optimizer[en_index].zero_grad() qf2_loss.backward() self.qf2_optimizer[en_index].step() self.policy_optimizer[en_index].zero_grad() policy_loss.backward() self.policy_optimizer[en_index].step() """ Soft Updates """ if self._n_train_steps_total % self.target_update_period == 0: ptu.soft_update_from_to( self.qf1[en_index], self.target_qf1[en_index], self.soft_target_tau ) ptu.soft_update_from_to( self.qf2[en_index], self.target_qf2[en_index], self.soft_target_tau ) """ Statistics for log """ tot_qf1_loss += qf1_loss * (1/self.num_ensemble) tot_qf2_loss += qf2_loss * (1/self.num_ensemble) tot_q1_pred += q1_pred * (1/self.num_ensemble) tot_q2_pred += q2_pred * (1/self.num_ensemble) tot_q_target += q_target * (1/self.num_ensemble) tot_log_pi += log_pi * (1/self.num_ensemble) tot_policy_mean += policy_mean * (1/self.num_ensemble) tot_policy_log_std += policy_log_std * (1/self.num_ensemble) tot_alpha += alpha.item() * (1/self.num_ensemble) tot_alpha_loss += alpha_loss.item() tot_policy_loss = (log_pi - q_new_actions).mean() * (1/self.num_ensemble) """ Save some statistics for eval """ if self._need_to_update_eval_statistics: self._need_to_update_eval_statistics = False """ Eval should set this to None. This way, these statistics are only computed for one batch. """ self.eval_statistics['QF1 Loss'] = np.mean(ptu.get_numpy(tot_qf1_loss)) self.eval_statistics['QF2 Loss'] = np.mean(ptu.get_numpy(tot_qf2_loss)) self.eval_statistics['Policy Loss'] = np.mean(ptu.get_numpy( tot_policy_loss )) self.eval_statistics.update(create_stats_ordered_dict( 'Q1 Predictions', ptu.get_numpy(tot_q1_pred), )) self.eval_statistics.update(create_stats_ordered_dict( 'Q2 Predictions', ptu.get_numpy(tot_q2_pred), )) self.eval_statistics.update(create_stats_ordered_dict( 'Q Targets', ptu.get_numpy(tot_q_target), )) self.eval_statistics.update(create_stats_ordered_dict( 'Log Pis', ptu.get_numpy(tot_log_pi), )) self.eval_statistics.update(create_stats_ordered_dict( 'Policy mu', ptu.get_numpy(tot_policy_mean), )) self.eval_statistics.update(create_stats_ordered_dict( 'Policy log std', ptu.get_numpy(tot_policy_log_std), )) if self.use_automatic_entropy_tuning: self.eval_statistics['Alpha'] = tot_alpha self.eval_statistics['Alpha Loss'] = tot_alpha_loss self._n_train_steps_total += 1
def _do_reward_training(self, epoch): ''' Train the discriminator ''' self.disc_optimizer.zero_grad() expert_batch = self.get_disc_training_batch(self.disc_optim_batch_size, True) policy_batch = self.get_disc_training_batch(self.disc_optim_batch_size, False) expert_obs = expert_batch['observations'] policy_obs = policy_batch['observations'] if self.wrap_absorbing: expert_obs = torch.cat( [expert_obs, expert_batch['absorbing'][:, 0:1]], dim=-1) policy_obs = torch.cat( [policy_obs, policy_batch['absorbing'][:, 0:1]], dim=-1) if not self.state_only: expert_actions = expert_batch['actions'] policy_actions = policy_batch['actions'] if self.use_disc_input_noise: noise_scale = linear_schedule(epoch, self.disc_input_noise_scale_start, self.disc_input_noise_scale_end, self.epochs_till_end_scale) if noise_scale > 0.0: expert_obs = expert_obs + noise_scale * Variable( torch.randn(expert_obs.size())) if not self.state_only: expert_actions = expert_actions + noise_scale * Variable( torch.randn(expert_actions.size())) policy_obs = policy_obs + noise_scale * Variable( torch.randn(policy_obs.size())) if not self.state_only: policy_actions = policy_actions + noise_scale * Variable( torch.randn(policy_actions.size())) obs = torch.cat([expert_obs, policy_obs], dim=0) if not self.state_only: actions = torch.cat([expert_actions, policy_actions], dim=0) if self.state_only: disc_logits = self.discriminator(obs, None) else: disc_logits = self.discriminator(obs, actions) disc_preds = (disc_logits > 0).type(disc_logits.data.type()) disc_ce_loss = self.bce(disc_logits, self.bce_targets) accuracy = (disc_preds == self.bce_targets).type( torch.FloatTensor).mean() disc_ce_loss.backward() ce_grad_norm = 0.0 for name, param in self.discriminator.named_parameters(): if param.grad is not None: if self.disc_grad_buffer_is_empty: self.disc_grad_buffer[name] = param.grad.data.clone() else: self.disc_grad_buffer[name].copy_(param.grad.data) param_norm = param.grad.data.norm(2) ce_grad_norm += param_norm**2 ce_grad_norm = ce_grad_norm**0.5 self.disc_grad_buffer_is_empty = False ce_clip_coef = self.disc_ce_grad_clip / (ce_grad_norm + 1e-6) if ce_clip_coef < 1.: for name, grad in self.disc_grad_buffer.items(): grad.mul_(ce_clip_coef) if ce_clip_coef < 1.0: ce_grad_norm *= ce_clip_coef self.max_disc_ce_grad = max(ce_grad_norm, self.max_disc_ce_grad) self.disc_ce_grad_norm += ce_grad_norm self.disc_ce_grad_norm_counter += 1 self.disc_optimizer.zero_grad() if self.use_grad_pen: eps = Variable(torch.rand(expert_obs.size(0), 1)) if ptu.gpu_enabled(): eps = eps.cuda() interp_obs = eps * expert_obs + (1 - eps) * policy_obs interp_obs = interp_obs.detach() interp_obs.requires_grad = True if self.state_only: gradients = autograd.grad( outputs=self.discriminator(interp_obs, None).sum(), inputs=[interp_obs], # grad_outputs=torch.ones(exp_specs['batch_size'], 1).cuda(), create_graph=True, retain_graph=True, only_inputs=True) total_grad = gradients[0] else: interp_actions = eps * expert_actions + (1 - eps) * policy_actions interp_actions = interp_actions.detach() interp_actions.requires_grad = True gradients = autograd.grad( outputs=self.discriminator(interp_obs, interp_actions).sum(), inputs=[interp_obs, interp_actions], # grad_outputs=torch.ones(exp_specs['batch_size'], 1).cuda(), create_graph=True, retain_graph=True, only_inputs=True) total_grad = torch.cat([gradients[0], gradients[1]], dim=1) # GP from Gulrajani et al. gradient_penalty = ((total_grad.norm(2, dim=1) - 1)**2).mean() disc_grad_pen_loss = gradient_penalty * self.grad_pen_weight # # GP from Mescheder et al. # gradient_penalty = (total_grad.norm(2, dim=1) ** 2).mean() # disc_grad_pen_loss = gradient_penalty * 0.5 * self.grad_pen_weight disc_grad_pen_loss.backward() gp_grad_norm = 0.0 for p in list( filter(lambda p: p.grad is not None, self.discriminator.parameters())): param_norm = p.grad.data.norm(2) gp_grad_norm += param_norm**2 gp_grad_norm = gp_grad_norm**0.5 gp_clip_coef = self.disc_gp_grad_clip / (gp_grad_norm + 1e-6) if gp_clip_coef < 1.: for p in self.discriminator.parameters(): p.grad.data.mul_(gp_clip_coef) if gp_clip_coef < 1.: gp_grad_norm *= gp_clip_coef self.max_disc_gp_grad = max(gp_grad_norm, self.max_disc_gp_grad) self.disc_gp_grad_norm += gp_grad_norm self.disc_gp_grad_norm_counter += 1 # now add back the gradients from the CE loss for name, param in self.discriminator.named_parameters(): param.grad.data.add_(self.disc_grad_buffer[name]) self.disc_optimizer.step() if self.use_target_disc: ptu.soft_update_from_to(self.discriminator, self.target_disc, self.soft_target_disc_tau) """ Save some statistics for eval """ if self.rewardf_eval_statistics is None: """ Eval should set this to None. This way, these statistics are only computed for one batch. """ self.rewardf_eval_statistics = OrderedDict() if self.use_target_disc: if self.state_only: target_disc_logits = self.target_disc(obs, None) else: target_disc_logits = self.target_disc(obs, actions) target_disc_preds = (target_disc_logits > 0).type( target_disc_logits.data.type()) target_disc_ce_loss = self.bce(target_disc_logits, self.bce_targets) target_accuracy = (target_disc_preds == self.bce_targets).type( torch.FloatTensor).mean() if self.use_grad_pen: eps = Variable(torch.rand(expert_obs.size(0), 1)) if ptu.gpu_enabled(): eps = eps.cuda() interp_obs = eps * expert_obs + (1 - eps) * policy_obs interp_obs = interp_obs.detach() interp_obs.requires_grad = True if self.state_only: target_gradients = autograd.grad( outputs=self.target_disc(interp_obs, None).sum(), inputs=[interp_obs], # grad_outputs=torch.ones(exp_specs['batch_size'], 1).cuda(), create_graph=True, retain_graph=True, only_inputs=True) total_target_grad = target_gradients[0] else: interp_actions = eps * expert_actions + ( 1 - eps) * policy_actions interp_actions = interp_actions.detach() interp_actions.requires_grad = True target_gradients = autograd.grad( outputs=self.target_disc(interp_obs, interp_actions).sum(), inputs=[interp_obs, interp_actions], # grad_outputs=torch.ones(exp_specs['batch_size'], 1).cuda(), create_graph=True, retain_graph=True, only_inputs=True) total_target_grad = torch.cat( [target_gradients[0], target_gradients[1]], dim=1) # GP from Gulrajani et al. target_gradient_penalty = (( total_target_grad.norm(2, dim=1) - 1)**2).mean() # # GP from Mescheder et al. # target_gradient_penalty = (total_target_grad.norm(2, dim=1) ** 2).mean() self.rewardf_eval_statistics['Target Disc CE Loss'] = np.mean( ptu.get_numpy(target_disc_ce_loss)) self.rewardf_eval_statistics['Target Disc Acc'] = np.mean( ptu.get_numpy(target_accuracy)) self.rewardf_eval_statistics['Target Grad Pen'] = np.mean( ptu.get_numpy(target_gradient_penalty)) self.rewardf_eval_statistics['Target Grad Pen W'] = np.mean( self.grad_pen_weight) self.rewardf_eval_statistics['Disc CE Loss'] = np.mean( ptu.get_numpy(disc_ce_loss)) self.rewardf_eval_statistics['Disc Acc'] = np.mean( ptu.get_numpy(accuracy)) if self.use_grad_pen: self.rewardf_eval_statistics['Grad Pen'] = np.mean( ptu.get_numpy(gradient_penalty)) self.rewardf_eval_statistics['Grad Pen W'] = np.mean( self.grad_pen_weight) self.rewardf_eval_statistics[ 'Disc Avg CE Grad Norm this epoch'] = np.mean( self.disc_ce_grad_norm / self.disc_ce_grad_norm_counter) self.rewardf_eval_statistics[ 'Disc Max CE Grad Norm this epoch'] = np.mean( self.max_disc_ce_grad) self.rewardf_eval_statistics[ 'Disc Avg GP Grad Norm this epoch'] = np.mean( self.disc_gp_grad_norm / self.disc_gp_grad_norm_counter) self.rewardf_eval_statistics[ 'Disc Max GP Grad Norm this epoch'] = np.mean( self.max_disc_gp_grad) if self.use_disc_input_noise: self.rewardf_eval_statistics[ 'Disc Input Noise Scale'] = noise_scale self.max_disc_ce_grad = 0.0 self.disc_ce_grad_norm = 0.0 self.disc_ce_grad_norm_counter = 0.0 self.max_disc_gp_grad = 0.0 self.disc_gp_grad_norm = 0.0 self.disc_gp_grad_norm_counter = 0.0
def train_from_torch(self, batch): rewards = batch['rewards'] terminals = batch['terminals'] obs = batch['observations'] actions = batch['actions'] next_obs = batch['next_observations'] skills = batch['skills'] """ MI estimator btw prioceptive and extrioceptive sensors """ prio_obs = obs[:, :self.prio_extrio_bound] extrio_obs = obs[:, self.prio_extrio_bound:] mi_btw_states = estimate_mutual_information( "smile", prio_obs, extrio_obs, critic_fn=self.mi_estimator, clip=self.smile_clip) mi_loss = -mi_btw_states """ DF Loss and Intrinsic Reward """ z_hat = torch.argmax(skills, dim=1) d_pred = self.df(next_obs) d_pred_log_softmax = F.log_softmax(d_pred, 1) _, pred_z = torch.max(d_pred_log_softmax, dim=1, keepdim=True) skill_mi_rewards = d_pred_log_softmax[torch.arange(d_pred.shape[0]), z_hat] - math.log( 1 / self.policy.skill_dim) df_loss = self.df_criterion(d_pred, z_hat) rewards = skill_mi_rewards.reshape(-1, 1) + mi_btw_states.reshape( -1, 1) #+ rewards """ Policy and Alpha Loss """ new_obs_actions, policy_mean, policy_log_std, log_pi, *_ = self.policy( obs, skill_vec=skills, reparameterize=True, return_log_prob=True, ) obs_skills = torch.cat((obs, skills), dim=1) if self.use_automatic_entropy_tuning: alpha_loss = -(self.log_alpha * (log_pi + self.target_entropy).detach()).mean() self.alpha_optimizer.zero_grad() alpha_loss.backward() self.alpha_optimizer.step() alpha = self.log_alpha.exp() else: alpha_loss = 0 alpha = 1 q_new_actions = torch.min( self.qf1(obs_skills, new_obs_actions), self.qf2(obs_skills, new_obs_actions), ) policy_loss = (alpha * log_pi - q_new_actions).mean() """ QF Loss """ q1_pred = self.qf1(obs_skills, actions) q2_pred = self.qf2(obs_skills, actions) # Make sure policy accounts for squashing functions like tanh correctly! new_next_actions, _, _, new_log_pi, *_ = self.policy( next_obs, skill_vec=skills, reparameterize=True, return_log_prob=True, ) next_obs_skills = torch.cat((next_obs, skills), dim=1) target_q_values = torch.min( self.target_qf1(next_obs_skills, new_next_actions), self.target_qf2(next_obs_skills, new_next_actions), ) - alpha * new_log_pi q_target = self.reward_scale * rewards + ( 1. - terminals) * self.discount * target_q_values qf1_loss = self.qf_criterion(q1_pred, q_target.detach()) qf2_loss = self.qf_criterion(q2_pred, q_target.detach()) """ Update networks """ self.df_optimizer.zero_grad() df_loss.backward() self.df_optimizer.step() self.mi_optimizer.zero_grad() mi_loss.backward() self.mi_optimizer.step() self.qf1_optimizer.zero_grad() self.qf2_optimizer.zero_grad() self.policy_optimizer.zero_grad() qf1_loss.backward() qf2_loss.backward() policy_loss.backward() self.qf2_optimizer.step() self.qf1_optimizer.step() self.policy_optimizer.step() """ Soft Updates """ if self._n_train_steps_total % self.target_update_period == 0: ptu.soft_update_from_to(self.qf1, self.target_qf1, self.soft_target_tau) ptu.soft_update_from_to(self.qf2, self.target_qf2, self.soft_target_tau) """ Save some statistics for eval """ df_accuracy = torch.sum( torch.eq(z_hat, pred_z.reshape(1, list( pred_z.size())[0])[0])).float() / list( pred_z.size())[0] if self._need_to_update_eval_statistics: self._need_to_update_eval_statistics = False """ Eval should set this to None. This way, these statistics are only computed for one batch. """ policy_loss = (log_pi - q_new_actions).mean() self.eval_statistics['MI btw States Rewards'] = np.mean( ptu.get_numpy(mi_btw_states)) self.eval_statistics['MI btw Skill Rewards'] = np.mean( ptu.get_numpy(skill_mi_rewards)) self.eval_statistics['Sum Rewards'] = np.mean( ptu.get_numpy(rewards)) self.eval_statistics['DF Loss'] = np.mean(ptu.get_numpy(df_loss)) self.eval_statistics['DF Accuracy'] = np.mean( ptu.get_numpy(df_accuracy)) self.eval_statistics['QF1 Loss'] = np.mean(ptu.get_numpy(qf1_loss)) self.eval_statistics['QF2 Loss'] = np.mean(ptu.get_numpy(qf2_loss)) self.eval_statistics['Policy Loss'] = np.mean( ptu.get_numpy(policy_loss)) self.eval_statistics.update( create_stats_ordered_dict( 'Q1 Predictions', ptu.get_numpy(q1_pred), )) self.eval_statistics.update( create_stats_ordered_dict( 'Q2 Predictions', ptu.get_numpy(q2_pred), )) self.eval_statistics.update( create_stats_ordered_dict( 'D Predictions', ptu.get_numpy(pred_z), )) self.eval_statistics.update( create_stats_ordered_dict( 'Q Targets', ptu.get_numpy(q_target), )) self.eval_statistics.update( create_stats_ordered_dict( 'Log Pis', ptu.get_numpy(log_pi), )) self.eval_statistics.update( create_stats_ordered_dict( 'Policy mu', ptu.get_numpy(policy_mean), )) self.eval_statistics.update( create_stats_ordered_dict( 'Policy log std', ptu.get_numpy(policy_log_std), )) if self.use_automatic_entropy_tuning: self.eval_statistics['Alpha'] = alpha.item() self.eval_statistics['Alpha Loss'] = alpha_loss.item() self._n_train_steps_total += 1