def test_train_multiple_subplots(self): agent = agents.PpoAgent("CartPole-v0") agent.train([ duration._SingleIteration(), plot.State(), plot.Rewards(), plot.Loss(), plot.Steps() ])
def test_train_tomovie_with_filename(self): f = tempfile.NamedTemporaryFile(delete=False) filepath = f.name f.close() os.remove(filepath) assert not os.path.isfile(filepath) filepath += '.gif' agent = agents.PpoAgent("CartPole-v0") m = plot.ToMovie(filepath=filepath, fps=10) agent.train([duration._SingleIteration(), plot.Rewards(), m]) try: os.remove(m.filepath) except: pass
def test_save_load(self): model_config = core.ModelConfig(_lineworld_name) tc = core.PpoTrainContext() ppo_agent = tfagents.TfPpoAgent(model_config=model_config) ppo_agent.train( train_context=tc, callbacks=[duration._SingleIteration(), log.Iteration()]) tempdir = bcore._get_temp_path() bcore._mkdir(tempdir) ppo_agent.save(tempdir, []) ppo_agent = tfagents.TfPpoAgent(model_config=model_config) ppo_agent.load(tempdir, []) pc = core.PlayContext() pc.max_steps_per_episode = 10 pc.num_episodes = 1 ppo_agent.play(play_context=pc, callbacks=[]) bcore._rmpath(tempdir)
def test_log_callbacks(self): agent = agents.PpoAgent(_step_count_name) agent.train([log._Callbacks(), duration._SingleIteration()])
def test_cartpole_log_iteration(self): ppo = agents.PpoAgent(gym_env_name="CartPole-v0", backend='tfagents') ppo.train([log.Iteration(), duration._SingleIteration()])
def test_log_step(self): agent = agents.PpoAgent(_step_count_name) agent.train([log.Step(), duration._SingleIteration()])
def test_train_plotsteps(self): agent = agents.PpoAgent("CartPole-v0") agent.train([duration._SingleIteration(), plot.Steps()])
def test_play_plotrewards(self): agent = agents.PpoAgent("CartPole-v0") agent.train([duration._SingleIteration()]) agent.play([plot.Rewards()])
def test_train_tomovie(self): agent = agents.PpoAgent("CartPole-v0") agent.train( [duration._SingleIteration(), plot.Rewards(), plot.ToMovie()])
def test_log_iteration(self): agent = agents.PpoAgent(env_name) agent.train([log.Iteration(), duration._SingleIteration()])