Ejemplo n.º 1
0
def verify_transitions(
    states: Set[S],
    tr_seq: Sequence[Mapping[S, float]]
) -> bool:
    b1 = set().union(*tr_seq).issubset(states)
    b2 = all(is_approx_eq(sum(d.values()), 1.0) for d in tr_seq)
    return b1 and b2
Ejemplo n.º 2
0
 def get_terminal_states(self) -> Set[S]:
     """
     A terminal state is a sink state (100% probability to going back
     to itself, FOR EACH ACTION) and the rewards on those transitions back
     to itself are zero.
     """
     sink = self.get_sink_states()
     return {s for s in sink if
             all(is_approx_eq(r, 0.0) for _, r in self.rewards[s].items())}
Ejemplo n.º 3
0
def verify_transitions(
    states: Set[S],
    tr_seq: Sequence[Mapping[S, float]]
) -> bool:
    b1 = set().union(*tr_seq).issubset(states)
    for state in set().union(*tr_seq):
        if state not in states:
            print("state", state)
    b2 = all(all(x >= 0 for x in d.values()) for d in tr_seq)
    b3 = all(is_approx_eq(sum(d.values()), 1.0) for d in tr_seq)
    #print((x for x in d.values()) for d in tr_seq)
    print(b1, b2, b3)
    return b1 and b2
Ejemplo n.º 4
0
 def verify_data(transitions_rewards: Sequence[SASTff],
                 terminal_opt_val: Mapping[S, float], gamma: float) -> bool:
     valid = 0. <= gamma <= 1.
     time_len = len(transitions_rewards)
     i = 0
     while valid and i < time_len:
         this_d = transitions_rewards[i]
         check_actions = all(len(v) > 0 for _, v in this_d.items())
         next_dict = [{k: v
                       for k, (v, _) in d1.items()}
                      for _, d in this_d.items() for _, d1 in d.items()]
         check_pos = all(
             all(x >= 0 for x in d1.values()) for d1 in next_dict)
         check_sum = all(
             is_approx_eq(sum(d1.values()), 1.0) for d1 in next_dict)
         states = set((transitions_rewards[i + 1]
                       if i < time_len - 1 else terminal_opt_val).keys())
         subset = all(set(d1.keys()).issubset(states) for d1 in next_dict)
         valid = valid and check_actions and check_pos and check_sum and subset
         i = i + 1
     return valid
Ejemplo n.º 5
0
def verify_policy(policy_data: SAf) -> bool:
    return all(is_approx_eq(sum(v.values()), 1.0) for s, v in policy_data.items())
Ejemplo n.º 6
0
def get_lean_transitions(d: Mapping[S, float]) -> Mapping[S, float]:
    return {s: v for s, v in d.items() if not is_approx_eq(v, 0.0)}
Ejemplo n.º 7
0
 def get_terminal_states(self) -> Set[S]:
     sink = self.get_sink_states()
     return {s for s in sink if is_approx_eq(self.rewards[s], 0.0)}