예제 #1
0
 def test_inference_data_no_posterior(self, data, eight_schools_params):
     posterior_samples = data.obj.get_samples()
     model = data.obj.kernel.model
     posterior_predictive = Predictive(model, posterior_samples)(
         eight_schools_params["J"], torch.from_numpy(eight_schools_params["sigma"]).float()
     )
     prior = Predictive(model, num_samples=500)(
         eight_schools_params["J"], torch.from_numpy(eight_schools_params["sigma"]).float()
     )
     idata = from_pyro(
         prior=prior,
         posterior_predictive=posterior_predictive,
         coords={"school": np.arange(eight_schools_params["J"])},
         dims={"theta": ["school"], "eta": ["school"]},
     )
     test_dict = {"posterior_predictive": ["obs"], "prior": ["mu", "tau", "eta", "obs"]}
     fails = check_multiple_attrs(test_dict, idata)
     assert not fails
예제 #2
0
def predict(Xpred=None, data_pars={}, compute_pars=None, out_pars={}, **kw):
    global model, session

    compute_pars2 = model.compute_pars if compute_pars is None else compute_pars
    num_samples = compute_pars2.get('num_samples', 300)

    ###### Data load
    if Xpred is None:
        Xpred = get_dataset(data_pars, task_type="predict")
    cols_Xpred = list(Xpred.columns)

    max_size = compute_pars2.get('max_size', len(Xpred))

    Xpred    = Xpred.iloc[:max_size, :]
    Xpred_   = torch.tensor(Xpred.values, dtype=torch.float)

    ###### Post processing normalization
    post_process_fun = model.model_pars.get('post_process_fun', None)
    if post_process_fun is None:
        def post_process_fun(y):
            return y

    from pyro.infer import Predictive
    def summary(samples):
        site_stats = {}
        for k, v in samples.items():
            site_stats[k] = {
                "mean": torch.mean(v, 0),
                "std": torch.std(v, 0),
            }
        return site_stats

    # If the model is loaded, it drops the guide param if it's None
    guide      = getattr(model, "guide", None)
    predictive = Predictive(model.model, guide=guide, num_samples=num_samples,
                            return_sites=("linear.weight", "obs", "_RETURN"))
    pred_samples = predictive(Xpred_)
    pred_summary = summary(pred_samples)

    mu = pred_summary["_RETURN"]
    y  = pred_summary["obs"]
    dd = {
        "mu_mean": post_process_fun(mu["mean"].detach().numpy()),
        "y_mean": post_process_fun(y["mean"].detach().numpy()),
    }
    for i, col in enumerate(cols_Xpred):
        dd[col] = Xpred[col].values  # "major_PHYSICS": x_data[:, -8],

    ypred_mean = pd.DataFrame(dd)
    model.pred_summary = {'pred_mean': ypred_mean, 'pred_summary': pred_summary, 'pred_samples': pred_samples}
    print('stored in model.pred_summary')

    ypred_proba = None  ### No proba
    if compute_pars.get("probability", False):
         ypred_proba = model.model.predict_proba(Xpred)
    return dd['y_mean'], ypred_proba
예제 #3
0
def test_posterior_predictive_svi_manual_guide(parallel):
    true_probs = torch.ones(5) * 0.7
    num_trials = torch.ones(5) * 1000
    num_success = dist.Binomial(num_trials, true_probs).sample()
    conditioned_model = poutine.condition(model, data={"obs": num_success})
    svi = SVI(conditioned_model, beta_guide, optim.Adam(dict(lr=1.0)),
              Trace_ELBO())
    for i in range(1000):
        svi.step(num_trials)
    posterior_predictive = Predictive(model,
                                      guide=beta_guide,
                                      num_samples=10000,
                                      parallel=parallel,
                                      return_sites=["_RETURN"])
    marginal_return_vals = posterior_predictive.get_samples(
        num_trials)["_RETURN"]
    assert_close(marginal_return_vals.mean(dim=0),
                 torch.ones(5) * 700,
                 rtol=0.05)
예제 #4
0
def sample_prior(model, num_samples, sites=None):
    return {
        k: v.detach().numpy()
        for k, v in Predictive(
            model,
            {},
            return_sites=sites,
            num_samples=num_samples
        )().items()
    }
