def test_swag_cov(self, **kwargs): model = torch.nn.Linear(300, 3, bias=True) swag_model = SWAG(torch.nn.Linear, in_features=300, out_features=3, bias=True, subspace_type = 'covariance', subspace_kwargs = {'max_rank':140}) optimizer = torch.optim.SGD(model.parameters(), lr = 1e-3) # construct swag model via training torch.manual_seed(0) for _ in range(101): model.zero_grad() input = torch.randn(100, 300) output = model(input) loss = ((torch.randn(100, 3) - output)**2.0).sum() loss.backward() optimizer.step() swag_model.collect_model(model) # check to ensure parameters have the correct sizes mean, var = swag_model._get_mean_and_variance() cov_mat_sqrt = swag_model.subspace.get_space() true_cov_mat = cov_mat_sqrt.t().matmul(cov_mat_sqrt) + torch.diag(var) test_cutoff = chi2(df = mean.numel()).ppf(0.95) #95% quantile of p dimensional chi-square distribution for scale in [0.01, 0.1, 0.5, 1.0, 2.0, 5.0]: scaled_cov_mat = true_cov_mat * scale scaled_cov_inv = torch.inverse(scaled_cov_mat) # now test to ensure that sampling has the correct covariance matrix probabilistically all_qforms = [] for _ in range(3000): swag_model.sample(scale=scale) curr_pars = [] for (module, name, _) in swag_model.base_params: curr_pars.append(getattr(module, name)) dev = flatten(curr_pars) - mean #(x - mu)sigma^{-1}(x - mu) qform = dev.matmul(scaled_cov_inv).matmul(dev) all_qforms.append(qform.item()) samples_in_cr = (np.array(all_qforms) < test_cutoff).sum() print(samples_in_cr) #between 94 and 96% of the samples should fall within the threshold #this should be very loose self.assertTrue(0.94*3000 <= samples_in_cr <= 0.96*3000)
def criterion(model, input, target, scale=args.prior_std): likelihood, output, _ = losses.cross_entropy(model, input, target) prior = 1 / (scale**2.0 * input.size(0)) * proj_params.norm() return likelihood + prior, output, { 'nll': likelihood * input.size(0), 'prior': proj_params.norm() } optimizer = torch.optim.SGD([proj_params], lr=5e-4, momentum=0.9, weight_decay=0) swag_model.sample(0) utils.bn_update(loaders['train'], swag_model) print(utils.eval(loaders['test'], swag_model, criterion)) printf, logfile = utils.get_logging_print( os.path.join(args.dir, args.log_fname + '-%s.txt')) print('Saving logs to: %s' % logfile) #printf=print columns = ['ep', 'acc', 'loss', 'prior'] for epoch in range(args.epochs): train_res = utils.train_epoch(loaders['train'], proj_model, criterion, optimizer) values = [ '%d/%d' % (epoch + 1, args.epochs), train_res['accuracy'], train_res['loss'], train_res['stats']['prior'],
criterion = losses.cross_entropy fractions = np.logspace(-np.log10(0.005 * len(loaders['train'].dataset)), 0.0, args.N) swa_accuracies = np.zeros(args.N) swa_nlls = np.zeros(args.N) swag_accuracies = np.zeros(args.N) swag_nlls = np.zeros(args.N) columns = ['fraction', 'swa_acc', 'swa_loss', 'swag_acc', 'swag_loss', 'time'] for i, fraction in enumerate(fractions): start_time = time.time() swag_model.load_state_dict(ckpt['state_dict']) swag_model.sample(0.0) utils.bn_update(loaders['train'], swag_model, subset=fraction) swa_res = utils.eval(loaders['test'], swag_model, criterion) swa_accuracies[i] = swa_res['accuracy'] swa_nlls[i] = swa_res['loss'] predictions = np.zeros((len(loaders['test'].dataset), num_classes)) for j in range(args.S): swag_model.load_state_dict(ckpt['state_dict']) swag_model.sample(scale=0.5, cov=args.cov_mat) utils.bn_update(loaders['train'], swag_model, subset=fraction) sample_res = utils.predict(loaders['test'], swag_model) predictions += sample_res['predictions'] targets = sample_res['targets'] predictions /= args.S
momentum=0.9, weight_decay=1e-4) loader = generate_dataloaders(N=10) state_dict = None for epoch in range(num_epochs): model.train() for x, y in loader: model.zero_grad() pred = model(x) loss = ((pred - y)**2.0).sum() loss.backward() optimizer.step() small_swag_model.collect_model(model) if epoch == 4: state_dict = small_swag_model.state_dict() small_swag_model.fit() with torch.no_grad(): x = torch.arange(-6., 6., 1.0).unsqueeze(1) for i in range(10): small_swag_model.sample(0.5) small_swag_model(x) _, _ = small_swag_model.get_space(export_cov_factor=False) _, _, _ = small_swag_model.get_space(export_cov_factor=True) small_swag_model.load_state_dict(state_dict)
print('Preparing model') swag_model = SWAG(model_class, no_cov_mat=not args.cov_mat, loading=True, max_num_models=20, num_classes=num_classes) swag_model.to(args.device) criterion = losses.cross_entropy print('Loading checkpoint %s' % args.ckpt) checkpoint = torch.load(args.ckpt) swag_model.load_state_dict(checkpoint['state_dict']) print('SWA') swag_model.sample(0.0) print('SWA BN update') utils.bn_update(loaders['train'], swag_model, verbose=True, subset=0.1) print('SWA EVAL') swa_res = utils.predict(loaders['test'], swag_model, verbose=True) targets = swa_res['targets'] swa_predictions = swa_res['predictions'] swa_accuracy = np.mean(np.argmax(swa_predictions, axis=1) == targets) swa_nll = -np.mean( np.log(swa_predictions[np.arange(swa_predictions.shape[0]), targets] + eps)) print('SWA. Accuracy: %.2f%% NLL: %.4f' % (swa_accuracy * 100, swa_nll)) swa_entropies = -np.sum(np.log(swa_predictions + eps) * swa_predictions, axis=1)
**model_cfg.kwargs) swag_model.to(args.device) columns = ['swag', 'sample', 'te_loss', 'te_acc', 'ens_loss', 'ens_acc'] n_ensembled = 0. multiswag_probs = None for ckpt_i, ckpt in enumerate(args.swag_ckpts): print("Checkpoint {}".format(ckpt)) checkpoint = torch.load(ckpt) swag_model.subspace.rank = torch.tensor(0) swag_model.load_state_dict(checkpoint['state_dict']) for sample in range(args.swag_samples): swag_model.sample(.5) utils.bn_update(loaders['train'], swag_model) res = utils.predict(loaders['test'], swag_model) probs = res['predictions'] targets = res['targets'] nll = utils.nll(probs, targets) acc = utils.accuracy(probs, targets) if multiswag_probs is None: multiswag_probs = probs.copy() else: #TODO: rewrite in a numerically stable way multiswag_probs += (probs - multiswag_probs) / (n_ensembled + 1) n_ensembled += 1 ens_nll = utils.nll(multiswag_probs, targets)
def test_swag_diag(self, **kwargs): model = torch.nn.Linear(300, 3, bias=True) swag_model = SWAG( torch.nn.Linear, in_features=300, out_features=3, bias=True, no_cov_mat=True, max_num_models=100, loading=False, ) optimizer = torch.optim.SGD(model.parameters(), lr=1e-3) # construct swag model via training torch.manual_seed(0) for _ in range(101): model.zero_grad() input = torch.randn(100, 300) output = model(input) loss = ((torch.randn(100, 3) - output) ** 2.0).sum() loss.backward() optimizer.step() swag_model.collect_model(model) # check to ensure parameters have the correct sizes mean_list = [] sq_mean_list = [] for (module, name), param in zip(swag_model.params, model.parameters()): mean = module.__getattr__("%s_mean" % name) sq_mean = module.__getattr__("%s_sq_mean" % name) self.assertEqual(param.size(), mean.size()) self.assertEqual(param.size(), sq_mean.size()) mean_list.append(mean) sq_mean_list.append(sq_mean) mean = flatten(mean_list).cuda() sq_mean = flatten(sq_mean_list).cuda() for scale in [0.01, 0.1, 0.5, 1.0, 2.0, 5.0]: var = scale * (sq_mean - mean ** 2) std = torch.sqrt(var) dist = torch.distributions.Normal(mean, std) # now test to ensure that sampling has the correct covariance matrix probabilistically all_qforms = 0 for _ in range(20): swag_model.sample(scale=scale, cov=False) curr_pars = [] for (module, name) in swag_model.params: curr_pars.append(getattr(module, name)) curr_probs = dist.cdf(flatten(curr_pars)) # check if within 95% CI num_in_cr = ((curr_probs > 0.025) & (curr_probs < 0.975)).float().sum() # all_qforms.append( num_in_cr ) all_qforms += num_in_cr # print(all_qforms/(20 * mean.numel())) # now compute average avg_prob_in_cr = all_qforms / (20 * mean.numel()) # CLT should hold a bit tighter here self.assertTrue(0.945 <= avg_prob_in_cr <= 0.955)
def test_swag_cov(self, **kwargs): model = torch.nn.Linear(300, 3, bias=True) swag_model = SWAG( torch.nn.Linear, in_features=300, out_features=3, bias=True, no_cov_mat=False, max_num_models=100, loading=False, ) optimizer = torch.optim.SGD(model.parameters(), lr=1e-3) # construct swag model via training torch.manual_seed(0) for _ in range(101): model.zero_grad() input = torch.randn(100, 300) output = model(input) loss = ((torch.randn(100, 3) - output) ** 2.0).sum() loss.backward() optimizer.step() swag_model.collect_model(model) # check to ensure parameters have the correct sizes mean_list = [] sq_mean_list = [] cov_mat_sqrt_list = [] for (module, name), param in zip(swag_model.params, model.parameters()): mean = module.__getattr__("%s_mean" % name) sq_mean = module.__getattr__("%s_sq_mean" % name) cov_mat_sqrt = module.__getattr__("%s_cov_mat_sqrt" % name) self.assertEqual(param.size(), mean.size()) self.assertEqual(param.size(), sq_mean.size()) self.assertEqual( [swag_model.max_num_models, param.numel()], list(cov_mat_sqrt.size()) ) mean_list.append(mean) sq_mean_list.append(sq_mean) cov_mat_sqrt_list.append(cov_mat_sqrt) mean = flatten(mean_list).cuda() sq_mean = flatten(sq_mean_list).cuda() cov_mat_sqrt = torch.cat(cov_mat_sqrt_list, dim=1).cuda() true_cov_mat = ( 1.0 / (swag_model.max_num_models - 1) ) * cov_mat_sqrt.t().matmul(cov_mat_sqrt) + torch.diag(sq_mean - mean ** 2) test_cutoff = chi2(df=mean.numel()).ppf( 0.95 ) # 95% quantile of p dimensional chi-square distribution for scale in [0.01, 0.1, 0.5, 1.0, 2.0, 5.0]: scaled_cov_mat = true_cov_mat * scale scaled_cov_inv = torch.inverse(scaled_cov_mat) # now test to ensure that sampling has the correct covariance matrix probabilistically all_qforms = [] for _ in range(2000): swag_model.sample(scale=scale, cov=True) curr_pars = [] for (module, name) in swag_model.params: curr_pars.append(getattr(module, name)) dev = flatten(curr_pars) - mean # (x - mu)sigma^{-1}(x - mu) qform = dev.matmul(scaled_cov_inv).matmul(dev) all_qforms.append(qform.item()) samples_in_cr = (np.array(all_qforms) < test_cutoff).sum() print(samples_in_cr) # between 94 and 96% of the samples should fall within the threshold # this should be very loose self.assertTrue(1880 <= samples_in_cr <= 1920)
criterion = losses.seg_ale_cross_entropy # construct and load model if args.swa_resume is not None: checkpoint = torch.load(args.swa_resume) model = SWAG( model_cfg.base, no_cov_mat=False, max_num_models=20, num_classes=num_classes, use_aleatoric=args.loss == "aleatoric", ) model.cuda() model.load_state_dict(checkpoint["state_dict"]) model.sample(0.0) bn_update(loaders["fine_tune"], model) else: model = model_cfg.base(num_classes=num_classes, use_aleatoric=args.loss == "aleatoric").cuda() checkpoint = torch.load(args.resume) start_epoch = checkpoint["epoch"] print(start_epoch) model.load_state_dict(checkpoint["state_dict"]) print(len(loaders["test"])) if args.use_test: print("Using test dataset") test_loader = "test" else: test_loader = "val"
targets = np.zeros((len(loaders["test"].dataset), 360, 480)) if args.loss == "aleatoric": scales = np.zeros((len(loaders["test"].dataset), 11, 360, 480)) else: scales = None print(targets.size) for i in range(args.N): print("%d/%d" % (i + 1, args.N)) if args.method not in ["SGD", "Dropout"]: sample_with_cov = args.cov_mat and not args.use_diag with torch.no_grad(): model.sample(scale=args.scale, cov=sample_with_cov) if "SWAG" in args.method: bn_update(loaders["fine_tune"], model) model.eval() if args.method in ["Dropout", "SWAGDrop"]: model.apply(train_dropout) k = 0 current_predictions = np.zeros_like(predictions) for input, target in tqdm.tqdm(loaders["test"]): input = input.cuda(non_blocking=True) torch.manual_seed(i) with torch.no_grad():
for i in range(args.N): print('%d/%d' % (i + 1, args.N)) if args.method == 'KFACLaplace': ## KFAC Laplace needs one forwards pass to load the KFAC model at the beginning model.net.load_state_dict(model.mean_state) if i == 0: model.net.train() loss, _ = losses.cross_entropy(model.net, t_input, t_target) loss.backward(create_graph=True) model.step(update_params=False) if args.method not in ['SGD', 'Dropout']: sample_with_cov = args.cov_mat and not args.use_diag model.sample(scale=args.scale) if 'SWAG' in args.method: utils.bn_update(loaders['train'], model, subset=args.bn_subset) model.eval() if args.method in ['Dropout', 'SWAGDrop']: model.apply(train_dropout) #torch.manual_seed(i) #utils.bn_update(loaders['train'], model) k = 0 for input, target in tqdm.tqdm(loaders['test']): input = input.cuda(non_blocking=True) torch.manual_seed(i)
cov_factor = tsvd.components_ cov_factor /= np.linalg.norm(cov_factor, axis=1, keepdims=True) cov_factor *= scale print(cov_factor[:, 0]) swag_model.cov_factor.copy_(torch.FloatTensor(cov_factor, device=mean.device)) ens_predictions = np.zeros((len(loaders['test'].dataset), num_classes)) targets = np.zeros(len(loaders['test'].dataset)) columns = ['iter ens', 'acc', 'nll'] with torch.no_grad(): for i in range(args.num_samples): swag_model.sample(scale=args.scale) utils.bn_update(loaders['train'], swag_model, subset=args.bn_subset) pred_res = utils.predict(loaders['test'], swag_model) ens_predictions += pred_res['predictions'] targets = pred_res['targets'] values = ['%3d/%3d' % (i + 1, args.num_samples), np.mean(np.argmax(ens_predictions, axis=1) == targets), nll(ens_predictions / (i + 1), targets)] table = tabulate.tabulate([values], columns, tablefmt='simple', floatfmt='8.4f') if i == 0: print(table) else: print(table.split('\n')[2])