Exemplo n.º 1
0
    def test_marginal_mag(self):
        d = cd.DAG(arcs={(1, 2), (1, 3)})
        self.assertEqual(d.marginal_mag(1),
                         cd.AncestralGraph(bidirected={(2, 3)}))

        d = cd.DAG(arcs={(1, 2), (1, 3), (2, 3)})
        self.assertEqual(d.marginal_mag(1),
                         cd.AncestralGraph(directed={(2, 3)}))
Exemplo n.º 2
0
 def test_fast_markov_equivalence_simple(self):
     g1 = cd.AncestralGraph(directed={(0, 1), (1, 3)},
                            bidirected={(1, 2), (2, 3)})
     g2 = cd.AncestralGraph(directed={(0, 1), (1, 3), (1, 2)},
                            bidirected={(2, 3)})
     g3 = cd.AncestralGraph(directed={(0, 1), (1, 2), (1, 3), (3, 2)})
     self.assertFalse(g1.fast_markov_equivalent(g2))
     self.assertFalse(g1.fast_markov_equivalent(g3))
     self.assertTrue(g2.fast_markov_equivalent(g3))
Exemplo n.º 3
0
    def test_legitimate_mark_changes(self):
        g = cd.AncestralGraph(directed={(0, 1)}, bidirected={(1, 2)})
        lmcs = g.legitimate_mark_changes()
        self.assertEqual(lmcs, ({(0, 1)}, {(2, 1)}))

        g = cd.AncestralGraph(directed={(0, 1), (1, 2)})
        lmcs = g.legitimate_mark_changes()
        self.assertEqual(lmcs, ({(0, 1)}, set()))

        g = cd.AncestralGraph(directed={(2, 1), (2, 3), (3, 5)},
                              bidirected={(1, 3), (4, 5), (2, 4)})
        lmcs = g.legitimate_mark_changes()
        self.assertEqual(lmcs, (set(), {(1, 3), (2, 4), (3, 1)}))
Exemplo n.º 4
0
def fci(samples, alpha, dag_num, fci_plus=False):
    p = samples.shape[1]

    # === SAVE SAMPLES
    samples_filename = f'tmp_file_{dag_num}.npy'
    np.save(samples_filename, samples)

    # === RUN FCI AND CONVERT OUTPUT
    r_output = subprocess.check_output(
        ['Rscript', FCI_FILENAME, samples_filename,
         str(alpha),
         str(fci_plus)])
    r_output = r_output.split(b'\n')[-1]
    amat = np.array(list(map(int,
                             r_output.decode().split(' ')))).reshape([p, p]).T
    try:
        mag = cd.AncestralGraph.from_amat(amat)
    except Exception as e:  # sometimes returns non-ancestral graph??
        mag = cd.AncestralGraph(nodes=set(range(p)))
    skeleton = cd.UndirectedGraph.from_amat(amat)
    # if not np.alltrue(pag.to_amat() == amat):
    #     print(pag.to_amat())
    #     print(amat)

    # === CLEAN UP AND RETURN
    os.remove(samples_filename)
    return mag, skeleton
Exemplo n.º 5
0
def poset2mag(poset: Poset, ci_tester, maximal_completion=True):
    nodes = poset.underlying_dag.nodes
    m = cd.AncestralGraph(nodes=nodes)
    for i, j in itr.combinations(nodes, r=2):
        # if({1,3} == {i,j}):
        #    #import pdb; pdb.set_trace()
        ancestors_i = poset._ancestors[i]
        ancestors_j = poset._ancestors[j]
        S = (ancestors_i | ancestors_j) - {i, j}
        if not ci_tester.is_ci(i, j, S):
            if poset.incomparable(i, j):
                m.add_bidirected(i, j)
            elif i in ancestors_j:
                m.add_directed(i, j)
            else:
                m.add_directed(j, i)
    # print("------------------before---")
    # print(m.directed)
    # print(m.bidirected)
    # print("------------------after---")
    # print(m.directed)
    # print(m.bidirected)
    if (maximal_completion):
        m.to_maximal()
    return m
Exemplo n.º 6
0
 def test_ancestor_dict(self):
     g = cd.AncestralGraph(bidirected={(0, 1)},
                           directed={(0, 2), (1, 3), (2, 4), (3, 4)})
     ancestor_dict = g.ancestor_dict()
     self.assertEqual(ancestor_dict[0], set())
     self.assertEqual(ancestor_dict[1], set())
     self.assertEqual(ancestor_dict[2], {0})
     self.assertEqual(ancestor_dict[3], {1})
     self.assertEqual(ancestor_dict[4], {0, 1, 2, 3})
