Beispiel #1
0
 def testTrainCartpole(self):
     register_env("test", lambda _: SimpleServing(gym.make("CartPole-v0")))
     pg = PGAgent(env="test", config={"num_workers": 0})
     for i in range(100):
         result = pg.train()
         print("Iteration {}, reward {}, timesteps {}".format(
             i, result.episode_reward_mean, result.timesteps_total))
         if result.episode_reward_mean >= 100:
             return
     raise Exception("failed to improve reward")
Beispiel #2
0
 def testTrainMultiCartpoleSinglePolicy(self):
     n = 10
     register_env("multi_cartpole", lambda _: MultiCartpole(n))
     pg = PGAgent(env="multi_cartpole", config={"num_workers": 0})
     for i in range(100):
         result = pg.train()
         print("Iteration {}, reward {}, timesteps {}".format(
             i, result.episode_reward_mean, result.timesteps_total))
         if result.episode_reward_mean >= 50 * n:
             return
     raise Exception("failed to improve reward")