Exemple #1
0
def main(args):
    pyro.set_rng_seed(args.rng_seed)
    fig = plt.figure(figsize=(8, 16), constrained_layout=True)
    gs = GridSpec(4, 2, figure=fig)
    ax1 = fig.add_subplot(gs[0, 0])
    ax2 = fig.add_subplot(gs[0, 1])
    ax3 = fig.add_subplot(gs[1, 0])
    ax4 = fig.add_subplot(gs[2, 0])
    ax5 = fig.add_subplot(gs[3, 0])
    ax6 = fig.add_subplot(gs[1, 1])
    ax7 = fig.add_subplot(gs[2, 1])
    ax8 = fig.add_subplot(gs[3, 1])
    xlim = tuple(int(x) for x in args.x_lim.strip().split(','))
    ylim = tuple(int(x) for x in args.y_lim.strip().split(','))
    assert len(xlim) == 2
    assert len(ylim) == 2

    # 1. Plot samples drawn from BananaShaped distribution
    x1, x2 = torch.meshgrid(
        [torch.linspace(*xlim, 100),
         torch.linspace(*ylim, 100)])
    d = BananaShaped(args.param_a, args.param_b)
    p = torch.exp(d.log_prob(torch.stack([x1, x2], dim=-1)))
    ax1.contourf(
        x1,
        x2,
        p,
        cmap='OrRd',
    )
    ax1.set(xlabel='x0',
            ylabel='x1',
            xlim=xlim,
            ylim=ylim,
            title='BananaShaped distribution: \nlog density')

    # 2. Run vanilla HMC
    logging.info('\nDrawing samples using vanilla HMC ...')
    mcmc = run_hmc(args, model)
    vanilla_samples = mcmc.get_samples()['x'].cpu().numpy()
    ax2.contourf(x1, x2, p, cmap='OrRd')
    ax2.set(xlabel='x0',
            ylabel='x1',
            xlim=xlim,
            ylim=ylim,
            title='Posterior \n(vanilla HMC)')
    sns.kdeplot(vanilla_samples[:, 0], vanilla_samples[:, 1], ax=ax2)

    # 3(a). Fit a diagonal normal autoguide
    logging.info('\nFitting a DiagNormal autoguide ...')
    guide = AutoDiagonalNormal(model, init_scale=0.05)
    fit_guide(guide, args)
    with pyro.plate('N', args.num_samples):
        guide_samples = guide()['x'].detach().cpu().numpy()

    ax3.contourf(x1, x2, p, cmap='OrRd')
    ax3.set(xlabel='x0',
            ylabel='x1',
            xlim=xlim,
            ylim=ylim,
            title='Posterior \n(DiagNormal autoguide)')
    sns.kdeplot(guide_samples[:, 0], guide_samples[:, 1], ax=ax3)

    # 3(b). Draw samples using NeuTra HMC
    logging.info(
        '\nDrawing samples using DiagNormal autoguide + NeuTra HMC ...')
    neutra = NeuTraReparam(guide.requires_grad_(False))
    neutra_model = poutine.reparam(model, config=lambda _: neutra)
    mcmc = run_hmc(args, neutra_model)
    zs = mcmc.get_samples()['x_shared_latent']
    sns.scatterplot(zs[:, 0], zs[:, 1], alpha=0.2, ax=ax4)
    ax4.set(xlabel='x0',
            ylabel='x1',
            title='Posterior (warped) samples \n(DiagNormal + NeuTra HMC)')

    samples = neutra.transform_sample(zs)
    samples = samples['x'].cpu().numpy()
    ax5.contourf(x1, x2, p, cmap='OrRd')
    ax5.set(xlabel='x0',
            ylabel='x1',
            xlim=xlim,
            ylim=ylim,
            title='Posterior (transformed) \n(DiagNormal + NeuTra HMC)')
    sns.kdeplot(samples[:, 0], samples[:, 1], ax=ax5)

    # 4(a). Fit a BNAF autoguide
    logging.info('\nFitting a BNAF autoguide ...')
    guide = AutoNormalizingFlow(
        model, partial(iterated, args.num_flows, block_autoregressive))
    fit_guide(guide, args)
    with pyro.plate('N', args.num_samples):
        guide_samples = guide()['x'].detach().cpu().numpy()

    ax6.contourf(x1, x2, p, cmap='OrRd')
    ax6.set(xlabel='x0',
            ylabel='x1',
            xlim=xlim,
            ylim=ylim,
            title='Posterior \n(BNAF autoguide)')
    sns.kdeplot(guide_samples[:, 0], guide_samples[:, 1], ax=ax6)

    # 4(b). Draw samples using NeuTra HMC
    logging.info('\nDrawing samples using BNAF autoguide + NeuTra HMC ...')
    neutra = NeuTraReparam(guide.requires_grad_(False))
    neutra_model = poutine.reparam(model, config=lambda _: neutra)
    mcmc = run_hmc(args, neutra_model)
    zs = mcmc.get_samples()['x_shared_latent']
    sns.scatterplot(zs[:, 0], zs[:, 1], alpha=0.2, ax=ax7)
    ax7.set(xlabel='x0',
            ylabel='x1',
            title='Posterior (warped) samples \n(BNAF + NeuTra HMC)')

    samples = neutra.transform_sample(zs)
    samples = samples['x'].cpu().numpy()
    ax8.contourf(x1, x2, p, cmap='OrRd')
    ax8.set(xlabel='x0',
            ylabel='x1',
            xlim=xlim,
            ylim=ylim,
            title='Posterior (transformed) \n(BNAF + NeuTra HMC)')
    sns.kdeplot(samples[:, 0], samples[:, 1], ax=ax8)

    plt.savefig(os.path.join(os.path.dirname(__file__), 'neutra.pdf'))
