def auto_variational_fit(model, data, num_epochs=2500, lr=0.001):
    """ Use Stochastic Variational Inference for inferring latent variables """
    guide = AutoMultivariateNormal(model)
    svi = pyro.infer.SVI(model=model,
                         guide=guide,
                         optim=pyro.optim.Adam({'lr': lr}),
                         loss=Trace_ELBO())
    losses = []
    for i in tqdm(range(num_epochs)):
        losses.append(svi.step(data))

    return losses, guide.get_posterior()
示例#2
0
def AutoMixed(model_full, init_loc={}, delta=None):
    guide = AutoGuideList(model_full)

    marginalised_guide_block = poutine.block(model_full,
                                             expose_all=True,
                                             hide_all=False,
                                             hide=['tau'])
    if delta is None:
        guide.append(
            AutoNormal(marginalised_guide_block,
                       init_loc_fn=autoguide.init_to_value(values=init_loc),
                       init_scale=0.05))
    elif delta == 'part' or delta == 'all':
        guide.append(
            AutoDelta(marginalised_guide_block,
                      init_loc_fn=autoguide.init_to_value(values=init_loc)))

    full_rank_guide_block = poutine.block(model_full,
                                          hide_all=True,
                                          expose=['tau'])
    if delta is None or delta == 'part':
        guide.append(
            AutoMultivariateNormal(
                full_rank_guide_block,
                init_loc_fn=autoguide.init_to_value(values=init_loc),
                init_scale=0.05))
    else:
        guide.append(
            AutoDelta(full_rank_guide_block,
                      init_loc_fn=autoguide.init_to_value(values=init_loc)))

    return guide
示例#3
0
    def do_test_auto(self, N, reparameterized, n_steps):
        logger.debug("\nGoing to do AutoGaussianChain test...")
        pyro.clear_param_store()
        self.setUp()
        self.setup_chain(N)
        self.compute_target(N)
        self.guide = AutoMultivariateNormal(self.model)
        logger.debug("target auto_loc: {}"
                     .format(self.target_auto_mus[1:].detach().cpu().numpy()))
        logger.debug("target auto_diag_cov: {}"
                     .format(self.target_auto_diag_cov[1:].detach().cpu().numpy()))

        # TODO speed up with parallel num_particles > 1
        adam = optim.Adam({"lr": .001, "betas": (0.95, 0.999)})
        svi = SVI(self.model, self.guide, adam, loss=Trace_ELBO())

        for k in range(n_steps):
            loss = svi.step(reparameterized)
            assert np.isfinite(loss), loss

            if k % 1000 == 0 and k > 0 or k == n_steps - 1:
                logger.debug("[step {}] guide mean parameter: {}"
                             .format(k, self.guide.loc.detach().cpu().numpy()))
                L = self.guide.scale_tril
                diag_cov = torch.mm(L, L.t()).diag()
                logger.debug("[step {}] auto_diag_cov: {}"
                             .format(k, diag_cov.detach().cpu().numpy()))

        assert_equal(self.guide.loc.detach(), self.target_auto_mus[1:], prec=0.05,
                     msg="guide mean off")
        assert_equal(diag_cov, self.target_auto_diag_cov[1:], prec=0.07,
                     msg="guide covariance off")
