def test_multi_agent_sample_round_robin(self): ev = RolloutWorker( env_creator=lambda _: RoundRobinMultiAgent(5, increment_obs=True), policy_spec={ "p0": PolicySpec(policy_class=MockPolicy), }, policy_mapping_fn=lambda agent_id, episode, **kwargs: "p0", rollout_fragment_length=50, ) batch = ev.sample() self.assertEqual(batch.count, 50) # since we round robin introduce agents into the env, some of the env # steps don't count as proper transitions self.assertEqual(batch.policy_batches["p0"].count, 42) check( batch.policy_batches["p0"]["obs"][:10], one_hot(np.array([0, 1, 2, 3, 4] * 2), 10), ) check( batch.policy_batches["p0"]["new_obs"][:10], one_hot(np.array([1, 2, 3, 4, 5] * 2), 10), ) self.assertEqual( batch.policy_batches["p0"]["rewards"].tolist()[:10], [100, 100, 100, 100, 0] * 2, ) self.assertEqual( batch.policy_batches["p0"]["dones"].tolist()[:10], [False, False, False, False, True] * 2, ) self.assertEqual( batch.policy_batches["p0"]["t"].tolist()[:10], [4, 9, 14, 19, 24, 5, 10, 15, 20, 25], )
def test_multi_agent_sample_round_robin(self): act_space = gym.spaces.Discrete(2) obs_space = gym.spaces.Discrete(10) ev = RolloutWorker( env_creator=lambda _: RoundRobinMultiAgent(5, increment_obs=True), policy={ "p0": (MockPolicy, obs_space, act_space, {}), }, policy_mapping_fn=lambda agent_id: "p0", rollout_fragment_length=50) batch = ev.sample() self.assertEqual(batch.count, 50) # since we round robin introduce agents into the env, some of the env # steps don't count as proper transitions self.assertEqual(batch.policy_batches["p0"].count, 42) self.assertEqual(batch.policy_batches["p0"]["obs"].tolist()[:10], [ one_hot(0, 10), one_hot(1, 10), one_hot(2, 10), one_hot(3, 10), one_hot(4, 10), ] * 2) self.assertEqual(batch.policy_batches["p0"]["new_obs"].tolist()[:10], [ one_hot(1, 10), one_hot(2, 10), one_hot(3, 10), one_hot(4, 10), one_hot(5, 10), ] * 2) self.assertEqual(batch.policy_batches["p0"]["rewards"].tolist()[:10], [100, 100, 100, 100, 0] * 2) self.assertEqual(batch.policy_batches["p0"]["dones"].tolist()[:10], [False, False, False, False, True] * 2) self.assertEqual(batch.policy_batches["p0"]["t"].tolist()[:10], [4, 9, 14, 19, 24, 5, 10, 15, 20, 25])
def test_vectorize_round_robin(self): env = MultiAgentEnvWrapper(lambda v: RoundRobinMultiAgent(2), [], 2) obs, rew, dones, _, _ = env.poll() self.assertEqual(obs, {0: {0: 0}, 1: {0: 0}}) self.assertEqual(rew, {0: {}, 1: {}}) env.send_actions({0: {0: 0}, 1: {0: 0}}) obs, rew, dones, _, _ = env.poll() self.assertEqual(obs, {0: {1: 0}, 1: {1: 0}}) env.send_actions({0: {1: 0}, 1: {1: 0}}) obs, rew, dones, _, _ = env.poll() self.assertEqual(obs, {0: {0: 0}, 1: {0: 0}})
def test_round_robin_mock(self): env = RoundRobinMultiAgent(2) obs = env.reset() self.assertEqual(obs, {0: 0}) for _ in range(5): obs, rew, done, info = env.step({0: 0}) self.assertEqual(obs, {1: 0}) self.assertEqual(done["__all__"], False) obs, rew, done, info = env.step({1: 0}) self.assertEqual(obs, {0: 0}) self.assertEqual(done["__all__"], False) obs, rew, done, info = env.step({0: 0}) self.assertEqual(done["__all__"], True)