def test_to_dict_from_dict(self):
     oldseed = agents.seed
     agents.seed = 123
     p1 = PpoAgent(gym_env_name=_line_world_name, fc_layers=(10, 20, 30), backend='tfagents')
     d = p1._to_dict()
     agents.seed = oldseed
     p2: EasyAgent = EasyAgent._from_dict(d)
     self.assert_are_equal(p1, p2)
 def test_save_load_play(self):
     oldseed = agents.seed
     agents.seed = 123
     p1 = PpoAgent(gym_env_name=_line_world_name, fc_layers=(10, 20, 30), backend='tfagents')
     p1.train(callbacks=[duration._SingleEpisode()], default_plots=False)
     d = p1.save()
     agents.seed = oldseed
     p2: EasyAgent = agents.load(d)
     self.assert_are_equal(p1, p2)
     p2.play(default_plots=False, num_episodes=1)
Exemple #3
0
 def test_train_cartpole(self):
     for backend in get_backends(PpoAgent):
         ppo = PpoAgent(gym_env_name="CartPole-v0", backend=backend)
         tc = core.PpoTrainContext()
         tc.num_iterations = 3
         tc.num_episodes_per_iteration = 10
         tc.max_steps_per_episode = 500
         tc.num_epochs_per_iteration = 5
         tc.num_iterations_between_eval = 2
         tc.num_episodes_per_eval = 5
         ppo.train([log.Iteration()], train_context=tc)
 def test_train(self):
     agents.seed = 0
     for backend in get_backends(PpoAgent):
         ppo = PpoAgent(gym_env_name=_cartpole_name, backend=backend)
         tc = core.PpoTrainContext()
         tc.num_iterations = 10
         tc.num_episodes_per_iteration = 10
         tc.max_steps_per_episode = 200
         tc.num_epochs_per_iteration = 5
         tc.num_iterations_between_eval = 5
         tc.num_episodes_per_eval = 5
         ppo.train([log.Iteration()], train_context=tc, default_plots=False)
         assert max_avg_rewards(tc) >= 50
Exemple #5
0
 def test_callback_single(self):
     for backend in get_backends(PpoAgent):
         env._StepCountEnv.clear()
         agent = PpoAgent(_env_name, backend=backend)
         agent.train(duration._SingleEpisode())
         assert env._StepCountEnv.reset_count <= 2
 def test_save_no_trained_policy_exception(self):
     p1 = PpoAgent(gym_env_name=_line_world_name, fc_layers=(10, 20, 30), backend='tfagents')
     with pytest.raises(Exception):
         p1.save()