예제 #5
0
def test_posterior_predictive_svi_auto_diag_normal_guide(return_trace):
    true_probs = torch.ones(5) * 0.7
    num_trials = torch.ones(5) * 1000
    num_success = dist.Binomial(num_trials, true_probs).sample()
    conditioned_model = poutine.condition(model, data={"obs": num_success})
    guide = AutoDiagonalNormal(conditioned_model)
    svi = SVI(conditioned_model, guide, optim.Adam(dict(lr=0.1)), Trace_ELBO())
    for i in range(1000):
        svi.step(num_trials)
    posterior_predictive = Predictive(
        model, guide=guide, num_samples=10000, parallel=True
    )
    if return_trace:
        marginal_return_vals = posterior_predictive.get_vectorized_trace(
            num_trials
        ).nodes["obs"]["value"]
    else:
        marginal_return_vals = posterior_predictive.get_samples(num_trials)["obs"]
    assert_close(marginal_return_vals.mean(dim=0), torch.ones(5) * 700, rtol=0.05)
예제 #6
0
def marginal(guide, num_samples=25):
    posterior_predictive = Predictive(guide, num_samples=num_samples)
    posterior_samples = posterior_predictive.forward(data)

    mu_mean = posterior_samples['mu'].detach().mean(dim=0)
    prec_mean = posterior_samples['prec'].detach().mean(dim=0)

    corr_mean = torch.zeros(T, D, D)
    for t in range(T):
        corr_mean[t, ...] = posterior_samples['corr_chol_{}'.format(
            t)].detach().mean(dim=0)

    beta_mean = posterior_samples['beta'].detach().mean(dim=0)
    weights_mean = mix_weights(beta_mean)

    centers, sigmas, _ = truncate(alpha, mu_mean, prec_mean, corr_mean,
                                  weights_mean)

    return centers, sigmas
예제 #7
0
def predict_samples(inputs, digits, pre_trained_cvae, epoch_frac):
    predictive = Predictive(pre_trained_cvae.model,
                            guide=pre_trained_cvae.guide,
                            num_samples=1)
    preds = predictive(inputs)
    y_loc = preds['y'].squeeze().detach().cpu().numpy()
    dfs = pd.DataFrame(data=y_loc)
    dfs['digit'] = digits.numpy()
    dfs['epoch'] = epoch_frac
    return dfs
예제 #8
0
 def get_inference_data(self, data, eight_schools_params):
     posterior_samples = data.obj.get_samples()
     model = data.obj.kernel.model
     posterior_predictive = Predictive(
         model, posterior_samples).get_samples(
             eight_schools_params["J"],
             torch.from_numpy(eight_schools_params["sigma"]).float())
     prior = Predictive(model, num_samples=500).get_samples(
         eight_schools_params["J"],
         torch.from_numpy(eight_schools_params["sigma"]).float())
     return from_pyro(
         posterior=data.obj,
         prior=prior,
         posterior_predictive=posterior_predictive,
         coords={"school": np.arange(eight_schools_params["J"])},
         dims={
             "theta": ["school"],
             "eta": ["school"]
         },
     )
예제 #9
0
def evaluate_pointwise_pred_density(model, posterior_samples,
                                    baseball_dataset):
    """
    Evaluate the log probability density of observing the unseen data (season hits)
    given a model and posterior distribution over the parameters.
    """
    _, test, player_names = train_test_split(baseball_dataset)
    at_bats_season, hits_season = test[:, 0], test[:, 1]
    trace = Predictive(model, posterior_samples).get_vectorized_trace(
        at_bats_season, hits_season)
    # Use LogSumExp trick to evaluate $log(1/num_samples \sum_i p(new_data | \theta^{i})) $,
    # where $\theta^{i}$ are parameter samples from the model's posterior.
    trace.compute_log_prob()
    post_loglik = trace.nodes["obs"]["log_prob"]
    # computes expected log predictive density at each data point
    exp_log_density = (post_loglik.logsumexp(0) -
                       math.log(post_loglik.shape[0])).sum()
    logging.info("\nLog pointwise predictive density")
    logging.info("--------------------------------")
    logging.info("{:.4f}\n".format(exp_log_density))
예제 #10
0
def test_end_to_end(model):
    # Test training.
    model = AutoReparam()(model)
    guide = AutoNormal(model)
    svi = SVI(model, guide, Adam({"lr": 1e-9}), Trace_ELBO())
    for step in range(3):
        svi.step()

    # Test prediction.
    predictive = Predictive(model, guide=guide, num_samples=2)
    samples = predictive()
    assert set("abc").issubset(samples.keys())
