예제 #1
0
    def test_swag_cov(self, **kwargs):
        model = torch.nn.Linear(300, 3, bias=True)

        swag_model = SWAG(torch.nn.Linear, in_features=300, out_features=3, bias=True, subspace_type = 'covariance',
                subspace_kwargs = {'max_rank':140})

        optimizer = torch.optim.SGD(model.parameters(), lr = 1e-3)

        # construct swag model via training
        torch.manual_seed(0)
        for _ in range(101):
            model.zero_grad()

            input = torch.randn(100, 300)
            output = model(input)
            loss = ((torch.randn(100, 3) - output)**2.0).sum()
            loss.backward()
            
            optimizer.step()
            
            swag_model.collect_model(model)

        # check to ensure parameters have the correct sizes
        mean, var = swag_model._get_mean_and_variance()
        cov_mat_sqrt = swag_model.subspace.get_space()

        true_cov_mat = cov_mat_sqrt.t().matmul(cov_mat_sqrt) + torch.diag(var)

        test_cutoff = chi2(df = mean.numel()).ppf(0.95) #95% quantile of p dimensional chi-square distribution

        for scale in [0.01, 0.1, 0.5, 1.0, 2.0, 5.0]:
            scaled_cov_mat = true_cov_mat * scale
            scaled_cov_inv = torch.inverse(scaled_cov_mat)
            # now test to ensure that sampling has the correct covariance matrix probabilistically
            all_qforms = []            
            for _ in range(3000):
                swag_model.sample(scale=scale)
                curr_pars = []
                for (module, name, _) in swag_model.base_params:
                    curr_pars.append(getattr(module, name))
                dev = flatten(curr_pars) - mean

                #(x - mu)sigma^{-1}(x - mu)
                qform = dev.matmul(scaled_cov_inv).matmul(dev)

                all_qforms.append(qform.item())
            
            samples_in_cr = (np.array(all_qforms) < test_cutoff).sum()
            print(samples_in_cr)

            #between 94 and 96% of the samples should fall within the threshold
            #this should be very loose
            self.assertTrue(0.94*3000 <= samples_in_cr <= 0.96*3000) 
예제 #2
0
def criterion(model, input, target, scale=args.prior_std):
    likelihood, output, _ = losses.cross_entropy(model, input, target)
    prior = 1 / (scale**2.0 * input.size(0)) * proj_params.norm()
    return likelihood + prior, output, {
        'nll': likelihood * input.size(0),
        'prior': proj_params.norm()
    }


optimizer = torch.optim.SGD([proj_params],
                            lr=5e-4,
                            momentum=0.9,
                            weight_decay=0)

swag_model.sample(0)
utils.bn_update(loaders['train'], swag_model)
print(utils.eval(loaders['test'], swag_model, criterion))

printf, logfile = utils.get_logging_print(
    os.path.join(args.dir, args.log_fname + '-%s.txt'))
print('Saving logs to: %s' % logfile)
#printf=print
columns = ['ep', 'acc', 'loss', 'prior']

for epoch in range(args.epochs):
    train_res = utils.train_epoch(loaders['train'], proj_model, criterion,
                                  optimizer)
    values = [
        '%d/%d' % (epoch + 1, args.epochs), train_res['accuracy'],
        train_res['loss'], train_res['stats']['prior'],
criterion = losses.cross_entropy

fractions = np.logspace(-np.log10(0.005 * len(loaders['train'].dataset)), 0.0,
                        args.N)
swa_accuracies = np.zeros(args.N)
swa_nlls = np.zeros(args.N)
swag_accuracies = np.zeros(args.N)
swag_nlls = np.zeros(args.N)

columns = ['fraction', 'swa_acc', 'swa_loss', 'swag_acc', 'swag_loss', 'time']

for i, fraction in enumerate(fractions):
    start_time = time.time()
    swag_model.load_state_dict(ckpt['state_dict'])

    swag_model.sample(0.0)
    utils.bn_update(loaders['train'], swag_model, subset=fraction)
    swa_res = utils.eval(loaders['test'], swag_model, criterion)
    swa_accuracies[i] = swa_res['accuracy']
    swa_nlls[i] = swa_res['loss']

    predictions = np.zeros((len(loaders['test'].dataset), num_classes))

    for j in range(args.S):
        swag_model.load_state_dict(ckpt['state_dict'])
        swag_model.sample(scale=0.5, cov=args.cov_mat)
        utils.bn_update(loaders['train'], swag_model, subset=fraction)
        sample_res = utils.predict(loaders['test'], swag_model)
        predictions += sample_res['predictions']
        targets = sample_res['targets']
    predictions /= args.S
예제 #4
0
                                momentum=0.9,
                                weight_decay=1e-4)
    loader = generate_dataloaders(N=10)

    state_dict = None

    for epoch in range(num_epochs):
        model.train()

        for x, y in loader:
            model.zero_grad()
            pred = model(x)
            loss = ((pred - y)**2.0).sum()
            loss.backward()
            optimizer.step()
        small_swag_model.collect_model(model)

        if epoch == 4:
            state_dict = small_swag_model.state_dict()

    small_swag_model.fit()
    with torch.no_grad():
        x = torch.arange(-6., 6., 1.0).unsqueeze(1)
        for i in range(10):
            small_swag_model.sample(0.5)
            small_swag_model(x)

    _, _ = small_swag_model.get_space(export_cov_factor=False)
    _, _, _ = small_swag_model.get_space(export_cov_factor=True)
    small_swag_model.load_state_dict(state_dict)
