def testMultiAgentSampleRoundRobin(self): act_space = gym.spaces.Discrete(2) obs_space = gym.spaces.Discrete(2) ev = CommonPolicyEvaluator( env_creator=lambda _: RoundRobinMultiAgent(5, increment_obs=True), policy_graph={ "p0": (MockPolicyGraph, obs_space, act_space, {}), }, policy_mapping_fn=lambda agent_id: "p0", batch_steps=50) batch = ev.sample() self.assertEqual(batch.count, 50) # since we round robin introduce agents into the env, some of the env # steps don't count as proper transitions self.assertEqual(batch.policy_batches["p0"].count, 42) self.assertEqual( batch.policy_batches["p0"]["obs"].tolist()[:10], [0, 1, 2, 3, 4] * 2) self.assertEqual( batch.policy_batches["p0"]["new_obs"].tolist()[:10], [1, 2, 3, 4, 5] * 2) self.assertEqual( batch.policy_batches["p0"]["rewards"].tolist()[:10], [100, 100, 100, 100, 0] * 2) self.assertEqual( batch.policy_batches["p0"]["dones"].tolist()[:10], [False, False, False, False, True] * 2) self.assertEqual( batch.policy_batches["p0"]["t"].tolist()[:10], [4, 9, 14, 19, 24, 5, 10, 15, 20, 25])
def testCompleteEpisodes(self): ev = CommonPolicyEvaluator( env_creator=lambda _: MockEnv(10), policy_graph=MockPolicyGraph, batch_steps=5, batch_mode="complete_episodes") batch = ev.sample() self.assertEqual(batch.count, 10)
def testServingEnvBadActions(self): ev = CommonPolicyEvaluator( env_creator=lambda _: SimpleServing(MockEnv(25)), policy_graph=BadPolicyGraph, sample_async=True, batch_steps=40, batch_mode="truncate_episodes") self.assertRaises(Exception, lambda: ev.sample())
def testBasic(self): ev = CommonPolicyEvaluator( env_creator=lambda _: gym.make("CartPole-v0"), policy_graph=MockPolicyGraph) batch = ev.sample() for key in ["obs", "actions", "rewards", "dones", "advantages"]: self.assertIn(key, batch) self.assertGreater(batch["advantages"][0], 1)
def testServingEnvHorizonNotSupported(self): ev = CommonPolicyEvaluator( env_creator=lambda _: SimpleServing(MockEnv(25)), policy_graph=MockPolicyGraph, episode_horizon=20, batch_steps=10, batch_mode="complete_episodes") ev.sample() self.assertRaises(Exception, lambda: ev.sample())
def testServingEnvTruncateEpisodes(self): ev = CommonPolicyEvaluator( env_creator=lambda _: SimpleServing(MockEnv(25)), policy_graph=MockPolicyGraph, batch_steps=40, batch_mode="truncate_episodes") for _ in range(3): batch = ev.sample() self.assertEqual(batch.count, 40)
def testAutoConcat(self): ev = CommonPolicyEvaluator( env_creator=lambda _: MockEnv(episode_length=40), policy_graph=MockPolicyGraph, sample_async=True, batch_steps=10, batch_mode="truncate_episodes", observation_filter="ConcurrentMeanStdFilter") time.sleep(2) batch = ev.sample() self.assertEqual(batch.count, 40) # auto-concat up to 5 episodes
def testCompleteEpisodesPacking(self): ev = CommonPolicyEvaluator( env_creator=lambda _: MockEnv(10), policy_graph=MockPolicyGraph, batch_steps=15, batch_mode="complete_episodes") batch = ev.sample() self.assertEqual(batch.count, 20) self.assertEqual( batch["t"].tolist(), [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
def testServingEnvOffPolicy(self): ev = CommonPolicyEvaluator( env_creator=lambda _: SimpleOffPolicyServing(MockEnv(25), 42), policy_graph=MockPolicyGraph, batch_steps=40, batch_mode="complete_episodes") for _ in range(3): batch = ev.sample() self.assertEqual(batch.count, 50) self.assertEqual(batch["actions"][0], 42) self.assertEqual(batch["actions"][-1], 42)
def testFilterSync(self): ev = CommonPolicyEvaluator( env_creator=lambda _: gym.make("CartPole-v0"), policy_graph=MockPolicyGraph, sample_async=True, observation_filter="ConcurrentMeanStdFilter") time.sleep(2) ev.sample() filters = ev.get_filters(flush_after=True) obs_f = filters["default"] self.assertNotEqual(obs_f.rs.n, 0) self.assertNotEqual(obs_f.buffer.n, 0)
def testMetrics(self): ev = CommonPolicyEvaluator( env_creator=lambda _: MockEnv(episode_length=10), policy_graph=MockPolicyGraph, batch_mode="complete_episodes") remote_ev = CommonPolicyEvaluator.as_remote().remote( env_creator=lambda _: MockEnv(episode_length=10), policy_graph=MockPolicyGraph, batch_mode="complete_episodes") ev.sample() ray.get(remote_ev.sample.remote()) result = collect_metrics(ev, [remote_ev]) self.assertEqual(result.episodes_total, 20) self.assertEqual(result.episode_reward_mean, 10)
def testBatchesSmallerWhenVectorized(self): ev = CommonPolicyEvaluator( env_creator=lambda _: MockEnv(episode_length=8), policy_graph=MockPolicyGraph, batch_mode="truncate_episodes", batch_steps=16, num_envs=4) batch = ev.sample() self.assertEqual(batch.count, 16) result = collect_metrics(ev, []) self.assertEqual(result.episodes_total, 0) batch = ev.sample() result = collect_metrics(ev, []) self.assertEqual(result.episodes_total, 4)
def testGetFilters(self): ev = CommonPolicyEvaluator( env_creator=lambda _: gym.make("CartPole-v0"), policy_graph=MockPolicyGraph, sample_async=True, observation_filter="ConcurrentMeanStdFilter") self.sample_and_flush(ev) filters = ev.get_filters(flush_after=False) time.sleep(2) filters2 = ev.get_filters(flush_after=False) obs_f = filters["default"] obs_f2 = filters2["default"] self.assertGreaterEqual(obs_f2.rs.n, obs_f.rs.n) self.assertGreaterEqual(obs_f2.buffer.n, obs_f.buffer.n)
def testAutoVectorization(self): ev = CommonPolicyEvaluator( env_creator=lambda _: MockEnv(episode_length=20), policy_graph=MockPolicyGraph, batch_mode="truncate_episodes", batch_steps=16, num_envs=8) for _ in range(8): batch = ev.sample() self.assertEqual(batch.count, 16) result = collect_metrics(ev, []) self.assertEqual(result.episodes_total, 0) for _ in range(8): batch = ev.sample() self.assertEqual(batch.count, 16) result = collect_metrics(ev, []) self.assertEqual(result.episodes_total, 8)
def testBatchDivisibilityCheck(self): self.assertRaises( ValueError, lambda: CommonPolicyEvaluator( env_creator=lambda _: MockEnv(episode_length=8), policy_graph=MockPolicyGraph, batch_mode="truncate_episodes", batch_steps=15, num_envs=4))
def make_remote_evaluators(self, env_creator, policy_graph, count, remote_args): """Convenience method to return a number of remote evaluators.""" cls = CommonPolicyEvaluator.as_remote(**remote_args).remote return [ self._make_evaluator(cls, env_creator, policy_graph, i + 1) for i in range(count) ]
def testMultiAgentSample(self): act_space = gym.spaces.Discrete(2) obs_space = gym.spaces.Discrete(2) ev = CommonPolicyEvaluator( env_creator=lambda _: BasicMultiAgent(5), policy_graph={ "p0": (MockPolicyGraph, obs_space, act_space, {}), "p1": (MockPolicyGraph, obs_space, act_space, {}), }, policy_mapping_fn=lambda agent_id: "p{}".format(agent_id % 2), batch_steps=50) batch = ev.sample() self.assertEqual(batch.count, 50) self.assertEqual(batch.policy_batches["p0"].count, 150) self.assertEqual(batch.policy_batches["p1"].count, 100) self.assertEqual( batch.policy_batches["p0"]["t"].tolist(), list(range(25)) * 6)
def _testWithOptimizer(self, optimizer_cls): n = 3 env = gym.make("CartPole-v0") act_space = env.action_space obs_space = env.observation_space dqn_config = {"gamma": 0.95, "n_step": 3} if optimizer_cls == SyncReplayOptimizer: # TODO: support replay with non-DQN graphs. Currently this can't # happen since the replay buffer doesn't encode extra fields like # "advantages" that PG uses. policies = { "p1": (DQNPolicyGraph, obs_space, act_space, dqn_config), "p2": (DQNPolicyGraph, obs_space, act_space, dqn_config), } else: policies = { "p1": (PGPolicyGraph, obs_space, act_space, {}), "p2": (DQNPolicyGraph, obs_space, act_space, dqn_config), } ev = CommonPolicyEvaluator( env_creator=lambda _: MultiCartpole(n), policy_graph=policies, policy_mapping_fn=lambda agent_id: ["p1", "p2"][agent_id % 2], batch_steps=50) if optimizer_cls == AsyncGradientsOptimizer: remote_evs = [CommonPolicyEvaluator.as_remote().remote( env_creator=lambda _: MultiCartpole(n), policy_graph=policies, policy_mapping_fn=lambda agent_id: ["p1", "p2"][agent_id % 2], batch_steps=50)] else: remote_evs = [] optimizer = optimizer_cls(ev, remote_evs, {}) for i in range(200): ev.foreach_policy( lambda p, _: p.set_epsilon(max(0.02, 1 - i * .02)) if isinstance(p, DQNPolicyGraph) else None) optimizer.step() result = collect_metrics(ev, remote_evs) if i % 20 == 0: ev.foreach_policy( lambda p, _: p.update_target() if isinstance(p, DQNPolicyGraph) else None) print("Iter {}, rew {}".format(i, result.policy_reward_mean)) print("Total reward", result.episode_reward_mean) if result.episode_reward_mean >= 25 * n: return print(result) raise Exception("failed to improve reward")
def testTrainMultiCartpoleManyPolicies(self): n = 20 env = gym.make("CartPole-v0") act_space = env.action_space obs_space = env.observation_space policies = {} for i in range(20): policies["pg_{}".format(i)] = ( PGPolicyGraph, obs_space, act_space, {}) policy_ids = list(policies.keys()) ev = CommonPolicyEvaluator( env_creator=lambda _: MultiCartpole(n), policy_graph=policies, policy_mapping_fn=lambda agent_id: random.choice(policy_ids), batch_steps=100) optimizer = SyncSamplesOptimizer(ev, [], {}) for i in range(100): optimizer.step() result = collect_metrics(ev) print("Iteration {}, rew {}".format(i, result.policy_reward_mean)) print("Total reward", result.episode_reward_mean) if result.episode_reward_mean >= 25 * n: return raise Exception("failed to improve reward")
def testSyncFilter(self): ev = CommonPolicyEvaluator( env_creator=lambda _: gym.make("CartPole-v0"), policy_graph=MockPolicyGraph, sample_async=True, observation_filter="ConcurrentMeanStdFilter") obs_f = self.sample_and_flush(ev) # Current State filters = ev.get_filters(flush_after=False) obs_f = filters["default"] self.assertLessEqual(obs_f.buffer.n, 20) new_obsf = obs_f.copy() new_obsf.rs._n = 100 ev.sync_filters({"default": new_obsf}) filters = ev.get_filters(flush_after=False) obs_f = filters["default"] self.assertGreaterEqual(obs_f.rs.n, 100) self.assertLessEqual(obs_f.buffer.n, 20)