示例#1
0
def test_recurrent_poca(action_sizes, is_multiagent):
    if is_multiagent:
        # This is not a recurrent environment, just check if LSTM doesn't crash
        env = MultiAgentEnvironment([BRAIN_NAME],
                                    action_sizes=action_sizes,
                                    num_agents=2)
    else:
        # Actually test LSTM here
        env = MemoryEnvironment([BRAIN_NAME], action_sizes=action_sizes)
    new_network_settings = attr.evolve(
        POCA_TORCH_CONFIG.network_settings,
        memory=NetworkSettings.MemorySettings(memory_size=16),
    )
    new_hyperparams = attr.evolve(
        POCA_TORCH_CONFIG.hyperparameters,
        learning_rate=1.0e-3,
        batch_size=64,
        buffer_size=128,
    )
    config = attr.evolve(
        POCA_TORCH_CONFIG,
        hyperparameters=new_hyperparams,
        network_settings=new_network_settings,
        max_steps=500 if is_multiagent else 6000,
    )
    check_environment_trains(env, {BRAIN_NAME: config},
                             success_threshold=None if is_multiagent else 0.9)
示例#2
0
def test_visual_poca(num_visual):
    env = MultiAgentEnvironment([BRAIN_NAME],
                                action_sizes=(0, 1),
                                num_agents=2,
                                num_visual=num_visual)
    new_hyperparams = attr.evolve(POCA_TORCH_CONFIG.hyperparameters,
                                  learning_rate=3.0e-4)
    config = attr.evolve(POCA_TORCH_CONFIG, hyperparameters=new_hyperparams)
    check_environment_trains(env, {BRAIN_NAME: config})
示例#3
0
def test_var_len_obs_and_goal_poca(num_vis, num_vector, num_var_len,
                                   conditioning_type):
    env = MultiAgentEnvironment(
        [BRAIN_NAME],
        action_sizes=(0, 1),
        num_visual=num_vis,
        num_vector=num_vector,
        num_var_len=num_var_len,
        step_size=0.2,
        num_agents=2,
        goal_indices=[0],
    )
    new_network = attr.evolve(POCA_TORCH_CONFIG.network_settings,
                              goal_conditioning_type=conditioning_type)
    new_hyperparams = attr.evolve(POCA_TORCH_CONFIG.hyperparameters,
                                  learning_rate=3.0e-4)
    config = attr.evolve(POCA_TORCH_CONFIG,
                         hyperparameters=new_hyperparams,
                         network_settings=new_network)
    check_environment_trains(env, {BRAIN_NAME: config})
示例#4
0
def test_simple_poca(action_sizes):
    env = MultiAgentEnvironment([BRAIN_NAME],
                                action_sizes=action_sizes,
                                num_agents=2)
    config = attr.evolve(POCA_TORCH_CONFIG)
    check_environment_trains(env, {BRAIN_NAME: config})