Esempio n. 1
0
def test(models, writer, criterion, data_loader, epoch):
    np.random.seed(args.seed)

    n = args.num_models
    # turn all the models into vectors
    vecs = [
        utils.sd_to_vector(models[i].state_dict()).clone() for i in range(n)
    ]

    swa_vec = vecs[0]
    for i in range(1, n):
        swa_vec = swa_vec + vecs[i]
    swa_vec = swa_vec / n

    square_vec = vecs[0].pow(2)
    for i in range(1, n):
        square_vec = square_vec + vecs[i].pow(2)
    square_vec = square_vec / n

    swa_diag_mult = (
        (1.0 / math.sqrt(2))
        * (square_vec - swa_vec.pow(2)).pow(0.5)
        * torch.randn_like(swa_vec)
    )

    low_rank_term = (vecs[0] - swa_vec) * torch.randn(1).item()
    for i in range(1, n):
        low_rank_term = (
            low_rank_term + (vecs[i] - swa_vec) * torch.randn(1).item()
        )
    low_rank_term = (1.0 / math.sqrt(2 * (n - 1))) * low_rank_term

    out = swa_vec + swa_diag_mult + low_rank_term

    final_model_sd = models[0].state_dict()
    utils.vector_to_sd(out, final_model_sd)
    models[0].load_state_dict(final_model_sd)

    utils.update_bn(data_loader.train_loader, models[0], device=args.device)

    torch.save(
        {
            "epoch": 0,
            "iter": 0,
            "arch": args.model,
            "state_dicts": [models[0].state_dict()],
            "optimizers": None,
            "best_acc1": 0,
            "curr_acc1": 0,
        },
        os.path.join(args.tmp_dir, f"model_{args.j}.pt"),
    )

    test_acc = 0
    metrics = {}

    return test_acc, metrics
def test(models, writer, criterion, data_loader, epoch):

    for m in models:
        m.eval()
    test_loss = 0
    correct = 0
    val_loader = data_loader.val_loader

    Z = np.random.exponential(scale=1.0, size=args.num_models)
    Z = Z / Z.sum()

    for ms in zip(*[models[i].modules() for i in range(args.num_models)]):
        if isinstance(ms[0], nn.Conv2d):
            ms[0].weight.data = Z[0] * ms[0].weight.data
            for i in range(1, args.num_models):
                ms[0].weight.data += Z[i] * ms[i].weight.data
        elif isinstance(ms[0], nn.BatchNorm2d):
            ms[0].weight.data = Z[0] * ms[0].weight.data
            for i in range(1, args.num_models):
                ms[0].weight.data += Z[i] * ms[i].weight.data
            ms[0].bias.data = Z[0] * ms[0].bias.data
            for i in range(1, args.num_models):
                ms[0].bias.data += Z[i] * ms[i].bias.data
    model = models[0]
    utils.update_bn(data_loader.train_loader, model, device=args.device)
    # model.train()
    # # for batch_idx, (data, target) in enumerate(data_loader.train_loader):
    # #     data, target = data.to(args.device), target.to(args.device)
    # #     output = model(data)
    model.eval()

    with torch.no_grad():

        for data, target in val_loader:
            data, target = data.to(args.device), target.to(args.device)
            output = model(data)
            test_loss += criterion(output, target).item()

            # get the index of the max log-probability
            pred = output.argmax(dim=1, keepdim=True)

            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(val_loader)
    test_acc = float(correct) / len(val_loader.dataset)

    print(
        f"\nTest set: Average loss: {test_loss:.4f}, Accuracy: ({test_acc:.4f})\n"
    )

    if args.save:
        writer.add_scalar(f"test/loss", test_loss, epoch)
        writer.add_scalar(f"test/acc", test_acc, epoch)

    metrics = {}

    return test_acc, metrics
Esempio n. 3
0
def test(models, writer, criterion, data_loader, epoch):
    model = models[0]

    model.zero_grad()
    model.eval()
    test_loss = 0
    correct = 0
    val_loader = data_loader.val_loader

    model.apply(lambda m: setattr(m, "alpha", 0.5))

    # optionally update the bn during training to, but note this slows down things.
    if args.train_update_bn:
        utils.update_bn(data_loader.train_loader, model, args.device)

    with torch.no_grad():

        for data, target in val_loader:
            data, target = data.to(args.device), target.to(args.device)

            output = model(data)
            test_loss += criterion(output, target).item()

            # get the index of the max log-probability
            pred = output.argmax(dim=1, keepdim=True)

            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(val_loader)
    test_acc = float(correct) / len(val_loader.dataset)

    print(
        f"\nTest set: Average loss: {test_loss:.4f}, Accuracy: ({test_acc:.4f})\n"
    )
    cossim, l2 = get_stats(model)

    if args.save:
        writer.add_scalar(f"test/loss", test_loss, epoch)
        writer.add_scalar(f"test/acc", test_acc, epoch)
        writer.add_scalar(f"test/norm", l2, epoch)
        writer.add_scalar(f"test/cossim", cossim, epoch)

    metrics = {"norm": l2, "cossim": cossim}

    return test_acc, metrics
Esempio n. 4
0
    'X', 'Y', 'Train loss', 'Train nll', 'Train error (%)', 'Test nll',
    'Test error (%)'
]

