Ejemplo n.º 1
0
    def do_causal(self,
                  their_tensors,
                  absent_tensors,
                  movie_inds,
                  num_samples=1000):
        # With confounding
        their_do = pyro.do(model, data={"x": their_tensors})
        absent_do = pyro.do(model, data={"x": absent_tensors})

        their_do_y = []
        for _ in range(num_samples):
            their_do_y.append(
                torch.sum(their_do(self.step2_params)['y'][movie_inds]).item())

        absent_do_y = []
        for _ in range(num_samples):
            absent_do_y.append(
                torch.sum(absent_do(
                    self.step2_params)['y'][movie_inds]).item())

        their_do_mean = np.mean(their_do_y)
        absent_do_mean = np.mean(absent_do_y)
        causal_effect_conf = their_do_mean - absent_do_mean

        return causal_effect_conf
Ejemplo n.º 2
0
    def twin_query(self,
                   node_of_interest,
                   evidence,
                   intervention,
                   n_samples,
                   distribution=False,
                   merge=False):
        """Run the twin network counterfactual inference procedure.

        Args:
            node_of_interest (str): the name of the noode of interest
            intervention (dict): a dictionary of {node_name: value} interventions
            evidence (dict): a dictionary of {node_name: value} evidence
            n_samples (int): the number of samples to take during importance sampling and from the posterior.
            distribution (bool): if True, return samples. If False, returns mean and sd. of distribution.
        """
        if not self.scm.twin_exists:
            self.scm.create_twin_network()
        if merge:
            self.scm.merge_in_twin(node_of_interest, intervention)
        self.G_inference = self.scm.twin_G
        intervention = {
            "{}tn".format(k):
            torch.tensor([intervention[k]]).double().flatten()
            for k in intervention
        }
        node_of_interest = "{}tn".format(
            node_of_interest
        ) if "tn" not in node_of_interest else node_of_interest
        self._intv_nodes = [k for k in intervention]
        self._noi_nodes = [node_of_interest]
        intervened_model = pyro.do(self.model, data=intervention)
        intervened_guide = pyro.do(self.guide, data=intervention)
        if self.verbose:
            print("Performing Twin Network inference... ", end="", flush=True)
        t_twin = time.time()
        posterior = self.get_posterior(self._noi_nodes,
                                       evidence,
                                       n_samples,
                                       custom_model=intervened_model,
                                       custom_guide=intervened_guide,
                                       twin=True)
        t_twin = time.time() - t_twin
        if self.verbose:
            print("✓ ({}s)".format(round(t_twin, 3)))
        self._intv_nodes = []
        self._noi_nodes = []
        samples = posterior.sample_n((n_samples))
        if not distribution:
            return samples.mean().numpy(), samples.std().numpy()
        else:
            return samples
Ejemplo n.º 3
0
def scm_ras_erk_counterfactual(rates,
                               totals,
                               observation,
                               ras_intervention,
                               spike_width=1.0,
                               svi=True):
    gf_scm = GF_SCM(rates, totals, spike_width)
    noise = {
        'N_SOS': Normal(0., 1.),
        'N_Ras': Normal(0., 1.),
        'N_PI3K': Normal(0., 1.),
        'N_AKT': Normal(0., 1.),
        'N_Raf': Normal(0., 1.),
        'N_Mek': Normal(0., 1.),
        'N_Erk': Normal(0., 1.)
    }
    if svi:
        updated_noise, _ = gf_scm.update_noise_svi(observation, noise)
    else:
        updated_noise = gf_scm.update_noise_importance(observation, noise)
    counterfactual_model = do(gf_scm.model, ras_intervention)
    cf_posterior = gf_scm.infer(counterfactual_model, updated_noise)
    cf_erk_marginal = EmpiricalMarginal(cf_posterior, sites='Erk')

    scm_causal_effect_samples = [
        observation['Erk'] - float(cf_erk_marginal.sample())
        for _ in range(500)
    ]
    return scm_causal_effect_samples
Ejemplo n.º 4
0
 def _pull_bandit(self):
     observed_env = self.environment.observe()
     algo_choice = self._select_arm(observed_env)
     model = self.bandits.gambler_model
     intervention_model = do(model, {'arm':algo_choice})
     result = intervention_model(observed_env)
     self.update_parameters(algo_choice, result, observed_env)
