コード例 #1
0
def test_egreedy_select_action_exploration():
    trial_num = 50
    policy = EpsilonGreedy(n_actions=2, epsilon=1.0)
    policy.action_counts = np.array([3, 3])
    policy.reward_counts = np.array([3, 0])
    selected_action = [policy.select_action() for _ in range(trial_num)]
    assert 0 < sum(selected_action)[0] < trial_num
コード例 #2
0
def test_egreedy_select_action_exploitation():
    trial_num = 50
    policy = EpsilonGreedy(n_actions=2, epsilon=0.0)
    policy.action_counts = np.array([3, 3])
    policy.reward_counts = np.array([3, 0])
    for _ in range(trial_num):
        assert policy.select_action()[0] == 0
コード例 #3
0
def test_egreedy_abnormal_epsilon():

    with pytest.raises(ValueError):
        EpsilonGreedy(n_actions=2, epsilon=1.2)

    with pytest.raises(ValueError):
        EpsilonGreedy(n_actions=5, epsilon=-0.2)
コード例 #4
0
def test_egreedy_normal_epsilon():

    policy1 = EpsilonGreedy(n_actions=2)
    assert 0 <= policy1.epsilon <= 1

    policy2 = EpsilonGreedy(n_actions=3, epsilon=0.3)
    assert 0 <= policy2.epsilon <= 1
コード例 #5
0
ファイル: test_contextfree.py プロジェクト: zwcdp/zr-obp
def test_egreedy_normal_epsilon():

    policy1 = EpsilonGreedy(n_actions=2)
    assert 0 <= policy1.epsilon <= 1

    policy2 = EpsilonGreedy(n_actions=3, epsilon=0.3)
    assert 0 <= policy2.epsilon <= 1

    # policy type
    assert EpsilonGreedy(n_actions=2).policy_type == PolicyType.CONTEXT_FREE
コード例 #6
0
def test_contextfree_base_exception():
    # invalid n_actions
    with pytest.raises(ValueError):
        EpsilonGreedy(n_actions=0)

    with pytest.raises(ValueError):
        EpsilonGreedy(n_actions="3")

    # invalid len_list
    with pytest.raises(ValueError):
        EpsilonGreedy(n_actions=2, len_list=-1)

    with pytest.raises(ValueError):
        EpsilonGreedy(n_actions=2, len_list="5")

    # invalid batch_size
    with pytest.raises(ValueError):
        EpsilonGreedy(n_actions=2, batch_size=-3)

    with pytest.raises(ValueError):
        EpsilonGreedy(n_actions=2, batch_size="3")

    # invalid relationship between n_actions and len_list
    with pytest.raises(ValueError):
        EpsilonGreedy(n_actions=5, len_list=10)

    with pytest.raises(ValueError):
        EpsilonGreedy(n_actions=2, len_list=3)
コード例 #7
0
def test_egreedy_update_params():
    policy = EpsilonGreedy(n_actions=2, epsilon=1.0)
    policy.action_counts_temp = np.array([4, 3])
    policy.action_counts = np.copy(policy.action_counts_temp)
    policy.reward_counts_temp = np.array([2.0, 0.0])
    policy.reward_counts = np.copy(policy.reward_counts_temp)
    action = 0
    reward = 1.0
    policy.update_params(action, reward)
    assert np.array_equal(policy.action_counts, np.array([5, 3]))
    assert np.allclose(policy.reward_counts, np.array([2.0 + reward, 0.0]))
コード例 #8
0
def test_egreedy_update_params():
    policy = EpsilonGreedy(n_actions=2, epsilon=1.0)
    policy.action_counts_temp = np.array([4, 3])
    policy.action_counts = np.copy(policy.action_counts_temp)
    policy.reward_counts_temp = np.array([2.0, 0.0])
    policy.reward_counts = np.copy(policy.reward_counts_temp)
    action = 0
    reward = 1.0
    policy.update_params(action, reward)
    assert np.array_equal(policy.action_counts, np.array([5, 3]))
    # in epsilon greedy, reward_counts is defined as the mean of observed rewards for each action
    next_reward = (2.0 * (5 - 1) / 5) + (reward / 5)
    assert np.allclose(policy.reward_counts, np.array([next_reward, 0.0]))