示例#1
0
    def evaluate(self, env_fn, hparams, stochastic):
        if stochastic:
            policy_to_actions_lambda = lambda policy: policy.sample()
        else:
            policy_to_actions_lambda = lambda policy: policy.mode()
        hparams.add_hparam("policy_to_actions_lambda",
                           policy_to_actions_lambda)
        hparams.add_hparam("force_beginning_resets", False)
        hparams.add_hparam("env_fn", env_fn)
        hparams.add_hparam("frame_stack_size", self.frame_stack_size)

        rl_trainer_lib.evaluate(hparams, self.agent_model_dir)
def evaluate_single_config(hparams, agent_model_dir):
    """Evaluate the PPO agent in the real environment."""
    eval_hparams = trainer_lib.create_hparams(hparams.ppo_params)
    eval_hparams.num_agents = hparams.num_agents
    env = setup_env(hparams, batch_size=hparams.num_agents)
    environment_spec = rl.standard_atari_env_spec(env)
    eval_hparams.add_hparam("environment_spec", environment_spec)
    eval_hparams.add_hparam("policy_to_actions_lambda",
                            hparams.policy_to_actions_lambda)

    env.start_new_epoch(0)
    rl_trainer_lib.evaluate(eval_hparams, agent_model_dir)
    rollouts = env.current_epoch_rollouts()[:hparams.num_agents]
    env.close()

    assert len(rollouts) == hparams.num_agents
    return tuple(
        compute_mean_reward(rollouts, clipped) for clipped in (True, False))