Ejemplo n.º 5
0
    def intervention(self, intervention_data: dict):
        """
        It intervenes self.model with intervention data
        Returns: intervention pyro model
        """

        intervention_model = pyro.do(self.model, intervention_data)
        return intervention_model
Ejemplo n.º 6
0
def doCausal(their_tensors, absent_tensors, movie_inds):
    # With confounding
    their_do = pyro.do(model, data={"x": their_tensors})
    absent_do = pyro.do(model, data={"x": absent_tensors})

    their_do_y = []
    for _ in range(1000):
        their_do_y.append(torch.sum(their_do(p2)['y'][movie_inds]).item())

    absent_do_y = []
    for _ in range(1000):
        absent_do_y.append(torch.sum(absent_do(p2)['y'][movie_inds]).item())

    their_do_mean = np.mean(their_do_y)
    absent_do_mean = np.mean(absent_do_y)
    causal_effect_conf = their_do_mean - absent_do_mean

    return causal_effect_conf
Ejemplo n.º 7
0
def imagine_next_step(env, action, i):
    """Agent imagines next time step"""
    sim_env = deepcopy(env)
    state = sim_env.s
    int_model = pyro.do(model, {f'action{state}{i}': action})
    sim_env, _, _, _, _ = int_model(sim_env, i)
    # sanity check
    assert sim_env.lastaction == action
    return sim_env
Ejemplo n.º 8
0
    def model_do_sample(self, do_dict):
        # sample the graph given the do-variables in do_dict

        data_in = {}
        for item in do_dict:
            data_in[item] = do_dict[item]

        do_model = pyro.do(self.model_sample, data=data_in)
        return do_model()
Ejemplo n.º 9
0
    def do_predict(self, X_test, num_samples=1000):

        if self.test_params is None:
            self.infer_z(X_test)

        do = pyro.do(model, data={"x": X_test})
        predictions = np.zeros((X_test.shape[0]))
        for _ in range(num_samples):
            y_pred = do(self.test_params)['y']
            predictions += y_pred.detach().numpy()

        return predictions / num_samples, self.test_params
Ejemplo n.º 10
0
	def play(self, state, time_left):
		"""
		Puts policy into action and moves to next state
		"""
		while not torch.eq(time_left, torch.tensor(0.)):
			action_posterior = self.infer(state, time_left,'action')
			action = self.policy(action_posterior)
			interv_model = pyro.do(self.model, data={'action_{}_{}'.format(state,time_left): action})
			next_state = interv_model(state, time_left)['next_state_{}_{}'.format(state,time_left)]
			
			# Print Trajectory and Actions  
			print('State:',state.item(),
				', Action: ', list(Environment().action_dictionary.keys())[action])
			return action, self.play(next_state, time_left-1)
Ejemplo n.º 11
0
def main():
    fl_env = gym.make('FrozenLake-v0', is_slippery=False)
    fl_env.reset()
    fl_env.s = 1
    for t in range(25):
        pyro.clear_param_store()
        action = policy(fl_env)
        int_model = pyro.do(model,
                            {f'action{fl_env.s}{t}': torch.tensor(action)})
        fl_env, observation, reward, done, info = int_model(fl_env, t)
        fl_env.render()
        if done:
            print("Episode finished after {} timesteps".format(t + 1))
            break
    fl_env.close()
