Esempio n. 1
0
def test_simple_asymm_ghost(use_discrete):
    # Make opponent for asymmetric case
    brain_name_opp = BRAIN_NAME + "Opp"
    env = SimpleEnvironment(
        [BRAIN_NAME + "?team=0", brain_name_opp + "?team=1"],
        use_discrete=use_discrete)
    self_play_settings = SelfPlaySettings(
        play_against_latest_model_ratio=1.0,
        save_steps=10000,
        swap_steps=10000,
        team_change=400,
    )
    config = attr.evolve(
        PPO_TF_CONFIG,
        self_play=self_play_settings,
        max_steps=4000,
        framework=FrameworkType.TENSORFLOW,
    )
    _check_environment_trains(env, {
        BRAIN_NAME: config,
        brain_name_opp: config
    })
Esempio n. 2
0
def test_simple_ghost_fails(use_discrete):
    env = SimpleEnvironment([BRAIN_NAME + "?team=0", BRAIN_NAME + "?team=1"],
                            use_discrete=use_discrete)
    # This config should fail because the ghosted policy is never swapped with a competent policy.
    # Swap occurs after max step is reached.
    self_play_settings = SelfPlaySettings(play_against_latest_model_ratio=1.0,
                                          save_steps=2000,
                                          swap_steps=4000)
    config = attr.evolve(
        PPO_TF_CONFIG,
        self_play=self_play_settings,
        max_steps=2500,
        framework=FrameworkType.TENSORFLOW,
    )
    _check_environment_trains(env, {BRAIN_NAME: config},
                              success_threshold=None)
    processed_rewards = [
        default_reward_processor(rewards)
        for rewards in env.final_rewards.values()
    ]
    success_threshold = 0.9
    assert any(reward > success_threshold
               for reward in processed_rewards) and any(
                   reward < success_threshold for reward in processed_rewards)
Esempio n. 3
0
def dummy_config():
    return TrainerSettings(self_play=SelfPlaySettings())
Esempio n. 4
0
def dummy_config():
    return TrainerSettings(self_play=SelfPlaySettings(),
                           framework=FrameworkType.PYTORCH)