def do_causal(self, their_tensors, absent_tensors, movie_inds, num_samples=1000): # With confounding their_do = pyro.do(model, data={"x": their_tensors}) absent_do = pyro.do(model, data={"x": absent_tensors}) their_do_y = [] for _ in range(num_samples): their_do_y.append( torch.sum(their_do(self.step2_params)['y'][movie_inds]).item()) absent_do_y = [] for _ in range(num_samples): absent_do_y.append( torch.sum(absent_do( self.step2_params)['y'][movie_inds]).item()) their_do_mean = np.mean(their_do_y) absent_do_mean = np.mean(absent_do_y) causal_effect_conf = their_do_mean - absent_do_mean return causal_effect_conf
def twin_query(self, node_of_interest, evidence, intervention, n_samples, distribution=False, merge=False): """Run the twin network counterfactual inference procedure. Args: node_of_interest (str): the name of the noode of interest intervention (dict): a dictionary of {node_name: value} interventions evidence (dict): a dictionary of {node_name: value} evidence n_samples (int): the number of samples to take during importance sampling and from the posterior. distribution (bool): if True, return samples. If False, returns mean and sd. of distribution. """ if not self.scm.twin_exists: self.scm.create_twin_network() if merge: self.scm.merge_in_twin(node_of_interest, intervention) self.G_inference = self.scm.twin_G intervention = { "{}tn".format(k): torch.tensor([intervention[k]]).double().flatten() for k in intervention } node_of_interest = "{}tn".format( node_of_interest ) if "tn" not in node_of_interest else node_of_interest self._intv_nodes = [k for k in intervention] self._noi_nodes = [node_of_interest] intervened_model = pyro.do(self.model, data=intervention) intervened_guide = pyro.do(self.guide, data=intervention) if self.verbose: print("Performing Twin Network inference... ", end="", flush=True) t_twin = time.time() posterior = self.get_posterior(self._noi_nodes, evidence, n_samples, custom_model=intervened_model, custom_guide=intervened_guide, twin=True) t_twin = time.time() - t_twin if self.verbose: print("✓ ({}s)".format(round(t_twin, 3))) self._intv_nodes = [] self._noi_nodes = [] samples = posterior.sample_n((n_samples)) if not distribution: return samples.mean().numpy(), samples.std().numpy() else: return samples
def scm_ras_erk_counterfactual(rates, totals, observation, ras_intervention, spike_width=1.0, svi=True): gf_scm = GF_SCM(rates, totals, spike_width) noise = { 'N_SOS': Normal(0., 1.), 'N_Ras': Normal(0., 1.), 'N_PI3K': Normal(0., 1.), 'N_AKT': Normal(0., 1.), 'N_Raf': Normal(0., 1.), 'N_Mek': Normal(0., 1.), 'N_Erk': Normal(0., 1.) } if svi: updated_noise, _ = gf_scm.update_noise_svi(observation, noise) else: updated_noise = gf_scm.update_noise_importance(observation, noise) counterfactual_model = do(gf_scm.model, ras_intervention) cf_posterior = gf_scm.infer(counterfactual_model, updated_noise) cf_erk_marginal = EmpiricalMarginal(cf_posterior, sites='Erk') scm_causal_effect_samples = [ observation['Erk'] - float(cf_erk_marginal.sample()) for _ in range(500) ] return scm_causal_effect_samples
def _pull_bandit(self): observed_env = self.environment.observe() algo_choice = self._select_arm(observed_env) model = self.bandits.gambler_model intervention_model = do(model, {'arm':algo_choice}) result = intervention_model(observed_env) self.update_parameters(algo_choice, result, observed_env)
def intervention(self, intervention_data: dict): """ It intervenes self.model with intervention data Returns: intervention pyro model """ intervention_model = pyro.do(self.model, intervention_data) return intervention_model
def doCausal(their_tensors, absent_tensors, movie_inds): # With confounding their_do = pyro.do(model, data={"x": their_tensors}) absent_do = pyro.do(model, data={"x": absent_tensors}) their_do_y = [] for _ in range(1000): their_do_y.append(torch.sum(their_do(p2)['y'][movie_inds]).item()) absent_do_y = [] for _ in range(1000): absent_do_y.append(torch.sum(absent_do(p2)['y'][movie_inds]).item()) their_do_mean = np.mean(their_do_y) absent_do_mean = np.mean(absent_do_y) causal_effect_conf = their_do_mean - absent_do_mean return causal_effect_conf
def imagine_next_step(env, action, i): """Agent imagines next time step""" sim_env = deepcopy(env) state = sim_env.s int_model = pyro.do(model, {f'action{state}{i}': action}) sim_env, _, _, _, _ = int_model(sim_env, i) # sanity check assert sim_env.lastaction == action return sim_env
def model_do_sample(self, do_dict): # sample the graph given the do-variables in do_dict data_in = {} for item in do_dict: data_in[item] = do_dict[item] do_model = pyro.do(self.model_sample, data=data_in) return do_model()
def do_predict(self, X_test, num_samples=1000): if self.test_params is None: self.infer_z(X_test) do = pyro.do(model, data={"x": X_test}) predictions = np.zeros((X_test.shape[0])) for _ in range(num_samples): y_pred = do(self.test_params)['y'] predictions += y_pred.detach().numpy() return predictions / num_samples, self.test_params
def play(self, state, time_left): """ Puts policy into action and moves to next state """ while not torch.eq(time_left, torch.tensor(0.)): action_posterior = self.infer(state, time_left,'action') action = self.policy(action_posterior) interv_model = pyro.do(self.model, data={'action_{}_{}'.format(state,time_left): action}) next_state = interv_model(state, time_left)['next_state_{}_{}'.format(state,time_left)] # Print Trajectory and Actions print('State:',state.item(), ', Action: ', list(Environment().action_dictionary.keys())[action]) return action, self.play(next_state, time_left-1)
def main(): fl_env = gym.make('FrozenLake-v0', is_slippery=False) fl_env.reset() fl_env.s = 1 for t in range(25): pyro.clear_param_store() action = policy(fl_env) int_model = pyro.do(model, {f'action{fl_env.s}{t}': torch.tensor(action)}) fl_env, observation, reward, done, info = int_model(fl_env, t) fl_env.render() if done: print("Episode finished after {} timesteps".format(t + 1)) break fl_env.close()
def scm_covid_counterfactual( rates, totals, observation, ras_intervention, spike_width=1.0, svi=True ): gf_scm = COVID_SCM(rates, totals, spike_width) noise = { 'N_SARS_COV2': (0., 5.), 'N_TOCI': (0., 5.), 'N_PRR': (0., 1.), 'N_ACE2': (0., 1.), 'N_AngII': (0., 1.), 'N_AGTR1': (0., 1.), 'N_ADAM17': (0., 1.), 'N_IL_6Ralpha': (0., 1.), 'N_sIL_6_alpha': (0., 1.), 'N_STAT3': (0., 1.), 'N_EGF': (0., 1.), 'N_TNF': (0., 1.), 'N_EGFR': (0., 1.), 'N_IL6_STAT3': (0., 1.), 'N_NF_xB': (0., 1.), 'N_IL_6_AMP': (0., 1.), 'N_cytokine': (0., 1.) } if svi: updated_noise, _ = gf_scm.update_noise_svi(observation, noise) else: updated_noise = gf_scm.update_noise_importance(observation, noise) counterfactual_model = do(gf_scm.model, ras_intervention) cf_posterior = gf_scm.infer(counterfactual_model, updated_noise) cf_cytokine_marginal = EmpiricalMarginal(cf_posterior, sites=['cytokine']) cf_il6amp_marginal = EmpiricalMarginal(cf_posterior, sites=['IL_6_AMP']) cf_nfxb_marginal = EmpiricalMarginal(cf_posterior, sites=['NF_xB']) cf_il6stat3_marginal = EmpiricalMarginal(cf_posterior, sites=['IL6_STAT3']) scm_causal_effect_samples = [ observation['cytokine'] - float(cf_cytokine_marginal.sample()) for _ in range(5000) ] il6amp_samples = cf_il6amp_marginal.sample((5000,)).tolist() nfxb_samples = cf_nfxb_marginal.sample((5000,)).tolist() il6stat3_samples = cf_il6stat3_marginal.sample((5000,)).tolist() return il6amp_samples, nfxb_samples, il6stat3_samples, scm_causal_effect_samples
def model_do_cond_sample(self, do_dict, data_dict): # sample the graph given do-variables in do_dict and conditioned variables in data_dict if np.any([[item1 == item2 for item1 in do_dict] for item2 in data_dict]): print('overlapping lists!') return else: do_dict_in = {} for item in do_dict: do_dict_in[item] = do_dict[item] data_dict_in = {} for item in data_dict: data_dict_in[item] = data_dict[item] do_model = pyro.do(self.model_sample, data=do_dict_in) cond_model = pyro.condition(do_model, data=data_dict_in) return cond_model()
def infer_Q(env, action, *, trajectory_model, agent_model, log=False): """infer_Q Infer Q(state, action) via pyro's importance sampling. :param env: OpenAI Gym environment :param action: integer action :param trajectory_model: trajectory probabilistic program :param agent_model: agent's probabilistic program :param log: boolean; if True, print log info """ posterior = Importance( pyro.do(trajectory_model, {'A_0': torch.as_tensor(action)}), num_samples=args.num_samples, ).run(env, agent_model=agent_model) Q = EmpiricalMarginal(posterior, 'G').mean if log: print(f'Q({env.actions[action]}) = {Q.item()}') return Q
def simulate(self, state, time_left, output): """ Agent Simulation: Forward simulation process """ action = self.model(state, time_left)['action_{}_{}'.format(state,time_left)] interv_model = pyro.do(self.model, data={'action_{}_{}'.format(state,time_left): action}) next_state = interv_model(state, time_left)['next_state_{}_{}'.format(state,time_left)] next_utility = interv_model(state, time_left)['next_utility_{}_{}'.format(state,time_left)] if not torch.eq(time_left, torch.tensor(1.)): # if there are moves left future_utility_posterior = self.infer(next_state, time_left-1,'exp_utility') # print('simulate: posterior',future_utility_posterior, 'state',state,'timeleft',time_left) exp_utility = next_utility + torch.mean(future_utility_posterior) self.add_factor(exp_utility, 'exp_util_{}_{}'.format(state,time_left)) else: exp_utility = next_utility self.add_factor(exp_utility, 'exp_util_{}_{}'.format(state,time_left)) if output=='action': return action elif output=='exp_utility': return exp_utility
def intervention_prediction(node_of_interest, intervention, posterior, n_samples): intervention = { k: torch.tensor(intervention[k]).float().flatten() for k in intervention } intervened_model = pyro.do(generative_model, data=intervention) estimate = [] for _ in range(n_samples): exog_values_ = posterior.sample() exog_values = {} for index, var_name in enumerate(exog_sites_list): if var_name not in intervention.keys(): exog_values[var_name] = exog_values_[index] intervened_model_with_values = intervened_model( exog_values=exog_values) result = intervened_model_with_values[node_of_interest] estimate.append(result) return estimate
def intervention(model, evidence, infer, val, num_samples=1000): """ Uses pyro condition function with importance sampling to get the intervention probability of a particular value for the random variable under inference. :param func model: Probabilistic model defined with pyro sample methods. :param dict(str, torch.tensor) evidence: Dictionary with trace objects and their observed values. :param str infer: Trace object which needs to be inferred. :param int val: Value of the trace object for which the probabilities are required. :param int num_samples: Number of samples to run the inference alogrithm. :return: Probability of trace object being the value provided. :rtype: int """ intervention_model = pyro.do(model, data=evidence) posterior = pyro.infer.Importance(intervention_model, num_samples=num_samples).run() marginal = pyro.infer.EmpiricalMarginal(posterior, infer) samples = np.array([marginal().item() for _ in range(num_samples)]) return sum([1 for s in samples if s.item() == val]) / num_samples
def scm_covid_counterfactual( betas, max_abundance, observation, ras_intervention, spike_width: float = 1.0, svi: bool = True, samples: int = 5000, ) -> List[float]: gf_scm = SigmoidSCM(betas, max_abundance, spike_width) if svi: updated_noise, _ = gf_scm.update_noise_svi(observation, NOISE) else: updated_noise = gf_scm.update_noise_importance(observation, NOISE) counterfactual_model = do(gf_scm.model, ras_intervention) cf_posterior = gf_scm.infer(counterfactual_model, updated_noise) cf_cytokine_marginal = EmpiricalMarginal(cf_posterior, sites=['cytokine']) scm_causal_effect_samples = [ observation['cytokine'] - float(cf_cytokine_marginal.sample()) for _ in range(samples) ] return scm_causal_effect_samples
def intervention_prediction(self, node_of_interest, intervention, posterior, n_samples): """Given an exogenous posterior, sample then return the mean of the node of interest. Args: node_of_interest (str): the name of the noode of interest intervention (dict): a dictionary of {node_name: value} interventions evidence (dict): a dictionary of {node_name: value} evidence n_samples (int): the number of samples to take from the posterior. """ intervention = { k: torch.tensor([intervention[k]]).double().flatten() for k in intervention } self._intv_nodes = [k for k in intervention] intervened_model = pyro.do(self.model, data=intervention) estimate = [] for s in range(n_samples): estimate.append( intervened_model(noise=posterior)[node_of_interest]) self._intv_nodes = [] return estimate
def infer_Q(env, action, infer_type='intervention', *, agent_model): """infer_Q Infer Q(state, action) via pyro's importance sampling, via conditioning or intervention. :param env: OpenAI Gym environment :param action: integer action :param infer_type: type of inference; none, condition, or intervention :param agent_model: agent's probabilistic program """ if infer_type not in ('intervention', 'condition', 'none'): raise ValueError('Invalid inference type {infer_type}') if infer_type == 'intervention': model = pyro.do(trajectory_model, {'A_0': torch.tensor(action)}) elif infer_type == 'condition': model = pyro.condition(trajectory_model, {'A_0': torch.tensor(action)}) else: # infer_type == 'none' model = trajectory_model posterior = Importance(model, num_samples=args.num_samples).run( env, agent_model=agent_model) return EmpiricalMarginal(posterior, 'G').mean
.10, # P(C = 'on' | A = 'off', B = 'off') .90 # P(C = 'off' | A = 'off', B = 'off') ] ] ]) A = pyro.sample('A', dist.Categorical(probs=prob_A)) B = pyro.sample('B', dist.Categorical(probs=prob_B[A])) C = pyro.sample('C', dist.Categorical(probs=prob_C[A][B])) return C if __name__ == '__main__': intervened = pyro.do(intervention, {'B': torch.tensor(0)}) conditioned = pyro.condition(intervention, { 'B': torch.tensor(0), 'C': torch.tensor(0) }) posterior = pyro.infer.Importance(conditioned, num_samples=10000).run() marginal = pyro.infer.EmpiricalMarginal(posterior, 'A') samples = [marginal() for _ in range(10000)] ons = [sample for sample in samples if sample == torch.tensor(0)] print(len(ons) / len(samples)) # Plot the distribution of P(A | O = 'self', R = 'big') plt.figure(figsize=(14, 7)) plt.hist(samples, bins='auto') plt.xticks([0, 1], ['on', 'off'])
cond_noise = scm.update_noise_svi(cond_data) print(cond_data) rxs = [] for i in range(100): (rx,ry,_), _ = scm.model(cond_noise) rxs.append(rx) compare_to_density(ox, torch.cat(rxs)) _ =plt.suptitle("SCM Conditioned on Original", fontsize=18, fontstyle='italic') """## Counterfactuals""" # intervening on Shape, posX and PosY intervened_model = pyro.do(scm.model, data={ "Y_1": torch.tensor(0.), "Y_4": torch.tensor(30.), "Y_5": torch.tensor(25.), }) noise_data = {} for term, d in cond_noise.items(): noise_data[term] = d.loc # intervened_noise = scm.update_noise_svi(noise_data, intervened_model) (rx1,ry,_), _ = intervened_model(scm.init_noise) compare_to_density(ox, rx1) print(ry) rxs = [] for i in range(5000): (cfo1,ny1,nz1), _= intervened_model(cond_noise) rxs.append(cfo1)
# The exciting thing is Pyro comes with a "do" function for causal inference, it works exactly # the same semantically as "condition" but under the hood it implements intervention rather # than conditioning. # Note however, this will only save the "intervention step" of counterfactual computation, the # "do" function cannot compute other-world scenarios out of the box like # ---- Y in the world where we do(T=1) given Y=5 in the world where we do(T=0) ---- # we still have to do abduction and prediction ourselves. # "do" is unsupported for direct definitions of posteriors/augmented graphs # def scale_intervene(guess, measurement=9.5): # weight = pyro.sample("weight", dist.Normal(guess, 1.)) # return pyro.sample ("measure", dist.Normal(weight, 0.75), do=ERROR) # here we intervene on the measuremnet == 9.5 which should not affect P(weight) intervened_scale = pyro.do(scale, data={'measure': 9.5}) def intervene_wrapper(measurement=9.5): return pyro.do(scale, data={'measure': measurement}) # Unfortunately this isn't the end of the story # Pyro isn't SO plug-and-play that our posterior stocfuns are now ready to draw samples from # the posterior distribution. Unfortunateley, we still need to specify the nitty-gritty # of the inference algorithm and "press run". # Inference algorithms: # importance sampling, rejection sampling, sequential Monte Carlo, MCMC # and independent Metropolis-Hastings, and as variational distributions # These all require as input some notion of a "proposal distribution" e.g. for the random walk # in MCMC or the proposal for rejection sampling. Or the variational distribution for VI. # We need to (intelligently) choose the proposal distribution.
def intervene_wrapper(measurement=9.5): return pyro.do(scale, data={'measure': measurement})