Ejemplo n.º 1
0
    def update_results(self, epoch, calibration=False):
        outputs = []
        self.results["Epochs"].append(epoch)
        self.trainer.swag_model.set_swa()
        logits, probs, preds, targets, acc = self.trainer.getTestOutputs()
        self.results["SWA"].append(acc)
        if calibration:
            calibration_dict = calibration_curve(probs,
                                                 targets,
                                                 num_bins=self.num_bins)
            self.results["SWA calibration"] = calibration_dict
        for index in range(self.num_models):
            self.trainer.swag_model.sample(scale=self.scale)
            bn_update(self.trainer.dataloaders["train"],
                      self.trainer.swag_model)
            logits, probs, preds, targets, acc = self.trainer.getTestOutputs()
            self.results["Model {}".format(index + 1)].append(acc)
            outputs.append([logits, probs, preds, targets, acc])
            if calibration:
                calibration_dict = calibration_curve(probs,
                                                     targets,
                                                     num_bins=self.num_bins)
                self.results["Model {} calibration".format(
                    index + 1)] = calibration_dict

        en_probs, en_acc, fl_probs, fl_acc, targets = self.__get_ensembles_accuracy(
            outputs)
        self.results["SWAG"].append(en_acc)
        self.results["Flow SWAG"].append(fl_acc)
        if calibration:
            calibration_dict = calibration_curve(en_probs,
                                                 targets,
                                                 num_bins=self.num_bins)
            self.results["SWAG calibration"] = calibration_dict
            calibration_dict = calibration_curve(fl_probs,
                                                 targets,
                                                 num_bins=self.num_bins)
            self.results["Flow SWAG calibration"] = calibration_dict
Ejemplo n.º 2
0
def criterion(model, input, target, scale=args.prior_std):
    likelihood, output, _ = losses.cross_entropy(model, input, target)
    prior = 1 / (scale**2.0 * input.size(0)) * proj_params.norm()
    return likelihood + prior, output, {
        'nll': likelihood * input.size(0),
        'prior': proj_params.norm()
    }


optimizer = torch.optim.SGD([proj_params],
                            lr=5e-4,
                            momentum=0.9,
                            weight_decay=0)

swag_model.sample(0)
utils.bn_update(loaders['train'], swag_model)
print(utils.eval(loaders['test'], swag_model, criterion))

printf, logfile = utils.get_logging_print(
    os.path.join(args.dir, args.log_fname + '-%s.txt'))
print('Saving logs to: %s' % logfile)
#printf=print
columns = ['ep', 'acc', 'loss', 'prior']

