예제 #1
0
def test_information_criterion():
    # milk dataset: https://github.com/rmcelreath/rethinking/blob/master/data/milk.csv
    kcal = torch.tensor([
        0.49, 0.47, 0.56, 0.89, 0.92, 0.8, 0.46, 0.71, 0.68, 0.97, 0.84, 0.62,
        0.54, 0.49, 0.48, 0.55, 0.71
    ])
    kcal_mean = kcal.mean()
    kcal_logstd = kcal.std().log()

    def model():
        mu = pyro.sample("mu", dist.Normal(kcal_mean, 1))
        log_sigma = pyro.sample("log_sigma", dist.Normal(kcal_logstd, 1))
        with pyro.plate("plate"):
            pyro.sample("kcal", dist.Normal(mu, log_sigma.exp()), obs=kcal)

    delta_guide = AutoLaplaceApproximation(model)

    svi = SVI(model,
              delta_guide,
              optim.Adam({"lr": 0.05}),
              loss=Trace_ELBO(),
              num_samples=3000)
    for i in range(100):
        svi.step()

    svi.guide = delta_guide.laplace_approximation()
    posterior = svi.run()

    ic = posterior.information_criterion()
    assert_equal(ic["waic"], torch.tensor(-8.3), prec=0.2)
    assert_equal(ic["p_waic"], torch.tensor(1.8), prec=0.2)
예제 #2
0
def get_svi_posterior(data,
                      demand,
                      svi=None,
                      model=None,
                      guide=None,
                      num_samples=100,
                      filename=''):
    """
    Extract posterior

    :param data: data to be passed to model, guide
    :param demand: demand to be passed to model, guide
    :param svi: svi object
    :param model: pyro model
    :param guide: pyro guide
    :param num_samples: number of samples to generate
    :param filename: param store to load
    :return: posterior
    """

    if svi is None and filename and model and guide:
        pyro.get_param_store().load(filename)

        svi = SVI(model,
                  guide,
                  optim.Adam({"lr": .005}),
                  loss=JitTrace_ELBO(),
                  num_samples=num_samples)

        svi.run(data, demand)

        return svi
    elif svi:
        svi.run(data, demand)
        return svi
    else:
        raise ValueError('Provide svi object or model/guide and filename')
예제 #3
0
def inference(train_x, train_y, num_epochs=2000):
    svi = SVI(model,
              guide,
              optim.Adam({'lr': 0.005}),
              loss=Trace_ELBO(),
              num_samples=1000)

    for i in range(num_epochs):
        elbo = svi.step(train_x, train_y)
        if i % 200 == 0:
            print('Elbo loss : {}'.format(elbo))

    svi_posterior = svi.run(train_x, train_y)
    sites = ['w', 'b', 'sigma']
    for site, values in summary(svi_posterior, sites).items():
        print("Site: {}".format(site))
        print(values, "\n")
예제 #4
0
train = torch.tensor(df.values, dtype=torch.float)

svi = SVI(model,
          guide,
          optim.Adam({"lr": .005}),
          loss=Trace_ELBO(),
          num_samples=1000)
x_data, y_data = train[:, :-1], train[:, 2]
pyro.clear_param_store()
num_iters = 8000 if not smoke_test else 2
for i in range(num_iters):
    elbo = svi.step(x_data, y_data)
    if i % 500 == 0:
        logging.info("Elbo loss: {}".format(elbo))

posterior = svi.run(x_data, y_data)   

sites = ["a", "bA", "bR", "bAR", "sigma"]

for site, values in summary(posterior, sites).items():
    print("Site: {}".format(site))
    print(values, "\n")


def wrapped_model(x_data, y_data):
    pyro.sample("prediction", dist.Delta(model(x_data, y_data)))



