def test(models, writer, criterion, data_loader, epoch): np.random.seed(args.seed) n = args.num_models # turn all the models into vectors vecs = [ utils.sd_to_vector(models[i].state_dict()).clone() for i in range(n) ] swa_vec = vecs[0] for i in range(1, n): swa_vec = swa_vec + vecs[i] swa_vec = swa_vec / n square_vec = vecs[0].pow(2) for i in range(1, n): square_vec = square_vec + vecs[i].pow(2) square_vec = square_vec / n swa_diag_mult = ( (1.0 / math.sqrt(2)) * (square_vec - swa_vec.pow(2)).pow(0.5) * torch.randn_like(swa_vec) ) low_rank_term = (vecs[0] - swa_vec) * torch.randn(1).item() for i in range(1, n): low_rank_term = ( low_rank_term + (vecs[i] - swa_vec) * torch.randn(1).item() ) low_rank_term = (1.0 / math.sqrt(2 * (n - 1))) * low_rank_term out = swa_vec + swa_diag_mult + low_rank_term final_model_sd = models[0].state_dict() utils.vector_to_sd(out, final_model_sd) models[0].load_state_dict(final_model_sd) utils.update_bn(data_loader.train_loader, models[0], device=args.device) torch.save( { "epoch": 0, "iter": 0, "arch": args.model, "state_dicts": [models[0].state_dict()], "optimizers": None, "best_acc1": 0, "curr_acc1": 0, }, os.path.join(args.tmp_dir, f"model_{args.j}.pt"), ) test_acc = 0 metrics = {} return test_acc, metrics
def test(models, writer, criterion, data_loader, epoch): for m in models: m.eval() test_loss = 0 correct = 0 val_loader = data_loader.val_loader Z = np.random.exponential(scale=1.0, size=args.num_models) Z = Z / Z.sum() for ms in zip(*[models[i].modules() for i in range(args.num_models)]): if isinstance(ms[0], nn.Conv2d): ms[0].weight.data = Z[0] * ms[0].weight.data for i in range(1, args.num_models): ms[0].weight.data += Z[i] * ms[i].weight.data elif isinstance(ms[0], nn.BatchNorm2d): ms[0].weight.data = Z[0] * ms[0].weight.data for i in range(1, args.num_models): ms[0].weight.data += Z[i] * ms[i].weight.data ms[0].bias.data = Z[0] * ms[0].bias.data for i in range(1, args.num_models): ms[0].bias.data += Z[i] * ms[i].bias.data model = models[0] utils.update_bn(data_loader.train_loader, model, device=args.device) # model.train() # # for batch_idx, (data, target) in enumerate(data_loader.train_loader): # # data, target = data.to(args.device), target.to(args.device) # # output = model(data) model.eval() with torch.no_grad(): for data, target in val_loader: data, target = data.to(args.device), target.to(args.device) output = model(data) test_loss += criterion(output, target).item() # get the index of the max log-probability pred = output.argmax(dim=1, keepdim=True) correct += pred.eq(target.view_as(pred)).sum().item() test_loss /= len(val_loader) test_acc = float(correct) / len(val_loader.dataset) print( f"\nTest set: Average loss: {test_loss:.4f}, Accuracy: ({test_acc:.4f})\n" ) if args.save: writer.add_scalar(f"test/loss", test_loss, epoch) writer.add_scalar(f"test/acc", test_acc, epoch) metrics = {} return test_acc, metrics
def test(models, writer, criterion, data_loader, epoch): model = models[0] model.zero_grad() model.eval() test_loss = 0 correct = 0 val_loader = data_loader.val_loader model.apply(lambda m: setattr(m, "alpha", 0.5)) # optionally update the bn during training to, but note this slows down things. if args.train_update_bn: utils.update_bn(data_loader.train_loader, model, args.device) with torch.no_grad(): for data, target in val_loader: data, target = data.to(args.device), target.to(args.device) output = model(data) test_loss += criterion(output, target).item() # get the index of the max log-probability pred = output.argmax(dim=1, keepdim=True) correct += pred.eq(target.view_as(pred)).sum().item() test_loss /= len(val_loader) test_acc = float(correct) / len(val_loader.dataset) print( f"\nTest set: Average loss: {test_loss:.4f}, Accuracy: ({test_acc:.4f})\n" ) cossim, l2 = get_stats(model) if args.save: writer.add_scalar(f"test/loss", test_loss, epoch) writer.add_scalar(f"test/acc", test_acc, epoch) writer.add_scalar(f"test/norm", l2, epoch) writer.add_scalar(f"test/cossim", cossim, epoch) metrics = {"norm": l2, "cossim": cossim} return test_acc, metrics
'X', 'Y', 'Train loss', 'Train nll', 'Train error (%)', 'Test nll', 'Test error (%)' ] for i, alpha in enumerate(alphas): for j, beta in enumerate(betas): p = w[0] + alpha * dx * u + beta * dy * v offset = 0 for parameter in base_model.parameters(): size = np.prod(parameter.size()) value = p[offset:offset + size].reshape(parameter.size()) parameter.data.copy_(torch.from_numpy(value)) offset += size utils.update_bn(loaders['train'], base_model) tr_res = utils.test(loaders['train'], base_model, criterion, regularizer) te_res = utils.test(loaders['test'], base_model, criterion, regularizer) tr_loss_v, tr_nll_v, tr_acc_v = tr_res['loss'], tr_res['nll'], tr_res[ 'accuracy'] te_loss_v, te_nll_v, te_acc_v = te_res['loss'], te_res['nll'], te_res[ 'accuracy'] c = get_xy(p, w[0], u, v) grid[i, j] = [alpha * dx, beta * dy] tr_loss[i, j] = tr_loss_v
def evaluate_curve(dir='/tmp/curve/', ckpt=None, num_points=61, dataset='CIFAR10', use_test=True, transform='VGG', data_path=None, batch_size=128, num_workers=4, model_type=None, curve_type=None, num_bends=3, wd=1e-4): args = EvalCurveArgSet(dir=dir, ckpt=ckpt, num_points=num_points, dataset=dataset, use_test=use_test, transform=transform, data_path=data_path, batch_size=batch_size, num_workers=num_workers, model=model_type, curve=curve_type, num_bends=num_bends, wd=wd) if True: q = 1 # parser = argparse.ArgumentParser(description='DNN curve evaluation') # parser.add_argument('--dir', type=str, default='/tmp/eval', metavar='DIR', # help='training directory (default: /tmp/eval)') # # parser.add_argument('--num_points', type=int, default=61, metavar='N', # help='number of points on the curve (default: 61)') # # parser.add_argument('--dataset', type=str, default='CIFAR10', metavar='DATASET', # help='dataset name (default: CIFAR10)') # parser.add_argument('--use_test', action='store_true', # help='switches between validation and test set (default: validation)') # parser.add_argument('--transform', type=str, default='VGG', metavar='TRANSFORM', # help='transform name (default: VGG)') # parser.add_argument('--data_path', type=str, default=None, metavar='PATH', # help='path to datasets location (default: None)') # parser.add_argument('--batch_size', type=int, default=128, metavar='N', # help='input batch size (default: 128)') # parser.add_argument('--num_workers', type=int, default=4, metavar='N', # help='number of workers (default: 4)') # # parser.add_argument('--model', type=str, default=None, metavar='MODEL', # help='model name (default: None)') # parser.add_argument('--curve', type=str, default=None, metavar='CURVE', # help='curve type to use (default: None)') # parser.add_argument('--num_bends', type=int, default=3, metavar='N', # help='number of curve bends (default: 3)') # # parser.add_argument('--ckpt', type=str, default=None, metavar='CKPT', # help='checkpoint to eval (default: None)') # # parser.add_argument('--wd', type=float, default=1e-4, metavar='WD', # help='weight decay (default: 1e-4)') # args = parser.parse_args() os.makedirs(args.dir, exist_ok=True) torch.backends.cudnn.benchmark = True loaders, num_classes = data.loaders(args.dataset, args.data_path, args.batch_size, args.num_workers, args.transform, args.use_test, shuffle_train=False) architecture = getattr(models, args.model) curve = getattr(curves, args.curve) model = curves.CurveNet( num_classes, curve, architecture.curve, args.num_bends, architecture_kwargs=architecture.kwargs, ) model.cuda() checkpoint = torch.load(args.ckpt) model.load_state_dict(checkpoint['model_state']) criterion = F.cross_entropy regularizer = curves.l2_regularizer(args.wd) T = args.num_points ts = np.linspace(0.0, 1.0, T) tr_loss = np.zeros(T) tr_nll = np.zeros(T) tr_acc = np.zeros(T) te_loss = np.zeros(T) te_nll = np.zeros(T) te_acc = np.zeros(T) tr_err = np.zeros(T) te_err = np.zeros(T) dl = np.zeros(T) previous_weights = None columns = [ 't', 'Train loss', 'Train nll', 'Train error (%)', 'Test nll', 'Test error (%)', 'Distance' ] t = torch.FloatTensor([0.0]).cuda() for i, t_value in enumerate(ts): t.data.fill_(t_value) weights = model.weights(t) if previous_weights is not None: dl[i] = np.sqrt(np.sum(np.square(weights - previous_weights))) previous_weights = weights.copy() utils.update_bn(loaders['train'], model, t=t) tr_res = utils.test(loaders['train'], model, criterion, regularizer, t=t) te_res = utils.test(loaders['test'], model, criterion, regularizer, t=t) tr_loss[i] = tr_res['loss'] tr_nll[i] = tr_res['nll'] tr_acc[i] = tr_res['accuracy'] tr_err[i] = 100.0 - tr_acc[i] te_loss[i] = te_res['loss'] te_nll[i] = te_res['nll'] te_acc[i] = te_res['accuracy'] te_err[i] = 100.0 - te_acc[i] values = [ t, tr_loss[i], tr_nll[i], tr_err[i], te_nll[i], te_err[i], dl[i] ] table = tabulate.tabulate([values], columns, tablefmt='simple', floatfmt='10.4f') if i % 40 == 0: table = table.split('\n') table = '\n'.join([table[1]] + table) else: table = table.split('\n')[2] print(table) def stats(values, dl): min = np.min(values) max = np.max(values) avg = np.mean(values) int = np.sum(0.5 * (values[:-1] + values[1:]) * dl[1:]) / np.sum(dl[1:]) return min, max, avg, int tr_loss_min, tr_loss_max, tr_loss_avg, tr_loss_int = stats(tr_loss, dl) tr_nll_min, tr_nll_max, tr_nll_avg, tr_nll_int = stats(tr_nll, dl) tr_err_min, tr_err_max, tr_err_avg, tr_err_int = stats(tr_err, dl) te_loss_min, te_loss_max, te_loss_avg, te_loss_int = stats(te_loss, dl) te_nll_min, te_nll_max, te_nll_avg, te_nll_int = stats(te_nll, dl) te_err_min, te_err_max, te_err_avg, te_err_int = stats(te_err, dl) print('Length: %.2f' % np.sum(dl)) print( tabulate.tabulate([ [ 'train loss', tr_loss[0], tr_loss[-1], tr_loss_min, tr_loss_max, tr_loss_avg, tr_loss_int ], [ 'train error (%)', tr_err[0], tr_err[-1], tr_err_min, tr_err_max, tr_err_avg, tr_err_int ], [ 'test nll', te_nll[0], te_nll[-1], te_nll_min, te_nll_max, te_nll_avg, te_nll_int ], [ 'test error (%)', te_err[0], te_err[-1], te_err_min, te_err_max, te_err_avg, te_err_int ], ], ['', 'start', 'end', 'min', 'max', 'avg', 'int'], tablefmt='simple', floatfmt='10.4f')) np.savez( os.path.join(args.dir, 'curve.npz'), ts=ts, dl=dl, tr_loss=tr_loss, tr_loss_min=tr_loss_min, tr_loss_max=tr_loss_max, tr_loss_avg=tr_loss_avg, tr_loss_int=tr_loss_int, tr_nll=tr_nll, tr_nll_min=tr_nll_min, tr_nll_max=tr_nll_max, tr_nll_avg=tr_nll_avg, tr_nll_int=tr_nll_int, tr_acc=tr_acc, tr_err=tr_err, tr_err_min=tr_err_min, tr_err_max=tr_err_max, tr_err_avg=tr_err_avg, tr_err_int=tr_err_int, te_loss=te_loss, te_loss_min=te_loss_min, te_loss_max=te_loss_max, te_loss_avg=te_loss_avg, te_loss_int=te_loss_int, te_nll=te_nll, te_nll_min=te_nll_min, te_nll_max=te_nll_max, te_nll_avg=te_nll_avg, te_nll_int=te_nll_int, te_acc=te_acc, te_err=te_err, te_err_min=te_err_min, te_err_max=te_err_max, te_err_avg=te_err_avg, te_err_int=te_err_int, )
'step_size': 2.0 / 255, 'random_start': True, 'loss_func': 'xent', } net = AttackPGD(model, config) t = torch.FloatTensor([0.0]).cuda() for i, t_value in enumerate(ts): t.data.fill_(t_value) weights = model.weights(t) if previous_weights is not None: dl[i] = np.sqrt(np.sum(np.square(weights - previous_weights))) previous_weights = weights.copy() utils.update_bn(loaders['train'], model, t=t) tr_res = utils.test(loaders['train'], model, criterion, regularizer, t=t) te_res = utils.test(loaders['test'], model, criterion, regularizer, t=t) te_example_res = utils.test_examples(loaders['test'], net, criterion, regularizer, t=t) # te_example_res = utils.test_examples(loaders['test'], net, criterion, regularizer) tr_loss[i] = tr_res['loss'] tr_nll[i] = tr_res['nll'] tr_acc[i] = tr_res['accuracy'] tr_err[i] = 100.0 - tr_acc[i] te_loss[i] = te_res['loss'] te_nll[i] = te_res['nll'] te_acc[i] = te_res['accuracy']
def test(models, writer, criterion, data_loader, epoch): for m in models: m.eval() model = models[0] test_loss = 0 correct = 0 val_loader = data_loader.val_loader M = 20 acc_bm = torch.zeros(M) conf_bm = torch.zeros(M) count_bm = torch.zeros(M) for ms in zip(*[models[i].modules() for i in range(args.num_models)]): if isinstance(ms[0], nn.Conv2d): ms[0].weight.data = (1.0 / args.num_models) * ms[0].weight.data for i in range(1, args.num_models): ms[0].weight.data += (1.0 / args.num_models) * ms[i].weight.data elif isinstance(ms[0], nn.BatchNorm2d): ms[0].weight.data = (1.0 / args.num_models) * ms[0].weight.data for i in range(1, args.num_models): ms[0].weight.data += (1.0 / args.num_models) * ms[i].weight.data ms[0].bias.data = (1.0 / args.num_models) * ms[0].bias.data for i in range(1, args.num_models): ms[0].bias.data += (1.0 / args.num_models) * ms[i].bias.data utils.update_bn(data_loader.train_loader, model, device=args.device) with torch.no_grad(): for data, target in val_loader: data, target = data.to(args.device), target.to(args.device) output = model(data) test_loss += criterion(output, target).item() # get the index of the max log-probability pred = output.argmax(dim=1, keepdim=True) soft_out = output.softmax(dim=1) correct_vec = pred.eq(target.view_as(pred)) correct += correct_vec.sum().item() for i in range(data.size(0)): conf = soft_out[i][pred[i]] bin_idx = min((conf * M).int().item(), M - 1) acc_bm[bin_idx] += correct_vec[i].float().item() conf_bm[bin_idx] += conf.item() count_bm[bin_idx] += 1.0 test_loss /= len(val_loader) test_acc = float(correct) / len(val_loader.dataset) print( f"\nTest set: Average loss: {test_loss:.4f}, Accuracy: ({test_acc:.4f})\n" ) if args.save: writer.add_scalar(f"test/loss", test_loss, epoch) writer.add_scalar(f"test/acc", test_acc, epoch) ece = 0.0 for i in range(M): ece += (acc_bm[i] - conf_bm[i]).abs().item() ece /= len(val_loader.dataset) print("ece is", ece) metrics = {"ece": ece, "test_loss": test_loss} return test_acc, metrics
"X", "Y", "Train loss", "Train nll", "Train error (%)", "Test nll", "Test error (%)" ] for i, alpha in enumerate(alphas): for j, beta in enumerate(betas): p = w[0] + alpha * dx * u + beta * dy * v offset = 0 for parameter in base_model.parameters(): size = np.prod(parameter.size()) value = p[offset:offset + size].reshape(parameter.size()) parameter.data.copy_(torch.from_numpy(value)) offset += size utils.update_bn(loaders["train"], base_model) tr_res = utils.test(loaders["train"], base_model, criterion, regularizer) te_res = utils.test(loaders["test"], base_model, criterion, regularizer) tr_loss_v, tr_nll_v, tr_acc_v = tr_res["loss"], tr_res["nll"], tr_res[ "accuracy"] te_loss_v, te_nll_v, te_acc_v = te_res["loss"], te_res["nll"], te_res[ "accuracy"] c = get_xy(p, w[0], u, v) grid[i, j] = [alpha * dx, beta * dy] tr_loss[i, j] = tr_loss_v
def test(models, writer, criterion, data_loader, epoch): for i, model in enumerate(models): model.zero_grad() model.eval() if args.layerwise: for m in model.modules(): if isinstance(m, nn.Conv2d) or isinstance(m, nn.BatchNorm2d): Z = np.random.exponential(scale=1.0, size=args.n) Z = Z / Z.sum() for i in range(1, args.n): setattr(m, f"t{i}", Z[i]) else: Z = np.random.exponential(scale=1.0, size=args.n) Z = Z / Z.sum() for m in model.modules(): if isinstance(m, nn.Conv2d) or isinstance(m, nn.BatchNorm2d): for i in range(1, args.n): setattr(m, f"t{i}", Z[i]) test_loss = 0 correct = 0 corrects = [0 for _ in range(10)] M = 20 acc_bm = torch.zeros(M) conf_bm = torch.zeros(M) count_bm = torch.zeros(M) val_loader = data_loader.val_loader for m in models: utils.update_bn(data_loader.train_loader, m, device=args.device) with torch.no_grad(): for data, target in val_loader: data, target = data.to(args.device), target.to(args.device) model_output = models[0](data) model_pred = model_output.argmax(dim=1, keepdim=True) corrects[0] += (model_pred.eq( target.view_as(model_pred)).sum().item()) mean_output = model_output for i, m in enumerate(models[1:]): model_output = m(data) model_pred = model_output.argmax(dim=1, keepdim=True) corrects[i + 1] += (model_pred.eq( target.view_as(model_pred)).sum().item()) mean_output += model_output mean_output /= len(models) # get the index of the max log-probability pred = mean_output.argmax(dim=1, keepdim=True) test_loss += criterion(mean_output, target).item() correct_vec = pred.eq(target.view_as(pred)) correct += correct_vec.sum().item() soft_output = mean_output.softmax(dim=1) for i in range(data.size(0)): conf = soft_output[i][pred[i]] bin_idx = min((conf * M).int().item(), M - 1) acc_bm[bin_idx] += correct_vec[i].float().item() conf_bm[bin_idx] += conf.item() count_bm[bin_idx] += 1.0 test_loss /= len(val_loader) test_acc = float(correct) / len(val_loader.dataset) print( f"\nTest set: Average loss: {test_loss:.4f}, Accuracy: ({test_acc:.4f})\n" ) ece = 0.0 for i in range(M): ece += (acc_bm[i] - conf_bm[i]).abs().item() ece /= len(val_loader.dataset) print("ece is", ece) if args.save: writer.add_scalar(f"test/loss", test_loss, epoch) writer.add_scalar(f"test/acc", test_acc, epoch) corrects_sacled = [ float(corrects[i]) / len(val_loader.dataset) for i in range(len(corrects)) ] metrics = { f"model_{i}_acc": corrects_sacled[i] for i in range(len(corrects_sacled)) } corrects_sacled = np.array(corrects_sacled) metrics["avg_model_acc"] = np.mean(corrects_sacled[corrects_sacled > 0]) metrics["avg_model_std"] = np.std(corrects_sacled[corrects_sacled > 0]) metrics["ece"] = ece metrics["test_loss"] = test_loss return test_acc, metrics
def test(models, writer, criterion, data_loader, epoch): model = models[0] model_0 = models[1] model_0.eval() model_0.zero_grad() model.apply(lambda m: setattr(m, "return_feats", True)) model_0.apply(lambda m: setattr(m, "return_feats", True)) model.zero_grad() model.eval() test_loss = 0 correct = 0 ensemble_correct = 0 m0_correct = 0 tv_dist = 0.0 val_loader = data_loader.val_loader feat_cosim = 0 model.apply(lambda m: setattr(m, "t", args.t)) model_0.apply(lambda m: setattr(m, "t", args.baset)) model.apply(lambda m: setattr(m, "t1", args.t)) model_0.apply(lambda m: setattr(m, "t1", args.baset)) if args.update_bn: utils.update_bn(data_loader.train_loader, model, device=args.device) utils.update_bn(data_loader.train_loader, model_0, device=args.device) M = 20 acc_bm_m0 = torch.zeros(M) conf_bm_m0 = torch.zeros(M) count_bm_m0 = torch.zeros(M) acc_bm_ens = torch.zeros(M) conf_bm_ens = torch.zeros(M) count_bm_ens = torch.zeros(M) acc_bm = torch.zeros(M) conf_bm = torch.zeros(M) count_bm = torch.zeros(M) with torch.no_grad(): for data, target in val_loader: data, target = data.to(args.device), target.to(args.device) model.apply(lambda m: setattr(m, "t", args.t)) model.apply(lambda m: setattr(m, "t1", args.t)) output, feats = model(data) test_loss += criterion(output, target).item() # get the index of the max log-probability pred = output.argmax(dim=1, keepdim=True) correct_vec = pred.eq(target.view_as(pred)) correct += correct_vec.sum().item() # get model 0 model_0.apply(lambda m: setattr(m, "t", args.baset)) model_0.apply(lambda m: setattr(m, "t1", args.baset)) model_0_output, model_0_feats = model_0(data) ensemble_output = (model_0_output + output) / 2 ensemble_pred = ensemble_output.argmax(dim=1, keepdim=True) ensemble_correct_vec = ensemble_pred.eq(target.view_as(pred)) ensemble_correct += ensemble_correct_vec.sum().item() m0_pred = model_0_output.argmax(dim=1, keepdim=True) m0_correct_vec = m0_pred.eq(target.view_as(pred)) m0_correct += m0_correct_vec.sum().item() model_t_prob = nn.functional.softmax(output, dim=1) model_0_prob = nn.functional.softmax(model_0_output, dim=1) tv_dist += 0.5 * (model_0_prob - model_t_prob).abs().sum().item() feat_cosim += (torch.nn.functional.cosine_similarity( feats, model_0_feats, dim=1).pow(2).sum().item()) soft_out = output.softmax(dim=1) soft_out_m0 = model_0_output.softmax(dim=1) soft_out_ens = ensemble_output.softmax(dim=1) # need to do ece for m0, ensemble, model for i in range(data.size(0)): conf = soft_out[i][pred[i]] bin_idx = min((conf * M).int().item(), M - 1) acc_bm[bin_idx] += correct_vec[i].float().item() conf_bm[bin_idx] += conf.item() count_bm[bin_idx] += 1.0 conf = soft_out_ens[i][pred[i]] bin_idx = min((conf * M).int().item(), M - 1) acc_bm_ens[bin_idx] += ensemble_correct_vec[i].float().item() conf_bm_ens[bin_idx] += conf.item() count_bm_ens[bin_idx] += 1.0 conf = soft_out_m0[i][pred[i]] bin_idx = min((conf * M).int().item(), M - 1) acc_bm_m0[bin_idx] += m0_correct_vec[i].float().item() conf_bm_m0[bin_idx] += conf.item() count_bm_m0[bin_idx] += 1.0 test_loss /= len(val_loader) test_acc = float(correct) / len(val_loader.dataset) m0_acc = float(m0_correct) / len(val_loader.dataset) tv_dist /= len(val_loader.dataset) feat_cosim /= len(val_loader.dataset) ensemble_acc = float(ensemble_correct) / len(val_loader.dataset) print( f"\nTest set: Average loss: {test_loss:.4f}, Accuracy: ({test_acc:.4f})\n" ) if args.save: writer.add_scalar(f"test/loss", test_loss, epoch) writer.add_scalar(f"test/acc", test_acc, epoch) ece = 0.0 for i in range(M): ece += (acc_bm[i] - conf_bm[i]).abs().item() ece /= len(val_loader.dataset) print("ece is", ece) ece_ens = 0.0 for i in range(M): ece_ens += (acc_bm_ens[i] - conf_bm_ens[i]).abs().item() ece_ens /= len(val_loader.dataset) print("ece_ens is", ece_ens) ece_m0 = 0.0 for i in range(M): ece_m0 += (acc_bm_m0[i] - conf_bm_m0[i]).abs().item() ece_m0 /= len(val_loader.dataset) print("ece_m0 is", ece_m0) metrics = { "ece": ece, "ece_ens": ece_ens, "ece_m0": ece_m0, "tvdist": tv_dist, "ensemble_acc": ensemble_acc, "feat_cossim": feat_cosim, "m0_acc": m0_acc, } return test_acc, metrics
def test(models, writer, criterion, data_loader, epoch): for model in models: model.eval() model.apply(lambda m: setattr(m, "return_feats", True)) test_loss = 0 correct = 0 tvdist_sum = 0 tvdist_len = 0 feat_cossim = 0 percent_disagree_sum = 0 percent_disagree_len = 0 percent_disagree_correct_sum = 0 percent_disagree_correct_len = 0 val_loader = data_loader.val_loader if args.update_bn: for model in models: utils.update_bn(data_loader.train_loader, model, args.device) with torch.no_grad(): for data, target in val_loader: data, target = data.to(args.device), target.to(args.device) output, f = models[0](data) probs = [nn.functional.softmax(output, dim=1)] feats = [f] for t in range(1, args.num_models): modelt_output, model_feats_t = models[t](data) feats.append(model_feats_t) probs.append(nn.functional.softmax(modelt_output, dim=1)) output += modelt_output # output = 0 # for p in probs: # output += p.log() # output = (output / args.num_models).exp() # get the index of the max log-probability pred = output.argmax(dim=1, keepdim=True) correct += pred.eq(target.view_as(pred)).sum().item() # get tvdist between i and j for i in range(args.num_models): for j in range(i + 1, args.num_models): feat_cossim += (nn.functional.cosine_similarity( feats[i], feats[j], dim=1).sum().item()) pairwise_tvdist = 0.5 * (probs[i] - probs[j]).abs().sum(dim=1) tvdist_len += pairwise_tvdist.size(0) tvdist_sum += pairwise_tvdist.sum().item() model_i_pred = probs[i].argmax(dim=1, keepdim=True) model_j_pred = probs[j].argmax(dim=1, keepdim=True) percent_disagree_len += data.size(0) percent_disagree_sum += ((model_i_pred != model_j_pred).sum().item()) percent_disagree_correct_len += data.size(0) percent_disagree_correct_sum += ( ((model_i_pred != model_j_pred) * (model_i_pred.eq(target.view_as(model_i_pred)) + model_j_pred.eq(target.view_as(model_j_pred))) ).sum().item()) feat_cossim = feat_cossim / tvdist_len if tvdist_len > 0 else 0 tvdist = tvdist_sum / tvdist_len if tvdist_len > 0 else 0 percent_disagree = (percent_disagree_sum / percent_disagree_len if percent_disagree_len > 0 else 0) percent_disagree_correct = (percent_disagree_correct_sum / percent_disagree_correct_len if percent_disagree_correct_len > 0 else 0) test_loss /= len(val_loader) test_acc = float(correct) / len(val_loader.dataset) print( f"\nTest set: Average loss: {test_loss:.4f}, Accuracy: ({test_acc:.4f}), TVDist: ({tvdist})\n" ) if args.save: writer.add_scalar(f"test/loss", test_loss, epoch) writer.add_scalar(f"test/acc", test_acc, epoch) metrics = { "tvdist": tvdist, "percent_disagree": percent_disagree, "percent_disagree_correct": percent_disagree_correct, "feat_cossim": feat_cossim, } return test_acc, metrics
def test(models, writer, criterion, data_loader, epoch): model = models[0] model.eval() test_loss = 0 correct0 = 0 wa_correct = 0 val_loader = data_loader.val_loader for i in range(1, args.n): model.apply(lambda m: setattr(m, f"t{i}", 1.0 / args.n)) utils.update_bn(data_loader.train_loader, model, args.device) model.eval() cossim, l2 = get_stats(model) M = 20 acc_bm = torch.zeros(M) conf_bm = torch.zeros(M) count_bm = torch.zeros(M) with torch.no_grad(): for data, target in val_loader: data, target = data.to(args.device), target.to(args.device) wa_output = model(data) soft_output = wa_output.softmax(dim=1) wa_pred = wa_output.argmax(dim=1, keepdim=True) correct_vec = wa_pred.eq(target.view_as(wa_pred)) wa_correct += correct_vec.sum().item() test_loss += criterion(wa_output, target).item() for i in range(data.size(0)): conf = soft_output[i][wa_pred[i]] bin_idx = min((conf * M).int().item(), M - 1) acc_bm[bin_idx] += correct_vec[i].float().item() conf_bm[bin_idx] += conf.item() count_bm[bin_idx] += 1.0 wa_acc = float(wa_correct) / len(val_loader.dataset) m0_acc = float(correct0) / len(val_loader.dataset) test_acc = wa_acc test_loss /= len(val_loader) print( f"\nTest set: Average loss: {test_loss:.4f}, Accuracy: ({test_acc:.4f})\n" ) if args.save: writer.add_scalar(f"test/loss", test_loss, epoch) writer.add_scalar(f"test/norm", l2, epoch) writer.add_scalar(f"test/cossim", cossim, epoch) writer.add_scalar(f"test/wa_acc", wa_acc, epoch) writer.add_scalar(f"test/m0_acc", m0_acc, epoch) ece = 0.0 for i in range(M): ece += (acc_bm[i] - conf_bm[i]).abs().item() ece /= len(val_loader.dataset) print("ece is", ece) metrics = { "ece": ece, "wa_acc": wa_acc, "m0_acc": m0_acc, "l2": l2, "cossim": cossim, "test_loss": test_loss, } return test_acc, metrics
def main(): """Main entry point""" args = parse_args() os.makedirs(args.dir, exist_ok=True) utils.torch_settings(seed=None, benchmark=True) loaders, num_classes = data.loaders(args.dataset, args.data_path, args.batch_size, args.num_workers, args.transform, args.use_test, shuffle_train=False) model = init_model(args, num_classes) criterion = F.cross_entropy regularizer = curves.l2_regularizer(args.wd) (ts, tr_loss, tr_nll, tr_acc, te_loss, te_nll, te_acc, tr_err, te_err, dl) = init_metrics(args.num_points) previous_weights = None columns = [ "t", "Train loss", "Train nll", "Train error (%)", "Test nll", "Test error (%)" ] curve_metrics = metrics.TestCurve(args.num_points, columns) t = torch.FloatTensor([0.0]).cuda() for i, t_value in enumerate(ts): t.data.fill_(t_value) weights = model.weights(t) if previous_weights is not None: dl[i] = np.sqrt(np.sum(np.square(weights - previous_weights))) previous_weights = weights.copy() utils.update_bn(loaders["train"], model, t=t) tr_res = utils.test(loaders["train"], model, criterion, regularizer, t=t) te_res = utils.test(loaders["test"], model, criterion, regularizer, t=t) tr_loss[i] = tr_res["loss"] tr_nll[i] = tr_res["nll"] tr_acc[i] = tr_res["accuracy"] tr_err[i] = 100.0 - tr_acc[i] te_loss[i] = te_res["loss"] te_nll[i] = te_res["nll"] te_acc[i] = te_res["accuracy"] te_err[i] = 100.0 - te_acc[i] curve_metrics.add_meas(i, tr_res, te_res) values = [t, tr_loss[i], tr_nll[i], tr_err[i], te_nll[i], te_err[i]] table = tabulate.tabulate([values], columns, tablefmt="simple", floatfmt="10.4f") print(curve_metrics.table(i, with_header=i % 40 == 0)) if i % 40 == 0: table = table.split("\n") table = "\n".join([table[1]] + table) else: table = table.split("\n")[2] print(table) def stats(values, dl): min = np.min(values) max = np.max(values) avg = np.mean(values) int = np.sum(0.5 * (values[:-1] + values[1:]) * dl[1:]) / np.sum(dl[1:]) return min, max, avg, int tr_loss_min, tr_loss_max, tr_loss_avg, tr_loss_int = stats(tr_loss, dl) tr_nll_min, tr_nll_max, tr_nll_avg, tr_nll_int = stats(tr_nll, dl) tr_err_min, tr_err_max, tr_err_avg, tr_err_int = stats(tr_err, dl) te_loss_min, te_loss_max, te_loss_avg, te_loss_int = stats(te_loss, dl) te_nll_min, te_nll_max, te_nll_avg, te_nll_int = stats(te_nll, dl) te_err_min, te_err_max, te_err_avg, te_err_int = stats(te_err, dl) print("Length: %.2f" % np.sum(dl)) print( tabulate.tabulate([ [ "train loss", tr_loss[0], tr_loss[-1], tr_loss_min, tr_loss_max, tr_loss_avg, tr_loss_int ], [ "train error (%)", tr_err[0], tr_err[-1], tr_err_min, tr_err_max, tr_err_avg, tr_err_int ], [ "test nll", te_nll[0], te_nll[-1], te_nll_min, te_nll_max, te_nll_avg, te_nll_int ], [ "test error (%)", te_err[0], te_err[-1], te_err_min, te_err_max, te_err_avg, te_err_int ], ], ["", "start", "end", "min", "max", "avg", "int"], tablefmt="simple", floatfmt="10.4f")) np.savez( os.path.join(args.dir, "curve.npz"), ts=ts, dl=dl, tr_loss=tr_loss, tr_loss_min=tr_loss_min, tr_loss_max=tr_loss_max, tr_loss_avg=tr_loss_avg, tr_loss_int=tr_loss_int, tr_nll=tr_nll, tr_nll_min=tr_nll_min, tr_nll_max=tr_nll_max, tr_nll_avg=tr_nll_avg, tr_nll_int=tr_nll_int, tr_acc=tr_acc, tr_err=tr_err, tr_err_min=tr_err_min, tr_err_max=tr_err_max, tr_err_avg=tr_err_avg, tr_err_int=tr_err_int, te_loss=te_loss, te_loss_min=te_loss_min, te_loss_max=te_loss_max, te_loss_avg=te_loss_avg, te_loss_int=te_loss_int, te_nll=te_nll, te_nll_min=te_nll_min, te_nll_max=te_nll_max, te_nll_avg=te_nll_avg, te_nll_int=te_nll_int, te_acc=te_acc, te_err=te_err, te_err_min=te_err_min, te_err_max=te_err_max, te_err_avg=te_err_avg, te_err_int=te_err_int, )
def test(models, writer, criterion, data_loader, epoch): model = models[0] model_0 = models[1] model_0.eval() model_0.zero_grad() model.apply(lambda m: setattr(m, "return_feats", True)) model_0.apply(lambda m: setattr(m, "return_feats", True)) model.zero_grad() model.eval() test_loss = 0 correct = 0 ensemble_correct = 0 m0_correct = 0 tv_dist = 0.0 val_loader = data_loader.val_loader feat_cosim = 0 model.apply(lambda m: setattr(m, "alpha", args.alpha1)) model_0.apply(lambda m: setattr(m, "alpha", args.alpha0)) if args.update_bn: utils.update_bn(data_loader.train_loader, model, device=args.device) utils.update_bn(data_loader.train_loader, model_0, device=args.device) with torch.no_grad(): for data, target in val_loader: data, target = data.to(args.device), target.to(args.device) output, feats = model(data) test_loss += criterion(output, target).item() # get the index of the max log-probability pred = output.argmax(dim=1, keepdim=True) correct += pred.eq(target.view_as(pred)).sum().item() # get model 0 model_0_output, model_0_feats = model_0(data) ensemble_pred = (model_0_output + output).argmax( dim=1, keepdim=True ) ensemble_correct += ( ensemble_pred.eq(target.view_as(pred)).sum().item() ) m0_pred = model_0_output.argmax(dim=1, keepdim=True) m0_correct += m0_pred.eq(target.view_as(pred)).sum().item() model_t_prob = nn.functional.softmax(output, dim=1) model_0_prob = nn.functional.softmax(model_0_output, dim=1) tv_dist += 0.5 * (model_0_prob - model_t_prob).abs().sum().item() feat_cosim += ( torch.nn.functional.cosine_similarity( feats, model_0_feats, dim=1 ) .pow(2) .sum() .item() ) test_loss /= len(val_loader) test_acc = float(correct) / len(val_loader.dataset) m0_acc = float(m0_correct) / len(val_loader.dataset) tv_dist /= len(val_loader.dataset) feat_cosim /= len(val_loader.dataset) ensemble_acc = float(ensemble_correct) / len(val_loader.dataset) print( f"\nTest set: Average loss: {test_loss:.4f}, Accuracy: ({test_acc:.4f})\n" ) if args.save: writer.add_scalar(f"test/loss", test_loss, epoch) writer.add_scalar(f"test/acc", test_acc, epoch) metrics = { "tvdist": tv_dist, "ensemble_acc": ensemble_acc, "feat_cossim": feat_cosim, "m0_acc": m0_acc, } return test_acc, metrics
def test(models, writer, criterion, data_loader, epoch): j = args.j n = len(models) print(args.t) print((1 - args.t) / (n - 1)) print("--") for ms in zip(*[model.modules() for model in models]): if isinstance(ms[0], nn.Conv2d): if j == 0: ms[0].weight.data = ms[0].weight.data * args.t else: ms[0].weight.data = ms[0].weight.data * (1 - args.t) / (n - 1) for i in range(1, n): if i == j: ms[0].weight.data += ms[i].weight.data * args.t else: ms[0].weight.data += (ms[i].weight.data * (1 - args.t) / (n - 1)) print("conv", ms[0].weight[0, 0, 0, 0]) elif isinstance(ms[0], nn.BatchNorm2d): if j == 0: ms[0].weight.data = ms[0].weight.data * args.t else: ms[0].weight.data = ms[0].weight.data * (1 - args.t) / (n - 1) for i in range(1, n): if i == j: ms[0].weight.data += ms[i].weight.data * args.t else: ms[0].weight.data += (ms[i].weight.data * (1 - args.t) / (n - 1)) if j == 0: ms[0].bias.data = ms[0].bias.data * args.t else: ms[0].bias.data = ms[0].bias.data * (1 - args.t) / (n - 1) for i in range(1, n): if i == j: ms[0].bias.data += ms[i].bias.data * args.t else: ms[0].bias.data += ms[i].bias.data * (1 - args.t) / (n - 1) print("bn", ms[0].weight[0]) print("bn", ms[0].bias[0]) utils.update_bn(data_loader.train_loader, models[0], device=args.device) # here was save the model in args.tmp_dir/model_{j}.pt torch.save( { "epoch": 0, "iter": 0, "arch": args.model, "state_dicts": [models[0].state_dict()], "optimizers": None, "best_acc1": 0, "curr_acc1": 0, }, os.path.join(args.tmp_dir, f"model_{j}.pt"), ) test_acc = 0 metrics = {} return test_acc, metrics