Ejemplo n.º 12
0
def scm_covid_counterfactual(
    rates,
    totals,
    observation,
    ras_intervention,
    spike_width=1.0,
    svi=True
):
    gf_scm = COVID_SCM(rates, totals, spike_width)
    noise = {
        'N_SARS_COV2': (0., 5.),
        'N_TOCI': (0., 5.),
        'N_PRR': (0., 1.),
        'N_ACE2': (0., 1.),
        'N_AngII': (0., 1.),
        'N_AGTR1': (0., 1.),
        'N_ADAM17': (0., 1.),
        'N_IL_6Ralpha': (0., 1.),
        'N_sIL_6_alpha': (0., 1.),
        'N_STAT3': (0., 1.),
        'N_EGF': (0., 1.),
        'N_TNF': (0., 1.),
        'N_EGFR': (0., 1.),
        'N_IL6_STAT3': (0., 1.),
        'N_NF_xB': (0., 1.),
        'N_IL_6_AMP': (0., 1.),
        'N_cytokine': (0., 1.)
    }
    if svi:
        updated_noise, _ = gf_scm.update_noise_svi(observation, noise)
    else:
        updated_noise = gf_scm.update_noise_importance(observation, noise)
    counterfactual_model = do(gf_scm.model, ras_intervention)
    cf_posterior = gf_scm.infer(counterfactual_model, updated_noise)
    cf_cytokine_marginal = EmpiricalMarginal(cf_posterior, sites=['cytokine'])
    cf_il6amp_marginal = EmpiricalMarginal(cf_posterior, sites=['IL_6_AMP'])
    cf_nfxb_marginal = EmpiricalMarginal(cf_posterior, sites=['NF_xB'])
    cf_il6stat3_marginal = EmpiricalMarginal(cf_posterior, sites=['IL6_STAT3'])
    scm_causal_effect_samples = [
        observation['cytokine'] - float(cf_cytokine_marginal.sample())
        for _ in range(5000)
    ]
    il6amp_samples = cf_il6amp_marginal.sample((5000,)).tolist()
    nfxb_samples = cf_nfxb_marginal.sample((5000,)).tolist()
    il6stat3_samples = cf_il6stat3_marginal.sample((5000,)).tolist()
    return il6amp_samples, nfxb_samples, il6stat3_samples, scm_causal_effect_samples
Ejemplo n.º 13
0
    def model_do_cond_sample(self, do_dict, data_dict):
        # sample the graph given do-variables in do_dict and conditioned variables in data_dict

        if np.any([[item1 == item2 for item1 in do_dict]
                   for item2 in data_dict]):
            print('overlapping lists!')
            return
        else:
            do_dict_in = {}
            for item in do_dict:
                do_dict_in[item] = do_dict[item]

            data_dict_in = {}
            for item in data_dict:
                data_dict_in[item] = data_dict[item]

            do_model = pyro.do(self.model_sample, data=do_dict_in)
            cond_model = pyro.condition(do_model, data=data_dict_in)
            return cond_model()
def infer_Q(env, action, *, trajectory_model, agent_model, log=False):
    """infer_Q

    Infer Q(state, action) via pyro's importance sampling.

    :param env: OpenAI Gym environment
    :param action: integer action
    :param trajectory_model: trajectory probabilistic program
    :param agent_model: agent's probabilistic program
    :param log: boolean; if True, print log info
    """
    posterior = Importance(
        pyro.do(trajectory_model, {'A_0': torch.as_tensor(action)}),
        num_samples=args.num_samples,
    ).run(env, agent_model=agent_model)
    Q = EmpiricalMarginal(posterior, 'G').mean

    if log:
        print(f'Q({env.actions[action]}) = {Q.item()}')

    return Q
Ejemplo n.º 15
0
	def simulate(self, state, time_left, output):
		"""
		Agent Simulation: Forward simulation process
		"""
		action = self.model(state, time_left)['action_{}_{}'.format(state,time_left)]
		interv_model = pyro.do(self.model, data={'action_{}_{}'.format(state,time_left): action})
		next_state = interv_model(state, time_left)['next_state_{}_{}'.format(state,time_left)]
		next_utility = interv_model(state, time_left)['next_utility_{}_{}'.format(state,time_left)]

		if not torch.eq(time_left, torch.tensor(1.)): # if there are moves left
			future_utility_posterior = self.infer(next_state, time_left-1,'exp_utility')
#             print('simulate: posterior',future_utility_posterior, 'state',state,'timeleft',time_left)
			exp_utility = next_utility + torch.mean(future_utility_posterior)
			self.add_factor(exp_utility, 'exp_util_{}_{}'.format(state,time_left))
		else:
			exp_utility = next_utility
			self.add_factor(exp_utility, 'exp_util_{}_{}'.format(state,time_left))

		if output=='action':
			return action
		elif output=='exp_utility':
			return exp_utility
