def test_marginal_mag(self): d = cd.DAG(arcs={(1, 2), (1, 3)}) self.assertEqual(d.marginal_mag(1), cd.AncestralGraph(bidirected={(2, 3)})) d = cd.DAG(arcs={(1, 2), (1, 3), (2, 3)}) self.assertEqual(d.marginal_mag(1), cd.AncestralGraph(directed={(2, 3)}))
def test_fast_markov_equivalence_simple(self): g1 = cd.AncestralGraph(directed={(0, 1), (1, 3)}, bidirected={(1, 2), (2, 3)}) g2 = cd.AncestralGraph(directed={(0, 1), (1, 3), (1, 2)}, bidirected={(2, 3)}) g3 = cd.AncestralGraph(directed={(0, 1), (1, 2), (1, 3), (3, 2)}) self.assertFalse(g1.fast_markov_equivalent(g2)) self.assertFalse(g1.fast_markov_equivalent(g3)) self.assertTrue(g2.fast_markov_equivalent(g3))
def test_legitimate_mark_changes(self): g = cd.AncestralGraph(directed={(0, 1)}, bidirected={(1, 2)}) lmcs = g.legitimate_mark_changes() self.assertEqual(lmcs, ({(0, 1)}, {(2, 1)})) g = cd.AncestralGraph(directed={(0, 1), (1, 2)}) lmcs = g.legitimate_mark_changes() self.assertEqual(lmcs, ({(0, 1)}, set())) g = cd.AncestralGraph(directed={(2, 1), (2, 3), (3, 5)}, bidirected={(1, 3), (4, 5), (2, 4)}) lmcs = g.legitimate_mark_changes() self.assertEqual(lmcs, (set(), {(1, 3), (2, 4), (3, 1)}))
def fci(samples, alpha, dag_num, fci_plus=False): p = samples.shape[1] # === SAVE SAMPLES samples_filename = f'tmp_file_{dag_num}.npy' np.save(samples_filename, samples) # === RUN FCI AND CONVERT OUTPUT r_output = subprocess.check_output( ['Rscript', FCI_FILENAME, samples_filename, str(alpha), str(fci_plus)]) r_output = r_output.split(b'\n')[-1] amat = np.array(list(map(int, r_output.decode().split(' ')))).reshape([p, p]).T try: mag = cd.AncestralGraph.from_amat(amat) except Exception as e: # sometimes returns non-ancestral graph?? mag = cd.AncestralGraph(nodes=set(range(p))) skeleton = cd.UndirectedGraph.from_amat(amat) # if not np.alltrue(pag.to_amat() == amat): # print(pag.to_amat()) # print(amat) # === CLEAN UP AND RETURN os.remove(samples_filename) return mag, skeleton
def poset2mag(poset: Poset, ci_tester, maximal_completion=True): nodes = poset.underlying_dag.nodes m = cd.AncestralGraph(nodes=nodes) for i, j in itr.combinations(nodes, r=2): # if({1,3} == {i,j}): # #import pdb; pdb.set_trace() ancestors_i = poset._ancestors[i] ancestors_j = poset._ancestors[j] S = (ancestors_i | ancestors_j) - {i, j} if not ci_tester.is_ci(i, j, S): if poset.incomparable(i, j): m.add_bidirected(i, j) elif i in ancestors_j: m.add_directed(i, j) else: m.add_directed(j, i) # print("------------------before---") # print(m.directed) # print(m.bidirected) # print("------------------after---") # print(m.directed) # print(m.bidirected) if (maximal_completion): m.to_maximal() return m
def test_ancestor_dict(self): g = cd.AncestralGraph(bidirected={(0, 1)}, directed={(0, 2), (1, 3), (2, 4), (3, 4)}) ancestor_dict = g.ancestor_dict() self.assertEqual(ancestor_dict[0], set()) self.assertEqual(ancestor_dict[1], set()) self.assertEqual(ancestor_dict[2], {0}) self.assertEqual(ancestor_dict[3], {1}) self.assertEqual(ancestor_dict[4], {0, 1, 2, 3})
def test_disc_paths(self): g = cd.AncestralGraph(nodes=set(range(1, 5)), directed={(1, 2), (2, 4), (3, 2), (3, 4)}) disc_paths = g.discriminating_paths() self.assertEqual(disc_paths, {(1, 2, 3, 4): 'n'}) g = cd.AncestralGraph(nodes=set(range(1, 5)), directed={(1, 2), (2, 4)}, bidirected={(3, 2), (3, 4)}) disc_paths = g.discriminating_paths() self.assertEqual(disc_paths, {(1, 2, 3, 4): 'c'}) g = cd.AncestralGraph(nodes=set(range(1, 6)), directed={(1, 2), (2, 5), (3, 5)}, bidirected={(2, 3), (3, 4), (4, 5)}) disc_paths = g.discriminating_paths() # print(disc_paths) self.assertEqual(disc_paths, {(1, 2, 3, 5): 'n', (1, 2, 3, 4, 5): 'c'})
def test_msep(self): return # 1 -> 3 <-> 4 <- 2 d = cd.AncestralGraph(directed={(1, 3), (2, 4)}, bidirected={(3, 4)}) self.assertTrue(d.msep({1, 3}, 2)) self.assertTrue(d.msep(1, 2)) self.assertTrue(d.msep({1}, 4)) self.assertFalse(d.msep({1, 3}, 4)) self.assertFalse(d.msep(1, 2, {3, 4})) # undirected 4-cycle d = cd.AncestralGraph(undirected={(1, 2), (2, 3), (3, 4), (4, 1)}) self.assertFalse(d.msep(1, 3)) self.assertTrue(d.msep(1, 3, {2, 4})) # bidirected 4-cycle d = cd.AncestralGraph(bidirected={(1, 2), (2, 3), (3, 4), (4, 1)}) self.assertTrue(d.msep(1, 3)) self.assertFalse(d.msep(1, 3, 2)) # discriminating path with discriminated node (3) as collider d = cd.AncestralGraph(directed={(1, 2), (2, 4)}, bidirected={(2, 3), (3, 4)}) self.assertTrue(d.msep(1, 4, 2)) self.assertFalse(d.msep(1, 4, {2, 3})) # big random graph np.random.seed(1729) random.seed(1729) nnodes = 10 nodes = set(range(nnodes)) g = cd.rand.directed_erdos(nnodes, 1 / (nnodes - 1)) print(g.arcs) amat_file = os.path.join(CURR_DIR, 'random_mag.txt') np.savetxt(amat_file, g.to_amat(list(nodes))[0]) rfile = os.path.join(CURR_DIR, 'test_msep.R') ntests = 50 for _ in range(ntests): set_size = random.randint( 1, 3) # these currently need to be the same size b/c ggm is buggy nodes1 = random.sample(nodes, set_size) nodes2 = random.sample(nodes - set(nodes1), set_size) cond_set = random.sample(nodes - set(nodes1) - set(nodes2), random.randint(1, 3)) print(nodes1, nodes2, cond_set) nodes1_str = ','.join(map(str, nodes1)) nodes2_str = ','.join(map(str, nodes2)) cond_set_str = ','.join(map(str, cond_set)) if len(cond_set) > 0: r_output = subprocess.check_output([ 'Rscript', rfile, amat_file, nodes1_str, nodes2_str, cond_set_str ]) else: r_output = subprocess.check_output( ['Rscript', rfile, amat_file, nodes1_str, nodes2_str]) print(r_output.decode()) r_output = r_output.decode() == 'TRUE' my_output = g.dsep(nodes1, nodes2, cond_set) print(r_output, my_output) self.assertEqual(r_output, my_output)
def test_msep_from_given(self): d = cd.AncestralGraph(directed={(1, 2), (3, 2), (2, 4), (3, 4)})
def setUp(self): self.d = cd.AncestralGraph(directed={(1, 3)}, bidirected={(3, 4)}, undirected={(1, 2)})
def gspo( nodes: set, ci_tester, depth=4, initial_imap='permutation', strict=True, verbose=False, max_iters=float('inf'), nruns=5, ): """ Estimate a MAG using the Greedy Sparsest Poset algorithm. Parameters ---------- nodes: Labels of nodes in the graph. ci_tester: A conditional independence tester, which has a method is_ci taking two sets A and B, and a conditioning set C, and returns True/False. depth: Maximum depth in depth-first search. Use None for infinite search depth. initial_imap: String indicating how to obtain the initial IMAP. Must be "permutation" or "empty". strict: If True, check discriminating paths condition for legitimate mark changes. verbose: If True, print information about algorithm progress. max_iters: Maximum number of depth-first search steps without score improvement before stopping. nruns: Number of times to run the algorithm (each run may vary due to randomness in tie-breaking and/or starting imap. Return ------ An estimated MAG """ if initial_imap == 'permutation': ug = threshold_ug(nodes, ci_tester) amat = ug.to_amat() perms = [min_degree_alg_amat(amat) for _ in range(nruns)] dags = [perm2dag(perm, ci_tester) for perm in perms] starting_imaps = [ cd.AncestralGraph(dag.nodes, directed=dag.arcs) for dag in dags ] elif initial_imap == 'empty': edges = {(i, j) for i, j in itr.combinations(nodes, 2) if not ci_tester.is_ci(i, j)} starting_imaps = [ cd.AncestralGraph(nodes, bidirected=edges) for _ in range(nruns) ] elif initial_imap == 'gsp': ug = threshold_ug(nodes, ci_tester) amat = ug.to_amat() perms = [min_degree_alg_amat(amat) for _ in range(nruns)] dags = [ gsp(nodes, ci_tester, nruns=1, initial_permutations=[perm]) for perm in perms ] starting_imaps = [ cd.AncestralGraph(dag.nodes, directed=dag.arcs) for dag, _ in dags ] get_alt_edges = get_lmc_altered_edges sparsest_imap = None for r in range(nruns): current_imap = starting_imaps[r] if verbose: print(f"Starting run {r} with {current_imap.num_edges} edges") # TODO: BOTTLENECK current_lmcs_directed, current_lmcs_bidirected = current_imap.legitimate_mark_changes( strict=strict) current_lmcs = current_lmcs_directed | current_lmcs_bidirected # TODO: BOTTLENECK lmcs2altered_edges = [(lmc, get_alt_edges(current_imap, *lmc, ci_tester)) for lmc in current_lmcs] lmcs2altered_edges = [(lmc, (a, b)) for lmc, (a, b) in lmcs2altered_edges if a is not None] lmcs2edge_delta = [(lmc, len(removed_dir) + len(removed_bidir)) for lmc, (removed_dir, removed_bidir) in lmcs2altered_edges] mag2number = dict() graph_counter = 0 trace = [] iters_since_improvement = 0 while True: if iters_since_improvement > max_iters: break mag_hash = (frozenset(current_imap._directed), bidirected_frozenset(current_imap)) if mag_hash not in mag2number: mag2number[mag_hash] = graph_counter graph_num = mag2number[mag_hash] if verbose: print( f"Number of visited MAGs: {len(mag2number)}. Exploring MAG #{graph_num} with {current_imap.num_edges} edges." ) max_delta = max([delta for lmc, delta in lmcs2edge_delta], default=0) sparser_exists = max_delta > 0 keep_searching_mec = len(trace) != depth and len( lmcs2altered_edges) > 0 if sparser_exists: trace = [] lmc_ix = random.choice([ ix for ix, (lmc, delta) in enumerate(lmcs2edge_delta) if delta == max_delta ]) (i, j), (removed_dir, removed_bidir) = lmcs2altered_edges.pop(lmc_ix) apply_lmc(current_imap, i, j) current_imap.remove_edges(removed_dir | removed_bidir) if verbose: print( f"Starting over at a sparser IMAP with {current_imap.num_edges} edges" ) elif keep_searching_mec: if verbose: print( f"{'='*len(trace)}Continuing search through the MEC at {current_imap.num_edges} edges. " f"Picking from {len(lmcs2altered_edges)} neighbors of #{graph_num}." ) trace.append((current_imap.copy(), current_lmcs, lmcs2altered_edges, lmcs2edge_delta)) (i, j), _ = lmcs2altered_edges.pop(0) lmcs2edge_delta.pop(0) apply_lmc(current_imap, i, j) elif len(trace) != 0: # BACKTRACK IF POSSIBLE if verbose: print(f"{'='*len(trace)}Backtracking") current_imap, current_lmcs, lmcs2altered_edges, lmcs2edge_delta = trace.pop( ) iters_since_improvement += 1 else: break # IF WE MOVED TO A NOVEL IMAP, WE NEED TO UPDATE LMCs if sparser_exists or keep_searching_mec: graph_counter += 1 current_lmcs_dir, current_lmcs_bidir = current_imap.legitimate_mark_changes( strict=strict) current_lmcs = current_lmcs_dir | current_lmcs_bidir lmcs2altered_edges = [(lmc, get_alt_edges(current_imap, *lmc, ci_tester)) for lmc in current_lmcs] lmcs2altered_edges = [(lmc, (a, b)) for lmc, (a, b) in lmcs2altered_edges if a is not None] current_directed, current_bidirected = frozenset( current_imap.directed), bidirected_frozenset(current_imap) # === FILTER OUT ALREADY-VISITED IMAPS filtered_lmcs2altered_edges = [] for lmc, (removed_dir, removed_bidir) in lmcs2altered_edges: if current_imap.has_directed(*lmc): new_directed = current_directed - {lmc} - removed_dir new_bidirected = current_bidirected | { frozenset({*lmc}) } - {frozenset({*e}) for e in removed_bidir} else: new_directed = current_directed | {lmc} - removed_dir new_bidirected = current_bidirected - { frozenset({*lmc}) } - {frozenset({*e}) for e in removed_bidir} if (new_directed, new_bidirected) not in mag2number: filtered_lmcs2altered_edges.append( (lmc, (removed_dir, removed_bidir))) lmcs2altered_edges = filtered_lmcs2altered_edges lmcs2edge_delta = [ (lmc, len(removed_dir) + len(removed_bidir)) for lmc, (removed_dir, removed_bidir) in lmcs2altered_edges ] if sparsest_imap is None or sparsest_imap.num_edges > current_imap.num_edges: sparsest_imap = current_imap return current_imap
if __name__ == '__main__': from causaldag.utils.ci_tests import MemoizedCI_Tester, msep_test, gauss_ci_suffstat, gauss_ci_test from R_algs.fci_wrapper import fci from line_profiler import LineProfiler # lp = LineProfiler() import time import numpy as np seed = np.random.randint(0, 100000) # seed = 95831 np.random.seed(seed) random.seed(seed) m2 = cd.AncestralGraph(directed={(3, 4), (0, 4), (1, 4)}, bidirected={(0, 1), (1, 3), (2, 3), (1, 2)}) m3 = cd.AncestralGraph(directed={(1, 4), (3, 4)}, bidirected={(0, 4), (0, 1), (1, 2), (2, 3), (1, 3)}) m4 = cd.AncestralGraph(directed={ (4, 7), (9, 1), (9, 4), (2, 5), (0, 3), (8, 5), (1, 2), (1, 5), (0, 4), (8, 2), (9, 3), (0, 5),