Esempio n. 1
0
 def test_train_single_episode(self):
     for backend in get_backends(PpoAgent):
         ppo = agents.PpoAgent(gym_env_name=_env_name, backend=backend)
         count = log._CallbackCounts()
         ppo.train([log.Agent(), count, duration._SingleEpisode()])
         assert count.gym_init_begin_count == count.gym_init_end_count == 1
         assert count.gym_step_begin_count == count.gym_step_end_count
         assert count.gym_step_begin_count < 10 + count.gym_reset_begin_count
Esempio n. 2
0
 def test_play_single_episode(self):
     for backend in get_backends(PpoAgent):
         ppo = agents.PpoAgent(gym_env_name=_env_name, backend=backend)
         count = log._CallbackCounts()
         cb = [log.Agent(), count, duration._SingleEpisode()]
         ppo.train(duration._SingleEpisode())
         ppo.play(cb)
         assert count.gym_init_begin_count == count.gym_init_end_count == 1
         assert count.gym_step_begin_count == count.gym_step_end_count <= 10
Esempio n. 3
0
 def test_single_episode(self):
     agent = agents.PpoAgent("CartPole-v0")
     count=log._CallbackCounts()
     agent.train([duration._SingleEpisode(), log._Callbacks(), count])
     assert count.train_iteration_begin_count == 1