Пример #1
0
    def test_MLE_rate_estimation(self):
        env_params = attention_allocation.Params()
        env_params.prior_incident_counts = (500, 500)
        env_params.n_attention_units = 5

        # pylint: disable=g-long-lambda
        agent_params = allocation_agents.MLEProbabilityMatchingAgentParams()

        agent_params.feature_selection_fn = lambda obs: allocation_agents._get_added_vector_features(
            obs, env_params.n_locations, keys=['incidents_seen'])
        agent_params.interval = 200
        agent_params.epsilon = 0

        env = attention_allocation.LocationAllocationEnv(env_params)
        agent = allocation_agents.MLEProbabilityMatchingAgent(
            action_space=env.action_space,
            reward_fn=lambda x: None,
            observation_space=env.observation_space,
            params=agent_params)
        seed = 0
        agent.rng.seed(seed)
        env.seed(seed)
        observation = env.reset()
        done = False
        steps = 200
        for _ in range(steps):
            action = agent.act(observation, done)
            observation, _, done, _ = env.step(action)

        self.assertTrue(
            np.all(
                np.isclose(list(agent.beliefs),
                           list(env_params.incident_rates),
                           atol=0.5)))
Пример #2
0
 def test_can_interact_with_attention_env(self):
     env = attention_allocation.LocationAllocationEnv()
     agent = allocation_agents.MLEProbabilityMatchingAgent(
         action_space=env.action_space,
         observation_space=env.observation_space,
         reward_fn=rewards.VectorSumReward('incidents_seen'),
         params=None)
     test_util.run_test_simulation(env=env, agent=agent)