Example #1
0
def compute_joint(loader, model, criterion, wd_scale=3e-4):
    nll_dict = utils.eval(loader=loader, model=model, criterion=criterion)
    nll = nll_dict['loss'] * len(loader.dataset)

    prior = compute_swag_param_norm(model) * wd_scale

    return nll + prior
Example #2
0
def train_epoch(model,
                loaders,
                criterion,
                optimizer,
                epoch,
                end_epoch,
                eval_freq=1,
                save_freq=10,
                output_dir='./',
                lr_init=0.01):

    time_ep = time.time()

    lr = training_utils.schedule(epoch, lr_init, end_epoch, swa=False)
    training_utils.adjust_learning_rate(optimizer, lr)
    train_res = training_utils.train_epoch(loaders["train"], model, criterion,
                                           optimizer)
    if (epoch == 0 or epoch % eval_freq == eval_freq - 1
            or epoch == end_epoch - 1):
        test_res = training_utils.eval(loaders["test"], model, criterion)
    else:
        test_res = {"loss": None, "accuracy": None}

    if (epoch + 1) % save_freq == 0:
        training_utils.save_checkpoint(
            output_dir,
            epoch + 1,
            state_dict=model.state_dict(),
            optimizer=optimizer.state_dict(),
        )

    time_ep = time.time() - time_ep
    values = [
        epoch + 1,
        lr,
        train_res["loss"],
        train_res["accuracy"],
        test_res["loss"],
        test_res["accuracy"],
        time_ep,
    ]
    table = tabulate.tabulate([values],
                              columns,
                              tablefmt="simple",
                              floatfmt="8.4f")
    if epoch % 40 == 0:
        table = table.split("\n")
        table = "\n".join([table[1]] + table)
    else:
        table = table.split("\n")[2]
    print(table)
Example #3
0
    likelihood, output, _ = losses.cross_entropy(model, input, target)
    prior = 1 / (scale**2.0 * input.size(0)) * proj_params.norm()
    return likelihood + prior, output, {
        'nll': likelihood * input.size(0),
        'prior': proj_params.norm()
    }


optimizer = torch.optim.SGD([proj_params],
                            lr=5e-4,
                            momentum=0.9,
                            weight_decay=0)

swag_model.sample(0)
utils.bn_update(loaders['train'], swag_model)
print(utils.eval(loaders['test'], swag_model, criterion))

printf, logfile = utils.get_logging_print(
    os.path.join(args.dir, args.log_fname + '-%s.txt'))
print('Saving logs to: %s' % logfile)
#printf=print
columns = ['ep', 'acc', 'loss', 'prior']

for epoch in range(args.epochs):
    train_res = utils.train_epoch(loaders['train'], proj_model, criterion,
                                  optimizer)
    values = [
        '%d/%d' % (epoch + 1, args.epochs), train_res['accuracy'],
        train_res['loss'], train_res['stats']['prior'],
        train_res['stats']['nll']
    ]
Example #4
0
                  subspace_type='pca',
                  subspace_kwargs={
                      'max_rank': 20,
                      'pca_rank': args.rank,
                  },
                  *model_cfg.args,
                  **model_cfg.kwargs)
swag_model.to(args.device)

print('Loading: %s' % args.checkpoint)
ckpt = torch.load(args.checkpoint)
swag_model.load_state_dict(ckpt['state_dict'], strict=False)

swag_model.set_swa()
print("SWA:",
      utils.eval(loaders["train"], swag_model, criterion=losses.cross_entropy))

mean, var, cov_factor = swag_model.get_space()
subspace = Subspace(mean, cov_factor)

print(torch.norm(cov_factor, dim=1))

nvp_flow = construct_flow(cov_factor.shape[0],
                          device=torch.cuda.current_device())

