Exemple #1
0
    def test_swag_cov(self, **kwargs):
        model = torch.nn.Linear(300, 3, bias=True)

        swag_model = SWAG(torch.nn.Linear, in_features=300, out_features=3, bias=True, subspace_type = 'covariance',
                subspace_kwargs = {'max_rank':140})

        optimizer = torch.optim.SGD(model.parameters(), lr = 1e-3)

        # construct swag model via training
        torch.manual_seed(0)
        for _ in range(101):
            model.zero_grad()

            input = torch.randn(100, 300)
            output = model(input)
            loss = ((torch.randn(100, 3) - output)**2.0).sum()
            loss.backward()
            
            optimizer.step()
            
            swag_model.collect_model(model)

        # check to ensure parameters have the correct sizes
        mean, var = swag_model._get_mean_and_variance()
        cov_mat_sqrt = swag_model.subspace.get_space()

        true_cov_mat = cov_mat_sqrt.t().matmul(cov_mat_sqrt) + torch.diag(var)

        test_cutoff = chi2(df = mean.numel()).ppf(0.95) #95% quantile of p dimensional chi-square distribution

        for scale in [0.01, 0.1, 0.5, 1.0, 2.0, 5.0]:
            scaled_cov_mat = true_cov_mat * scale
            scaled_cov_inv = torch.inverse(scaled_cov_mat)
            # now test to ensure that sampling has the correct covariance matrix probabilistically
            all_qforms = []            
            for _ in range(3000):
                swag_model.sample(scale=scale)
                curr_pars = []
                for (module, name, _) in swag_model.base_params:
                    curr_pars.append(getattr(module, name))
                dev = flatten(curr_pars) - mean

                #(x - mu)sigma^{-1}(x - mu)
                qform = dev.matmul(scaled_cov_inv).matmul(dev)

                all_qforms.append(qform.item())
            
            samples_in_cr = (np.array(all_qforms) < test_cutoff).sum()
            print(samples_in_cr)

            #between 94 and 96% of the samples should fall within the threshold
            #this should be very loose
            self.assertTrue(0.94*3000 <= samples_in_cr <= 0.96*3000) 
Exemple #2
0
if torch.cuda.is_available():
    args.device = torch.device('cuda')
else:
    args.device = torch.device('cpu')

print('Using model %s' % args.model)
model_cfg = getattr(models, args.model)

print('Preparing model')
print(*model_cfg.args)
model = model_cfg.base(*model_cfg.args,
                       num_classes=args.num_classes,
                       **model_cfg.kwargs)
model.to(args.device)

swag_model = SWAG(model_cfg.base,
                  subspace_type=args.subspace,
                  subspace_kwargs={'max_rank': args.max_num_models},
                  *model_cfg.args,
                  num_classes=args.num_classes,
                  **model_cfg.kwargs)
swag_model.to(args.device)

for path in args.checkpoint:
    print(path)
    ckpt = torch.load(path)
    model.load_state_dict(ckpt['state_dict'])
    swag_model.collect_model(model)

