def infer(): action_sampler = env.action_space trajectory_sampler = None imm_timestamp = 30 num_samples = 50 for i in range(50): posterior = Importance(model, num_samples=num_samples).run( trajectory_sampler, action_sampler, imm_timestamp) trajectory_sampler = EmpiricalMarginal(posterior) samples = trajectory_sampler.sample((num_samples, 1)) possible_vals, counts = torch.unique(input=torch.flatten(samples, end_dim=1), sorted=True, return_counts=True, dim=0) probs = torch.true_divide(counts, num_samples) assert torch.allclose(torch.sum(probs), torch.tensor(1.0)) print("possible_Traj") print(possible_vals) print(probs) imm_timestamp += 10 num_samples += 50 print("in {}th inference, sample traj is".format(i), trajectory_sampler.sample()) return trajectory_sampler
def _vi_ape(model, design, observation_labels, target_labels, vi_parameters, is_parameters, y_dist=None): svi_num_steps = vi_parameters.pop('num_steps') def posterior_entropy(y_dist, design): # Important that y_dist is sampled *within* the function y = pyro.sample("conditioning_y", y_dist) y_dict = {label: y[i, ...] for i, label in enumerate(observation_labels)} conditioned_model = pyro.condition(model, data=y_dict) svi = SVI(conditioned_model, **vi_parameters) with poutine.block(): for _ in range(svi_num_steps): svi.step(design) # Recover the entropy with poutine.block(): guide = vi_parameters["guide"] entropy = mean_field_entropy(guide, [design], whitelist=target_labels) return entropy if y_dist is None: y_dist = EmpiricalMarginal(Importance(model, **is_parameters).run(design), sites=observation_labels) # Calculate the expected posterior entropy under this distn of y loss_dist = EmpiricalMarginal(Search(posterior_entropy).run(y_dist, design)) loss = loss_dist.mean return loss
def infer_prob(self, posterior, num_samples): marginal = EmpiricalMarginal(posterior) samples = marginal.sample((num_samples, 1)) possible_vals, counts = torch.unique(input=torch.flatten(samples, end_dim=1), sorted=True, return_counts=True, dim=0) probs = torch.true_divide(counts, num_samples) assert torch.allclose(torch.sum(probs), torch.tensor(1.0)) return possible_vals, probs
def _laplace_vi_ape(model, design, observation_labels, target_labels, guide, loss, optim, num_steps, final_num_samples, y_dist=None): def posterior_entropy(y_dist, design): # Important that y_dist is sampled *within* the function y = pyro.sample("conditioning_y", y_dist) y_dict = {label: y[i, ...] for i, label in enumerate(observation_labels)} conditioned_model = pyro.condition(model, data=y_dict) # Here just using SVI to run the MAP optimization guide.train() svi = SVI(conditioned_model, guide=guide, loss=loss, optim=optim) with poutine.block(): for _ in range(num_steps): svi.step(design) # Recover the entropy with poutine.block(): final_loss = loss(conditioned_model, guide, design) guide.finalize(final_loss, target_labels) entropy = mean_field_entropy(guide, [design], whitelist=target_labels) return entropy if y_dist is None: y_dist = EmpiricalMarginal(Importance(model, num_samples=final_num_samples).run(design), sites=observation_labels) # Calculate the expected posterior entropy under this distn of y loss_dist = EmpiricalMarginal(Search(posterior_entropy).run(y_dist, design)) ape = loss_dist.mean return ape
def posterior_to_xarray(self): """Convert the posterior to an xarray dataset.""" # Do not make pyro a requirement from pyro.infer import EmpiricalMarginal try: # Try pyro>=0.3 release syntax data = { name: utils.expand_dims(samples.enumerate_support().squeeze()) if self.posterior.num_chains == 1 else samples.enumerate_support().squeeze() for name, samples in self.posterior.marginal( sites=self.latent_vars).empirical.items() } except AttributeError: # Use pyro<0.3 release syntax data = {} for var_name in self.latent_vars: # pylint: disable=no-member samples = EmpiricalMarginal( self.posterior, sites=var_name).get_samples_and_weights()[0] samples = samples.numpy().squeeze() data[var_name] = utils.expand_dims(samples) return dict_to_dataset(data, library=self.pyro, coords=self.coords, dims=self.dims)
def scm_ras_erk_counterfactual(rates, totals, observation, ras_intervention, spike_width=1.0, svi=True): gf_scm = GF_SCM(rates, totals, spike_width) noise = { 'N_SOS': Normal(0., 1.), 'N_Ras': Normal(0., 1.), 'N_PI3K': Normal(0., 1.), 'N_AKT': Normal(0., 1.), 'N_Raf': Normal(0., 1.), 'N_Mek': Normal(0., 1.), 'N_Erk': Normal(0., 1.) } if svi: updated_noise, _ = gf_scm.update_noise_svi(observation, noise) else: updated_noise = gf_scm.update_noise_importance(observation, noise) counterfactual_model = do(gf_scm.model, ras_intervention) cf_posterior = gf_scm.infer(counterfactual_model, updated_noise) cf_erk_marginal = EmpiricalMarginal(cf_posterior, sites='Erk') scm_causal_effect_samples = [ observation['Erk'] - float(cf_erk_marginal.sample()) for _ in range(500) ] return scm_causal_effect_samples
def policy_control_as_inference_like(env, *, trajectory_model, agent_model, log=False): """policy_control_as_inference_like Implements a control-as-inference-like policy which "maximizes" $\\Pr(A_0 \\mid S_0, high G)$. Not actually standard CaI, because we don't really condition on G; rather, we use $\\alpha G$ as a likelihood factor on sample traces. :param env: OpenAI Gym environment :param trajectory_model: trajectory probabilistic program :param agent_model: agent's probabilistic program :param log: boolean; if True, print log info """ inference = Importance(trajectory_model, num_samples=args.num_samples) posterior = inference.run(env, agent_model=agent_model, factor_G=True) marginal = EmpiricalMarginal(posterior, 'A_0') if log: samples = marginal.sample((args.num_samples, )) counts = Counter(samples.tolist()) hist = [ counts[i] / args.num_samples for i in range(env.action_space.n) ] print('policy:') print(tabulate([hist], headers=env.actions, tablefmt='fancy_grid')) return marginal.sample()
def observed_data_to_xarray(self): """Convert observed data to xarray.""" from pyro.infer import EmpiricalMarginal data = {} for var_name in self.observed_vars: samples = EmpiricalMarginal( self.posterior, sites=var_name).get_samples_and_weights()[0] data[var_name] = np.expand_dims(samples.numpy().squeeze(), 0) return dict_to_dataset(data, library=self.pyro, coords=self.coords, dims=self.dims)
def posterior_to_xarray(self): """Convert the posterior to an xarray dataset.""" # Do not make pyro a requirement from pyro.infer import EmpiricalMarginal data = {} for var_name in self.latent_vars: samples = EmpiricalMarginal( self.posterior, sites=var_name).get_samples_and_weights()[0] data[var_name] = np.expand_dims(samples.numpy().squeeze(), 0) return dict_to_dataset(data, library=self.pyro, coords=self.coords, dims=self.dims)
def main(args): # create an importance sampler (the prior is used as the proposal distribution) importance = Importance(model, guide=None, num_samples=args.num_samples) # get posterior samples of mu (which is the return value of model) # from the raw execution traces provided by the importance sampler. print("doing importance sampling...") emp_marginal = EmpiricalMarginal(importance.run(observed_data)) # calculate statistics over posterior samples posterior_mean = emp_marginal.mean posterior_std_dev = emp_marginal.variance.sqrt() # report results inferred_mu = posterior_mean.item() inferred_mu_uncertainty = posterior_std_dev.item() print("the coefficient of friction inferred by pyro is %.3f +- %.3f" % (inferred_mu, inferred_mu_uncertainty)) # note that, given the finite step size in the simulator, the simulated descent times will # not precisely match the numbers from the analytic result. # in particular the first two numbers reported below should match each other pretty closely # but will be systematically off from the third number print("the mean observed descent time in the dataset is: %.4f seconds" % observed_mean) print( "the (forward) simulated descent time for the inferred (mean) mu is: %.4f seconds" % simulate(posterior_mean).item()) print(( "disregarding measurement noise, elementary calculus gives the descent time\n" + "for the inferred (mean) mu as: %.4f seconds") % analytic_T(posterior_mean.item())) """
def get_avg_estimates(x, y): print('Getting posterior estimates') estimates = svi.run(x, y) sites = [ 'w_prior1', 'w_prior2', 'w_prior3', 'w_prior4', 'w_prior5', 'w_prior6', 'w_prior7', 'w_prior8', 'w_prior9', 'w_prior10', 'w_prior11', 'w_prior12', 'w_prior13', 'w_prior14', 'w_prior15', 'w_prior16', 'w_prior17', 'w_prior18', 'w_prior19', 'b_prior', 'sigma' ] svi_samples = { site: EmpiricalMarginal( estimates, sites=site).enumerate_support().detach().cpu().numpy() for site in sites } summ = summary(svi_samples) post_means = [] for site in summ: if site != 'sigma' and site != 'b_prior': post_means.append(summ[site]['mean'].values) post_means = np.array(post_means) post_means.flatten() post_means = torch.tensor(post_means) post_weights = post_means.type(torch.FloatTensor) sigma = torch.tensor(summ['sigma']['mean'].values).type(torch.FloatTensor) bias = torch.tensor(summ['b_prior']['mean'].values).type(torch.FloatTensor) return post_weights, sigma, bias
def test_nuts_conjugate_gaussian(fixture, num_samples, warmup_steps, hmc_params, expected_means, expected_precs, mean_tol, std_tol): pyro.get_param_store().clear() nuts_kernel = NUTS(fixture.model, hmc_params['step_size']) mcmc_run = MCMC(nuts_kernel, num_samples, warmup_steps).run(fixture.data) for i in range(1, fixture.chain_len + 1): param_name = 'loc_' + str(i) marginal = EmpiricalMarginal(mcmc_run, sites=param_name) latent_loc = marginal.mean latent_std = marginal.variance.sqrt() expected_mean = torch.ones(fixture.dim) * expected_means[i - 1] expected_std = 1 / torch.sqrt( torch.ones(fixture.dim) * expected_precs[i - 1]) # Actual vs expected posterior means for the latents logger.info('Posterior mean (actual) - {}'.format(param_name)) logger.info(latent_loc) logger.info('Posterior mean (expected) - {}'.format(param_name)) logger.info(expected_mean) assert_equal(rmse(latent_loc, expected_mean).item(), 0.0, prec=mean_tol) # Actual vs expected posterior precisions for the latents logger.info('Posterior std (actual) - {}'.format(param_name)) logger.info(latent_std) logger.info('Posterior std (expected) - {}'.format(param_name)) logger.info(expected_std) assert_equal(rmse(latent_std, expected_std).item(), 0.0, prec=std_tol)
def main(args): nuts_kernel = NUTS(conditioned_model, adapt_step_size=True) posterior = MCMC(nuts_kernel, num_samples=args.num_samples, warmup_steps=args.warmup_steps)\ .run(model, data.sigma, data.y) marginal_mu_tau = EmpiricalMarginal(posterior, sites=["mu", "tau"])\ .get_samples_and_weights()[0].squeeze().numpy() marginal_eta = EmpiricalMarginal(posterior, sites=["eta"])\ .get_samples_and_weights()[0].squeeze().numpy() marginal = np.concatenate([marginal_mu_tau, marginal_eta], axis=1) params = [ 'mu', 'tau', 'eta[0]', 'eta[1]', 'eta[2]', 'eta[3]', 'eta[4]', 'eta[5]', 'eta[6]', 'eta[7]' ] df = pd.DataFrame(marginal, columns=params).transpose() df_summary = df.apply(pd.Series.describe, axis=1)[["mean", "std", "25%", "50%", "75%"]] logging.info(df_summary)
def update_noise_importance(self, observed_steady_state, initial_noise): observation_model = condition(self.noisy_model, observed_steady_state) posterior = self.infer(observation_model, initial_noise) updated_noise = { k: EmpiricalMarginal(posterior, sites=k) for k in initial_noise.keys() } return updated_noise
def update_belief(self, state): # can use cache strategy to speed up running time observations = self.observe(state) # observations at state if observations: num_samples = 20 is_belief = Importance(self.belief_model, num_samples = num_samples).run(self.belief, observations) is_marginal = EmpiricalMarginal(is_belief) self.belief = is_marginal
def get_posterior_mean(posterior, n_samples=30): """ Calculate posterior mean """ # Sample marginal_dist = EmpiricalMarginal(posterior).sample( (n_samples, 1)).float() # assumed to be all the same return torch.mean(marginal_dist)
def test_importance_prior(self): posterior = pyro.infer.Importance( self.model, guide=None, num_samples=10000 ).run() marginal = EmpiricalMarginal(posterior) assert_equal(0, torch.norm(marginal.mean - self.loc_mean).item(), prec=0.01) assert_equal( 0, torch.norm(marginal.variance.sqrt() - self.loc_stddev).item(), prec=0.1 )
def infer_prob(self, posterior, possible_vals, num_samples): counts = torch.zeros(possible_vals.shape).float() marginal_dist = EmpiricalMarginal(posterior).sample( (num_samples, 1)).float() # count the sample for each possible values for i in range(len(possible_vals)): counts[i] = (marginal_dist == possible_vals[i]).sum() probs = counts / num_samples return possible_vals, probs
def summary(traces, sites): marginal = EmpiricalMarginal(traces, sites)._get_samples_and_weights()[0].detach().cpu().numpy() site_stats = {} for i in range(marginal.shape[1]): site_name = sites[i] marginal_site = pd.DataFrame(marginal[:, i]).transpose() describe = partial(pd.Series.describe, percentiles=[.05, 0.25, 0.5, 0.75, 0.95]) site_stats[site_name] = marginal_site.apply(describe, axis=1) \ [["mean", "std", "5%", "25%", "50%", "75%", "95%"]] return site_stats
def test_mcmc_interface(): data = torch.tensor([1.0]) kernel = PriorKernel(normal_normal_model) mcmc = MCMC(kernel=kernel, num_samples=800, warmup_steps=100).run(data) marginal = EmpiricalMarginal(mcmc) assert_equal(marginal.sample_size, 800) sample_mean = marginal.mean sample_std = marginal.variance.sqrt() assert_equal(sample_mean, torch.tensor([0.0]), prec=0.08) assert_equal(sample_std, torch.tensor([1.0]), prec=0.08)
def observed_data_to_xarray(self): """Convert observed data to xarray.""" from pyro.infer import EmpiricalMarginal try: # Try pyro>=0.3 release syntax data = { name: np.expand_dims(samples.enumerate_support(), 0) for name, samples in self.posterior.marginal( sites=self.observed_vars).empirical.items() } except AttributeError: # Use pyro<0.3 release syntax data = {} for var_name in self.observed_vars: samples = EmpiricalMarginal( self.posterior, sites=var_name).get_samples_and_weights()[0] data[var_name] = np.expand_dims(samples.numpy().squeeze(), 0) return dict_to_dataset(data, library=self.pyro, coords=self.coords, dims=self.dims)
def test_posterior_predictive(): true_probs = torch.ones(5) * 0.7 num_trials = torch.ones(5) * 1000 num_success = dist.Binomial(num_trials, true_probs).sample() conditioned_model = poutine.condition(model, data={"obs": num_success}) nuts_kernel = NUTS(conditioned_model, adapt_step_size=True) mcmc_run = MCMC(nuts_kernel, num_samples=1000, warmup_steps=200).run(num_trials) posterior_predictive = TracePredictive(model, mcmc_run, num_samples=10000).run(num_trials) marginal_return_vals = EmpiricalMarginal(posterior_predictive) assert_equal(marginal_return_vals.mean, torch.ones(5) * 700, prec=30)
def policy(t, env, *, trajectory_model, log=False): """policy :param t: time-step :param env: OpenAI Gym environment :param trajectory_model: trajectory probabilistic program :param log: boolean; if True, print log info """ inference = Importance(softmax_agent_model, num_samples=args.num_samples) posterior = inference.run(t, env, trajectory_model=trajectory_model) marginal = EmpiricalMarginal(posterior, f'A_{t}') if log: samples = marginal.sample((args.num_samples, )) counts = Counter(samples.tolist()) hist = [ counts[i] / args.num_samples for i in range(env.action_space.n) ] print('policy:') print(tabulate([hist], headers=env.actions, tablefmt='fancy_grid')) return marginal.sample()
def policy(env, log=False): """policy :param env: OpenAI Gym environment :param log: boolean; if True, print log info """ inference = Importance(softmax_agent_model, num_samples=args.num_samples) posterior = inference.run(env) marginal = EmpiricalMarginal(posterior, 'policy_vector') if log: policy_samples = marginal.sample((args.num_samples, )) action_samples = policy_samples[:, env.state] counts = Counter(action_samples.tolist()) hist = [ counts[i] / args.num_samples for i in range(env.action_space.n) ] print('policy:') print(tabulate([hist], headers=env.actions, tablefmt='fancy_grid')) policy_vector = marginal.sample() return policy_vector[env.state]
def scm_covid_counterfactual( betas, max_abundance, observation, ras_intervention, spike_width: float = 1.0, svi: bool = True, samples: int = 5000, ) -> List[float]: gf_scm = SigmoidSCM(betas, max_abundance, spike_width) if svi: updated_noise, _ = gf_scm.update_noise_svi(observation, NOISE) else: updated_noise = gf_scm.update_noise_importance(observation, NOISE) counterfactual_model = do(gf_scm.model, ras_intervention) cf_posterior = gf_scm.infer(counterfactual_model, updated_noise) cf_cytokine_marginal = EmpiricalMarginal(cf_posterior, sites=['cytokine']) scm_causal_effect_samples = [ observation['cytokine'] - float(cf_cytokine_marginal.sample()) for _ in range(samples) ] return scm_causal_effect_samples
def test_normal_gamma(): def model(data): rate = torch.tensor([1.0, 1.0]) concentration = torch.tensor([1.0, 1.0]) p_latent = pyro.sample('p_latent', dist.Gamma(rate, concentration)) pyro.sample("obs", dist.Normal(3, p_latent), obs=data) return p_latent true_std = torch.tensor([0.5, 2]) data = dist.Normal(3, true_std).sample(sample_shape=(torch.Size((2000, )))) nuts_kernel = NUTS(model, step_size=0.01) mcmc_run = MCMC(nuts_kernel, num_samples=200, warmup_steps=100).run(data) posterior = EmpiricalMarginal(mcmc_run, sites='p_latent') assert_equal(posterior.mean, true_std, prec=0.05)
def test_categorical_dirichlet(): def model(data): concentration = torch.tensor([1.0, 1.0, 1.0]) p_latent = pyro.sample('p_latent', dist.Dirichlet(concentration)) pyro.sample("obs", dist.Categorical(p_latent), obs=data) return p_latent true_probs = torch.tensor([0.1, 0.6, 0.3]) data = dist.Categorical(true_probs).sample( sample_shape=(torch.Size((2000, )))) nuts_kernel = NUTS(model, adapt_step_size=True) mcmc_run = MCMC(nuts_kernel, num_samples=200, warmup_steps=100).run(data) posterior = EmpiricalMarginal(mcmc_run, sites='p_latent') assert_equal(posterior.mean, true_probs, prec=0.02)
def counterfactual_inference(self, condition_data: dict, intervention_data: dict, target, svi=True): # Step 1. Condition the model conditioned_model = self.condition(condition_data) # Step 2. Noise abduction if svi: updated_noise, _ = self.update_noise_svi(conditioned_model) # Step 3. Intervene intervention_model = self.intervention(intervention_data) # Pass abducted noises to intervention model cf_posterior = self.infer(intervention_model, updated_noise) marginal = EmpiricalMarginal(cf_posterior, target) counterfactual_samples = [marginal.sample() for _ in range(1000)] # Calculate causal effect scm_causal_effect_samples = [ condition_data[target] - float(marginal.sample()) for _ in range(5000) ] return scm_causal_effect_samples, counterfactual_samples
def summary(traces, sites, player_names, transforms={}): """ Return summarized statistics for each of the ``sites`` in the traces corresponding to the approximate posterior. """ marginal = EmpiricalMarginal(traces, sites).get_samples_and_weights()[0].numpy() site_stats = {} for i in range(marginal.shape[1]): site_name = sites[i] marginal_site = marginal[:, i] if site_name in transforms: marginal_site = transforms[site_name](marginal_site) site_stats[site_name] = get_site_stats(marginal_site, player_names) return site_stats
def get_mean_svi_est_manual_guide(self): svi_posterior = self.svi.run() sites = ["weights", "scale", "locs"] svi_samples = { site: EmpiricalMarginal(svi_posterior, sites=site). enumerate_support().detach().cpu().numpy() for site in sites } self.posterior_est = dict() for item in svi_samples.keys(): self.posterior_est[item] = torch.tensor( svi_samples[item].mean(axis=0)) return self.posterior_est