示例#4
0
def run_inference(data, gen_model, ode_model, method, iterations=10000, num_particles=1, num_samples=1000, warmup_steps=500, init_scale=0.1,
                  seed=12, lr=0.5, return_sites="_RETURN"):
    torch_data = torch.tensor(data, dtype=torch.float)
    if isinstance(ode_model, ForwardSensManualJacobians) or \
            isinstance(ode_model, ForwardSensTorchJacobians):
        ode_op = ForwardSensOp
    elif isinstance(ode_model, AdjointSensManualJacobians) or \
            isinstance(ode_model, AdjointSensTorchJacobians):
        ode_op = AdjointSensOp
    else:
        raise ValueError('Unknown sensitivity solver: Use "Forward" or "Adjoint"')
    model = gen_model(ode_op, ode_model)
    pyro.set_rng_seed(seed)
    pyro.clear_param_store()
    if method == 'VI':

        guide = AutoMultivariateNormal(model, init_scale=init_scale)
        optim = AdagradRMSProp({"eta": lr})
        if num_particles == 1:
            svi = SVI(model, guide, optim, loss=Trace_ELBO())
        else:
            svi = SVI(model, guide, optim, loss=Trace_ELBO(num_particles=num_particles,
                                                           vectorize_particles=True))
        loss_trace = []
        t0 = timer.time()
        for j in range(iterations):
            loss = svi.step(torch_data)
            loss_trace.append(loss)

            if j % 500 == 0:
                print("[iteration %04d] loss: %.4f" % (j + 1, np.mean(loss_trace[max(0, j - 1000):j + 1])))
        t1 = timer.time()
        print('VI time: ', t1 - t0)
        predictive = Predictive(model, guide=guide, num_samples=num_samples,
                                return_sites=return_sites)  # "ode_params", "scale",
        vb_samples = predictive(torch_data)
        return vb_samples

    elif method == 'NUTS':

        nuts_kernel = NUTS(model, adapt_step_size=True, init_strategy=init_to_median)

        # mcmc = MCMC(nuts_kernel, num_samples=iterations, warmup_steps=warmup_steps, num_chains=2)
        mcmc = MCMC(nuts_kernel, num_samples=iterations, warmup_steps=warmup_steps, num_chains=1)
        t0 = timer.time()
        mcmc.run(torch_data)
        t1 = timer.time()
        print('NUTS time: ', t1 - t0)
        hmc_samples = {k: v.detach().cpu().numpy() for k, v in mcmc.get_samples().items()}
        return hmc_samples
    else:
        raise ValueError('Unknown method: Use "NUTS" or "VI"')
示例#5
0
文件: SPIRE.py 项目: ianmbus/XID_plus
def all_bands(priors,
              lr=0.005,
              n_steps=1000,
              n_samples=1000,
              verbose=True,
              sub=1):
    from pyro.infer import Predictive

    pyro.clear_param_store()

    guide = AutoMultivariateNormal(spire_model, init_loc_fn=init_to_mean)

    svi = SVI(spire_model, guide, optim.Adam({"lr": lr}), loss=Trace_ELBO())

    loss_history = []
    for i in range(n_steps):
        loss = svi.step(priors, sub=sub)
        if (i % 100 == 0) and verbose:
            print('ELBO loss: {}'.format(loss))
        loss_history.append(loss)
    print('ELBO loss: {}'.format(loss))
    predictive = Predictive(spire_model, guide=guide, num_samples=n_samples)
    samples = {
        k: v.squeeze(-1).detach().cpu().numpy()
        for k, v in predictive(priors).items() if k != "obs"
    }
    f_low_lim = torch.tensor([p.prior_flux_lower for p in priors],
                             dtype=torch.float)
    f_up_lim = torch.tensor([p.prior_flux_upper for p in priors],
                            dtype=torch.float)
    f_vec_multi = (f_up_lim -
                   f_low_lim) * samples['src_f'][..., :, :] + f_low_lim
    samples['src_f'] = f_vec_multi.squeeze(-3).numpy()
    samples['sigma_conf'] = samples['sigma_conf'].squeeze(-1).squeeze(-2)
    samples['bkg'] = samples['bkg'].squeeze(-1).squeeze(-2)

    return {'loss_history': loss_history, 'samples': samples}
 def guide_multivariatenormal(self):
     self.guide = AutoMultivariateNormal(poutine.block(self.model,
                                                       expose=['weights',
                                                               'locs',
                                                               'scale']))