# posterior predictive distribution we can get samples from
trace_pred = TracePredictive(wrapped_model,
예제 #5
0
train = torch.tensor(df.values, dtype=torch.float)

svi = SVI(model,
          guide,
          optim.Adam({"lr": .005}),
          loss=Trace_ELBO(),
          num_samples=1000)
is_cont_africa, ruggedness, log_gdp = train[:, 0], train[:, 1], train[:, 2]
pyro.clear_param_store()
num_iters = 8000 if not smoke_test else 2
for i in range(num_iters):
    elbo = svi.step(is_cont_africa, ruggedness, log_gdp)
    if i % 500 == 0:
        logging.info("Elbo loss: {}".format(elbo))

posterior = svi.run(log_gdp, is_cont_africa, ruggedness)

sites = ["a", "bA", "bR", "bAR", "sigma"]

for site, values in summary(posterior, sites).items():
    print("Site: {}".format(site))
    print(values, "\n")


def wrapped_model(is_cont_africa, ruggedness, log_gdp):
    pyro.sample("prediction", Delta(model(is_cont_africa, ruggedness,
                                          log_gdp)))


# posterior predictive distribution we can get samples from
trace_pred = TracePredictive(wrapped_model, posterior, num_samples=1000)
예제 #6
0
    site_stats = {}
    for i in range(marginal.shape[1]):
        site_name = sites[i]
        marginal_site = pd.DataFrame(marginal[:, i]).transpose()
        describe = partial(pd.Series.describe,
                           percentiles=[.05, 0.25, 0.5, 0.75, 0.95])
        site_stats[site_name] = marginal_site.apply(describe, axis=1) \
            [["mean", "std", "5%", "25%", "50%", "75%", "95%"]]
    return site_stats


def wrapped_model(x_data, y_data):
    pyro.sample("prediction", Delta(model(x_data, y_data)))


posterior = svi.run(data[0], data[1][:, -1])

# posterior predictive distribution we can get samples from
trace_pred = TracePredictive(wrapped_model, posterior, num_samples=100)
post_pred = trace_pred.run(data[0], None)
post_summary = summary(post_pred, sites=['prediction', 'obs'])
mu = post_summary["prediction"]
y = post_summary["obs"]
y.insert(0, 'true', data[1].cpu().numpy())

print("sample y data:")
print(y.head(10))

df = pd.DataFrame(y)
nx = df.reset_index()  #insert a first row in Dataframe for index
nx = nx.values  #Convert Dataframe to array
beta_error = np.abs(true_beta - beta_params).mean()
beta_rmse = np.sqrt(np.mean((true_beta - beta_params)**2))
params_resps_error = np.abs(betaInd_tmp - params_resps).mean()
params_resps_rmse = np.sqrt(np.mean((betaInd_tmp - params_resps)**2))

loglik, acc = loglikelihood(alt_attributes, true_choices, alt_av_mat,
                            alpha_params, beta_params, params_resps)

loglik_hyp, _ = loglikelihood(alt_attributes, true_choices, alt_av_mat,
                              alpha_params, beta_params,
                              np.tile(beta_params, [N, T]))

# In[30]:

try:
    svi_posterior = svi.run(train_x, train_y, alt_av_mat_cuda, alt_ids_cuda)
except:
    pass

# In[31]:

L_omega_posterior = EmpiricalMarginal(svi,
                                      sites=["L_omega"
                                             ])._get_samples_and_weights()[0]
L_omega = L_omega_posterior.mean(axis=0)[0].detach().cpu().numpy()
theta_posterior = EmpiricalMarginal(svi,
                                    sites=["theta"
                                           ])._get_samples_and_weights()[0]
L_Omega = torch.mm(torch.diag(theta_posterior.mean(axis=0)[0].sqrt()),
                   L_omega_posterior.mean(axis=0)[0])
L_Omega = L_Omega.detach().cpu().numpy()
예제 #8
0
    for batch_id, data_train in enumerate(training_generator):
        # calculate the loss and take a gradient step
        loss += svi.step(data_train[0], data_train[1][:, -1])
        #loss += svi.step(x, y)
    normalizer_train = len(training_generator.dataset)
    total_epoch_loss_train = loss / normalizer_train

    losses.append(total_epoch_loss_train)
    print("Epoch ", j, " Loss ", total_epoch_loss_train)

plt.plot(losses)
plt.title("ELBO")
plt.xlabel("step")
plt.ylabel("Epoch loss")

posterior = svi.run(data_train[0], data_train[1][:, -1])

# Break
#import Ipython; Ipython.embed()

# Save parameters
pyro.get_param_store().save(f'{experiment_id}_params.pt')

torch.save(model, os.path.join(CHECKPOINT_DIR, f'{experiment_id}_latest'))

#Save Parameters: preferred method
torch.save(net.state_dict(), f'{experiment_id}_state.pt')
#Save everything
torch.save(net, f'{experiment_id}_full.pt')

output = {
def pyro_bayesian(regression_model, y_data):
    def summary(traces, sites):
        marginal = get_marginal(traces, sites)
        site_stats = {}
        for i in range(marginal.shape[1]):
            site_name = sites[i]
            marginal_site = pd.DataFrame(marginal[:, i]).transpose()
            describe = partial(pd.Series.describe,
                               percentiles=[.05, 0.25, 0.5, 0.75, 0.95])
            site_stats[site_name] = marginal_site.apply(describe, axis=1) \
                [["mean", "std", "5%", "25%", "50%", "75%", "95%"]]
        return site_stats

    # CI testing
    assert pyro.__version__.startswith('0.3.0')
    pyro.enable_validation(True)
    pyro.set_rng_seed(1)
    pyro.enable_validation(True)

    from pyro.contrib.autoguide import AutoDiagonalNormal
    guide = AutoDiagonalNormal(model)

    optim = Adam({"lr": 0.03})
    svi = SVI(model, guide, optim, loss=Trace_ELBO(), num_samples=1000)

    train(svi, x_data, y_data, num_iterations, regression_model)

    for name, value in pyro.get_param_store().items():
        print(name, pyro.param(name))

    get_marginal = lambda traces, sites: EmpiricalMarginal(
        traces, sites)._get_samples_and_weights()[0].detach().cpu().numpy()

    posterior = svi.run(x_data, y_data, regression_model)

    # posterior predictive distribution we can get samples from
    trace_pred = TracePredictive(wrapped_model, posterior, num_samples=1000)
    post_pred = trace_pred.run(x_data, None, regression_model)
    post_summary = summary(post_pred, sites=['prediction', 'obs'])
    mu = post_summary["prediction"]
    y = post_summary["obs"]
    predictions = pd.DataFrame({
        "x0": x_data[:, 0],
        "x1": x_data[:, 1],
        "mu_mean": mu["mean"],
        "mu_perc_5": mu["5%"],
        "mu_perc_95": mu["95%"],
        "y_mean": y["mean"],
        "y_perc_5": y["5%"],
        "y_perc_95": y["95%"],
        "true_gdp": y_data,
    })
    # print("predictions=", predictions)
    """we need to prepend `module$$$` to all parameters of nn.Modules since
    # that is how they are stored in the ParamStore
    """
    weight = get_marginal(posterior,
                          ['module$$$linear.weight']).squeeze(1).squeeze(1)
    factor = get_marginal(posterior, ['module$$$factor'])

    # x0, x1, x2"-home_page, x1*x2-factor
    print("weight shape=", weight.shape)
    print("factor shape=", factor.shape)

    fig, ax = plt.subplots(nrows=1, ncols=3, figsize=(12, 6), sharey=True)
    ax[0].hist(weight[:, 0])
    ax[1].hist(weight[:, 1])
    ax[2].hist(factor.squeeze(1))
    plt.show()
예제 #10
0
파일: bnn.py 프로젝트: galfaroi/Insight
def summary(traces, sites):
    marginal = get_marginal(traces, sites)
    site_stats = {}
    for i in range(marginal.shape[1]):
        site_name = sites[i]
        marginal_site = pd.DataFrame(marginal[:, i]).transpose()
        describe = partial(pd.Series.describe, percentiles=[.05, 0.25, 0.5, 0.75, 0.95])
        site_stats[site_name] = marginal_site.apply(describe, axis=1) \
            [["mean", "std", "5%", "25%", "50%", "75%", "95%"]]
    return site_stats

def wrapped_model(x_data, y_data):
    pyro.sample("prediction", Delta(model(x_data, y_data)))

posterior = svi.run(x_test, y_test)
print(posterior)
trace_pred = TracePredictive(wrapped_model,
                             posterior,
                             num_samples=1000)
post_pred = trace_pred.run(x_test, y_test)
post_summary = summary(post_pred, sites= ['prediction', 'obs'])
mu = post_summary["prediction"]
y = post_summary["obs"]
len(y)
mu[:5]
y_test
mu.head()
y.head()
preds = []
for i in range(100):
예제 #11
0
        site_stats[site_name] = marginal_site.apply(describe, axis=1) \
            [["mean", "std", "5%", "25%", "50%", "75%", "95%"]]
    return site_stats


def wrapped_model(x_data, y_data):
    pyro.sample("prediction", Delta(model(x_data, y_data)))


for name, value in pyro.get_param_store().items():
    print(name, pyro.param(name))

for name, value in pyro.get_param_store().items():
    print(name, pyro.param(name).cpu().detach().numpy().mean())

posterior = svi.run(Xtrain, Ytrain)

# Break
#pdb.set_trace()

# posterior predictive distribution we can get samples from
trace_pred = TracePredictive(wrapped_model, posterior, num_samples=100)
post_pred = trace_pred.run(Xtrain, None)  #inputing pca components?
post_summary = summary(post_pred, sites=['prediction', 'obs'])
meuw = post_summary["prediction"]
y = post_summary["obs"]
meuw.insert(0, 'true', np.array(Ytrain.cpu()))
y.insert(0, 'true', np.array(Ytrain.cpu()))

print("sample meuw data:")
print(meuw.head(10))
class GMM(object):
    # Set device to CPU
    device = torch.device('cpu')

    def __init__(self, n_comp=3, infr='svi', n_itr=100, subsample=False):
        assert infr == 'svi' or infr == 'mcmc', 'Only svi and mcmc supported'
        # Load data
        # df = read_data(data_type='nyse')
        data_train, _, data_test, _ = preprocess(ret_type='tensor')
        self.tensor_train = data_train.type(torch.FloatTensor)
        self.tensor_test = data_test.type(torch.FloatTensor)
        self.n_comp = n_comp
        self.infr = infr
        self.shape = self.tensor_train.shape
        self.params = None
        self.weights = None
        self.locs = None
        self.scale = None
        self.mcmc_time = None
        self.svi_time = None
        print(f'Initializing object for inference method {self.infr}')
        if self.infr == 'svi':
            self.guide = None
            self.optim = Adam({'lr': 0.1, 'betas': [0.8, 0.99]})
            self.svi = None
            self.svi_itr = n_itr
            self.elbo_loss = TraceEnum_ELBO(max_plate_nesting=1)
            self.posterior_est = None
            self.resp = None
        else:
            self.num_samples = 250
            self.mcmc = None
            self.warmup_steps = 50
            if subsample:
                self.mcmc_subsample = 0.1
                self.n_obs = int(self.shape[0] * self.mcmc_subsample)
            else:
                self.n_obs = self.shape[0]
            # Need to subsample in numpy array because
            # sampling using multinomial takes ages
            # self.tensor = torch.from_numpy(np.random.choice(data, self.n_obs)
            #                                ).type(torch.FloatTensor)

        # Initialize model
        self.model()

    ##################
    # Model definition
    ##################
    @config_enumerate
    def model(self):
        # Global variables
        weights = pyro.sample('weights',
                              dist.Dirichlet(0.5 * torch.ones(self.n_comp)))

        with pyro.plate('components', self.n_comp):
            locs = pyro.sample('locs',
                               dist.MultivariateNormal(
                                   torch.zeros(self.shape[1]),
                                   torch.eye(self.shape[1]))
                               )
            scale = pyro.sample('scale', dist.LogNormal(0., 2.))

        lis = []
        for i in range(self.n_comp):
            t = torch.eye(self.shape[1]) * scale[i]
            lis.append(t)
        f = torch.stack(lis)

        with pyro.plate('data', self.shape[0]):
            # Local variables.
            assignment = pyro.sample('assignment', dist.Categorical(weights))
            pyro.sample('obs', dist.MultivariateNormal(locs[assignment],
                                                       f[assignment]),
                        obs=self.tensor_train)

    ##################
    # SVI
    ##################
    def guide_autodelta(self):
        self.guide = AutoDelta(poutine.block(self.model,
                                             expose=['weights',
                                                     'locs',
                                                     'scale']))

    def guide_autodiagnorm(self):
        self.guide = AutoDiagonalNormal(poutine.block(self.model,
                                                      expose=['weights',
                                                              'locs',
                                                              'scale']))

    def guide_multivariatenormal(self):
        self.guide = AutoMultivariateNormal(poutine.block(self.model,
                                                          expose=['weights',
                                                                  'locs',
                                                                  'scale']))

    def guide_manual(self):
        # Define priors
        weights_alpha = pyro.param('weights_alpha',
                                   (1. / self.n_comp) * torch.ones(
                                       self.n_comp),
                                   constraint=constraints.simplex)
        scale_loc = pyro.param('scale_loc',
                               torch.rand(1).expand([self.n_comp]),
                               constraint=constraints.positive)
        scale_scale = pyro.param('scale_scale',
                                 torch.rand(1),
                                 constraint=constraints.positive)
        loc_loc = pyro.param('loc_loc',
                             torch.zeros(self.shape[1]),
                             constraint=constraints.positive)
        loc_scale = pyro.param('loc_scale',
                               torch.ones(1),
                               constraint=constraints.positive)

        # Global variables
        weights = pyro.sample('weights', dist.Dirichlet(weights_alpha))
        with pyro.plate('components', self.n_comp):
            locs = pyro.sample('locs',
                               dist.MultivariateNormal(loc_loc, torch.eye(
                                   self.shape[1]) * loc_scale))
            scale = pyro.sample('scale',
                                dist.LogNormal(scale_loc, scale_scale))

        with pyro.plate('data', self.shape[0]):
            # Local variables.
            assignment = pyro.sample('assignment', dist.Categorical(weights))

        return locs, scale, assignment

    def optimizer(self):
        self.optim = pyro.optim.Adam({'lr': 0.1, 'betas': [0.8, 0.99]})

    def initialize(self, seed):
        self.set_seed(seed)
        self.clear_params()
        return self.run_svi()

    def init_svi(self):
        self.svi = SVI(self.model, self.guide, self.optim, loss=self.elbo_loss)

    def run_svi(self):
        if self.guide is None:
            self.guide = self.guide_manual
        self.init_svi()
        loss = self.svi.loss(self.model, self.guide)
        return loss

    def best_start(self):
        # Choose the best among 100 random initializations.
        print("Determining best seed for initialization")
        loss, seed = min((self.initialize(seed), seed) for seed in range(100))
        self.initialize(seed)
        print("Best seed determined after 100 random initializations:")
        print('seed = {}, initial_loss = {}'.format(seed, loss))

    def params(self):
        self.params = pyro.get_param_store()
        return self.params

    def register_params(self):
        gradient_norms = defaultdict(list)
        for nam, value in self.params.named_parameters():
            value.register_hook(lambda g, name=nam: gradient_norms[nam].
                                append(g.norm().item()))

    def get_svi_estimates_auto_guide(self):
        estimates = self.guide()
        self.weights = estimates['weights']
        self.locs = estimates['locs']
        self.scale = estimates['scale']
        return self.weights, self.locs, self.scale

    def get_mean_svi_est_manual_guide(self):
        svi_posterior = self.svi.run()

        sites = ["weights", "scale", "locs"]
        svi_samples = {
            site: EmpiricalMarginal(svi_posterior, sites=site).
            enumerate_support().detach().cpu().numpy() for site in sites
        }

        self.posterior_est = dict()
        for item in svi_samples.keys():
            self.posterior_est[item] = torch.tensor(
                svi_samples[item].mean(axis=0))

        return self.posterior_est

    ##################
    # MCMC
    ##################
    def init_mcmc(self, seed=42):
        self.set_seed(seed)
        kernel = NUTS(self.model)
        self.mcmc = MCMC(kernel, num_samples=self.num_samples,
                         warmup_steps=self.warmup_steps)
        print("Initialized MCMC with NUTS kernal")

    def run_mcmc(self):
        self.clear_params()
        print("Initializing MCMC")
        self.init_mcmc()
        print(f'Running MCMC using NUTS with num_obs = {self.n_obs}')
        self.mcmc.run()

    def get_mcmc_samples(self):
        return self.mcmc.get_samples()

    ##################
    # Inference
    ##################
    def inference(self):
        if self.infr == 'svi':
            start = time.time()
            # Initialize with best seed
            self.best_start()
            # Run SVI iterations
            print("Running SVI iterations")
            losses = []
            for i in range(self.svi_itr):
                loss = self.svi.step()
                losses.append(loss)
                # print('.' if i % 100 else '\n', end='')
                end = time.time()
                self.svi_time = (end - start)
            return losses
        else:
            start = time.time()
            self.run_mcmc()
            end = time.time()
            self.mcmc_time = (end - start)
            return self.get_mcmc_samples()

    # Get posterior responsibilities
    def get_posterior_resp(self):
        '''
        Formula:
        k: cluster index
        p(c=k|x) = w_k * N(x|mu_k, sigma_k) / sum(w_k * N(x|mu_k, sigma_k))
        '''
        w = self.posterior_est['weights']
        lo = self.posterior_est['locs']
        s = self.posterior_est['scale']
        prob_list = []
        lis = []
        for i in range(self.n_comp):
            t = torch.eye(self.shape[1]) * s[i]
            lis.append(t)
        f = torch.stack(lis)
        distri = dist.MultivariateNormal(lo, f)
        for d in self.tensor_test:
            numerator = w * torch.exp(distri.log_prob(d))
            denom = numerator.sum()
            probs = numerator / denom
            prob_list.append(probs)

        self.resp = torch.stack(prob_list)

        return self.resp

    ##################
    # Generate stats
    ##################
    def generate_stats(self):
        if self.svi is not None:
            svi_stats = dict({'num_samples': self.shape[0],
                              'num_iterations': self.svi_itr,
                              'exec_time': self.svi_time})
        else:
            svi_stats = None

        if self.mcmc is not None:
            mcmc_stats = dict(
                {'num_sampl': self.shape[0] * self.mcmc_subsample,
                 'exec_time': self.mcmc_time,
                 'num_samples_generated': self.num_samples,
                 'warmup_steps': self.warmup_steps})
        else:
            mcmc_stats = None

        return [svi_stats, mcmc_stats]

    ##################
    # Static Methods
    #################
    @staticmethod
    def set_seed(seed):
        pyro.set_rng_seed(seed)

    @staticmethod
    def clear_params():
        pyro.clear_param_store()

    @staticmethod
    def plot_svi_convergence(losses):
        plt.figure(figsize=(10, 3), dpi=100).set_facecolor('white')
        plt.plot(losses)
        plt.xlabel('iters')
        plt.ylabel('loss')
        plt.title('Convergence of SVI')
        plt.plot()
예제 #13
0
    site_stats = {}
    for i in range(marginal.shape[1]):
        site_name = sites[i]
        marginal_site = pd.DataFrame(marginal[:, i]).transpose()
        describe = partial(pd.Series.describe,
                           percentiles=[.05, 0.25, 0.5, 0.75, 0.95])
        site_stats[site_name] = marginal_site.apply(describe, axis=1) \
            [["mean", "std", "5%", "25%", "50%", "75%", "95%"]]
    return site_stats


def wrapped_model(x_data, y_data):
    pyro.sample("prediction", Delta(model(x_data, y_data)))


posterior = svi.run(data[0], data[1])

# posterior predictive distribution we can get samples from
trace_pred = TracePredictive(wrapped_model, posterior, num_samples=1000)
post_pred = trace_pred.run(data[0], None)

#works up to here
post_summary = summary(post_pred, sites=['prediction', 'obs'])
mu = post_summary["prediction"]
y = post_summary["obs"]
y.insert(0, 'true', data[1].cpu().numpy())

print("sample y data:")
print(y.head(10))

print("mu_mean")
예제 #14
0
k = X_tensor.shape[1] - 1
frequentist_model = TorchLogisticRegression(k)
q = AutoDiagonalNormal(bayes_logistic)
svi = SVI(bayes_logistic,
          q,
          Adam({"lr": 1e-2}),
          loss=Trace_ELBO(),
          num_samples=1000)

pyro.clear_param_store()
for i in range(3000):
    elbo = svi.step(X_tensor, y_tensor)
    if not i % 100:
        print(elbo / X_tensor.size(0))

svi_meanfield_posterior = svi.run(X_tensor, y_tensor)
new_sites = [f"parameter_{i}" for i in X_train.columns]
sites = ["bayes_logistic$$$linear.weight", "bayes_logistic$$$linear.bias"]


old_svi_samples = \
    {site: EmpiricalMarginal(svi_meanfield_posterior, sites=site)
     .enumerate_support().detach().cpu() for site in sites}
svi_samples = seperate_posteriors(old_svi_samples, new_sites)

fig, axs = plt.subplots(nrows=5,
                        ncols=2,
                        figsize=(12, 10),
                        sharex=True,
                        sharey=True)
fig.suptitle("Marginal Posterior Distributions", fontsize=16)
    site_stats = {}
    for i in range(marginal.shape[1]):
        site_name = sites[i]
        marginal_site = pd.DataFrame(marginal[:, i]).transpose()
        describe = partial(pd.Series.describe, percentiles=[.05, 0.25, 0.5, 0.75, 0.95])
        site_stats[site_name] = marginal_site.apply(describe, axis=1) \
            [["mean", "std", "5%", "25%", "50%", "75%", "95%"]]
    return site_stats

def wrapped_model(x_data, y_data):
    pyro.sample("prediction", Delta(model(x_data, y_data)))




posterior = svi.run(data_train[0], data_train[1][:,-1])


# posterior predictive distribution we can get samples from
trace_pred = TracePredictive(wrapped_model,
                             posterior,
                             num_samples=1000)
post_pred = trace_pred.run(data_train[0], None)  #check Why data_train[0] ?
post_summary = summary(post_pred, sites= ['prediction', 'obs'])
mu = post_summary["prediction"]
y = post_summary["obs"]
y.insert(0, 'true', data_train[1].cpu().numpy())

print("sample y data:")
print(y.head(10))