class DiscPathTest(unittest.TestCase):
    # tests for finding discriminating paths on a graph
    def setUp(self):
        self.pag = PAG()
        self.pag.add_nodes_from([1, 2, 3, 4])

    def test1(self):
        self.pag.add_edges_from([[1, 2], [2, 3], [3, 4], [2, 4]])
        self.pag.fully_direct_edge(1, 2)
        self.pag.fully_direct_edge(3, 2)
        self.pag.fully_direct_edge(2, 4)
        self.pag.fully_direct_edge(3, 4)
        assert (self.pag.hasDiscPath(1, 4, 3))

    def test2(self):
        self.pag.add_edges_from([[1, 2], [2, 3], [2, 4]])
        self.pag.fully_direct_edge(1, 2)
        self.pag.fully_direct_edge(3, 2)
        self.pag.fully_direct_edge(2, 4)
        assert (not self.pag.hasDiscPath(1, 4, 3))

    def test3(self):
        self.pag.add_edges_from([[1, 2], [2, 3], [3, 4], [2, 4]])
        self.pag.fully_direct_edge(1, 2)
        self.pag.fully_direct_edge(2, 3)
        self.pag.fully_direct_edge(2, 4)
        self.pag.fully_direct_edge(3, 4)
        assert (not self.pag.hasDiscPath(1, 4, 3))

    def test4(self):
        self.pag.add_edges_from([[1, 2], [2, 3], [3, 4], [2, 4]])
        self.pag.fully_direct_edge(1, 2)
        self.pag.fully_direct_edge(3, 2)
        self.pag.fully_direct_edge(3, 4)
        assert (not self.pag.hasDiscPath(1, 4, 3))

    def test5(self):
        self.pag.add_edges_from([[1, 2], [2, 3], [3, 4], [2, 4]])
        self.pag.fully_direct_edge(1, 2)
        self.pag.fully_direct_edge(3, 2)
        self.pag.fully_direct_edge(3, 4)
        self.pag.fully_direct_edge(1, 4)
        assert (not self.pag.hasDiscPath(1, 4, 3))
