示例#1
0
    def gridsearch_var0(self, val_loader, ood_loader, interval, n_classes=10, lam=1):
        vals, var0s = [], []
        pbar = tqdm(interval)

        for var0 in pbar:
            self.estimate_variance(var0)

            if n_classes == 2:
                preds_in, y_in = lutil.predict_binary(val_loader, self, 10, return_targets=True)
                preds_out = lutil.predict_binary(ood_loader, self, 10)

                loss_in = F.binary_cross_entropy(preds_in.squeeze(), y_in.float())
                loss_out = F.binary_cross_entropy(preds_out.squeeze(), torch.ones_like(y_in)*0.5)
            else:
                preds_in, y_in = lutil.predict(val_loader, self, n_samples=5, return_targets=True)
                preds_out = lutil.predict(ood_loader, self, n_samples=5)

                loss_in = F.nll_loss(torch.log(preds_in + 1e-8), y_in)
                loss_out = -torch.log(preds_out + 1e-8).mean()

            loss = loss_in + lam * loss_out

            vals.append(loss)
            var0s.append(var0)

            pbar.set_description(f'var0: {var0:.5f}, Loss-in: {loss_in:.3f}, Loss-out: {loss_out:.3f}, Loss: {loss:.3f}')

        best_var0 = var0s[np.argmin(vals)]

        return best_var0
示例#2
0
        var0 = torch.tensor(1.2328e-07).float().cuda()

    print(var0)

    model_dla.estimate_variance(var0)
    torch.save(model_dla.state_dict(),
               f'./pretrained_models/CIFAR10_{args.type}_dla.pt')
else:
    time_inf = 0
    model_dla.load_state_dict(
        torch.load(f'./pretrained_models/CIFAR10_{args.type}_dla.pt'))
    model_dla.eval()

# In-distribution
py_in, time_pred = timing(
    lambda: lutil.predict(test_loader, model_dla).cpu().numpy())
acc_in = np.mean(np.argmax(py_in, 1) == targets)
conf_in = get_confidence(py_in)
mmc = conf_in.mean()
ece, mce = get_calib(py_in, targets)
save_res_ood(tab_ood['CIFAR10 - CIFAR10'], mmc)
save_res_cal(tab_cal['DLA'], ece, mce)
print(
    f'[In, DiagLaplace] Time: {time_inf:.1f}/{time_pred:.1f}s; Accuracy: {acc_in:.3f}; ECE: {ece:.3f}; MCE: {mce:.3f}; MMC: {mmc:.3f}'
)

# Out-distribution
py_out = lutil.predict(test_loader_SVHN, model_dla).cpu().numpy()
conf_svhn = get_confidence(py_out)
mmc = conf_svhn.mean()
auroc = get_auroc(py_in, py_out)