+ str(len(pred_interval_x1)) \
    + ', ' \
    + str(len(pred_interval_x2)) \
    + ')...'

for k in range(num_chains):
    pred_posterior = np.empty([len(pred_interval_x1), len(pred_interval_x2)])
    nums_dropped_samples = np.empty(
        [len(pred_interval_x1), len(pred_interval_x2)], dtype=np.int64)

    for i in range(len(pred_interval_x1)):
        for j in range(len(pred_interval_x2)):
            print(verbose_msg.format(k + 1, i + 1, j + 1))

            integral, num_dropped_samples = model.predictive_posterior(
                chain_lists.vals['sample'][k],
                torch.tensor([[pred_interval_x1[i], pred_interval_x2[j]]],
                             dtype=dtype), torch.tensor([[1.]], dtype=dtype))
            pred_posterior[i, j] = integral.item()
            nums_dropped_samples[i, j] = num_dropped_samples

    np.savetxt(
        sampler_output_run_paths[k].joinpath('pred_posterior_on_grid.csv'),
        pred_posterior,
        delimiter=',')
    np.savetxt(sampler_output_run_paths[k].joinpath(
        'pred_posterior_on_grid_num_dropped_samples.csv'),
               nums_dropped_samples,
               fmt='%d',
               delimiter=',')
chain_lists = ChainLists.from_file(sampler_output_run_paths,
                                   keys=['sample'],
                                   dtype=dtype)

# %% Drop burn-in samples

for i in range(num_chains):
    chain_lists.vals['sample'][i] = chain_lists.vals['sample'][i][
        pred_iter_thres:]

# %% Compute chain means

means = chain_lists.mean()

# %% Make and save predictions

for k in range(num_chains):
    test_pred_probs = np.empty([len(test_dataloader)])

    for i, (x, _) in enumerate(test_dataloader):
        integral, _ = model.predictive_posterior([means[k, :]], x,
                                                 torch.tensor([[1.]],
                                                              dtype=dtype))
        test_pred_probs[i] = integral.item()

    test_preds = test_pred_probs > 0.5

    np.savetxt(sampler_output_run_paths[k].joinpath('preds_via_mean.txt'),
               test_preds,
               fmt='%d')