Example #1
0
    def test_swag_cov(self, **kwargs):
        model = torch.nn.Linear(300, 3, bias=True)

        swag_model = SWAG(torch.nn.Linear, in_features=300, out_features=3, bias=True, subspace_type = 'covariance',
                subspace_kwargs = {'max_rank':140})

        optimizer = torch.optim.SGD(model.parameters(), lr = 1e-3)

        # construct swag model via training
        torch.manual_seed(0)
        for _ in range(101):
            model.zero_grad()

            input = torch.randn(100, 300)
            output = model(input)
            loss = ((torch.randn(100, 3) - output)**2.0).sum()
            loss.backward()
            
            optimizer.step()
            
            swag_model.collect_model(model)

        # check to ensure parameters have the correct sizes
        mean, var = swag_model._get_mean_and_variance()
        cov_mat_sqrt = swag_model.subspace.get_space()

        true_cov_mat = cov_mat_sqrt.t().matmul(cov_mat_sqrt) + torch.diag(var)

        test_cutoff = chi2(df = mean.numel()).ppf(0.95) #95% quantile of p dimensional chi-square distribution

        for scale in [0.01, 0.1, 0.5, 1.0, 2.0, 5.0]:
            scaled_cov_mat = true_cov_mat * scale
            scaled_cov_inv = torch.inverse(scaled_cov_mat)
            # now test to ensure that sampling has the correct covariance matrix probabilistically
            all_qforms = []            
            for _ in range(3000):
                swag_model.sample(scale=scale)
                curr_pars = []
                for (module, name, _) in swag_model.base_params:
                    curr_pars.append(getattr(module, name))
                dev = flatten(curr_pars) - mean

                #(x - mu)sigma^{-1}(x - mu)
                qform = dev.matmul(scaled_cov_inv).matmul(dev)

                all_qforms.append(qform.item())
            
            samples_in_cr = (np.array(all_qforms) < test_cutoff).sum()
            print(samples_in_cr)

            #between 94 and 96% of the samples should fall within the threshold
            #this should be very loose
            self.assertTrue(0.94*3000 <= samples_in_cr <= 0.96*3000) 
                  *model_cfg.args,
                  num_classes=num_classes,
                  **model_cfg.kwargs)
swag_model.to(args.device)

columns = ['swag', 'sample', 'te_loss', 'te_acc', 'ens_loss', 'ens_acc']

if args.warm_start:
    print('estimating initial PI using K-Means')
    kmeans_input = []
    for ckpt_i, ckpt in enumerate(args.swag_ckpts):
        #print("Checkpoint {}".format(ckpt))
        checkpoint = torch.load(ckpt)
        swag_model.subspace.rank = torch.tensor(0)
        swag_model.load_state_dict(checkpoint['state_dict'])
        mean, variance = swag_model._get_mean_and_variance()
        kmeans_input.append({
            'model': ckpt,
            'mean': mean,
            'variance': variance
        })

paramDump = []


def boSwag(Pi):
    useMetric = 'nll'
    disableBo = False
    if disableBo == True:
        print("Computing for base case")
        # Base case = uniform weights for each SWAG and its models