torch.save({'state_dict': swag_model.state_dict()}, args.path)
Exemple #3
0
                                momentum=0.9,
                                weight_decay=1e-4)
    loader = generate_dataloaders(N=10)

    state_dict = None

    for epoch in range(num_epochs):
        model.train()

        for x, y in loader:
            model.zero_grad()
            pred = model(x)
            loss = ((pred - y)**2.0).sum()
            loss.backward()
            optimizer.step()
        small_swag_model.collect_model(model)

        if epoch == 4:
            state_dict = small_swag_model.state_dict()

    small_swag_model.fit()
    with torch.no_grad():
        x = torch.arange(-6., 6., 1.0).unsqueeze(1)
        for i in range(10):
            small_swag_model.sample(0.5)
            small_swag_model(x)

    _, _ = small_swag_model.get_space(export_cov_factor=False)
    _, _, _ = small_swag_model.get_space(export_cov_factor=True)
    small_swag_model.load_state_dict(state_dict)
    def test_swag_diag(self, **kwargs):
        model = torch.nn.Linear(300, 3, bias=True)

        swag_model = SWAG(
            torch.nn.Linear,
            in_features=300,
            out_features=3,
            bias=True,
            no_cov_mat=True,
            max_num_models=100,
            loading=False,
        )

        optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)

        # construct swag model via training
        torch.manual_seed(0)
        for _ in range(101):
            model.zero_grad()

            input = torch.randn(100, 300)
            output = model(input)
            loss = ((torch.randn(100, 3) - output) ** 2.0).sum()
            loss.backward()

            optimizer.step()

            swag_model.collect_model(model)

        # check to ensure parameters have the correct sizes
        mean_list = []
        sq_mean_list = []
        for (module, name), param in zip(swag_model.params, model.parameters()):
            mean = module.__getattr__("%s_mean" % name)
            sq_mean = module.__getattr__("%s_sq_mean" % name)

            self.assertEqual(param.size(), mean.size())
            self.assertEqual(param.size(), sq_mean.size())

            mean_list.append(mean)
            sq_mean_list.append(sq_mean)

        mean = flatten(mean_list).cuda()
        sq_mean = flatten(sq_mean_list).cuda()

        for scale in [0.01, 0.1, 0.5, 1.0, 2.0, 5.0]:
            var = scale * (sq_mean - mean ** 2)

            std = torch.sqrt(var)
            dist = torch.distributions.Normal(mean, std)

            # now test to ensure that sampling has the correct covariance matrix probabilistically
            all_qforms = 0
            for _ in range(20):
                swag_model.sample(scale=scale, cov=False)
                curr_pars = []
                for (module, name) in swag_model.params:
                    curr_pars.append(getattr(module, name))

                curr_probs = dist.cdf(flatten(curr_pars))

                # check if within 95% CI
                num_in_cr = ((curr_probs > 0.025) & (curr_probs < 0.975)).float().sum()
                # all_qforms.append( num_in_cr )
                all_qforms += num_in_cr

            # print(all_qforms/(20 * mean.numel()))
            # now compute average
            avg_prob_in_cr = all_qforms / (20 * mean.numel())

            # CLT should hold a bit tighter here
            self.assertTrue(0.945 <= avg_prob_in_cr <= 0.955)
    def test_swag_cov(self, **kwargs):
        model = torch.nn.Linear(300, 3, bias=True)

        swag_model = SWAG(
            torch.nn.Linear,
            in_features=300,
            out_features=3,
            bias=True,
            no_cov_mat=False,
            max_num_models=100,
            loading=False,
        )

        optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)

        # construct swag model via training
        torch.manual_seed(0)
        for _ in range(101):
            model.zero_grad()

            input = torch.randn(100, 300)
            output = model(input)
            loss = ((torch.randn(100, 3) - output) ** 2.0).sum()
            loss.backward()

            optimizer.step()

            swag_model.collect_model(model)

        # check to ensure parameters have the correct sizes
        mean_list = []
        sq_mean_list = []
        cov_mat_sqrt_list = []
        for (module, name), param in zip(swag_model.params, model.parameters()):
            mean = module.__getattr__("%s_mean" % name)
            sq_mean = module.__getattr__("%s_sq_mean" % name)
            cov_mat_sqrt = module.__getattr__("%s_cov_mat_sqrt" % name)

            self.assertEqual(param.size(), mean.size())
            self.assertEqual(param.size(), sq_mean.size())
            self.assertEqual(
                [swag_model.max_num_models, param.numel()], list(cov_mat_sqrt.size())
            )

            mean_list.append(mean)
            sq_mean_list.append(sq_mean)
            cov_mat_sqrt_list.append(cov_mat_sqrt)

        mean = flatten(mean_list).cuda()
        sq_mean = flatten(sq_mean_list).cuda()
        cov_mat_sqrt = torch.cat(cov_mat_sqrt_list, dim=1).cuda()

        true_cov_mat = (
            1.0 / (swag_model.max_num_models - 1)
        ) * cov_mat_sqrt.t().matmul(cov_mat_sqrt) + torch.diag(sq_mean - mean ** 2)

        test_cutoff = chi2(df=mean.numel()).ppf(
            0.95
        )  # 95% quantile of p dimensional chi-square distribution

        for scale in [0.01, 0.1, 0.5, 1.0, 2.0, 5.0]:
            scaled_cov_mat = true_cov_mat * scale
            scaled_cov_inv = torch.inverse(scaled_cov_mat)
            # now test to ensure that sampling has the correct covariance matrix probabilistically
            all_qforms = []
            for _ in range(2000):
                swag_model.sample(scale=scale, cov=True)
                curr_pars = []
                for (module, name) in swag_model.params:
                    curr_pars.append(getattr(module, name))
                dev = flatten(curr_pars) - mean

                # (x - mu)sigma^{-1}(x - mu)
                qform = dev.matmul(scaled_cov_inv).matmul(dev)

                all_qforms.append(qform.item())

            samples_in_cr = (np.array(all_qforms) < test_cutoff).sum()
            print(samples_in_cr)

            # between 94 and 96% of the samples should fall within the threshold
            # this should be very loose
            self.assertTrue(1880 <= samples_in_cr <= 1920)