class RulesTests(unittest.TestCase):
    # tests for fci algorithm orientation rules
    def setUp(self):
        self.pag = PAG()
        self.pag.add_nodes_from([1, 2, 3, 4, 5])

    # Rule 1 Tests
    def test11(self):
        self.pag.add_edge(1, 2)
        self.pag.add_edge(2, 3)
        self.pag.direct_edge(1, 2)
        FCIAlg.rule1(self.pag, 1, 2, 3)
        assert (self.pag.has_fully_directed_edge(2, 3))

    def test12(self):
        self.pag.add_edge(1, 2)
        self.pag.add_edge(2, 3)
        FCIAlg.rule1(self.pag, 1, 2, 3)
        assert (not self.pag.has_fully_directed_edge(2, 3))

    # Rule 2 Tests
    def test21(self):
        self.pag.add_edge(1, 2)
        self.pag.add_edge(2, 3)
        self.pag.add_edge(1, 3)
        self.pag.direct_edge(1, 2)
        self.pag.fully_direct_edge(2, 3)
        FCIAlg.rule2(self.pag, 1, 2, 3)
        assert (self.pag.has_directed_edge(1, 3))

    def test22(self):
        self.pag.add_edge(1, 2)
        self.pag.add_edge(2, 3)
        self.pag.add_edge(1, 3)
        self.pag.direct_edge(2, 3)
        self.pag.fully_direct_edge(1, 2)
        FCIAlg.rule2(self.pag, 1, 2, 3)
        assert (self.pag.has_directed_edge(1, 3))

    def test23(self):
        self.pag.add_edge(1, 2)
        self.pag.add_edge(2, 3)
        self.pag.add_edge(1, 3)
        self.pag.fully_direct_edge(3, 2)
        self.pag.fully_direct_edge(1, 2)
        FCIAlg.rule2(self.pag, 1, 2, 3)
        assert (not self.pag.has_directed_edge(1, 3))

    # Rule 3 Tests
    def test31(self):
        self.pag.add_edge(1, 2)
        self.pag.add_edge(2, 3)
        self.pag.add_edge(1, 4)
        self.pag.add_edge(4, 3)
        self.pag.add_edge(4, 2)
        self.pag.direct_edge(1, 2)
        self.pag.direct_edge(3, 2)
        FCIAlg.rule3(self.pag, 1, 2, 3, 4)
        assert (self.pag.has_directed_edge(4, 2))

    def test32(self):
        self.pag.add_edge(1, 2)
        self.pag.add_edge(2, 3)
        self.pag.add_edge(1, 4)
        self.pag.add_edge(4, 3)
        self.pag.add_edge(4, 2)
        self.pag.direct_edge(1, 2)
        self.pag.direct_edge(3, 2)
        self.pag.direct_edge(1, 4)
        FCIAlg.rule3(self.pag, 1, 2, 3, 4)
        assert (not self.pag.has_directed_edge(4, 2))

    # Rule 4 Tests
    def test41(self):
        self.pag.add_edge(1, 2)
        self.pag.add_edge(2, 3)
        self.pag.add_edge(3, 4)
        self.pag.add_edge(4, 5)
        self.pag.add_edge(2, 5)
        self.pag.add_edge(3, 5)
        self.pag.direct_edge(1, 2)
        self.pag.direct_edge(2, 3)
        self.pag.direct_edge(2, 5)
        self.pag.direct_edge(3, 2)
        self.pag.direct_edge(3, 4)
        self.pag.direct_edge(3, 5)
        self.pag.direct_edge(4, 3)
        self.pag.direct_edge(3, 2)
        sepset = {(1, 5): [4], (5, 1): [4]}
        FCIAlg.rule4(self.pag, 3, 4, 5, 1, sepset)
        assert (self.pag.has_directed_edge(4, 5))

    def test42(self):
        self.pag.add_edge(1, 2)
        self.pag.add_edge(2, 3)
        self.pag.add_edge(3, 4)
        self.pag.add_edge(4, 5)
        self.pag.add_edge(2, 5)
        self.pag.add_edge(3, 5)
        self.pag.direct_edge(1, 2)
        self.pag.direct_edge(2, 3)
        self.pag.direct_edge(2, 5)
        self.pag.direct_edge(3, 2)
        self.pag.direct_edge(3, 4)
        self.pag.direct_edge(3, 5)
        self.pag.direct_edge(4, 3)
        self.pag.direct_edge(3, 2)
        sepset = {(1, 5): [2], (5, 1): [2]}
        FCIAlg.rule4(self.pag, 3, 4, 5, 1, sepset)
        assert (self.pag.has_directed_edge(4, 5)
                and self.pag.has_directed_edge(5, 4)
                and self.pag.has_directed_edge(4, 3)
                and self.pag.has_directed_edge(3, 4))

    # Rule 5 Tests
    def test51(self):
        self.pag.add_edge(1, 2)
        self.pag.add_edge(1, 3)
        self.pag.add_edge(3, 4)
        self.pag.add_edge(2, 4)
        FCIAlg.rule5(self.pag, 1, 2, 3, 4)
        assert (self.pag.has_fully_undirected_edge(1, 2)
                and self.pag.has_fully_undirected_edge(1, 3)
                and self.pag.has_fully_undirected_edge(3, 4)
                and self.pag.has_fully_undirected_edge(4, 2))

    def test52(self):
        self.pag.add_edge(1, 2)
        self.pag.add_edge(1, 3)
        self.pag.add_edge(2, 4)
        FCIAlg.rule5(self.pag, 1, 2, 3, 4)
        assert (not (self.pag.has_fully_undirected_edge(1, 2)
                     and self.pag.has_fully_undirected_edge(1, 3)
                     and self.pag.has_fully_undirected_edge(3, 4)
                     and self.pag.has_fully_undirected_edge(4, 2)))

    def test53(self):
        self.pag.add_edge(1, 2)
        self.pag.add_edge(1, 3)
        self.pag.add_edge(3, 4)
        FCIAlg.rule5(self.pag, 1, 2, 3, 4)
        assert (not (self.pag.has_fully_undirected_edge(1, 2)
                     and self.pag.has_fully_undirected_edge(1, 3)
                     and self.pag.has_fully_undirected_edge(3, 4)
                     and self.pag.has_fully_undirected_edge(4, 2)))

    def test54(self):
        self.pag.add_edge(1, 5)
        self.pag.add_edge(1, 2)
        self.pag.add_edge(1, 3)
        self.pag.add_edge(3, 4)
        self.pag.add_edge(2, 4)
        FCIAlg.rule5(self.pag, 1, 2, 3, 4)
        assert (self.pag.has_fully_undirected_edge(1, 2)
                and self.pag.has_fully_undirected_edge(1, 3)
                and self.pag.has_fully_undirected_edge(3, 4)
                and self.pag.has_fully_undirected_edge(4, 2))

    # Rule 6 Tests
    def test61(self):
        self.pag.add_edge(1, 2)
        self.pag.add_edge(2, 3)
        self.pag.undirect_edge(1, 2)
        FCIAlg.rule67(self.pag, 1, 2, 3)
        assert (self.pag.get_edge_data(2, 3)[2] == '-')

    def test62(self):
        self.pag.add_edge(1, 2)
        self.pag.add_edge(2, 3)
        FCIAlg.rule67(self.pag, 1, 2, 3)
        assert (not (self.pag.get_edge_data(2, 3)[2] == '-'))

    def test63(self):
        self.pag.add_edge(2, 3)
        self.pag.undirect_edge(1, 2)
        FCIAlg.rule67(self.pag, 1, 2, 3)
        assert (not (self.pag.get_edge_data(2, 3)[2] == '-'))

    # Rule 7 Tests
    def test71(self):
        self.pag.add_edge(1, 2)
        self.pag.add_edge(2, 3)
        self.pag.setTag([1, 2], 1, '-')
        FCIAlg.rule67(self.pag, 1, 2, 3)
        assert (self.pag.get_edge_data(2, 3)[2] == '-')

    def test72(self):
        self.pag.add_edge(1, 2)
        self.pag.add_edge(2, 3)
        FCIAlg.rule67(self.pag, 1, 2, 3)
        assert (not (self.pag.get_edge_data(2, 3)[2] == '-'))

    # Rule 8 Tests
    def test81(self):
        self.pag.add_edge(1, 2)
        self.pag.add_edge(2, 3)
        self.pag.add_edge(1, 3)
        self.pag.direct_edge(1, 3)
        self.pag.fully_direct_edge(1, 2)
        self.pag.fully_direct_edge(2, 3)
        FCIAlg.rule8(self.pag, 1, 2, 3)
        assert (self.pag.has_fully_directed_edge(1, 3))

    def test82(self):
        self.pag.add_edge(1, 2)
        self.pag.add_edge(2, 3)
        self.pag.add_edge(1, 3)
        self.pag.direct_edge(1, 3)
        self.pag.setTag([1, 2], 1, '-')
        self.pag.fully_direct_edge(2, 3)
        FCIAlg.rule8(self.pag, 1, 2, 3)
        assert (self.pag.has_fully_directed_edge(1, 3))

    # Rule 9 Tests
    def test91(self):
        self.pag.add_edge(1, 3)
        self.pag.add_edge(1, 2)
        self.pag.add_edge(2, 4)
        self.pag.add_edge(4, 3)
        self.pag.direct_edge(1, 3)
        self.pag.direct_edge(1, 2)
        FCIAlg.rule9(self.pag, 1, 2, 3, 4)
        assert (self.pag.has_fully_directed_edge(1, 3))

    def test92(self):
        self.pag.add_edge(1, 3)
        self.pag.add_edge(1, 2)
        self.pag.add_edge(2, 4)
        self.pag.add_edge(4, 3)
        self.pag.direct_edge(1, 3)
        self.pag.undirect_edge(1, 2)
        FCIAlg.rule9(self.pag, 1, 2, 3, 4)
        assert (not self.pag.has_fully_directed_edge(1, 3))

    def test93(self):
        self.pag.add_edge(1, 3)
        self.pag.add_edge(1, 2)
        self.pag.add_edge(2, 4)
        self.pag.add_edge(4, 3)
        self.pag.direct_edge(1, 3)
        self.pag.direct_edge(2, 1)
        FCIAlg.rule9(self.pag, 1, 2, 3, 4)
        assert (not self.pag.has_fully_directed_edge(1, 3))

    # Rule 10 Tests
    def test101(self):
        self.pag.add_edge(1, 3)
        self.pag.add_edge(1, 2)
        self.pag.add_edge(2, 3)
        self.pag.add_edge(4, 3)
        self.pag.add_edge(1, 5)
        self.pag.add_edge(5, 4)
        self.pag.direct_edge(1, 3)
        self.pag.fully_direct_edge(2, 3)
        self.pag.direct_edge(1, 5)
        self.pag.fully_direct_edge(4, 3)
        FCIAlg.rule10(self.pag, 1, 2, 3, 4)
        assert (self.pag.has_fully_directed_edge(1, 3))

    def test102(self):
        self.pag.add_edge(1, 3)
        self.pag.add_edge(1, 2)
        self.pag.add_edge(2, 3)
        self.pag.add_edge(4, 3)
        self.pag.add_edge(1, 5)
        self.pag.add_edge(5, 4)
        self.pag.fully_direct_edge(2, 3)
        self.pag.direct_edge(1, 5)
        self.pag.fully_direct_edge(4, 3)
        FCIAlg.rule10(self.pag, 1, 2, 3, 4)
        assert (not (self.pag.has_fully_directed_edge(1, 3)))

    def test103(self):
        self.pag.add_edge(1, 3)
        self.pag.add_edge(1, 2)
        self.pag.add_edge(2, 3)
        self.pag.add_edge(4, 3)
        self.pag.add_edge(1, 5)
        self.pag.direct_edge(1, 3)
        self.pag.direct_edge(1, 5)
        self.pag.fully_direct_edge(4, 3)
        FCIAlg.rule10(self.pag, 1, 2, 3, 4)
        assert (not (self.pag.has_fully_directed_edge(1, 3)))

    def test104(self):
        self.pag.add_edge(1, 3)
        self.pag.add_edge(1, 2)
        self.pag.add_edge(2, 3)
        self.pag.add_edge(4, 3)
        self.pag.add_edge(1, 5)
        self.pag.direct_edge(1, 3)
        self.pag.fully_direct_edge(2, 3)
        self.pag.direct_edge(1, 5)
        self.pag.fully_direct_edge(4, 3)
        FCIAlg.rule10(self.pag, 1, 2, 3, 4)
        assert (not (self.pag.has_fully_directed_edge(1, 3)))