Exemple #2
0
class BNN_2_layer(PyroModule):
    def __init__(self, in_features, num_hidden, learn_var = False):
        super().__init__()
        self.linear1 = PyroModule[nn.Linear](in_features, num_hidden)
        self.linear1.weight = PyroSample(dist.Normal(0., 1.).expand([num_hidden, in_features]).to_event(2))
        self.linear1.bias = PyroSample(dist.Normal(0., 10.).expand([num_hidden]).to_event(1))
        self.learn_var = learn_var
        self.linear2 = PyroModule[nn.Linear](num_hidden, num_hidden)
        self.linear2.weight = PyroSample(dist.Normal(0., 1.).expand([num_hidden, num_hidden]).to_event(2))
        self.linear2.bias = PyroSample(dist.Normal(0., 10.).expand([num_hidden]).to_event(1))
        if not self.learn_var:
            self.linear3 = PyroModule[nn.Linear](num_hidden, 2)
            self.linear3.weight = PyroSample(dist.Normal(0., 1.).expand([2, num_hidden]).to_event(2))
            self.linear3.bias = PyroSample(dist.Normal(0., 10.).expand([2]).to_event(1))
        else:
            self.linear3 = PyroModule[nn.Linear](num_hidden, 1)
            self.linear3.weight = PyroSample(dist.Normal(0., 1.).expand([1, num_hidden]).to_event(2))
            self.linear3.bias = PyroSample(dist.Normal(0., 10.).expand([1]).to_event(1))

        self.guide = AutoDiagonalNormal(self)

    def forward(self, x, y = None):
        out = self.linear1(x)
        out = F.relu(out)
        out = self.linear2(out)
        out = F.relu(out)
        out = self.linear3(out)
        if not self.learn_var:
            mean = out[:, 0]
            std = F.softplus(out[:, 1])
        else:
            mean = out
            sigma = pyro.sample("sigma", dist.Normal(0., 10.))
            std = F.softplus(sigma)

        with pyro.plate("data", x.shape[0]):
            obs = pyro.sample("obs", dist.Normal(mean, std), obs=y)
        return out

    def train(self, x_train, y_train, num_epoch = 80000, 
              lr = 1e-2, every_epoch_to_print = 1000):
        optimizer = Adam({"lr": lr})
        svi = SVI(self, self.guide, optimizer, 
                  loss = Trace_ELBO())
        pyro.clear_param_store()
        loss_arr = []
        for j in range(num_epoch):
            # calculate the loss and take a gradient step
            loss = svi.step(x_train, y_train)
            loss_arr.append(loss)
            if j % every_epoch_to_print == 0:
                print("[iteration %04d] loss: %.4f" % (j + 1, loss))
        return loss_arr

    def summary(self, samples):
        site_stats = {}
        for k, v in samples.items():
            if (k == "_RETURN" and (not self.learn_var)):
                site_stats[k] = {
                        "mean": torch.mean(v[:, :, 0], 0),
                        "std": torch.mean(F.softplus(v[:, :, 1]), 0)
                }
            else:
                site_stats[k] = {
                        "mean": torch.mean(v, 0),
                        "std": torch.std(v, 0)
                        }
        return site_stats

    def sample(self, x_test, num_samples = 128):
        self.guide.requires_grad_(False)

        predictive = Predictive(self, guide = self.guide, 
                                num_samples = num_samples,
                                return_sites = ("obs", "_RETURN"))
        samples = predictive(x_test)
        pred_summary = self.summary(samples)
        mu = pred_summary["_RETURN"]
        y = pred_summary["obs"]
        mu_mean = mu["mean"]
        mu_std = mu["std"]
        mu_var = mu_std.pow(2)
        y_mean = y["mean"]
        y_std = y["std"]
        y_var = y_std.pow(2)



