Пример #1
0
 def test_save_load(self):
     model_config = core.ModelConfig(_lineworld_name)
     random_agent = tfagents.TfRandomAgent(model_config=model_config)
     tempdir = bcore._get_temp_path()
     bcore._mkdir(tempdir)
     random_agent.save(directory=tempdir, callbacks=[])
     random_agent.load(directory=tempdir, callbacks=[])
     bcore._rmpath(tempdir)
Пример #2
0
 def test_play(self):
     model_config = core.ModelConfig("CartPole-v0")
     randomAgent = tfagents.TfRandomAgent(model_config=model_config)
     pc=core.PlayContext()
     pc.max_steps_per_episode=10
     pc.num_episodes=1
     randomAgent.play(play_context=pc,callbacks=[])
     assert pc.num_episodes == 1
Пример #3
0
 def test_train(self):
     model_config = core.ModelConfig(_lineworld_name)
     tc = core.TrainContext()
     random_agent = tfagents.TfRandomAgent(model_config=model_config)
     random_agent.train(train_context=tc,
                        callbacks=[duration.Fast(),
                                   log.Iteration()])
     assert tc.episodes_done_in_iteration == 1