for i, alpha in enumerate(alphas):
    for j, beta in enumerate(betas):
        p = w[0] + alpha * dx * u + beta * dy * v

        offset = 0
        for parameter in base_model.parameters():
            size = np.prod(parameter.size())
            value = p[offset:offset + size].reshape(parameter.size())
            parameter.data.copy_(torch.from_numpy(value))
            offset += size

        utils.update_bn(loaders['train'], base_model)

        tr_res = utils.test(loaders['train'], base_model, criterion,
                            regularizer)
        te_res = utils.test(loaders['test'], base_model, criterion,
                            regularizer)

        tr_loss_v, tr_nll_v, tr_acc_v = tr_res['loss'], tr_res['nll'], tr_res[
            'accuracy']
        te_loss_v, te_nll_v, te_acc_v = te_res['loss'], te_res['nll'], te_res[
            'accuracy']

        c = get_xy(p, w[0], u, v)
        grid[i, j] = [alpha * dx, beta * dy]

        tr_loss[i, j] = tr_loss_v
def evaluate_curve(dir='/tmp/curve/',
                   ckpt=None,
                   num_points=61,
                   dataset='CIFAR10',
                   use_test=True,
                   transform='VGG',
                   data_path=None,
                   batch_size=128,
                   num_workers=4,
                   model_type=None,
                   curve_type=None,
                   num_bends=3,
                   wd=1e-4):
    args = EvalCurveArgSet(dir=dir,
                           ckpt=ckpt,
                           num_points=num_points,
                           dataset=dataset,
                           use_test=use_test,
                           transform=transform,
                           data_path=data_path,
                           batch_size=batch_size,
                           num_workers=num_workers,
                           model=model_type,
                           curve=curve_type,
                           num_bends=num_bends,
                           wd=wd)

    if True:
        q = 1
        # parser = argparse.ArgumentParser(description='DNN curve evaluation')
        # parser.add_argument('--dir', type=str, default='/tmp/eval', metavar='DIR',
        #                     help='training directory (default: /tmp/eval)')
        #
        # parser.add_argument('--num_points', type=int, default=61, metavar='N',
        #                     help='number of points on the curve (default: 61)')
        #
        # parser.add_argument('--dataset', type=str, default='CIFAR10', metavar='DATASET',
        #                     help='dataset name (default: CIFAR10)')
        # parser.add_argument('--use_test', action='store_true',
        #                     help='switches between validation and test set (default: validation)')
        # parser.add_argument('--transform', type=str, default='VGG', metavar='TRANSFORM',
        #                     help='transform name (default: VGG)')
        # parser.add_argument('--data_path', type=str, default=None, metavar='PATH',
        #                     help='path to datasets location (default: None)')
        # parser.add_argument('--batch_size', type=int, default=128, metavar='N',
        #                     help='input batch size (default: 128)')
        # parser.add_argument('--num_workers', type=int, default=4, metavar='N',
        #                     help='number of workers (default: 4)')
        #
        # parser.add_argument('--model', type=str, default=None, metavar='MODEL',
        #                     help='model name (default: None)')
        # parser.add_argument('--curve', type=str, default=None, metavar='CURVE',
        #                     help='curve type to use (default: None)')
        # parser.add_argument('--num_bends', type=int, default=3, metavar='N',
        #                     help='number of curve bends (default: 3)')
        #
        # parser.add_argument('--ckpt', type=str, default=None, metavar='CKPT',
        #                     help='checkpoint to eval (default: None)')
        #
        # parser.add_argument('--wd', type=float, default=1e-4, metavar='WD',
        #                     help='weight decay (default: 1e-4)')

    # args = parser.parse_args()

    os.makedirs(args.dir, exist_ok=True)

    torch.backends.cudnn.benchmark = True

    loaders, num_classes = data.loaders(args.dataset,
                                        args.data_path,
                                        args.batch_size,
                                        args.num_workers,
                                        args.transform,
                                        args.use_test,
                                        shuffle_train=False)

    architecture = getattr(models, args.model)
    curve = getattr(curves, args.curve)
    model = curves.CurveNet(
        num_classes,
        curve,
        architecture.curve,
        args.num_bends,
        architecture_kwargs=architecture.kwargs,
    )
    model.cuda()
    checkpoint = torch.load(args.ckpt)
    model.load_state_dict(checkpoint['model_state'])

    criterion = F.cross_entropy
    regularizer = curves.l2_regularizer(args.wd)

    T = args.num_points
    ts = np.linspace(0.0, 1.0, T)
    tr_loss = np.zeros(T)
    tr_nll = np.zeros(T)
    tr_acc = np.zeros(T)
    te_loss = np.zeros(T)
    te_nll = np.zeros(T)
    te_acc = np.zeros(T)
    tr_err = np.zeros(T)
    te_err = np.zeros(T)
    dl = np.zeros(T)

    previous_weights = None

    columns = [
        't', 'Train loss', 'Train nll', 'Train error (%)', 'Test nll',
        'Test error (%)', 'Distance'
    ]

    t = torch.FloatTensor([0.0]).cuda()
    for i, t_value in enumerate(ts):
        t.data.fill_(t_value)
        weights = model.weights(t)
        if previous_weights is not None:
            dl[i] = np.sqrt(np.sum(np.square(weights - previous_weights)))
        previous_weights = weights.copy()

        utils.update_bn(loaders['train'], model, t=t)
        tr_res = utils.test(loaders['train'],
                            model,
                            criterion,
                            regularizer,
                            t=t)
        te_res = utils.test(loaders['test'],
                            model,
                            criterion,
                            regularizer,
                            t=t)
        tr_loss[i] = tr_res['loss']
        tr_nll[i] = tr_res['nll']
        tr_acc[i] = tr_res['accuracy']
        tr_err[i] = 100.0 - tr_acc[i]
        te_loss[i] = te_res['loss']
        te_nll[i] = te_res['nll']
        te_acc[i] = te_res['accuracy']
        te_err[i] = 100.0 - te_acc[i]

        values = [
            t, tr_loss[i], tr_nll[i], tr_err[i], te_nll[i], te_err[i], dl[i]
        ]
        table = tabulate.tabulate([values],
                                  columns,
                                  tablefmt='simple',
                                  floatfmt='10.4f')
        if i % 40 == 0:
            table = table.split('\n')
            table = '\n'.join([table[1]] + table)
        else:
            table = table.split('\n')[2]
        print(table)

    def stats(values, dl):
        min = np.min(values)
        max = np.max(values)
        avg = np.mean(values)
        int = np.sum(0.5 *
                     (values[:-1] + values[1:]) * dl[1:]) / np.sum(dl[1:])
        return min, max, avg, int

    tr_loss_min, tr_loss_max, tr_loss_avg, tr_loss_int = stats(tr_loss, dl)
    tr_nll_min, tr_nll_max, tr_nll_avg, tr_nll_int = stats(tr_nll, dl)
    tr_err_min, tr_err_max, tr_err_avg, tr_err_int = stats(tr_err, dl)

    te_loss_min, te_loss_max, te_loss_avg, te_loss_int = stats(te_loss, dl)
    te_nll_min, te_nll_max, te_nll_avg, te_nll_int = stats(te_nll, dl)
    te_err_min, te_err_max, te_err_avg, te_err_int = stats(te_err, dl)

    print('Length: %.2f' % np.sum(dl))
    print(
        tabulate.tabulate([
            [
                'train loss', tr_loss[0], tr_loss[-1], tr_loss_min,
                tr_loss_max, tr_loss_avg, tr_loss_int
            ],
            [
                'train error (%)', tr_err[0], tr_err[-1], tr_err_min,
                tr_err_max, tr_err_avg, tr_err_int
            ],
            [
                'test nll', te_nll[0], te_nll[-1], te_nll_min, te_nll_max,
                te_nll_avg, te_nll_int
            ],
            [
                'test error (%)', te_err[0], te_err[-1], te_err_min,
                te_err_max, te_err_avg, te_err_int
            ],
        ], ['', 'start', 'end', 'min', 'max', 'avg', 'int'],
                          tablefmt='simple',
                          floatfmt='10.4f'))

    np.savez(
        os.path.join(args.dir, 'curve.npz'),
        ts=ts,
        dl=dl,
        tr_loss=tr_loss,
        tr_loss_min=tr_loss_min,
        tr_loss_max=tr_loss_max,
        tr_loss_avg=tr_loss_avg,
        tr_loss_int=tr_loss_int,
        tr_nll=tr_nll,
        tr_nll_min=tr_nll_min,
        tr_nll_max=tr_nll_max,
        tr_nll_avg=tr_nll_avg,
        tr_nll_int=tr_nll_int,
        tr_acc=tr_acc,
        tr_err=tr_err,
        tr_err_min=tr_err_min,
        tr_err_max=tr_err_max,
        tr_err_avg=tr_err_avg,
        tr_err_int=tr_err_int,
        te_loss=te_loss,
        te_loss_min=te_loss_min,
        te_loss_max=te_loss_max,
        te_loss_avg=te_loss_avg,
        te_loss_int=te_loss_int,
        te_nll=te_nll,
        te_nll_min=te_nll_min,
        te_nll_max=te_nll_max,
        te_nll_avg=te_nll_avg,
        te_nll_int=te_nll_int,
        te_acc=te_acc,
        te_err=te_err,
        te_err_min=te_err_min,
        te_err_max=te_err_max,
        te_err_avg=te_err_avg,
        te_err_int=te_err_int,
    )
    'step_size': 2.0 / 255,
    'random_start': True,
    'loss_func': 'xent',
}