for epoch in range(args.epochs):
    train_res = utils.train_epoch(loaders['train'], proj_model, criterion,
                                  optimizer)
    values = [
        '%d/%d' % (epoch + 1, args.epochs), train_res['accuracy'],
        train_res['loss'], train_res['stats']['prior'],
        train_res['stats']['nll']
fractions = np.logspace(-np.log10(0.005 * len(loaders['train'].dataset)), 0.0,
                        args.N)
swa_accuracies = np.zeros(args.N)
swa_nlls = np.zeros(args.N)
swag_accuracies = np.zeros(args.N)
swag_nlls = np.zeros(args.N)

columns = ['fraction', 'swa_acc', 'swa_loss', 'swag_acc', 'swag_loss', 'time']

for i, fraction in enumerate(fractions):
    start_time = time.time()
    swag_model.load_state_dict(ckpt['state_dict'])

    swag_model.sample(0.0)
    utils.bn_update(loaders['train'], swag_model, subset=fraction)
    swa_res = utils.eval(loaders['test'], swag_model, criterion)
    swa_accuracies[i] = swa_res['accuracy']
    swa_nlls[i] = swa_res['loss']

    predictions = np.zeros((len(loaders['test'].dataset), num_classes))

    for j in range(args.S):
        swag_model.load_state_dict(ckpt['state_dict'])
        swag_model.sample(scale=0.5, cov=args.cov_mat)
        utils.bn_update(loaders['train'], swag_model, subset=fraction)
        sample_res = utils.predict(loaders['test'], swag_model)
        predictions += sample_res['predictions']
        targets = sample_res['targets']
    predictions /= args.S
Ejemplo n.º 4
0
        sgd_res = utils.predict(loaders["test"], model)
        sgd_preds = sgd_res["predictions"]
        sgd_targets = sgd_res["targets"]
        print("updating sgd_ens")
        if sgd_ens_preds is None:
            sgd_ens_preds = sgd_preds.copy()  #numpy copy
        else:
            # TODO: rewrite in a numerically stable way
            sgd_ens_preds = sgd_ens_preds * n_ensembled / (
                n_ensembled + 1) + sgd_preds / (n_ensembled + 1)
        n_ensembled += 1
        swag_model.collect_model(model)
        if (epoch == 0 or epoch % args.eval_freq == args.eval_freq - 1
                or epoch == args.epochs - 1):
            swag_model.sample(0.0)
            utils.bn_update(loaders["train"], swag_model)
            swag_res = utils.eval(loaders["test"], swag_model, criterion)
        else:
            swag_res = {"loss": None, "accuracy": None}

    # if (epoch + 1) % args.save_freq == 0:
    #     # utils.save_checkpoint(
    #     #     args.dir,
    #     #     epoch + 1,
    #     #     state_dict=model.state_dict(),
    #     #     optimizer=optimizer.state_dict(),
    #     # )
    #     if args.swa:
    #         utils.save_checkpoint(
    #             args.dir, epoch + 1, name="swag", state_dict=swag_model.state_dict()
    #         )
def boSwag(Pi):
    useMetric = 'nll'
    disableBo = False
    if disableBo == True:
        print("Computing for base case")
        # Base case = uniform weights for each SWAG and its models
        swagCount = 0
        for ckpt_i, ckpt in enumerate(args.swag_ckpts):
            swagCount += 1
        n_ensembled = 0.
        multiswag_probs = None
        for ckpt_i, ckpt in enumerate(args.swag_ckpts):
            print("Checkpoint {}".format(ckpt))
            checkpoint = torch.load(ckpt)
            swag_model.subspace.rank = torch.tensor(0)
            swag_model.load_state_dict(checkpoint['state_dict'])

            for sample in range(args.swag_samples):
                swag_model.sample(.5)
                utils.bn_update(loaders['train'], swag_model)
                res = utils.predict(loaders['test'], swag_model)
                probs = res['predictions']
                targets = res['targets']
                nll = utils.nll(probs, targets)
                acc = utils.accuracy(probs, targets)

                if multiswag_probs is None:
                    multiswag_probs = probs.copy()
                else:
                    #TODO: rewrite in a numerically stable way
                    multiswag_probs += (probs -
                                        multiswag_probs) / (n_ensembled + 1)
                n_ensembled += 1

                ens_nll = utils.nll(multiswag_probs, targets)
                ens_acc = utils.accuracy(multiswag_probs, targets)
                values = [ckpt_i, sample, nll, acc, ens_nll, ens_acc]
                table = tabulate.tabulate([values],
                                          columns,
                                          tablefmt='simple',
                                          floatfmt='8.4f')
                print(table)
        initialPi = [1 / swagCount] * swagCount
        return (initialPi, ens_nll, ens_acc)
    else:
        n_ensembled = 0.
        multiswag_probs = None
        for ckpt_i, ckpt in enumerate(args.swag_ckpts):
            #print("Checkpoint {}".format(ckpt))
            checkpoint = torch.load(ckpt)
            swag_model.subspace.rank = torch.tensor(0)
            swag_model.load_state_dict(checkpoint['state_dict'])
            #swagWeight = Pi[ckpt]
            #swagWeight = Pi[ckpt]/sum([Pi[i] for i in Pi])
            swagWeight = Pi[ckpt_i] / sum(Pi)
            indivWeight = swagWeight / args.swag_samples

            for sample in range(args.swag_samples):
                swag_model.sample(.5)
                utils.bn_update(loaders['train'], swag_model)
                res = utils.predict(loaders['test'], swag_model)
                probs = res['predictions']
                targets = res['targets']
                nll = utils.nll(probs, targets)
                acc = utils.accuracy(probs, targets)

                if multiswag_probs is None:
                    multiswag_probs = indivWeight * probs.copy()
                else:
                    #TODO: rewrite in a numerically stable way
                    #multiswag_probs +=  (probs - multiswag_probs)/ (n_ensembled + 1)
                    multiswag_probs += indivWeight * probs.copy()
                n_ensembled += 1

                ens_nll = utils.nll(multiswag_probs, targets)
                ens_acc = utils.accuracy(multiswag_probs, targets)
                values = [ckpt_i, sample, nll, acc, ens_nll, ens_acc]
                #table = tabulate.tabulate([values], columns, tablefmt='simple', floatfmt='8.4f')
                #print(table)
        paramDump.append([Pi, ens_nll, ens_acc])
        print(Pi, ens_nll, ens_acc)
        if useMetric == 'nll':
            return ens_nll
        else:
            return ens_acc
                  no_cov_mat=not args.cov_mat,
                  loading=True,
                  max_num_models=20,
                  num_classes=num_classes)