예제 #11
0
def run_inference(data, gen_model, ode_model, method, iterations=10000, num_particles=1, num_samples=1000, warmup_steps=500, init_scale=0.1,
                  seed=12, lr=0.5, return_sites="_RETURN"):
    torch_data = torch.tensor(data, dtype=torch.float)
    if isinstance(ode_model, ForwardSensManualJacobians) or \
            isinstance(ode_model, ForwardSensTorchJacobians):
        ode_op = ForwardSensOp
    elif isinstance(ode_model, AdjointSensManualJacobians) or \
            isinstance(ode_model, AdjointSensTorchJacobians):
        ode_op = AdjointSensOp
    else:
        raise ValueError('Unknown sensitivity solver: Use "Forward" or "Adjoint"')
    model = gen_model(ode_op, ode_model)
    pyro.set_rng_seed(seed)
    pyro.clear_param_store()
    if method == 'VI':

        guide = AutoMultivariateNormal(model, init_scale=init_scale)
        optim = AdagradRMSProp({"eta": lr})
        if num_particles == 1:
            svi = SVI(model, guide, optim, loss=Trace_ELBO())
        else:
            svi = SVI(model, guide, optim, loss=Trace_ELBO(num_particles=num_particles,
                                                           vectorize_particles=True))
        loss_trace = []
        t0 = timer.time()
        for j in range(iterations):
            loss = svi.step(torch_data)
            loss_trace.append(loss)

            if j % 500 == 0:
                print("[iteration %04d] loss: %.4f" % (j + 1, np.mean(loss_trace[max(0, j - 1000):j + 1])))
        t1 = timer.time()
        print('VI time: ', t1 - t0)
        predictive = Predictive(model, guide=guide, num_samples=num_samples,
                                return_sites=return_sites)  # "ode_params", "scale",
        vb_samples = predictive(torch_data)
        return vb_samples

    elif method == 'NUTS':

        nuts_kernel = NUTS(model, adapt_step_size=True, init_strategy=init_to_median)

        # mcmc = MCMC(nuts_kernel, num_samples=iterations, warmup_steps=warmup_steps, num_chains=2)
        mcmc = MCMC(nuts_kernel, num_samples=iterations, warmup_steps=warmup_steps, num_chains=1)
        t0 = timer.time()
        mcmc.run(torch_data)
        t1 = timer.time()
        print('NUTS time: ', t1 - t0)
        hmc_samples = {k: v.detach().cpu().numpy() for k, v in mcmc.get_samples().items()}
        return hmc_samples
    else:
        raise ValueError('Unknown method: Use "NUTS" or "VI"')
예제 #12
0
    def predict(self, X):
        data = X
        observed = data.columns
        if 'class' in data.columns:
            observed = observed.drop(['class'])

        served_model = Predictive(
            model=model,
            guide=guide,
            return_sites=('class',),
            num_samples=1)
        predictions = served_model(data, self.graph, observed, self.n_categories)
        return predictions['class'].squeeze(0)
 def predict(self, n_samples):
     predictive = Predictive(self.model,
                             guide=self.guide,
                             num_samples=n_samples,
                             return_sites=("linear.weight", "obs",
                                           "_RETURN"))
     samples = predictive(self.X_test)
     for k, v in samples.items():
         self.pred[k] = {
             "mean": torch.mean(v, 0),
             "std": torch.std(v, 0),
         }
     return
예제 #14
0
def predict(x, model, guide, num_samples=30):
    predictive = Predictive(model, guide=guide, num_samples=num_samples)
    # for a single image, output a mean and sd for each multivariate answer?
    
    yhats = predictive(x)["obs"].double()
    # yhats[0] seems to be integers 0 to 9, len 256
    # prediction for one model, for all items in batch
    # 20, 256
    mean = torch.mean(yhats, axis=0)
    std = torch.std(yhats.float(), 0).cpu().numpy()
    # yhats outputs a batch size number of predictions for 20 models
    # yhats seem to be a dictionary of weights
    return mean, std
예제 #15
0
    def sample_posterior(self, num_samples = 200, attempts = 5):

        for i in range(attempts):
            try:

                samples = Predictive(self.model, guide=self.guide, num_samples=num_samples,
                                    return_sites=self.model.get_varnames())()
                return {varname.split('_')[-1] : samples[varname].detach().numpy() for varname in self.model.get_varnames()}

            except ValueError:
                pass

        raise ValueError('Posterior contains improper values')