net = AttackPGD(model, config)

t = torch.FloatTensor([0.0]).cuda()
for i, t_value in enumerate(ts):
    t.data.fill_(t_value)
    weights = model.weights(t)
    if previous_weights is not None:
        dl[i] = np.sqrt(np.sum(np.square(weights - previous_weights)))
    previous_weights = weights.copy()

    utils.update_bn(loaders['train'], model, t=t)
    tr_res = utils.test(loaders['train'], model, criterion, regularizer, t=t)
    te_res = utils.test(loaders['test'], model, criterion, regularizer, t=t)
    te_example_res = utils.test_examples(loaders['test'],
                                         net,
                                         criterion,
                                         regularizer,
                                         t=t)
    #  te_example_res = utils.test_examples(loaders['test'], net, criterion, regularizer)
    tr_loss[i] = tr_res['loss']
    tr_nll[i] = tr_res['nll']
    tr_acc[i] = tr_res['accuracy']
    tr_err[i] = 100.0 - tr_acc[i]
    te_loss[i] = te_res['loss']
    te_nll[i] = te_res['nll']
    te_acc[i] = te_res['accuracy']
Esempio n. 7
0
def test(models, writer, criterion, data_loader, epoch):
    for m in models:
        m.eval()
    model = models[0]
    test_loss = 0
    correct = 0
    val_loader = data_loader.val_loader

    M = 20
    acc_bm = torch.zeros(M)
    conf_bm = torch.zeros(M)
    count_bm = torch.zeros(M)

    for ms in zip(*[models[i].modules() for i in range(args.num_models)]):
        if isinstance(ms[0], nn.Conv2d):
            ms[0].weight.data = (1.0 / args.num_models) * ms[0].weight.data
            for i in range(1, args.num_models):
                ms[0].weight.data += (1.0 /
                                      args.num_models) * ms[i].weight.data
        elif isinstance(ms[0], nn.BatchNorm2d):
            ms[0].weight.data = (1.0 / args.num_models) * ms[0].weight.data
            for i in range(1, args.num_models):
                ms[0].weight.data += (1.0 /
                                      args.num_models) * ms[i].weight.data
            ms[0].bias.data = (1.0 / args.num_models) * ms[0].bias.data
            for i in range(1, args.num_models):
                ms[0].bias.data += (1.0 / args.num_models) * ms[i].bias.data

    utils.update_bn(data_loader.train_loader, model, device=args.device)

    with torch.no_grad():

        for data, target in val_loader:
            data, target = data.to(args.device), target.to(args.device)
            output = model(data)
            test_loss += criterion(output, target).item()

            # get the index of the max log-probability
            pred = output.argmax(dim=1, keepdim=True)
            soft_out = output.softmax(dim=1)
            correct_vec = pred.eq(target.view_as(pred))

            correct += correct_vec.sum().item()

            for i in range(data.size(0)):
                conf = soft_out[i][pred[i]]
                bin_idx = min((conf * M).int().item(), M - 1)
                acc_bm[bin_idx] += correct_vec[i].float().item()
                conf_bm[bin_idx] += conf.item()
                count_bm[bin_idx] += 1.0

    test_loss /= len(val_loader)
    test_acc = float(correct) / len(val_loader.dataset)

    print(
        f"\nTest set: Average loss: {test_loss:.4f}, Accuracy: ({test_acc:.4f})\n"
    )

    if args.save:
        writer.add_scalar(f"test/loss", test_loss, epoch)
        writer.add_scalar(f"test/acc", test_acc, epoch)

    ece = 0.0
    for i in range(M):
        ece += (acc_bm[i] - conf_bm[i]).abs().item()
    ece /= len(val_loader.dataset)
    print("ece is", ece)

    metrics = {"ece": ece, "test_loss": test_loss}

    return test_acc, metrics