vi_model = VINFModel(base=model_cfg.base,
                     subspace=subspace,
                     flow=nvp_flow,
                     prior_log_sigma=math.log(args.prior_std) +
                     math.log(args.temperature) / 2,
                     num_classes=num_classes,
fractions = np.logspace(-np.log10(0.005 * len(loaders['train'].dataset)), 0.0,
                        args.N)
swa_accuracies = np.zeros(args.N)
swa_nlls = np.zeros(args.N)
swag_accuracies = np.zeros(args.N)
swag_nlls = np.zeros(args.N)

columns = ['fraction', 'swa_acc', 'swa_loss', 'swag_acc', 'swag_loss', 'time']

for i, fraction in enumerate(fractions):
    start_time = time.time()
    swag_model.load_state_dict(ckpt['state_dict'])

    swag_model.sample(0.0)
    utils.bn_update(loaders['train'], swag_model, subset=fraction)
    swa_res = utils.eval(loaders['test'], swag_model, criterion)
    swa_accuracies[i] = swa_res['accuracy']
    swa_nlls[i] = swa_res['loss']

    predictions = np.zeros((len(loaders['test'].dataset), num_classes))

    for j in range(args.S):
        swag_model.load_state_dict(ckpt['state_dict'])
        swag_model.sample(scale=0.5, cov=args.cov_mat)
        utils.bn_update(loaders['train'], swag_model, subset=fraction)
        sample_res = utils.predict(loaders['test'], swag_model)
        predictions += sample_res['predictions']
        targets = sample_res['targets']
    predictions /= args.S

    swag_accuracies[i] = np.mean(np.argmax(predictions, axis=1) == targets)
Example #6
0
                                                   criterion,
                                                   optimizer,
                                                   weight_decay=args.wd,
                                                   velocity=velocity)
    else:
        #train_res = utils.train_epoch(loaders["train"], model, criterion, optimizer, weight_decay=args.wd)
        train_res, velocity = utils.train_epoch_v2(loaders["train"],
                                                   model,
                                                   criterion,
                                                   optimizer,
                                                   weight_decay=args.wd,
                                                   velocity=velocity)

    if (epoch == 0 or epoch % args.eval_freq == args.eval_freq - 1
            or epoch == args.epochs - 1):
        test_res = utils.eval(loaders["test"], model, criterion)
    else:
        test_res = {"loss": None, "accuracy": None}

    if (args.swa and (epoch + 1) > args.swa_start
            and (epoch + 1 - args.swa_start) % args.swa_c_epochs == 0):
        # sgd_preds, sgd_targets = utils.predictions(loaders["test"], model)
        sgd_res = utils.predict(loaders["test"], model)
        sgd_preds = sgd_res["predictions"]
        sgd_targets = sgd_res["targets"]
        print("updating sgd_ens")
        if sgd_ens_preds is None:
            sgd_ens_preds = sgd_preds.copy()  #numpy copy
        else:
            # TODO: rewrite in a numerically stable way
            sgd_ens_preds = sgd_ens_preds * n_ensembled / (
Example #7
0
for i, x in enumerate(xs):
    for j, y in enumerate(ys):
        t_start = time.time()
        w = mean + x * u + y * v

        offset = 0
        for param in model.parameters():
            size = np.prod(param.size())
            param.data.copy_(
                param.new_tensor(w[offset:offset + size].reshape(
                    param.size())))
            offset += size

        utils.bn_update(loaders['train'], model)
        train_res = utils.eval(loaders['train'], model, criterion)
        test_res = utils.eval(loaders['test'], model, criterion)

        train_acc[i, j] = train_res['accuracy']
        train_loss[i, j] = train_res['loss']
        test_acc[i, j] = test_res['accuracy']
        test_loss[i, j] = test_res['loss']

        t = time.time() - t_start
        values = [
            x, y, train_loss[i, j], train_acc[i, j], test_loss[i, j],
            test_acc[i, j], t
        ]
        table = tabulate.tabulate([values],
                                  columns,
                                  tablefmt='simple',
Example #8
0
for i, x in enumerate(xs):
    for j, y in enumerate(ys):
        t_start = time.time()
        w = mean + x * u + y * v

        offset = 0
        for param in model.parameters():
            size = np.prod(param.size())
            param.data.copy_(
                param.new_tensor(w[offset:offset + size].reshape(
                    param.size())))
            offset += size

        utils.bn_update(loaders["train"], model)
        train_res = utils.eval(loaders["train"], model, criterion)
        test_res = utils.eval(loaders["test"], model, criterion)

        train_acc[i, j] = train_res["accuracy"]
        train_loss[i, j] = train_res["loss"]
        test_acc[i, j] = test_res["accuracy"]
        test_loss[i, j] = test_res["loss"]

        t = time.time() - t_start
        values = [
            x,
            y,
            train_loss[i, j],
            train_acc[i, j],
            test_loss[i, j],
            test_acc[i, j],
Example #9
0
                                    use_validation=not args.use_test,
                                    split_classes=args.split_classes)

model = model_cfg.base(*model_cfg.args,
                       num_classes=num_classes,
                       **model_cfg.kwargs)
model.to(args.device)

print('Loading checkpoint %s' % args.ckpt)
checkpoint = torch.load(args.ckpt)

num_parameters = sum([p.numel() for p in model.parameters()])

offset = 0
for name, param in model.named_parameters():
    if 'net.%s_1' % name in checkpoint['model_state']:
        param.data.copy_(checkpoint['model_state']['net.%s_1' % name])
    else:
        # PRERESNET 164 fix
        tokens = name.split('.')
        name_fixed = '.'.join(tokens[:3] + tokens[4:])
        param.data.copy_(checkpoint['model_state']['net.%s_1' % name_fixed])

utils.bn_update(loaders['train'], model, verbose=True)

print(utils.eval(loaders['test'], model, losses.cross_entropy))

torch.save(
    {'state_dict': model.state_dict()},
    args.save_path,
)
            utils.save_checkpoint(
                args.dir,
                num_iterates,
                name='iter',
                state_dict=model.state_dict(),
            )

            model.to(args.swa_device)
            swag_model.collect_model(model)
            model.to(args.device)
    else:
        train_res = utils.train_epoch(loaders['train'], model, criterion, optimizer, verbose=True)

    if epoch == 0 or epoch % args.eval_freq == args.eval_freq - 1 or epoch == args.epochs - 1:
        print('EPOCH %d. EVAL' % (epoch + 1))
        test_res = utils.eval(loaders['test'], model, criterion, verbose=True)
    else:
        test_res = {'loss': None, 'accuracy': None}

    if args.swa and (epoch + 1) > args.swa_start:
        if epoch == args.swa_start or epoch % args.eval_freq == args.eval_freq - 1 or epoch == args.epochs - 1:
            swag_res = {'loss': None, 'accuracy': None}
            swag_model.to(args.device)
            swag_model.sample(0.0)
            print('EPOCH %d. SWAG BN' % (epoch + 1))
            utils.bn_update(loaders['train'], swag_model, verbose=True, subset=0.1)
            print('EPOCH %d. SWAG EVAL' % (epoch + 1))
            swag_res = utils.eval(loaders['test'], swag_model, criterion, verbose=True)
            swag_model.to(args.swa_device)
        else:
            swag_res = {'loss': None, 'accuracy': None}
Example #11
0
printf, logfile = utils.get_logging_print(os.path.join(args.dir, args.log_fname + '-%s.txt'))
print('Saving logs to: %s' % logfile)
columns = ['iter ens', 'acc', 'nll']

for i in range(num_samples):
    # utils.bn_update(loaders['train'], model, subset=args.bn_subset)
    pyro_model.eval()
    k = 0
    pyro_model.t.set_(samples[i, :])
    for input, target in tqdm.tqdm(loaders['test']):
        input = input.cuda(non_blocking=True)
        torch.manual_seed(i)

        output = pyro_model(input)

        with torch.no_grad():
            predictions[k:k+input.size()[0]] += F.softmax(output, dim=1).cpu().numpy()
        targets[k:(k+target.size(0))] = target.numpy()
        k += input.size()[0]

    values = ['%d/%d' % (i + 1, num_samples), np.mean(np.argmax(predictions, axis=1) == targets), nll(predictions / (i+1), targets)]
    if i == 0:
        printf(tabulate.tabulate([values], columns, tablefmt='simple', floatfmt='8.4f'))
    else:
        printf(tabulate.tabulate([values], columns, tablefmt='plain', floatfmt='8.4f').split('\n')[1])

pyro_model.t.set_(torch.zeros_like(pyro.model.t))
print(utils.eval(loaders["train"], pyro_model, criterion=losses.cross_entropy))

predictions /= num_samples
Example #12
0
columns = ["model", "epoch", "acc", "loss", "swa_acc", "swa_loss"]

pt_loss, pt_accuracy = list(), list()

if not args.no_ensembles:
    predictions = np.zeros((len(loaders["test"].dataset), num_classes, len(dir_locs)))
    targets = np.zeros(len(loaders["test"].dataset))


for i, ckpt in enumerate(dir_locs):

    model.load_state_dict(torch.load(ckpt)["state_dict"])
    epoch = int(ckpt.replace(".", "-").split("-")[1])
    model.eval()

    res = utils.eval(loaders["test"], model, criterion)

    pt_loss.append(res["loss"])
    pt_accuracy.append(res["accuracy"])

    if not args.no_ensembles:
        k = 0
        with torch.no_grad():
            for input, target in tqdm.tqdm(loaders["test"]):
                input = input.cuda(non_blocking=True)
                torch.manual_seed(1)

                output = model(input)

                predictions[k : k + input.size(0), :, i] += (
                    F.softmax(output, dim=1).cpu().numpy()
columns = ['model', 'epoch', 'acc', 'loss', 'swa_acc', 'swa_loss']

pt_loss, pt_accuracy = list(), list()

if not args.no_ensembles:
    predictions = np.zeros(
        (len(loaders['test'].dataset), num_classes, len(dir_locs)))
    targets = np.zeros(len(loaders['test'].dataset))

for i, ckpt in enumerate(dir_locs):

    model.load_state_dict(torch.load(ckpt)['state_dict'])
    epoch = int(ckpt.replace('.', '-').split('-')[1])
    model.eval()

    res = utils.eval(loaders['test'], model, criterion)

    pt_loss.append(res['loss'])
    pt_accuracy.append(res['accuracy'])

    if not args.no_ensembles:
        k = 0
        with torch.no_grad():
            for input, target in tqdm.tqdm(loaders['test']):
                input = input.cuda(non_blocking=True)
                torch.manual_seed(1)

                output = model(input)

                predictions[k:k + input.size(0), :,
                            i] += F.softmax(output, dim=1).cpu().numpy()
Example #14
0
        lr = schedule(epoch)
        utils.adjust_learning_rate(optimizer, lr)
    else:
        lr = args.lr_init

    if (args.swa and (epoch + 1) > args.swa_start) and args.cov_mat:
        train_res = utils.train_epoch(loaders["train"], model, criterion, optimizer, cuda=use_cuda)
    else:
        train_res = utils.train_epoch(loaders["train"], model, criterion, optimizer, cuda=use_cuda)

    if (
        epoch == 0
        or epoch % args.eval_freq == args.eval_freq - 1
        or epoch == args.epochs - 1
    ):
        test_res = utils.eval(loaders["test"], model, criterion, cuda=use_cuda)
    else:
        test_res = {"loss": None, "accuracy": None}

    if (
        args.swa
        and (epoch + 1) > args.swa_start
        and (epoch + 1 - args.swa_start) % args.swa_c_epochs == 0
    ):
        # sgd_preds, sgd_targets = utils.predictions(loaders["test"], model)
        sgd_res = utils.predict(loaders["test"], model)
        sgd_preds = sgd_res["predictions"]
        sgd_targets = sgd_res["targets"]
        print("updating sgd_ens")
        if sgd_ens_preds is None:
            sgd_ens_preds = sgd_preds.copy()