Exemplo n.º 1
0
def test_compile_6():
    mdp = MDP()
    t1 = Transition("s0", "a0", "s1", 1, 0)
    mdp.add_transition(t1)
    mdp.add_terminal_state("s1")
    mdp.ensure_compiled()
    assert mdp.terminal_mask[mdp._state_dict["s1"]] == True
    assert mdp.terminal_mask[mdp._state_dict["s0"]] == False
Exemplo n.º 2
0
def test_compile_1():
    mdp = MDP()
    t1 = Transition("s0", "a0", "s1", 1, 0)
    t2 = Transition("s1", "a0", "s0", 1, 0)
    mdp.add_transition(t1)
    mdp.add_transition(t2)

    mdp.ensure_compiled()
    assert mdp.P.shape == (2, 1, 2)
    assert mdp.R.shape == (2, 1, 2)
Exemplo n.º 3
0
def test_compile_4():
    mdp = MDP()
    t1 = Transition("s0", "a0", "s1", 1, 0)
    mdp.add_transition(t1)

    with pytest.raises(ValueError):
        mdp.ensure_compiled()

    mdp.add_terminal_state("s1")
    mdp.ensure_compiled()
Exemplo n.º 4
0
def test_compile_5():
    mdp = MDP()
    t1 = Transition("s0", "a0", "s1", 1, 0)
    mdp.add_transition(t1)
    mdp.add_terminal_state("s1")

    mdp.ensure_compiled()
    assert mdp.compiled
    assert type(mdp.states) is tuple
    assert type(mdp.actions) is tuple
    assert type(mdp.terminal_states) is tuple
    mdp._decompile()
    assert not mdp.compiled
    assert type(mdp.states) is set
    assert type(mdp.actions) is set
    assert type(mdp.terminal_states) is set
Exemplo n.º 5
0
def test_compile_2():
    mdp = MDP()
    t1 = Transition("s0", "a0", "s1", 1, 5)
    t2 = Transition("s1", "a0", "s0", 1, 2)
    mdp.add_transition(t1)
    mdp.add_transition(t2)

    mdp.ensure_compiled()
    assert (mdp.R[mdp._state_dict["s0"], mdp._action_dict["a0"],
                  mdp._state_dict["s1"]] == 5)
    assert (mdp.R[mdp._state_dict["s1"], mdp._action_dict["a0"],
                  mdp._state_dict["s0"]] == 2)

    assert (mdp.P[mdp._state_dict["s0"], mdp._action_dict["a0"],
                  mdp._state_dict["s1"]] == 1)
    assert (mdp.P[mdp._state_dict["s1"], mdp._action_dict["a0"],
                  mdp._state_dict["s0"]] == 1)
Exemplo n.º 6
0
def test_compile_3():
    mdp = MDP()
    t1 = Transition("s0", "a0", "s1", 1, 0)
    t2 = Transition("s1", "a0", "s0", 0.5, 0)
    mdp.add_transition(t1)
    mdp.add_transition(t2)

    with pytest.raises(ValueError):
        mdp.ensure_compiled()

    t3 = Transition("s1", "a0", "s2", 0.5, 0)
    t4 = Transition("s2", "a0", "s0", 1, 0)

    mdp.add_transition(t3)
    mdp.add_transition(t4)

    mdp.ensure_compiled()