Esempio n. 8
0
    "X", "Y", "Train loss", "Train nll", "Train error (%)", "Test nll",
    "Test error (%)"
]

for i, alpha in enumerate(alphas):
    for j, beta in enumerate(betas):
        p = w[0] + alpha * dx * u + beta * dy * v

        offset = 0
        for parameter in base_model.parameters():
            size = np.prod(parameter.size())
            value = p[offset:offset + size].reshape(parameter.size())
            parameter.data.copy_(torch.from_numpy(value))
            offset += size

        utils.update_bn(loaders["train"], base_model)

        tr_res = utils.test(loaders["train"], base_model, criterion,
                            regularizer)
        te_res = utils.test(loaders["test"], base_model, criterion,
                            regularizer)

        tr_loss_v, tr_nll_v, tr_acc_v = tr_res["loss"], tr_res["nll"], tr_res[
            "accuracy"]
        te_loss_v, te_nll_v, te_acc_v = te_res["loss"], te_res["nll"], te_res[
            "accuracy"]

        c = get_xy(p, w[0], u, v)
        grid[i, j] = [alpha * dx, beta * dy]

        tr_loss[i, j] = tr_loss_v
Esempio n. 9
0
def test(models, writer, criterion, data_loader, epoch):

    for i, model in enumerate(models):
        model.zero_grad()
        model.eval()

        if args.layerwise:
            for m in model.modules():
                if isinstance(m, nn.Conv2d) or isinstance(m, nn.BatchNorm2d):
                    Z = np.random.exponential(scale=1.0, size=args.n)
                    Z = Z / Z.sum()
                    for i in range(1, args.n):
                        setattr(m, f"t{i}", Z[i])
        else:
            Z = np.random.exponential(scale=1.0, size=args.n)
            Z = Z / Z.sum()
            for m in model.modules():
                if isinstance(m, nn.Conv2d) or isinstance(m, nn.BatchNorm2d):
                    for i in range(1, args.n):
                        setattr(m, f"t{i}", Z[i])

    test_loss = 0
    correct = 0
    corrects = [0 for _ in range(10)]

    M = 20
    acc_bm = torch.zeros(M)
    conf_bm = torch.zeros(M)
    count_bm = torch.zeros(M)

    val_loader = data_loader.val_loader

    for m in models:
        utils.update_bn(data_loader.train_loader, m, device=args.device)

    with torch.no_grad():

        for data, target in val_loader:
            data, target = data.to(args.device), target.to(args.device)

            model_output = models[0](data)
            model_pred = model_output.argmax(dim=1, keepdim=True)
            corrects[0] += (model_pred.eq(
                target.view_as(model_pred)).sum().item())
            mean_output = model_output

            for i, m in enumerate(models[1:]):
                model_output = m(data)
                model_pred = model_output.argmax(dim=1, keepdim=True)
                corrects[i + 1] += (model_pred.eq(
                    target.view_as(model_pred)).sum().item())
                mean_output += model_output

            mean_output /= len(models)
            # get the index of the max log-probability
            pred = mean_output.argmax(dim=1, keepdim=True)
            test_loss += criterion(mean_output, target).item()
            correct_vec = pred.eq(target.view_as(pred))
            correct += correct_vec.sum().item()
            soft_output = mean_output.softmax(dim=1)

            for i in range(data.size(0)):
                conf = soft_output[i][pred[i]]
                bin_idx = min((conf * M).int().item(), M - 1)
                acc_bm[bin_idx] += correct_vec[i].float().item()
                conf_bm[bin_idx] += conf.item()
                count_bm[bin_idx] += 1.0

    test_loss /= len(val_loader)
    test_acc = float(correct) / len(val_loader.dataset)

    print(
        f"\nTest set: Average loss: {test_loss:.4f}, Accuracy: ({test_acc:.4f})\n"
    )

    ece = 0.0
    for i in range(M):
        ece += (acc_bm[i] - conf_bm[i]).abs().item()
    ece /= len(val_loader.dataset)
    print("ece is", ece)

    if args.save:
        writer.add_scalar(f"test/loss", test_loss, epoch)
        writer.add_scalar(f"test/acc", test_acc, epoch)

    corrects_sacled = [
        float(corrects[i]) / len(val_loader.dataset)
        for i in range(len(corrects))
    ]
    metrics = {
        f"model_{i}_acc": corrects_sacled[i]
        for i in range(len(corrects_sacled))
    }

    corrects_sacled = np.array(corrects_sacled)

    metrics["avg_model_acc"] = np.mean(corrects_sacled[corrects_sacled > 0])
    metrics["avg_model_std"] = np.std(corrects_sacled[corrects_sacled > 0])
    metrics["ece"] = ece
    metrics["test_loss"] = test_loss

    return test_acc, metrics