Ejemplo n.º 16
0
    def intervention_prediction(node_of_interest, intervention, posterior,
                                n_samples):
        intervention = {
            k: torch.tensor(intervention[k]).float().flatten()
            for k in intervention
        }
        intervened_model = pyro.do(generative_model, data=intervention)

        estimate = []
        for _ in range(n_samples):
            exog_values_ = posterior.sample()
            exog_values = {}
            for index, var_name in enumerate(exog_sites_list):
                if var_name not in intervention.keys():
                    exog_values[var_name] = exog_values_[index]

            intervened_model_with_values = intervened_model(
                exog_values=exog_values)
            result = intervened_model_with_values[node_of_interest]
            estimate.append(result)

        return estimate
Ejemplo n.º 17
0
    def intervention(model, evidence, infer, val, num_samples=1000):
        """
        Uses pyro condition function with importance sampling to get the intervention probability 
        of a particular value for the random variable under inference.
        
        :param func model: Probabilistic model defined with pyro sample methods.
        :param dict(str, torch.tensor) evidence: Dictionary with trace objects and their observed values.
        :param str infer: Trace object which needs to be inferred.
        :param int val: Value of the trace object for which the probabilities are required.
        :param int num_samples: Number of samples to run the inference alogrithm.
        
        :return: Probability of trace object being the value provided.
        :rtype: int
        """

        intervention_model = pyro.do(model, data=evidence)
        posterior = pyro.infer.Importance(intervention_model,
                                          num_samples=num_samples).run()
        marginal = pyro.infer.EmpiricalMarginal(posterior, infer)
        samples = np.array([marginal().item() for _ in range(num_samples)])

        return sum([1 for s in samples if s.item() == val]) / num_samples
Ejemplo n.º 18
0
def scm_covid_counterfactual(
    betas,
    max_abundance,
    observation,
    ras_intervention,
    spike_width: float = 1.0,
    svi: bool = True,
    samples: int = 5000,
) -> List[float]:
    gf_scm = SigmoidSCM(betas, max_abundance, spike_width)

    if svi:
        updated_noise, _ = gf_scm.update_noise_svi(observation, NOISE)
    else:
        updated_noise = gf_scm.update_noise_importance(observation, NOISE)
    counterfactual_model = do(gf_scm.model, ras_intervention)
    cf_posterior = gf_scm.infer(counterfactual_model, updated_noise)
    cf_cytokine_marginal = EmpiricalMarginal(cf_posterior, sites=['cytokine'])
    scm_causal_effect_samples = [
        observation['cytokine'] - float(cf_cytokine_marginal.sample())
        for _ in range(samples)
    ]
    return scm_causal_effect_samples
Ejemplo n.º 19
0
 def intervention_prediction(self, node_of_interest, intervention,
                             posterior, n_samples):
     """Given an exogenous posterior, sample then return the mean of the node of interest.
     
     Args:
         node_of_interest (str): the name of the noode of interest
         intervention (dict): a dictionary of {node_name: value} interventions
         evidence (dict): a dictionary of {node_name: value} evidence
         n_samples (int): the number of samples to take from the posterior.
         
     """
     intervention = {
         k: torch.tensor([intervention[k]]).double().flatten()
         for k in intervention
     }
     self._intv_nodes = [k for k in intervention]
     intervened_model = pyro.do(self.model, data=intervention)
     estimate = []
     for s in range(n_samples):
         estimate.append(
             intervened_model(noise=posterior)[node_of_interest])
     self._intv_nodes = []
     return estimate
def infer_Q(env, action, infer_type='intervention', *, agent_model):
    """infer_Q

    Infer Q(state, action) via pyro's importance sampling, via conditioning or
    intervention.

    :param env: OpenAI Gym environment
    :param action: integer action
    :param infer_type: type of inference; none, condition, or intervention
    :param agent_model: agent's probabilistic program
    """
    if infer_type not in ('intervention', 'condition', 'none'):
        raise ValueError('Invalid inference type {infer_type}')

    if infer_type == 'intervention':
        model = pyro.do(trajectory_model, {'A_0': torch.tensor(action)})
    elif infer_type == 'condition':
        model = pyro.condition(trajectory_model, {'A_0': torch.tensor(action)})
    else:  # infer_type == 'none'
        model = trajectory_model

    posterior = Importance(model, num_samples=args.num_samples).run(
        env, agent_model=agent_model)
    return EmpiricalMarginal(posterior, 'G').mean