示例#7
0
    def fit_svi(self, *,
                num_samples=100,
                num_steps=2000,
                num_particles=32,
                learning_rate=0.1,
                learning_rate_decay=0.01,
                betas=(0.8, 0.99),
                haar=True,
                init_scale=0.01,
                guide_rank=0,
                jit=False,
                log_every=200,
                **options):
        """
        Runs stochastic variational inference to generate posterior samples.

        This runs :class:`~pyro.infer.svi.SVI`, setting the ``.samples``
        attribute on completion.

        This approximate inference method is useful for quickly iterating on
        probabilistic models.

        :param int num_samples: Number of posterior samples to draw from the
            trained guide. Defaults to 100.
        :param int num_steps: Number of :class:`~pyro.infer.svi.SVI` steps.
        :param int num_particles: Number of :class:`~pyro.infer.svi.SVI`
            particles per step.
        :param int learning_rate: Learning rate for the
            :class:`~pyro.optim.clipped_adam.ClippedAdam` optimizer.
        :param int learning_rate_decay: Learning rate for the
            :class:`~pyro.optim.clipped_adam.ClippedAdam` optimizer. Note this
            is decay over the entire schedule, not per-step decay.
        :param tuple betas: Momentum parameters for the
            :class:`~pyro.optim.clipped_adam.ClippedAdam` optimizer.
        :param bool haar: Whether to use a Haar wavelet reparameterizer.
        :param int guide_rank: Rank of the auto normal guide. If zero (default)
            use an :class:`~pyro.infer.autoguide.AutoNormal` guide. If a
            positive integer or None, use an
            :class:`~pyro.infer.autoguide.AutoLowRankMultivariateNormal` guide.
            If the string "full", use an
            :class:`~pyro.infer.autoguide.AutoMultivariateNormal` guide. These
            latter two require more ``num_steps`` to fit.
        :param float init_scale: Initial scale of the
            :class:`~pyro.infer.autoguide.AutoLowRankMultivariateNormal` guide.
        :param bool jit: Whether to use a jit compiled ELBO.
        :param int log_every: How often to log svi losses.
        :param int heuristic_num_particles: Passed to :meth:`heuristic` as
            ``num_particles``. Defaults to 1024.
        :returns: Time series of SVI losses (useful to diagnose convergence).
        :rtype: list
        """
        # Save configuration for .predict().
        self.relaxed = True
        self.num_quant_bins = 1

        # Setup Haar wavelet transform.
        if haar:
            time_dim = -2 if self.is_regional else -1
            dims = {"auxiliary": time_dim}
            supports = {"auxiliary": constraints.interval(-0.5, self.population + 0.5)}
            for name, (fn, is_regional) in self._non_compartmental.items():
                dims[name] = time_dim - fn.event_dim
                supports[name] = fn.support
            haar = _HaarSplitReparam(0, self.duration, dims, supports)

        # Heuristically initialize to feasible latents.
        heuristic_options = {k.replace("heuristic_", ""): options.pop(k)
                             for k in list(options)
                             if k.startswith("heuristic_")}
        assert not options, "unrecognized options: {}".format(", ".join(options))
        init_strategy = self._heuristic(haar, **heuristic_options)

        # Configure variational inference.
        logger.info("Running inference...")
        model = self._relaxed_model
        if haar:
            model = haar.reparam(model)
        if guide_rank == 0:
            guide = AutoNormal(model, init_loc_fn=init_strategy, init_scale=init_scale)
        elif guide_rank == "full":
            guide = AutoMultivariateNormal(model, init_loc_fn=init_strategy,
                                           init_scale=init_scale)
        elif guide_rank is None or isinstance(guide_rank, int):
            guide = AutoLowRankMultivariateNormal(model, init_loc_fn=init_strategy,
                                                  init_scale=init_scale, rank=guide_rank)
        else:
            raise ValueError("Invalid guide_rank: {}".format(guide_rank))
        Elbo = JitTrace_ELBO if jit else Trace_ELBO
        elbo = Elbo(max_plate_nesting=self.max_plate_nesting,
                    num_particles=num_particles, vectorize_particles=True,
                    ignore_jit_warnings=True)
        optim = ClippedAdam({"lr": learning_rate, "betas": betas,
                             "lrd": learning_rate_decay ** (1 / num_steps)})
        svi = SVI(model, guide, optim, elbo)

        # Run inference.
        start_time = default_timer()
        losses = []
        for step in range(1 + num_steps):
            loss = svi.step() / self.duration
            if step % log_every == 0:
                logger.info("step {} loss = {:0.4g}".format(step, loss))
            losses.append(loss)
        elapsed = default_timer() - start_time
        logger.info("SVI took {:0.1f} seconds, {:0.1f} step/sec"
                    .format(elapsed, (1 + num_steps) / elapsed))

        # Draw posterior samples.
        with torch.no_grad():
            particle_plate = pyro.plate("particles", num_samples,
                                        dim=-1 - self.max_plate_nesting)
            guide_trace = poutine.trace(particle_plate(guide)).get_trace()
            model_trace = poutine.trace(
                poutine.replay(particle_plate(model), guide_trace)).get_trace()
            self.samples = {name: site["value"] for name, site in model_trace.nodes.items()
                            if site["type"] == "sample"
                            if not site["is_observed"]
                            if not site_is_subsample(site)}
            if haar:
                haar.aux_to_user(self.samples)
        assert all(v.size(0) == num_samples for v in self.samples.values()), \
            {k: tuple(v.shape) for k, v in self.samples.items()}

        return losses