Exemplo n.º 7
0
    def test_disc_paths(self):
        g = cd.AncestralGraph(nodes=set(range(1, 5)),
                              directed={(1, 2), (2, 4), (3, 2), (3, 4)})
        disc_paths = g.discriminating_paths()
        self.assertEqual(disc_paths, {(1, 2, 3, 4): 'n'})

        g = cd.AncestralGraph(nodes=set(range(1, 5)),
                              directed={(1, 2), (2, 4)},
                              bidirected={(3, 2), (3, 4)})
        disc_paths = g.discriminating_paths()
        self.assertEqual(disc_paths, {(1, 2, 3, 4): 'c'})

        g = cd.AncestralGraph(nodes=set(range(1, 6)),
                              directed={(1, 2), (2, 5), (3, 5)},
                              bidirected={(2, 3), (3, 4), (4, 5)})
        disc_paths = g.discriminating_paths()
        # print(disc_paths)
        self.assertEqual(disc_paths, {(1, 2, 3, 5): 'n', (1, 2, 3, 4, 5): 'c'})
Exemplo n.º 8
0
    def test_msep(self):
        return
        # 1 -> 3 <-> 4 <- 2
        d = cd.AncestralGraph(directed={(1, 3), (2, 4)}, bidirected={(3, 4)})
        self.assertTrue(d.msep({1, 3}, 2))
        self.assertTrue(d.msep(1, 2))
        self.assertTrue(d.msep({1}, 4))
        self.assertFalse(d.msep({1, 3}, 4))
        self.assertFalse(d.msep(1, 2, {3, 4}))

        # undirected 4-cycle
        d = cd.AncestralGraph(undirected={(1, 2), (2, 3), (3, 4), (4, 1)})
        self.assertFalse(d.msep(1, 3))
        self.assertTrue(d.msep(1, 3, {2, 4}))

        # bidirected 4-cycle
        d = cd.AncestralGraph(bidirected={(1, 2), (2, 3), (3, 4), (4, 1)})
        self.assertTrue(d.msep(1, 3))
        self.assertFalse(d.msep(1, 3, 2))

        # discriminating path with discriminated node (3) as collider
        d = cd.AncestralGraph(directed={(1, 2), (2, 4)},
                              bidirected={(2, 3), (3, 4)})
        self.assertTrue(d.msep(1, 4, 2))
        self.assertFalse(d.msep(1, 4, {2, 3}))

        # big random graph
        np.random.seed(1729)
        random.seed(1729)
        nnodes = 10
        nodes = set(range(nnodes))
        g = cd.rand.directed_erdos(nnodes, 1 / (nnodes - 1))
        print(g.arcs)
        amat_file = os.path.join(CURR_DIR, 'random_mag.txt')
        np.savetxt(amat_file, g.to_amat(list(nodes))[0])

        rfile = os.path.join(CURR_DIR, 'test_msep.R')

        ntests = 50
        for _ in range(ntests):
            set_size = random.randint(
                1,
                3)  # these currently need to be the same size b/c ggm is buggy
            nodes1 = random.sample(nodes, set_size)
            nodes2 = random.sample(nodes - set(nodes1), set_size)
            cond_set = random.sample(nodes - set(nodes1) - set(nodes2),
                                     random.randint(1, 3))
            print(nodes1, nodes2, cond_set)
            nodes1_str = ','.join(map(str, nodes1))
            nodes2_str = ','.join(map(str, nodes2))
            cond_set_str = ','.join(map(str, cond_set))
            if len(cond_set) > 0:
                r_output = subprocess.check_output([
                    'Rscript', rfile, amat_file, nodes1_str, nodes2_str,
                    cond_set_str
                ])
            else:
                r_output = subprocess.check_output(
                    ['Rscript', rfile, amat_file, nodes1_str, nodes2_str])
            print(r_output.decode())
            r_output = r_output.decode() == 'TRUE'
            my_output = g.dsep(nodes1, nodes2, cond_set)
            print(r_output, my_output)
            self.assertEqual(r_output, my_output)
Exemplo n.º 9
0
 def test_msep_from_given(self):
     d = cd.AncestralGraph(directed={(1, 2), (3, 2), (2, 4), (3, 4)})
