def test_default_plots_True_plotcallback(self): agent = agents.PpoAgent("CartPole-v0") p = plot.Loss() r = plot.Rewards() c = agent._add_plot_callbacks([r], True, [p]) assert p in c assert r in c
def test_default_plots_None_plotcallback(self): agent = agents.PpoAgent("CartPole-v0") p = plot.Loss() r = plot.Rewards() c = agent._prepare_callbacks([r], None, [p]) assert not p in c assert r in c
def test_prepare_callbacks(self): agent = agents.PpoAgent("CartPole-v0") c = [plot.ToMovie(), plot.Rewards()] d = agent._prepare_callbacks(c, default_plots=None, default_plot_callbacks=[]) assert isinstance(d[0], plot._PreProcess) assert isinstance(d[1], plot.Rewards) assert isinstance(d[-2], plot._PostProcess) assert isinstance(d[-1], plot.ToMovie)
def test_train_single_episode(self): for backend in get_backends(PpoAgent): ppo = agents.PpoAgent(gym_env_name=_env_name, backend=backend) count = log._CallbackCounts() ppo.train([log.Agent(), count, duration._SingleEpisode()]) assert count.gym_init_begin_count == count.gym_init_end_count == 1 assert count.gym_step_begin_count == count.gym_step_end_count assert count.gym_step_begin_count < 10 + count.gym_reset_begin_count
def test_save_load(self): for backend in get_backends(PpoAgent): ppo = agents.PpoAgent(gym_env_name=_step_count_name, backend=backend) ppo.train([duration._SingleEpisode()], default_plots=False) temp_dir = ppo.save() ppo = agents.load(temp_dir) ppo.play(default_plots=False, num_episodes=1, callbacks=[]) easyagents.backends.core._rmpath(temp_dir)
def test_play_single_episode(self): for backend in get_backends(PpoAgent): ppo = agents.PpoAgent(gym_env_name=_env_name, backend=backend) count = log._CallbackCounts() cb = [log.Agent(), count, duration._SingleEpisode()] ppo.train(duration._SingleEpisode()) ppo.play(cb) assert count.gym_init_begin_count == count.gym_init_end_count == 1 assert count.gym_step_begin_count == count.gym_step_end_count <= 10
def test_train_multiple_subplots(self): agent = agents.PpoAgent("CartPole-v0") agent.train([ duration._SingleIteration(), plot.State(), plot.Rewards(), plot.Loss(), plot.Steps() ])
def test_every(self): agent = agents.PpoAgent(_line_world_name) every = save.Every(num_evals_between_save=1) agent.train([every], num_iterations_between_eval=1, num_episodes_per_iteration=10, num_iterations=3, default_plots=False) os.path.isdir(every.directory) assert len(every.saved_agents) == 4 for (episode, reward, dir) in every.saved_agents: os.path.isdir(dir)
def test_train_tomovie_with_filename(self): f = tempfile.NamedTemporaryFile(delete=False) filepath = f.name f.close() os.remove(filepath) assert not os.path.isfile(filepath) filepath += '.gif' agent = agents.PpoAgent("CartPole-v0") m = plot.ToMovie(filepath=filepath, fps=10) agent.train([duration._SingleIteration(), plot.Rewards(), m]) try: os.remove(m.filepath) except: pass
def test_best(self): agent = agents.PpoAgent(_line_world_name) best = save.Best() agent.train([best], num_iterations_between_eval=1, num_episodes_per_iteration=10, num_iterations=3, default_plots=False) os.path.isdir(best.directory) assert len(best.saved_agents) > 0 (episode, reward, dir) = best.saved_agents[0] os.path.isdir(dir) agent2 = easyagents.agents.load(dir) assert agent2 agent2.evaluate()
def test_train_plotsteps(self): agent = agents.PpoAgent("CartPole-v0") agent.train([duration._SingleIteration(), plot.Steps()])
def test_train_plotsteprewards(self): ppoAgent = agents.PpoAgent('CartPole-v0') ppoAgent.train([plot.StepRewards(), duration.Fast()])
def test_single_episode(self): agent = agents.PpoAgent("CartPole-v0") count=log._CallbackCounts() agent.train([duration._SingleEpisode(), log._Callbacks(), count]) assert count.train_iteration_begin_count == 1
def test_default_plots_None_nocallback(self): agent = agents.PpoAgent("CartPole-v0") p = plot.Loss() c = agent._add_plot_callbacks([], None, [p]) assert p in c
def test_play_plotrewards(self): agent = agents.PpoAgent("CartPole-v0") agent.train([duration._SingleIteration()]) agent.play([plot.Rewards()])
def test_play_plotstate(self): agent = agents.PpoAgent("CartPole-v0") agent.train([duration._SingleEpisode()]) agent.play([plot.State()])
def test_cartpole_log_iteration(self): ppo = agents.PpoAgent(gym_env_name="CartPole-v0", backend='tfagents') ppo.train([log.Iteration(), duration._SingleIteration()])
def test_log_agent(self): agent = agents.PpoAgent(_step_count_name) agent.train([log.Agent(), duration.Fast()])
def test_default_plots_None_durationcallback(self): agent = agents.PpoAgent("CartPole-v0") p = plot.Loss() c = agent._prepare_callbacks([duration.Fast()], None, [p]) assert p in c
def test_default_plots_False_nocallback(self): agent = agents.PpoAgent("CartPole-v0") p = plot.Loss() c = agent._prepare_callbacks([], False, [p]) assert not p in c
def test_log_step(self): agent = agents.PpoAgent(_step_count_name) agent.train([log.Step(), duration._SingleIteration()])
def test_train_tomovie(self): agent = agents.PpoAgent("CartPole-v0") agent.train( [duration._SingleIteration(), plot.Rewards(), plot.ToMovie()])
def test_log_callbacks(self): agent = agents.PpoAgent(_step_count_name) agent.train([log._Callbacks(), duration._SingleIteration()])
def test_log_iteration(self): agent = agents.PpoAgent(env_name) agent.train([log.Iteration(), duration._SingleIteration()])