def test_to_dict_from_dict(self): oldseed = agents.seed agents.seed = 123 p1 = PpoAgent(gym_env_name=_line_world_name, fc_layers=(10, 20, 30), backend='tfagents') d = p1._to_dict() agents.seed = oldseed p2: EasyAgent = EasyAgent._from_dict(d) self.assert_are_equal(p1, p2)
def test_save_load_play(self): oldseed = agents.seed agents.seed = 123 p1 = PpoAgent(gym_env_name=_line_world_name, fc_layers=(10, 20, 30), backend='tfagents') p1.train(callbacks=[duration._SingleEpisode()], default_plots=False) d = p1.save() agents.seed = oldseed p2: EasyAgent = agents.load(d) self.assert_are_equal(p1, p2) p2.play(default_plots=False, num_episodes=1)
def test_train_cartpole(self): for backend in get_backends(PpoAgent): ppo = PpoAgent(gym_env_name="CartPole-v0", backend=backend) tc = core.PpoTrainContext() tc.num_iterations = 3 tc.num_episodes_per_iteration = 10 tc.max_steps_per_episode = 500 tc.num_epochs_per_iteration = 5 tc.num_iterations_between_eval = 2 tc.num_episodes_per_eval = 5 ppo.train([log.Iteration()], train_context=tc)
def test_train(self): agents.seed = 0 for backend in get_backends(PpoAgent): ppo = PpoAgent(gym_env_name=_cartpole_name, backend=backend) tc = core.PpoTrainContext() tc.num_iterations = 10 tc.num_episodes_per_iteration = 10 tc.max_steps_per_episode = 200 tc.num_epochs_per_iteration = 5 tc.num_iterations_between_eval = 5 tc.num_episodes_per_eval = 5 ppo.train([log.Iteration()], train_context=tc, default_plots=False) assert max_avg_rewards(tc) >= 50
def test_callback_single(self): for backend in get_backends(PpoAgent): env._StepCountEnv.clear() agent = PpoAgent(_env_name, backend=backend) agent.train(duration._SingleEpisode()) assert env._StepCountEnv.reset_count <= 2
def test_save_no_trained_policy_exception(self): p1 = PpoAgent(gym_env_name=_line_world_name, fc_layers=(10, 20, 30), backend='tfagents') with pytest.raises(Exception): p1.save()