print('Preparing model')
swag_model = SWAG(model_class,
                  no_cov_mat=not args.cov_mat,
                  loading=True,
                  max_num_models=20,
                  num_classes=num_classes)
swag_model.to(args.device)

criterion = losses.cross_entropy

print('Loading checkpoint %s' % args.ckpt)
checkpoint = torch.load(args.ckpt)
swag_model.load_state_dict(checkpoint['state_dict'])

print('SWA')
swag_model.sample(0.0)
print('SWA BN update')
utils.bn_update(loaders['train'], swag_model, verbose=True, subset=0.1)
print('SWA EVAL')
swa_res = utils.predict(loaders['test'], swag_model, verbose=True)

targets = swa_res['targets']
swa_predictions = swa_res['predictions']

swa_accuracy = np.mean(np.argmax(swa_predictions, axis=1) == targets)
swa_nll = -np.mean(
    np.log(swa_predictions[np.arange(swa_predictions.shape[0]), targets] +
           eps))
print('SWA. Accuracy: %.2f%% NLL: %.4f' % (swa_accuracy * 100, swa_nll))
swa_entropies = -np.sum(np.log(swa_predictions + eps) * swa_predictions,
                        axis=1)
                  **model_cfg.kwargs)
swag_model.to(args.device)

columns = ['swag', 'sample', 'te_loss', 'te_acc', 'ens_loss', 'ens_acc']

n_ensembled = 0.
multiswag_probs = None

for ckpt_i, ckpt in enumerate(args.swag_ckpts):
    print("Checkpoint {}".format(ckpt))
    checkpoint = torch.load(ckpt)
    swag_model.subspace.rank = torch.tensor(0)
    swag_model.load_state_dict(checkpoint['state_dict'])

    for sample in range(args.swag_samples):
        swag_model.sample(.5)
        utils.bn_update(loaders['train'], swag_model)
        res = utils.predict(loaders['test'], swag_model)
        probs = res['predictions']
        targets = res['targets']
        nll = utils.nll(probs, targets)
        acc = utils.accuracy(probs, targets)

        if multiswag_probs is None:
            multiswag_probs = probs.copy()
        else:
            #TODO: rewrite in a numerically stable way
            multiswag_probs += (probs - multiswag_probs) / (n_ensembled + 1)
        n_ensembled += 1

        ens_nll = utils.nll(multiswag_probs, targets)
