def test_agent_stats_1(): # Define train and evaluation envs train_env = GridWorld() eval_env = GridWorld() # Parameters params = {"n_episodes": 500} horizon = 20 # Check DummyAgent agent = DummyAgent(train_env, **params) agent.fit() agent.policy(None) # Run AgentStats stats_agent1 = AgentStats(DummyAgent, train_env, init_kwargs=params, n_fit=4, eval_horizon=10) stats_agent2 = AgentStats(DummyAgent, train_env, init_kwargs=params, n_fit=4, eval_horizon=10) agent_stats_list = [stats_agent1, stats_agent2] # learning curves plot_episode_rewards(agent_stats_list, cumulative=True, show=False) # compare final policies compare_policies(agent_stats_list, eval_env, eval_horizon=horizon, n_sim=10, show=False) compare_policies(agent_stats_list, eval_env, eval_horizon=horizon, n_sim=10, show=False, stationary_policy=False) # check if fitted for agent_stats in agent_stats_list: assert len(agent_stats.fitted_agents) == 4 for agent in agent_stats.fitted_agents: assert agent.fitted # test saving/loading stats_agent1.save('test_agent_stats_file.pickle') loaded_stats = AgentStats.load('test_agent_stats_file.pickle') assert stats_agent1.identifier == loaded_stats.identifier # delete file os.remove('test_agent_stats_file.pickle') # test hyperparemeter optimization loaded_stats.optimize_hyperparams() loaded_stats.optimize_hyperparams(continue_previous=True)
def test_agent_stats_2(): # Define train and evaluation envs train_env = GridWorld() eval_env = GridWorld() # Parameters params = {"n_episodes": 500} # Run AgentStats stats_agent1 = AgentStats(DummyAgent, train_env, eval_env=eval_env, init_kwargs=params, n_fit=4, eval_horizon=10, n_jobs=1) stats_agent2 = AgentStats(DummyAgent, train_env, eval_env=eval_env, init_kwargs=params, n_fit=4, eval_horizon=10, n_jobs=1) agent_stats_list = [stats_agent1, stats_agent2] # set some writers stats_agent1.set_writer(1, None) stats_agent1.set_writer(2, None) # compare final policies compare_policies(agent_stats_list, n_sim=10, show=False) compare_policies(agent_stats_list, n_sim=10, show=False, stationary_policy=False) # learning curves plot_episode_rewards(agent_stats_list, cumulative=True, show=False) # check if fitted for agent_stats in agent_stats_list: assert len(agent_stats.fitted_agents) == 4 for agent in agent_stats.fitted_agents: assert agent.fitted # test saving/loading dirname = stats_agent1.output_dir fname = dirname / 'stats' stats_agent1.save() loaded_stats = AgentStats.load(fname) assert stats_agent1.identifier == loaded_stats.identifier # delete file os.remove(fname.with_suffix('.pickle')) dirname.rmdir() # test hyperparemeter optimization loaded_stats.optimize_hyperparams()
init_kwargs=params_ppo, n_fit=4) # hyperparam optim best_trial, data = ppo_stats.optimize_hyperparams( n_trials=10, timeout=None, n_sim=5, n_fit=2, n_jobs=2, sampler_method='optuna_default') initial_n_trials = len(ppo_stats.study.trials) # save ppo_stats.save('ppo_stats_backup') del ppo_stats # load ppo_stats = AgentStats.load('ppo_stats_backup') # continue previous optimization, now with 5s of timeout best_trial, data = ppo_stats.optimize_hyperparams(n_trials=10, timeout=5, n_sim=5, n_fit=2, n_jobs=2, continue_previous=True) print("number of initial trials = ", initial_n_trials) print("number of trials after continuing= ", len(ppo_stats.study.trials))
HORIZON = 50 BONUS_SCALE_FACTOR = 0.1 MIN_DIST = 0.1 params_ppo = {"n_episodes": N_EPISODES, "gamma": GAMMA, "horizon": HORIZON, "learning_rate": 0.0003} # ------------------------------- # Run AgentStats and save results # -------------------------------- ppo_stats = AgentStats(PPOAgent, train_env, init_kwargs=params_ppo, n_fit=4) ppo_stats.fit() # fit the 4 agents ppo_stats.save('ppo_stats') del ppo_stats # ------------------------------- # Load and plot results # -------------------------------- ppo_stats = AgentStats.load('ppo_stats') agent_stats_list = [ppo_stats] # learning curves plot_episode_rewards(agent_stats_list, cumulative=True, show=False) # compare final policies output = compare_policies(agent_stats_list, eval_env, eval_horizon=HORIZON, n_sim=10) print(output)