예제 #16
0
def test_broken_plates_smoke(backend):
    def model():
        with pyro.plate("i", 2):
            a = pyro.sample("a", dist.Normal(0, 1))
        pyro.sample("b", dist.Normal(a.mean(-1), 1), obs=torch.tensor(0.0))

    guide = AutoGaussian(model, backend=backend)
    svi = SVI(model, guide, ClippedAdam({"lr": 1e-8}), Trace_ELBO())
    for step in range(2):
        with xfail_if_not_implemented():
            svi.step()
    guide()
    predictive = Predictive(model, guide=guide, num_samples=2)
    predictive()
예제 #17
0
    def summarize_posterior(self,
                            raw_expr,
                            encoded_expr,
                            read_depth,
                            num_samples=200,
                            attempts=5):

        logging.info('Sampling posterior ...')

        self.posterior_predictive = Predictive(self.model,
                                               guide=self.guide,
                                               num_samples=num_samples,
                                               return_sites=self.var_names)

        trace = {}
        for i, batch in enumerate(
                self.epoch_batch(raw_expr,
                                 encoded_expr,
                                 read_depth,
                                 batch_size=512)):

            samples = self.posterior_predictive(*batch)

            if i == 0:
                for varname in self.global_vars:
                    new_samples = samples[varname].cpu().detach().numpy()
                    trace[varname] = new_samples

            for varname in self.local_vars:
                new_samples = samples[varname].cpu().detach().numpy()

                if not varname in trace:
                    trace[varname] = []
                trace[varname].append(new_samples)

            logging.info('Done {} batches.'.format(str(i + 1)))

        for varname in self.local_vars:
            trace[varname] = np.concatenate(trace[varname], axis=1)

        for varname in self.var_names:
            self.__setattr__(varname, np.mean(trace[varname], axis=0))

        self.beta = self.get_beta()
        self.gamma = self.get_gamma()
        self.bias = self.get_bias()
        self.bn_mean = self.get_bn_mean()
        self.bn_var = self.get_bn_var()

        return self
예제 #18
0
    def sample_all1(self, init='init_1', batch_size: int = 10):

        predictive = Predictive(self.model,
                                guide=self.guide_i[init],
                                num_samples=batch_size)

        post_samples = {
            k: v.detach().cpu().numpy()
            for k, v in self.step_predictive(predictive, self.x_data,
                                             self.extra_data_train).items()
            if k != "data_target"
        }

        return (post_samples)
예제 #19
0
    def sample_node1(self, node, init, batch_size: int = 10):

        predictive = Predictive(self.model,
                                guide=self.guide_i[init],
                                num_samples=batch_size)

        post_samples = {
            k: v.detach().cpu().numpy()
            for k, v in self.step_predictive(predictive, self.x_data,
                                             self.extra_data_train).items()
            if k == node
        }

        return (post_samples[node])
예제 #20
0
    def custom_l2_loss(self, model, guide, *args, **kwargs):
        # run the guide and trace its execution
        X, num_samples, pred_X_mean, pred_X_var = args

        predictive = Predictive(self, guide = self.guide, 
                                num_samples = num_samples,
                                return_sites = ("obs", "_RETURN"))
        samples = predictive(X)
        pred_summary = self.summary(samples)
        mu = pred_summary["_RETURN"]
        y = pred_summary["obs"]
        mu_mean = mu["mean"]
        mu_std = mu["std"]
        
        return l2_loss(pred_X_mean, pred_X_var, mu_mean, mu_std)
예제 #21
0
    def predict(self, data=None):
        self.latest_data = data
        print("evaluate model")
        predictor = Predictive(self.model, guide=self.guide, num_samples=5000)
        # print("predictor", predictor)

        if data is None:
            data = self.boards_test_data

        prediction = predictor(data)

        self.latest_means = prediction['obs'].T.detach().numpy().mean(axis=1)
        self.latest_stds = prediction['obs'].T.detach().numpy().std(axis=1)

        return self.latest_means
