def testServingEnvHorizonNotSupported(self): ev = CommonPolicyEvaluator( env_creator=lambda _: SimpleServing(MockEnv(25)), policy_graph=MockPolicyGraph, episode_horizon=20, batch_steps=10, batch_mode="complete_episodes") ev.sample() self.assertRaises(Exception, lambda: ev.sample())
def testFilterSync(self): ev = CommonPolicyEvaluator( env_creator=lambda _: gym.make("CartPole-v0"), policy_graph=MockPolicyGraph, sample_async=True, observation_filter="ConcurrentMeanStdFilter") time.sleep(2) ev.sample() filters = ev.get_filters(flush_after=True) obs_f = filters["obs_filter"] self.assertNotEqual(obs_f.rs.n, 0) self.assertNotEqual(obs_f.buffer.n, 0)
def testCompleteEpisodes(self): ev = CommonPolicyEvaluator( env_creator=lambda _: gym.make("CartPole-v0"), policy_graph=MockPolicyGraph, batch_steps=2, batch_mode="complete_episodes") batch = ev.sample() self.assertGreater(batch.count, 2) self.assertTrue(batch["dones"][-1]) batch = ev.sample() self.assertGreater(batch.count, 2) self.assertTrue(batch["dones"][-1])
def testMetrics(self): ev = CommonPolicyEvaluator( env_creator=lambda _: MockEnv(episode_length=10), policy_graph=MockPolicyGraph, batch_mode="complete_episodes") remote_ev = CommonPolicyEvaluator.as_remote().remote( env_creator=lambda _: MockEnv(episode_length=10), policy_graph=MockPolicyGraph, batch_mode="complete_episodes") ev.sample() ray.get(remote_ev.sample.remote()) result = collect_metrics(ev, [remote_ev]) self.assertEqual(result.episodes_total, 20) self.assertEqual(result.episode_reward_mean, 10)
def testBatchesSmallerWhenVectorized(self): ev = CommonPolicyEvaluator( env_creator=lambda _: MockEnv(episode_length=8), policy_graph=MockPolicyGraph, batch_mode="truncate_episodes", batch_steps=16, num_envs=4) batch = ev.sample() self.assertEqual(batch.count, 16) result = collect_metrics(ev, []) self.assertEqual(result.episodes_total, 0) batch = ev.sample() result = collect_metrics(ev, []) self.assertEqual(result.episodes_total, 4)
def testCompleteEpisodesPacking(self): ev = CommonPolicyEvaluator(env_creator=lambda _: MockEnv(10), policy_graph=MockPolicyGraph, batch_steps=15, batch_mode="complete_episodes") batch = ev.sample() self.assertEqual(batch.count, 20)
def testVectorEnvSupport(self): ev = CommonPolicyEvaluator( env_creator=lambda _: MockVectorEnv(episode_length=20, num_envs=8), policy_graph=MockPolicyGraph, batch_mode="truncate_episodes", batch_steps=10) for _ in range(8): batch = ev.sample() self.assertEqual(batch.count, 10) result = collect_metrics(ev, []) self.assertEqual(result.episodes_total, 0) for _ in range(8): batch = ev.sample() self.assertEqual(batch.count, 10) result = collect_metrics(ev, []) self.assertEqual(result.episodes_total, 8)
def testMultiAgentSampleRoundRobin(self): act_space = gym.spaces.Discrete(2) obs_space = gym.spaces.Discrete(2) ev = CommonPolicyEvaluator( env_creator=lambda _: RoundRobinMultiAgent(5, increment_obs=True), policy_graph={ "p0": (MockPolicyGraph, obs_space, act_space, {}), }, policy_mapping_fn=lambda agent_id: "p0", batch_steps=50) batch = ev.sample() self.assertEqual(batch.count, 50) # since we round robin introduce agents into the env, some of the env # steps don't count as proper transitions self.assertEqual(batch.policy_batches["p0"].count, 42) self.assertEqual( batch.policy_batches["p0"]["obs"].tolist()[:10], [0, 1, 2, 3, 4] * 2) self.assertEqual( batch.policy_batches["p0"]["new_obs"].tolist()[:10], [1, 2, 3, 4, 5] * 2) self.assertEqual( batch.policy_batches["p0"]["rewards"].tolist()[:10], [100, 100, 100, 100, 0] * 2) self.assertEqual( batch.policy_batches["p0"]["dones"].tolist()[:10], [False, False, False, False, True] * 2) self.assertEqual( batch.policy_batches["p0"]["t"].tolist()[:10], [4, 9, 14, 19, 24, 5, 10, 15, 20, 25])
def testServingEnvBadActions(self): ev = CommonPolicyEvaluator( env_creator=lambda _: SimpleServing(MockEnv(25)), policy_graph=BadPolicyGraph, sample_async=True, batch_steps=40, batch_mode="truncate_episodes") self.assertRaises(Exception, lambda: ev.sample())
def testBasic(self): ev = CommonPolicyEvaluator( env_creator=lambda _: gym.make("CartPole-v0"), policy_graph=MockPolicyGraph) batch = ev.sample() for key in ["obs", "actions", "rewards", "dones", "advantages"]: self.assertIn(key, batch) self.assertGreater(batch["advantages"][0], 1)
def testServingEnvOffPolicy(self): ev = CommonPolicyEvaluator( env_creator=lambda _: SimpleOffPolicyServing(MockEnv(25)), policy_graph=MockPolicyGraph, batch_steps=40, batch_mode="complete_episodes") for _ in range(3): batch = ev.sample() self.assertEqual(batch.count, 50)
def testPackEpisodes(self): for batch_size in [1, 10, 100, 1000]: ev = CommonPolicyEvaluator( env_creator=lambda _: gym.make("CartPole-v0"), policy_graph=MockPolicyGraph, batch_steps=batch_size, batch_mode="pack_episodes") batch = ev.sample() self.assertEqual(batch.count, batch_size)
def testAutoConcat(self): ev = CommonPolicyEvaluator( env_creator=lambda _: MockEnv(episode_length=40), policy_graph=MockPolicyGraph, sample_async=True, batch_steps=10, batch_mode="truncate_episodes", observation_filter="ConcurrentMeanStdFilter") time.sleep(2) batch = ev.sample() self.assertEqual(batch.count, 40) # auto-concat up to 5 episodes
def testTruncateEpisodes(self): ev = CommonPolicyEvaluator( env_creator=lambda _: gym.make("CartPole-v0"), policy_graph=MockPolicyGraph, batch_steps=2, batch_mode="truncate_episodes") batch = ev.sample() self.assertEqual(batch.count, 2) ev = CommonPolicyEvaluator( env_creator=lambda _: gym.make("CartPole-v0"), policy_graph=MockPolicyGraph, batch_steps=1000, batch_mode="truncate_episodes") self.assertLess(batch.count, 200)
def testMultiAgentSample(self): act_space = gym.spaces.Discrete(2) obs_space = gym.spaces.Discrete(2) ev = CommonPolicyEvaluator( env_creator=lambda _: BasicMultiAgent(5), policy_graph={ "p0": (MockPolicyGraph, obs_space, act_space, {}), "p1": (MockPolicyGraph, obs_space, act_space, {}), }, policy_mapping_fn=lambda agent_id: "p{}".format(agent_id % 2), batch_steps=50) batch = ev.sample() self.assertEqual(batch.count, 50) self.assertEqual(batch.policy_batches["p0"].count, 150) self.assertEqual(batch.policy_batches["p1"].count, 100) self.assertEqual( batch.policy_batches["p0"]["t"].tolist(), list(range(25)) * 6)