def intervention(cls, input):
        """Users can apply an intervention to any node in the data, updating its distribution
         using a do operator, examining the effect of that intervention by querying marginals 
         and resetting any interventions

        Args:
           input (a list of dictionaries): The data on which to do the interventions.
        """

        from causalnex.inference import InferenceEngine

        bn = cls.get_model()
        ie = InferenceEngine(bn)
        i_node = input["node"]
        i_states = input["states"]
        i_target = input["target_node"]

        print(i_node, i_states, i_target)
        lst = []

        # i_states is a list of dict
        for state in i_states:
            state = {int(k): int(v) for k, v in state.items()}
            ie.do_intervention(i_node, state)
            intervention_result = ie.query()[i_target]
            lst.append(intervention_result)
            print("Updated marginal", intervention_result)
            ie.reset_do(i_node)

        return lst
Example #2
0
    def test_reset_do_sets_probabilities_back_to_initial_state(
            self, bn, train_data_discrete_marginals):
        """Resetting Do operator should re-introduce the original conditional dependencies"""

        ie = InferenceEngine(bn)
        ie.do_intervention("d", {False: 0.7, True: 0.3})
        ie.reset_do("d")

        assert math.isclose(ie.query()["d"][False],
                            train_data_discrete_marginals["d"][False])
        assert math.isclose(ie.query()["d"][False],
                            train_data_discrete_marginals["d"][False])
Example #3
0
    def test_reset_do_sets_probabilities_back_to_initial_state(
        self, train_model, train_data_idx, train_data_idx_marginals
    ):
        """Resetting Do operator should re-introduce the original conditional dependencies"""

        bn = BayesianNetwork(train_model)
        bn.fit_node_states(train_data_idx).fit_cpds(train_data_idx)

        ie = InferenceEngine(bn)
        ie.do_intervention("d", {0: 0.7, 1: 0.3})
        ie.reset_do("d")

        assert math.isclose(ie.query()["d"][0], train_data_idx_marginals["d"][0])
        assert math.isclose(ie.query()["d"][1], train_data_idx_marginals["d"][1])
Example #4
0
# For example, can ask what would happen if 100% of students wanted to go on to do higher education.
# %% codecell
print("'higher' marginal distribution before DO: ", eng.query()['higher'])

# Make the intervention on the network
eng.do_intervention(node='higher', state={
    'yes': 1.0,
    'no': 0.0
})  # all students yes

print("'higher' marginal distribution after DO: ", eng.query()['higher'])
# %% markdown [markdown]
# ### Resetting a Node Distribution
# We can reset any interventions that we make using `reset_intervention` method and providing the node we want to reset:
# %% codecell
eng.reset_do('higher')

eng.query()['higher']  # same as before

# %% markdown [markdown]
# ### Effect of DO on Marginals
# We can use `query` to find the effect that an intervention has on our marginal likelihoods of OTHER variables, not just on the INTERVENED variable.
#
# **Example 1:** change 'higher' and check grade 'G1' (how the likelihood of achieving a pass changes if 100% of students wanted to do higher education)
#
# Answer: if 100% of students wanted to do higher education (as opposed to 90% in our data population) , then we estimate the pass rate would increase from 74.7% to 79.3%.
# %% codecell
print('marginal G1', eng.query()['G1'])

eng.do_intervention(node='higher', state={'yes': 1.0, 'no': 0.0})
print('updated marginal G1', eng.query()['G1'])
# Higher 변수(고등 교육 선호도)에 Intervention(개입, 조작, 조종) 수행
# -> 임의로 해당 변수의 분포를 통제(도메인 지식 개입))
# 여기서는 모두가 고등 교육을 선호할 것이다라고 가정하고 해당 변수의 분포를 강제로 변경
print("distribution before do", ie.query()["higher"])
ie.do_intervention("higher",
                   {'yes': 1.0,
                    'no': 0.0})
