Esempio n. 1
0
def test_can_unhook_tql_agenthook(RPSTask, tabular_q_learning_config_dict):
    agent = build_TabularQ_Agent(RPSTask, tabular_q_learning_config_dict,
                                 'TQL')
    hook = AgentHook(agent)
    retrieved_agent = AgentHook.unhook(hook)

    compare_against_expected_retrieved_agent(agent, retrieved_agent, [])
Esempio n. 2
0
def test_can_load_tql_from_agenthook(RPSTask, tabular_q_learning_config_dict):
    agent = build_TabularQ_Agent(RPSTask, tabular_q_learning_config_dict,
                                 'TQL')
    save_path = '/tmp/test_save.agent'
    hook = AgentHook(agent, save_path=save_path)

    retrieved_agent = AgentHook.unhook(hook)
    assert np.array_equal(agent.algorithm.Q_table,
                          retrieved_agent.algorithm.Q_table)
Esempio n. 3
0
def test_can_load_ppo_from_agenthook_disabling_cuda(RPSTask, ppo_config_dict):
    ppo_config_dict['use_cuda'] = True
    agent = build_PPO_Agent(RPSTask, ppo_config_dict, 'PPO')
    save_path = '/tmp/test_save.agent'
    hook = AgentHook(agent, save_path=save_path)

    retrieved_agent = AgentHook.unhook(hook, use_cuda=False)
    model = retrieved_agent.algorithm.model
    assert all(map(lambda param: not param.is_cuda, model.parameters()))
Esempio n. 4
0
def test_can_unhook_ppo_agenthook_with_cuda(RPSTask, ppo_config_dict):
    ppo_config_dict['use_cuda'] = True
    agent = build_PPO_Agent(RPSTask, ppo_config_dict, 'PPO')
    assert all(
        map(lambda param: param.is_cuda, agent.algorithm.model.parameters()))
    hook = AgentHook(agent)

    retrieved_agent = AgentHook.unhook(hook)

    compare_against_expected_retrieved_agent(agent, retrieved_agent,
                                             [retrieved_agent.algorithm.model])
Esempio n. 5
0
def test_can_load_ppo_from_agenthook_with_cuda(RPSTask, ppo_config_dict):
    ppo_config_dict['use_cuda'] = True
    agent = build_PPO_Agent(RPSTask, ppo_config_dict, 'PPO')
    save_path = '/tmp/test_save.agent'
    hook = AgentHook(agent, save_path=save_path)

    assert not hasattr(hook, 'agent')

    retrieved_agent = AgentHook.unhook(hook)
    model_list = [retrieved_agent.algorithm.model]
    assert_model_parameters_are_cuda_tensors(model_list)
Esempio n. 6
0
def test_can_load_dqn_from_agenthook_with_cuda(RPSTask, dqn_config_dict):
    dqn_config_dict['use_cuda'] = True
    agent = build_DQN_Agent(RPSTask, dqn_config_dict, 'DQN')
    save_path = '/tmp/test_save.agent'
    hook = AgentHook(agent, save_path=save_path)

    retrieved_agent = AgentHook.unhook(hook)
    model_list = [
        retrieved_agent.algorithm.model, retrieved_agent.algorithm.target_model
    ]
    assert_model_parameters_are_cuda_tensors(model_list)
Esempio n. 7
0
def test_can_unhook_dqn_agenthook_cuda(RPSTask, dqn_config_dict):
    dqn_config_dict['use_cuda'] = True
    agent = build_DQN_Agent(RPSTask, dqn_config_dict, 'DQN')
    assert all(
        map(lambda param: param.is_cuda, agent.algorithm.model.parameters()))
    assert all(
        map(lambda param: param.is_cuda,
            agent.algorithm.target_model.parameters()))
    hook = AgentHook(agent)
    retrieved_agent = AgentHook.unhook(hook)

    compare_against_expected_retrieved_agent(agent, retrieved_agent, [
        retrieved_agent.algorithm.model, retrieved_agent.algorithm.target_model
    ])
Esempio n. 8
0
def test_can_save_ppo_to_memory(RPSTask, ppo_config_dict):
    agent = build_PPO_Agent(RPSTask, ppo_config_dict, 'PPO')
    save_path = '/tmp/test_save.agent'
    hook = AgentHook(agent, save_path=save_path)

    assess_file_has_been_saved_on_disk_and_not_on_ram(hook, save_path)
    os.remove(save_path)
Esempio n. 9
0
def test_can_save_tql_to_memory(RPSTask, tabular_q_learning_config_dict):
    agent = build_TabularQ_Agent(RPSTask, tabular_q_learning_config_dict,
                                 'TQL')
    save_path = '/tmp/test_save.agent'
    hook = AgentHook(agent, save_path)

    assess_file_has_been_saved_on_disk_and_not_on_ram(hook, save_path)
    os.remove(save_path)
Esempio n. 10
0
def test_can_hook_ppo_agent_using_cuda(RPSTask, ppo_config_dict):
    ppo_config_dict['use_cuda'] = True
    agent = build_PPO_Agent(RPSTask, ppo_config_dict, 'PPO')
    assert all(
        map(lambda param: param.is_cuda, agent.algorithm.model.parameters()))
    hook = AgentHook(agent)

    compare_against_expected_agenthook(agent, hook, AgentType.PPO,
                                       [hook.agent.algorithm.model])
Esempio n. 11
0
def test_can_hook_dqn_agent_using_cuda(RPSTask, dqn_config_dict):
    dqn_config_dict['use_cuda'] = True
    agent = build_DQN_Agent(RPSTask, dqn_config_dict, 'DQN')
    assert all(
        map(lambda param: param.is_cuda, agent.algorithm.model.parameters()))
    assert all(
        map(lambda param: param.is_cuda,
            agent.algorithm.target_model.parameters()))
    hook = AgentHook(agent)

    compare_against_expected_agenthook(
        agent, hook, AgentType.DQN,
        [hook.agent.algorithm.model, hook.agent.algorithm.target_model])
Esempio n. 12
0
def test_can_hook_tql_agent(RPSTask, tabular_q_learning_config_dict):
    agent = build_TabularQ_Agent(RPSTask, tabular_q_learning_config_dict,
                                 'TQL')
    hook = AgentHook(agent)

    compare_against_expected_agenthook(agent, hook, AgentType.TQL, [])