Exemplo n.º 10
0
 def setUp(self):
     self.d = cd.AncestralGraph(directed={(1, 3)},
                                bidirected={(3, 4)},
                                undirected={(1, 2)})
Exemplo n.º 11
0
def gspo(
        nodes: set,
        ci_tester,
        depth=4,
        initial_imap='permutation',
        strict=True,
        verbose=False,
        max_iters=float('inf'),
        nruns=5,
):
    """
    Estimate a MAG using the Greedy Sparsest Poset algorithm.

    Parameters
    ----------
    nodes:
        Labels of nodes in the graph.
    ci_tester:
        A conditional independence tester, which has a method is_ci taking two sets A and B, and a conditioning set C,
        and returns True/False.
    depth:
        Maximum depth in depth-first search. Use None for infinite search depth.
    initial_imap:
        String indicating how to obtain the initial IMAP. Must be "permutation" or "empty".
    strict:
        If True, check discriminating paths condition for legitimate mark changes.
    verbose:
        If True, print information about algorithm progress.
    max_iters:
        Maximum number of depth-first search steps without score improvement before stopping.
    nruns:
        Number of times to run the algorithm (each run may vary due to randomness in tie-breaking and/or starting
        imap.

    Return
    ------
    An estimated MAG
    """
    if initial_imap == 'permutation':
        ug = threshold_ug(nodes, ci_tester)
        amat = ug.to_amat()
        perms = [min_degree_alg_amat(amat) for _ in range(nruns)]
        dags = [perm2dag(perm, ci_tester) for perm in perms]
        starting_imaps = [
            cd.AncestralGraph(dag.nodes, directed=dag.arcs) for dag in dags
        ]
    elif initial_imap == 'empty':
        edges = {(i, j)
                 for i, j in itr.combinations(nodes, 2)
                 if not ci_tester.is_ci(i, j)}
        starting_imaps = [
            cd.AncestralGraph(nodes, bidirected=edges) for _ in range(nruns)
        ]
    elif initial_imap == 'gsp':
        ug = threshold_ug(nodes, ci_tester)
        amat = ug.to_amat()
        perms = [min_degree_alg_amat(amat) for _ in range(nruns)]
        dags = [
            gsp(nodes, ci_tester, nruns=1, initial_permutations=[perm])
            for perm in perms
        ]
        starting_imaps = [
            cd.AncestralGraph(dag.nodes, directed=dag.arcs) for dag, _ in dags
        ]

    get_alt_edges = get_lmc_altered_edges

    sparsest_imap = None
    for r in range(nruns):
        current_imap = starting_imaps[r]
        if verbose:
            print(f"Starting run {r} with {current_imap.num_edges} edges")

        # TODO: BOTTLENECK
        current_lmcs_directed, current_lmcs_bidirected = current_imap.legitimate_mark_changes(
            strict=strict)
        current_lmcs = current_lmcs_directed | current_lmcs_bidirected

        # TODO: BOTTLENECK
        lmcs2altered_edges = [(lmc, get_alt_edges(current_imap, *lmc,
                                                  ci_tester))
                              for lmc in current_lmcs]
        lmcs2altered_edges = [(lmc, (a, b))
                              for lmc, (a, b) in lmcs2altered_edges
                              if a is not None]
        lmcs2edge_delta = [(lmc, len(removed_dir) + len(removed_bidir))
                           for lmc, (removed_dir,
                                     removed_bidir) in lmcs2altered_edges]

        mag2number = dict()
        graph_counter = 0
        trace = []
        iters_since_improvement = 0
        while True:
            if iters_since_improvement > max_iters:
                break

            mag_hash = (frozenset(current_imap._directed),
                        bidirected_frozenset(current_imap))
            if mag_hash not in mag2number:
                mag2number[mag_hash] = graph_counter
            graph_num = mag2number[mag_hash]
            if verbose:
                print(
                    f"Number of visited MAGs: {len(mag2number)}. Exploring MAG #{graph_num} with {current_imap.num_edges} edges."
                )
            max_delta = max([delta for lmc, delta in lmcs2edge_delta],
                            default=0)

            sparser_exists = max_delta > 0
            keep_searching_mec = len(trace) != depth and len(
                lmcs2altered_edges) > 0

            if sparser_exists:
                trace = []

                lmc_ix = random.choice([
                    ix for ix, (lmc, delta) in enumerate(lmcs2edge_delta)
                    if delta == max_delta
                ])
                (i, j), (removed_dir,
                         removed_bidir) = lmcs2altered_edges.pop(lmc_ix)
                apply_lmc(current_imap, i, j)
                current_imap.remove_edges(removed_dir | removed_bidir)

                if verbose:
                    print(
                        f"Starting over at a sparser IMAP with {current_imap.num_edges} edges"
                    )
            elif keep_searching_mec:
                if verbose:
                    print(
                        f"{'='*len(trace)}Continuing search through the MEC at {current_imap.num_edges} edges. "
                        f"Picking from {len(lmcs2altered_edges)} neighbors of #{graph_num}."
                    )
                trace.append((current_imap.copy(), current_lmcs,
                              lmcs2altered_edges, lmcs2edge_delta))
                (i, j), _ = lmcs2altered_edges.pop(0)
                lmcs2edge_delta.pop(0)
                apply_lmc(current_imap, i, j)
            elif len(trace) != 0:  # BACKTRACK IF POSSIBLE
                if verbose: print(f"{'='*len(trace)}Backtracking")
                current_imap, current_lmcs, lmcs2altered_edges, lmcs2edge_delta = trace.pop(
                )
                iters_since_improvement += 1
            else:
                break

            # IF WE MOVED TO A NOVEL IMAP, WE NEED TO UPDATE LMCs
            if sparser_exists or keep_searching_mec:
                graph_counter += 1
                current_lmcs_dir, current_lmcs_bidir = current_imap.legitimate_mark_changes(
                    strict=strict)
                current_lmcs = current_lmcs_dir | current_lmcs_bidir
                lmcs2altered_edges = [(lmc,
                                       get_alt_edges(current_imap, *lmc,
                                                     ci_tester))
                                      for lmc in current_lmcs]
                lmcs2altered_edges = [(lmc, (a, b))
                                      for lmc, (a, b) in lmcs2altered_edges
                                      if a is not None]
                current_directed, current_bidirected = frozenset(
                    current_imap.directed), bidirected_frozenset(current_imap)

                # === FILTER OUT ALREADY-VISITED IMAPS
                filtered_lmcs2altered_edges = []
                for lmc, (removed_dir, removed_bidir) in lmcs2altered_edges:
                    if current_imap.has_directed(*lmc):
                        new_directed = current_directed - {lmc} - removed_dir
                        new_bidirected = current_bidirected | {
                            frozenset({*lmc})
                        } - {frozenset({*e})
                             for e in removed_bidir}
                    else:
                        new_directed = current_directed | {lmc} - removed_dir
                        new_bidirected = current_bidirected - {
                            frozenset({*lmc})
                        } - {frozenset({*e})
                             for e in removed_bidir}

                    if (new_directed, new_bidirected) not in mag2number:
                        filtered_lmcs2altered_edges.append(
                            (lmc, (removed_dir, removed_bidir)))
                lmcs2altered_edges = filtered_lmcs2altered_edges

                lmcs2edge_delta = [
                    (lmc, len(removed_dir) + len(removed_bidir))
                    for lmc, (removed_dir, removed_bidir) in lmcs2altered_edges
                ]
        if sparsest_imap is None or sparsest_imap.num_edges > current_imap.num_edges:
            sparsest_imap = current_imap

    return current_imap
Exemplo n.º 12
0
if __name__ == '__main__':
    from causaldag.utils.ci_tests import MemoizedCI_Tester, msep_test, gauss_ci_suffstat, gauss_ci_test
    from R_algs.fci_wrapper import fci
    from line_profiler import LineProfiler
    #
    lp = LineProfiler()

    import time
    import numpy as np

    seed = np.random.randint(0, 100000)
    # seed = 95831
    np.random.seed(seed)
    random.seed(seed)

    m2 = cd.AncestralGraph(directed={(3, 4), (0, 4), (1, 4)},
                           bidirected={(0, 1), (1, 3), (2, 3), (1, 2)})
    m3 = cd.AncestralGraph(directed={(1, 4), (3, 4)},
                           bidirected={(0, 4), (0, 1), (1, 2), (2, 3), (1, 3)})
    m4 = cd.AncestralGraph(directed={
        (4, 7),
        (9, 1),
        (9, 4),
        (2, 5),
        (0, 3),
        (8, 5),
        (1, 2),
        (1, 5),
        (0, 4),
        (8, 2),
        (9, 3),
        (0, 5),