def auto_variational_fit(model, data, num_epochs=2500, lr=0.001): """ Use Stochastic Variational Inference for inferring latent variables """ guide = AutoMultivariateNormal(model) svi = pyro.infer.SVI(model=model, guide=guide, optim=pyro.optim.Adam({'lr': lr}), loss=Trace_ELBO()) losses = [] for i in tqdm(range(num_epochs)): losses.append(svi.step(data)) return losses, guide.get_posterior()
def AutoMixed(model_full, init_loc={}, delta=None): guide = AutoGuideList(model_full) marginalised_guide_block = poutine.block(model_full, expose_all=True, hide_all=False, hide=['tau']) if delta is None: guide.append( AutoNormal(marginalised_guide_block, init_loc_fn=autoguide.init_to_value(values=init_loc), init_scale=0.05)) elif delta == 'part' or delta == 'all': guide.append( AutoDelta(marginalised_guide_block, init_loc_fn=autoguide.init_to_value(values=init_loc))) full_rank_guide_block = poutine.block(model_full, hide_all=True, expose=['tau']) if delta is None or delta == 'part': guide.append( AutoMultivariateNormal( full_rank_guide_block, init_loc_fn=autoguide.init_to_value(values=init_loc), init_scale=0.05)) else: guide.append( AutoDelta(full_rank_guide_block, init_loc_fn=autoguide.init_to_value(values=init_loc))) return guide
def do_test_auto(self, N, reparameterized, n_steps): logger.debug("\nGoing to do AutoGaussianChain test...") pyro.clear_param_store() self.setUp() self.setup_chain(N) self.compute_target(N) self.guide = AutoMultivariateNormal(self.model) logger.debug("target auto_loc: {}" .format(self.target_auto_mus[1:].detach().cpu().numpy())) logger.debug("target auto_diag_cov: {}" .format(self.target_auto_diag_cov[1:].detach().cpu().numpy())) # TODO speed up with parallel num_particles > 1 adam = optim.Adam({"lr": .001, "betas": (0.95, 0.999)}) svi = SVI(self.model, self.guide, adam, loss=Trace_ELBO()) for k in range(n_steps): loss = svi.step(reparameterized) assert np.isfinite(loss), loss if k % 1000 == 0 and k > 0 or k == n_steps - 1: logger.debug("[step {}] guide mean parameter: {}" .format(k, self.guide.loc.detach().cpu().numpy())) L = self.guide.scale_tril diag_cov = torch.mm(L, L.t()).diag() logger.debug("[step {}] auto_diag_cov: {}" .format(k, diag_cov.detach().cpu().numpy())) assert_equal(self.guide.loc.detach(), self.target_auto_mus[1:], prec=0.05, msg="guide mean off") assert_equal(diag_cov, self.target_auto_diag_cov[1:], prec=0.07, msg="guide covariance off")
def run_inference(data, gen_model, ode_model, method, iterations=10000, num_particles=1, num_samples=1000, warmup_steps=500, init_scale=0.1, seed=12, lr=0.5, return_sites="_RETURN"): torch_data = torch.tensor(data, dtype=torch.float) if isinstance(ode_model, ForwardSensManualJacobians) or \ isinstance(ode_model, ForwardSensTorchJacobians): ode_op = ForwardSensOp elif isinstance(ode_model, AdjointSensManualJacobians) or \ isinstance(ode_model, AdjointSensTorchJacobians): ode_op = AdjointSensOp else: raise ValueError('Unknown sensitivity solver: Use "Forward" or "Adjoint"') model = gen_model(ode_op, ode_model) pyro.set_rng_seed(seed) pyro.clear_param_store() if method == 'VI': guide = AutoMultivariateNormal(model, init_scale=init_scale) optim = AdagradRMSProp({"eta": lr}) if num_particles == 1: svi = SVI(model, guide, optim, loss=Trace_ELBO()) else: svi = SVI(model, guide, optim, loss=Trace_ELBO(num_particles=num_particles, vectorize_particles=True)) loss_trace = [] t0 = timer.time() for j in range(iterations): loss = svi.step(torch_data) loss_trace.append(loss) if j % 500 == 0: print("[iteration %04d] loss: %.4f" % (j + 1, np.mean(loss_trace[max(0, j - 1000):j + 1]))) t1 = timer.time() print('VI time: ', t1 - t0) predictive = Predictive(model, guide=guide, num_samples=num_samples, return_sites=return_sites) # "ode_params", "scale", vb_samples = predictive(torch_data) return vb_samples elif method == 'NUTS': nuts_kernel = NUTS(model, adapt_step_size=True, init_strategy=init_to_median) # mcmc = MCMC(nuts_kernel, num_samples=iterations, warmup_steps=warmup_steps, num_chains=2) mcmc = MCMC(nuts_kernel, num_samples=iterations, warmup_steps=warmup_steps, num_chains=1) t0 = timer.time() mcmc.run(torch_data) t1 = timer.time() print('NUTS time: ', t1 - t0) hmc_samples = {k: v.detach().cpu().numpy() for k, v in mcmc.get_samples().items()} return hmc_samples else: raise ValueError('Unknown method: Use "NUTS" or "VI"')
def all_bands(priors, lr=0.005, n_steps=1000, n_samples=1000, verbose=True, sub=1): from pyro.infer import Predictive pyro.clear_param_store() guide = AutoMultivariateNormal(spire_model, init_loc_fn=init_to_mean) svi = SVI(spire_model, guide, optim.Adam({"lr": lr}), loss=Trace_ELBO()) loss_history = [] for i in range(n_steps): loss = svi.step(priors, sub=sub) if (i % 100 == 0) and verbose: print('ELBO loss: {}'.format(loss)) loss_history.append(loss) print('ELBO loss: {}'.format(loss)) predictive = Predictive(spire_model, guide=guide, num_samples=n_samples) samples = { k: v.squeeze(-1).detach().cpu().numpy() for k, v in predictive(priors).items() if k != "obs" } f_low_lim = torch.tensor([p.prior_flux_lower for p in priors], dtype=torch.float) f_up_lim = torch.tensor([p.prior_flux_upper for p in priors], dtype=torch.float) f_vec_multi = (f_up_lim - f_low_lim) * samples['src_f'][..., :, :] + f_low_lim samples['src_f'] = f_vec_multi.squeeze(-3).numpy() samples['sigma_conf'] = samples['sigma_conf'].squeeze(-1).squeeze(-2) samples['bkg'] = samples['bkg'].squeeze(-1).squeeze(-2) return {'loss_history': loss_history, 'samples': samples}
def guide_multivariatenormal(self): self.guide = AutoMultivariateNormal(poutine.block(self.model, expose=['weights', 'locs', 'scale']))
def fit_svi(self, *, num_samples=100, num_steps=2000, num_particles=32, learning_rate=0.1, learning_rate_decay=0.01, betas=(0.8, 0.99), haar=True, init_scale=0.01, guide_rank=0, jit=False, log_every=200, **options): """ Runs stochastic variational inference to generate posterior samples. This runs :class:`~pyro.infer.svi.SVI`, setting the ``.samples`` attribute on completion. This approximate inference method is useful for quickly iterating on probabilistic models. :param int num_samples: Number of posterior samples to draw from the trained guide. Defaults to 100. :param int num_steps: Number of :class:`~pyro.infer.svi.SVI` steps. :param int num_particles: Number of :class:`~pyro.infer.svi.SVI` particles per step. :param int learning_rate: Learning rate for the :class:`~pyro.optim.clipped_adam.ClippedAdam` optimizer. :param int learning_rate_decay: Learning rate for the :class:`~pyro.optim.clipped_adam.ClippedAdam` optimizer. Note this is decay over the entire schedule, not per-step decay. :param tuple betas: Momentum parameters for the :class:`~pyro.optim.clipped_adam.ClippedAdam` optimizer. :param bool haar: Whether to use a Haar wavelet reparameterizer. :param int guide_rank: Rank of the auto normal guide. If zero (default) use an :class:`~pyro.infer.autoguide.AutoNormal` guide. If a positive integer or None, use an :class:`~pyro.infer.autoguide.AutoLowRankMultivariateNormal` guide. If the string "full", use an :class:`~pyro.infer.autoguide.AutoMultivariateNormal` guide. These latter two require more ``num_steps`` to fit. :param float init_scale: Initial scale of the :class:`~pyro.infer.autoguide.AutoLowRankMultivariateNormal` guide. :param bool jit: Whether to use a jit compiled ELBO. :param int log_every: How often to log svi losses. :param int heuristic_num_particles: Passed to :meth:`heuristic` as ``num_particles``. Defaults to 1024. :returns: Time series of SVI losses (useful to diagnose convergence). :rtype: list """ # Save configuration for .predict(). self.relaxed = True self.num_quant_bins = 1 # Setup Haar wavelet transform. if haar: time_dim = -2 if self.is_regional else -1 dims = {"auxiliary": time_dim} supports = {"auxiliary": constraints.interval(-0.5, self.population + 0.5)} for name, (fn, is_regional) in self._non_compartmental.items(): dims[name] = time_dim - fn.event_dim supports[name] = fn.support haar = _HaarSplitReparam(0, self.duration, dims, supports) # Heuristically initialize to feasible latents. heuristic_options = {k.replace("heuristic_", ""): options.pop(k) for k in list(options) if k.startswith("heuristic_")} assert not options, "unrecognized options: {}".format(", ".join(options)) init_strategy = self._heuristic(haar, **heuristic_options) # Configure variational inference. logger.info("Running inference...") model = self._relaxed_model if haar: model = haar.reparam(model) if guide_rank == 0: guide = AutoNormal(model, init_loc_fn=init_strategy, init_scale=init_scale) elif guide_rank == "full": guide = AutoMultivariateNormal(model, init_loc_fn=init_strategy, init_scale=init_scale) elif guide_rank is None or isinstance(guide_rank, int): guide = AutoLowRankMultivariateNormal(model, init_loc_fn=init_strategy, init_scale=init_scale, rank=guide_rank) else: raise ValueError("Invalid guide_rank: {}".format(guide_rank)) Elbo = JitTrace_ELBO if jit else Trace_ELBO elbo = Elbo(max_plate_nesting=self.max_plate_nesting, num_particles=num_particles, vectorize_particles=True, ignore_jit_warnings=True) optim = ClippedAdam({"lr": learning_rate, "betas": betas, "lrd": learning_rate_decay ** (1 / num_steps)}) svi = SVI(model, guide, optim, elbo) # Run inference. start_time = default_timer() losses = [] for step in range(1 + num_steps): loss = svi.step() / self.duration if step % log_every == 0: logger.info("step {} loss = {:0.4g}".format(step, loss)) losses.append(loss) elapsed = default_timer() - start_time logger.info("SVI took {:0.1f} seconds, {:0.1f} step/sec" .format(elapsed, (1 + num_steps) / elapsed)) # Draw posterior samples. with torch.no_grad(): particle_plate = pyro.plate("particles", num_samples, dim=-1 - self.max_plate_nesting) guide_trace = poutine.trace(particle_plate(guide)).get_trace() model_trace = poutine.trace( poutine.replay(particle_plate(model), guide_trace)).get_trace() self.samples = {name: site["value"] for name, site in model_trace.nodes.items() if site["type"] == "sample" if not site["is_observed"] if not site_is_subsample(site)} if haar: haar.aux_to_user(self.samples) assert all(v.size(0) == num_samples for v in self.samples.values()), \ {k: tuple(v.shape) for k, v in self.samples.items()} return losses
axs[0].set(xlabel="bA", ylabel="bR", xlim=(-2.5, -1.2), ylim=(-0.5, 0.1)) sns.kdeplot(hmc_samples["bR"], hmc_samples["bAR"], ax=axs[1], shade=True, label="HMC") sns.kdeplot(svi_samples["bR"], svi_samples["bAR"], ax=axs[1], label="SVI (DiagNormal)") axs[1].set(xlabel="bR", ylabel="bAR", xlim=(-0.45, 0.05), ylim=(-0.15, 0.8)) handles, labels = axs[1].get_legend_handles_labels() fig.legend(handles, labels, loc='upper right'); from pyro.infer.autoguide import AutoMultivariateNormal, init_to_mean guide = AutoMultivariateNormal(model, init_loc_fn=init_to_mean) svi = SVI(model, guide, optim.Adam({"lr": .01}), loss=Trace_ELBO()) is_cont_africa, ruggedness, log_gdp = train[:, 0], train[:, 1], train[:, 2] pyro.clear_param_store() for i in range(num_iters): elbo = svi.step(is_cont_africa, ruggedness, log_gdp) if i % 500 == 0: logging.info("Elbo loss: {}".format(elbo))
def _create_guide(self, X): return AutoMultivariateNormal(self.model, init_scale=0.2)
def multi_norm_guide(self): return AutoMultivariateNormal(self.model, init_loc_fn=init_to_mean)
def hetero_multi_gp_train(X_augmented, Y_augmented, X1, Y1, X2, Y2, Y1_cens_sc, y, y2, y_sc, censoring, censoring_mul, int_low, noise_scale, file): N1 = len(y) pyro.clear_param_store() k1 = gp.kernels.RBF(input_dim=1, active_dims=[0], lengthscale=torch.tensor(1.), variance=torch.tensor(1.)) coreg = gp.kernels.Coregionalize(input_dim=X_augmented.shape[1], rank=1) f_rbf = gp.kernels.Product(k1, coreg) g0_rbf = gp.kernels.RBF(input_dim=1, lengthscale=torch.tensor(1.), variance=torch.tensor(1.)) g1_rbf = gp.kernels.RBF(input_dim=1, lengthscale=torch.tensor(1.), variance=torch.tensor(1.)) like = CensoredHeteroGaussian(censoring=censoring_mul) multi_het_gp = VariationalMHGP(X=X_augmented, y=Y_augmented, f_kernel=f_rbf, g0_kernel=g0_rbf, g1_kernel=g1_rbf, likelihood=like, jitter=0.005) guide = AutoMultivariateNormal(multi_het_gp.model) optimizer = pyro.optim.ClippedAdam({"lr": 0.003, "lrd": 0.99969}) svi = SVI(multi_het_gp.model, guide, optimizer, Trace_ELBO(num_particles=60)) num_epochs = 12000 losses = [] pyro.clear_param_store() for epoch in range(num_epochs): loss = svi.step(X_augmented, Y_augmented) losses.append(loss) if epoch == num_epochs - 1: with torch.no_grad(): predictive = Predictive(multi_het_gp.model, guide=guide, num_samples=1000, return_sites=("f", "g0", "g1", "_RETURN")) samples = predictive(X_augmented) f_samples = samples["f"] f_mean = f_samples.mean(dim=0) g0_samples = samples["g0"] g1_samples = samples["g1"] g0_mean = g0_samples.mean(dim=0).detach().numpy() g1_mean = g1_samples.mean(dim=0).detach().numpy() f_025 = np.quantile(a=f_samples.detach().numpy(), q=0.025, axis=0) f_975 = np.quantile(a=f_samples.detach().numpy(), q=0.975, axis=0) fig = plt.figure(figsize=(20, 12)) fig.add_subplot(221) plt.plot(X1.numpy(), y_sc.numpy(), linestyle="--", color="black") plt.plot(X1.reshape(-1).detach().numpy(), f_mean.detach().numpy()[0:(N1)], color="black") plt.fill_between( X1.reshape(-1).detach().numpy(), f_mean.detach().numpy()[0:(N1)] - 1.96 * np.exp(g0_mean), f_mean.detach().numpy()[0:(N1)] + 1.96 * np.exp(g0_mean), alpha=0.3) plt.fill_between(X1.reshape(-1).detach().numpy(), f_025[0:N1], f_975[0:N1], alpha=0.3, label='Y1 Mean uncertainty') plt.scatter(X1.numpy()[censoring == 1].reshape(-1, 1), y=Y1_cens_sc.numpy()[censoring == 1].reshape( -1, 1), marker="x", label="Censored Observations", color='#348ABD') plt.scatter(X1.numpy()[censoring == 0].reshape(-1, 1), y=Y1_cens_sc.numpy()[censoring == 0].reshape( -1, 1), marker="o", label="Non-Censored Observations", color='#348ABD') plt.legend(prop={'size': 12}) fig.add_subplot(222) plt.plot(X2.numpy(), y2.numpy(), linestyle="--", color="gray") plt.plot(X2.reshape(-1).detach().numpy(), f_mean.detach().numpy()[(N1):], color="black") plt.fill_between( X2.reshape(-1).detach().numpy(), f_mean.detach().numpy()[(N1):] - 1.96 * np.exp(g1_mean), f_mean.detach().numpy()[(N1):] + 1.96 * np.exp(g1_mean), alpha=0.3) plt.fill_between(X2.reshape(-1).detach().numpy(), f_025[N1:], f_975[N1:], alpha=0.3, label='Y2 mean uncertainty') plt.scatter(X2.numpy(), Y2.numpy(), label='Y2 Observed values') plt.legend(prop={'size': 12}) fig.add_subplot(223) plt.plot(np.arange(len(g0_mean)), np.exp(g0_mean), 'b--', label='Y1 inference noise') plt.plot(np.arange(len(noise_scale)), noise_scale, 'b-', label='Y1 true noise') plt.legend(prop={'size': 12}) fig.add_subplot(224) plt.plot(np.arange(len(g1_mean)), np.exp(g1_mean), 'r--', label='Y2 inference noise') plt.plot(np.arange(len(noise_scale_y2)), noise_scale_y2, 'r-', label='Y2 true noise') plt.legend(prop={'size': 12}) plt.savefig( 'Experiments/Synthetic/HMGP/HMGP_Synthetic_{}.png'.format( int_low)) fig1 = plt.figure(figsize=(8, 6)) plt.plot(losses, label='Loss') plt.legend(prop={'size': 12}) plt.savefig('Experiments/Synthetic/HMGP/HMGP_Synthetic_Loss_{}.png'.format( int_low)) RMSE = sqrt(mean_squared_error(y_sc, f_mean.detach().numpy()[:(N1)])) NLPD = -(1 / len(y_sc) * multi_het_gp.likelihood.y_dist.log_prob( torch.cat((y_sc, y2.type(torch.float64))))[:(len(y_sc))].sum().item()) file.write('\n Intensity :' + str(int_low) + ' ') file.write('RMSE: ' + str(RMSE) + ' ') file.write('NLPD: ' + str(NLPD))
def standard_gp_train(X1, Y1, Y1_cens_sc, y, y_sc, censoring, int_low, noise_scale, file): pyro.clear_param_store() kern = gp.kernels.RBF(input_dim=1, active_dims=[0], lengthscale=torch.tensor(1.), variance=torch.tensor(1.)) like = HomoscedGaussian() sgphomo = VariationalGP(X=X1, y=Y1_cens_sc, kernel=kern, likelihood=like, mean_function=None, latent_shape=None, whiten=False, jitter=0.005) guide = AutoMultivariateNormal(sgphomo.model) optimizer = pyro.optim.ClippedAdam({"lr": 0.003, "lrd": 0.99969}) svi = SVI(sgphomo.model, guide, optimizer, Trace_ELBO(num_particles=40)) num_epochs = 4000 losses = [] pyro.clear_param_store() for epoch in range(num_epochs): loss = svi.step(X1, Y1_cens_sc) losses.append(loss) if epoch == num_epochs - 1: with torch.no_grad(): predictive = Predictive(sgphomo.model, guide=guide, num_samples=1000, return_sites=("f", "g", "_RETURN")) samples = predictive(X1) f_samples = samples["f"] f_mean = f_samples.mean(dim=0) f_std = sgphomo.likelihood.variance.sqrt().item() f_025 = np.quantile(a=f_samples.detach().numpy(), q=0.025, axis=0) f_975 = np.quantile(a=f_samples.detach().numpy(), q=0.975, axis=0) fig = plt.figure(figsize=(20, 6)) fig.add_subplot(121) plt.plot(X1.numpy(), y_sc.numpy(), linestyle="--", color="black") plt.plot(X1.detach().numpy(), f_mean.detach().numpy(), color="black") plt.fill_between(X1.detach().numpy(), f_mean.detach().numpy() - 1.96 * f_std, f_mean.detach().numpy() + 1.96 * f_std, alpha=0.3) plt.fill_between(X1.reshape(-1).detach().numpy(), f_025, f_975, alpha=0.3, label='Mean uncertainty') plt.scatter(X1.numpy()[censoring == 1].reshape(-1, 1), y=Y1_cens_sc.numpy()[censoring == 1].reshape( -1, 1), marker="x", label="Censored Observations", color='#348ABD') plt.scatter(X1.numpy()[censoring == 0].reshape(-1, 1), y=Y1_cens_sc.numpy()[censoring == 0].reshape( -1, 1), marker="o", label="Non-Censored Observations", color='#348ABD') plt.legend(prop={'size': 12}) fig.add_subplot(122) plt.plot(np.arange(len(noise_scale)), np.ones(len(noise_scale)) * sgphomo.likelihood.variance.sqrt().item(), label='Estimated noise') plt.plot(np.arange(len(noise_scale)), noise_scale, label='True noise scale') plt.legend(prop={'size': 12}) plt.savefig( 'Experiments/Synthetic/SGP/SGP_Synthetic_{}.png'.format( int_low)) fig1 = plt.figure(figsize=(8, 6)) plt.plot(losses, label='Loss') plt.legend(prop={'size': 12}) plt.savefig( 'Experiments/Synthetic/SGP/SGP_Synthetic_Loss_{}.png'.format(int_low)) RMSE = sqrt(mean_squared_error(y_sc, f_mean.detach().numpy())) NLPD = -(1 / len(Y1) * sgphomo.likelihood.y_dist.log_prob(y_sc).sum().item()) file.write('\n Intensity :' + str(int_low) + ' ') file.write('RMSE: ' + str(RMSE) + ' ') file.write('NLPD: ' + str(NLPD))