def gridsearch_var0(self, val_loader, ood_loader, interval, n_classes=10, lam=1): vals, var0s = [], [] pbar = tqdm(interval) for var0 in pbar: self.estimate_variance(var0) if n_classes == 2: preds_in, y_in = lutil.predict_binary(val_loader, self, 10, return_targets=True) preds_out = lutil.predict_binary(ood_loader, self, 10) loss_in = F.binary_cross_entropy(preds_in.squeeze(), y_in.float()) loss_out = F.binary_cross_entropy(preds_out.squeeze(), torch.ones_like(y_in)*0.5) else: preds_in, y_in = lutil.predict(val_loader, self, n_samples=5, return_targets=True) preds_out = lutil.predict(ood_loader, self, n_samples=5) loss_in = F.nll_loss(torch.log(preds_in + 1e-8), y_in) loss_out = -torch.log(preds_out + 1e-8).mean() loss = loss_in + lam * loss_out vals.append(loss) var0s.append(var0) pbar.set_description(f'var0: {var0:.5f}, Loss-in: {loss_in:.3f}, Loss-out: {loss_out:.3f}, Loss: {loss:.3f}') best_var0 = var0s[np.argmin(vals)] return best_var0
var0 = torch.tensor(1.2328e-07).float().cuda() print(var0) model_dla.estimate_variance(var0) torch.save(model_dla.state_dict(), f'./pretrained_models/CIFAR10_{args.type}_dla.pt') else: time_inf = 0 model_dla.load_state_dict( torch.load(f'./pretrained_models/CIFAR10_{args.type}_dla.pt')) model_dla.eval() # In-distribution py_in, time_pred = timing( lambda: lutil.predict(test_loader, model_dla).cpu().numpy()) acc_in = np.mean(np.argmax(py_in, 1) == targets) conf_in = get_confidence(py_in) mmc = conf_in.mean() ece, mce = get_calib(py_in, targets) save_res_ood(tab_ood['CIFAR10 - CIFAR10'], mmc) save_res_cal(tab_cal['DLA'], ece, mce) print( f'[In, DiagLaplace] Time: {time_inf:.1f}/{time_pred:.1f}s; Accuracy: {acc_in:.3f}; ECE: {ece:.3f}; MCE: {mce:.3f}; MMC: {mmc:.3f}' ) # Out-distribution py_out = lutil.predict(test_loader_SVHN, model_dla).cpu().numpy() conf_svhn = get_confidence(py_out) mmc = conf_svhn.mean() auroc = get_auroc(py_in, py_out)