def test_episode_done_raises_error(self): env = attention_allocation.LocationAllocationEnv() agent = allocation_agents.NaiveProbabilityMatchingAgent( action_space=env.action_space, observation_space=env.observation_space, reward_fn=None ) observation = env.reset() with self.assertRaises(core.EpisodeDoneError): agent.act(observation, done=True)
def test__allocate_by_counts(self): """Check allocation proportions match probabilities from counts.""" env = attention_allocation.LocationAllocationEnv() agent = allocation_agents.NaiveProbabilityMatchingAgent( action_space=env.action_space, observation_space=env.observation_space, reward_fn=None ) counts = [3, 6, 8] n_resource = 20 n_samples = 100 samples = [agent._allocate(n_resource, counts) for _ in range(n_samples)] counts_normalized = [(count / float(np.sum(counts))) for count in counts] samples_normalized = [(count / float(np.sum(samples))) for count in np.sum(samples, axis=0)] self.assertTrue(np.all(np.isclose(counts_normalized, samples_normalized, atol=0.05)))
def test_allocate_by_counts_zero(self): """Check allocations are even when counts are zero.""" env = attention_allocation.LocationAllocationEnv() agent = allocation_agents.NaiveProbabilityMatchingAgent( action_space=env.action_space, observation_space=env.observation_space, reward_fn=None ) counts = [0, 0, 0] n_resource = 15 n_samples = 100 samples = [agent._allocate(n_resource, counts) for _ in range(n_samples)] mean_samples = np.sum(samples, axis=0) / float(n_samples) expected_mean = n_resource / float(len(counts)) std_dev = np.std(samples) means_close = [np.abs(mean - expected_mean) < std_dev for mean in mean_samples] self.assertTrue(np.all(means_close))
def test_update_counts(self): """Check that counts are updated correctly given an observation.""" env = attention_allocation.LocationAllocationEnv() agent_params = allocation_agents.NaiveProbabilityMatchingAgentParams() agent_params.decay_prob = 0 agent = allocation_agents.NaiveProbabilityMatchingAgent( action_space=env.action_space, observation_space=env.observation_space, reward_fn=None, params=agent_params, ) counts = [3, 6, 8] observation = np.array([1, 2, 0]) updated_counts = agent._update_beliefs(observation, counts) self.assertTrue(np.all(np.equal(updated_counts, [4, 8, 8])))
def test_can_interact_with_attention_env(self): env = attention_allocation.LocationAllocationEnv() agent = allocation_agents.NaiveProbabilityMatchingAgent( action_space=env.action_space, observation_space=env.observation_space, reward_fn=None ) test_util.run_test_simulation(env=env, agent=agent)