def get_mask_diagnostics(unused): from rlkit.core.logging import append_log, add_prefix, OrderedDict log = OrderedDict() for prefix, collector in zip(addl_log_prefixes, addl_collectors): paths = collector.collect_new_paths( max_path_length, variant['algo_kwargs']['num_eval_steps_per_epoch'], discard_incomplete_paths=True, ) old_path_info = eval_env.get_diagnostics(paths) keys_to_keep = [] for key in old_path_info.keys(): if ('env_infos' in key) and ('final' in key) and ('Mean' in key): keys_to_keep.append(key) path_info = OrderedDict() for key in keys_to_keep: path_info[key] = old_path_info[key] generic_info = add_prefix( path_info, prefix, ) append_log(log, generic_info) for collector in addl_collectors: collector.end_epoch(0) return log
def get_eval_diagnostics(key_to_paths): stats = OrderedDict() for eval_env_name, paths in key_to_paths.items(): env, _ = eval_env_name_to_env_and_context_distrib[eval_env_name] stats.update( add_prefix( env.get_diagnostics(paths), eval_env_name, divider='/', )) stats.update( add_prefix( eval_util.get_generic_path_information(paths), eval_env_name, divider='/', )) return stats
def train(self): # first train only the Q function iteration = 0 for i in range(self.num_batches): train_data = self.replay_buffer.random_batch(self.batch_size) train_data = np_to_pytorch_batch(train_data) obs = train_data['observations'] next_obs = train_data['next_observations'] train_data['observations'] = obs train_data['next_observations'] = next_obs self.trainer.train_from_torch(train_data) if i % self.logging_period == 0: stats_with_prefix = add_prefix( self.trainer.eval_statistics, prefix="trainer/") self.trainer.end_epoch(iteration) iteration += 1 logger.record_dict(stats_with_prefix) logger.dump_tabular(with_prefix=True, with_timestamp=False)
def get_mask_diagnostics(unused): from rlkit.core.logging import append_log, add_prefix, OrderedDict from rlkit.misc import eval_util log = OrderedDict() for prefix, collector in zip(log_prefixes, collectors): paths = collector.collect_new_paths( max_path_length, masking_eval_steps, discard_incomplete_paths=True, ) generic_info = add_prefix( eval_util.get_generic_path_information(paths), prefix, ) append_log(log, generic_info) for collector in collectors: collector.end_epoch(0) return log
def pretrain_q_with_bc_data(self, batch_size): logger.remove_tabular_output('progress.csv', relative_to_snapshot_dir=True) logger.add_tabular_output('pretrain_q.csv', relative_to_snapshot_dir=True) prev_time = time.time() for i in range(self.num_pretrain_steps): self.eval_statistics = dict() if i % self.pretraining_logging_period == 0: self._need_to_update_eval_statistics = True train_data = self.replay_buffer.random_batch(self.bc_batch_size) train_data = np_to_pytorch_batch(train_data) obs = train_data['observations'] next_obs = train_data['next_observations'] # goals = train_data['resampled_goals'] train_data['observations'] = obs # torch.cat((obs, goals), dim=1) train_data[ 'next_observations'] = next_obs # torch.cat((next_obs, goals), dim=1) self.train_from_torch(train_data, pretrain=True) if i % self.pretraining_logging_period == 0: self.eval_statistics["batch"] = i self.eval_statistics["epoch_time"] = time.time() - prev_time stats_with_prefix = add_prefix(self.eval_statistics, prefix="trainer/") logger.record_dict(stats_with_prefix) logger.dump_tabular(with_prefix=True, with_timestamp=False) prev_time = time.time() logger.remove_tabular_output( 'pretrain_q.csv', relative_to_snapshot_dir=True, ) logger.add_tabular_output( 'progress.csv', relative_to_snapshot_dir=True, ) self._need_to_update_eval_statistics = True self.eval_statistics = dict()
def test_from_torch(self, batch): rewards = batch['rewards'] terminals = batch['terminals'] obs = batch['observations'] actions = batch['actions'] next_obs = batch['next_observations'] weights = batch.get('weights', None) if self.reward_transform: rewards = self.reward_transform(rewards) if self.terminal_transform: terminals = self.terminal_transform(terminals) """ Policy and Alpha Loss """ dist = self.policy(obs) new_obs_actions, log_pi = dist.rsample_and_logprob() policy_mle = dist.mle_estimate() if self.use_automatic_entropy_tuning: alpha_loss = -(self.log_alpha * (log_pi + self.target_entropy).detach()).mean() alpha = self.log_alpha.exp() else: alpha_loss = 0 alpha = self.alpha q1_pred = self.qf1(obs, actions) q2_pred = self.qf2(obs, actions) # Make sure policy accounts for squashing functions like tanh correctly! next_dist = self.policy(next_obs) new_next_actions, new_log_pi = next_dist.rsample_and_logprob() target_q_values = torch.min( self.target_qf1(next_obs, new_next_actions), self.target_qf2(next_obs, new_next_actions), ) - alpha * new_log_pi q_target = self.reward_scale * rewards + ( 1. - terminals) * self.discount * target_q_values qf1_loss = self.qf_criterion(q1_pred, q_target.detach()) qf2_loss = self.qf_criterion(q2_pred, q_target.detach()) qf1_new_actions = self.qf1(obs, new_obs_actions) qf2_new_actions = self.qf2(obs, new_obs_actions) q_new_actions = torch.min( qf1_new_actions, qf2_new_actions, ) policy_loss = (log_pi - q_new_actions).mean() self.eval_statistics['validation/QF1 Loss'] = np.mean( ptu.get_numpy(qf1_loss)) self.eval_statistics['validation/QF2 Loss'] = np.mean( ptu.get_numpy(qf2_loss)) self.eval_statistics['validation/Policy Loss'] = np.mean( ptu.get_numpy(policy_loss)) self.eval_statistics.update( create_stats_ordered_dict( 'validation/Q1 Predictions', ptu.get_numpy(q1_pred), )) self.eval_statistics.update( create_stats_ordered_dict( 'validation/Q2 Predictions', ptu.get_numpy(q2_pred), )) self.eval_statistics.update( create_stats_ordered_dict( 'validation/Q Targets', ptu.get_numpy(q_target), )) self.eval_statistics.update( create_stats_ordered_dict( 'validation/Log Pis', ptu.get_numpy(log_pi), )) policy_statistics = add_prefix(dist.get_diagnostics(), "validation/policy/") self.eval_statistics.update(policy_statistics)
def pretrain_q_with_bc_data(self): """ :return: """ logger.remove_tabular_output('progress.csv', relative_to_snapshot_dir=True) logger.add_tabular_output('pretrain_q.csv', relative_to_snapshot_dir=True) self.update_policy = False # first train only the Q function for i in range(self.q_num_pretrain1_steps): self.eval_statistics = dict() train_data = self.replay_buffer.random_batch(self.bc_batch_size) train_data = np_to_pytorch_batch(train_data) obs = train_data['observations'] next_obs = train_data['next_observations'] # goals = train_data['resampled_goals'] train_data['observations'] = obs # torch.cat((obs, goals), dim=1) train_data[ 'next_observations'] = next_obs # torch.cat((next_obs, goals), dim=1) self.train_from_torch(train_data, pretrain=True) if i % self.pretraining_logging_period == 0: stats_with_prefix = add_prefix(self.eval_statistics, prefix="trainer/") logger.record_dict(stats_with_prefix) logger.dump_tabular(with_prefix=True, with_timestamp=False) self.update_policy = True # then train policy and Q function together prev_time = time.time() for i in range(self.q_num_pretrain2_steps): self.eval_statistics = dict() if i % self.pretraining_logging_period == 0: self._need_to_update_eval_statistics = True train_data = self.replay_buffer.random_batch(self.bc_batch_size) train_data = np_to_pytorch_batch(train_data) obs = train_data['observations'] next_obs = train_data['next_observations'] # goals = train_data['resampled_goals'] train_data['observations'] = obs # torch.cat((obs, goals), dim=1) train_data[ 'next_observations'] = next_obs # torch.cat((next_obs, goals), dim=1) self.train_from_torch(train_data, pretrain=True) if i % self.pretraining_logging_period == 0: self.eval_statistics["batch"] = i self.eval_statistics["epoch_time"] = time.time() - prev_time stats_with_prefix = add_prefix(self.eval_statistics, prefix="trainer/") logger.record_dict(stats_with_prefix) logger.dump_tabular(with_prefix=True, with_timestamp=False) prev_time = time.time() logger.remove_tabular_output( 'pretrain_q.csv', relative_to_snapshot_dir=True, ) logger.add_tabular_output( 'progress.csv', relative_to_snapshot_dir=True, ) self._need_to_update_eval_statistics = True self.eval_statistics = dict() if self.post_pretrain_hyperparams: self.set_algorithm_weights(**self.post_pretrain_hyperparams)
def train_from_torch(self, batch, train=True, pretrain=False,): rewards = batch['rewards'] terminals = batch['terminals'] obs = batch['observations'] actions = batch['actions'] next_obs = batch['next_observations'] if self.reward_transform: rewards = self.reward_transform(rewards) if self.terminal_transform: terminals = self.terminal_transform(terminals) """ Policy and Alpha Loss """ dist = self.policy(obs) """ QF Loss """ q1_pred = self.qf1(obs, actions) q2_pred = self.qf2(obs, actions) target_vf_pred = self.vf(next_obs).detach() q_target = self.reward_scale * rewards + (1. - terminals) * self.discount * target_vf_pred q_target = q_target.detach() qf1_loss = self.qf_criterion(q1_pred, q_target) qf2_loss = self.qf_criterion(q2_pred, q_target) """ VF Loss """ q_pred = torch.min( self.target_qf1(obs, actions), self.target_qf2(obs, actions), ).detach() vf_pred = self.vf(obs) vf_err = vf_pred - q_pred vf_sign = (vf_err > 0).float() vf_weight = (1 - vf_sign) * self.quantile + vf_sign * (1 - self.quantile) vf_loss = (vf_weight * (vf_err ** 2)).mean() """ Policy Loss """ policy_logpp = dist.log_prob(actions) adv = q_pred - vf_pred exp_adv = torch.exp(adv / self.beta) if self.clip_score is not None: exp_adv = torch.clamp(exp_adv, max=self.clip_score) weights = exp_adv[:, 0].detach() policy_loss = (-policy_logpp * weights).mean() """ Update networks """ if self._n_train_steps_total % self.q_update_period == 0: self.qf1_optimizer.zero_grad() qf1_loss.backward() self.qf1_optimizer.step() self.qf2_optimizer.zero_grad() qf2_loss.backward() self.qf2_optimizer.step() self.vf_optimizer.zero_grad() vf_loss.backward() self.vf_optimizer.step() if self._n_train_steps_total % self.policy_update_period == 0: self.policy_optimizer.zero_grad() policy_loss.backward() self.policy_optimizer.step() """ Soft Updates """ if self._n_train_steps_total % self.target_update_period == 0: ptu.soft_update_from_to( self.qf1, self.target_qf1, self.soft_target_tau ) ptu.soft_update_from_to( self.qf2, self.target_qf2, self.soft_target_tau ) """ Save some statistics for eval """ if self._need_to_update_eval_statistics: self._need_to_update_eval_statistics = False """ Eval should set this to None. This way, these statistics are only computed for one batch. """ self.eval_statistics['QF1 Loss'] = np.mean(ptu.get_numpy(qf1_loss)) self.eval_statistics['QF2 Loss'] = np.mean(ptu.get_numpy(qf2_loss)) self.eval_statistics['Policy Loss'] = np.mean(ptu.get_numpy( policy_loss )) self.eval_statistics.update(create_stats_ordered_dict( 'Q1 Predictions', ptu.get_numpy(q1_pred), )) self.eval_statistics.update(create_stats_ordered_dict( 'Q2 Predictions', ptu.get_numpy(q2_pred), )) self.eval_statistics.update(create_stats_ordered_dict( 'Q Targets', ptu.get_numpy(q_target), )) self.eval_statistics.update(create_stats_ordered_dict( 'rewards', ptu.get_numpy(rewards), )) self.eval_statistics.update(create_stats_ordered_dict( 'terminals', ptu.get_numpy(terminals), )) self.eval_statistics['replay_buffer_len'] = self.replay_buffer._size policy_statistics = add_prefix(dist.get_diagnostics(), "policy/") self.eval_statistics.update(policy_statistics) self.eval_statistics.update(create_stats_ordered_dict( 'Advantage Weights', ptu.get_numpy(weights), )) self.eval_statistics.update(create_stats_ordered_dict( 'Advantage Score', ptu.get_numpy(adv), )) self.eval_statistics.update(create_stats_ordered_dict( 'V1 Predictions', ptu.get_numpy(vf_pred), )) self.eval_statistics['VF Loss'] = np.mean(ptu.get_numpy(vf_loss)) self._n_train_steps_total += 1
def compute_loss( self, batch, skip_statistics=False, ) -> Tuple[SACLosses, LossStatistics]: rewards = batch['rewards'] terminals = batch['terminals'] obs = batch['observations'] actions = batch['actions'] next_obs = batch['next_observations'] collisions = batch['collision'] risk = batch['risk'] """ Policy and Alpha Loss """ dist = self.policy(obs) new_obs_actions, log_pi = dist.rsample_and_logprob() log_pi = log_pi.unsqueeze(-1) if self.use_automatic_entropy_tuning: alpha_loss = -(self.log_alpha * (log_pi + self.target_entropy).detach()).mean() alpha = self.log_alpha.exp() else: alpha_loss = 0 alpha = 1 ## TODO: finetune the loss function BELOW q_new_actions = torch.min( self.qf1(obs, new_obs_actions), self.qf2(obs, new_obs_actions), ) # r_new_actions = torch.min( # self.rf1(obs, new_obs_actions), # self.rf2(obs, new_obs_actions), # ) r_new_actions = self.rf1(obs, new_obs_actions) r_bound = self.delta r_loss_coeff = self.risk_coeff r_diff = r_new_actions - r_bound m = nn.Hardtanh(0, 1) r_policy_loss = m(r_diff) * r_loss_coeff # TODO(cyrushx): Add risk in policy loss. # policy_loss = (alpha*log_pi - q_new_actions + 1.*r_new_actions).mean() policy_loss = (alpha * log_pi - q_new_actions + r_policy_loss).mean() """ QF Loss """ q1_pred = self.qf1(obs, actions) q2_pred = self.qf2(obs, actions) next_dist = self.policy(next_obs) new_next_actions, new_log_pi = next_dist.rsample_and_logprob() new_log_pi = new_log_pi.unsqueeze(-1) target_q_values = torch.min( self.target_qf1(next_obs, new_next_actions), self.target_qf2(next_obs, new_next_actions), ) - alpha * new_log_pi q_target = self.reward_scale * rewards + ( 1. - terminals) * self.discount * target_q_values qf1_loss = self.qf_criterion(q1_pred, q_target.detach()) qf2_loss = self.qf_criterion(q2_pred, q_target.detach()) """ Risk Critic Loss """ r1_pred = self.rf1(obs, actions) r2_pred = self.rf2(obs, actions) # TODO(cyrushx): Replace target_r_values with ground truth risk values. target_r_values = self.target_rf1(next_obs, new_next_actions) # target_r_values = torch.min( # self.target_rf1(next_obs, new_next_actions), # self.target_rf2(next_obs, new_next_actions), # ) - alpha * new_log_pi r_target = risk rf1_loss = self.rf_criterion(r1_pred, r_target.detach()) rf2_loss = self.rf_criterion(r2_pred, r_target.detach()) """ Save some statistics for eval """ eval_statistics = OrderedDict() if not skip_statistics: eval_statistics['QF1 Loss'] = np.mean(ptu.get_numpy(qf1_loss)) eval_statistics['QF2 Loss'] = np.mean(ptu.get_numpy(qf2_loss)) eval_statistics['RF1 Loss'] = np.mean(ptu.get_numpy(rf1_loss)) eval_statistics['RF2 Loss'] = np.mean(ptu.get_numpy(rf2_loss)) eval_statistics['Policy Loss'] = np.mean( ptu.get_numpy(policy_loss)) eval_statistics.update( create_stats_ordered_dict( 'Q1 Predictions', ptu.get_numpy(q1_pred), )) eval_statistics.update( create_stats_ordered_dict( 'Q2 Predictions', ptu.get_numpy(q2_pred), )) eval_statistics.update( create_stats_ordered_dict( 'Q Targets', ptu.get_numpy(q_target), )) eval_statistics.update( create_stats_ordered_dict( 'R1 Predictions', ptu.get_numpy(r1_pred), )) eval_statistics.update( create_stats_ordered_dict( 'R2 Predictions', ptu.get_numpy(r2_pred), )) eval_statistics.update( create_stats_ordered_dict( 'R Targets', ptu.get_numpy(r_target), )) eval_statistics.update( create_stats_ordered_dict( 'Collisions', ptu.get_numpy(collisions), )) eval_statistics.update( create_stats_ordered_dict( 'Log Pis', ptu.get_numpy(log_pi), )) policy_statistics = add_prefix(dist.get_diagnostics(), "policy/") eval_statistics.update(policy_statistics) if self.use_automatic_entropy_tuning: eval_statistics['Alpha'] = alpha.item() eval_statistics['Alpha Loss'] = alpha_loss.item() loss = SACLosses( policy_loss=policy_loss, qf1_loss=qf1_loss, qf2_loss=qf2_loss, alpha_loss=alpha_loss, rf1_loss=rf1_loss, rf2_loss=rf2_loss, ) return loss, eval_statistics
def compute_loss( self, batch, skip_statistics=False, ) -> Tuple[SACLosses, LossStatistics]: rewards = batch['rewards'] terminals = batch['terminals'] obs = batch['observations'] actions = batch['actions'] next_obs = batch['next_observations'] collisions = batch['collision'] import IPython; IPython.embed() """ Policy and Alpha Loss """ dist = self.policy(obs) new_obs_actions, log_pi = dist.rsample_and_logprob() log_pi = log_pi.unsqueeze(-1) if self.use_automatic_entropy_tuning: alpha_loss = -(self.log_alpha * (log_pi + self.target_entropy).detach()).mean() alpha = self.log_alpha.exp() else: alpha_loss = 0 alpha = 1 ## TODO: finetune the loss function BELOW q_new_actions = torch.min( self.qf1(obs, new_obs_actions), self.qf2(obs, new_obs_actions), ) r_new_actions = self.rf1(obs, new_obs_actions) # r_new_actions = torch.min( # self.rf1(obs, new_obs_actions), # self.rf2(obs, new_obs_actions), # ) # Compute an approximate step function, such that f(risk <= r_bound) = 1; f(risk > r_bound) = 0. r_bound = 0.3 r_left = r_bound - r_new_actions m = nn.Hardtanh(0, 0.01) r_step = m(r_left) * 100 policy_loss = (alpha * log_pi - q_new_actions * r_step).mean() """ QF Loss """ q1_pred = self.qf1(obs, actions) q2_pred = self.qf2(obs, actions) next_dist = self.policy(next_obs) new_next_actions, new_log_pi = next_dist.rsample_and_logprob() new_log_pi = new_log_pi.unsqueeze(-1) target_q_values = torch.min( self.target_qf1(next_obs, new_next_actions), self.target_qf2(next_obs, new_next_actions), ) - alpha * new_log_pi q_target = self.reward_scale * rewards + (1. - terminals) * self.discount * target_q_values qf1_loss = self.qf_criterion(q1_pred, q_target.detach()) qf2_loss = self.qf_criterion(q2_pred, q_target.detach()) """ Risk Critic Loss """ r1_pred = self.rf1(obs, actions) r2_pred = self.rf2(obs, actions) target_r_values = self.target_rf1(next_obs, new_next_actions) # target_r_values = torch.min( # self.target_rf1(next_obs, new_next_actions), # self.target_rf2(next_obs, new_next_actions), # ) - alpha * new_log_pi r_target = collisions + (1. - terminals) * (1 - collisions) * target_r_values rf1_loss = self.qf_criterion(r1_pred, r_target.detach()) rf2_loss = self.qf_criterion(r2_pred, r_target.detach()) """ Save some statistics for eval """ eval_statistics = OrderedDict() if not skip_statistics: eval_statistics['QF1 Loss'] = np.mean(ptu.get_numpy(qf1_loss)) eval_statistics['QF2 Loss'] = np.mean(ptu.get_numpy(qf2_loss)) eval_statistics['RF1 Loss'] = np.mean(ptu.get_numpy(rf1_loss)) eval_statistics['RF2 Loss'] = np.mean(ptu.get_numpy(rf2_loss)) eval_statistics['Policy Loss'] = np.mean(ptu.get_numpy( policy_loss )) eval_statistics.update(create_stats_ordered_dict( 'Q1 Predictions', ptu.get_numpy(q1_pred), )) eval_statistics.update(create_stats_ordered_dict( 'Q2 Predictions', ptu.get_numpy(q2_pred), )) eval_statistics.update(create_stats_ordered_dict( 'Q Targets', ptu.get_numpy(q_target), )) eval_statistics.update(create_stats_ordered_dict( 'R1 Predictions', ptu.get_numpy(r1_pred), )) eval_statistics.update(create_stats_ordered_dict( 'R2 Predictions', ptu.get_numpy(r2_pred), )) eval_statistics.update(create_stats_ordered_dict( 'R Targets', ptu.get_numpy(r_target), )) eval_statistics.update(create_stats_ordered_dict( 'Log Pis', ptu.get_numpy(log_pi), )) policy_statistics = add_prefix(dist.get_diagnostics(), "policy/") eval_statistics.update(policy_statistics) if self.use_automatic_entropy_tuning: eval_statistics['Alpha'] = alpha.item() eval_statistics['Alpha Loss'] = alpha_loss.item() loss = SACLosses( policy_loss=policy_loss, qf1_loss=qf1_loss, qf2_loss=qf2_loss, alpha_loss=alpha_loss, rf1_loss=rf1_loss, rf2_loss=rf2_loss, ) return loss, eval_statistics
def get_snapshot(self): snapshot = {} for name, collector in self.path_collectors.items(): snapshot.update( add_prefix(collector.get_snapshot(), name, divider='/'), ) return snapshot
def compute_loss( self, batch, return_statistics=False, ) -> Union[PGRLosses, Tuple[PGRLosses, MutableMapping]]: rewards = batch['rewards'] terminals = batch['terminals'] obs = batch['observations'] actions = batch['actions'] next_obs = batch['next_observations'] eval_statistics = OrderedDict() """ Policy and Alpha Loss """ dist = self.policy(obs) new_obs_actions, log_pi = dist.rsample_and_logprob() log_pi = log_pi.unsqueeze(-1) if self.use_automatic_entropy_tuning: alpha_loss = -(self.log_alpha * (log_pi + self.target_entropy).detach()).mean() else: alpha_loss = 0 alpha = self.get_alpha() if not self._qfs_were_initialized and self._auto_init_qf_bias: average_value = (rewards - alpha * log_pi).mean() self.qf1.last_fc.bias.data = average_value self.qf2.last_fc.bias.data = average_value self._qfs_were_initialized = False q_new_actions = torch.min( self.qf1(obs, new_obs_actions), self.qf2(obs, new_obs_actions), ) """ QF Loss """ bootstrap_value, q1_pred, q2_pred, bootstrap_log_pi_term = ( self.get_bootstrap_stats( obs, actions, next_obs, )) # Use the unscaled bootstrap values/rewards so that the weight on the # the Q-value/reward has the correct scale relative to the other terms raw_discount = self.get_discount_factor( bootstrap_value, rewards, obs, actions, ) discount = (self._weight_on_prior_discount * self.discount + (1 - self._weight_on_prior_discount) * raw_discount) q_target = self._compute_target_q_value( discount, rewards, terminals, bootstrap_value, eval_statistics, return_statistics, ) policy_loss = (alpha * log_pi - q_new_actions).mean() qf1_loss = self.qf_criterion(q1_pred, q_target.detach()) qf2_loss = self.qf_criterion(q2_pred, q_target.detach()) """ Save some statistics for eval """ if return_statistics: eval_statistics.update( create_stats_ordered_dict( 'rewards', ptu.get_numpy(rewards), )) eval_statistics.update( create_stats_ordered_dict( 'bootstrap log pi', ptu.get_numpy(bootstrap_log_pi_term), )) if isinstance(discount, torch.Tensor): eval_statistics.update( create_stats_ordered_dict( 'discount factor', ptu.get_numpy(raw_discount), )) else: eval_statistics.update( create_stats_ordered_dict( 'discount factor', np.array([raw_discount]), )) if isinstance(discount, torch.Tensor): eval_statistics.update( create_stats_ordered_dict( 'used discount factor', ptu.get_numpy(discount), )) else: eval_statistics.update( create_stats_ordered_dict( 'used discount factor', np.array([discount]), )) eval_statistics[ 'weight on prior discount'] = self._weight_on_prior_discount reward_scale = self.reward_scale if isinstance(reward_scale, torch.Tensor): reward_scale = ptu.get_numpy(reward_scale) eval_statistics['reward scale'] = reward_scale eval_statistics['QF1 Loss'] = np.mean(ptu.get_numpy(qf1_loss)) eval_statistics['QF2 Loss'] = np.mean(ptu.get_numpy(qf2_loss)) eval_statistics['Policy Loss'] = np.mean( ptu.get_numpy(policy_loss)) eval_statistics['Policy Q-only Loss'] = np.mean( ptu.get_numpy(-q_new_actions)) eval_statistics['Policy entropy-only Loss'] = np.mean( ptu.get_numpy(alpha * log_pi)) eval_statistics.update( create_stats_ordered_dict( 'Q1 Predictions', ptu.get_numpy(q1_pred), )) eval_statistics.update( create_stats_ordered_dict( 'Q2 Predictions', ptu.get_numpy(q2_pred), )) eval_statistics.update( create_stats_ordered_dict( 'Q Targets', ptu.get_numpy(q_target), )) eval_statistics.update( create_stats_ordered_dict( 'Log Pis', ptu.get_numpy(log_pi), )) policy_statistics = add_prefix(dist.get_diagnostics(), "policy/") eval_statistics.update(policy_statistics) if self.use_automatic_entropy_tuning: eval_statistics['Alpha'] = alpha.item() eval_statistics['Alpha Loss'] = alpha_loss.item() losses = PGRLosses( policy_loss=policy_loss, qf1_loss=qf1_loss, qf2_loss=qf2_loss, alpha_loss=alpha_loss, ) if return_statistics: return losses, eval_statistics else: return losses
def _get_vae_diagnostics(self): return add_prefix( self.model_trainer.get_diagnostics(), prefix='vae_trainer/', )
def train_from_torch(self, batch): rewards = batch['rewards'] terminals = batch['terminals'] obs = batch['observations'] actions = batch['actions'] next_obs = batch['next_observations'] context = batch['context'] if self.reward_transform: rewards = self.reward_transform(rewards) if self.terminal_transform: terminals = self.terminal_transform(terminals) """ Policy and Alpha Loss """ dist, p_z, task_z_with_grad = self.agent( obs, context, return_latent_posterior_and_task_z=True, ) task_z_detached = task_z_with_grad.detach() new_obs_actions, log_pi = dist.rsample_and_logprob() log_pi = log_pi.unsqueeze(1) next_dist = self.agent(next_obs, context) if self._debug_ignore_context: task_z_with_grad = task_z_with_grad * 0 # flattens out the task dimension t, b, _ = obs.size() obs = obs.view(t * b, -1) actions = actions.view(t * b, -1) next_obs = next_obs.view(t * b, -1) unscaled_rewards_flat = rewards.view(t * b, 1) rewards_flat = unscaled_rewards_flat * self.reward_scale terms_flat = terminals.view(t * b, 1) if self.use_automatic_entropy_tuning: alpha_loss = -(self.log_alpha * (log_pi + self.target_entropy).detach()).mean() self.alpha_optimizer.zero_grad() alpha_loss.backward() self.alpha_optimizer.step() alpha = self.log_alpha.exp() else: alpha_loss = 0 alpha = self.alpha """ QF Loss """ if self.backprop_q_loss_into_encoder: q1_pred = self.qf1(obs, actions, task_z_with_grad) q2_pred = self.qf2(obs, actions, task_z_with_grad) else: q1_pred = self.qf1(obs, actions, task_z_detached) q2_pred = self.qf2(obs, actions, task_z_detached) # Make sure policy accounts for squashing functions like tanh correctly! new_next_actions, new_log_pi = next_dist.rsample_and_logprob() new_log_pi = new_log_pi.unsqueeze(1) with torch.no_grad(): target_q_values = torch.min( self.target_qf1(next_obs, new_next_actions, task_z_detached), self.target_qf2(next_obs, new_next_actions, task_z_detached), ) - alpha * new_log_pi q_target = rewards_flat + ( 1. - terms_flat) * self.discount * target_q_values qf1_loss = self.qf_criterion(q1_pred, q_target.detach()) qf2_loss = self.qf_criterion(q2_pred, q_target.detach()) """ Context Encoder Loss """ if self._debug_use_ground_truth_context: kl_div = kl_loss = ptu.zeros(0) else: kl_div = kl_divergence(p_z, self.agent.latent_prior).mean(dim=0).sum() kl_loss = self.kl_lambda * kl_div if self.train_context_decoder: # TODO: change to use a distribution reward_pred = self.context_decoder(obs, actions, task_z_with_grad) reward_prediction_loss = ((reward_pred - unscaled_rewards_flat)**2).mean() context_loss = kl_loss + reward_prediction_loss else: context_loss = kl_loss reward_prediction_loss = ptu.zeros(1) """ Policy Loss """ qf1_new_actions = self.qf1(obs, new_obs_actions, task_z_detached) qf2_new_actions = self.qf2(obs, new_obs_actions, task_z_detached) q_new_actions = torch.min( qf1_new_actions, qf2_new_actions, ) # Advantage-weighted regression if self.vf_K > 1: vs = [] for i in range(self.vf_K): u = dist.sample() q1 = self.qf1(obs, u, task_z_detached) q2 = self.qf2(obs, u, task_z_detached) v = torch.min(q1, q2) # v = q1 vs.append(v) v_pi = torch.cat(vs, 1).mean(dim=1) else: # v_pi = self.qf1(obs, new_obs_actions) v1_pi = self.qf1(obs, new_obs_actions, task_z_detached) v2_pi = self.qf2(obs, new_obs_actions, task_z_detached) v_pi = torch.min(v1_pi, v2_pi) u = actions if self.awr_min_q: q_adv = torch.min(q1_pred, q2_pred) else: q_adv = q1_pred policy_logpp = dist.log_prob(u) if self.use_automatic_beta_tuning: buffer_dist = self.buffer_policy(obs) beta = self.log_beta.exp() kldiv = torch.distributions.kl.kl_divergence(dist, buffer_dist) beta_loss = -1 * (beta * (kldiv - self.beta_epsilon).detach()).mean() self.beta_optimizer.zero_grad() beta_loss.backward() self.beta_optimizer.step() else: beta = self.beta_schedule.get_value(self._n_train_steps_total) beta_loss = ptu.zeros(1) score = q_adv - v_pi if self.mask_positive_advantage: score = torch.sign(score) if self.clip_score is not None: score = torch.clamp(score, max=self.clip_score) weights = batch.get('weights', None) if self.weight_loss and weights is None: if self.normalize_over_batch == True: weights = F.softmax(score / beta, dim=0) elif self.normalize_over_batch == "whiten": adv_mean = torch.mean(score) adv_std = torch.std(score) + 1e-5 normalized_score = (score - adv_mean) / adv_std weights = torch.exp(normalized_score / beta) elif self.normalize_over_batch == "exp": weights = torch.exp(score / beta) elif self.normalize_over_batch == "step_fn": weights = (score > 0).float() elif self.normalize_over_batch == False: weights = score elif self.normalize_over_batch == 'uniform': weights = F.softmax(ptu.ones_like(score) / beta, dim=0) else: raise ValueError(self.normalize_over_batch) weights = weights[:, 0] policy_loss = alpha * log_pi.mean() if self.use_awr_update and self.weight_loss: policy_loss = policy_loss + self.awr_weight * ( -policy_logpp * len(weights) * weights.detach()).mean() elif self.use_awr_update: policy_loss = policy_loss + self.awr_weight * ( -policy_logpp).mean() if self.use_reparam_update: policy_loss = policy_loss + self.train_reparam_weight * ( -q_new_actions).mean() policy_loss = self.rl_weight * policy_loss """ Update networks """ if self._n_train_steps_total % self.q_update_period == 0: if self.train_encoder_decoder: self.context_optimizer.zero_grad() if self.train_agent: self.qf1_optimizer.zero_grad() self.qf2_optimizer.zero_grad() context_loss.backward(retain_graph=True) # retain graph because the encoder is trained by both QF losses qf1_loss.backward(retain_graph=True) qf2_loss.backward() if self.train_agent: self.qf1_optimizer.step() self.qf2_optimizer.step() if self.train_encoder_decoder: self.context_optimizer.step() if self.train_agent: if self._n_train_steps_total % self.policy_update_period == 0 and self.update_policy: self.policy_optimizer.zero_grad() policy_loss.backward() self.policy_optimizer.step() self._num_gradient_steps += 1 """ Soft Updates """ if self._n_train_steps_total % self.target_update_period == 0: ptu.soft_update_from_to(self.qf1, self.target_qf1, self.soft_target_tau) ptu.soft_update_from_to(self.qf2, self.target_qf2, self.soft_target_tau) """ Save some statistics for eval """ if self._need_to_update_eval_statistics: self._need_to_update_eval_statistics = False """ Eval should set this to None. This way, these statistics are only computed for one batch. """ policy_loss = (log_pi - q_new_actions).mean() self.eval_statistics['QF1 Loss'] = np.mean(ptu.get_numpy(qf1_loss)) self.eval_statistics['QF2 Loss'] = np.mean(ptu.get_numpy(qf2_loss)) self.eval_statistics['Policy Loss'] = np.mean( ptu.get_numpy(policy_loss)) self.eval_statistics.update( create_stats_ordered_dict( 'Q1 Predictions', ptu.get_numpy(q1_pred), )) self.eval_statistics.update( create_stats_ordered_dict( 'Q2 Predictions', ptu.get_numpy(q2_pred), )) self.eval_statistics.update( create_stats_ordered_dict( 'Q Targets', ptu.get_numpy(q_target), )) self.eval_statistics['task_embedding/kl_divergence'] = ( ptu.get_numpy(kl_div)) self.eval_statistics['task_embedding/kl_loss'] = ( ptu.get_numpy(kl_loss)) self.eval_statistics['task_embedding/reward_prediction_loss'] = ( ptu.get_numpy(reward_prediction_loss)) self.eval_statistics['task_embedding/context_loss'] = ( ptu.get_numpy(context_loss)) self.eval_statistics.update( create_stats_ordered_dict( 'Log Pis', ptu.get_numpy(log_pi), )) self.eval_statistics.update( create_stats_ordered_dict( 'rewards', ptu.get_numpy(rewards), )) self.eval_statistics.update( create_stats_ordered_dict( 'terminals', ptu.get_numpy(terminals), )) policy_statistics = add_prefix(dist.get_diagnostics(), "policy/") self.eval_statistics.update(policy_statistics) self.eval_statistics.update( create_stats_ordered_dict( 'Advantage Weights', ptu.get_numpy(weights), )) self.eval_statistics.update( create_stats_ordered_dict( 'Advantage Score', ptu.get_numpy(score), )) self.eval_statistics['reparam_weight'] = self.train_reparam_weight self.eval_statistics['num_gradient_steps'] = ( self._num_gradient_steps) if self.use_automatic_entropy_tuning: self.eval_statistics['Alpha'] = alpha.item() self.eval_statistics['Alpha Loss'] = alpha_loss.item() if self.use_automatic_beta_tuning: self.eval_statistics.update({ "adaptive_beta/beta": ptu.get_numpy(beta.mean()), "adaptive_beta/beta loss": ptu.get_numpy(beta_loss.mean()), }) self._n_train_steps_total += 1
def train(self): # first train only the Q function iteration = 0 timer.return_global_times = True timer.reset() for i in range(self.num_batches): if self.use_meta_learning_buffer: train_data = self.meta_replay_buffer.sample_meta_batch( rl_batch_size=self.batch_size, meta_batch_size=self.meta_batch_size, embedding_batch_size=self.task_embedding_batch_size, ) train_data = np_to_pytorch_batch(train_data) else: task_indices = np.random.choice( self.train_tasks, self.meta_batch_size, ) train_data = self.replay_buffer.sample_batch( task_indices, self.batch_size, ) train_data = np_to_pytorch_batch(train_data) obs = train_data['observations'] next_obs = train_data['next_observations'] train_data['observations'] = obs train_data['next_observations'] = next_obs train_data['context'] = ( self.task_embedding_replay_buffer.sample_context( task_indices, self.task_embedding_batch_size, )) timer.start_timer('train', unique=False) self.trainer.train_from_torch(train_data) timer.stop_timer('train') if i % self.logging_period == 0 or i == self.num_batches - 1: stats_with_prefix = add_prefix( self.trainer.eval_statistics, prefix="trainer/") self.trainer.end_epoch(iteration) logger.record_dict(stats_with_prefix) timer.start_timer('extra_fns', unique=False) for fn in self._extra_eval_fns: extra_stats = fn() logger.record_dict(extra_stats) timer.stop_timer('extra_fns') # TODO: evaluate during offline RL # eval_stats = self.get_eval_statistics() # eval_stats_with_prefix = add_prefix(eval_stats, prefix="eval/") # logger.record_dict(eval_stats_with_prefix) logger.record_tabular('iteration', iteration) logger.record_dict(_get_epoch_timings()) try: import os import psutil process = psutil.Process(os.getpid()) logger.record_tabular('RAM Usage (Mb)', int(process.memory_info().rss / 1000000)) except ImportError: pass logger.dump_tabular(with_prefix=True, with_timestamp=False) iteration += 1
def train_from_torch( self, batch, train=True, pretrain=False, ): """ :param batch: :param train: :param pretrain: :return: """ rewards = batch['rewards'] terminals = batch['terminals'] obs = batch['observations'] actions = batch['actions'] next_obs = batch['next_observations'] weights = batch.get('weights', None) if self.reward_transform: rewards = self.reward_transform(rewards) if self.terminal_transform: terminals = self.terminal_transform(terminals) """ Policy and Alpha Loss """ dist = self.policy(obs) new_obs_actions, log_pi = dist.rsample_and_logprob() policy_mle = dist.mle_estimate() if self.brac: buf_dist = self.buffer_policy(obs) buf_log_pi = buf_dist.log_prob(actions) rewards = rewards + buf_log_pi if self.use_automatic_entropy_tuning: alpha_loss = -(self.log_alpha * (log_pi + self.target_entropy).detach()).mean() self.alpha_optimizer.zero_grad() alpha_loss.backward() self.alpha_optimizer.step() alpha = self.log_alpha.exp() else: alpha_loss = 0 alpha = self.alpha """ QF Loss """ q1_pred = self.qf1(obs, actions) q2_pred = self.qf2(obs, actions) # Make sure policy accounts for squashing functions like tanh correctly! next_dist = self.policy(next_obs) new_next_actions, new_log_pi = next_dist.rsample_and_logprob() target_q_values = torch.min( self.target_qf1(next_obs, new_next_actions), self.target_qf2(next_obs, new_next_actions), ) - alpha * new_log_pi q_target = self.reward_scale * rewards + ( 1. - terminals) * self.discount * target_q_values qf1_loss = self.qf_criterion(q1_pred, q_target.detach()) qf2_loss = self.qf_criterion(q2_pred, q_target.detach()) """ Policy Loss """ qf1_new_actions = self.qf1(obs, new_obs_actions) qf2_new_actions = self.qf2(obs, new_obs_actions) q_new_actions = torch.min( qf1_new_actions, qf2_new_actions, ) # Advantage-weighted regression if self.awr_use_mle_for_vf: v1_pi = self.qf1(obs, policy_mle) v2_pi = self.qf2(obs, policy_mle) v_pi = torch.min(v1_pi, v2_pi) else: if self.vf_K > 1: vs = [] for i in range(self.vf_K): u = dist.sample() q1 = self.qf1(obs, u) q2 = self.qf2(obs, u) v = torch.min(q1, q2) # v = q1 vs.append(v) v_pi = torch.cat(vs, 1).mean(dim=1) else: # v_pi = self.qf1(obs, new_obs_actions) v1_pi = self.qf1(obs, new_obs_actions) v2_pi = self.qf2(obs, new_obs_actions) v_pi = torch.min(v1_pi, v2_pi) if self.awr_sample_actions: u = new_obs_actions if self.awr_min_q: q_adv = q_new_actions else: q_adv = qf1_new_actions elif self.buffer_policy_sample_actions: buf_dist = self.buffer_policy(obs) u, _ = buf_dist.rsample_and_logprob() qf1_buffer_actions = self.qf1(obs, u) qf2_buffer_actions = self.qf2(obs, u) q_buffer_actions = torch.min( qf1_buffer_actions, qf2_buffer_actions, ) if self.awr_min_q: q_adv = q_buffer_actions else: q_adv = qf1_buffer_actions else: u = actions if self.awr_min_q: q_adv = torch.min(q1_pred, q2_pred) else: q_adv = q1_pred policy_logpp = dist.log_prob(u) if self.use_automatic_beta_tuning: buffer_dist = self.buffer_policy(obs) beta = self.log_beta.exp() kldiv = torch.distributions.kl.kl_divergence(dist, buffer_dist) beta_loss = -1 * (beta * (kldiv - self.beta_epsilon).detach()).mean() self.beta_optimizer.zero_grad() beta_loss.backward() self.beta_optimizer.step() else: beta = self.beta_schedule.get_value(self._n_train_steps_total) if self.normalize_over_state == "advantage": score = q_adv - v_pi if self.mask_positive_advantage: score = torch.sign(score) elif self.normalize_over_state == "Z": buffer_dist = self.buffer_policy(obs) K = self.Z_K buffer_obs = [] buffer_actions = [] log_bs = [] log_pis = [] for i in range(K): u = buffer_dist.sample() log_b = buffer_dist.log_prob(u) log_pi = dist.log_prob(u) buffer_obs.append(obs) buffer_actions.append(u) log_bs.append(log_b) log_pis.append(log_pi) buffer_obs = torch.cat(buffer_obs, 0) buffer_actions = torch.cat(buffer_actions, 0) p_buffer = torch.exp(torch.cat(log_bs, 0).sum(dim=1, )) log_pi = torch.cat(log_pis, 0) log_pi = log_pi.sum(dim=1, ) q1_b = self.qf1(buffer_obs, buffer_actions) q2_b = self.qf2(buffer_obs, buffer_actions) q_b = torch.min(q1_b, q2_b) q_b = torch.reshape(q_b, (-1, K)) adv_b = q_b - v_pi # if self._n_train_steps_total % 100 == 0: # import ipdb; ipdb.set_trace() # Z = torch.exp(adv_b / beta).mean(dim=1, keepdim=True) # score = torch.exp((q_adv - v_pi) / beta) / Z # score = score / sum(score) logK = torch.log(ptu.tensor(float(K))) logZ = torch.logsumexp(adv_b / beta - logK, dim=1, keepdim=True) logS = (q_adv - v_pi) / beta - logZ # logZ = torch.logsumexp(q_b/beta - logK, dim=1, keepdim=True) # logS = q_adv/beta - logZ score = F.softmax(logS, dim=0) # score / sum(score) else: error if self.clip_score is not None: score = torch.clamp(score, max=self.clip_score) if self.weight_loss and weights is None: if self.normalize_over_batch: weights = F.softmax(score / beta, dim=0) elif self.normalize_over_batch == "whiten": adv_mean = torch.mean(score) adv_std = torch.std(score) + 1e-5 normalized_score = (score - adv_mean) / adv_std weights = torch.exp(normalized_score / beta) elif self.normalize_over_batch == "exp": weights = torch.exp(score / beta) elif self.normalize_over_batch == "step_fn": weights = (score > 0).float() elif not self.normalize_over_batch: weights = score else: error weights = weights[:, 0] policy_loss = alpha * log_pi.mean() if self.use_awr_update and self.weight_loss: policy_loss = policy_loss + self.awr_weight * ( -policy_logpp * len(weights) * weights.detach()).mean() elif self.use_awr_update: policy_loss = policy_loss + self.awr_weight * ( -policy_logpp).mean() if self.use_reparam_update: policy_loss = policy_loss + self.reparam_weight * ( -q_new_actions).mean() policy_loss = self.rl_weight * policy_loss if self.compute_bc: train_policy_loss, train_logp_loss, train_mse_loss, _ = self.run_bc_batch( self.demo_train_buffer, self.policy) policy_loss = policy_loss + self.bc_weight * train_policy_loss if not pretrain and self.buffer_policy_reset_period > 0 and self._n_train_steps_total % self.buffer_policy_reset_period == 0: del self.buffer_policy_optimizer self.buffer_policy_optimizer = self.optimizer_class( self.buffer_policy.parameters(), weight_decay=self.policy_weight_decay, lr=self.policy_lr, ) self.optimizers[self.buffer_policy] = self.buffer_policy_optimizer for i in range(self.num_buffer_policy_train_steps_on_reset): if self.train_bc_on_rl_buffer: if self.advantage_weighted_buffer_loss: buffer_dist = self.buffer_policy(obs) buffer_u = actions buffer_new_obs_actions, _ = buffer_dist.rsample_and_logprob( ) buffer_policy_logpp = buffer_dist.log_prob(buffer_u) buffer_policy_logpp = buffer_policy_logpp[:, None] buffer_q1_pred = self.qf1(obs, buffer_u) buffer_q2_pred = self.qf2(obs, buffer_u) buffer_q_adv = torch.min(buffer_q1_pred, buffer_q2_pred) buffer_v1_pi = self.qf1(obs, buffer_new_obs_actions) buffer_v2_pi = self.qf2(obs, buffer_new_obs_actions) buffer_v_pi = torch.min(buffer_v1_pi, buffer_v2_pi) buffer_score = buffer_q_adv - buffer_v_pi buffer_weights = F.softmax(buffer_score / beta, dim=0) buffer_policy_loss = self.awr_weight * ( -buffer_policy_logpp * len(buffer_weights) * buffer_weights.detach()).mean() else: buffer_policy_loss, buffer_train_logp_loss, buffer_train_mse_loss, _ = self.run_bc_batch( self.replay_buffer.train_replay_buffer, self.buffer_policy) self.buffer_policy_optimizer.zero_grad() buffer_policy_loss.backward(retain_graph=True) self.buffer_policy_optimizer.step() if self.train_bc_on_rl_buffer: if self.advantage_weighted_buffer_loss: buffer_dist = self.buffer_policy(obs) buffer_u = actions buffer_new_obs_actions, _ = buffer_dist.rsample_and_logprob() buffer_policy_logpp = buffer_dist.log_prob(buffer_u) buffer_policy_logpp = buffer_policy_logpp[:, None] buffer_q1_pred = self.qf1(obs, buffer_u) buffer_q2_pred = self.qf2(obs, buffer_u) buffer_q_adv = torch.min(buffer_q1_pred, buffer_q2_pred) buffer_v1_pi = self.qf1(obs, buffer_new_obs_actions) buffer_v2_pi = self.qf2(obs, buffer_new_obs_actions) buffer_v_pi = torch.min(buffer_v1_pi, buffer_v2_pi) buffer_score = buffer_q_adv - buffer_v_pi buffer_weights = F.softmax(buffer_score / beta, dim=0) buffer_policy_loss = self.awr_weight * ( -buffer_policy_logpp * len(buffer_weights) * buffer_weights.detach()).mean() else: buffer_policy_loss, buffer_train_logp_loss, buffer_train_mse_loss, _ = self.run_bc_batch( self.replay_buffer.train_replay_buffer, self.buffer_policy) """ Update networks """ if self._n_train_steps_total % self.q_update_period == 0: self.qf1_optimizer.zero_grad() qf1_loss.backward() self.qf1_optimizer.step() self.qf2_optimizer.zero_grad() qf2_loss.backward() self.qf2_optimizer.step() if self._n_train_steps_total % self.policy_update_period == 0 and self.update_policy: self.policy_optimizer.zero_grad() policy_loss.backward() self.policy_optimizer.step() if self.train_bc_on_rl_buffer and self._n_train_steps_total % self.policy_update_period == 0: self.buffer_policy_optimizer.zero_grad() buffer_policy_loss.backward() self.buffer_policy_optimizer.step() """ Soft Updates """ if self._n_train_steps_total % self.target_update_period == 0: ptu.soft_update_from_to(self.qf1, self.target_qf1, self.soft_target_tau) ptu.soft_update_from_to(self.qf2, self.target_qf2, self.soft_target_tau) """ Save some statistics for eval """ if self._need_to_update_eval_statistics: self._need_to_update_eval_statistics = False """ Eval should set this to None. This way, these statistics are only computed for one batch. """ policy_loss = (log_pi - q_new_actions).mean() self.eval_statistics['QF1 Loss'] = np.mean(ptu.get_numpy(qf1_loss)) self.eval_statistics['QF2 Loss'] = np.mean(ptu.get_numpy(qf2_loss)) self.eval_statistics['Policy Loss'] = np.mean( ptu.get_numpy(policy_loss)) self.eval_statistics.update( create_stats_ordered_dict( 'Q1 Predictions', ptu.get_numpy(q1_pred), )) self.eval_statistics.update( create_stats_ordered_dict( 'Q2 Predictions', ptu.get_numpy(q2_pred), )) self.eval_statistics.update( create_stats_ordered_dict( 'Q Targets', ptu.get_numpy(q_target), )) self.eval_statistics.update( create_stats_ordered_dict( 'Log Pis', ptu.get_numpy(log_pi), )) self.eval_statistics.update( create_stats_ordered_dict( 'rewards', ptu.get_numpy(rewards), )) self.eval_statistics.update( create_stats_ordered_dict( 'terminals', ptu.get_numpy(terminals), )) policy_statistics = add_prefix(dist.get_diagnostics(), "policy/") self.eval_statistics.update(policy_statistics) self.eval_statistics.update( create_stats_ordered_dict( 'Advantage Weights', ptu.get_numpy(weights), )) self.eval_statistics.update( create_stats_ordered_dict( 'Advantage Score', ptu.get_numpy(score), )) if self.normalize_over_state == "Z": self.eval_statistics.update( create_stats_ordered_dict( 'logZ', ptu.get_numpy(logZ), )) if self.use_automatic_entropy_tuning: self.eval_statistics['Alpha'] = alpha.item() self.eval_statistics['Alpha Loss'] = alpha_loss.item() if self.compute_bc: test_policy_loss, test_logp_loss, test_mse_loss, _ = self.run_bc_batch( self.demo_test_buffer, self.policy) self.eval_statistics.update({ "bc/Train Logprob Loss": ptu.get_numpy(train_logp_loss), "bc/Test Logprob Loss": ptu.get_numpy(test_logp_loss), "bc/Train MSE": ptu.get_numpy(train_mse_loss), "bc/Test MSE": ptu.get_numpy(test_mse_loss), "bc/train_policy_loss": ptu.get_numpy(train_policy_loss), "bc/test_policy_loss": ptu.get_numpy(test_policy_loss), }) if self.train_bc_on_rl_buffer: _, buffer_train_logp_loss, _, _ = self.run_bc_batch( self.replay_buffer.train_replay_buffer, self.buffer_policy) _, buffer_test_logp_loss, _, _ = self.run_bc_batch( self.replay_buffer.validation_replay_buffer, self.buffer_policy) buffer_dist = self.buffer_policy(obs) kldiv = torch.distributions.kl.kl_divergence(dist, buffer_dist) _, train_offline_logp_loss, _, _ = self.run_bc_batch( self.demo_train_buffer, self.buffer_policy) _, test_offline_logp_loss, _, _ = self.run_bc_batch( self.demo_test_buffer, self.buffer_policy) self.eval_statistics.update({ "buffer_policy/Train Online Logprob": -1 * ptu.get_numpy(buffer_train_logp_loss), "buffer_policy/Test Online Logprob": -1 * ptu.get_numpy(buffer_test_logp_loss), "buffer_policy/Train Offline Logprob": -1 * ptu.get_numpy(train_offline_logp_loss), "buffer_policy/Test Offline Logprob": -1 * ptu.get_numpy(test_offline_logp_loss), "buffer_policy/train_policy_loss": ptu.get_numpy(buffer_policy_loss), # "buffer_policy/test_policy_loss": ptu.get_numpy(buffer_test_policy_loss), "buffer_policy/kl_div": ptu.get_numpy(kldiv.mean()), }) if self.use_automatic_beta_tuning: self.eval_statistics.update({ "adaptive_beta/beta": ptu.get_numpy(beta.mean()), "adaptive_beta/beta loss": ptu.get_numpy(beta_loss.mean()), }) if self.validation_qlearning: train_data = self.replay_buffer.validation_replay_buffer.random_batch( self.bc_batch_size) train_data = np_to_pytorch_batch(train_data) obs = train_data['observations'] next_obs = train_data['next_observations'] # goals = train_data['resampled_goals'] train_data[ 'observations'] = obs # torch.cat((obs, goals), dim=1) train_data[ 'next_observations'] = next_obs # torch.cat((next_obs, goals), dim=1) self.test_from_torch(train_data) self._n_train_steps_total += 1
def compute_loss( self, batch, skip_statistics=False, ) -> Tuple[SACLosses, LossStatistics]: rewards = batch['rewards'] terminals = batch['terminals'] obs = batch['observations'] actions = batch['actions'] next_obs = batch['next_observations'] """ Policy and Alpha Loss """ dist = self.policy(obs) new_obs_actions, log_pi = dist.rsample_and_logprob() log_pi = log_pi.unsqueeze(-1) if self.use_automatic_entropy_tuning: alpha_loss = -(self.log_alpha * (log_pi + self.target_entropy).detach()).mean() alpha = self.log_alpha.exp() else: alpha_loss = 0 alpha = 1 q_new_actions = torch.min( self.qf1(obs, new_obs_actions), self.qf2(obs, new_obs_actions), ) policy_loss = (alpha * log_pi - q_new_actions).mean() """ QF Loss """ q1_pred = self.qf1(obs, actions) q2_pred = self.qf2(obs, actions) next_dist = self.policy(next_obs) new_next_actions, new_log_pi = next_dist.rsample_and_logprob() new_log_pi = new_log_pi.unsqueeze(-1) target_q_values = torch.min( self.target_qf1(next_obs, new_next_actions), self.target_qf2(next_obs, new_next_actions), ) - alpha * new_log_pi q_target = (self.reward_scale * rewards + (1. - terminals) * self.discount * target_q_values) qf1_loss = self.qf_criterion(q1_pred, q_target.detach()) qf2_loss = self.qf_criterion(q2_pred, q_target.detach()) """ Save some statistics for eval """ eval_statistics = OrderedDict() if not skip_statistics: eval_statistics['QF1 Loss'] = np.mean(ptu.get_numpy(qf1_loss)) eval_statistics['QF2 Loss'] = np.mean(ptu.get_numpy(qf2_loss)) eval_statistics['Policy Loss'] = np.mean( ptu.get_numpy(policy_loss)) eval_statistics.update( create_stats_ordered_dict( 'Q1 Predictions', ptu.get_numpy(q1_pred), )) eval_statistics.update( create_stats_ordered_dict( 'Q2 Predictions', ptu.get_numpy(q2_pred), )) eval_statistics.update( create_stats_ordered_dict( 'Q Targets', ptu.get_numpy(q_target), )) eval_statistics.update( create_stats_ordered_dict( 'Log Pis', ptu.get_numpy(log_pi), )) eval_statistics.update( create_stats_ordered_dict( 'rewards', ptu.get_numpy(rewards), )) eval_statistics.update( create_stats_ordered_dict( 'terminals', ptu.get_numpy(terminals), )) reward_scale = self.reward_scale if isinstance(reward_scale, torch.Tensor): reward_scale = ptu.get_numpy(reward_scale) eval_statistics['reward scale'] = reward_scale policy_statistics = add_prefix(dist.get_diagnostics(), "policy/") eval_statistics.update(policy_statistics) if self.use_automatic_entropy_tuning: eval_statistics['Alpha'] = alpha.item() eval_statistics['Alpha Loss'] = alpha_loss.item() loss = SACLosses( policy_loss=policy_loss, qf1_loss=qf1_loss, qf2_loss=qf2_loss, alpha_loss=alpha_loss, ) return loss, eval_statistics
def get_diagnostics(self): diagnostics = OrderedDict() for name, collector in self.path_collectors.items(): diagnostics.update( add_prefix(collector.get_diagnostics(), name, divider='/'), ) return diagnostics