def test_unwrap_finite_horizon_MDP(self):
        finite = finite_horizon_MDP(self.finite_flip_flop, 10)
        unwrapped = unwrap_finite_horizon_MDP(finite)

        self.assertEqual(len(unwrapped), 10)

        def action_mapping_for(
                s: WithTime[bool]) -> ActionMapping[bool, WithTime[bool]]:
            same = s.step_time()
            different = dataclasses.replace(s.step_time(), state=not s.state)

            return {
                True: Categorical({
                    (same, 1.0): 0.7,
                    (different, 2.0): 0.3
                }),
                False: Categorical({
                    (same, 1.0): 0.3,
                    (different, 2.0): 0.7
                })
            }

        for t in range(0, 10):
            for s in True, False:
                s_time = WithTime(state=s, time=t)
                for a in True, False:
                    distribution.assert_almost_equal(
                        self,
                        finite.action_mapping(s_time)[a],
                        action_mapping_for(s_time)[a])

        self.assertEqual(finite.action_mapping(WithTime(state=True, time=10)),
                         None)
    def test_compare_to_backward_induction(self):
        finite_horizon = finite_horizon_MRP(self.finite_flip_flop, 10)

        v = evaluate_mrp_result(finite_horizon, gamma=1)
        self.assertEqual(len(v), 20)

        finite_v =\
            list(evaluate(unwrap_finite_horizon_MRP(finite_horizon), gamma=1))

        for time in range(0, 10):
            self.assertAlmostEqual(v[WithTime(state=True, time=time)],
                                   finite_v[time][True])
            self.assertAlmostEqual(v[WithTime(state=False, time=time)],
                                   finite_v[time][False])
Beispiel #3
0
    def test_finite_horizon_MDP(self):
        finite = finite_horizon_MDP(self.finite_flip_flop, limit=10)

        self.assertEqual(len(finite.non_terminal_states), 20)

        for s in finite.non_terminal_states:
            self.assertEqual(set(finite.actions(s)), {False, True})

        start = NonTerminal(WithTime(state=True, time=0))
        result = finite.mapping[start][False]
        expected_result = Categorical({
            (NonTerminal(WithTime(False, time=1)), 2.0): 0.7,
            (NonTerminal(WithTime(True, time=1)), 1.0): 0.3
        })
        distribution.assert_almost_equal(self, result, expected_result)
    def test_compare_to_backward_induction(self):
        finite_horizon = finite_horizon_MRP(self.finite_flip_flop, 10)

        start = Dynamic({s: 0.0 for s in finite_horizon.states()})
        v = FunctionApprox.converged(
            evaluate_finite_mrp(finite_horizon, γ=1, approx_0=start))
        self.assertEqual(len(v.values_map), 22)

        finite_v =\
            list(evaluate(unwrap_finite_horizon_MRP(finite_horizon), gamma=1))

        for time in range(0, 10):
            self.assertAlmostEqual(v(WithTime(state=True, time=time)),
                                   finite_v[time][True])
            self.assertAlmostEqual(v(WithTime(state=False, time=time)),
                                   finite_v[time][False])
    def test_finite_horizon_MDP(self):
        finite = finite_horizon_MDP(self.finite_flip_flop, limit=10)

        self.assertEqual(len(finite.states()), 22)

        for s in finite.states():
            if len(set(finite.actions(s))) > 0:
                self.assertEqual(set(finite.actions(s)), {False, True})

        start = WithTime(state=True, time=0)
        result = finite.action_mapping(start)[False]
        expected_result = Categorical({
            (WithTime(False, time=1), 2.0): 0.7,
            (WithTime(True, time=1), 1.0): 0.3
        })
        distribution.assert_almost_equal(self, result, expected_result)

        self.assertEqual(finite.step(WithTime(True, 10), True), None)
Beispiel #6
0
    def test_finite_horizon_MRP(self):
        finite = finite_horizon_MRP(self.finite_flip_flop, 10)

        trues = [NonTerminal(WithTime(True, time)) for time in range(10)]
        falses = [NonTerminal(WithTime(False, time)) for time in range(10)]
        non_terminal_states = set(trues + falses)
        self.assertEqual(set(finite.non_terminal_states), non_terminal_states)

        expected_transition = {}
        for state in non_terminal_states:
            t: int = state.state.time
            st: bool = state.state.state
            if t < 9:
                prob = {
                    (NonTerminal(WithTime(st, t + 1)), 1.0): 0.3,
                    (NonTerminal(WithTime(not st, t + 1)), 2.0): 0.7
                }
            else:
                prob = {
                    (Terminal(WithTime(st, t + 1)), 1.0): 0.3,
                    (Terminal(WithTime(not st, t + 1)), 2.0): 0.7
                }

            expected_transition[state] = Categorical(prob)

        for state in non_terminal_states:
            distribution.assert_almost_equal(
                self,
                finite.transition_reward(state),
                expected_transition[state])
    def test_finite_horizon_MRP(self):
        finite = finite_horizon_MRP(self.finite_flip_flop, 10)

        trues = [WithTime(True, time) for time in range(0, 10)]
        falses = [WithTime(False, time) for time in range(0, 10)]
        non_terminal_states = set(trues + falses)
        terminal_states = {WithTime(True, 10), WithTime(False, 10)}
        expected_states = non_terminal_states.union(terminal_states)

        self.assertEqual(set(finite.states()), expected_states)

        expected_transition = {}
        for state in non_terminal_states:
            expected_transition[state] =\
                Categorical({
                    (WithTime(state.state, state.time + 1), 1.0): 0.3,
                    (WithTime(not state.state, state.time + 1), 2.0): 0.7
                })

        for state in non_terminal_states:
            distribution.assert_almost_equal(self,
                                             finite.transition_reward(state),
                                             expected_transition[state])

        for state in terminal_states:
            self.assertEqual(finite.transition(state), None)