示例#1
0
                                                   model,
                                                   criterion,
                                                   optimizer,
                                                   weight_decay=args.wd,
                                                   velocity=velocity)

    if (epoch == 0 or epoch % args.eval_freq == args.eval_freq - 1
            or epoch == args.epochs - 1):
        test_res = utils.eval(loaders["test"], model, criterion)
    else:
        test_res = {"loss": None, "accuracy": None}

    if (args.swa and (epoch + 1) > args.swa_start
            and (epoch + 1 - args.swa_start) % args.swa_c_epochs == 0):
        # sgd_preds, sgd_targets = utils.predictions(loaders["test"], model)
        sgd_res = utils.predict(loaders["test"], model)
        sgd_preds = sgd_res["predictions"]
        sgd_targets = sgd_res["targets"]
        print("updating sgd_ens")
        if sgd_ens_preds is None:
            sgd_ens_preds = sgd_preds.copy()  #numpy copy
        else:
            # TODO: rewrite in a numerically stable way
            sgd_ens_preds = sgd_ens_preds * n_ensembled / (
                n_ensembled + 1) + sgd_preds / (n_ensembled + 1)
        n_ensembled += 1
        swag_model.collect_model(model)
        if (epoch == 0 or epoch % args.eval_freq == args.eval_freq - 1
                or epoch == args.epochs - 1):
            swag_model.sample(0.0)
            utils.bn_update(loaders["train"], swag_model)
    start_time = time.time()
    swag_model.load_state_dict(ckpt['state_dict'])

    swag_model.sample(0.0)
    utils.bn_update(loaders['train'], swag_model, subset=fraction)
    swa_res = utils.eval(loaders['test'], swag_model, criterion)
    swa_accuracies[i] = swa_res['accuracy']
    swa_nlls[i] = swa_res['loss']

    predictions = np.zeros((len(loaders['test'].dataset), num_classes))

    for j in range(args.S):
        swag_model.load_state_dict(ckpt['state_dict'])
        swag_model.sample(scale=0.5, cov=args.cov_mat)
        utils.bn_update(loaders['train'], swag_model, subset=fraction)
        sample_res = utils.predict(loaders['test'], swag_model)
        predictions += sample_res['predictions']
        targets = sample_res['targets']
    predictions /= args.S

    swag_accuracies[i] = np.mean(np.argmax(predictions, axis=1) == targets)
    swag_nlls[i] = -np.mean(
        np.log(predictions[np.arange(predictions.shape[0]), targets] + eps))

    run_time = time.time() - start_time
    values = [
        fraction * 100.0, swa_accuracies[i], swa_nlls[i], swag_accuracies[i],
        swag_nlls[i], run_time
    ]
    table = tabulate.tabulate([values],
                              columns,
                  max_num_models=20,
                  num_classes=num_classes)
swag_model.to(args.device)

criterion = losses.cross_entropy

print('Loading checkpoint %s' % args.ckpt)
checkpoint = torch.load(args.ckpt)
swag_model.load_state_dict(checkpoint['state_dict'])

print('SWA')
swag_model.sample(0.0)
print('SWA BN update')
utils.bn_update(loaders['train'], swag_model, verbose=True, subset=0.1)
print('SWA EVAL')
swa_res = utils.predict(loaders['test'], swag_model, verbose=True)

targets = swa_res['targets']
swa_predictions = swa_res['predictions']

swa_accuracy = np.mean(np.argmax(swa_predictions, axis=1) == targets)
swa_nll = -np.mean(
    np.log(swa_predictions[np.arange(swa_predictions.shape[0]), targets] +
           eps))
print('SWA. Accuracy: %.2f%% NLL: %.4f' % (swa_accuracy * 100, swa_nll))
swa_entropies = -np.sum(np.log(swa_predictions + eps) * swa_predictions,
                        axis=1)