print("distribution after do", ie.query()["higher"])
"""
distribution before do {'no': 0.10752688172043011, 'yes': 0.8924731182795698}
distribution after do {'no': 0.0, 'yes': 0.9999999999999998}
=> higher 변수를 임의로 marginal까지 조정하여 강제로 변화시킨 결과
"""
# Intervention을 다시 되돌리고 싶을 때
ie.reset_do("higher")


# 변수에 Intervention을 수행한 후 확률 계산 변화
print("marginal G1", ie.query()["G1"])
ie.do_intervention("higher",
                   {'yes': 1.0,
                    'no': 0.0})
print("updated marginal G1", ie.query()["G1"])

"""
marginal G1 {'Fail': 0.25260687281677224, 'Pass': 0.7473931271832277}
updated marginal G1 {'Fail': 0.20682952942551894, 'Pass': 0.7931704705744809}
=> higher를 marginal까지 조정한 뒤 G1의 결과 변화 예상.
=> 고등 교육을 선호할 수록 시험에 통과할 확률이 높아진다라고 예측.
"""
Example #6
0
    def test_query_after_do_intervention_has_split_graph(self, chain_network):
        """
        chain network: a → b → c → d → e

        test 1.
        - do intervention on node c generates 2 graphs (a → b) and (c → d → e)
        - assert the query can be run (it used to hang before)
        - assert rest_do works
        """
        ie = InferenceEngine(chain_network)
        original_margs = ie.query()

        var = "c"
        state_dict = {0: 1.0, 1: 0.0}
        ie.do_intervention(var, state_dict)
        # assert the intervention node has indeed the right state
        assert ie.query()[var][0] == state_dict[0]
        assert ie.query()[var][1] == state_dict[1]

        # assert the upstream nodes have the default marginals (no info
        # propagates in the upstream graph)
        assert ie.query()["a"][0] == original_margs["a"][0]
        assert ie.query()["a"][1] == original_margs["a"][1]
        assert ie.query()["b"][0] == original_margs["b"][0]
        assert ie.query()["b"][1] == original_margs["b"][1]

        # assert the _cpds of the upstream nodes are stored correctly
        orig_cpds = ie._cpds_original  # pylint: disable=protected-access
        upstream_cpds = ie._detached_cpds  # pylint: disable=protected-access
        assert orig_cpds["a"] == upstream_cpds["a"]
        assert orig_cpds["b"] == upstream_cpds["b"]

        ie.reset_do(var)
        reset_margs = ie.query()

        for node in original_margs.keys():
            dict_left = original_margs[node]
            dict_right = reset_margs[node]
            for (kl, kr) in zip(dict_left.keys(), dict_right.keys()):
                assert math.isclose(dict_left[kl], dict_right[kr])

        # repeating above tests intervening on b, so that there is one single
        # isolate
        var_b = "b"
        state_dict_b = {0: 1.0, 1: 0.0}
        ie.do_intervention(var_b, state_dict_b)
        # assert the intervention node has indeed the right state
        assert ie.query()[var_b][0] == state_dict[0]
        assert ie.query()[var_b][1] == state_dict[1]

        # assert the upstream nodes have the default marginals (no info
        # propagates in the upstream graph)
        assert ie.query()["a"][0] == original_margs["a"][0]
        assert ie.query()["a"][1] == original_margs["a"][1]

        # assert the _cpds of the upstream nodes are stored correctly
        orig_cpds = ie._cpds_original  # pylint: disable=protected-access
        upstream_cpds = ie._detached_cpds  # pylint: disable=protected-access
        assert orig_cpds["a"] == upstream_cpds["a"]

        ie.reset_do(var_b)
        reset_margs = ie.query()

        for node in original_margs.keys():
            dict_left = original_margs[node]
            dict_right = reset_margs[node]
            for (kl, kr) in zip(dict_left.keys(), dict_right.keys()):
                assert math.isclose(dict_left[kl], dict_right[kr])