def test(models, writer, criterion, data_loader, epoch):

    model = models[0]
    model_0 = models[1]
    model_0.eval()
    model_0.zero_grad()

    model.apply(lambda m: setattr(m, "return_feats", True))
    model_0.apply(lambda m: setattr(m, "return_feats", True))

    model.zero_grad()
    model.eval()
    test_loss = 0
    correct = 0
    ensemble_correct = 0
    m0_correct = 0
    tv_dist = 0.0
    val_loader = data_loader.val_loader
    feat_cosim = 0

    model.apply(lambda m: setattr(m, "t", args.t))
    model_0.apply(lambda m: setattr(m, "t", args.baset))
    model.apply(lambda m: setattr(m, "t1", args.t))
    model_0.apply(lambda m: setattr(m, "t1", args.baset))

    if args.update_bn:
        utils.update_bn(data_loader.train_loader, model, device=args.device)
        utils.update_bn(data_loader.train_loader, model_0, device=args.device)

    M = 20
    acc_bm_m0 = torch.zeros(M)
    conf_bm_m0 = torch.zeros(M)
    count_bm_m0 = torch.zeros(M)

    acc_bm_ens = torch.zeros(M)
    conf_bm_ens = torch.zeros(M)
    count_bm_ens = torch.zeros(M)

    acc_bm = torch.zeros(M)
    conf_bm = torch.zeros(M)
    count_bm = torch.zeros(M)

    with torch.no_grad():

        for data, target in val_loader:
            data, target = data.to(args.device), target.to(args.device)

            model.apply(lambda m: setattr(m, "t", args.t))
            model.apply(lambda m: setattr(m, "t1", args.t))
            output, feats = model(data)
            test_loss += criterion(output, target).item()

            # get the index of the max log-probability
            pred = output.argmax(dim=1, keepdim=True)

            correct_vec = pred.eq(target.view_as(pred))
            correct += correct_vec.sum().item()

            # get model 0
            model_0.apply(lambda m: setattr(m, "t", args.baset))
            model_0.apply(lambda m: setattr(m, "t1", args.baset))
            model_0_output, model_0_feats = model_0(data)
            ensemble_output = (model_0_output + output) / 2
            ensemble_pred = ensemble_output.argmax(dim=1, keepdim=True)
            ensemble_correct_vec = ensemble_pred.eq(target.view_as(pred))
            ensemble_correct += ensemble_correct_vec.sum().item()

            m0_pred = model_0_output.argmax(dim=1, keepdim=True)
            m0_correct_vec = m0_pred.eq(target.view_as(pred))
            m0_correct += m0_correct_vec.sum().item()

            model_t_prob = nn.functional.softmax(output, dim=1)
            model_0_prob = nn.functional.softmax(model_0_output, dim=1)
            tv_dist += 0.5 * (model_0_prob - model_t_prob).abs().sum().item()

            feat_cosim += (torch.nn.functional.cosine_similarity(
                feats, model_0_feats, dim=1).pow(2).sum().item())

            soft_out = output.softmax(dim=1)
            soft_out_m0 = model_0_output.softmax(dim=1)
            soft_out_ens = ensemble_output.softmax(dim=1)

            # need to do ece for m0, ensemble, model
            for i in range(data.size(0)):

                conf = soft_out[i][pred[i]]
                bin_idx = min((conf * M).int().item(), M - 1)
                acc_bm[bin_idx] += correct_vec[i].float().item()
                conf_bm[bin_idx] += conf.item()
                count_bm[bin_idx] += 1.0

                conf = soft_out_ens[i][pred[i]]
                bin_idx = min((conf * M).int().item(), M - 1)
                acc_bm_ens[bin_idx] += ensemble_correct_vec[i].float().item()
                conf_bm_ens[bin_idx] += conf.item()
                count_bm_ens[bin_idx] += 1.0

                conf = soft_out_m0[i][pred[i]]
                bin_idx = min((conf * M).int().item(), M - 1)
                acc_bm_m0[bin_idx] += m0_correct_vec[i].float().item()
                conf_bm_m0[bin_idx] += conf.item()
                count_bm_m0[bin_idx] += 1.0

    test_loss /= len(val_loader)
    test_acc = float(correct) / len(val_loader.dataset)
    m0_acc = float(m0_correct) / len(val_loader.dataset)
    tv_dist /= len(val_loader.dataset)
    feat_cosim /= len(val_loader.dataset)
    ensemble_acc = float(ensemble_correct) / len(val_loader.dataset)

    print(
        f"\nTest set: Average loss: {test_loss:.4f}, Accuracy: ({test_acc:.4f})\n"
    )

    if args.save:
        writer.add_scalar(f"test/loss", test_loss, epoch)
        writer.add_scalar(f"test/acc", test_acc, epoch)

    ece = 0.0
    for i in range(M):
        ece += (acc_bm[i] - conf_bm[i]).abs().item()
    ece /= len(val_loader.dataset)
    print("ece is", ece)

    ece_ens = 0.0
    for i in range(M):
        ece_ens += (acc_bm_ens[i] - conf_bm_ens[i]).abs().item()
    ece_ens /= len(val_loader.dataset)
    print("ece_ens is", ece_ens)

    ece_m0 = 0.0
    for i in range(M):
        ece_m0 += (acc_bm_m0[i] - conf_bm_m0[i]).abs().item()
    ece_m0 /= len(val_loader.dataset)
    print("ece_m0 is", ece_m0)

    metrics = {
        "ece": ece,
        "ece_ens": ece_ens,
        "ece_m0": ece_m0,
        "tvdist": tv_dist,
        "ensemble_acc": ensemble_acc,
        "feat_cossim": feat_cosim,
        "m0_acc": m0_acc,
    }

    return test_acc, metrics
