def test_train_single_episode(self): for backend in get_backends(PpoAgent): ppo = agents.PpoAgent(gym_env_name=_env_name, backend=backend) count = log._CallbackCounts() ppo.train([log.Agent(), count, duration._SingleEpisode()]) assert count.gym_init_begin_count == count.gym_init_end_count == 1 assert count.gym_step_begin_count == count.gym_step_end_count assert count.gym_step_begin_count < 10 + count.gym_reset_begin_count
def test_play_single_episode(self): for backend in get_backends(PpoAgent): ppo = agents.PpoAgent(gym_env_name=_env_name, backend=backend) count = log._CallbackCounts() cb = [log.Agent(), count, duration._SingleEpisode()] ppo.train(duration._SingleEpisode()) ppo.play(cb) assert count.gym_init_begin_count == count.gym_init_end_count == 1 assert count.gym_step_begin_count == count.gym_step_end_count <= 10
def test_single_episode(self): agent = agents.PpoAgent("CartPole-v0") count=log._CallbackCounts() agent.train([duration._SingleEpisode(), log._Callbacks(), count]) assert count.train_iteration_begin_count == 1