def bound_t_new(self, node, gamma, R, O, T, S, fringe_function: Bound): """ Implement bound calculation relative to nodes. In the usual calculation, we define: Note: :math:`B_T(b)` is implemented implicitly in the ``update_x_bound`` functions, so we just implement B_T(b,a) here. .. math:: B_T(b) = \begin{cases}B(b), & \text{if } b \in F(T) \\ \text{max}_{a \in A} B_T(b,a), & \text{otherwise}\end{cases} B_T(b,a) = R_B(b,a) + \gamma \sum_{z \in Z} \text{Pr}(z | b,a)B_T(\tau(b,a,z)) However, note that the belief of a node's child is exactly :math:`\tau(b,a,z)` Thus, :math:`B_T(\tau(b,a,z))` is just a lookup of the bound for the zth child of this (AND) Node. So, we'll implement B_T(b,a) as: .. math:: B_T(node) = R_B(node.b,node.action) + \gamma \sum_{child \in node.children} \text{Pr}(child.observation | node.b,node.action)*child.bound :param node: :param gamma: :param R: :param O: :param T: :param S: :param fringe_function: :return: """ if node in self.fringe: return fringe_function.value(node.belief_state) assert "AND_Node" in str(type(node)) discounted_sum = 0 for child in node.children: child_bound = 0 if isinstance(fringe_function, UpperBound): child_bound = child.upper_bound elif isinstance(fringe_function, LowerBound): child_bound = child.lower_bound else: raise ValueError( "fringe_function to bound_t_new is neither {} nor {}, but {}.".format( UpperBound, LowerBound, type(fringe_function) ) ) discounted_sum += self.pr_z(child.observation, node.belief_state, node.action, O, T, S) * child_bound discounted_sum *= gamma # We calculate R_B(belief_state, action) here, from R. R_B = self.R_B(belief_state=node.belief_state, action=node.action, S=S, R=R) return R_B + discounted_sum
def bound_t(self, belief_state, gamma, R, O, T, S, fringe_function: Bound): """Abstract update to upper bound and lower bound. By taking a function parameter, abstract the calculation for bound propagation. If `belief_state` is a fringe belief, we use `fringe_function` to calculate its bound. Thus, if `fringe_function` is a ``Lower_Bound``, `bound_t` calculates the lower bound. Args: belief_state (list[float]): the belief state to use for updating R (Dict[tuple[ACTION,STATE], float]): Mapping from action/state tuple to a payoff real value. Represents the immediate payoff of taking an action in a state. O (dict[str, dict[str, dict[str, float]]]): The observation function that maps a state to an action to an observation to a probability. Represents the conditional probability of observing w given state s and action a. T (dict[str, dict[str, dict[str, float]]]): The transition function that maps a state to an action to a state to a probability. Represents the conditional probability of transitioning from state s to s' given action a. S (list[str]): The states of the POMDP. fringe_function (Bound): A Bound that exposes a value method to use for fringe node Bound calculation. Returns: float: The new bound for the node corresponding to the belief state `belief_state`. """ if belief_state in self.fringe_beliefs: return fringe_function.value(belief_state) else: new_bound = 0 for action in self.actions: bound_t = self.bound_action( belief_state=belief_state, gamma=gamma, R=R, O=O, T=T, S=S, action=action, fringe_function=fringe_function, ) if bound_t > new_bound: new_bound = bound_t return new_bound