Esempio n. 11
0
def test(models, writer, criterion, data_loader, epoch):

    for model in models:
        model.eval()
        model.apply(lambda m: setattr(m, "return_feats", True))
    test_loss = 0
    correct = 0
    tvdist_sum = 0
    tvdist_len = 0
    feat_cossim = 0
    percent_disagree_sum = 0
    percent_disagree_len = 0
    percent_disagree_correct_sum = 0
    percent_disagree_correct_len = 0
    val_loader = data_loader.val_loader

    if args.update_bn:
        for model in models:
            utils.update_bn(data_loader.train_loader, model, args.device)

    with torch.no_grad():

        for data, target in val_loader:
            data, target = data.to(args.device), target.to(args.device)

            output, f = models[0](data)
            probs = [nn.functional.softmax(output, dim=1)]
            feats = [f]
            for t in range(1, args.num_models):
                modelt_output, model_feats_t = models[t](data)
                feats.append(model_feats_t)
                probs.append(nn.functional.softmax(modelt_output, dim=1))
                output += modelt_output

            # output = 0
            # for p in probs:
            #     output += p.log()
            # output = (output / args.num_models).exp()

            # get the index of the max log-probability
            pred = output.argmax(dim=1, keepdim=True)

            correct += pred.eq(target.view_as(pred)).sum().item()

            # get tvdist between i and j
            for i in range(args.num_models):
                for j in range(i + 1, args.num_models):
                    feat_cossim += (nn.functional.cosine_similarity(
                        feats[i], feats[j], dim=1).sum().item())
                    pairwise_tvdist = 0.5 * (probs[i] -
                                             probs[j]).abs().sum(dim=1)
                    tvdist_len += pairwise_tvdist.size(0)
                    tvdist_sum += pairwise_tvdist.sum().item()

                    model_i_pred = probs[i].argmax(dim=1, keepdim=True)
                    model_j_pred = probs[j].argmax(dim=1, keepdim=True)
                    percent_disagree_len += data.size(0)
                    percent_disagree_sum += ((model_i_pred !=
                                              model_j_pred).sum().item())

                    percent_disagree_correct_len += data.size(0)
                    percent_disagree_correct_sum += (
                        ((model_i_pred != model_j_pred) *
                         (model_i_pred.eq(target.view_as(model_i_pred)) +
                          model_j_pred.eq(target.view_as(model_j_pred)))
                         ).sum().item())

    feat_cossim = feat_cossim / tvdist_len if tvdist_len > 0 else 0
    tvdist = tvdist_sum / tvdist_len if tvdist_len > 0 else 0
    percent_disagree = (percent_disagree_sum / percent_disagree_len
                        if percent_disagree_len > 0 else 0)
    percent_disagree_correct = (percent_disagree_correct_sum /
                                percent_disagree_correct_len
                                if percent_disagree_correct_len > 0 else 0)
    test_loss /= len(val_loader)
    test_acc = float(correct) / len(val_loader.dataset)

    print(
        f"\nTest set: Average loss: {test_loss:.4f}, Accuracy: ({test_acc:.4f}), TVDist: ({tvdist})\n"
    )

    if args.save:
        writer.add_scalar(f"test/loss", test_loss, epoch)
        writer.add_scalar(f"test/acc", test_acc, epoch)

    metrics = {
        "tvdist": tvdist,
        "percent_disagree": percent_disagree,
        "percent_disagree_correct": percent_disagree_correct,
        "feat_cossim": feat_cossim,
    }

    return test_acc, metrics
Esempio n. 12
0
def test(models, writer, criterion, data_loader, epoch):

    model = models[0]
    model.eval()
    test_loss = 0
    correct0 = 0
    wa_correct = 0
    val_loader = data_loader.val_loader
    for i in range(1, args.n):
        model.apply(lambda m: setattr(m, f"t{i}", 1.0 / args.n))

    utils.update_bn(data_loader.train_loader, model, args.device)
    model.eval()
    cossim, l2 = get_stats(model)

    M = 20
    acc_bm = torch.zeros(M)
    conf_bm = torch.zeros(M)
    count_bm = torch.zeros(M)

    with torch.no_grad():

        for data, target in val_loader:
            data, target = data.to(args.device), target.to(args.device)

            wa_output = model(data)
            soft_output = wa_output.softmax(dim=1)
            wa_pred = wa_output.argmax(dim=1, keepdim=True)
            correct_vec = wa_pred.eq(target.view_as(wa_pred))
            wa_correct += correct_vec.sum().item()
            test_loss += criterion(wa_output, target).item()

            for i in range(data.size(0)):
                conf = soft_output[i][wa_pred[i]]
                bin_idx = min((conf * M).int().item(), M - 1)
                acc_bm[bin_idx] += correct_vec[i].float().item()
                conf_bm[bin_idx] += conf.item()
                count_bm[bin_idx] += 1.0

    wa_acc = float(wa_correct) / len(val_loader.dataset)
    m0_acc = float(correct0) / len(val_loader.dataset)
    test_acc = wa_acc
    test_loss /= len(val_loader)

    print(
        f"\nTest set: Average loss: {test_loss:.4f}, Accuracy: ({test_acc:.4f})\n"
    )

    if args.save:
        writer.add_scalar(f"test/loss", test_loss, epoch)
        writer.add_scalar(f"test/norm", l2, epoch)
        writer.add_scalar(f"test/cossim", cossim, epoch)

        writer.add_scalar(f"test/wa_acc", wa_acc, epoch)
        writer.add_scalar(f"test/m0_acc", m0_acc, epoch)

    ece = 0.0
    for i in range(M):
        ece += (acc_bm[i] - conf_bm[i]).abs().item()
    ece /= len(val_loader.dataset)
    print("ece is", ece)

    metrics = {
        "ece": ece,
        "wa_acc": wa_acc,
        "m0_acc": m0_acc,
        "l2": l2,
        "cossim": cossim,
        "test_loss": test_loss,
    }

    return test_acc, metrics
