Beispiel #1
0
def causalgraph_check():
    dag = DirectedAcyclicGraph(exposure="X", outcome="Y")
    dag.add_arrow(source="X", endpoint="Y")
    dag.add_arrow(source="V", endpoint="Y")
    dag.add_arrows(pairs=(("W", "X"), ("W", "Y")))
    dag.draw_dag()
    plt.show()
Beispiel #2
0
    def test_no_mediator(self):
        correct_set = [{'W', 'V'}]

        dag = DirectedAcyclicGraph(exposure='X', outcome="Y")
        dag.add_arrows((
            ("X", "Y"),
            ("W", "X"),
            ("W", "Y"),
            ("V", "X"),
            ("V", "Y"),
            ("X", "M"),
            ("M", "Y"),
        ))
        dag.calculate_adjustment_sets()

        # Making sure number of adjustment sets are equal to correct sets
        assert len(dag.adjustment_sets) == len(correct_set)

        # Checking no 'double' sets in adjustment sets
        assert len(dag.adjustment_sets) == len(set(dag.adjustment_sets))

        # Checking that all adjustment sets are in the correct
        for i in dag.adjustment_sets:
            assert set(i) in list(correct_set)

        for i in dag.minimal_adjustment_sets:
            assert set(i) in correct_set
Beispiel #3
0
    def test_adjustment_set_1(self, arrow_list_1):
        correct_set = [{"W", "Z"}, {"V", "Z"}, {"W", "V", "Z"}]

        dag = DirectedAcyclicGraph(exposure='X', outcome="Y")
        dag.add_arrows(arrow_list_1)
        dag.calculate_adjustment_sets()

        # Making sure number of adjustment sets are equal to correct sets
        assert len(dag.adjustment_sets) == len(correct_set)

        # Checking no 'double' sets in adjustment sets
        assert len(dag.adjustment_sets) == len(set(dag.adjustment_sets))

        # Checking that all adjustment sets are in the correct
        for i in dag.adjustment_sets:
            assert set(i) in list(correct_set)
Beispiel #4
0
    def test_adjustment_set_3(self, arrow_list_3):
        correct_set = [()]

        dag = DirectedAcyclicGraph(exposure='X', outcome="Y")
        dag.add_arrows(arrow_list_3)
        dag.calculate_adjustment_sets()
        print(dag.adjustment_sets)

        # Making sure number of adjustment sets are equal to correct sets
        assert len(dag.adjustment_sets) == len(correct_set)

        # Checking no 'double' sets in adjustment sets
        assert len(dag.adjustment_sets) == len(set(dag.adjustment_sets))

        # Checking that minimal is the same
        assert dag.adjustment_sets == dag.minimal_adjustment_sets
Beispiel #5
0
    def test_butterfly(self, arrow_butterfly):
        correct_set = [{'U1', 'B'}, {'B', 'U2'}, {'U1', 'B', 'U2'}]

        dag = DirectedAcyclicGraph(exposure='X', outcome="Y")
        dag.add_arrows(arrow_butterfly)
        dag.calculate_adjustment_sets()

        # Making sure number of adjustment sets are equal to correct sets
        assert len(dag.adjustment_sets) == len(correct_set)

        # Checking no 'double' sets in adjustment sets
        assert len(dag.adjustment_sets) == len(set(dag.adjustment_sets))

        # Checking that all adjustment sets are in the correct
        for i in dag.adjustment_sets:
            assert set(i) in list(correct_set)

        for i in dag.minimal_adjustment_sets:
            assert set(i) in [{'U1', 'B'}, {'B', 'U2'}]
Beispiel #6
0
import statsmodels.api as sm
import statsmodels.formula.api as smf

from scipy.stats import logistic
from pygam import LinearGAM
from sklearn.neural_network import MLPRegressor
from sklearn.neighbors import KNeighborsRegressor
from sklearn.ensemble import RandomForestRegressor
from zepid.causal.causalgraph import DirectedAcyclicGraph

np.random.seed(20210316)

###############################################
# Directed Acyclic Graph
###############################################
dag = DirectedAcyclicGraph(exposure=r'$X$', outcome=r'$Y$')
dag.add_arrows([[r'$Z$', r'$X$'],
                [r'$Z$', r'$Y$'],
                [r'$X$', r'$M$'],
                [r'$M$', r'$Y$'],
                [r'$U$', r'$Z$'],
                [r'$U$', r'$X$'],
                ])

dag.draw_dag(positions={r'$Z$': [-0.75, 0.05],
                        r'$X$': [-0.5, 0],
                        r'$M$': [-0.25, -0.02],
                        r'$U$': [-1, 0],
                        r'$Y$': [0, 0]},
             fig_size=[6, 3])
plt.ylim([-0.05, 0.08])
Beispiel #7
0
 def test_read_networkx(self):
     G = nx.DiGraph()
     G.add_edges_from((("X", "Y"), ("C", "Y"), ("C", "X")))
     dag = DirectedAcyclicGraph(exposure='X', outcome="Y")
     dag.add_from_networkx(G)
Beispiel #8
0
 def test_error_networkx_noY(self):
     G = nx.DiGraph()
     G.add_edge("X", "W")
     dag = DirectedAcyclicGraph(exposure='X', outcome="Y")
     with pytest.raises(DAGError):
         dag.add_from_networkx(G)
Beispiel #9
0
 def test_error_read_networkx(self):
     G = nx.DiGraph()
     G.add_edges_from((("X", "Y"), ("Y", "C"), ("C", "X")))
     dag = DirectedAcyclicGraph(exposure='X', outcome="Y")
     with pytest.raises(DAGError):
         dag.add_from_networkx(G)
Beispiel #10
0
 def test_error_add_from_cyclic_arrow(self):
     dag = DirectedAcyclicGraph(exposure='X', outcome="Y")
     with pytest.raises(DAGError):
         dag.add_arrows(pairs=(("X", "Y"), ("Y", "C"), ("C", "X")))
Beispiel #11
0
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

#####################################################################################################################
# Causal Graphs
#####################################################################################################################
print("Running causal graphs...")

from zepid.causal.causalgraph import DirectedAcyclicGraph

dag = DirectedAcyclicGraph(exposure='X', outcome="Y")
dag.add_arrows(
    (('X', 'Y'), ('U1', 'X'), ('U1', 'B'), ('U2', 'B'), ('U2', 'Y')))
pos = {"X": [0, 0], "Y": [1, 0], "B": [0.5, 0.5], "U1": [0, 1], "U2": [1, 1]}

dag.draw_dag(positions=pos)
plt.tight_layout()
plt.savefig("../images/zepid_dag_mbias.png", format='png', dpi=300)
plt.close()

dag.calculate_adjustment_sets()
print(dag.adjustment_sets)

dag.add_arrows((('X', 'Y'), ('U1', 'X'), ('U1', 'B'), ('U2', 'B'), ('U2', 'Y'),
                ('B', 'X'), ('B', 'Y')))

dag.draw_dag(positions=pos)
plt.tight_layout()
plt.savefig("../images/zepid_dag_bbias.png", format='png', dpi=300)
plt.close()