コード例 #1
0
def test_reach():
    T0 = np.array([[0.5, 0.25, 0.25], [0, 1, 0], [0, 0, 1]])
    mdp = MDP([T0])

    V, _ = mdp.solve_reach(accept=lambda y: y == 2)

    np.testing.assert_almost_equal(V[0], [0.5, 0, 1], decimal=4)
コード例 #2
0
def test_reach_finitetime():

    T0 = np.array([[0.9, 0, 0.1], [0, 1, 0], [0, 0, 1]])
    T1 = np.array([[0, 0.5, 0.5], [0, 1, 0], [0, 0, 1]])

    mdp = MDP([T0, T1])

    accept = lambda n: n == 2

    vlist, plist = mdp.solve_reach(accept, horizon=3)

    np.testing.assert_almost_equal(vlist[0][0], 0.1 + 0.9 * 0.1 + 0.9**2 * 0.5)
    np.testing.assert_almost_equal(vlist[1][0], 0.1 + 0.9 * 0.5)
    np.testing.assert_almost_equal(vlist[2][0], 0.5)

    np.testing.assert_almost_equal(plist[0][0], 0)
    np.testing.assert_almost_equal(plist[1][0], 0)
    np.testing.assert_almost_equal(plist[2][0], 1)