def to_gauss_dag(self, perm): """ Return a GaussDAG with the same mean and covariance as this GGM, and is a minimal IMAP of this GGM consistent with the node ordering `perm`. Parameters ---------- perm: The desired permutation, or total order, of the nodes in the result. Returns ------- Examples -------- TODO """ from causaldag import DAG, GaussDAG d = DAG(nodes=self.nodes) ixs = list( itr.chain.from_iterable( ((f, s) for f in range(s)) for s in range(len(perm)))) for i, j in ixs: pi_i, pi_j = perm[i], perm[j] if not np.isclose( self.partial_correlation(pi_i, pi_j, d.markov_blanket(pi_i)), 0): d.add_arc(pi_i, pi_j, unsafe=True) arcs = dict() means = [] Sigma = self.covariance variances = [] for i in perm: ps = list(d.parents_of(i)) # === LINEAR REGRESSION TO FIND EDGE WEIGHTS S_xx = Sigma[np.ix_(ps, ps)] S_xy = Sigma[ps, i] coeffs = inv(S_xx) @ S_xy # === COMPUTE MEAN AND VARIANCE mean = self.means[i] - self.means[ps] @ coeffs.T variance = Sigma[i, i] - Sigma[i, ps] @ coeffs for p, coeff in zip(ps, coeffs): print(p, i) arcs[(p, i)] = coeff means.append(mean) variances.append(variance) return GaussDAG(list(range(self.num_nodes)), arcs, means=means, variances=variances)
def perm2dag(perm, ci_tester: CI_Tester, verbose=False, fixed_adjacencies=set(), fixed_gaps=set(), node2nbrs=None, older=False): """ TODO Parameters ---------- perm ci_tester verbose fixed_adjacencies fixed_gaps node2nbrs older Examples -------- TODO """ d = DAG(nodes=set(perm)) ixs = list( itr.chain.from_iterable( ((f, s) for f in range(s)) for s in range(len(perm)))) for i, j in ixs: pi_i, pi_j = perm[i], perm[j] # === IF FIXED, DON'T TEST if (pi_i, pi_j) in fixed_adjacencies or (pi_j, pi_i) in fixed_adjacencies: d.add_arc(pi_i, pi_j) continue if (pi_i, pi_j) in fixed_gaps or (pi_j, pi_i) in fixed_gaps: continue # === TEST MARKOV BLANKET mb = d.markov_blanket(pi_i) if node2nbrs is None else ( set(perm[:j]) - {pi_i}) & (node2nbrs[pi_i] | node2nbrs[pi_j]) mb = mb if not older else set(perm[:j]) - {pi_i} is_ci = ci_tester.is_ci(pi_i, pi_j, mb) if not is_ci: d.add_arc(pi_i, pi_j, unsafe=True) if verbose: print("%s indep of %s given %s: %s" % (pi_i, pi_j, mb, is_ci)) return d
def perm2dag(perm: list, ci_tester: CI_Tester, verbose=False, fixed_adjacencies: Set[UndirectedEdge] = set(), fixed_gaps: Set[UndirectedEdge] = set(), node2nbrs=None, older=False, progress=False): """ Given a permutation, find the minimal IMAP consistent with that permutation and the results of conditional independence tests from ci_tester. Parameters ---------- perm: list of nodes representing the permutation. ci_tester: object for testing conditional independence. verbose: if True, log each CI test. fixed_adjacencies: set of nodes known to be adjacent. fixed_gaps: set of nodes known not to be adjacent. node2nbrs: TODO older: TODO Examples -------- >>> from causaldag.utils.ci_tests import MemoizedCI_Tester, gauss_ci_test, gauss_ci_suffstat >>> perm = [0,1,2] >>> suffstat = gauss_ci_suffstat(samples) >>> ci_tester = MemoizedCI_Tester(gauss_ci_test, suffstat) >>> perm2dag(perm, ci_tester, fixed_gaps={frozenset({1, 2})}) """ if fixed_adjacencies: adj = next(iter(fixed_adjacencies)) if not isinstance(adj, frozenset): raise ValueError('fixed_adjacencies should contain frozensets') if fixed_gaps: adj = next(iter(fixed_gaps)) if not isinstance(adj, frozenset): raise ValueError('fixed_gaps should contain frozensets') d = DAG(nodes=set(perm)) ixs = list( itr.chain.from_iterable( ((f, s) for f in range(s)) for s in range(len(perm)))) ixs = ixs if not progress else tqdm(ixs) for i, j in ixs: pi_i, pi_j = perm[i], perm[j] # === IF FIXED, DON'T TEST if frozenset({pi_i, pi_j}) in fixed_adjacencies: d.add_arc(pi_i, pi_j) continue if frozenset({pi_i, pi_j}) in fixed_gaps: continue # === TEST MARKOV BLANKET mb = d.markov_blanket(pi_i) if node2nbrs is None else ( set(perm[:j]) - {pi_i}) & (node2nbrs[pi_i] | node2nbrs[pi_j]) mb = mb if not older else set(perm[:j]) - {pi_i} is_ci = ci_tester.is_ci(pi_i, pi_j, mb) if not is_ci: d.add_arc(pi_i, pi_j, unsafe=True) if verbose: print(f"{pi_i} is independent of {pi_j} given {mb}: {is_ci}") return d