def try_evaluate(itr: int, policy_type: str): assert policy_type in ["Running"] update_time = itr * v['bc']['eval_freq'] # eval real reward real_return_det = eval.evaluate_real_return(sac_agent.get_action, env_fn(), v['bc']['eval_episodes'], v['env']['T'], True) print(f"real det return avg: {real_return_det:.2f}") logger.record_tabular("Real Det Return", round(real_return_det, 2)) real_return_sto = eval.evaluate_real_return(sac_agent.get_action, env_fn(), v['bc']['eval_episodes'], v['env']['T'], False) print(f"real sto return avg: {real_return_sto:.2f}") logger.record_tabular("Real Sto Return", round(real_return_sto, 2)) logger.record_tabular(f"{policy_type} Update Time", update_time) return real_return_det, real_return_sto
def try_evaluate(itr: int, policy_type: str, sac_info): assert policy_type in ["Running"] update_time = itr * v['reward']['gradient_step'] env_steps = itr * v['sac']['epochs'] * v['env']['T'] agent_emp_states = samples[0].copy() assert agent_emp_states.shape[0] == v['irl']['training_trajs'] metrics = eval.KL_summary(expert_samples, agent_emp_states.reshape(-1, agent_emp_states.shape[2]), env_steps, policy_type) # eval real reward real_return_det = eval.evaluate_real_return(sac_agent.get_action, env_fn(), v['irl']['eval_episodes'], v['env']['T'], True) metrics['Real Det Return'] = real_return_det print(f"real det return avg: {real_return_det:.2f}") logger.record_tabular("Real Det Return", round(real_return_det, 2)) real_return_sto = eval.evaluate_real_return(sac_agent.get_action, env_fn(), v['irl']['eval_episodes'], v['env']['T'], False) metrics['Real Sto Return'] = real_return_sto print(f"real sto return avg: {real_return_sto:.2f}") logger.record_tabular("Real Sto Return", round(real_return_sto, 2)) if v['obj'] in ["emd"]: eval_len = int(0.1 * len(critic_loss["main"])) emd = -np.array(critic_loss["main"][-eval_len:]).mean() metrics['emd'] = emd logger.record_tabular(f"{policy_type} EMD", emd) # plot_disc(v['obj'], log_folder, env_steps, # sac_info, critic_loss if v['obj'] in ["emd"] else disc_loss, metrics) if "PointMaze" in env_name: visual_disc(agent_emp_states, reward_func.get_scalar_reward, disc.log_density_ratio, v['obj'], log_folder, env_steps, gym_env.range_lim, sac_info, disc_loss, metrics) logger.record_tabular(f"{policy_type} Update Time", update_time) logger.record_tabular(f"{policy_type} Env Steps", env_steps) return real_return_det, real_return_sto
def try_evaluate(self, policy_type: str, epoch): assert policy_type == "Running" update_time = self._n_train_steps_total * self.num_update_loops_per_train_call * self.num_disc_updates_per_loop_iter env_steps = self._n_env_steps_total sac_info = [None, None, None, None] samples = self.collect_fn( self.test_env, self.agent, n=self.training_trajs, state_indices=self.state_indices.detach().cpu().numpy()) agent_emp_states = samples[0] expert_samples = self.target_state_buffer[ 0] if self.mode == 'airl' else self.target_state_buffer metrics = eval.KL_summary( expert_samples, agent_emp_states.reshape(-1, agent_emp_states.shape[2]), env_steps, policy_type, self.v['task']['task_name'] == 'uniform' if self.expert_IS else False) if not self.expert_IS: real_return_det = eval.evaluate_real_return( self.agent.get_action, self.test_env, self.v['irl']['eval_episodes'], self.v['env']['T'], True) metrics["Real Det Return"] = real_return_det print(f"real det return avg: {real_return_det:.2f}") self.logger.record_tabular("Real Det Return", round(real_return_det, 2)) real_return_sto = eval.evaluate_real_return( self.agent.get_action, self.test_env, self.v['irl']['eval_episodes'], self.v['env']['T'], False) metrics["Real Sto Return"] = real_return_sto print(f"real sto return avg: {real_return_sto:.2f}") self.logger.record_tabular("Real Sto Return", round(real_return_sto, 2)) if real_return_det > self.max_real_return_det and real_return_sto > self.max_real_return_sto: self.max_real_return_det, self.max_real_return_sto = real_return_det, real_return_sto save_name = osp.join( self.logger.get_dir(), f"model/disc_epoch{epoch}_det{self.max_real_return_det:.0f}_sto{self.max_real_return_sto:.0f}.pkl" ) if self.mode != 'airl': torch.save(self.discriminator.state_dict(), save_name) else: torch.save(self.reward_model.state_dict(), save_name) self.logger.record_tabular(f"{policy_type} Update Time", update_time) self.logger.record_tabular(f"{policy_type} Env Steps", env_steps) self.logger.dump_tabular() if self.expert_IS: if self.v['env']['env_name'] == 'ReacherDraw-v0': train_plot.plot_submission(agent_emp_states, self.get_scalar_reward, self.mode, self.logger.get_dir(), env_steps, self.range_lim, metrics, self.rho_expert) else: train_plot.plot_adv_irl(agent_emp_states, self.get_scalar_reward, self.logger.get_dir(), env_steps, self.range_lim, sac_info, metrics)