Ejemplo n.º 21
0
                .10,  # P(C = 'on' | A = 'off', B = 'off')
                .90  # P(C = 'off' | A = 'off', B = 'off')
            ]
        ]
    ])

    A = pyro.sample('A', dist.Categorical(probs=prob_A))
    B = pyro.sample('B', dist.Categorical(probs=prob_B[A]))
    C = pyro.sample('C', dist.Categorical(probs=prob_C[A][B]))

    return C


if __name__ == '__main__':

    intervened = pyro.do(intervention, {'B': torch.tensor(0)})
    conditioned = pyro.condition(intervention, {
        'B': torch.tensor(0),
        'C': torch.tensor(0)
    })
    posterior = pyro.infer.Importance(conditioned, num_samples=10000).run()
    marginal = pyro.infer.EmpiricalMarginal(posterior, 'A')
    samples = [marginal() for _ in range(10000)]

    ons = [sample for sample in samples if sample == torch.tensor(0)]
    print(len(ons) / len(samples))

    # Plot the distribution of P(A | O = 'self', R = 'big')
    plt.figure(figsize=(14, 7))
    plt.hist(samples, bins='auto')
    plt.xticks([0, 1], ['on', 'off'])
Ejemplo n.º 22
0
cond_noise = scm.update_noise_svi(cond_data)
print(cond_data)

rxs = []
for i in range(100):
     (rx,ry,_), _ = scm.model(cond_noise)
     rxs.append(rx)
compare_to_density(ox, torch.cat(rxs))
_ =plt.suptitle("SCM Conditioned on Original", fontsize=18, fontstyle='italic')

"""## Counterfactuals"""

# intervening on Shape, posX and PosY 
intervened_model = pyro.do(scm.model, data={
    "Y_1": torch.tensor(0.),
    "Y_4": torch.tensor(30.),
    "Y_5": torch.tensor(25.),
})
noise_data = {}
for term, d in cond_noise.items():
  noise_data[term] = d.loc
# intervened_noise = scm.update_noise_svi(noise_data, intervened_model)

(rx1,ry,_), _ = intervened_model(scm.init_noise)
compare_to_density(ox, rx1)
print(ry)

rxs = []
for i in range(5000):
     (cfo1,ny1,nz1), _= intervened_model(cond_noise)
     rxs.append(cfo1)
Ejemplo n.º 23
0
# The exciting thing is Pyro comes with a "do" function for causal inference, it works exactly
# the same semantically as "condition" but under the hood it implements intervention rather
# than conditioning.
# Note however, this will only save the "intervention step" of counterfactual computation, the
# "do" function cannot compute other-world scenarios out of the box like 
#      ---- Y in the world where we do(T=1) given Y=5 in the world where we do(T=0) ----
# we still have to do abduction and prediction ourselves.

# "do" is unsupported for direct definitions of posteriors/augmented graphs
# def scale_intervene(guess, measurement=9.5):
#     weight = pyro.sample("weight", dist.Normal(guess, 1.))
#     return pyro.sample ("measure", dist.Normal(weight, 0.75), do=ERROR)

# here we intervene on the measuremnet == 9.5 which should not affect P(weight)
intervened_scale = pyro.do(scale, data={'measure': 9.5})

def intervene_wrapper(measurement=9.5):
    return pyro.do(scale, data={'measure': measurement})

# Unfortunately this isn't the end of the story
# Pyro isn't SO plug-and-play that our posterior stocfuns are now ready to draw samples from
# the posterior distribution. Unfortunateley, we still need to specify the nitty-gritty
# of the inference algorithm and "press run".

# Inference algorithms:
# importance sampling, rejection sampling, sequential Monte Carlo, MCMC
# and independent Metropolis-Hastings, and as variational distributions
# These all require as input some notion of a "proposal distribution" e.g. for the random walk
# in MCMC or the proposal for rejection sampling. Or the variational distribution for VI.
# We need to (intelligently) choose the proposal distribution.
Ejemplo n.º 24
0
def intervene_wrapper(measurement=9.5):
    return pyro.do(scale, data={'measure': measurement})