Beispiel #1
0
 def test_can_interact_with_attention_env(self):
     env = attention_allocation.LocationAllocationEnv()
     agent = allocation_agents.MLEGreedyAgent(
         action_space=env.action_space,
         observation_space=env.observation_space,
         reward_fn=rewards.VectorSumReward('incidents_seen'))
     test_util.run_test_simulation(env=env, agent=agent)
 def test_can_interact_with_attention_env(self):
   env = attention_allocation.LocationAllocationEnv()
   agent = allocation_agents.MLEGreedyAgent(
       action_space=env.action_space,
       observation_space=env.observation_space,
       reward_fn=None)
   test_util.run_test_simulation(env=env, agent=agent)
 def test_allocate_beliefs_greedy(self):
   env_params = attention_allocation.Params(
       n_locations=4,
       prior_incident_counts=(10, 10, 10, 10),
       n_attention_units=5,
       incident_rates=[0, 0, 0, 0])
   env = attention_allocation.LocationAllocationEnv(params=env_params)
   agent_params = allocation_agents.MLEGreedyAgentParams(epsilon=0.0)
   agent = allocation_agents.MLEGreedyAgent(
       action_space=env.action_space,
       observation_space=env.observation_space,
       reward_fn=rewards.VectorSumReward('incidents_seen'),
       params=agent_params)
   allocation = agent._allocate(5, [5, 2, 1, 1])
   self.assertTrue(np.all(np.equal(allocation, [4, 1, 0, 0])))
 def test_allocate_beliefs_fair_unsatisfiable(self):
   env_params = attention_allocation.Params(
       n_locations=4,
       prior_incident_counts=(10, 10, 10, 10),
       n_attention_units=5,
       incident_rates=[0, 0, 0, 0])
   env = attention_allocation.LocationAllocationEnv(params=env_params)
   agent_params = allocation_agents.MLEGreedyAgentParams(
       epsilon=0.0, alpha=0.25)
   agent = allocation_agents.MLEGreedyAgent(
       action_space=env.action_space,
       observation_space=env.observation_space,
       reward_fn=rewards.VectorSumReward('incidents_seen'),
       params=agent_params)
   with self.assertRaises(gym.error.InvalidAction):
     agent._allocate(5, [5, 2, 1, 1])