def test_swag_cov(self, **kwargs): model = torch.nn.Linear(300, 3, bias=True) swag_model = SWAG(torch.nn.Linear, in_features=300, out_features=3, bias=True, subspace_type = 'covariance', subspace_kwargs = {'max_rank':140}) optimizer = torch.optim.SGD(model.parameters(), lr = 1e-3) # construct swag model via training torch.manual_seed(0) for _ in range(101): model.zero_grad() input = torch.randn(100, 300) output = model(input) loss = ((torch.randn(100, 3) - output)**2.0).sum() loss.backward() optimizer.step() swag_model.collect_model(model) # check to ensure parameters have the correct sizes mean, var = swag_model._get_mean_and_variance() cov_mat_sqrt = swag_model.subspace.get_space() true_cov_mat = cov_mat_sqrt.t().matmul(cov_mat_sqrt) + torch.diag(var) test_cutoff = chi2(df = mean.numel()).ppf(0.95) #95% quantile of p dimensional chi-square distribution for scale in [0.01, 0.1, 0.5, 1.0, 2.0, 5.0]: scaled_cov_mat = true_cov_mat * scale scaled_cov_inv = torch.inverse(scaled_cov_mat) # now test to ensure that sampling has the correct covariance matrix probabilistically all_qforms = [] for _ in range(3000): swag_model.sample(scale=scale) curr_pars = [] for (module, name, _) in swag_model.base_params: curr_pars.append(getattr(module, name)) dev = flatten(curr_pars) - mean #(x - mu)sigma^{-1}(x - mu) qform = dev.matmul(scaled_cov_inv).matmul(dev) all_qforms.append(qform.item()) samples_in_cr = (np.array(all_qforms) < test_cutoff).sum() print(samples_in_cr) #between 94 and 96% of the samples should fall within the threshold #this should be very loose self.assertTrue(0.94*3000 <= samples_in_cr <= 0.96*3000)
def RealNVPTabularSWAG(dim_in, coupling_layers, k, nperlayer=1, subspace='covariance', max_num_models=10): swag_model = SWAG(RealNVPTabular, subspace_type=subspace, subspace_kwargs={'max_rank': max_num_models}, num_coupling_layers=coupling_layers, in_dim=dim_in, hidden_dim=k, num_layers=1, dropout=True) return swag_model
use_validation=not args.use_test, split_classes=args.split_classes ) print('Preparing model') print(*model_cfg.args) model = model_cfg.base(*model_cfg.args, num_classes=num_classes, **model_cfg.kwargs) model.to(args.device) if args.cov_mat: args.no_cov_mat = False else: args.no_cov_mat = True print('SWAG training') swag_model = SWAG(model_cfg.base, no_cov_mat=args.no_cov_mat, max_num_models=args.max_num_models, *model_cfg.args, num_classes=num_classes, **model_cfg.kwargs) swag_model.to(args.device) def schedule(epoch): t = (epoch) / (args.swa_start) lr_ratio = args.swa_lr / args.lr_init if t <= 0.5: factor = 1.0 elif t <= 0.9: factor = 1.0 - (1.0 - lr_ratio) * (t - 0.5) / 0.4 else: factor = lr_ratio return args.lr_init * factor
use_validation=not args.use_test, split_classes=args.split_classes) print('Preparing model') print(*model_cfg.args) model = model_cfg.base(*model_cfg.args, num_classes=num_classes, **model_cfg.kwargs) model.cuda() swag_model = SWAG(model_cfg.base, num_classes=num_classes, subspace_type='pca', subspace_kwargs={ 'max_rank': 140, 'pca_rank': args.rank, }, *model_cfg.args, **model_cfg.kwargs) swag_model.to(args.device) def checkpoint_num(filename): num = filename.split("-")[1] num = num.split(".")[0] num = int(num) return num for file in os.listdir(args.dir):
print('Using model %s' % args.model) model_cfg = getattr(models, args.model) print('Loading dataset %s from %s' % (args.dataset, args.data_path)) loaders, num_classes = data.loaders(args.dataset, args.data_path, args.batch_size, args.num_workers, model_cfg.transform_train, model_cfg.transform_test, use_validation=not args.use_test) print('Preparing model') swag_model = SWAG(model_cfg.base, no_cov_mat=not args.cov_mat, max_num_models=args.swag_rank, *model_cfg.args, num_classes=num_classes, **model_cfg.kwargs) swag_model.to(args.device) ckpt = torch.load(args.checkpoint) criterion = losses.cross_entropy fractions = np.logspace(-np.log10(0.005 * len(loaders['train'].dataset)), 0.0, args.N) swa_accuracies = np.zeros(args.N) swa_nlls = np.zeros(args.N) swag_accuracies = np.zeros(args.N) swag_nlls = np.zeros(args.N)
print(*model_cfg.args) model = model_cfg.base(*model_cfg.args, num_classes=num_classes, **model_cfg.kwargs) wandb.watch(model) model.to(args.device) if args.cov_mat: args.no_cov_mat = False else: args.no_cov_mat = True if args.swa: print("SWAG training") swag_model = SWAG(model_cfg.base, no_cov_mat=args.no_cov_mat, max_num_models=args.max_num_models, *model_cfg.args, num_classes=num_classes, **model_cfg.kwargs) swag_model.to(args.device) else: print("SGD training") def schedule(epoch): t = (epoch) / (args.swa_start if args.swa else args.epochs) lr_ratio = args.swa_lr / args.lr_init if args.swa else 0.01 if t <= 0.5: factor = 1.0 elif t <= 0.9: factor = 1.0 - (1.0 - lr_ratio) * (t - 0.5) / 0.4 else:
label_arr = np.load(args.label_arr) print("Corruption:", (loaders['train'].dataset.targets != label_arr).mean()) loaders['train'].dataset.targets = label_arr print('Preparing model') model = model_cfg.base(*model_cfg.args, num_classes=num_classes, **model_cfg.kwargs) model.to(args.device) print("Model has {} parameters".format( sum([p.numel() for p in model.parameters()]))) swag_model = SWAG(model_cfg.base, args.subspace, {'max_rank': args.max_num_models}, *model_cfg.args, num_classes=num_classes, **model_cfg.kwargs) swag_model.to(args.device) columns = ['swag', 'sample', 'te_loss', 'te_acc', 'ens_loss', 'ens_acc'] if args.warm_start: print('estimating initial PI using K-Means') kmeans_input = [] for ckpt_i, ckpt in enumerate(args.swag_ckpts): #print("Checkpoint {}".format(ckpt)) checkpoint = torch.load(ckpt) swag_model.subspace.rank = torch.tensor(0) swag_model.load_state_dict(checkpoint['state_dict']) mean, variance = swag_model._get_mean_and_variance()
model_cfg.transform_test, use_validation=not args.use_test, split_classes=args.split_classes) print('Preparing model') print(*model_cfg.args) model = model_cfg.base(*model_cfg.args, num_classes=num_classes, **model_cfg.kwargs) model.to(args.device) if args.swag: print('SGLD+SWAG training') swag_model = SWAG(model_cfg.base, subspace_type=args.subspace, subspace_kwargs={'max_rank': args.max_num_models}, *model_cfg.args, num_classes=num_classes, **model_cfg.kwargs) swag_model.to(args.device) else: print('SGLD training') criterion = losses.cross_entropy sgld_optimizer = SGLD(model.parameters(), lr=args.lr_init, weight_decay=args.wd, noise_factor=args.noise_factor) num_batches = len(loaders['train']) num_iters = num_batches * (args.epochs - args.ens_start + 1)
args.data_path, args.batch_size, args.num_workers, model_cfg.transform_train, model_cfg.transform_test, use_validation=not args.use_test) print('Preparing model') model = model_cfg.base(*model_cfg.args, num_classes=num_classes, **model_cfg.kwargs) model.to(args.device) swag_model = SWAG(model_cfg.base, no_cov_mat=False, max_num_models=args.swag_rank, *model_cfg.args, num_classes=num_classes, **model_cfg.kwargs) swag_model.to(args.device) criterion = losses.cross_entropy W = [] num_checkpoints = len(args.checkpoint) for path in args.checkpoint: print('Loading %s' % path) checkpoint = torch.load(path) model.load_state_dict(checkpoint['state_dict']) swag_model.collect_model(model) W.append(
model_cfg.transform_train, model_cfg.transform_test, use_validation=not args.use_test, split_classes=args.split_classes, shuffle_train=False) """if args.split_classes is not None: num_classes /= 2 num_classes = int(num_classes)""" print('Preparing model') if args.method in ['SWAG', 'HomoNoise', 'SWAGDrop']: model = SWAG(model_cfg.base, num_classes=num_classes, subspace_type='pca', subspace_kwargs={ 'max_rank': 140, 'pca_rank': args.rank, }, *model_cfg.args, **model_cfg.kwargs) elif args.method in ['SGD', 'Dropout', 'KFACLaplace']: model = model_cfg.base(*model_cfg.args, num_classes=num_classes, **model_cfg.kwargs) else: assert False model.cuda() def train_dropout(m): if type(m) == torch.nn.modules.dropout.Dropout:
args.batch_size, args.num_workers, ) print('Preparing model') model = model_class(pretrained=args.pretrained, num_classes=num_classes) model.to(args.device) if args.cov_mat: args.no_cov_mat = False else: args.no_cov_mat = True if args.swa: print('SWAG training') args.swa_device = 'cpu' if args.swa_cpu else args.device swag_model = SWAG(model_class, no_cov_mat=args.no_cov_mat, max_num_models=20, num_classes=num_classes) swag_model.to(args.swa_device) if args.pretrained: model.to(args.swa_device) swag_model.collect_model(model) model.to(args.device) else: print('SGD training') def schedule(epoch): if args.swa and epoch >= args.swa_start: return args.swa_lr else: return args.lr_init * (0.1 ** (epoch // 30))
args = parser.parse_args() args.device = None if torch.cuda.is_available(): args.device = torch.device('cuda') else: args.device = torch.device('cpu') # torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = True print('Using model %s' % args.model) model_class = getattr(torchvision.models, args.model) num_classes = 1000 print('Preparing model') model = model_class(pretrained=args.pretrained, num_classes=num_classes) model.to(args.device) print('SWAG training') swag_model = SWAG(model_class, no_cov_mat=False, max_num_models=20, num_classes=num_classes) swag_model.to('cpu') for k in range(100): swag_model.collect_model(model) print(k + 1)
loaders, num_classes = data.loaders(args.data_path, args.batch_size, args.num_workers) print("Preparing model") model = model_class(pretrained=args.pretrained, num_classes=num_classes) model.to(args.device) if args.cov_mat: args.no_cov_mat = False else: args.no_cov_mat = True if args.swa: print("SWAG training") args.swa_device = "cpu" if args.swa_cpu else args.device swag_model = SWAG( model_class, no_cov_mat=args.no_cov_mat, max_num_models=20, num_classes=num_classes, ) swag_model.to(args.swa_device) if args.pretrained: model.to(args.swa_device) swag_model.collect_model(model) model.to(args.device) else: print("SGD training") def schedule(epoch): if args.swa and epoch >= args.swa_start: return args.swa_lr else:
args.num_workers, model_cfg.transform_train, model_cfg.transform_test, use_validation=not args.use_test, split_classes=args.split_classes) model = model_cfg.base(*model_cfg.args, num_classes=num_classes, **model_cfg.kwargs) model.to(args.device) swag_model = SWAG(model_cfg.base, num_classes=num_classes, subspace_type='pca', subspace_kwargs={ 'max_rank': args.rank, 'pca_rank': args.rank, }, *model_cfg.args, **model_cfg.kwargs) swag_model.to(args.device) print('Loading checkpoint %s' % args.ckpt) checkpoint = torch.load(args.ckpt) swag_model.load_state_dict(checkpoint['state_dict']) num_parameters = sum([p.numel() for p in model.parameters()]) offset = 0 for param in model.parameters(): size = param.numel()
def test_swag_diag(self, **kwargs): model = torch.nn.Linear(300, 3, bias=True) swag_model = SWAG( torch.nn.Linear, in_features=300, out_features=3, bias=True, no_cov_mat=True, max_num_models=100, loading=False, ) optimizer = torch.optim.SGD(model.parameters(), lr=1e-3) # construct swag model via training torch.manual_seed(0) for _ in range(101): model.zero_grad() input = torch.randn(100, 300) output = model(input) loss = ((torch.randn(100, 3) - output) ** 2.0).sum() loss.backward() optimizer.step() swag_model.collect_model(model) # check to ensure parameters have the correct sizes mean_list = [] sq_mean_list = [] for (module, name), param in zip(swag_model.params, model.parameters()): mean = module.__getattr__("%s_mean" % name) sq_mean = module.__getattr__("%s_sq_mean" % name) self.assertEqual(param.size(), mean.size()) self.assertEqual(param.size(), sq_mean.size()) mean_list.append(mean) sq_mean_list.append(sq_mean) mean = flatten(mean_list).cuda() sq_mean = flatten(sq_mean_list).cuda() for scale in [0.01, 0.1, 0.5, 1.0, 2.0, 5.0]: var = scale * (sq_mean - mean ** 2) std = torch.sqrt(var) dist = torch.distributions.Normal(mean, std) # now test to ensure that sampling has the correct covariance matrix probabilistically all_qforms = 0 for _ in range(20): swag_model.sample(scale=scale, cov=False) curr_pars = [] for (module, name) in swag_model.params: curr_pars.append(getattr(module, name)) curr_probs = dist.cdf(flatten(curr_pars)) # check if within 95% CI num_in_cr = ((curr_probs > 0.025) & (curr_probs < 0.975)).float().sum() # all_qforms.append( num_in_cr ) all_qforms += num_in_cr # print(all_qforms/(20 * mean.numel())) # now compute average avg_prob_in_cr = all_qforms / (20 * mean.numel()) # CLT should hold a bit tighter here self.assertTrue(0.945 <= avg_prob_in_cr <= 0.955)
label_arr = np.load(args.label_arr) print("Corruption:", (loaders['train'].dataset.targets != label_arr).mean()) loaders['train'].dataset.targets = label_arr print('Preparing model') model = model_cfg.base(*model_cfg.args, num_classes=num_classes, **model_cfg.kwargs) model.to(args.device) print("Model has {} parameters".format( sum([p.numel() for p in model.parameters()]))) swag_model = SWAG(model_cfg.base, args.subspace, {'max_rank': args.max_num_models}, *model_cfg.args, num_classes=num_classes, **model_cfg.kwargs) swag_model.to(args.device) columns = ['swag', 'sample', 'te_loss', 'te_acc', 'ens_loss', 'ens_acc'] n_ensembled = 0. multiswag_probs = None for ckpt_i, ckpt in enumerate(args.swag_ckpts): print("Checkpoint {}".format(ckpt)) checkpoint = torch.load(ckpt) swag_model.subspace.rank = torch.tensor(0) swag_model.load_state_dict(checkpoint['state_dict'])
loaders, num_classes = data.loaders(args.dataset, args.data_path, args.batch_size, args.num_workers, model_cfg.transform_train, model_cfg.transform_test, use_validation=not args.use_test, split_classes=args.split_classes, shuffle_train=False) print('Preparing model') if args.method in ['SWAG', 'HomoNoise', 'SWAGDrop']: model = SWAG(model_cfg.base, no_cov_mat=not args.cov_mat, max_num_models=args.max_num_models, loading=True, *model_cfg.args, num_classes=num_classes, **model_cfg.kwargs) elif args.method in ['SGD', 'Dropout', 'KFACLaplace']: model = model_cfg.base(*model_cfg.args, num_classes=num_classes, **model_cfg.kwargs) else: assert False model.cuda() def train_dropout(m): if type(m) == torch.nn.modules.dropout.Dropout: m.train()
args.num_workers, model_cfg.transform_train, model_cfg.transform_test, use_validation=not args.use_test) print('Preparing model') model = model_cfg.base(*model_cfg.args, num_classes=num_classes, **model_cfg.kwargs) model.to(args.device) swag_model = SWAG(model_cfg.base, num_classes=num_classes, subspace_type='pca', subspace_kwargs={ 'max_rank': 20, 'pca_rank': args.rank, }, *model_cfg.args, **model_cfg.kwargs) swag_model.to(args.device) criterion = losses.cross_entropy print('Loading: %s' % args.checkpoint) ckpt = torch.load(args.checkpoint) swag_model.load_state_dict(ckpt['state_dict'], strict=False) mean, _, cov_factor = swag_model.export_numpy_parameters(True) norms = np.linalg.norm(cov_factor, axis=1)
joint_transform=model_cfg.joint_transform, ft_joint_transform=model_cfg.ft_joint_transform, target_transform=model_cfg.target_transform, ) if args.loss == "cross_entropy": criterion = losses.seg_cross_entropy else: criterion = losses.seg_ale_cross_entropy print("Preparing model") if args.method in ["SWAG", "HomoNoise", "SWAGDrop"]: model = SWAG( model_cfg.base, no_cov_mat=False, max_num_models=20, num_classes=num_classes, use_aleatoric=args.loss == "aleatoric", ) elif args.method in ["SGD", "Dropout"]: # construct and load model model = model_cfg.base(num_classes=num_classes, use_aleatoric=args.loss == "aleatoric") else: assert False model.cuda() def train_dropout(m): if m.__module__ == torch.nn.modules.dropout.__name__:
torch.cuda.manual_seed(args.seed) print('Using model %s' % args.model) model_class = getattr(torchvision.models, args.model) print('Loading ImageNet from %s' % (args.data_path)) loaders, num_classes = data.loaders( args.data_path, args.batch_size, args.num_workers, ) print('Preparing model') swag_model = SWAG(model_class, no_cov_mat=not args.cov_mat, loading=True, max_num_models=20, num_classes=num_classes) swag_model.to(args.device) criterion = losses.cross_entropy print('Loading checkpoint %s' % args.ckpt) checkpoint = torch.load(args.ckpt) swag_model.load_state_dict(checkpoint['state_dict']) print('SWA') swag_model.sample(0.0) print('SWA BN update') utils.bn_update(loaders['train'], swag_model, verbose=True, subset=0.1) print('SWA EVAL')
class RegressionRunner(RegressionModel): def __init__(self, base, epochs, criterion, batch_size=50, lr_init=1e-2, momentum=0.9, wd=1e-4, swag_lr=1e-3, swag_freq=1, swag_start=50, subspace_type='pca', subspace_kwargs={'max_rank': 20}, use_cuda=False, use_swag=False, double_bias_lr=False, model_variance=True, num_samples=30, scale=0.5, const_lr=False, *args, **kwargs): self.base = base self.model = base(*args, **kwargs) num_pars = 0 for p in self.model.parameters(): num_pars += p.numel() print('number of parameters: ', num_pars) if use_cuda: self.model.cuda() if use_swag: self.swag_model = SWAG(base, subspace_type=subspace_type, subspace_kwargs=subspace_kwargs, *args, **kwargs) if use_cuda: self.swag_model.cuda() else: self.swag_model = None self.use_cuda = use_cuda if not double_bias_lr: pars = self.model.parameters() else: pars = [] for name, module in self.model.named_parameters(): if 'bias' in str(name): print('Doubling lr of ', name) pars.append({'params': module, 'lr': 2.0 * lr_init}) else: pars.append({'params': module, 'lr': lr_init}) self.optimizer = torch.optim.SGD(pars, lr=lr_init, momentum=momentum, weight_decay=wd) self.const_lr = const_lr self.batch_size = batch_size # TODO: set up criterions better for classification if model_variance: self.criterion = criterion(noise_var=None) else: self.criterion = criterion(noise_var=1.0) if self.criterion.noise_var is not None: self.var = self.criterion.noise_var self.epochs = epochs self.lr_init = lr_init self.use_swag = use_swag self.swag_start = swag_start self.swag_lr = swag_lr self.swag_freq = swag_freq self.num_samples = num_samples self.scale = scale def train(self, model, loader, optimizer, criterion, lr_init=1e-2, epochs=3000, swag_model=None, swag=False, swag_start=2000, swag_freq=50, swag_lr=1e-3, print_freq=100, use_cuda=False, const_lr=False): # copied from pavels regression notebook if const_lr: lr = lr_init train_res_list = [] for epoch in range(epochs): if not const_lr: t = (epoch + 1) / swag_start if swag else (epoch + 1) / epochs lr_ratio = swag_lr / lr_init if swag else 0.05 if t <= 0.5: factor = 1.0 elif t <= 0.9: factor = 1.0 - (1.0 - lr_ratio) * (t - 0.5) / 0.4 else: factor = lr_ratio lr = factor * lr_init adjust_learning_rate(optimizer, factor) train_res = utils.train_epoch(loader, model, criterion, optimizer, cuda=use_cuda, regression=True) train_res_list.append(train_res) if swag and epoch > swag_start: swag_model.collect_model(model) if (epoch % print_freq == 0 or epoch == epochs - 1): print('Epoch %d. LR: %g. Loss: %.4f' % (epoch, lr, train_res['loss'])) return train_res_list def fit(self, features, labels): self.features, self.labels = torch.FloatTensor( features), torch.FloatTensor(labels) # construct data loader self.data_loader = DataLoader(TensorDataset(self.features, self.labels), batch_size=self.batch_size) # now train with pre-specified options result = self.train(model=self.model, loader=self.data_loader, optimizer=self.optimizer, criterion=self.criterion, lr_init=self.lr_init, swag_model=self.swag_model, swag=self.use_swag, swag_start=self.swag_start, swag_freq=self.swag_freq, swag_lr=self.swag_lr, use_cuda=self.use_cuda, epochs=self.epochs, const_lr=self.const_lr) if self.criterion.noise_var is not None: # another forwards pass through network to estimate noise variance preds, targets = utils.predictions(model=self.model, test_loader=self.data_loader, regression=True, cuda=self.use_cuda) self.var = np.power(np.linalg.norm(preds - targets), 2.0) / targets.shape[0] print(self.var) return result def predict(self, features, swag_model=None): """ default prediction method is to use built in Low rank Gaussian SWA: scale = 0.0, num_samples = 1 """ swag_model = swag_model if swag_model is not None else self.swag_model if self.use_cuda: device = torch.device('cuda') else: device = torch.device('cpu') with torch.no_grad(): if swag_model is None: self.model.eval() preds = self.model( torch.FloatTensor(features).to(device)).data.cpu() if preds.size(1) == 1: var = torch.ones_like(preds[:, 0]).unsqueeze(1) * self.var else: var = preds[:, 1].view(-1, 1) preds = preds[:, 0].view(-1, 1) print(var.mean()) else: prediction = 0 sq_prediction = 0 for _ in range(self.num_samples): swag_model.sample(scale=self.scale) current_prediction = swag_model( torch.FloatTensor(features).to(device)).data.cpu() prediction += current_prediction if current_prediction.size(1) == 2: #convert to standard deviation current_prediction[:, 1] = current_prediction[:, 1]**0.5 sq_prediction += current_prediction**2.0 # preds = bma/(self.num_samples) # compute mean of prediction # \mu^* preds = (prediction[:, 0] / self.num_samples).view(-1, 1) # 1/M \sum(\sigma^2(x) + \mu^2(x)) - \mu*^2 var = torch.sum(sq_prediction, 1, keepdim=True ) / self.num_samples - preds.pow(2.0) # add variance if not heteroscedastic if prediction.size(1) == 1: var = var + self.var return preds.numpy(), var.numpy()
'max_rank': 20, 'pca_rank': 'mle', } }, { 'name': 'freq_dir', 'kwargs': { 'max_rank': 20, } }] for item in subspaces: name, kwargs = item['name'], item['kwargs'] print('Now running %s %r' % (name, kwargs)) model = model_generator() small_swag_model = SWAG(base=model_generator, subspace_type=name, subspace_kwargs=kwargs) optimizer = torch.optim.SGD(model.parameters(), lr=1e-3, momentum=0.9, weight_decay=1e-4) loader = generate_dataloaders(N=10) state_dict = None for epoch in range(num_epochs): model.train() for x, y in loader: model.zero_grad()
def __init__(self, base, epochs, criterion, batch_size=50, lr_init=1e-2, momentum=0.9, wd=1e-4, swag_lr=1e-3, swag_freq=1, swag_start=50, subspace_type='pca', subspace_kwargs={'max_rank': 20}, use_cuda=False, use_swag=False, double_bias_lr=False, model_variance=True, num_samples=30, scale=0.5, const_lr=False, *args, **kwargs): self.base = base self.model = base(*args, **kwargs) num_pars = 0 for p in self.model.parameters(): num_pars += p.numel() print('number of parameters: ', num_pars) if use_cuda: self.model.cuda() if use_swag: self.swag_model = SWAG(base, subspace_type=subspace_type, subspace_kwargs=subspace_kwargs, *args, **kwargs) if use_cuda: self.swag_model.cuda() else: self.swag_model = None self.use_cuda = use_cuda if not double_bias_lr: pars = self.model.parameters() else: pars = [] for name, module in self.model.named_parameters(): if 'bias' in str(name): print('Doubling lr of ', name) pars.append({'params': module, 'lr': 2.0 * lr_init}) else: pars.append({'params': module, 'lr': lr_init}) self.optimizer = torch.optim.SGD(pars, lr=lr_init, momentum=momentum, weight_decay=wd) self.const_lr = const_lr self.batch_size = batch_size # TODO: set up criterions better for classification if model_variance: self.criterion = criterion(noise_var=None) else: self.criterion = criterion(noise_var=1.0) if self.criterion.noise_var is not None: self.var = self.criterion.noise_var self.epochs = epochs self.lr_init = lr_init self.use_swag = use_swag self.swag_start = swag_start self.swag_lr = swag_lr self.swag_freq = swag_freq self.num_samples = num_samples self.scale = scale
print(*model_cfg.args) model = model_cfg.base(*model_cfg.args, num_classes=num_classes, input_shape=input_shape, **model_cfg.kwargs) #model.to(args.device) if args.cov_mat: args.no_cov_mat = False else: args.no_cov_mat = True if args.swa: print("SWAG training") swag_model = SWAG(model, no_cov_mat=args.no_cov_mat, max_num_models=args.max_num_models, *model_cfg.args, num_classes=num_classes, **model_cfg.kwargs) #swag_model.to(args.device) else: print("SGD training") def schedule(epoch): t = (epoch) / (args.swa_start if args.swa else args.epochs) lr_ratio = args.swa_lr / args.lr_init if args.swa else 0.01 if t <= 0.5: factor = 1.0 elif t <= 0.9: factor = 1.0 - (1.0 - lr_ratio) * (t - 0.5) / 0.4 else:
target_transform=model_cfg.target_transform, ) # criterion = nn.NLLLoss(weight=camvid.class_weight[:-1].cuda(), reduction='none').cuda() if args.loss == "cross_entropy": criterion = losses.seg_cross_entropy else: criterion = losses.seg_ale_cross_entropy # construct and load model if args.swa_resume is not None: checkpoint = torch.load(args.swa_resume) model = SWAG( model_cfg.base, no_cov_mat=False, max_num_models=20, num_classes=num_classes, use_aleatoric=args.loss == "aleatoric", ) model.cuda() model.load_state_dict(checkpoint["state_dict"]) model.sample(0.0) bn_update(loaders["fine_tune"], model) else: model = model_cfg.base(num_classes=num_classes, use_aleatoric=args.loss == "aleatoric").cuda() checkpoint = torch.load(args.resume) start_epoch = checkpoint["epoch"] print(start_epoch) model.load_state_dict(checkpoint["state_dict"])
args.data_path, args.batch_size, args.num_workers, model_cfg.transform_train, model_cfg.transform_test, use_validation=not args.use_test, split_classes=args.split_classes) print('Preparing model') print(*model_cfg.args) swag_model = SWAG(model_cfg.base, num_classes=num_classes, subspace_type='pca', subspace_kwargs={ 'max_rank': 20, 'pca_rank': args.rank, }, *model_cfg.args, **model_cfg.kwargs) swag_model.to(args.device) print('Loading: %s' % args.checkpoint) ckpt = torch.load(args.checkpoint) swag_model.load_state_dict(ckpt['state_dict'], strict=False) swag_model.set_swa() print("SWA:", utils.eval(loaders["train"], swag_model, criterion=losses.cross_entropy)) mean, var, cov_factor = swag_model.get_space()
criterion = partial(criterion, weight=class_weights) if args.resume is not None: print('Resume training from %s' % args.resume) checkpoint = torch.load(args.resume) start_epoch = checkpoint['epoch'] model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) del checkpoint if args.swa: print('SWAG training') swag_model = SWAG(model_cfg.base, no_cov_mat=False, max_num_models=20, *model_cfg.args, num_classes=num_classes, use_aleatoric=args.loss == 'aleatoric', **model_cfg.kwargs) swag_model.to(args.device) else: print('SGD training') if args.swa and args.swa_resume is not None: checkpoint = torch.load(args.swa_resume) swag_model = SWAG(model_cfg.base, no_cov_mat=False, max_num_models=20, *model_cfg.args, num_classes=num_classes, use_aleatoric=args.loss == 'aleatoric',
if torch.cuda.is_available(): args.device = torch.device('cuda') else: args.device = torch.device('cpu') print('Using model %s' % args.model) model_cfg = getattr(models, args.model) print('Preparing model') print(*model_cfg.args) model = model_cfg.base(*model_cfg.args, num_classes=args.num_classes, **model_cfg.kwargs) model.to(args.device) swag_model = SWAG(model_cfg.base, subspace_type=args.subspace, subspace_kwargs={'max_rank': args.max_num_models}, *model_cfg.args, num_classes=args.num_classes, **model_cfg.kwargs) swag_model.to(args.device) for path in args.checkpoint: print(path) ckpt = torch.load(path) model.load_state_dict(ckpt['state_dict']) swag_model.collect_model(model) torch.save({'state_dict': swag_model.state_dict()}, args.path)
def test_swag_cov(self, **kwargs): model = torch.nn.Linear(300, 3, bias=True) swag_model = SWAG( torch.nn.Linear, in_features=300, out_features=3, bias=True, no_cov_mat=False, max_num_models=100, loading=False, ) optimizer = torch.optim.SGD(model.parameters(), lr=1e-3) # construct swag model via training torch.manual_seed(0) for _ in range(101): model.zero_grad() input = torch.randn(100, 300) output = model(input) loss = ((torch.randn(100, 3) - output) ** 2.0).sum() loss.backward() optimizer.step() swag_model.collect_model(model) # check to ensure parameters have the correct sizes mean_list = [] sq_mean_list = [] cov_mat_sqrt_list = [] for (module, name), param in zip(swag_model.params, model.parameters()): mean = module.__getattr__("%s_mean" % name) sq_mean = module.__getattr__("%s_sq_mean" % name) cov_mat_sqrt = module.__getattr__("%s_cov_mat_sqrt" % name) self.assertEqual(param.size(), mean.size()) self.assertEqual(param.size(), sq_mean.size()) self.assertEqual( [swag_model.max_num_models, param.numel()], list(cov_mat_sqrt.size()) ) mean_list.append(mean) sq_mean_list.append(sq_mean) cov_mat_sqrt_list.append(cov_mat_sqrt) mean = flatten(mean_list).cuda() sq_mean = flatten(sq_mean_list).cuda() cov_mat_sqrt = torch.cat(cov_mat_sqrt_list, dim=1).cuda() true_cov_mat = ( 1.0 / (swag_model.max_num_models - 1) ) * cov_mat_sqrt.t().matmul(cov_mat_sqrt) + torch.diag(sq_mean - mean ** 2) test_cutoff = chi2(df=mean.numel()).ppf( 0.95 ) # 95% quantile of p dimensional chi-square distribution for scale in [0.01, 0.1, 0.5, 1.0, 2.0, 5.0]: scaled_cov_mat = true_cov_mat * scale scaled_cov_inv = torch.inverse(scaled_cov_mat) # now test to ensure that sampling has the correct covariance matrix probabilistically all_qforms = [] for _ in range(2000): swag_model.sample(scale=scale, cov=True) curr_pars = [] for (module, name) in swag_model.params: curr_pars.append(getattr(module, name)) dev = flatten(curr_pars) - mean # (x - mu)sigma^{-1}(x - mu) qform = dev.matmul(scaled_cov_inv).matmul(dev) all_qforms.append(qform.item()) samples_in_cr = (np.array(all_qforms) < test_cutoff).sum() print(samples_in_cr) # between 94 and 96% of the samples should fall within the threshold # this should be very loose self.assertTrue(1880 <= samples_in_cr <= 1920)