예제 #7
0
    def test_swag_diag(self, **kwargs):
        model = torch.nn.Linear(300, 3, bias=True)

        swag_model = SWAG(
            torch.nn.Linear,
            in_features=300,
            out_features=3,
            bias=True,
            no_cov_mat=True,
            max_num_models=100,
            loading=False,
        )

        optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)

        # construct swag model via training
        torch.manual_seed(0)
        for _ in range(101):
            model.zero_grad()

            input = torch.randn(100, 300)
            output = model(input)
            loss = ((torch.randn(100, 3) - output) ** 2.0).sum()
            loss.backward()

            optimizer.step()

            swag_model.collect_model(model)

        # check to ensure parameters have the correct sizes
        mean_list = []
        sq_mean_list = []
        for (module, name), param in zip(swag_model.params, model.parameters()):
            mean = module.__getattr__("%s_mean" % name)
            sq_mean = module.__getattr__("%s_sq_mean" % name)

            self.assertEqual(param.size(), mean.size())
            self.assertEqual(param.size(), sq_mean.size())

            mean_list.append(mean)
            sq_mean_list.append(sq_mean)

        mean = flatten(mean_list).cuda()
        sq_mean = flatten(sq_mean_list).cuda()

        for scale in [0.01, 0.1, 0.5, 1.0, 2.0, 5.0]:
            var = scale * (sq_mean - mean ** 2)

            std = torch.sqrt(var)
            dist = torch.distributions.Normal(mean, std)

            # now test to ensure that sampling has the correct covariance matrix probabilistically
            all_qforms = 0
            for _ in range(20):
                swag_model.sample(scale=scale, cov=False)
                curr_pars = []
                for (module, name) in swag_model.params:
                    curr_pars.append(getattr(module, name))

                curr_probs = dist.cdf(flatten(curr_pars))

                # check if within 95% CI
                num_in_cr = ((curr_probs > 0.025) & (curr_probs < 0.975)).float().sum()
                # all_qforms.append( num_in_cr )
                all_qforms += num_in_cr

            # print(all_qforms/(20 * mean.numel()))
            # now compute average
            avg_prob_in_cr = all_qforms / (20 * mean.numel())

            # CLT should hold a bit tighter here
            self.assertTrue(0.945 <= avg_prob_in_cr <= 0.955)
예제 #8
0
    def test_swag_cov(self, **kwargs):
        model = torch.nn.Linear(300, 3, bias=True)

        swag_model = SWAG(
            torch.nn.Linear,
            in_features=300,
            out_features=3,
            bias=True,
            no_cov_mat=False,
            max_num_models=100,
            loading=False,
        )

        optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)

        # construct swag model via training
        torch.manual_seed(0)
        for _ in range(101):
            model.zero_grad()

            input = torch.randn(100, 300)
            output = model(input)
            loss = ((torch.randn(100, 3) - output) ** 2.0).sum()
            loss.backward()

            optimizer.step()

            swag_model.collect_model(model)

        # check to ensure parameters have the correct sizes
        mean_list = []
        sq_mean_list = []
        cov_mat_sqrt_list = []
        for (module, name), param in zip(swag_model.params, model.parameters()):
            mean = module.__getattr__("%s_mean" % name)
            sq_mean = module.__getattr__("%s_sq_mean" % name)
            cov_mat_sqrt = module.__getattr__("%s_cov_mat_sqrt" % name)

            self.assertEqual(param.size(), mean.size())
            self.assertEqual(param.size(), sq_mean.size())
            self.assertEqual(
                [swag_model.max_num_models, param.numel()], list(cov_mat_sqrt.size())
            )

            mean_list.append(mean)
            sq_mean_list.append(sq_mean)
            cov_mat_sqrt_list.append(cov_mat_sqrt)

        mean = flatten(mean_list).cuda()
        sq_mean = flatten(sq_mean_list).cuda()
        cov_mat_sqrt = torch.cat(cov_mat_sqrt_list, dim=1).cuda()

        true_cov_mat = (
            1.0 / (swag_model.max_num_models - 1)
        ) * cov_mat_sqrt.t().matmul(cov_mat_sqrt) + torch.diag(sq_mean - mean ** 2)

        test_cutoff = chi2(df=mean.numel()).ppf(
            0.95
        )  # 95% quantile of p dimensional chi-square distribution

        for scale in [0.01, 0.1, 0.5, 1.0, 2.0, 5.0]:
            scaled_cov_mat = true_cov_mat * scale
            scaled_cov_inv = torch.inverse(scaled_cov_mat)
            # now test to ensure that sampling has the correct covariance matrix probabilistically
            all_qforms = []
            for _ in range(2000):
                swag_model.sample(scale=scale, cov=True)
                curr_pars = []
                for (module, name) in swag_model.params:
                    curr_pars.append(getattr(module, name))
                dev = flatten(curr_pars) - mean

                # (x - mu)sigma^{-1}(x - mu)
                qform = dev.matmul(scaled_cov_inv).matmul(dev)

                all_qforms.append(qform.item())

            samples_in_cr = (np.array(all_qforms) < test_cutoff).sum()
            print(samples_in_cr)

            # between 94 and 96% of the samples should fall within the threshold
            # this should be very loose
            self.assertTrue(1880 <= samples_in_cr <= 1920)