swag_model.to(args.device)

criterion = losses.cross_entropy

print('Loading checkpoint %s' % args.ckpt)
checkpoint = torch.load(args.ckpt)
swag_model.load_state_dict(checkpoint['state_dict'])

print('SWA')
swag_model.sample(0.0)
print('SWA BN update')
utils.bn_update(loaders['train'], swag_model, verbose=True, subset=0.1)
print('SWA EVAL')
swa_res = utils.predict(loaders['test'], swag_model, verbose=True)

targets = swa_res['targets']
swa_predictions = swa_res['predictions']

swa_accuracy = np.mean(np.argmax(swa_predictions, axis=1) == targets)
swa_nll = -np.mean(
    np.log(swa_predictions[np.arange(swa_predictions.shape[0]), targets] +
           eps))
print('SWA. Accuracy: %.2f%% NLL: %.4f' % (swa_accuracy * 100, swa_nll))
swa_entropies = -np.sum(np.log(swa_predictions + eps) * swa_predictions,
                        axis=1)

np.savez(args.save_path_swa,
Ejemplo n.º 7
0
    printf, logfile = utils.get_logging_print(
        os.path.join(args.dir, args.log_fname + '-%s.txt'))
    print('Saving logs to: %s' % logfile)
    columns = ['iter ens', 'acc', 'nll']

    for i in range(num_samples):
        with torch.no_grad():
            w = vi_model.sample()
            offset = 0
            for param in eval_model.parameters():
                param.data.copy_(w[offset:offset + param.numel()].view(
                    param.size()).to(args.device))
                offset += param.numel()

        utils.bn_update(loaders['train'], eval_model, subset=args.bn_subset)

        pred_res = utils.predict(loaders['test'], eval_model)
        ens_predictions += pred_res['predictions']
        targets = pred_res['targets']

        values = [
            '%3d/%3d' % (i + 1, num_samples),
            np.mean(np.argmax(ens_predictions, axis=1) == targets),
            nll(ens_predictions / (i + 1), targets)
        ]
        table = tabulate.tabulate([values],
                                  columns,
                                  tablefmt='simple',
                                  floatfmt='8.4f')
        if i == 0:
Ejemplo n.º 8
0
                                    use_validation=not args.use_test,
                                    split_classes=args.split_classes)

model = model_cfg.base(*model_cfg.args,
                       num_classes=num_classes,
                       **model_cfg.kwargs)
model.to(args.device)

print('Loading checkpoint %s' % args.ckpt)
checkpoint = torch.load(args.ckpt)

num_parameters = sum([p.numel() for p in model.parameters()])

offset = 0
for name, param in model.named_parameters():
    if 'net.%s_1' % name in checkpoint['model_state']:
        param.data.copy_(checkpoint['model_state']['net.%s_1' % name])
    else:
        # PRERESNET 164 fix
        tokens = name.split('.')
        name_fixed = '.'.join(tokens[:3] + tokens[4:])
        param.data.copy_(checkpoint['model_state']['net.%s_1' % name_fixed])

utils.bn_update(loaders['train'], model, verbose=True)

print(utils.eval(loaders['test'], model, losses.cross_entropy))

torch.save(
    {'state_dict': model.state_dict()},
    args.save_path,
)
Ejemplo n.º 9
0
if args.parallel:
    print("Using Data Parallel model")
    model = torch.nn.parallel.DataParallel(model)

print("Loading checkpoint %s" % args.ckpt)
checkpoint = torch.load(args.ckpt)
state_dict = checkpoint["state_dict"]
if args.ckpt_cut_prefix:
    state_dict = {
        k[7:] if k.startswith("module.") else k: v for k, v in state_dict.items()
    }
model.load_state_dict(state_dict)

print("BN update")
utils.bn_update(loaders["train"], model, verbose=True, subset=0.1)
print("EVAL")
res = utils.predict(loaders["test"], model, verbose=True)

predictions = res["predictions"]
targets = res["targets"]


accuracy = np.mean(np.argmax(predictions, axis=1) == targets)
nll = -np.mean(np.log(predictions[np.arange(predictions.shape[0]), targets] + eps))
print("Accuracy: %.2f%% NLL: %.4f" % (accuracy * 100, nll))
entropies = -np.sum(np.log(predictions + eps) * predictions, axis=1)


