示例#1
0
    def msep(self, A, B, C=set()):
        """
        Check whether A and B are m-separated given C, using the Bayes ball algorithm.


        """
        # type coercion
        A = core_utils.to_set(A)
        B = core_utils.to_set(B)
        C = core_utils.to_set(C)

        # shade ancestors of C
        shaded_nodes = set(C)
        for node in C:
            self._add_upstream(shaded_nodes, node)

        visited = set()
        # marks whether the node has been encountered along a path where it has a tail or an arrowhead
        _t = 'tail'  # tail
        _a = 'arrowhead'  # arrowhead

        schedule = {(node, _t) for node in A}
        while schedule:
            node, _dir = schedule.pop()
            if node in B: return False
            if (node, _dir) in visited: continue
            visited.add((node, _dir))
            # print(node, _dir)

            # if coming through a tail, won't encounter v-structure
            if _dir == _t and node not in C:
                schedule.update({(parent, _t)
                                 for parent in self._parents[node]})
                schedule.update({(child, _a)
                                 for child in self._children[node]})
                schedule.update({(spouse, _a)
                                 for spouse in self._spouses[node]})
                schedule.update({(nbr, _t) for nbr in self._neighbors[node]})

            if _dir == _a:
                # if coming through an arrowhead and see shaded node, can go through v-structure
                if node in shaded_nodes:
                    schedule.update({(parent, _t)
                                     for parent in self._parents[node]})
                    schedule.update({(spouse, _a)
                                     for spouse in self._spouses[node]})

                # if coming through an arrowhead and see unconditioned node, can go through children and neighbors
                if node not in C:
                    schedule.update({(child, _a)
                                     for child in self._children[node]})
                    schedule.update({(nbr, _a)
                                     for nbr in self._neighbors[node]})

        return True
示例#2
0
    def general_identification(
            self, y: set, x: Optional[Set[Node]],
            available_experiments: Optional[Set[frozenset]]):
        x, y = to_set(x), to_set(y)

        # LINE 2: return matching experiment if one exists
        matching_experiment = next(
            (e for e in available_experiments if x == (e & self._nodes)), None)
        if matching_experiment:
            return ProbabilityTerm(y, intervened=matching_experiment)

        # LINE 3: simplify graph to only the ancestors of Y
        ancestors_y = self.ancestors_of(y)
        if ancestors_y != self._nodes:
            new_graph = self.induced_subgraph(ancestors_y)
            return new_graph.general_identification(y, x & ancestors_y,
                                                    available_experiments)

        # LINE 4: fix values of ancestors of Y
        x_mutilated_graph = self.mutilated_graph(incoming_mutilated=x)
        ancestors_y_no_x = x_mutilated_graph.ancestors_of(y)
        w = self._nodes - x - ancestors_y_no_x
        if len(w) != 0:
            return self.general_identification(y, x | w, available_experiments)

        # LINE 5: get c-components of graph without X
        g_minus_x = self.induced_subgraph(removed_nodes=x)
        s_components = g_minus_x.c_components()

        # LINE 6: factorize into c-components
        if len(s_components) > 1:
            prod = Product([
                self.general_identification(s, self._nodes - s,
                                            available_experiments)
                for s in s_components
            ])
            return prod.get_conditional(marginal=y | x)

        # LINE 7: identify P_x using P_z as base distribution, if possible
        for z in available_experiments:
            if (z & self._nodes) <= x:
                g_no_zx = self.induced_subgraph(removed_nodes=z | x)
                distribution = ProbabilityTerm(
                    self._nodes - z - x,
                    intervened=(z - self._nodes) |
                    (x & z))  # TODO right distribution?
                sub_id = g_no_zx.sub_identification(y, x - z, distribution)
                if sub_id is not None:
                    return sub_id

        # LINE 8
        raise NonIdentifiabilityError
示例#3
0
    def get_conditional(self, marginal=None, cond_set=None):
        super().get_conditional(marginal, cond_set)
        marginal, cond_set = to_set(marginal), to_set(cond_set)

        if len(marginal) == 0:
            active_variables = self.active_variables - cond_set
        else:
            assert marginal <= self.active_variables
            active_variables = marginal

        return ProbabilityTerm(active_variables,
                               cond_set=self.cond_set | cond_set,
                               intervened=self.intervened)
