예제 #1
0
def test_boltzmann():
    np.random.seed(88)
    beta = Parameter(0.1)
    pi = Boltzmann(beta)

    Q = Table((10, 3))
    Q.table = np.random.randn(10, 3)

    pi.set_q(Q)

    s = np.array([2])
    a = np.array([1])

    p_s = pi(s)
    p_s_test = np.array([0.30676679, 0.36223227, 0.33100094])
    assert np.allclose(p_s, p_s_test)

    p_sa = pi(s, a)
    p_sa_test = np.array([0.36223227])
    assert np.allclose(p_sa, p_sa_test)

    a = pi.draw_action(s)
    a_test = 2
    assert a.item() == a_test

    beta_2 = LinearParameter(0.2, 0.1, 2)
    pi.set_beta(beta_2)
    p_sa_2 = pi(s, a)
    assert p_sa_2 < p_sa

    pi.update(s, a)
    p_sa_3 = pi(s, a)
    p_sa_3_test = np.array([0.33100094])
    assert np.allclose(p_sa_3, p_sa_3_test)
예제 #2
0
def test_eps_greedy():
    np.random.seed(88)
    eps = Parameter(0.1)
    pi = EpsGreedy(eps)

    Q = Table((10, 3))
    Q.table = np.random.randn(10, 3)

    pi.set_q(Q)

    s = np.array([2])
    a = np.array([1])

    p_s = pi(s)
    p_s_test = np.array([0.03333333, 0.93333333, 0.03333333])
    assert np.allclose(p_s, p_s_test)

    p_sa = pi(s, a)
    p_sa_test = np.array([0.93333333])
    assert np.allclose(p_sa, p_sa_test)

    a = pi.draw_action(s)
    a_test = 1
    assert a.item() == a_test

    eps_2 = LinearParameter(0.2, 0.1, 2)
    pi.set_epsilon(eps_2)
    p_sa_2 = pi(s, a)
    assert p_sa_2 < p_sa

    pi.update(s, a)
    pi.update(s, a)
    p_sa_3 = pi(s, a)
    print(eps_2.get_value())
    assert p_sa_3 == p_sa
예제 #3
0
def test_mellowmax():
    np.random.seed(88)
    omega = Parameter(3)
    pi = Mellowmax(omega)

    Q = Table((10, 3))
    Q.table = np.random.randn(10, 3)

    pi.set_q(Q)

    s = np.array([2])
    a = np.array([1])

    p_s = pi(s)
    p_s_test = np.array([0.08540336, 0.69215916, 0.22243748])
    assert np.allclose(p_s, p_s_test)

    p_sa = pi(s, a)
    p_sa_test = np.array([0.69215916])
    assert np.allclose(p_sa, p_sa_test)

    a = pi.draw_action(s)
    a_test = 2
    assert a.item() == a_test

    try:
        beta = Parameter(0.1)
        pi.set_beta(beta)
    except RuntimeError:
        pass
    else:
        assert False

    try:
        pi.update(s,a)
    except RuntimeError:
        pass
    else:
        assert False