def test_save_load(self): model_config = core.ModelConfig(_lineworld_name) random_agent = tfagents.TfRandomAgent(model_config=model_config) tempdir = bcore._get_temp_path() bcore._mkdir(tempdir) random_agent.save(directory=tempdir, callbacks=[]) random_agent.load(directory=tempdir, callbacks=[]) bcore._rmpath(tempdir)
def test_play(self): model_config = core.ModelConfig("CartPole-v0") randomAgent = tfagents.TfRandomAgent(model_config=model_config) pc=core.PlayContext() pc.max_steps_per_episode=10 pc.num_episodes=1 randomAgent.play(play_context=pc,callbacks=[]) assert pc.num_episodes == 1
def test_train(self): model_config = core.ModelConfig(_lineworld_name) tc = core.TrainContext() random_agent = tfagents.TfRandomAgent(model_config=model_config) random_agent.train(train_context=tc, callbacks=[duration.Fast(), log.Iteration()]) assert tc.episodes_done_in_iteration == 1