axs[0].set(xlabel="bA", ylabel="bR", xlim=(-2.5, -1.2), ylim=(-0.5, 0.1))
sns.kdeplot(hmc_samples["bR"], hmc_samples["bAR"], ax=axs[1], shade=True, label="HMC")
sns.kdeplot(svi_samples["bR"], svi_samples["bAR"], ax=axs[1], label="SVI (DiagNormal)")
axs[1].set(xlabel="bR", ylabel="bAR", xlim=(-0.45, 0.05), ylim=(-0.15, 0.8))
handles, labels = axs[1].get_legend_handles_labels()
fig.legend(handles, labels, loc='upper right');






from pyro.infer.autoguide import AutoMultivariateNormal, init_to_mean


guide = AutoMultivariateNormal(model, init_loc_fn=init_to_mean)

svi = SVI(model,
          guide,
          optim.Adam({"lr": .01}),
          loss=Trace_ELBO())

is_cont_africa, ruggedness, log_gdp = train[:, 0], train[:, 1], train[:, 2]
pyro.clear_param_store()
for i in range(num_iters):
    elbo = svi.step(is_cont_africa, ruggedness, log_gdp)
    if i % 500 == 0:
        logging.info("Elbo loss: {}".format(elbo))


示例#9
0
 def _create_guide(self, X):
     return AutoMultivariateNormal(self.model, init_scale=0.2)
 def multi_norm_guide(self):
     return AutoMultivariateNormal(self.model, init_loc_fn=init_to_mean)