pyro.clear_param_store()
for j in range(num_iterations):
    # calculate the loss and take a gradient step
    loss = svi.step(x_data, y_data)
    if j % 100 == 0:
        print("[iteration %04d] loss: %.4f" % (j + 1, loss / len(data)))





guide.requires_grad_(False)

for name, value in pyro.get_param_store().items():
    print(name, pyro.param(name))




guide.quantiles([0.25, 0.5, 0.75])





from pyro.infer import Predictive
Exemple #4
0
def bayesian_regression(x_data, y_data, num_iterations):
    # BAYESIAN REGRESSION WITH SVI

    class BayesianRegression(PyroModule):
        def __init__(self, in_features, out_features):
            super().__init__()
            self.linear = PyroModule[nn.Linear](in_features, out_features)
            self.linear.weight = PyroSample(dist.Normal(0., 1.).expand([out_features, in_features]).to_event(2))
            self.linear.bias = PyroSample(dist.Normal(0., 10.).expand([out_features]).to_event(1))

        def forward(self, x, y=None):
            # forward() specifies the data generating process
            sigma = pyro.sample("sigma", dist.Uniform(0., 10.)) # this is the error term (typically called epsilon in regression equations)
            mean = self.linear(x).squeeze(-1)
            with pyro.plate("data", x.shape[0]):
                obs = pyro.sample("obs", dist.Normal(mean, sigma), obs=y)
            return mean

    """
    Guides -- posterior distribution classes

    The guide determines a family of distributions, and SVI aims to find an 
    approximate posterior distribution from this family that has the lowest
    KL divergence from the true posterior.
    """

    model = BayesianRegression(3, 1)

    """
    Under the hood, this defines a guide that uses a Normal distribution with
    learnable parameters corresponding to each sample statement in the model.
    e.g. in our case, this distribution should have a size of (5,) correspoding
    to the 3 regression coefficients for each of the terms, and 1 component
    contributed each by the intercept term and sigma in the model.
    """

    guide = AutoDiagonalNormal(model)

    adam = pyro.optim.Adam({"lr": 0.03}) # note this is from Pyro's optim module, not PyTorch's 
    svi = SVI(model, guide, adam, loss=Trace_ELBO())

    """
    We do not need to pass in learnable parameters to the optimizer
    (unlike the PyTorch example above) since that is determined by the guide
    code and happens behind the scenes within the SVI class automatically.
    To take an ELBO gradient step we simply call the step method of SVI.
    The data argument we pass to SVI.step will be passed to both
    model() and guide().
    """

    pyro.clear_param_store()
    for j in range(num_iterations):
        # calculate the loss and take a gradient step
        loss = svi.step(x_data, y_data)
        if (j+1) % 100 == 0:
            print("[iteration %04d] loss: %.4f" % (j + 1, loss / len(data)))


    # We can examine the optimized parameter values by fetching from Pyro’s param store.

    guide.requires_grad_(False) # not sure what this does

    for name, value in pyro.get_param_store().items():
        print(name, pyro.param(name))


    # This gets us quantiles from the posterior distribution
    guide.quantiles([0.25, 0.5, 0.75])

    """
    Since Bayesian models give you a posterior distribution, 
    model evalution needs to be a compbination of sampling the posterior and
    running the samples through the model.

    We generate 800 samples from our trained model. Internally, this is done
    by first generating samples for the unobserved sites in the guide, and
    then running the model forward by conditioning the sites to values sampled
    from the guide. Refer to the Model Serving section for insight on how the
    Predictive class works.
    """

    def summary(samples):
        site_stats = {}
        for k, v in samples.items():
            site_stats[k] = {
                "mean": torch.mean(v, 0),
                "std": torch.std(v, 0),
                "5%": v.kthvalue(int(len(v) * 0.05), dim=0)[0],
                "95%": v.kthvalue(int(len(v) * 0.95), dim=0)[0],
            }
        return site_stats

    """
    Note that in return_sites, we specify both the outcome ("obs" site) as
    well as the return value of the model ("_RETURN") which captures the
    regression line. Additionally, we would also like to capture the regression
    coefficients (given by "linear.weight") for further analysis.
    """

    predictive = Predictive(model, guide=guide, num_samples=800,
                            return_sites=("linear.weight", "obs", "_RETURN"))
    samples = predictive(x_data)
    pred_summary = summary(samples)