if j % 100 == 0: print("[iteration %04d] loss: %.4f" % (j + 1, loss / len(data))) guide.requires_grad_(False) for name, value in pyro.get_param_store().items(): print(name, pyro.param(name)) guide.quantiles([0.25, 0.5, 0.75]) from pyro.infer import Predictive def summary(samples): site_stats = {} for k, v in samples.items(): site_stats[k] = { "mean": torch.mean(v, 0), "std": torch.std(v, 0), "5%": v.kthvalue(int(len(v) * 0.05), dim=0)[0],
def bayesian_regression(x_data, y_data, num_iterations): # BAYESIAN REGRESSION WITH SVI class BayesianRegression(PyroModule): def __init__(self, in_features, out_features): super().__init__() self.linear = PyroModule[nn.Linear](in_features, out_features) self.linear.weight = PyroSample(dist.Normal(0., 1.).expand([out_features, in_features]).to_event(2)) self.linear.bias = PyroSample(dist.Normal(0., 10.).expand([out_features]).to_event(1)) def forward(self, x, y=None): # forward() specifies the data generating process sigma = pyro.sample("sigma", dist.Uniform(0., 10.)) # this is the error term (typically called epsilon in regression equations) mean = self.linear(x).squeeze(-1) with pyro.plate("data", x.shape[0]): obs = pyro.sample("obs", dist.Normal(mean, sigma), obs=y) return mean """ Guides -- posterior distribution classes The guide determines a family of distributions, and SVI aims to find an approximate posterior distribution from this family that has the lowest KL divergence from the true posterior. """ model = BayesianRegression(3, 1) """ Under the hood, this defines a guide that uses a Normal distribution with learnable parameters corresponding to each sample statement in the model. e.g. in our case, this distribution should have a size of (5,) correspoding to the 3 regression coefficients for each of the terms, and 1 component contributed each by the intercept term and sigma in the model. """ guide = AutoDiagonalNormal(model) adam = pyro.optim.Adam({"lr": 0.03}) # note this is from Pyro's optim module, not PyTorch's svi = SVI(model, guide, adam, loss=Trace_ELBO()) """ We do not need to pass in learnable parameters to the optimizer (unlike the PyTorch example above) since that is determined by the guide code and happens behind the scenes within the SVI class automatically. To take an ELBO gradient step we simply call the step method of SVI. The data argument we pass to SVI.step will be passed to both model() and guide(). """ pyro.clear_param_store() for j in range(num_iterations): # calculate the loss and take a gradient step loss = svi.step(x_data, y_data) if (j+1) % 100 == 0: print("[iteration %04d] loss: %.4f" % (j + 1, loss / len(data))) # We can examine the optimized parameter values by fetching from Pyro’s param store. guide.requires_grad_(False) # not sure what this does for name, value in pyro.get_param_store().items(): print(name, pyro.param(name)) # This gets us quantiles from the posterior distribution guide.quantiles([0.25, 0.5, 0.75]) """ Since Bayesian models give you a posterior distribution, model evalution needs to be a compbination of sampling the posterior and running the samples through the model. We generate 800 samples from our trained model. Internally, this is done by first generating samples for the unobserved sites in the guide, and then running the model forward by conditioning the sites to values sampled from the guide. Refer to the Model Serving section for insight on how the Predictive class works. """ def summary(samples): site_stats = {} for k, v in samples.items(): site_stats[k] = { "mean": torch.mean(v, 0), "std": torch.std(v, 0), "5%": v.kthvalue(int(len(v) * 0.05), dim=0)[0], "95%": v.kthvalue(int(len(v) * 0.95), dim=0)[0], } return site_stats """ Note that in return_sites, we specify both the outcome ("obs" site) as well as the return value of the model ("_RETURN") which captures the regression line. Additionally, we would also like to capture the regression coefficients (given by "linear.weight") for further analysis. """ predictive = Predictive(model, guide=guide, num_samples=800, return_sites=("linear.weight", "obs", "_RETURN")) samples = predictive(x_data) pred_summary = summary(samples)