示例#11
0
def hetero_multi_gp_train(X_augmented, Y_augmented, X1, Y1, X2, Y2, Y1_cens_sc,
                          y, y2, y_sc, censoring, censoring_mul, int_low,
                          noise_scale, file):
    N1 = len(y)
    pyro.clear_param_store()
    k1 = gp.kernels.RBF(input_dim=1,
                        active_dims=[0],
                        lengthscale=torch.tensor(1.),
                        variance=torch.tensor(1.))
    coreg = gp.kernels.Coregionalize(input_dim=X_augmented.shape[1], rank=1)
    f_rbf = gp.kernels.Product(k1, coreg)

    g0_rbf = gp.kernels.RBF(input_dim=1,
                            lengthscale=torch.tensor(1.),
                            variance=torch.tensor(1.))
    g1_rbf = gp.kernels.RBF(input_dim=1,
                            lengthscale=torch.tensor(1.),
                            variance=torch.tensor(1.))

    like = CensoredHeteroGaussian(censoring=censoring_mul)
    multi_het_gp = VariationalMHGP(X=X_augmented,
                                   y=Y_augmented,
                                   f_kernel=f_rbf,
                                   g0_kernel=g0_rbf,
                                   g1_kernel=g1_rbf,
                                   likelihood=like,
                                   jitter=0.005)

    guide = AutoMultivariateNormal(multi_het_gp.model)

    optimizer = pyro.optim.ClippedAdam({"lr": 0.003, "lrd": 0.99969})
    svi = SVI(multi_het_gp.model, guide, optimizer,
              Trace_ELBO(num_particles=60))

    num_epochs = 12000
    losses = []
    pyro.clear_param_store()
    for epoch in range(num_epochs):
        loss = svi.step(X_augmented, Y_augmented)
        losses.append(loss)
        if epoch == num_epochs - 1:

            with torch.no_grad():
                predictive = Predictive(multi_het_gp.model,
                                        guide=guide,
                                        num_samples=1000,
                                        return_sites=("f", "g0", "g1",
                                                      "_RETURN"))
                samples = predictive(X_augmented)
                f_samples = samples["f"]
                f_mean = f_samples.mean(dim=0)
                g0_samples = samples["g0"]
                g1_samples = samples["g1"]
                g0_mean = g0_samples.mean(dim=0).detach().numpy()
                g1_mean = g1_samples.mean(dim=0).detach().numpy()
                f_025 = np.quantile(a=f_samples.detach().numpy(),
                                    q=0.025,
                                    axis=0)
                f_975 = np.quantile(a=f_samples.detach().numpy(),
                                    q=0.975,
                                    axis=0)

                fig = plt.figure(figsize=(20, 12))
                fig.add_subplot(221)
                plt.plot(X1.numpy(),
                         y_sc.numpy(),
                         linestyle="--",
                         color="black")
                plt.plot(X1.reshape(-1).detach().numpy(),
                         f_mean.detach().numpy()[0:(N1)],
                         color="black")
                plt.fill_between(
                    X1.reshape(-1).detach().numpy(),
                    f_mean.detach().numpy()[0:(N1)] - 1.96 * np.exp(g0_mean),
                    f_mean.detach().numpy()[0:(N1)] + 1.96 * np.exp(g0_mean),
                    alpha=0.3)
                plt.fill_between(X1.reshape(-1).detach().numpy(),
                                 f_025[0:N1],
                                 f_975[0:N1],
                                 alpha=0.3,
                                 label='Y1 Mean uncertainty')
                plt.scatter(X1.numpy()[censoring == 1].reshape(-1, 1),
                            y=Y1_cens_sc.numpy()[censoring == 1].reshape(
                                -1, 1),
                            marker="x",
                            label="Censored Observations",
                            color='#348ABD')
                plt.scatter(X1.numpy()[censoring == 0].reshape(-1, 1),
                            y=Y1_cens_sc.numpy()[censoring == 0].reshape(
                                -1, 1),
                            marker="o",
                            label="Non-Censored Observations",
                            color='#348ABD')
                plt.legend(prop={'size': 12})

                fig.add_subplot(222)
                plt.plot(X2.numpy(), y2.numpy(), linestyle="--", color="gray")
                plt.plot(X2.reshape(-1).detach().numpy(),
                         f_mean.detach().numpy()[(N1):],
                         color="black")
                plt.fill_between(
                    X2.reshape(-1).detach().numpy(),
                    f_mean.detach().numpy()[(N1):] - 1.96 * np.exp(g1_mean),
                    f_mean.detach().numpy()[(N1):] + 1.96 * np.exp(g1_mean),
                    alpha=0.3)
                plt.fill_between(X2.reshape(-1).detach().numpy(),
                                 f_025[N1:],
                                 f_975[N1:],
                                 alpha=0.3,
                                 label='Y2 mean uncertainty')
                plt.scatter(X2.numpy(), Y2.numpy(), label='Y2 Observed values')
                plt.legend(prop={'size': 12})

                fig.add_subplot(223)
                plt.plot(np.arange(len(g0_mean)),
                         np.exp(g0_mean),
                         'b--',
                         label='Y1 inference noise')
                plt.plot(np.arange(len(noise_scale)),
                         noise_scale,
                         'b-',
                         label='Y1 true noise')
                plt.legend(prop={'size': 12})

                fig.add_subplot(224)
                plt.plot(np.arange(len(g1_mean)),
                         np.exp(g1_mean),
                         'r--',
                         label='Y2 inference noise')
                plt.plot(np.arange(len(noise_scale_y2)),
                         noise_scale_y2,
                         'r-',
                         label='Y2 true noise')
                plt.legend(prop={'size': 12})
                plt.savefig(
                    'Experiments/Synthetic/HMGP/HMGP_Synthetic_{}.png'.format(
                        int_low))

    fig1 = plt.figure(figsize=(8, 6))
    plt.plot(losses, label='Loss')
    plt.legend(prop={'size': 12})
    plt.savefig('Experiments/Synthetic/HMGP/HMGP_Synthetic_Loss_{}.png'.format(
        int_low))

    RMSE = sqrt(mean_squared_error(y_sc, f_mean.detach().numpy()[:(N1)]))
    NLPD = -(1 / len(y_sc) * multi_het_gp.likelihood.y_dist.log_prob(
        torch.cat((y_sc, y2.type(torch.float64))))[:(len(y_sc))].sum().item())

    file.write('\n Intensity :' + str(int_low) + ' ')
    file.write('RMSE: ' + str(RMSE) + ' ')
    file.write('NLPD: ' + str(NLPD))