예제 #22
0
def test_shapes(parallel):
    num_samples = 10

    def model():
        x = pyro.sample("x", dist.Normal(0, 1).expand([2]).to_event(1))
        with pyro.plate("plate", 5):
            loc, log_scale = x.unbind(-1)
            y = pyro.sample("y", dist.Normal(loc, log_scale.exp()))
        return dict(x=x, y=y)

    guide = AutoDiagonalNormal(model)

    # Compute by hand.
    vectorize = pyro.plate("_vectorize", num_samples, dim=-2)
    trace = poutine.trace(vectorize(guide)).get_trace()
    expected = poutine.replay(vectorize(model), trace)()

    # Use Predictive.
    predictive = Predictive(model, guide=guide, return_sites=["x", "y"],
                            num_samples=num_samples, parallel=parallel)
    actual = predictive.get_samples()
    assert set(actual) == set(expected)
    assert actual["x"].shape == expected["x"].shape
    assert actual["y"].shape == expected["y"].shape
예제 #23
0
def test_pyrocov_smoke(model, Guide, backend):
    T, P, S, F = 3, 4, 5, 6
    dataset = {
        "features": torch.randn(S, F),
        "local_time": torch.randn(T, P),
        "weekly_strains": torch.randn(T, P, S).exp().round(),
    }

    guide = Guide(model, backend=backend)
    svi = SVI(model, guide, ClippedAdam({"lr": 1e-8}), Trace_ELBO())
    for step in range(2):
        with xfail_if_not_implemented():
            svi.step(dataset)
    guide(dataset)
    predictive = Predictive(model, guide=guide, num_samples=2)
    predictive(dataset)
예제 #24
0
    def sample(self, x_test, num_samples = 128):
        self.guide.requires_grad_(False)

        predictive = Predictive(self, guide = self.guide, 
                                num_samples = num_samples,
                                return_sites = ("obs", "_RETURN"))
        samples = predictive(x_test)
        pred_summary = self.summary(samples)
        mu = pred_summary["_RETURN"]
        y = pred_summary["obs"]
        mu_mean = mu["mean"]
        mu_std = mu["std"]
        mu_var = mu_std.pow(2)
        y_mean = y["mean"]
        y_std = y["std"]
        y_var = y_std.pow(2)
예제 #25
0
def test_predictive(auto_class):
    N, D = 3, 2

    class RandomLinear(nn.Linear, PyroModule):
        def __init__(self, in_features, out_features):
            super().__init__(in_features, out_features)
            self.weight = PyroSample(
                dist.Normal(0., 1.).expand([out_features,
                                            in_features]).to_event(2))
            self.bias = PyroSample(
                dist.Normal(0., 10.).expand([out_features]).to_event(1))

    class LinearRegression(PyroModule):
        def __init__(self):
            super().__init__()
            self.linear = RandomLinear(D, 1)

        def forward(self, x, y=None):
            mean = self.linear(x).squeeze(-1)
            sigma = pyro.sample("sigma", dist.LogNormal(0., 1.))
            with pyro.plate('plate', N):
                return pyro.sample('obs', dist.Normal(mean, sigma), obs=y)

    x, y = torch.randn(N, D), torch.randn(N)
    model = LinearRegression()
    guide = auto_class(model)
    # XXX: Record `y` as observed in the prototype trace
    # Is there a better pattern to follow?
    guide(x, y=y)
    # Test predictive module
    model_trace = poutine.trace(model).get_trace(x, y=None)
    predictive = Predictive(model, guide=guide, num_samples=10)
    pyro.set_rng_seed(0)
    samples = predictive(x)
    for site in prune_subsample_sites(model_trace).stochastic_nodes:
        assert site in samples
    with warnings.catch_warnings():
        warnings.filterwarnings("ignore", category=torch.jit.TracerWarning)
        traced_predictive = torch.jit.trace_module(predictive, {"call": (x, )})
    f = io.BytesIO()
    torch.jit.save(traced_predictive, f)
    f.seek(0)
    predictive_deser = torch.jit.load(f)
    pyro.set_rng_seed(0)
    samples_deser = predictive_deser.call(x)
    # Note that the site values are different in the serialized guide
    assert len(samples) == len(samples_deser)
