def gridsearch_var0(self, val_loader, ood_loader, interval, n_classes=10, lam=1): vals, var0s = [], [] pbar = tqdm(interval) for var0 in pbar: self.estimate_variance(var0) if n_classes == 2: preds_in, y_in = lutil.predict_binary(val_loader, self, 10, return_targets=True) preds_out = lutil.predict_binary(ood_loader, self, 10) loss_in = F.binary_cross_entropy(preds_in.squeeze(), y_in.float()) loss_out = F.binary_cross_entropy(preds_out.squeeze(), torch.ones_like(y_in)*0.5) else: preds_in, y_in = lutil.predict(val_loader, self, n_samples=5, return_targets=True) preds_out = lutil.predict(ood_loader, self, n_samples=5) loss_in = F.nll_loss(torch.log(preds_in + 1e-8), y_in) loss_out = -torch.log(preds_out + 1e-8).mean() loss = loss_in + lam * loss_out vals.append(loss) var0s.append(var0) pbar.set_description(f'var0: {var0:.5f}, Loss-in: {loss_in:.3f}, Loss-out: {loss_out:.3f}, Loss: {loss:.3f}') best_var0 = var0s[np.argmin(vals)] return best_var0
_, time_inf = timing(lambda: model_dla.get_hessian(train_loader, binary=True)) interval = torch.linspace(1e-4, 1e-3, 100) var0 = model_dla.gridsearch_var0(val_loader, ood_loader, interval, n_classes=2, lam=0.5) # var0 = torch.tensor(0.00019090909336227924).cuda() # optimal print(var0.item()) model_dla.estimate_variance(var0) # In-distribution py_in, time_pred = timing( lambda: lutil.predict_binary(test_loader, model_dla).cpu().numpy()) acc_in = np.mean((py_in >= 0.5) == targets) mmc = np.maximum(py_in, 1 - py_in).mean() # ece, mce = get_calib(py_in, targets) ece, mce = 0, 0 save_res_ood(tab_ood['CIFAR10 - CIFAR10'], mmc) # save_res_cal(tab_cal['DLA'], ece, mce) print( f'[In, DiagLaplace] Time: {time_inf:.1f}/{time_pred:.1f}s; Accuracy: {acc_in:.3f}; ECE: {ece:.3f}; MCE: {mce:.3f}; MMC: {mmc:.3f}' ) # Out-distribution py_out = lutil.predict_binary(test_loader_SVHN, model_dla).cpu().numpy() conf_svhn = get_confidence(py_out, binary=True) mmc = conf_svhn.mean() auroc = get_auroc_binary(py_in, py_out)
_, time_inf = timing(lambda: model_dla.get_hessian(train_loader, binary=True)) interval = torch.linspace(1e-4, 0.05, 100) var0 = model_dla.gridsearch_var0(val_loader, ood_loader, interval, n_classes=2, lam=0.5) # var0 = torch.tensor(0.0037).cuda() # optimal print(var0) model_dla.estimate_variance(var0) #In-distribution py_in, time_pred = timing( lambda: lutil.predict_binary(test_loader, model_dla).cpu().numpy()) acc_in = np.mean((py_in >= 0.5) == targets) conf_in = get_confidence(py_in, binary=True) mmc = conf_in.mean() # ece, mce = get_calib(py_in, targets) ece, mce = 0, 0 save_res_ood(tab_ood['MNIST - MNIST'], mmc) # save_res_cal(tab_cal['MAP'], ece, mce) print( f'[In, DLA] Time: {time_inf:.1f}/{time_pred:.1f}s; Accuracy: {acc_in:.3f}; ECE: {ece:.3f}; MCE: {mce:.3f}; MMC: {mmc:.3f}' ) # Out-distribution EMNIST py_out = lutil.predict_binary(test_loader_EMNIST, model_dla).cpu().numpy() conf_emnist = get_confidence(py_out, binary=True) mmc_emnist = conf_emnist.mean()