np.savez(args.save_path_swa,
         accuracy=swa_accuracy,
         nll=swa_nll,
def boSwag(Pi):
    useMetric = 'nll'
    disableBo = False
    if disableBo == True:
        print("Computing for base case")
        # Base case = uniform weights for each SWAG and its models
        swagCount = 0
        for ckpt_i, ckpt in enumerate(args.swag_ckpts):
            swagCount += 1
        n_ensembled = 0.
        multiswag_probs = None
        for ckpt_i, ckpt in enumerate(args.swag_ckpts):
            print("Checkpoint {}".format(ckpt))
            checkpoint = torch.load(ckpt)
            swag_model.subspace.rank = torch.tensor(0)
            swag_model.load_state_dict(checkpoint['state_dict'])

            for sample in range(args.swag_samples):
                swag_model.sample(.5)
                utils.bn_update(loaders['train'], swag_model)
                res = utils.predict(loaders['test'], swag_model)
                probs = res['predictions']
                targets = res['targets']
                nll = utils.nll(probs, targets)
                acc = utils.accuracy(probs, targets)

                if multiswag_probs is None:
                    multiswag_probs = probs.copy()
                else:
                    #TODO: rewrite in a numerically stable way
                    multiswag_probs += (probs -
                                        multiswag_probs) / (n_ensembled + 1)
                n_ensembled += 1

                ens_nll = utils.nll(multiswag_probs, targets)
                ens_acc = utils.accuracy(multiswag_probs, targets)
                values = [ckpt_i, sample, nll, acc, ens_nll, ens_acc]
                table = tabulate.tabulate([values],
                                          columns,
                                          tablefmt='simple',
                                          floatfmt='8.4f')
                print(table)
        initialPi = [1 / swagCount] * swagCount
        return (initialPi, ens_nll, ens_acc)
    else:
        n_ensembled = 0.
        multiswag_probs = None
        for ckpt_i, ckpt in enumerate(args.swag_ckpts):
            #print("Checkpoint {}".format(ckpt))
            checkpoint = torch.load(ckpt)
            swag_model.subspace.rank = torch.tensor(0)
            swag_model.load_state_dict(checkpoint['state_dict'])
            #swagWeight = Pi[ckpt]
            #swagWeight = Pi[ckpt]/sum([Pi[i] for i in Pi])
            swagWeight = Pi[ckpt_i] / sum(Pi)
            indivWeight = swagWeight / args.swag_samples

            for sample in range(args.swag_samples):
                swag_model.sample(.5)
                utils.bn_update(loaders['train'], swag_model)
                res = utils.predict(loaders['test'], swag_model)
                probs = res['predictions']
                targets = res['targets']
                nll = utils.nll(probs, targets)
                acc = utils.accuracy(probs, targets)

                if multiswag_probs is None:
                    multiswag_probs = indivWeight * probs.copy()
                else:
                    #TODO: rewrite in a numerically stable way
                    #multiswag_probs +=  (probs - multiswag_probs)/ (n_ensembled + 1)
                    multiswag_probs += indivWeight * probs.copy()
                n_ensembled += 1

                ens_nll = utils.nll(multiswag_probs, targets)
                ens_acc = utils.accuracy(multiswag_probs, targets)
                values = [ckpt_i, sample, nll, acc, ens_nll, ens_acc]
                #table = tabulate.tabulate([values], columns, tablefmt='simple', floatfmt='8.4f')
                #print(table)
        paramDump.append([Pi, ens_nll, ens_acc])
        print(Pi, ens_nll, ens_acc)
        if useMetric == 'nll':
            return ens_nll
        else:
            return ens_acc
示例#5
0
        if epoch + 1 >= args.ens_start:
            n_iter += 1

    train_res = {
        'loss': loss_sum / len(loaders['train'].dataset),
        'accuracy': correct / len(loaders['train'].dataset) * 100.0,
    }

    if epoch == 0 or epoch % args.eval_freq == args.eval_freq - 1 or epoch == args.epochs - 1:
        test_res = utils.eval(loaders['test'], model, criterion)
    else:
        test_res = {'loss': None, 'accuracy': None}

    if epoch + 1 >= args.ens_start:
        utils.bn_update(loaders['train'], model)
        out = utils.predict(loaders['test'], model)
        cur_pred = out['predictions']
        cur_targets = out['targets']

        if sgld_ens_pred is None:
            sgld_ens_pred = cur_pred.copy()
            sgld_targets = cur_targets.copy()
        else:
            sgld_ens_pred += (cur_pred - sgld_ens_pred) / (n_ensembled + 1)
        n_ensembled += 1

        idx = np.arange(sgld_targets.size)
        ens_loss = np.mean(-np.log(sgld_ens_pred[idx, sgld_targets]))
        ens_acc = np.mean(
            np.argmax(sgld_ens_pred, axis=-1) == sgld_targets) * 100.0
示例#6
0
        os.path.join(args.dir, args.log_fname + '-%s.txt'))
    print('Saving logs to: %s' % logfile)
    columns = ['iter ens', 'acc', 'nll']

    for i in range(num_samples):
        with torch.no_grad():
            w = vi_model.sample()
            offset = 0
            for param in eval_model.parameters():
                param.data.copy_(w[offset:offset + param.numel()].view(
                    param.size()).to(args.device))
                offset += param.numel()

        utils.bn_update(loaders['train'], eval_model, subset=args.bn_subset)

        pred_res = utils.predict(loaders['test'], eval_model)
        ens_predictions += pred_res['predictions']
        targets = pred_res['targets']

        values = [
            '%3d/%3d' % (i + 1, num_samples),
            np.mean(np.argmax(ens_predictions, axis=1) == targets),
            nll(ens_predictions / (i + 1), targets)
        ]
        table = tabulate.tabulate([values],
                                  columns,
                                  tablefmt='simple',
                                  floatfmt='8.4f')
        if i == 0:
            printf(table)
        else:
