def testTrainCartpole(self): register_env("test", lambda _: SimpleServing(gym.make("CartPole-v0"))) pg = PGAgent(env="test", config={"num_workers": 0}) for i in range(100): result = pg.train() print("Iteration {}, reward {}, timesteps {}".format( i, result.episode_reward_mean, result.timesteps_total)) if result.episode_reward_mean >= 100: return raise Exception("failed to improve reward")
def testTrainMultiCartpoleSinglePolicy(self): n = 10 register_env("multi_cartpole", lambda _: MultiCartpole(n)) pg = PGAgent(env="multi_cartpole", config={"num_workers": 0}) for i in range(100): result = pg.train() print("Iteration {}, reward {}, timesteps {}".format( i, result.episode_reward_mean, result.timesteps_total)) if result.episode_reward_mean >= 50 * n: return raise Exception("failed to improve reward")