예제 #26
0
def predict(x, model, guide, transform=False):
    predictive = Predictive(model, guide=guide, num_samples=20)
    if transform != False:
        x = transform(x)
    # for a single image, output a mean and sd for category
    yhats = predictive(x)["obs"].double()
    # yhats[0] seems to be integers 0 to 9, len 256
    # prediction for one model, for all items in batch
    #yhats shape is 20, 256
    # predictions for each item in batch and each model
    # take a mean across 20 models
    # Doesnt make sense to take mean, should take mode
    mode = torch.mode(yhats, axis=0)

    std = torch.std(yhats.float(), 0).numpy()
    # yhats outputs a batch size number of predictions for 20 models
    # yhats seem to be a dictionary of weights
    return mode
    def fit_params_estimate(self, data):
        if self.inference_method == "svi":
            data = torch.tensor(data).float()
            predictive = Predictive(self.model,
                                    guide=self.multi_norm_guide(),
                                    num_samples=1000)
            svi_samples = {
                k: v.reshape(1000).detach().numpy()
                for k, v in predictive(data).items() if not k.startswith("obs")
            }
            return {
                "mean": np.array([v.mean() for v in svi_samples.values()]),
                "std": np.array([v.std() for v in svi_samples.values()])
            }

        elif self.inference_method == "mcmc":
            return {"mean": self.mcmc.get_samples()["fit_params"].mean(0)}
        else:
            raise NotImplementedError
 def predict(self, x_test, num_samples=500, percentiles=(5.0, 50.0, 95.0)):
     x_test=torch.tensor(x_test)
     posterior_predictive= Predictive(self.model, guide=self.guide, num_samples=800, return_sites=("_RETURN",'obs')).forward(x_test)
     
     #confidence intervall
     convidence_intervalls=posterior_predictive['_RETURN'].detach().cpu().numpy()
     
     y_pred=np.percentile(convidence_intervalls, percentiles, axis=0).T
     
     y_median_conv=y_pred[:,1].reshape(-1,1)
     y_lower_upper_quantil_conv=np.concatenate((y_pred[:,1].reshape(-1,1)-y_pred[:,0].reshape(-1,1),y_pred[:,2].reshape(-1,1)-y_pred[:,1].reshape(-1,1)),axis=1)
      
     #predictive intervall
     predictive_intervalls=posterior_predictive['obs'].detach().cpu().numpy()
     y_pred=np.percentile(predictive_intervalls, percentiles, axis=0).T
     
     y_median_pred=y_pred[:,1].reshape(-1,1)
     y_lower_upper_quantil_pred=np.concatenate((y_pred[:,1].reshape(-1,1)-y_pred[:,0].reshape(-1,1),y_pred[:,2].reshape(-1,1)-y_pred[:,1].reshape(-1,1)),axis=1)
      
     return(y_median_conv, y_lower_upper_quantil_conv, y_median_pred, y_lower_upper_quantil_pred)
예제 #29
0
def test_intractable_smoke(backend):
    def model():
        i_plate = pyro.plate("i", 2, dim=-1)
        j_plate = pyro.plate("j", 3, dim=-2)
        with i_plate:
            a = pyro.sample("a", dist.Normal(0, 1))
        with j_plate:
            b = pyro.sample("b", dist.Normal(0, 1))
        with i_plate, j_plate:
            c = pyro.sample("c", dist.Normal(a + b, 1))
            pyro.sample("d", dist.Normal(c, 1), obs=torch.zeros(3, 2))

    guide = AutoGaussian(model, backend=backend)
    svi = SVI(model, guide, ClippedAdam({"lr": 1e-8}), Trace_ELBO())
    for step in range(2):
        with xfail_if_not_implemented():
            svi.step()
    guide()
    predictive = Predictive(model, guide=guide, num_samples=2)
    predictive()
예제 #30
0
    def __init__(self, net_enc, net_dec, predict_samples=None):
        super().__init__()
        self.encoder = net_enc
        self.decoder = net_dec
        self.decoder_guide = AutoDiagonalNormal(
            poutine.block(self.decoder, hide=["obs"]))
        self.optim = Adam({"lr": cfg.BAYESIAN.lr})
        self.svi = SVI(self.decoder,
                       self.decoder_guide,
                       self.optim,
                       loss=Trace_ELBO())
        predict_samples = predict_samples or cfg.BAYESIAN.predict_samples
        self.predictive = Predictive(self.decoder,
                                     guide=self.decoder_guide,
                                     num_samples=predict_samples,
                                     return_sites=['_RETURN'])

        # Freeze encoder
        self.encoder.eval()
        for p in self.encoder.parameters():
            p.requires_grad = False