def test_agent_stats_2():
    # Define train and evaluation envs
    train_env = GridWorld()
    eval_env = GridWorld()

    # Parameters
    params = {"n_episodes": 500}

    # Run AgentStats
    stats_agent1 = AgentStats(DummyAgent,
                              train_env,
                              eval_env=eval_env,
                              init_kwargs=params,
                              n_fit=4,
                              eval_horizon=10,
                              n_jobs=1)
    stats_agent2 = AgentStats(DummyAgent,
                              train_env,
                              eval_env=eval_env,
                              init_kwargs=params,
                              n_fit=4,
                              eval_horizon=10,
                              n_jobs=1)
    agent_stats_list = [stats_agent1, stats_agent2]

    # set some writers
    stats_agent1.set_writer(1, None)
    stats_agent1.set_writer(2, None)

    # compare final policies
    compare_policies(agent_stats_list, n_sim=10, show=False)
    compare_policies(agent_stats_list,
                     n_sim=10,
                     show=False,
                     stationary_policy=False)

    # learning curves
    plot_episode_rewards(agent_stats_list, cumulative=True, show=False)

    # check if fitted
    for agent_stats in agent_stats_list:
        assert len(agent_stats.fitted_agents) == 4
        for agent in agent_stats.fitted_agents:
            assert agent.fitted

    # test saving/loading
    dirname = stats_agent1.output_dir
    fname = dirname / 'stats'
    stats_agent1.save()
    loaded_stats = AgentStats.load(fname)
    assert stats_agent1.identifier == loaded_stats.identifier

    # delete file
    os.remove(fname.with_suffix('.pickle'))
    dirname.rmdir()

    # test hyperparemeter optimization
    loaded_stats.optimize_hyperparams()
def test_agent_stats_partial_fit_and_tuple_env():
    # Define train and evaluation envs
    train_env = (GridWorld, None
                 )  # tuple (constructor, kwargs) must also work in AgentStats

    # Parameters
    params = {"n_episodes": 500}
    horizon = 20

    # Run AgentStats
    stats = AgentStats(DummyAgent,
                       train_env,
                       init_kwargs=params,
                       n_fit=4,
                       eval_horizon=10)
    stats2 = AgentStats(DummyAgent,
                        train_env,
                        init_kwargs=params,
                        n_fit=4,
                        eval_horizon=10)
    # set some writers
    stats.set_writer(0, None)
    stats.set_writer(3, None)

    # Run partial fit
    stats.partial_fit(0.1)
    stats.partial_fit(0.5)
    for agent in stats.fitted_agents:
        assert agent.fraction_fitted == 0.6
    for _ in range(2):
        stats.partial_fit(0.5)
        for agent in stats.fitted_agents:
            assert agent.fraction_fitted == 1.0

    # Run fit
    stats2.fit()

    # learning curves
    plot_episode_rewards([stats], cumulative=True, show=False)

    # compare final policies
    compare_policies([stats], eval_horizon=horizon, n_sim=10, show=False)
GAMMA = 0.99
HORIZON = 50

params_ppo = {
    "n_episodes": N_EPISODES,
    "gamma": GAMMA,
    "horizon": HORIZON,
    "learning_rate": 0.0003
}

# -----------------------------
# Run AgentStats
# -----------------------------
ppo_stats = AgentStats(PPOAgent, train_env, init_kwargs=params_ppo, n_fit=4)

ppo_stats.set_writer(0, SummaryWriter, writer_kwargs={'comment': 'worker_0'})
ppo_stats.set_writer(1, SummaryWriter, writer_kwargs={'comment': 'worker_1'})

agent_stats_list = [ppo_stats]

agent_stats_list[0].fit()
agent_stats_list[0].save(
)  # after fit, writers are set to None to avoid pickle problems.

# learning curves
plot_episode_rewards(agent_stats_list, cumulative=True, show=False)

# compare final policies
output = compare_policies(agent_stats_list,
                          eval_env,
                          eval_horizon=HORIZON,