def test_evaluate_mrp(self): start = Dynamic({s: 0.0 for s in self.finite_flip_flop.states()}) v = iterate.converged( evaluate_mrp( self.finite_flip_flop, γ=0.99, approx_0=start, non_terminal_states_distribution=Choose( set(self.finite_flip_flop.states())), num_state_samples=5, ), done=lambda a, b: a.within(b, 1e-4), ) self.assertEqual(len(v.values_map), 2) for s in v.values_map: self.assertLess(abs(v(s) - 170), 1.0) v_finite = iterate.converged( evaluate_finite_mrp(self.finite_flip_flop, γ=0.99, approx_0=start), done=lambda a, b: a.within(b, 1e-4), ) assert_allclose(v.evaluate([True, False]), v_finite.evaluate([True, False]), rtol=0.01)
def test_evaluate_mrp(self): mrp_vf1: np.ndarray = self.implied_mrp.get_value_function_vec( self.gamma) # print({s: mrp_vf1[i] for i, s in enumerate(self.states)}) fa = Dynamic({s: 0.0 for s in self.states}) mrp_finite_fa = iterate.converged( evaluate_finite_mrp(self.implied_mrp, self.gamma, fa), done=lambda a, b: a.within(b, 1e-4), ) # print(mrp_finite_fa.values_map) mrp_vf2: np.ndarray = mrp_finite_fa.evaluate(self.states) self.assertLess(max(abs(mrp_vf1 - mrp_vf2)), 0.001) mrp_fa = iterate.converged( evaluate_mrp( self.implied_mrp, self.gamma, fa, Choose(self.states), num_state_samples=30, ), done=lambda a, b: a.within(b, 0.1), ) # print(mrp_fa.values_map) mrp_vf3: np.ndarray = mrp_fa.evaluate(self.states) self.assertLess(max(abs(mrp_vf1 - mrp_vf3)), 1.0)
def update(vf_policy: Tuple[FunctionApprox[S], ThisPolicy[S, A]]) \ -> Tuple[FunctionApprox[S], ThisPolicy[S, A]]: nt_states: Sequence[S] = non_terminal_states_distribution\ .sample_n(num_state_samples) vf, pi = vf_policy mrp: MarkovRewardProcess[S] = mdp.apply_policy(pi) new_vf: FunctionApprox[S] = converged( evaluate_mrp(mrp, γ, vf, non_terminal_states_distribution, num_state_samples), done=lambda a, b: a.within(b, 1e-4) ) def return_(s_r: Tuple[S, float]) -> float: s1, r = s_r return r + γ * new_vf.evaluate([s1]).item() return (new_vf.update([(s, max(mdp.step(s, a).expectation(return_) for a in mdp.actions(s))) for s in nt_states]), ThisPolicy(mdp, return_))