示例#4
0
    def get_conditional(self, marginal=None, cond_set=None):
        super().get_conditional(marginal, cond_set)
        marginal, cond_set = to_set(marginal), to_set(cond_set)

        if len(cond_set) == 0:
            return MarginalDistribution(self.distribution,
                                        marginal_variables=marginal)
        if len(marginal) == 0:
            marginal = self.active_variables - cond_set

        conditional = self.distribution.get_conditional(cond_set=cond_set)
        marginal_variables = self.active_variables if marginal is None else marginal
        return MarginalDistribution(conditional,
                                    marginal_variables=marginal_variables)
示例#5
0
    def is_invariant(self, node, context, cond_set=set()):
        """
        Check if the conditional distribution of node, given cond_set, is invariant to the context.
        """
        cond_set = to_set(cond_set)
        index = (node, context, frozenset(cond_set))

        # check if result exists and return
        _is_invariant = self.invariance_dict.get(index)
        if _is_invariant is not None:
            return _is_invariant

        # otherwise, compute result and save
        if self.track_times:
            start = time.time()
        test_results = self.invariance_test(self.suffstat,
                                            context,
                                            node,
                                            cond_set=cond_set,
                                            **self.kwargs)
        if self.track_times:
            self.invariance_times[index] = time.time() - start
        if self.detailed:
            self.invariance_dict_detailed[index] = test_results
        _is_invariant = not test_results['reject']
        self.invariance_dict[index] = _is_invariant

        return _is_invariant
示例#6
0
    def get_conditional(self, marginal=None, cond_set=None):
        super().get_conditional(marginal, cond_set)
        marginal, cond_set = to_set(marginal), to_set(cond_set)

        if len(cond_set) == 0:
            return MarginalDistribution(self, marginal)
        else:
            if len(marginal) == 0:
                marginal = self.active_variables - cond_set

            prod = Product([
                term.get_conditional(cond_set=cond_set) for term in self.terms
            ])
            if (marginal | cond_set) == self.active_variables:
                return prod
            else:
                return MarginalDistribution(prod, marginal)
示例#7
0
    def msep_from_given(self, A, C=set()):
        """Find all nodes m-seperated from A given C using algorithm similar to that in Geiger, D., Verma, T., & Pearl, J. (1990). Identifying independence in Bayesian networks. Networks, 20(5), 507-534."""

        A = core_utils.to_set(A)
        C = core_utils.to_set(C)

        determined = set()
        descendants = set()

        for c in C:
            determined.add(c)
            descendants.add(c)
            self._add_upstream(descendants, c)

        reachable = set()
        i_links = set()
        labeled_links = set()

        for a in A:
            i_links.add((None, a))
            reachable.add(a)

        while True:
            i_p_1_links = set()
            # Find all unlabled links v->w adjacent to at least one link u->v labeled i, such that (u->v,v->w) is a legal pair.
            for link in i_links:
                u, v = link
                for w in self._adjacent[v]:
                    if not u == w and (v, w) not in labeled_links:
                        if self._is_collider(u, v, w):  # Is collider?
                            if v in descendants:
                                i_p_1_links.add((v, w))
                                reachable.add(w)
                        else:  # Not collider
                            if v not in determined:
                                i_p_1_links.add((v, w))
                                reachable.add(w)

            if len(i_p_1_links) == 0:
                break

            labeled_links = labeled_links.union(i_links)
            i_links = i_p_1_links

        return self._nodes.difference(A).difference(C).difference(reachable)
示例#8
0
    def partial_correlation(self, i, j, cond_set):
        """
        Return the partial correlation of i and j conditioned on `cond_set`.

        Parameters
        ----------
        i: first node.
        j: second node.
        cond_set: conditioning set.

        Examples
        --------
        TODO
        """
        cond_set = core_utils.to_set(cond_set)
        if len(cond_set) == 0:
            return self.correlation[i, j]
        else:
            theta = inv(self.correlation[np.ix_([i, j, *cond_set], [i, j, *cond_set])])
            return -theta[0, 1] / np.sqrt(theta[0, 0] * theta[1, 1])
示例#9
0
    def __init__(self, active_variables: set, cond_set=None, intervened=None):
        super().__init__()

        self.active_variables = active_variables
        self.cond_set = to_set(cond_set)
        self.intervened = to_set(intervened)
示例#10
0
 def ancestors_of(self,
                  node: Union[Node, Set[Node]],
                  include_argument=True) -> Set[Node]:
     ancestors = set() if not include_argument else to_set(node)
     self._add_ancestors(ancestors, to_set(node))
     return ancestors
示例#11
0
 def get_conditional(self, marginal=None, cond_set=None):
     marginal = to_set(marginal)
     cond_set = to_set(cond_set)
     assert len(marginal & cond_set) == 0
     assert marginal | cond_set <= self.active_variables