示例#7
0
    print("Using Data Parallel model")
    model = torch.nn.parallel.DataParallel(model)

print("Loading checkpoint %s" % args.ckpt)
checkpoint = torch.load(args.ckpt)
state_dict = checkpoint["state_dict"]
if args.ckpt_cut_prefix:
    state_dict = {
        k[7:] if k.startswith("module.") else k: v for k, v in state_dict.items()
    }
model.load_state_dict(state_dict)

print("BN update")
utils.bn_update(loaders["train"], model, verbose=True, subset=0.1)
print("EVAL")
res = utils.predict(loaders["test"], model, verbose=True)

predictions = res["predictions"]
targets = res["targets"]


accuracy = np.mean(np.argmax(predictions, axis=1) == targets)
nll = -np.mean(np.log(predictions[np.arange(predictions.shape[0]), targets] + eps))
print("Accuracy: %.2f%% NLL: %.4f" % (accuracy * 100, nll))
entropies = -np.sum(np.log(predictions + eps) * predictions, axis=1)


np.savez(
    args.save_path,
    accuracy=accuracy,
    nll=nll,
示例#8
0
    start_time = time.time()
    swag_model.load_state_dict(ckpt["state_dict"])

    swag_model.sample(0.0)
    utils.bn_update(loaders["train"], swag_model, subset=fraction)
    swa_res = utils.eval(loaders["test"], swag_model, criterion)
    swa_accuracies[i] = swa_res["accuracy"]
    swa_nlls[i] = swa_res["loss"]

    predictions = np.zeros((len(loaders["test"].dataset), num_classes))

    for j in range(args.S):
        swag_model.load_state_dict(ckpt["state_dict"])
        swag_model.sample(scale=0.5, cov=args.cov_mat)
        utils.bn_update(loaders["train"], swag_model, subset=fraction)
        sample_res = utils.predict(loaders["test"], swag_model)
        predictions += sample_res["predictions"]
        targets = sample_res["targets"]
    predictions /= args.S

    swag_accuracies[i] = np.mean(np.argmax(predictions, axis=1) == targets)
    swag_nlls[i] = -np.mean(
        np.log(predictions[np.arange(predictions.shape[0]), targets] + eps))

    run_time = time.time() - start_time
    values = [
        fraction * 100.0,
        swa_accuracies[i],
        swa_nlls[i],
        swag_accuracies[i],
        swag_nlls[i],
    model = torch.nn.parallel.DataParallel(model)

print('Loading checkpoint %s' % args.ckpt)
checkpoint = torch.load(args.ckpt)
state_dict = checkpoint['state_dict']
if args.ckpt_cut_prefix:
    state_dict = {
        k[7:] if k.startswith('module.') else k: v
        for k, v in state_dict.items()
    }
model.load_state_dict(state_dict)

print('BN update')
utils.bn_update(loaders['train'], model, verbose=True, subset=0.1)
print('EVAL')
res = utils.predict(loaders['test'], model, verbose=True)

predictions = res['predictions']
targets = res['targets']

accuracy = np.mean(np.argmax(predictions, axis=1) == targets)
nll = -np.mean(
    np.log(predictions[np.arange(predictions.shape[0]), targets] + eps))
print('Accuracy: %.2f%% NLL: %.4f' % (accuracy * 100, nll))
entropies = -np.sum(np.log(predictions + eps) * predictions, axis=1)

np.savez(args.save_path,
         accuracy=accuracy,
         nll=nll,
         entropies=entropies,
         predictions=predictions,