Exemplo n.º 1
0
 def test_multi_agent_sample_round_robin(self):
     ev = RolloutWorker(
         env_creator=lambda _: RoundRobinMultiAgent(5, increment_obs=True),
         policy_spec={
             "p0": PolicySpec(policy_class=MockPolicy),
         },
         policy_mapping_fn=lambda agent_id, episode, **kwargs: "p0",
         rollout_fragment_length=50,
     )
     batch = ev.sample()
     self.assertEqual(batch.count, 50)
     # since we round robin introduce agents into the env, some of the env
     # steps don't count as proper transitions
     self.assertEqual(batch.policy_batches["p0"].count, 42)
     check(
         batch.policy_batches["p0"]["obs"][:10],
         one_hot(np.array([0, 1, 2, 3, 4] * 2), 10),
     )
     check(
         batch.policy_batches["p0"]["new_obs"][:10],
         one_hot(np.array([1, 2, 3, 4, 5] * 2), 10),
     )
     self.assertEqual(
         batch.policy_batches["p0"]["rewards"].tolist()[:10],
         [100, 100, 100, 100, 0] * 2,
     )
     self.assertEqual(
         batch.policy_batches["p0"]["dones"].tolist()[:10],
         [False, False, False, False, True] * 2,
     )
     self.assertEqual(
         batch.policy_batches["p0"]["t"].tolist()[:10],
         [4, 9, 14, 19, 24, 5, 10, 15, 20, 25],
     )
Exemplo n.º 2
0
 def test_multi_agent_sample_round_robin(self):
     act_space = gym.spaces.Discrete(2)
     obs_space = gym.spaces.Discrete(10)
     ev = RolloutWorker(
         env_creator=lambda _: RoundRobinMultiAgent(5, increment_obs=True),
         policy={
             "p0": (MockPolicy, obs_space, act_space, {}),
         },
         policy_mapping_fn=lambda agent_id: "p0",
         rollout_fragment_length=50)
     batch = ev.sample()
     self.assertEqual(batch.count, 50)
     # since we round robin introduce agents into the env, some of the env
     # steps don't count as proper transitions
     self.assertEqual(batch.policy_batches["p0"].count, 42)
     self.assertEqual(batch.policy_batches["p0"]["obs"].tolist()[:10], [
         one_hot(0, 10),
         one_hot(1, 10),
         one_hot(2, 10),
         one_hot(3, 10),
         one_hot(4, 10),
     ] * 2)
     self.assertEqual(batch.policy_batches["p0"]["new_obs"].tolist()[:10], [
         one_hot(1, 10),
         one_hot(2, 10),
         one_hot(3, 10),
         one_hot(4, 10),
         one_hot(5, 10),
     ] * 2)
     self.assertEqual(batch.policy_batches["p0"]["rewards"].tolist()[:10],
                      [100, 100, 100, 100, 0] * 2)
     self.assertEqual(batch.policy_batches["p0"]["dones"].tolist()[:10],
                      [False, False, False, False, True] * 2)
     self.assertEqual(batch.policy_batches["p0"]["t"].tolist()[:10],
                      [4, 9, 14, 19, 24, 5, 10, 15, 20, 25])
Exemplo n.º 3
0
 def test_vectorize_round_robin(self):
     env = MultiAgentEnvWrapper(lambda v: RoundRobinMultiAgent(2), [], 2)
     obs, rew, dones, _, _ = env.poll()
     self.assertEqual(obs, {0: {0: 0}, 1: {0: 0}})
     self.assertEqual(rew, {0: {}, 1: {}})
     env.send_actions({0: {0: 0}, 1: {0: 0}})
     obs, rew, dones, _, _ = env.poll()
     self.assertEqual(obs, {0: {1: 0}, 1: {1: 0}})
     env.send_actions({0: {1: 0}, 1: {1: 0}})
     obs, rew, dones, _, _ = env.poll()
     self.assertEqual(obs, {0: {0: 0}, 1: {0: 0}})
Exemplo n.º 4
0
 def test_round_robin_mock(self):
     env = RoundRobinMultiAgent(2)
     obs = env.reset()
     self.assertEqual(obs, {0: 0})
     for _ in range(5):
         obs, rew, done, info = env.step({0: 0})
         self.assertEqual(obs, {1: 0})
         self.assertEqual(done["__all__"], False)
         obs, rew, done, info = env.step({1: 0})
         self.assertEqual(obs, {0: 0})
         self.assertEqual(done["__all__"], False)
     obs, rew, done, info = env.step({0: 0})
     self.assertEqual(done["__all__"], True)