예제 #1
0
 def test_episode_done_raises_error(self):
     env = attention_allocation.LocationAllocationEnv()
     agent = allocation_agents.NaiveProbabilityMatchingAgent(
         action_space=env.action_space, observation_space=env.observation_space, reward_fn=None
     )
     observation = env.reset()
     with self.assertRaises(core.EpisodeDoneError):
         agent.act(observation, done=True)
예제 #2
0
 def test__allocate_by_counts(self):
     """Check allocation proportions match probabilities from counts."""
     env = attention_allocation.LocationAllocationEnv()
     agent = allocation_agents.NaiveProbabilityMatchingAgent(
         action_space=env.action_space, observation_space=env.observation_space, reward_fn=None
     )
     counts = [3, 6, 8]
     n_resource = 20
     n_samples = 100
     samples = [agent._allocate(n_resource, counts) for _ in range(n_samples)]
     counts_normalized = [(count / float(np.sum(counts))) for count in counts]
     samples_normalized = [(count / float(np.sum(samples))) for count in np.sum(samples, axis=0)]
     self.assertTrue(np.all(np.isclose(counts_normalized, samples_normalized, atol=0.05)))
예제 #3
0
 def test_allocate_by_counts_zero(self):
     """Check allocations are even when counts are zero."""
     env = attention_allocation.LocationAllocationEnv()
     agent = allocation_agents.NaiveProbabilityMatchingAgent(
         action_space=env.action_space, observation_space=env.observation_space, reward_fn=None
     )
     counts = [0, 0, 0]
     n_resource = 15
     n_samples = 100
     samples = [agent._allocate(n_resource, counts) for _ in range(n_samples)]
     mean_samples = np.sum(samples, axis=0) / float(n_samples)
     expected_mean = n_resource / float(len(counts))
     std_dev = np.std(samples)
     means_close = [np.abs(mean - expected_mean) < std_dev for mean in mean_samples]
     self.assertTrue(np.all(means_close))
예제 #4
0
 def test_update_counts(self):
     """Check that counts are updated correctly given an observation."""
     env = attention_allocation.LocationAllocationEnv()
     agent_params = allocation_agents.NaiveProbabilityMatchingAgentParams()
     agent_params.decay_prob = 0
     agent = allocation_agents.NaiveProbabilityMatchingAgent(
         action_space=env.action_space,
         observation_space=env.observation_space,
         reward_fn=None,
         params=agent_params,
     )
     counts = [3, 6, 8]
     observation = np.array([1, 2, 0])
     updated_counts = agent._update_beliefs(observation, counts)
     self.assertTrue(np.all(np.equal(updated_counts, [4, 8, 8])))
예제 #5
0
 def test_can_interact_with_attention_env(self):
     env = attention_allocation.LocationAllocationEnv()
     agent = allocation_agents.NaiveProbabilityMatchingAgent(
         action_space=env.action_space, observation_space=env.observation_space, reward_fn=None
     )
     test_util.run_test_simulation(env=env, agent=agent)