def main():
    """Main entry point"""
    args = parse_args()
    os.makedirs(args.dir, exist_ok=True)

    utils.torch_settings(seed=None, benchmark=True)

    loaders, num_classes = data.loaders(args.dataset,
                                        args.data_path,
                                        args.batch_size,
                                        args.num_workers,
                                        args.transform,
                                        args.use_test,
                                        shuffle_train=False)

    model = init_model(args, num_classes)
    criterion = F.cross_entropy
    regularizer = curves.l2_regularizer(args.wd)

    (ts, tr_loss, tr_nll, tr_acc, te_loss, te_nll, te_acc, tr_err, te_err,
     dl) = init_metrics(args.num_points)

    previous_weights = None

    columns = [
        "t", "Train loss", "Train nll", "Train error (%)", "Test nll",
        "Test error (%)"
    ]

    curve_metrics = metrics.TestCurve(args.num_points, columns)
    t = torch.FloatTensor([0.0]).cuda()
    for i, t_value in enumerate(ts):
        t.data.fill_(t_value)
        weights = model.weights(t)
        if previous_weights is not None:
            dl[i] = np.sqrt(np.sum(np.square(weights - previous_weights)))
        previous_weights = weights.copy()

        utils.update_bn(loaders["train"], model, t=t)
        tr_res = utils.test(loaders["train"],
                            model,
                            criterion,
                            regularizer,
                            t=t)
        te_res = utils.test(loaders["test"],
                            model,
                            criterion,
                            regularizer,
                            t=t)
        tr_loss[i] = tr_res["loss"]
        tr_nll[i] = tr_res["nll"]
        tr_acc[i] = tr_res["accuracy"]
        tr_err[i] = 100.0 - tr_acc[i]
        te_loss[i] = te_res["loss"]
        te_nll[i] = te_res["nll"]
        te_acc[i] = te_res["accuracy"]
        te_err[i] = 100.0 - te_acc[i]
        curve_metrics.add_meas(i, tr_res, te_res)

        values = [t, tr_loss[i], tr_nll[i], tr_err[i], te_nll[i], te_err[i]]
        table = tabulate.tabulate([values],
                                  columns,
                                  tablefmt="simple",
                                  floatfmt="10.4f")
        print(curve_metrics.table(i, with_header=i % 40 == 0))
        if i % 40 == 0:
            table = table.split("\n")
            table = "\n".join([table[1]] + table)

        else:
            table = table.split("\n")[2]
        print(table)

    def stats(values, dl):
        min = np.min(values)
        max = np.max(values)
        avg = np.mean(values)
        int = np.sum(0.5 *
                     (values[:-1] + values[1:]) * dl[1:]) / np.sum(dl[1:])
        return min, max, avg, int

    tr_loss_min, tr_loss_max, tr_loss_avg, tr_loss_int = stats(tr_loss, dl)
    tr_nll_min, tr_nll_max, tr_nll_avg, tr_nll_int = stats(tr_nll, dl)
    tr_err_min, tr_err_max, tr_err_avg, tr_err_int = stats(tr_err, dl)

    te_loss_min, te_loss_max, te_loss_avg, te_loss_int = stats(te_loss, dl)
    te_nll_min, te_nll_max, te_nll_avg, te_nll_int = stats(te_nll, dl)
    te_err_min, te_err_max, te_err_avg, te_err_int = stats(te_err, dl)

    print("Length: %.2f" % np.sum(dl))
    print(
        tabulate.tabulate([
            [
                "train loss", tr_loss[0], tr_loss[-1], tr_loss_min,
                tr_loss_max, tr_loss_avg, tr_loss_int
            ],
            [
                "train error (%)", tr_err[0], tr_err[-1], tr_err_min,
                tr_err_max, tr_err_avg, tr_err_int
            ],
            [
                "test nll", te_nll[0], te_nll[-1], te_nll_min, te_nll_max,
                te_nll_avg, te_nll_int
            ],
            [
                "test error (%)", te_err[0], te_err[-1], te_err_min,
                te_err_max, te_err_avg, te_err_int
            ],
        ], ["", "start", "end", "min", "max", "avg", "int"],
                          tablefmt="simple",
                          floatfmt="10.4f"))

    np.savez(
        os.path.join(args.dir, "curve.npz"),
        ts=ts,
        dl=dl,
        tr_loss=tr_loss,
        tr_loss_min=tr_loss_min,
        tr_loss_max=tr_loss_max,
        tr_loss_avg=tr_loss_avg,
        tr_loss_int=tr_loss_int,
        tr_nll=tr_nll,
        tr_nll_min=tr_nll_min,
        tr_nll_max=tr_nll_max,
        tr_nll_avg=tr_nll_avg,
        tr_nll_int=tr_nll_int,
        tr_acc=tr_acc,
        tr_err=tr_err,
        tr_err_min=tr_err_min,
        tr_err_max=tr_err_max,
        tr_err_avg=tr_err_avg,
        tr_err_int=tr_err_int,
        te_loss=te_loss,
        te_loss_min=te_loss_min,
        te_loss_max=te_loss_max,
        te_loss_avg=te_loss_avg,
        te_loss_int=te_loss_int,
        te_nll=te_nll,
        te_nll_min=te_nll_min,
        te_nll_max=te_nll_max,
        te_nll_avg=te_nll_avg,
        te_nll_int=te_nll_int,
        te_acc=te_acc,
        te_err=te_err,
        te_err_min=te_err_min,
        te_err_max=te_err_max,
        te_err_avg=te_err_avg,
        te_err_int=te_err_int,
    )
