def find_terminal_states(self) -> Set[S]: # Terminal states are sink states with reward = 0 sink = self.find_sink_states() return { s for s in sink if all( is_approx_eq(r, 0.0) for _, r in self.rewards[s].items()) }
def policy_evaluation(self, pol: Policy) -> np.array: '''Iterative way to find the value functions given a specific policy''' mrp = self.find_mrp(pol) v0 = np.zeros(len(self.nt_states)) print(mrp.trans_matrix) print(mrp.reward_vector) converge = False while not converge: v1 = mrp.reward_vector + self.gamma * mrp.trans_matrix.dot(v0) converge = is_approx_eq(np.linalg.norm(v1), np.linalg.norm(v0)) v0 = v1 return v1
def find_terminal_states(self) -> Set[S]: sink = self.find_sink_states() return {s for s in sink if is_approx_eq(self.reward_graph[s], 0.0)}
def verify_policy(policy_data: SAf) -> bool: return all( is_approx_eq(sum(v.values()), 1.0) for s, v in policy_data.items())
def verify_transitions(states: Set[S], tr_seq: Sequence[Mapping[S, float]]) -> bool: b1 = set().union(*tr_seq).issubset(states) b2 = all(all(x >= 0 for x in d.values()) for d in tr_seq) b3 = all(is_approx_eq(sum(d.values()), 1.0) for d in tr_seq) return b1 and b2 and b3
def find_lean_transitions(d: Mapping[S, float]) -> Mapping[S, float]: return {s: v for s, v in d.items() if not is_approx_eq(v, 0.0)}