Exemplo n.º 1
0
def test_can_keep_track_of_window_of_winrate_for_learning_policy(RPS_task):
    psro = PSRONashResponse(task=RPS_task, match_outcome_rolling_window_size=3)
    training_agent_indeces = [1, 1, 0, 1]
    expected_rolling_window = [1, 0, 1]

    # TODO this is very ugly. It always chooses player 2 (1-index) as winner
    # We should really find a way of mocking this.
    sample_trajectory = [([], [], [0, 1], [])]  # (s, a, r, s')
    for i in training_agent_indeces:
        psro.update_rolling_winrates(episode_trajectory=sample_trajectory,
                                     training_agent_index=i)

    np.testing.assert_array_equal(expected_rolling_window,
                                  psro.match_outcome_rolling_window)
Exemplo n.º 2
0
def test_can_keep_track_of_window_of_winrate_for_learning_policy(RPS_task):
    psro = PSRONashResponse(task=RPS_task, match_outcome_rolling_window_size=3)
    training_agent_indeces = [1, 1, 0, 1]
    expected_rolling_window = [1, 0, 1]

    # TODO this is very ugly. It always chooses player 2 (1-index) as winner
    # We should really find a way of mocking this.
    sample_trajectory = Trajectory(
        env_type=EnvType.MULTIAGENT_SIMULTANEOUS_ACTION, num_agents=2)
    sample_trajectory.add_timestep(None, None, [0, 1], None, True)

    for i in training_agent_indeces:
        psro.update_rolling_winrates(episode_trajectory=sample_trajectory,
                                     training_agent_index=i)

    np.testing.assert_array_equal(expected_rolling_window,
                                  psro.match_outcome_rolling_window)