class ParticleVI(object): def __init__(self, algo, dataset, kernel_fn, base_model_fn, num_particles=10, resume=False, resume_epoch=None, resume_lr=1e-4): self.algo = algo self.dataset = dataset self.kernel_fn = kernel_fn self.num_particles = num_particles print("running {} on {}".format(algo, dataset)) if self.dataset == 'mnist': self.train_loader, self.test_loader, self.val_loader = datagen.load_mnist( split=True) elif self.dataset == 'cifar10': self.train_loader, self.test_loader, self.val_loader, = datagen.load_cifar10( split=True) else: raise NotImplementedError if kernel_fn == 'rbf': self.kernel = rbf_fn return_activations = False elif kernel_fn == 'cka': self.kernel = kernel_cka return_activations = True else: raise NotImplementedError models = [ base_model_fn(num_classes=6, return_activations=return_activations).cuda() for _ in range(num_particles) ] self.models = models param_set, state_dict = extract_parameters(self.models) self.state_dict = state_dict self.param_set = torch.nn.Parameter(param_set.clone(), requires_grad=True) self.optimizer = torch.optim.Adam([{ 'params': self.param_set, 'lr': 1e-3, 'weight_decay': 1e-4 }]) if resume: print('resuming from epoch {}'.format(resume_epoch)) d = torch.load('saved_models/{}/{}2/model_epoch_{}.pt'.format( self.dataset, model_id, resume_epoch)) for model, sd in zip(self.models, d['models']): model.load_state_dict(sd) self.param_set = d['params'] self.state_dict = d['state_dict'] self.optimizer = torch.optim.Adam([{ 'params': self.param_set, 'lr': resume_lr, 'weight_decay': 1e-4 }]) self.start_epoch = resume_epoch else: self.start_epoch = 0 self.activation_length = self.models[0].activation_length self.loss_fn = torch.nn.CrossEntropyLoss() self.kernel_width_averager = Averager(shape=()) def kernel_width(self, dist): """Update kernel_width averager and get latest kernel_width. """ if dist.ndim > 1: dist = torch.sum(dist, dim=-1) assert dist.ndim == 1, "dist must have dimension 1 or 2." width, _ = torch.median(dist, dim=0) width = width / np.log(len(dist)) self.kernel_width_averager.update(width) return self.kernel_width_averager.get() def svgd_grad(self, loss_grad, params): """ Compute particle gradients via SVGD, empirical expectation evaluated by splitting half of the sampled batch. """ num_particles = params.shape[0] params2 = params.detach().requires_grad_(True) # kernel_weight, kernel_grad = self.kernel(params2, params, self.kernel_width) for i in range(num_particles): for j in range(num_particles): if i == j: continue print(params[i].shape) k, _ = self.kernel(params[i], params[j]) print(k.shape) if kernel_grad is None: kernel_grad = torch.autograd.grad(kernel_weight.sum(), params2)[0] kernel_logp = torch.matmul(kernel_weight.t().detach(), loss_grad) / num_particles grad = kernel_logp - kernel_grad.mean(0) return grad def test(self, test_loader, eval_loss=True): for model in self.models: model.eval() correct = 0 test_loss = 0 outputs_all = [] for i, (inputs, targets) in enumerate(test_loader): preds = [] loss = 0 inputs = inputs.cuda() targets = targets.cuda() for model in self.models: outputs = model(inputs) #if self.kernel_fn == 'cka': # outputs, _ = outputs if eval_loss: loss += self.loss_fn(outputs, targets) else: loss += 0 preds.append(torch.nn.functional.softmax(outputs, dim=-1)) pred = torch.stack(preds) outputs_all.append(pred) preds = pred.mean(0) vote = preds.argmax(-1).cpu() correct += vote.eq( targets.cpu().data.view_as(vote)).float().cpu().sum() test_loss += (loss / self.num_particles) outputs_all = torch.cat(outputs_all, dim=1) test_loss /= i correct /= len(test_loader.dataset) for model in self.models: model.train() return outputs_all, (test_loss, correct) def train(self, epochs): for epoch in range(self.start_epoch, epochs): loss_epoch = 0 for (inputs, targets) in self.train_loader: outputs = [] activations = torch.zeros(len(self.models), len(targets), self.activation_length).cuda() neglogp = torch.zeros(self.num_particles) insert_items(self.models, self.param_set, self.state_dict) neglogp_grads = torch.zeros_like(self.param_set) for i, model in enumerate(self.models): inputs = inputs.cuda() targets = targets.cuda() output, activation = model(inputs) outputs.append(output) activations[i, :, :] = activation loss = self.loss_fn(outputs[-1], targets) grad = torch.autograd.grad(loss.sum(), activation)[0] print(grad) print(grad.shape) print(torch.count_nonzero(grad), np.prod(grad.shape)) loss.backward() neglogp[i] = loss g = [] for name, param in model.named_parameters(): g.append(param.grad.view(-1)) neglogp_grads[i] = torch.cat(g) model.zero_grad() par_vi_grad = self.svgd_grad(neglogp_grads, self.param_set) self.optimizer.zero_grad() self.param_set.grad = par_vi_grad self.optimizer.step() loss_step = neglogp.mean() loss_epoch += loss_step loss_epoch /= self.num_particles print('Train Epoch {} [cum loss: {}]\n'.format(epoch, loss_epoch)) if epoch % 1 == 0: insert_items(self.models, self.param_set, self.state_dict) with torch.no_grad(): outputs, stats = self.test(self.val_loader) outputs2, _ = self.test(self.test_loader, eval_loss=False) test_loss, correct = stats print('Test Loss: {}'.format(test_loss)) print('Test Acc: {}%'.format(correct * 100)) uncertainties = uncertainty(outputs) entropy, variance = uncertainties uncertainties2 = uncertainty(outputs2) entropy2, variance2 = uncertainties2 auc_entropy = auc_score(entropy, entropy2) auc_variance = auc_score(variance, variance2) print('Test AUC Entropy: {}'.format(auc_entropy)) print('Test AUC Variance: {}'.format(auc_variance)) params = { 'params': self.param_set, 'state_dict': self.state_dict, 'models': [m.state_dict() for m in self.models], 'optimizer': self.optimizer.state_dict() } save_dir = 'saved_models/{}/{}2/'.format( self.dataset, model_id) fn = 'model_epoch_{}.pt'.format(epoch) print('saving model: {}'.format(fn)) os.makedirs(save_dir, exist_ok=True) torch.save(params, save_dir + fn) print('*' * 86)
class ParticleVI(object): def __init__(self, algo, dataset, kernel_fn, base_model_fn, num_particles=10, resume=False, resume_epoch=None, resume_lr=1e-4): self.algo = algo self.dataset = dataset self.kernel_fn = kernel_fn self.num_particles = num_particles print("running {} on {}".format(algo, dataset)) if self.dataset == 'regression': self.data = toy.generate_regression_data(80, 200) (self.train_data, self.train_targets), (self.test_data, self.test_targets) = self.data elif self.dataset == 'classification': self.train_data, self.train_targets = toy.generate_classification_data( 100) self.test_data, self.test_targets = toy.generate_classification_data( 200) else: raise NotImplementedError if kernel_fn == 'rbf': self.kernel = rbf_fn else: raise NotImplementedError models = [base_model_fn().cuda() for _ in range(num_particles)] self.models = models param_set, state_dict = extract_parameters(self.models) self.state_dict = state_dict self.param_set = torch.nn.Parameter(param_set.clone(), requires_grad=True) self.optimizer = torch.optim.Adam([{ 'params': self.param_set, 'lr': 1e-3 }]) if self.dataset == 'regression': self.loss_fn = torch.nn.MSELoss() elif self.dataset == 'classification': self.loss_fn = torch.nn.CrossEntropyLoss() self.kernel_width_averager = Averager(shape=()) def kernel_width(self, dist): """Update kernel_width averager and get latest kernel_width. """ if dist.ndim > 1: dist = torch.sum(dist, dim=-1) assert dist.ndim == 1, "dist must have dimension 1 or 2." width, _ = torch.median(dist, dim=0) width = width / np.log(len(dist)) self.kernel_width_averager.update(width) return self.kernel_width_averager.get() def rbf_fn(self, x, y): Nx = x.shape[0] Ny = y.shape[0] x = x.view(Nx, -1) y = y.view(Ny, -1) Dx = x.shape[1] Dy = y.shape[1] assert Dx == Dy diff = x.unsqueeze(1) - y.unsqueeze(0) # [Nx, Ny, D] dist_sq = torch.sum(diff**2, -1) # [Nx, Ny] h = self.kernel_width(dist_sq.view(-1)) kappa = torch.exp(-dist_sq / h) # [Nx, Nx] kappa_grad = torch.einsum('ij,ijk->ijk', kappa, -2 * diff / h) # [Nx, Ny, D] return kappa, kappa_grad def svgd_grad(self, loss_grad, params): """ Compute particle gradients via SVGD, empirical expectation evaluated by splitting half of the sampled batch. """ num_particles = params.shape[0] params2 = params.detach().requires_grad_(True) kernel_weight, kernel_grad = self.rbf_fn(params2, params) if kernel_grad is None: kernel_grad = torch.autograd.grad(kernel_weight.sum(), params2)[0] kernel_logp = torch.matmul(kernel_weight.t().detach(), loss_grad) / num_particles grad = kernel_logp - kernel_grad.mean(0) return grad def test(self, eval_loss=True): for model in self.models: model.eval() correct = 0 test_loss = 0 preds = [] loss = 0 test_data = self.test_data.cuda() test_targets = self.test_targets.cuda() for model in self.models: outputs = model(test_data) if eval_loss: loss += self.loss_fn(outputs, test_targets) else: loss += 0 preds.append(outputs) preds = torch.stack(preds) p_mean = preds.mean(0) if self.dataset == 'classification': preds = torch.nn.functional.softmax(preds, dim=-1) preds = preds.mean(0) vote = preds.argmax(-1).cpu() correct = vote.eq( test_targets.cpu().data.view_as(vote)).float().cpu().sum() correct /= len(test_targets) else: correct = 0 test_loss += (loss / self.num_particles) outputs_all = preds test_loss /= len(self.models) for model in self.models: model.train() return outputs_all, (test_loss, correct) def train(self, epochs): for epoch in range(0, epochs): loss_epoch = 0 neglogp = torch.zeros(self.num_particles) insert_items(self.models, self.param_set, self.state_dict) neglogp_grads = torch.zeros_like(self.param_set) outputs = [] for i, model in enumerate(self.models): train_data = self.train_data.cuda() train_targets = self.train_targets.cuda() output = model(train_data) outputs.append(output) loss = self.loss_fn(outputs[-1], train_targets) loss.backward() neglogp[i] = loss g = [] for name, param in model.named_parameters(): g.append(param.grad.view(-1)) neglogp_grads[i] = torch.cat(g) model.zero_grad() par_vi_grad = self.svgd_grad(neglogp_grads, self.param_set) self.optimizer.zero_grad() self.param_set.grad = par_vi_grad self.optimizer.step() loss_step = neglogp.mean() loss_epoch += loss_step loss_epoch /= self.num_particles print('Train Epoch {} [cum loss: {}]\n'.format(epoch, loss_epoch)) if epoch % 100 == 0: insert_items(self.models, self.param_set, self.state_dict) with torch.no_grad(): outputs, stats = self.test(eval_loss=False) test_loss, correct = stats print('Test Loss: {}'.format(test_loss)) print('Test Acc: {}%'.format(correct * 100)) if self.dataset == 'regression': toy.plot_regression(self.models, self.data, epoch, tag='svgd') if self.dataset == 'classification': toy.plot_classification(self.models, epoch, tag='svgd') print('*' * 86)
class ParticleVI(object): def __init__(self, algo, dataset, kernel_fn, base_model_fn, num_particles=50, resume=False, resume_epoch=None, resume_lr=1e-4): self.algo = algo self.dataset = dataset self.kernel_fn = kernel_fn self.num_particles = num_particles print("running {} on {}".format(algo, dataset)) self._use_wandb = False self._save_model = False if self.dataset == 'mnist': self.train_loader, self.test_loader, self.val_loader = datagen.load_mnist( split=True) elif self.dataset == 'cifar10': self.train_loader, self.test_loader, self.val_loader, = datagen.load_cifar10( split=True) else: raise NotImplementedError if kernel_fn == 'rbf': self.kernel = rbf_fn else: raise NotImplementedError models = [ base_model_fn(num_classes=6).cuda() for _ in range(num_particles) ] self.models = models param_set, state_dict = extract_parameters(self.models) self.state_dict = state_dict self.param_set = torch.nn.Parameter(param_set.clone(), requires_grad=True) self.optimizer = torch.optim.Adam([{ 'params': self.param_set, 'lr': 1e-3, 'weight_decay': 1e-4 }]) if resume: print('resuming from epoch {}'.format(resume_epoch)) d = torch.load('saved_models/{}/{}2/model_epoch_{}.pt'.format( self.dataset, model_id, resume_epoch)) for model, sd in zip(self.models, d['models']): model.load_state_dict(sd) self.param_set = d['params'] self.state_dict = d['state_dict'] self.optimizer = torch.optim.Adam([{ 'params': self.param_set, 'lr': resume_lr, 'weight_decay': 1e-4 }]) self.start_epoch = resume_epoch else: self.start_epoch = 0 loss_type = 'ce' if loss_type == 'ce': self.loss_fn = torch.nn.CrossEntropyLoss() elif loss_type == 'kliep': self.loss_fn = MattLoss().get_loss_dict()['kliep'] self.kernel_width_averager = Averager(shape=()) if self._use_wandb: wandb.init(project="open-category-experiments", name="SVGD {}".format(self.dataset)) for model in models: wandb.watch(model) config = wandb.config config.algo = algo config.dataset = dataset config.kernel_fn = kernel_fn config.num_particles = num_particles config.loss_fn = loss_type def kernel_width(self, dist): """Update kernel_width averager and get latest kernel_width. """ if dist.ndim > 1: dist = torch.sum(dist, dim=-1) assert dist.ndim == 1, "dist must have dimension 1 or 2." width, _ = torch.median(dist, dim=0) width = width / np.log(len(dist)) self.kernel_width_averager.update(width) return self.kernel_width_averager.get() def rbf_fn(self, x, y): Nx = x.shape[0] Ny = y.shape[0] x = x.view(Nx, -1) y = y.view(Ny, -1) Dx = x.shape[1] Dy = y.shape[1] assert Dx == Dy diff = x.unsqueeze(1) - y.unsqueeze(0) # [Nx, Ny, D] dist_sq = torch.sum(diff**2, -1) # [Nx, Ny] h = self.kernel_width(dist_sq.view(-1)) kappa = torch.exp(-dist_sq / h) # [Nx, Nx] kappa_grad = torch.einsum('ij,ijk->ijk', kappa, -2 * diff / h) # [Nx, Ny, D] return kappa, kappa_grad def svgd_grad(self, loss_grad, params): """ Compute particle gradients via SVGD, empirical expectation evaluated by splitting half of the sampled batch. """ num_particles = params.shape[0] params2 = params.detach().requires_grad_(True) kernel_weight, kernel_grad = self.rbf_fn(params2, params) if kernel_grad is None: kernel_grad = torch.autograd.grad(kernel_weight.sum(), params2)[0] kernel_logp = torch.matmul(kernel_weight.t().detach(), loss_grad) / num_particles grad = kernel_logp - kernel_grad.mean(0) return grad def test(self, test_loader, eval_loss=True): for model in self.models: model.eval() correct = 0 test_loss = 0 outputs_all = [] for i, (inputs, targets) in enumerate(test_loader): preds = [] loss = 0 inputs = inputs.cuda() targets = targets.cuda() for model in self.models: outputs = model(inputs) if eval_loss: loss += self.loss_fn(outputs, targets) else: loss += 0 preds.append(torch.nn.functional.softmax(outputs, dim=-1)) pred = torch.stack(preds) outputs_all.append(pred) preds = pred.mean(0) vote = preds.argmax(-1).cpu() correct += vote.eq( targets.cpu().data.view_as(vote)).float().cpu().sum() test_loss += (loss / self.num_particles) outputs_all = torch.cat(outputs_all, dim=1) test_loss /= i correct /= len(test_loader.dataset) for model in self.models: model.train() return outputs_all, (test_loss, correct) def train(self, epochs): for epoch in range(self.start_epoch, epochs): loss_epoch = 0 for (inputs, targets) in self.train_loader: outputs = [] neglogp = torch.zeros(self.num_particles) insert_items(self.models, self.param_set, self.state_dict) neglogp_grads = torch.zeros_like(self.param_set) for i, model in enumerate(self.models): inputs = inputs.cuda() targets = targets.cuda() output = model(inputs) outputs.append(output) loss = self.loss_fn(outputs[-1], targets) loss.backward() neglogp[i] = loss g = [] for name, param in model.named_parameters(): g.append(param.grad.view(-1)) neglogp_grads[i] = torch.cat(g) model.zero_grad() par_vi_grad = self.svgd_grad(neglogp_grads, self.param_set) self.optimizer.zero_grad() self.param_set.grad = par_vi_grad self.optimizer.step() loss_step = neglogp.mean() loss_epoch += loss_step loss_epoch /= self.num_particles print('Train Epoch {} [cum loss: {}]\n'.format(epoch, loss_epoch)) if epoch % 1 == 0: insert_items(self.models, self.param_set, self.state_dict) with torch.no_grad(): outputs, stats = self.test(self.val_loader) outputs2, _ = self.test(self.test_loader, eval_loss=False) test_loss, correct = stats print('Test Loss: {}'.format(test_loss)) print('Test Acc: {}%'.format(correct * 100)) uncertainties = uncertainty(outputs) entropy, variance = uncertainties uncertainties2 = uncertainty(outputs2) entropy2, variance2 = uncertainties2 auc_entropy = auc_score(entropy, entropy2) auc_variance = auc_score(variance, variance2) print('Test AUC Entropy: {}'.format(auc_entropy)) print('Test AUC Variance: {}'.format(auc_variance)) if self._use_wandb: wandb.log({"Test Loss": test_loss}) wandb.log({"Train Loss": loss_epoch}) wandb.log({"Test Acc": correct * 100}) wandb.log({"Test AUC (entropy)": auc_entropy}) wandb.log({"Test AUC (variance)": auc_variance}) if self._save_model: params = { 'params': self.param_set, 'state_dict': self.state_dict, 'models': [m.state_dict() for m in self.models], 'optimizer': self.optimizer.state_dict() } save_dir = 'saved_models/{}/{}2/'.format( self.dataset, model_id) fn = 'model_epoch_{}.pt'.format(epoch) print('saving model: {}'.format(fn)) os.makedirs(save_dir, exist_ok=True) print('*' * 86)