Example #1
0
File: bc.py Project: twni2016/f-IRL
def try_evaluate(itr: int, policy_type: str):
    assert policy_type in ["Running"]
    update_time = itr * v['bc']['eval_freq']


    # eval real reward
    real_return_det = eval.evaluate_real_return(sac_agent.get_action, env_fn(), 
                                            v['bc']['eval_episodes'], v['env']['T'], True)

    print(f"real det return avg: {real_return_det:.2f}")
    logger.record_tabular("Real Det Return", round(real_return_det, 2))

    real_return_sto = eval.evaluate_real_return(sac_agent.get_action, env_fn(), 
                                            v['bc']['eval_episodes'], v['env']['T'], False)

    print(f"real sto return avg: {real_return_sto:.2f}")
    logger.record_tabular("Real Sto Return", round(real_return_sto, 2))

    logger.record_tabular(f"{policy_type} Update Time", update_time)

    return real_return_det, real_return_sto
Example #2
0
def try_evaluate(itr: int, policy_type: str, sac_info):
    assert policy_type in ["Running"]
    update_time = itr * v['reward']['gradient_step']
    env_steps = itr * v['sac']['epochs'] * v['env']['T']
    agent_emp_states = samples[0].copy()
    assert agent_emp_states.shape[0] == v['irl']['training_trajs']

    metrics = eval.KL_summary(expert_samples, agent_emp_states.reshape(-1, agent_emp_states.shape[2]), 
                         env_steps, policy_type)
    # eval real reward
    real_return_det = eval.evaluate_real_return(sac_agent.get_action, env_fn(), 
                                            v['irl']['eval_episodes'], v['env']['T'], True)
    metrics['Real Det Return'] = real_return_det
    print(f"real det return avg: {real_return_det:.2f}")
    logger.record_tabular("Real Det Return", round(real_return_det, 2))

    real_return_sto = eval.evaluate_real_return(sac_agent.get_action, env_fn(), 
                                            v['irl']['eval_episodes'], v['env']['T'], False)
    metrics['Real Sto Return'] = real_return_sto
    print(f"real sto return avg: {real_return_sto:.2f}")
    logger.record_tabular("Real Sto Return", round(real_return_sto, 2))

    if v['obj'] in ["emd"]:
        eval_len = int(0.1 * len(critic_loss["main"]))
        emd = -np.array(critic_loss["main"][-eval_len:]).mean()
        metrics['emd'] = emd
        logger.record_tabular(f"{policy_type} EMD", emd)
    
    # plot_disc(v['obj'], log_folder, env_steps, 
    #     sac_info, critic_loss if v['obj'] in ["emd"] else disc_loss, metrics)
    if "PointMaze" in env_name:
        visual_disc(agent_emp_states, reward_func.get_scalar_reward, disc.log_density_ratio, v['obj'],
                log_folder, env_steps, gym_env.range_lim,
                sac_info, disc_loss, metrics)

    logger.record_tabular(f"{policy_type} Update Time", update_time)
    logger.record_tabular(f"{policy_type} Env Steps", env_steps)

    return real_return_det, real_return_sto
Example #3
0
    def try_evaluate(self, policy_type: str, epoch):
        assert policy_type == "Running"
        update_time = self._n_train_steps_total * self.num_update_loops_per_train_call * self.num_disc_updates_per_loop_iter
        env_steps = self._n_env_steps_total

        sac_info = [None, None, None, None]
        samples = self.collect_fn(
            self.test_env,
            self.agent,
            n=self.training_trajs,
            state_indices=self.state_indices.detach().cpu().numpy())
        agent_emp_states = samples[0]

        expert_samples = self.target_state_buffer[
            0] if self.mode == 'airl' else self.target_state_buffer
        metrics = eval.KL_summary(
            expert_samples,
            agent_emp_states.reshape(-1, agent_emp_states.shape[2]), env_steps,
            policy_type, self.v['task']['task_name'] == 'uniform'
            if self.expert_IS else False)

        if not self.expert_IS:
            real_return_det = eval.evaluate_real_return(
                self.agent.get_action, self.test_env,
                self.v['irl']['eval_episodes'], self.v['env']['T'], True)
            metrics["Real Det Return"] = real_return_det
            print(f"real det return avg: {real_return_det:.2f}")
            self.logger.record_tabular("Real Det Return",
                                       round(real_return_det, 2))

            real_return_sto = eval.evaluate_real_return(
                self.agent.get_action, self.test_env,
                self.v['irl']['eval_episodes'], self.v['env']['T'], False)
            metrics["Real Sto Return"] = real_return_sto
            print(f"real sto return avg: {real_return_sto:.2f}")
            self.logger.record_tabular("Real Sto Return",
                                       round(real_return_sto, 2))

            if real_return_det > self.max_real_return_det and real_return_sto > self.max_real_return_sto:
                self.max_real_return_det, self.max_real_return_sto = real_return_det, real_return_sto
                save_name = osp.join(
                    self.logger.get_dir(),
                    f"model/disc_epoch{epoch}_det{self.max_real_return_det:.0f}_sto{self.max_real_return_sto:.0f}.pkl"
                )
                if self.mode != 'airl':
                    torch.save(self.discriminator.state_dict(), save_name)
                else:
                    torch.save(self.reward_model.state_dict(), save_name)

        self.logger.record_tabular(f"{policy_type} Update Time", update_time)
        self.logger.record_tabular(f"{policy_type} Env Steps", env_steps)
        self.logger.dump_tabular()

        if self.expert_IS:
            if self.v['env']['env_name'] == 'ReacherDraw-v0':
                train_plot.plot_submission(agent_emp_states,
                                           self.get_scalar_reward, self.mode,
                                           self.logger.get_dir(), env_steps,
                                           self.range_lim, metrics,
                                           self.rho_expert)
            else:
                train_plot.plot_adv_irl(agent_emp_states,
                                        self.get_scalar_reward,
                                        self.logger.get_dir(), env_steps,
                                        self.range_lim, sac_info, metrics)