示例#12
0
def standard_gp_train(X1, Y1, Y1_cens_sc, y, y_sc, censoring, int_low,
                      noise_scale, file):

    pyro.clear_param_store()
    kern = gp.kernels.RBF(input_dim=1,
                          active_dims=[0],
                          lengthscale=torch.tensor(1.),
                          variance=torch.tensor(1.))
    like = HomoscedGaussian()
    sgphomo = VariationalGP(X=X1,
                            y=Y1_cens_sc,
                            kernel=kern,
                            likelihood=like,
                            mean_function=None,
                            latent_shape=None,
                            whiten=False,
                            jitter=0.005)
    guide = AutoMultivariateNormal(sgphomo.model)

    optimizer = pyro.optim.ClippedAdam({"lr": 0.003, "lrd": 0.99969})
    svi = SVI(sgphomo.model, guide, optimizer, Trace_ELBO(num_particles=40))

    num_epochs = 4000
    losses = []
    pyro.clear_param_store()
    for epoch in range(num_epochs):
        loss = svi.step(X1, Y1_cens_sc)
        losses.append(loss)
        if epoch == num_epochs - 1:

            with torch.no_grad():
                predictive = Predictive(sgphomo.model,
                                        guide=guide,
                                        num_samples=1000,
                                        return_sites=("f", "g", "_RETURN"))
                samples = predictive(X1)
                f_samples = samples["f"]
                f_mean = f_samples.mean(dim=0)
                f_std = sgphomo.likelihood.variance.sqrt().item()
                f_025 = np.quantile(a=f_samples.detach().numpy(),
                                    q=0.025,
                                    axis=0)
                f_975 = np.quantile(a=f_samples.detach().numpy(),
                                    q=0.975,
                                    axis=0)
                fig = plt.figure(figsize=(20, 6))
                fig.add_subplot(121)
                plt.plot(X1.numpy(),
                         y_sc.numpy(),
                         linestyle="--",
                         color="black")
                plt.plot(X1.detach().numpy(),
                         f_mean.detach().numpy(),
                         color="black")
                plt.fill_between(X1.detach().numpy(),
                                 f_mean.detach().numpy() - 1.96 * f_std,
                                 f_mean.detach().numpy() + 1.96 * f_std,
                                 alpha=0.3)
                plt.fill_between(X1.reshape(-1).detach().numpy(),
                                 f_025,
                                 f_975,
                                 alpha=0.3,
                                 label='Mean uncertainty')
                plt.scatter(X1.numpy()[censoring == 1].reshape(-1, 1),
                            y=Y1_cens_sc.numpy()[censoring == 1].reshape(
                                -1, 1),
                            marker="x",
                            label="Censored Observations",
                            color='#348ABD')
                plt.scatter(X1.numpy()[censoring == 0].reshape(-1, 1),
                            y=Y1_cens_sc.numpy()[censoring == 0].reshape(
                                -1, 1),
                            marker="o",
                            label="Non-Censored Observations",
                            color='#348ABD')
                plt.legend(prop={'size': 12})

                fig.add_subplot(122)
                plt.plot(np.arange(len(noise_scale)),
                         np.ones(len(noise_scale)) *
                         sgphomo.likelihood.variance.sqrt().item(),
                         label='Estimated noise')
                plt.plot(np.arange(len(noise_scale)),
                         noise_scale,
                         label='True noise scale')
                plt.legend(prop={'size': 12})
                plt.savefig(
                    'Experiments/Synthetic/SGP/SGP_Synthetic_{}.png'.format(
                        int_low))

    fig1 = plt.figure(figsize=(8, 6))
    plt.plot(losses, label='Loss')
    plt.legend(prop={'size': 12})
    plt.savefig(
        'Experiments/Synthetic/SGP/SGP_Synthetic_Loss_{}.png'.format(int_low))

    RMSE = sqrt(mean_squared_error(y_sc, f_mean.detach().numpy()))
    NLPD = -(1 / len(Y1) *
             sgphomo.likelihood.y_dist.log_prob(y_sc).sum().item())

    file.write('\n Intensity :' + str(int_low) + ' ')
    file.write('RMSE: ' + str(RMSE) + ' ')
    file.write('NLPD: ' + str(NLPD))