Esempio n. 14
0
def test(models, writer, criterion, data_loader, epoch):

    model = models[0]
    model_0 = models[1]
    model_0.eval()
    model_0.zero_grad()

    model.apply(lambda m: setattr(m, "return_feats", True))
    model_0.apply(lambda m: setattr(m, "return_feats", True))

    model.zero_grad()
    model.eval()
    test_loss = 0
    correct = 0
    ensemble_correct = 0
    m0_correct = 0
    tv_dist = 0.0
    val_loader = data_loader.val_loader
    feat_cosim = 0

    model.apply(lambda m: setattr(m, "alpha", args.alpha1))
    model_0.apply(lambda m: setattr(m, "alpha", args.alpha0))

    if args.update_bn:
        utils.update_bn(data_loader.train_loader, model, device=args.device)
        utils.update_bn(data_loader.train_loader, model_0, device=args.device)

    with torch.no_grad():

        for data, target in val_loader:
            data, target = data.to(args.device), target.to(args.device)

            output, feats = model(data)
            test_loss += criterion(output, target).item()

            # get the index of the max log-probability
            pred = output.argmax(dim=1, keepdim=True)

            correct += pred.eq(target.view_as(pred)).sum().item()

            # get model 0
            model_0_output, model_0_feats = model_0(data)
            ensemble_pred = (model_0_output + output).argmax(
                dim=1, keepdim=True
            )
            ensemble_correct += (
                ensemble_pred.eq(target.view_as(pred)).sum().item()
            )

            m0_pred = model_0_output.argmax(dim=1, keepdim=True)
            m0_correct += m0_pred.eq(target.view_as(pred)).sum().item()

            model_t_prob = nn.functional.softmax(output, dim=1)
            model_0_prob = nn.functional.softmax(model_0_output, dim=1)
            tv_dist += 0.5 * (model_0_prob - model_t_prob).abs().sum().item()

            feat_cosim += (
                torch.nn.functional.cosine_similarity(
                    feats, model_0_feats, dim=1
                )
                .pow(2)
                .sum()
                .item()
            )

    test_loss /= len(val_loader)
    test_acc = float(correct) / len(val_loader.dataset)
    m0_acc = float(m0_correct) / len(val_loader.dataset)
    tv_dist /= len(val_loader.dataset)
    feat_cosim /= len(val_loader.dataset)
    ensemble_acc = float(ensemble_correct) / len(val_loader.dataset)

    print(
        f"\nTest set: Average loss: {test_loss:.4f}, Accuracy: ({test_acc:.4f})\n"
    )

    if args.save:
        writer.add_scalar(f"test/loss", test_loss, epoch)
        writer.add_scalar(f"test/acc", test_acc, epoch)

    metrics = {
        "tvdist": tv_dist,
        "ensemble_acc": ensemble_acc,
        "feat_cossim": feat_cosim,
        "m0_acc": m0_acc,
    }

    return test_acc, metrics
def test(models, writer, criterion, data_loader, epoch):
    j = args.j
    n = len(models)
    print(args.t)
    print((1 - args.t) / (n - 1))
    print("--")
    for ms in zip(*[model.modules() for model in models]):
        if isinstance(ms[0], nn.Conv2d):
            if j == 0:
                ms[0].weight.data = ms[0].weight.data * args.t
            else:
                ms[0].weight.data = ms[0].weight.data * (1 - args.t) / (n - 1)

            for i in range(1, n):
                if i == j:
                    ms[0].weight.data += ms[i].weight.data * args.t
                else:
                    ms[0].weight.data += (ms[i].weight.data * (1 - args.t) /
                                          (n - 1))
            print("conv", ms[0].weight[0, 0, 0, 0])
        elif isinstance(ms[0], nn.BatchNorm2d):
            if j == 0:
                ms[0].weight.data = ms[0].weight.data * args.t
            else:
                ms[0].weight.data = ms[0].weight.data * (1 - args.t) / (n - 1)

            for i in range(1, n):
                if i == j:
                    ms[0].weight.data += ms[i].weight.data * args.t
                else:
                    ms[0].weight.data += (ms[i].weight.data * (1 - args.t) /
                                          (n - 1))

            if j == 0:
                ms[0].bias.data = ms[0].bias.data * args.t
            else:
                ms[0].bias.data = ms[0].bias.data * (1 - args.t) / (n - 1)

            for i in range(1, n):
                if i == j:
                    ms[0].bias.data += ms[i].bias.data * args.t
                else:
                    ms[0].bias.data += ms[i].bias.data * (1 - args.t) / (n - 1)

            print("bn", ms[0].weight[0])
            print("bn", ms[0].bias[0])

    utils.update_bn(data_loader.train_loader, models[0], device=args.device)

    # here was save the model in args.tmp_dir/model_{j}.pt
    torch.save(
        {
            "epoch": 0,
            "iter": 0,
            "arch": args.model,
            "state_dicts": [models[0].state_dict()],
            "optimizers": None,
            "best_acc1": 0,
            "curr_acc1": 0,
        },
        os.path.join(args.tmp_dir, f"model_{j}.pt"),
    )

    test_acc = 0
    metrics = {}

    return test_acc, metrics