def intervention(cls, input): """Users can apply an intervention to any node in the data, updating its distribution using a do operator, examining the effect of that intervention by querying marginals and resetting any interventions Args: input (a list of dictionaries): The data on which to do the interventions. """ from causalnex.inference import InferenceEngine bn = cls.get_model() ie = InferenceEngine(bn) i_node = input["node"] i_states = input["states"] i_target = input["target_node"] print(i_node, i_states, i_target) lst = [] # i_states is a list of dict for state in i_states: state = {int(k): int(v) for k, v in state.items()} ie.do_intervention(i_node, state) intervention_result = ie.query()[i_target] lst.append(intervention_result) print("Updated marginal", intervention_result) ie.reset_do(i_node) return lst
def test_reset_do_sets_probabilities_back_to_initial_state( self, bn, train_data_discrete_marginals): """Resetting Do operator should re-introduce the original conditional dependencies""" ie = InferenceEngine(bn) ie.do_intervention("d", {False: 0.7, True: 0.3}) ie.reset_do("d") assert math.isclose(ie.query()["d"][False], train_data_discrete_marginals["d"][False]) assert math.isclose(ie.query()["d"][False], train_data_discrete_marginals["d"][False])
def test_reset_do_sets_probabilities_back_to_initial_state( self, train_model, train_data_idx, train_data_idx_marginals ): """Resetting Do operator should re-introduce the original conditional dependencies""" bn = BayesianNetwork(train_model) bn.fit_node_states(train_data_idx).fit_cpds(train_data_idx) ie = InferenceEngine(bn) ie.do_intervention("d", {0: 0.7, 1: 0.3}) ie.reset_do("d") assert math.isclose(ie.query()["d"][0], train_data_idx_marginals["d"][0]) assert math.isclose(ie.query()["d"][1], train_data_idx_marginals["d"][1])
# For example, can ask what would happen if 100% of students wanted to go on to do higher education. # %% codecell print("'higher' marginal distribution before DO: ", eng.query()['higher']) # Make the intervention on the network eng.do_intervention(node='higher', state={ 'yes': 1.0, 'no': 0.0 }) # all students yes print("'higher' marginal distribution after DO: ", eng.query()['higher']) # %% markdown [markdown] # ### Resetting a Node Distribution # We can reset any interventions that we make using `reset_intervention` method and providing the node we want to reset: # %% codecell eng.reset_do('higher') eng.query()['higher'] # same as before # %% markdown [markdown] # ### Effect of DO on Marginals # We can use `query` to find the effect that an intervention has on our marginal likelihoods of OTHER variables, not just on the INTERVENED variable. # # **Example 1:** change 'higher' and check grade 'G1' (how the likelihood of achieving a pass changes if 100% of students wanted to do higher education) # # Answer: if 100% of students wanted to do higher education (as opposed to 90% in our data population) , then we estimate the pass rate would increase from 74.7% to 79.3%. # %% codecell print('marginal G1', eng.query()['G1']) eng.do_intervention(node='higher', state={'yes': 1.0, 'no': 0.0}) print('updated marginal G1', eng.query()['G1'])
# Higher 변수(고등 교육 선호도)에 Intervention(개입, 조작, 조종) 수행 # -> 임의로 해당 변수의 분포를 통제(도메인 지식 개입)) # 여기서는 모두가 고등 교육을 선호할 것이다라고 가정하고 해당 변수의 분포를 강제로 변경 print("distribution before do", ie.query()["higher"]) ie.do_intervention("higher", {'yes': 1.0, 'no': 0.0}) print("distribution after do", ie.query()["higher"]) """ distribution before do {'no': 0.10752688172043011, 'yes': 0.8924731182795698} distribution after do {'no': 0.0, 'yes': 0.9999999999999998} => higher 변수를 임의로 marginal까지 조정하여 강제로 변화시킨 결과 """ # Intervention을 다시 되돌리고 싶을 때 ie.reset_do("higher") # 변수에 Intervention을 수행한 후 확률 계산 변화 print("marginal G1", ie.query()["G1"]) ie.do_intervention("higher", {'yes': 1.0, 'no': 0.0}) print("updated marginal G1", ie.query()["G1"]) """ marginal G1 {'Fail': 0.25260687281677224, 'Pass': 0.7473931271832277} updated marginal G1 {'Fail': 0.20682952942551894, 'Pass': 0.7931704705744809} => higher를 marginal까지 조정한 뒤 G1의 결과 변화 예상. => 고등 교육을 선호할 수록 시험에 통과할 확률이 높아진다라고 예측. """
def test_query_after_do_intervention_has_split_graph(self, chain_network): """ chain network: a → b → c → d → e test 1. - do intervention on node c generates 2 graphs (a → b) and (c → d → e) - assert the query can be run (it used to hang before) - assert rest_do works """ ie = InferenceEngine(chain_network) original_margs = ie.query() var = "c" state_dict = {0: 1.0, 1: 0.0} ie.do_intervention(var, state_dict) # assert the intervention node has indeed the right state assert ie.query()[var][0] == state_dict[0] assert ie.query()[var][1] == state_dict[1] # assert the upstream nodes have the default marginals (no info # propagates in the upstream graph) assert ie.query()["a"][0] == original_margs["a"][0] assert ie.query()["a"][1] == original_margs["a"][1] assert ie.query()["b"][0] == original_margs["b"][0] assert ie.query()["b"][1] == original_margs["b"][1] # assert the _cpds of the upstream nodes are stored correctly orig_cpds = ie._cpds_original # pylint: disable=protected-access upstream_cpds = ie._detached_cpds # pylint: disable=protected-access assert orig_cpds["a"] == upstream_cpds["a"] assert orig_cpds["b"] == upstream_cpds["b"] ie.reset_do(var) reset_margs = ie.query() for node in original_margs.keys(): dict_left = original_margs[node] dict_right = reset_margs[node] for (kl, kr) in zip(dict_left.keys(), dict_right.keys()): assert math.isclose(dict_left[kl], dict_right[kr]) # repeating above tests intervening on b, so that there is one single # isolate var_b = "b" state_dict_b = {0: 1.0, 1: 0.0} ie.do_intervention(var_b, state_dict_b) # assert the intervention node has indeed the right state assert ie.query()[var_b][0] == state_dict[0] assert ie.query()[var_b][1] == state_dict[1] # assert the upstream nodes have the default marginals (no info # propagates in the upstream graph) assert ie.query()["a"][0] == original_margs["a"][0] assert ie.query()["a"][1] == original_margs["a"][1] # assert the _cpds of the upstream nodes are stored correctly orig_cpds = ie._cpds_original # pylint: disable=protected-access upstream_cpds = ie._detached_cpds # pylint: disable=protected-access assert orig_cpds["a"] == upstream_cpds["a"] ie.reset_do(var_b) reset_margs = ie.query() for node in original_margs.keys(): dict_left = original_margs[node] dict_right = reset_margs[node] for (kl, kr) in zip(dict_left.keys(), dict_right.keys()): assert math.isclose(dict_left[kl], dict_right[kr])