def testTrainCartpole(self): register_env("test", lambda _: SimpleServing(gym.make("CartPole-v0"))) pg = PGAgent(env="test", config={"num_workers": 0}) for i in range(100): result = pg.train() print("Iteration {}, reward {}, timesteps {}".format( i, result.episode_reward_mean, result.timesteps_total)) if result.episode_reward_mean >= 100: return raise Exception("failed to improve reward")
def testTrainMultiCartpoleSinglePolicy(self): n = 10 register_env("multi_cartpole", lambda _: MultiCartpole(n)) pg = PGAgent(env="multi_cartpole", config={"num_workers": 0}) for i in range(100): result = pg.train() print("Iteration {}, reward {}, timesteps {}".format( i, result.episode_reward_mean, result.timesteps_total)) if result.episode_reward_mean >= 50 * n: return raise Exception("failed to improve reward")
def testQueryEvaluators(self): register_env("test", lambda _: gym.make("CartPole-v0")) pg = PGAgent(env="test", config={"num_workers": 2, "batch_size": 5}) results = pg.optimizer.foreach_evaluator(lambda ev: ev.batch_steps) results2 = pg.optimizer.foreach_evaluator_with_index( lambda ev, i: (i, ev.batch_steps)) self.assertEqual(results, [5, 5, 5]) self.assertEqual(results2, [(0, 5), (1, 5), (2, 5)])