def get_conf(dataloader, delta, model, mu, S, T=1, laplace=True):
        if laplace:
            py = llla_binary.predict(dataloader, model, mu, S, apply_sigm=True, delta=delta)
        else:
            py = predict_binary(dataloader, model, delta=delta, apply_sigm=True, T=T)

        # abs_z = torch.abs(z)
        conf = get_confidence(py.cpu().numpy(), binary=True)

        return conf
model = load_model()
hessians, time_inf = timing(
    lambda: llla_binary.get_hessian(model, train_loader, mnist=False))

# interval = torch.tensor(np.linspace(1, 1000, 100)).cuda()
# var0 = llla_binary.gridsearch_var0(model, hessians, val_loader, ood_loader, interval, lam=0.5)
var0 = 293.6364  # optimal var0
print(var0)

mu, S = llla_binary.estimate_variance(var0, hessians)
# print(np.linalg.eigvalsh(S.cpu().numpy()).min())

# In-distribution
py_in, time_pred = timing(
    lambda: llla_binary.predict(test_loader, model, mu, S).cpu().numpy())
acc_in = np.mean((py_in >= 0.5) == targets)
conf_in = get_confidence(py_in, binary=True)
mmc = conf_in.mean()
# ece, mce = get_calib(py_in, targets)
ece, mce = 0, 0
save_res_ood(tab_ood['CIFAR10 - CIFAR10'], mmc)
# save_res_cal(tab_cal['MAP'], ece, mce)
print(
    f'[In, LLLA] Time: NA/{time_pred:.1f}s; Accuracy: {acc_in:.3f}; ECE: {ece:.3f}; MCE: {mce:.3f}; MMC: {mmc:.3f}'
)

# Out-distribution - SVHN
py_out = llla_binary.predict(test_loader_SVHN, model, mu, S).cpu().numpy()
conf_svhn = get_confidence(py_out, binary=True)
mmc = conf_svhn.mean()