def infer(): action_sampler = env.action_space trajectory_sampler = None imm_timestamp = 30 num_samples = 50 for i in range(50): posterior = Importance(model, num_samples=num_samples).run( trajectory_sampler, action_sampler, imm_timestamp) trajectory_sampler = EmpiricalMarginal(posterior) samples = trajectory_sampler.sample((num_samples, 1)) possible_vals, counts = torch.unique(input=torch.flatten(samples, end_dim=1), sorted=True, return_counts=True, dim=0) probs = torch.true_divide(counts, num_samples) assert torch.allclose(torch.sum(probs), torch.tensor(1.0)) print("possible_Traj") print(possible_vals) print(probs) imm_timestamp += 10 num_samples += 50 print("in {}th inference, sample traj is".format(i), trajectory_sampler.sample()) return trajectory_sampler
def policy_control_as_inference_like(env, *, trajectory_model, agent_model, log=False): """policy_control_as_inference_like Implements a control-as-inference-like policy which "maximizes" $\\Pr(A_0 \\mid S_0, high G)$. Not actually standard CaI, because we don't really condition on G; rather, we use $\\alpha G$ as a likelihood factor on sample traces. :param env: OpenAI Gym environment :param trajectory_model: trajectory probabilistic program :param agent_model: agent's probabilistic program :param log: boolean; if True, print log info """ inference = Importance(trajectory_model, num_samples=args.num_samples) posterior = inference.run(env, agent_model=agent_model, factor_G=True) marginal = EmpiricalMarginal(posterior, 'A_0') if log: samples = marginal.sample((args.num_samples, )) counts = Counter(samples.tolist()) hist = [ counts[i] / args.num_samples for i in range(env.action_space.n) ] print('policy:') print(tabulate([hist], headers=env.actions, tablefmt='fancy_grid')) return marginal.sample()
def scm_ras_erk_counterfactual(rates, totals, observation, ras_intervention, spike_width=1.0, svi=True): gf_scm = GF_SCM(rates, totals, spike_width) noise = { 'N_SOS': Normal(0., 1.), 'N_Ras': Normal(0., 1.), 'N_PI3K': Normal(0., 1.), 'N_AKT': Normal(0., 1.), 'N_Raf': Normal(0., 1.), 'N_Mek': Normal(0., 1.), 'N_Erk': Normal(0., 1.) } if svi: updated_noise, _ = gf_scm.update_noise_svi(observation, noise) else: updated_noise = gf_scm.update_noise_importance(observation, noise) counterfactual_model = do(gf_scm.model, ras_intervention) cf_posterior = gf_scm.infer(counterfactual_model, updated_noise) cf_erk_marginal = EmpiricalMarginal(cf_posterior, sites='Erk') scm_causal_effect_samples = [ observation['Erk'] - float(cf_erk_marginal.sample()) for _ in range(500) ] return scm_causal_effect_samples
def infer_prob(self, posterior, num_samples): marginal = EmpiricalMarginal(posterior) samples = marginal.sample((num_samples, 1)) possible_vals, counts = torch.unique(input=torch.flatten(samples, end_dim=1), sorted=True, return_counts=True, dim=0) probs = torch.true_divide(counts, num_samples) assert torch.allclose(torch.sum(probs), torch.tensor(1.0)) return possible_vals, probs
def scm_covid_counterfactual( rates, totals, observation, ras_intervention, spike_width=1.0, svi=True ): gf_scm = COVID_SCM(rates, totals, spike_width) noise = { 'N_SARS_COV2': (0., 5.), 'N_TOCI': (0., 5.), 'N_PRR': (0., 1.), 'N_ACE2': (0., 1.), 'N_AngII': (0., 1.), 'N_AGTR1': (0., 1.), 'N_ADAM17': (0., 1.), 'N_IL_6Ralpha': (0., 1.), 'N_sIL_6_alpha': (0., 1.), 'N_STAT3': (0., 1.), 'N_EGF': (0., 1.), 'N_TNF': (0., 1.), 'N_EGFR': (0., 1.), 'N_IL6_STAT3': (0., 1.), 'N_NF_xB': (0., 1.), 'N_IL_6_AMP': (0., 1.), 'N_cytokine': (0., 1.) } if svi: updated_noise, _ = gf_scm.update_noise_svi(observation, noise) else: updated_noise = gf_scm.update_noise_importance(observation, noise) counterfactual_model = do(gf_scm.model, ras_intervention) cf_posterior = gf_scm.infer(counterfactual_model, updated_noise) cf_cytokine_marginal = EmpiricalMarginal(cf_posterior, sites=['cytokine']) cf_il6amp_marginal = EmpiricalMarginal(cf_posterior, sites=['IL_6_AMP']) cf_nfxb_marginal = EmpiricalMarginal(cf_posterior, sites=['NF_xB']) cf_il6stat3_marginal = EmpiricalMarginal(cf_posterior, sites=['IL6_STAT3']) scm_causal_effect_samples = [ observation['cytokine'] - float(cf_cytokine_marginal.sample()) for _ in range(5000) ] il6amp_samples = cf_il6amp_marginal.sample((5000,)).tolist() nfxb_samples = cf_nfxb_marginal.sample((5000,)).tolist() il6stat3_samples = cf_il6stat3_marginal.sample((5000,)).tolist() return il6amp_samples, nfxb_samples, il6stat3_samples, scm_causal_effect_samples
def policy(t, env, *, trajectory_model, log=False): """policy :param t: time-step :param env: OpenAI Gym environment :param trajectory_model: trajectory probabilistic program :param log: boolean; if True, print log info """ inference = Importance(softmax_agent_model, num_samples=args.num_samples) posterior = inference.run(t, env, trajectory_model=trajectory_model) marginal = EmpiricalMarginal(posterior, f'A_{t}') if log: samples = marginal.sample((args.num_samples, )) counts = Counter(samples.tolist()) hist = [ counts[i] / args.num_samples for i in range(env.action_space.n) ] print('policy:') print(tabulate([hist], headers=env.actions, tablefmt='fancy_grid')) return marginal.sample()
def policy(env, log=False): """policy :param env: OpenAI Gym environment :param log: boolean; if True, print log info """ inference = Importance(softmax_agent_model, num_samples=args.num_samples) posterior = inference.run(env) marginal = EmpiricalMarginal(posterior, 'policy_vector') if log: policy_samples = marginal.sample((args.num_samples, )) action_samples = policy_samples[:, env.state] counts = Counter(action_samples.tolist()) hist = [ counts[i] / args.num_samples for i in range(env.action_space.n) ] print('policy:') print(tabulate([hist], headers=env.actions, tablefmt='fancy_grid')) policy_vector = marginal.sample() return policy_vector[env.state]
def counterfactual_inference(self, condition_data: dict, intervention_data: dict, target, svi=True): # Step 1. Condition the model conditioned_model = self.condition(condition_data) # Step 2. Noise abduction if svi: updated_noise, _ = self.update_noise_svi(conditioned_model) # Step 3. Intervene intervention_model = self.intervention(intervention_data) # Pass abducted noises to intervention model cf_posterior = self.infer(intervention_model, updated_noise) marginal = EmpiricalMarginal(cf_posterior, target) counterfactual_samples = [marginal.sample() for _ in range(1000)] # Calculate causal effect scm_causal_effect_samples = [ condition_data[target] - float(marginal.sample()) for _ in range(5000) ] return scm_causal_effect_samples, counterfactual_samples
def scm_covid_counterfactual( betas, max_abundance, observation, ras_intervention, spike_width: float = 1.0, svi: bool = True, samples: int = 5000, ) -> List[float]: gf_scm = SigmoidSCM(betas, max_abundance, spike_width) if svi: updated_noise, _ = gf_scm.update_noise_svi(observation, NOISE) else: updated_noise = gf_scm.update_noise_importance(observation, NOISE) counterfactual_model = do(gf_scm.model, ras_intervention) cf_posterior = gf_scm.infer(counterfactual_model, updated_noise) cf_cytokine_marginal = EmpiricalMarginal(cf_posterior, sites=['cytokine']) scm_causal_effect_samples = [ observation['cytokine'] - float(cf_cytokine_marginal.sample()) for _ in range(samples) ] return scm_causal_effect_samples
obs = truncation_label[i]) pyro.clear_param_store() hmc_kernel = HMC(model, step_size = 0.1, num_steps = 4) mcmc_run = MCMC(hmc_kernel, num_samples=5, warmup_steps=1).run(x, y, truncation_label) marginal_a = EmpiricalMarginal(mcmc_run, sites="a_model") posterior_a = [marginal_a.sample() for i in range(50)] sns.distplot(posterior_a) """# Modeling using HMC with Vectorized Data Here we try to make the estimation faster using the `plate` and `mask` function. """ def model(x, y, truncation_label): a_model = pyro.sample("a_model", dist.Normal(0, 10)) b_model = pyro.sample("b_model", dist.Normal(0, 10)) link = torch.nn.functional.softplus(a_model * x + b_model) with pyro.plate("data"):