def test_finite_horizon_MRP(self):
        finite = finite_horizon_MRP(self.finite_flip_flop, 10)

        trues = [NonTerminal(WithTime(True, time)) for time in range(10)]
        falses = [NonTerminal(WithTime(False, time)) for time in range(10)]
        non_terminal_states = set(trues + falses)
        self.assertEqual(set(finite.non_terminal_states), non_terminal_states)

        expected_transition = {}
        for state in non_terminal_states:
            t: int = state.state.time
            st: bool = state.state.state
            if t < 9:
                prob = {
                    (NonTerminal(WithTime(st, t + 1)), 1.0): 0.3,
                    (NonTerminal(WithTime(not st, t + 1)), 2.0): 0.7
                }
            else:
                prob = {
                    (Terminal(WithTime(st, t + 1)), 1.0): 0.3,
                    (Terminal(WithTime(not st, t + 1)), 2.0): 0.7
                }

            expected_transition[state] = Categorical(prob)

        for state in non_terminal_states:
            distribution.assert_almost_equal(
                self,
                finite.transition_reward(state),
                expected_transition[state])
    def setUp(self):
        ii = 10
        self.steps = 6
        pairs = [(1.0, 0.5), (0.7, 1.0), (0.5, 1.5), (0.3, 2.5)]
        self.cp: ClearancePricingMDP = ClearancePricingMDP(
            initial_inventory=ii,
            time_steps=self.steps,
            price_lambda_pairs=pairs)

        def policy_func(x: int) -> int:
            return 0 if x < 2 else (1 if x < 5 else (2 if x < 8 else 3))

        stationary_policy: FiniteDeterministicPolicy[int, int] = \
            FiniteDeterministicPolicy(
                {s: policy_func(s) for s in range(ii + 1)}
            )

        self.single_step_mrp: FiniteMarkovRewardProcess[int] = \
            self.cp.single_step_mdp.apply_finite_policy(stationary_policy)

        self.mrp_seq = unwrap_finite_horizon_MRP(
            finite_horizon_MRP(self.single_step_mrp, self.steps))

        self.single_step_mdp: FiniteMarkovDecisionProcess[int, int] = \
            self.cp.single_step_mdp

        self.mdp_seq = unwrap_finite_horizon_MDP(
            finite_horizon_MDP(self.single_step_mdp, self.steps))
    def test_finite_horizon_MRP(self):
        finite = finite_horizon_MRP(self.finite_flip_flop, 10)

        trues = [WithTime(True, time) for time in range(0, 10)]
        falses = [WithTime(False, time) for time in range(0, 10)]
        non_terminal_states = set(trues + falses)
        terminal_states = {WithTime(True, 10), WithTime(False, 10)}
        expected_states = non_terminal_states.union(terminal_states)

        self.assertEqual(set(finite.states()), expected_states)

        expected_transition = {}
        for state in non_terminal_states:
            expected_transition[state] =\
                Categorical({
                    (WithTime(state.state, state.time + 1), 1.0): 0.3,
                    (WithTime(not state.state, state.time + 1), 2.0): 0.7
                })

        for state in non_terminal_states:
            distribution.assert_almost_equal(self,
                                             finite.transition_reward(state),
                                             expected_transition[state])

        for state in terminal_states:
            self.assertEqual(finite.transition(state), None)
    def test_evaluate(self):
        process = finite_horizon_MRP(self.finite_flip_flop, 10)
        vs = list(evaluate(unwrap_finite_horizon_MRP(process), gamma=1))

        self.assertEqual(len(vs), 10)

        self.assertAlmostEqual(vs[0][NonTerminal(True)], 17)
        self.assertAlmostEqual(vs[0][NonTerminal(False)], 17)

        self.assertAlmostEqual(vs[5][NonTerminal(True)], 17 / 2)
        self.assertAlmostEqual(vs[5][NonTerminal(False)], 17 / 2)

        self.assertAlmostEqual(vs[9][NonTerminal(True)], 17 / 10)
        self.assertAlmostEqual(vs[9][NonTerminal(False)], 17 / 10)
    def test_compare_to_backward_induction(self):
        finite_horizon = finite_horizon_MRP(self.finite_flip_flop, 10)

        v = evaluate_mrp_result(finite_horizon, gamma=1)
        self.assertEqual(len(v), 20)

        finite_v =\
            list(evaluate(unwrap_finite_horizon_MRP(finite_horizon), gamma=1))

        for time in range(0, 10):
            self.assertAlmostEqual(v[WithTime(state=True, time=time)],
                                   finite_v[time][True])
            self.assertAlmostEqual(v[WithTime(state=False, time=time)],
                                   finite_v[time][False])
    def test_compare_to_backward_induction(self):
        finite_horizon = finite_horizon_MRP(self.finite_flip_flop, 10)

        start = Dynamic({s: 0.0 for s in finite_horizon.states()})
        v = FunctionApprox.converged(
            evaluate_finite_mrp(finite_horizon, γ=1, approx_0=start))
        self.assertEqual(len(v.values_map), 22)

        finite_v =\
            list(evaluate(unwrap_finite_horizon_MRP(finite_horizon), gamma=1))

        for time in range(0, 10):
            self.assertAlmostEqual(v(WithTime(state=True, time=time)),
                                   finite_v[time][True])
            self.assertAlmostEqual(v(WithTime(state=False, time=time)),
                                   finite_v[time][False])
    def test_unwrap_finite_horizon_MRP(self):
        finite = finite_horizon_MRP(self.finite_flip_flop, 10)

        def transition_for(_):
            return {
                True: Categorical({
                    (NonTerminal(True), 1.0): 0.3,
                    (NonTerminal(False), 2.0): 0.7
                }),
                False: Categorical({
                    (NonTerminal(True), 2.0): 0.7,
                    (NonTerminal(False), 1.0): 0.3
                })
            }

        unwrapped = unwrap_finite_horizon_MRP(finite)
        self.assertEqual(len(unwrapped), 10)

        expected_transitions = [transition_for(n) for n in range(10)]
        for time in range(9):
            got = unwrapped[time]
            expected = expected_transitions[time]
            distribution.assert_almost_equal(
                self, got[NonTerminal(True)],
                expected[True]
            )
            distribution.assert_almost_equal(
                self, got[NonTerminal(False)],
                expected[False]
            )

        distribution.assert_almost_equal(
            self, unwrapped[9][NonTerminal(True)],
            Categorical({
                (Terminal(True), 1.0): 0.3,
                (Terminal(False), 2.0): 0.7
            })
        )
        distribution.assert_almost_equal(
            self, unwrapped[9][NonTerminal(False)],
            Categorical({
                (Terminal(True), 2.0): 0.7,
                (Terminal(False), 1.0): 0.3
            })
        )
    def test_unwrap_finite_horizon_MRP(self):
        finite = finite_horizon_MRP(self.finite_flip_flop, 10)

        def transition_for(time):
            return {
                True: Categorical({
                    (True, 1.0): 0.3,
                    (False, 2.0): 0.7,
                }),
                False: Categorical({
                    (True, 2.0): 0.7,
                    (False, 1.0): 0.3,
                })
            }

        unwrapped = unwrap_finite_horizon_MRP(finite)
        self.assertEqual(len(unwrapped), 10)

        expected_transitions = [transition_for(n) for n in range(0, 10)]
        for time in range(0, 10):
            got = unwrapped[time]
            expected = expected_transitions[time]
            distribution.assert_almost_equal(self, got[True], expected[True])
            distribution.assert_almost_equal(self, got[False], expected[False])
    print("Clearance Pricing MDP")
    print("---------------------")
    print(cp.mdp)

    def policy_func(x: int) -> int:
        return 0 if x < 2 else (1 if x < 5 else (2 if x < 8 else 3))

    stationary_policy: FinitePolicy[int, int] = FinitePolicy(
        {s: Constant(policy_func(s)) for s in range(ii + 1)}
    )

    single_step_mrp: FiniteMarkovRewardProcess[int] = \
        cp.single_step_mdp.apply_finite_policy(stationary_policy)

    vf_for_policy: Iterator[V[int]] = evaluate(
        unwrap_finite_horizon_MRP(finite_horizon_MRP(single_step_mrp, steps)),
        1.
    )

    print("Value Function for Stationary Policy")
    print("------------------------------------")
    for t, vf in enumerate(vf_for_policy):
        print(f"Time Step {t:d}")
        print("---------------")
        pprint(vf)

    print("Optimal Value Function and Optimal Policy")
    print("------------------------------------")
    prices = []
    for t, (vf, policy) in enumerate(cp.get_optimal_vf_and_policy()):
        print(f"Time Step {t:d}")