def test_unwrap_finite_horizon_MDP(self): finite = finite_horizon_MDP(self.finite_flip_flop, 10) unwrapped = unwrap_finite_horizon_MDP(finite) self.assertEqual(len(unwrapped), 10) def action_mapping_for( s: WithTime[bool]) -> ActionMapping[bool, WithTime[bool]]: same = s.step_time() different = dataclasses.replace(s.step_time(), state=not s.state) return { True: Categorical({ (same, 1.0): 0.7, (different, 2.0): 0.3 }), False: Categorical({ (same, 1.0): 0.3, (different, 2.0): 0.7 }) } for t in range(0, 10): for s in True, False: s_time = WithTime(state=s, time=t) for a in True, False: distribution.assert_almost_equal( self, finite.action_mapping(s_time)[a], action_mapping_for(s_time)[a]) self.assertEqual(finite.action_mapping(WithTime(state=True, time=10)), None)
def test_compare_to_backward_induction(self): finite_horizon = finite_horizon_MRP(self.finite_flip_flop, 10) v = evaluate_mrp_result(finite_horizon, gamma=1) self.assertEqual(len(v), 20) finite_v =\ list(evaluate(unwrap_finite_horizon_MRP(finite_horizon), gamma=1)) for time in range(0, 10): self.assertAlmostEqual(v[WithTime(state=True, time=time)], finite_v[time][True]) self.assertAlmostEqual(v[WithTime(state=False, time=time)], finite_v[time][False])
def test_finite_horizon_MDP(self): finite = finite_horizon_MDP(self.finite_flip_flop, limit=10) self.assertEqual(len(finite.non_terminal_states), 20) for s in finite.non_terminal_states: self.assertEqual(set(finite.actions(s)), {False, True}) start = NonTerminal(WithTime(state=True, time=0)) result = finite.mapping[start][False] expected_result = Categorical({ (NonTerminal(WithTime(False, time=1)), 2.0): 0.7, (NonTerminal(WithTime(True, time=1)), 1.0): 0.3 }) distribution.assert_almost_equal(self, result, expected_result)
def test_compare_to_backward_induction(self): finite_horizon = finite_horizon_MRP(self.finite_flip_flop, 10) start = Dynamic({s: 0.0 for s in finite_horizon.states()}) v = FunctionApprox.converged( evaluate_finite_mrp(finite_horizon, γ=1, approx_0=start)) self.assertEqual(len(v.values_map), 22) finite_v =\ list(evaluate(unwrap_finite_horizon_MRP(finite_horizon), gamma=1)) for time in range(0, 10): self.assertAlmostEqual(v(WithTime(state=True, time=time)), finite_v[time][True]) self.assertAlmostEqual(v(WithTime(state=False, time=time)), finite_v[time][False])
def test_finite_horizon_MDP(self): finite = finite_horizon_MDP(self.finite_flip_flop, limit=10) self.assertEqual(len(finite.states()), 22) for s in finite.states(): if len(set(finite.actions(s))) > 0: self.assertEqual(set(finite.actions(s)), {False, True}) start = WithTime(state=True, time=0) result = finite.action_mapping(start)[False] expected_result = Categorical({ (WithTime(False, time=1), 2.0): 0.7, (WithTime(True, time=1), 1.0): 0.3 }) distribution.assert_almost_equal(self, result, expected_result) self.assertEqual(finite.step(WithTime(True, 10), True), None)
def test_finite_horizon_MRP(self): finite = finite_horizon_MRP(self.finite_flip_flop, 10) trues = [NonTerminal(WithTime(True, time)) for time in range(10)] falses = [NonTerminal(WithTime(False, time)) for time in range(10)] non_terminal_states = set(trues + falses) self.assertEqual(set(finite.non_terminal_states), non_terminal_states) expected_transition = {} for state in non_terminal_states: t: int = state.state.time st: bool = state.state.state if t < 9: prob = { (NonTerminal(WithTime(st, t + 1)), 1.0): 0.3, (NonTerminal(WithTime(not st, t + 1)), 2.0): 0.7 } else: prob = { (Terminal(WithTime(st, t + 1)), 1.0): 0.3, (Terminal(WithTime(not st, t + 1)), 2.0): 0.7 } expected_transition[state] = Categorical(prob) for state in non_terminal_states: distribution.assert_almost_equal( self, finite.transition_reward(state), expected_transition[state])
def test_finite_horizon_MRP(self): finite = finite_horizon_MRP(self.finite_flip_flop, 10) trues = [WithTime(True, time) for time in range(0, 10)] falses = [WithTime(False, time) for time in range(0, 10)] non_terminal_states = set(trues + falses) terminal_states = {WithTime(True, 10), WithTime(False, 10)} expected_states = non_terminal_states.union(terminal_states) self.assertEqual(set(finite.states()), expected_states) expected_transition = {} for state in non_terminal_states: expected_transition[state] =\ Categorical({ (WithTime(state.state, state.time + 1), 1.0): 0.3, (WithTime(not state.state, state.time + 1), 2.0): 0.7 }) for state in non_terminal_states: distribution.assert_almost_equal(self, finite.transition_reward(state), expected_transition[state]) for state in terminal_states: self.assertEqual(finite.transition(state), None)