np.savez(
    args.save_path,
Ejemplo n.º 10
0
                                                      criterion)
        print('Val - Loss: {:.4f} | Acc: {:.4f} | IOU: {:.4f}'.format(
            val_loss, 1 - val_err, val_iou))

    time_elapsed = time.time() - since
    print('Total Time {:.0f}m {:.0f}s\n'.format(time_elapsed // 60,
                                                time_elapsed % 60))

    if args.swa and (epoch + 1) > args.swa_start and (
            epoch + 1 - args.swa_start) % args.swa_c_epochs == 0:
        print('Saving SWA model at epoch: ', epoch)
        swag_model.collect_model(model)

        if epoch % args.eval_freq is 0:
            swag_model.sample(0.0)
            bn_update(train_loader, swag_model)
            val_loss, val_err, val_iou = train_utils.test(
                swag_model, loaders['val'], criterion)
            print('SWA Val - Loss: {:.4f} | Acc: {:.4f} | IOU: {:.4f}'.format(
                val_loss, 1 - val_err, val_iou))

    ### Checkpoint ###
    if epoch % args.save_freq is 0:
        print('Saving model at Epoch: ', epoch)
        save_checkpoint(dir=args.dir,
                        epoch=epoch,
                        state_dict=model.state_dict(),
                        optimizer=optimizer.state_dict())
        if args.swa and (epoch + 1) > args.swa_start:
            save_checkpoint(
                dir=args.dir,
Ejemplo n.º 11
0
# construct and load model
if args.swa_resume is not None:
    checkpoint = torch.load(args.swa_resume)
    model = SWAG(
        model_cfg.base,
        no_cov_mat=False,
        max_num_models=20,
        num_classes=num_classes,
        use_aleatoric=args.loss == "aleatoric",
    )
    model.cuda()
    model.load_state_dict(checkpoint["state_dict"])

    model.sample(0.0)
    bn_update(loaders["fine_tune"], model)
else:
    model = model_cfg.base(num_classes=num_classes,
                           use_aleatoric=args.loss == "aleatoric").cuda()
    checkpoint = torch.load(args.resume)
    start_epoch = checkpoint["epoch"]
    print(start_epoch)
    model.load_state_dict(checkpoint["state_dict"])

print(len(loaders["test"]))
if args.use_test:
    print("Using test dataset")
    test_loader = "test"
else:
    test_loader = "val"
loss, err, mIOU, model_output_targets = test(
Ejemplo n.º 12
0
    coords = np.dot(
        cov_factor / np.sum(np.square(cov_factor), axis=1, keepdims=True),
        (w - mean[None, :]).T).T
    print(coords)

    theta = torch.FloatTensor(coords[2, :])

    for i in range(3):
        v = subspace(torch.FloatTensor(coords[i]))
        offset = 0
        for param in model.parameters():
            param.data.copy_(v[offset:offset + param.numel()].view(
                param.size()).to(args.device))
            offset += param.numel()
        utils.bn_update(loaders_bn['train'], model)
        print(utils.eval(loaders['test'], model, losses.cross_entropy))

else:
    assert len(args.checkpoint) == 1
    swag_model = SWAG(model_cfg.base,
                      num_classes=num_classes,
                      subspace_type='pca',
                      subspace_kwargs={
                          'max_rank': 20,
                          'pca_rank': args.rank,
                      },
                      *model_cfg.args,
                      **model_cfg.kwargs)
    swag_model.to(args.device)
Ejemplo n.º 13
0
#criterion = nn.NLLLoss(weight=camvid.class_weight[:-1].cuda(), reduction='none').cuda()
if args.loss == 'cross_entropy':
    criterion = losses.seg_cross_entropy
else:
    criterion = losses.seg_ale_cross_entropy

# construct and load model
if args.swa_resume is not None:
    checkpoint = torch.load(args.swa_resume)
    model = SWAG(model_cfg.base, no_cov_mat=False, max_num_models=20,  
                num_classes=num_classes, use_aleatoric=args.loss=='aleatoric')
    model.cuda()
    model.load_state_dict(checkpoint['state_dict'])

    model.sample(0.0)
    bn_update(loaders['fine_tune'], model)
else:
    model = model_cfg.base(num_classes=num_classes, use_aleatoric=args.loss=='aleatoric').cuda()
    checkpoint = torch.load(args.resume)
    start_epoch = checkpoint['epoch']
    print(start_epoch)
    model.load_state_dict(checkpoint['state_dict'])

print(len(loaders['test']))
if args.use_test:
    print('Using test dataset')
    test_loader = 'test'
else:
    test_loader = 'val'
loss, err, mIOU, model_output_targets = test(model, loaders[test_loader], criterion, return_outputs = True, return_scale = args.loss=='aleatoric')
print(loss, 1-err, mIOU)