class D_SepTests(unittest.TestCase):
    #Tests for possible d sep set calculation
    def setUp(self):
        self.pag = PAG()
        self.pag.add_nodes_from([1, 2, 3, 4, 5, 6])

    def test1(self):
        self.pag.add_edge(1, 2)
        self.pag.add_edge(3, 2)
        dseps = FCIAlg.possible_d_seps(self.pag)
        assert (dseps == {1: [2], 2: [1, 3], 3: [2], 4: [], 5: [], 6: []})

    def test2(self):
        self.pag.add_edge(1, 2)
        self.pag.add_edge(2, 3)
        self.pag.fully_direct_edge(1, 2)
        self.pag.fully_direct_edge(3, 2)
        dseps = FCIAlg.possible_d_seps(self.pag)
        assert (dseps == {
            1: [2, 3],
            2: [1, 3],
            3: [1, 2],
            4: [],
            5: [],
            6: []
        })

    def test3(self):
        self.pag.add_edge(1, 2)
        self.pag.add_edge(2, 3)
        self.pag.fully_direct_edge(1, 2)
        self.pag.fully_direct_edge(3, 2)
        self.pag.add_edge(2, 4)
        self.pag.add_edge(3, 4)
        dseps = FCIAlg.possible_d_seps(self.pag)
        assert (dseps == {
            1: [2, 3, 4],
            2: [1, 3, 4],
            3: [
                1,
                2,
                4,
            ],
            4: [1, 2, 3],
            5: [],
            6: []
        })

    def test4(self):
        self.pag.add_edge(1, 2)
        self.pag.add_edge(3, 2)
        self.pag.add_edge(3, 1)
        dseps = FCIAlg.possible_d_seps(self.pag)
        assert (dseps == {
            1: [2, 3],
            2: [1, 3],
            3: [1, 2],
            4: [],
            5: [],
            6: []
        })