def test_agent_solve_bit_flipping_game():
    AGENTS = [PPO, DDQN, DQN_With_Fixed_Q_Targets, DDQN_With_Prioritised_Experience_Replay, DQN, DQN_HER]
    trainer = Trainer(config, AGENTS)
    results = trainer.train()
    for agent in AGENTS:
        agent_results = results[agent.agent_name]
        agent_results = np.max(agent_results[0][1][50:])
        assert agent_results >= 0.0, "Failed for {} -- score {}".format(agent.agent_name, agent_results)
def test_agents_can_play_games_of_different_dimensions():
    config.num_episodes_to_run = 10
    config.hyperparameters["DQN_Agents"]["batch_size"] = 3
    AGENTS = [A2C, A3C, PPO, DDQN, DQN_With_Fixed_Q_Targets, DDQN_With_Prioritised_Experience_Replay, DQN]
    trainer = Trainer(config, AGENTS)
    config.environment = gym.make("CartPole-v0")
    results = trainer.train()
    for agent in AGENTS:
        assert agent.agent_name in results.keys()

    AGENTS = [SAC, TD3, PPO, DDPG]
    config.environment = gym.make("MountainCarContinuous-v0")
    trainer = Trainer(config, AGENTS)
    results = trainer.train()
    for agent in AGENTS:
        assert agent.agent_name in results.keys()

    AGENTS = [DDQN, SNN_HRL]
    config.environment = Four_Rooms_Environment(15, 15, stochastic_actions_probability=0.25,
                                                random_start_user_place=True, random_goal_place=False)
    trainer = Trainer(config, AGENTS)
    results = trainer.train()
    for agent in AGENTS:
        assert agent.agent_name in results.keys()
        "batch_size": 128,
        "buffer_size": 100000,
        "epsilon": 1.0,
        "epsilon_decay_rate_denominator": 150,
        "discount_rate": 0.999,
        "alpha_prioritised_replay": 0.6,
        "beta_prioritised_replay": 0.1,
        "incremental_td_error": 1e-8,
        "update_every_n_steps": 15,
        "tau": 1e-2,
        "linear_hidden_units": [256, 256],
        "final_layer_activation": "softmax",
        # "y_range": (-1, 14),
        "batch_norm": False,
        "gradient_clipping_norm": 5,
        "HER_sample_proportion": 0.8,
        "learning_iterations": 1,
        "clip_rewards": False
    }
}

config.model = FCNN()

if __name__== '__main__':
    AGENTS = [DQN, DRQN, ]#DDQN, Dueling_DDQN, DDQN_With_Prioritised_Experience_Replay]

    trainer = Trainer(config, AGENTS)
    trainer.train()