def test_vectorize_round_robin(self): env = MultiAgentEnvWrapper(lambda v: RoundRobinMultiAgent(2), [], 2) obs, rew, dones, _, _ = env.poll() self.assertEqual(obs, {0: {0: 0}, 1: {0: 0}}) self.assertEqual(rew, {0: {}, 1: {}}) env.send_actions({0: {0: 0}, 1: {0: 0}}) obs, rew, dones, _, _ = env.poll() self.assertEqual(obs, {0: {1: 0}, 1: {1: 0}}) env.send_actions({0: {1: 0}, 1: {1: 0}}) obs, rew, dones, _, _ = env.poll() self.assertEqual(obs, {0: {0: 0}, 1: {0: 0}})
def _make_base_env(self): del self num_envs = 2 sub_envs = [ make_multi_agent("CartPole-v1")({ "num_agents": 2 }) for _ in range(num_envs) ] env = MultiAgentEnvWrapper(None, sub_envs, 2) return env
def test_vectorize_basic(self): env = MultiAgentEnvWrapper(lambda v: BasicMultiAgent(2), [], 2) obs, rew, dones, _, _ = env.poll() self.assertEqual(obs, {0: {0: 0, 1: 0}, 1: {0: 0, 1: 0}}) self.assertEqual(rew, {0: {}, 1: {}}) self.assertEqual( dones, { 0: { "__all__": False }, 1: { "__all__": False }, }, ) for _ in range(24): env.send_actions({0: {0: 0, 1: 0}, 1: {0: 0, 1: 0}}) obs, rew, dones, _, _ = env.poll() self.assertEqual(obs, {0: {0: 0, 1: 0}, 1: {0: 0, 1: 0}}) self.assertEqual(rew, {0: {0: 1, 1: 1}, 1: {0: 1, 1: 1}}) self.assertEqual( dones, { 0: { 0: False, 1: False, "__all__": False }, 1: { 0: False, 1: False, "__all__": False }, }, ) env.send_actions({0: {0: 0, 1: 0}, 1: {0: 0, 1: 0}}) obs, rew, dones, _, _ = env.poll() self.assertEqual( dones, { 0: { 0: True, 1: True, "__all__": True }, 1: { 0: True, 1: True, "__all__": True }, }, ) # Reset processing self.assertRaises( ValueError, lambda: env.send_actions({ 0: { 0: 0, 1: 0 }, 1: { 0: 0, 1: 0 } })) self.assertEqual(env.try_reset(0), {0: {0: 0, 1: 0}}) self.assertEqual(env.try_reset(1), {1: {0: 0, 1: 0}}) env.send_actions({0: {0: 0, 1: 0}, 1: {0: 0, 1: 0}}) obs, rew, dones, _, _ = env.poll() self.assertEqual(obs, {0: {0: 0, 1: 0}, 1: {0: 0, 1: 0}}) self.assertEqual(rew, {0: {0: 1, 1: 1}, 1: {0: 1, 1: 1}}) self.assertEqual( dones, { 0: { 0: False, 1: False, "__all__": False }, 1: { 0: False, 1: False, "__all__": False }, }, )
def test_no_reset_until_poll(self): env = MultiAgentEnvWrapper(lambda v: BasicMultiAgent(2), [], 1) self.assertFalse(env.get_sub_environments()[0].resetted) env.poll() self.assertTrue(env.get_sub_environments()[0].resetted)