Esempio n. 1
0
File: AO_Tree.py Progetto: acgs/PDHS
    def bound_t_new(self, node, gamma, R, O, T, S, fringe_function: Bound):
        """
        Implement bound calculation relative to nodes.

        In the usual calculation, we define:

        Note:
            :math:`B_T(b)` is implemented implicitly in the ``update_x_bound`` functions, so we just implement B_T(b,a) here.

        .. math::
            B_T(b) = \begin{cases}B(b), & \text{if } b \in F(T) \\ \text{max}_{a \in A} B_T(b,a), & \text{otherwise}\end{cases}

            B_T(b,a) = R_B(b,a) + \gamma \sum_{z \in Z} \text{Pr}(z | b,a)B_T(\tau(b,a,z))

        However, note that the belief of a node's child is exactly :math:`\tau(b,a,z)`
        Thus, :math:`B_T(\tau(b,a,z))` is just a lookup of the bound for the zth child of this (AND) Node.

        So, we'll implement B_T(b,a) as:

        .. math::
            B_T(node) = R_B(node.b,node.action) + \gamma \sum_{child \in node.children} \text{Pr}(child.observation | node.b,node.action)*child.bound

        :param node:
        :param gamma:
        :param R:
        :param O:
        :param T:
        :param S:
        :param fringe_function:
        :return:
        """
        if node in self.fringe:
            return fringe_function.value(node.belief_state)

        assert "AND_Node" in str(type(node))
        discounted_sum = 0
        for child in node.children:
            child_bound = 0
            if isinstance(fringe_function, UpperBound):
                child_bound = child.upper_bound
            elif isinstance(fringe_function, LowerBound):
                child_bound = child.lower_bound
            else:
                raise ValueError(
                    "fringe_function to bound_t_new is neither {} nor {}, but {}.".format(
                        UpperBound, LowerBound, type(fringe_function)
                    )
                )
            discounted_sum += self.pr_z(child.observation, node.belief_state, node.action, O, T, S) * child_bound
        discounted_sum *= gamma
        # We calculate R_B(belief_state, action) here, from R.
        R_B = self.R_B(belief_state=node.belief_state, action=node.action, S=S, R=R)
        return R_B + discounted_sum
Esempio n. 2
0
File: AO_Tree.py Progetto: acgs/PDHS
    def bound_t(self, belief_state, gamma, R, O, T, S, fringe_function: Bound):
        """Abstract update to upper bound and lower bound.

        By taking a function parameter, abstract the calculation for bound propagation.

        If `belief_state` is a fringe belief, we use `fringe_function` to calculate its bound.
        Thus, if `fringe_function` is a ``Lower_Bound``, `bound_t` calculates the lower bound.

        Args:
            belief_state (list[float]): the belief state to use for updating
            R (Dict[tuple[ACTION,STATE], float]): Mapping from action/state tuple to a payoff real value. Represents the immediate payoff of taking an action in a state.
            O (dict[str, dict[str, dict[str, float]]]): The observation function that maps a state to an action to an observation to a probability. Represents the conditional probability of observing w given state s and action a.
            T (dict[str, dict[str, dict[str, float]]]): The transition function that maps a state to an action to a state to a probability. Represents the conditional probability of transitioning from state s to s' given action a.
            S (list[str]): The states of the POMDP.
            fringe_function (Bound): A Bound that exposes a value method to use for fringe node Bound calculation.

        Returns:
            float: The new bound for the node corresponding to the belief state `belief_state`.
        """
        if belief_state in self.fringe_beliefs:
            return fringe_function.value(belief_state)
        else:
            new_bound = 0
            for action in self.actions:
                bound_t = self.bound_action(
                    belief_state=belief_state,
                    gamma=gamma,
                    R=R,
                    O=O,
                    T=T,
                    S=S,
                    action=action,
                    fringe_function=fringe_function,
                )
                if bound_t > new_bound:
                    new_bound = bound_t

        return new_bound