Exemplo n.º 1
0
def test_add_transition_3():
    mdp = MDP()
    t1 = Transition("s0", "a0", "s0", 0, 0)
    t2 = Transition("s1", "a1", "s1", 0, 0)
    mdp.add_transition(t1)
    mdp.add_transition(t2)

    assert len(mdp.states) == 2
    assert len(mdp.actions) == 2
Exemplo n.º 2
0
def test_compile_1():
    mdp = MDP()
    t1 = Transition("s0", "a0", "s1", 1, 0)
    t2 = Transition("s1", "a0", "s0", 1, 0)
    mdp.add_transition(t1)
    mdp.add_transition(t2)

    mdp.ensure_compiled()
    assert mdp.P.shape == (2, 1, 2)
    assert mdp.R.shape == (2, 1, 2)
Exemplo n.º 3
0
def test_add_transition_4():
    mdp = MDP()
    t1 = Transition("s0", "a0", "s1", 0, 0)
    t2 = Transition("s0", "a0", "s1", 0, 0)
    mdp.add_transition(t1)

    with pytest.raises(ValueError):
        mdp.add_transition(t2)

    assert len(mdp.states) == 2
    assert len(mdp.actions) == 1
Exemplo n.º 4
0
def test_compile_2():
    mdp = MDP()
    t1 = Transition("s0", "a0", "s1", 1, 5)
    t2 = Transition("s1", "a0", "s0", 1, 2)
    mdp.add_transition(t1)
    mdp.add_transition(t2)

    mdp.ensure_compiled()
    assert (mdp.R[mdp._state_dict["s0"], mdp._action_dict["a0"],
                  mdp._state_dict["s1"]] == 5)
    assert (mdp.R[mdp._state_dict["s1"], mdp._action_dict["a0"],
                  mdp._state_dict["s0"]] == 2)

    assert (mdp.P[mdp._state_dict["s0"], mdp._action_dict["a0"],
                  mdp._state_dict["s1"]] == 1)
    assert (mdp.P[mdp._state_dict["s1"], mdp._action_dict["a0"],
                  mdp._state_dict["s0"]] == 1)
Exemplo n.º 5
0
def test_compile_6():
    mdp = MDP()
    t1 = Transition("s0", "a0", "s1", 1, 0)
    mdp.add_transition(t1)
    mdp.add_terminal_state("s1")
    mdp.ensure_compiled()
    assert mdp.terminal_mask[mdp._state_dict["s1"]] == True
    assert mdp.terminal_mask[mdp._state_dict["s0"]] == False
Exemplo n.º 6
0
def test_compile_3():
    mdp = MDP()
    t1 = Transition("s0", "a0", "s1", 1, 0)
    t2 = Transition("s1", "a0", "s0", 0.5, 0)
    mdp.add_transition(t1)
    mdp.add_transition(t2)

    with pytest.raises(ValueError):
        mdp.ensure_compiled()

    t3 = Transition("s1", "a0", "s2", 0.5, 0)
    t4 = Transition("s2", "a0", "s0", 1, 0)

    mdp.add_transition(t3)
    mdp.add_transition(t4)

    mdp.ensure_compiled()
Exemplo n.º 7
0
def test_add_transition_1():
    mdp = MDP()
    t = Transition("s0", "a0", "s1", 0, 0)
    mdp.add_transition(t)
    assert "s0" in mdp.states
    assert "s1" in mdp.states
    assert "a0" in mdp.actions

    assert len(mdp.states) == 2
    assert len(mdp.actions) == 1
Exemplo n.º 8
0
def test_compile_4():
    mdp = MDP()
    t1 = Transition("s0", "a0", "s1", 1, 0)
    mdp.add_transition(t1)

    with pytest.raises(ValueError):
        mdp.ensure_compiled()

    mdp.add_terminal_state("s1")
    mdp.ensure_compiled()
Exemplo n.º 9
0
def test_compile_5():
    mdp = MDP()
    t1 = Transition("s0", "a0", "s1", 1, 0)
    mdp.add_transition(t1)
    mdp.add_terminal_state("s1")

    mdp.ensure_compiled()
    assert mdp.compiled
    assert type(mdp.states) is tuple
    assert type(mdp.actions) is tuple
    assert type(mdp.terminal_states) is tuple
    mdp._decompile()
    assert not mdp.compiled
    assert type(mdp.states) is set
    assert type(mdp.actions) is set
    assert type(mdp.terminal_states) is set
Exemplo n.º 10
0
def test_set_terminal():
    mdp = MDP()
    t1 = Transition("s0", "a0", "s1", 0, 0)
    mdp.add_transition(t1)

    mdp.add_terminal_state("s0")

    assert len(mdp.states) == 2
    assert len(mdp.actions) == 1
    assert len(mdp.terminal_states) == 1
    assert "s0" in mdp.terminal_states

    mdp.add_terminal_state("s4")
    assert len(mdp.states) == 3
    assert len(mdp.terminal_states)
    assert "s4" in mdp.terminal_states
    assert "s4" in mdp.states
Exemplo n.º 11
0
        :param theta (float, optional): stop threshold, defaults to 1e-6
        :return (Tuple[np.ndarray of float with dim (num of states, num of actions),
                       np.ndarray of float with dim (num of states)]):
            Tuple of calculated policy and value function
        """
        self.mdp.ensure_compiled()
        self.theta = theta
        return self._policy_improvement()


if __name__ == "__main__":
    mdp = MDP()
    mdp.add_transition(
        #         start action end prob reward
        Transition("high", "wait", "high", 1, 2),
        Transition("high", "search", "high", 0.8, 5),
        Transition("high", "search", "low", 0.2, 5),
        Transition("high", "recharge", "high", 1, 0),
        Transition("low", "recharge", "high", 1, 0),
        Transition("low", "wait", "low", 1, 2),
        Transition("low", "search", "high", 0.6, -3),
        Transition("low", "search", "low", 0.4, 5),
    )

    solver = ValueIteration(mdp, 0.9)
    policy, valuefunc = solver.solve()
    print("---Value Iteration---")
    print("Policy:")
    print(solver.decode_policy(policy))
    print("Value Function")