def make_remote_evaluators(self, env_creator, policy_graph, count, remote_args): """Convenience method to return a number of remote evaluators.""" cls = CommonPolicyEvaluator.as_remote(**remote_args).remote return [ self._make_evaluator(cls, env_creator, policy_graph, i + 1) for i in range(count) ]
def testMetrics(self): ev = CommonPolicyEvaluator( env_creator=lambda _: MockEnv(episode_length=10), policy_graph=MockPolicyGraph, batch_mode="complete_episodes") remote_ev = CommonPolicyEvaluator.as_remote().remote( env_creator=lambda _: MockEnv(episode_length=10), policy_graph=MockPolicyGraph, batch_mode="complete_episodes") ev.sample() ray.get(remote_ev.sample.remote()) result = collect_metrics(ev, [remote_ev]) self.assertEqual(result.episodes_total, 20) self.assertEqual(result.episode_reward_mean, 10)
def _testWithOptimizer(self, optimizer_cls): n = 3 env = gym.make("CartPole-v0") act_space = env.action_space obs_space = env.observation_space dqn_config = {"gamma": 0.95, "n_step": 3} if optimizer_cls == SyncReplayOptimizer: # TODO: support replay with non-DQN graphs. Currently this can't # happen since the replay buffer doesn't encode extra fields like # "advantages" that PG uses. policies = { "p1": (DQNPolicyGraph, obs_space, act_space, dqn_config), "p2": (DQNPolicyGraph, obs_space, act_space, dqn_config), } else: policies = { "p1": (PGPolicyGraph, obs_space, act_space, {}), "p2": (DQNPolicyGraph, obs_space, act_space, dqn_config), } ev = CommonPolicyEvaluator( env_creator=lambda _: MultiCartpole(n), policy_graph=policies, policy_mapping_fn=lambda agent_id: ["p1", "p2"][agent_id % 2], batch_steps=50) if optimizer_cls == AsyncGradientsOptimizer: remote_evs = [CommonPolicyEvaluator.as_remote().remote( env_creator=lambda _: MultiCartpole(n), policy_graph=policies, policy_mapping_fn=lambda agent_id: ["p1", "p2"][agent_id % 2], batch_steps=50)] else: remote_evs = [] optimizer = optimizer_cls(ev, remote_evs, {}) for i in range(200): ev.foreach_policy( lambda p, _: p.set_epsilon(max(0.02, 1 - i * .02)) if isinstance(p, DQNPolicyGraph) else None) optimizer.step() result = collect_metrics(ev, remote_evs) if i % 20 == 0: ev.foreach_policy( lambda p, _: p.update_target() if isinstance(p, DQNPolicyGraph) else None) print("Iter {}, rew {}".format(i, result.policy_reward_mean)) print("Total reward", result.episode_reward_mean) if result.episode_reward_mean >= 25 * n: return print(result) raise Exception("failed to improve reward")