Example #1
0
    def gridsearch_var0(self, val_loader, ood_loader, interval, n_classes=10, lam=1):
        vals, var0s = [], []
        pbar = tqdm(interval)

        for var0 in pbar:
            self.estimate_variance(var0)

            if n_classes == 2:
                preds_in, y_in = lutil.predict_binary(val_loader, self, 10, return_targets=True)
                preds_out = lutil.predict_binary(ood_loader, self, 10)

                loss_in = F.binary_cross_entropy(preds_in.squeeze(), y_in.float())
                loss_out = F.binary_cross_entropy(preds_out.squeeze(), torch.ones_like(y_in)*0.5)
            else:
                preds_in, y_in = lutil.predict(val_loader, self, n_samples=5, return_targets=True)
                preds_out = lutil.predict(ood_loader, self, n_samples=5)

                loss_in = F.nll_loss(torch.log(preds_in + 1e-8), y_in)
                loss_out = -torch.log(preds_out + 1e-8).mean()

            loss = loss_in + lam * loss_out

            vals.append(loss)
            var0s.append(var0)

            pbar.set_description(f'var0: {var0:.5f}, Loss-in: {loss_in:.3f}, Loss-out: {loss_out:.3f}, Loss: {loss:.3f}')

        best_var0 = var0s[np.argmin(vals)]

        return best_var0
_, time_inf = timing(lambda: model_dla.get_hessian(train_loader, binary=True))

interval = torch.linspace(1e-4, 1e-3, 100)
var0 = model_dla.gridsearch_var0(val_loader,
                                 ood_loader,
                                 interval,
                                 n_classes=2,
                                 lam=0.5)
# var0 = torch.tensor(0.00019090909336227924).cuda()  # optimal
print(var0.item())

model_dla.estimate_variance(var0)

# In-distribution
py_in, time_pred = timing(
    lambda: lutil.predict_binary(test_loader, model_dla).cpu().numpy())
acc_in = np.mean((py_in >= 0.5) == targets)
mmc = np.maximum(py_in, 1 - py_in).mean()
# ece, mce = get_calib(py_in, targets)
ece, mce = 0, 0
save_res_ood(tab_ood['CIFAR10 - CIFAR10'], mmc)
# save_res_cal(tab_cal['DLA'], ece, mce)
print(
    f'[In, DiagLaplace] Time: {time_inf:.1f}/{time_pred:.1f}s; Accuracy: {acc_in:.3f}; ECE: {ece:.3f}; MCE: {mce:.3f}; MMC: {mmc:.3f}'
)

# Out-distribution
py_out = lutil.predict_binary(test_loader_SVHN, model_dla).cpu().numpy()
conf_svhn = get_confidence(py_out, binary=True)
mmc = conf_svhn.mean()
auroc = get_auroc_binary(py_in, py_out)
Example #3
0
_, time_inf = timing(lambda: model_dla.get_hessian(train_loader, binary=True))

interval = torch.linspace(1e-4, 0.05, 100)
var0 = model_dla.gridsearch_var0(val_loader,
                                 ood_loader,
                                 interval,
                                 n_classes=2,
                                 lam=0.5)
# var0 = torch.tensor(0.0037).cuda()  # optimal
print(var0)

model_dla.estimate_variance(var0)

#In-distribution
py_in, time_pred = timing(
    lambda: lutil.predict_binary(test_loader, model_dla).cpu().numpy())
acc_in = np.mean((py_in >= 0.5) == targets)
conf_in = get_confidence(py_in, binary=True)
mmc = conf_in.mean()
# ece, mce = get_calib(py_in, targets)
ece, mce = 0, 0
save_res_ood(tab_ood['MNIST - MNIST'], mmc)
# save_res_cal(tab_cal['MAP'], ece, mce)
print(
    f'[In, DLA] Time: {time_inf:.1f}/{time_pred:.1f}s; Accuracy: {acc_in:.3f}; ECE: {ece:.3f}; MCE: {mce:.3f}; MMC: {mmc:.3f}'
)

# Out-distribution EMNIST
py_out = lutil.predict_binary(test_loader_EMNIST, model_dla).cpu().numpy()
conf_emnist = get_confidence(py_out, binary=True)
mmc_emnist = conf_emnist.mean()