def testExternalMultiAgentEnvTruncateEpisodes(self): agents = 4 ev = PolicyEvaluator( env_creator=lambda _: SimpleMultiServing(BasicMultiAgent(agents)), policy_graph=MockPolicyGraph, batch_steps=40, batch_mode="truncate_episodes") for _ in range(3): batch = ev.sample() self.assertEqual(batch.count, 160) self.assertEqual(len(np.unique(batch["agent_index"])), agents)
def testExternalMultiAgentEnvCompleteEpisodes(self): agents = 4 ev = RolloutWorker( env_creator=lambda _: SimpleMultiServing(BasicMultiAgent(agents)), policy=MockPolicy, batch_steps=40, batch_mode="complete_episodes") for _ in range(3): batch = ev.sample() self.assertEqual(batch.count, 40) self.assertEqual(len(np.unique(batch["agent_index"])), agents)
def test_external_multi_agent_env_truncate_episodes(self): agents = 4 ev = RolloutWorker( env_creator=lambda _: SimpleMultiServing(BasicMultiAgent(agents)), policy=MockPolicy, rollout_fragment_length=40, batch_mode="truncate_episodes") for _ in range(3): batch = ev.sample() self.assertEqual(batch.count, 160) self.assertEqual(len(np.unique(batch["agent_index"])), agents)
def testExternalMultiAgentEnvSample(self): agents = 2 act_space = gym.spaces.Discrete(2) obs_space = gym.spaces.Discrete(2) ev = PolicyEvaluator( env_creator=lambda _: SimpleMultiServing(BasicMultiAgent(agents)), policy_graph={ "p0": (MockPolicyGraph, obs_space, act_space, {}), "p1": (MockPolicyGraph, obs_space, act_space, {}), }, policy_mapping_fn=lambda agent_id: "p{}".format(agent_id % 2), batch_steps=50) batch = ev.sample() self.assertEqual(batch.count, 50)
def test_external_multi_agent_env_sample(self): agents = 2 act_space = gym.spaces.Discrete(2) obs_space = gym.spaces.Discrete(2) ev = RolloutWorker( env_creator=lambda _: SimpleMultiServing(BasicMultiAgent(agents)), policy={ "p0": (MockPolicy, obs_space, act_space, {}), "p1": (MockPolicy, obs_space, act_space, {}), }, policy_mapping_fn=lambda agent_id: "p{}".format(agent_id % 2), batch_steps=50) batch = ev.sample() self.assertEqual(batch.count, 50)