예제 #9
0
    criterion = losses.seg_ale_cross_entropy

# construct and load model
if args.swa_resume is not None:
    checkpoint = torch.load(args.swa_resume)
    model = SWAG(
        model_cfg.base,
        no_cov_mat=False,
        max_num_models=20,
        num_classes=num_classes,
        use_aleatoric=args.loss == "aleatoric",
    )
    model.cuda()
    model.load_state_dict(checkpoint["state_dict"])

    model.sample(0.0)
    bn_update(loaders["fine_tune"], model)
else:
    model = model_cfg.base(num_classes=num_classes,
                           use_aleatoric=args.loss == "aleatoric").cuda()
    checkpoint = torch.load(args.resume)
    start_epoch = checkpoint["epoch"]
    print(start_epoch)
    model.load_state_dict(checkpoint["state_dict"])

print(len(loaders["test"]))
if args.use_test:
    print("Using test dataset")
    test_loader = "test"
else:
    test_loader = "val"
예제 #10
0
targets = np.zeros((len(loaders["test"].dataset), 360, 480))

if args.loss == "aleatoric":
    scales = np.zeros((len(loaders["test"].dataset), 11, 360, 480))
else:
    scales = None

print(targets.size)

for i in range(args.N):
    print("%d/%d" % (i + 1, args.N))

    if args.method not in ["SGD", "Dropout"]:
        sample_with_cov = args.cov_mat and not args.use_diag
        with torch.no_grad():
            model.sample(scale=args.scale, cov=sample_with_cov)

    if "SWAG" in args.method:
        bn_update(loaders["fine_tune"], model)

    model.eval()
    if args.method in ["Dropout", "SWAGDrop"]:
        model.apply(train_dropout)

    k = 0
    current_predictions = np.zeros_like(predictions)
    for input, target in tqdm.tqdm(loaders["test"]):
        input = input.cuda(non_blocking=True)
        torch.manual_seed(i)

        with torch.no_grad():
예제 #11
0
for i in range(args.N):
    print('%d/%d' % (i + 1, args.N))
    if args.method == 'KFACLaplace':
        ## KFAC Laplace needs one forwards pass to load the KFAC model at the beginning
        model.net.load_state_dict(model.mean_state)

        if i == 0:
            model.net.train()

            loss, _ = losses.cross_entropy(model.net, t_input, t_target)
            loss.backward(create_graph=True)
            model.step(update_params=False)

    if args.method not in ['SGD', 'Dropout']:
        sample_with_cov = args.cov_mat and not args.use_diag
        model.sample(scale=args.scale)

    if 'SWAG' in args.method:
        utils.bn_update(loaders['train'], model, subset=args.bn_subset)

    model.eval()
    if args.method in ['Dropout', 'SWAGDrop']:
        model.apply(train_dropout)
        #torch.manual_seed(i)
        #utils.bn_update(loaders['train'], model)

    k = 0
    for input, target in tqdm.tqdm(loaders['test']):
        input = input.cuda(non_blocking=True)
        torch.manual_seed(i)
예제 #12
0
    cov_factor = tsvd.components_
    cov_factor /= np.linalg.norm(cov_factor, axis=1, keepdims=True)
    cov_factor *= scale

    print(cov_factor[:, 0])

    swag_model.cov_factor.copy_(torch.FloatTensor(cov_factor, device=mean.device))

ens_predictions = np.zeros((len(loaders['test'].dataset), num_classes))
targets = np.zeros(len(loaders['test'].dataset))

columns = ['iter ens', 'acc', 'nll']

with torch.no_grad():
    for i in range(args.num_samples):
        swag_model.sample(scale=args.scale)

        utils.bn_update(loaders['train'], swag_model, subset=args.bn_subset)

        pred_res = utils.predict(loaders['test'], swag_model)
        ens_predictions += pred_res['predictions']
        targets = pred_res['targets']

        values = ['%3d/%3d' % (i + 1, args.num_samples),
                  np.mean(np.argmax(ens_predictions, axis=1) == targets),
                  nll(ens_predictions / (i + 1), targets)]
        table = tabulate.tabulate([values], columns, tablefmt='simple', floatfmt='8.4f')
        if i == 0:
            print(table)
        else:
            print(table.split('\n')[2])