def perturb_model(model, sigma, n_pars, use_cuda): perturb = torch.randn(n_pars) * sigma if use_cuda: perturb = perturb.cuda() perturb = utils.unflatten_like(perturb.unsqueeze(0), model.parameters()) for i, par in enumerate(model.parameters()): par.data = par.data + perturb[i] return
def hess_vec_prod(vec): hess_vec_prod.count += 1 # simulates a static variable vec = unflatten_like(vec.t(), params) start_time = time.time() eval_hess_vec_prod(vec, params, net, criterion, dataloader, use_cuda) prod_time = time.time() - start_time if verbose and rank == 0: print(" Iter: %d time: %f" % (hess_vec_prod.count, prod_time)) out = gradtensor_to_tensor(net) return out.unsqueeze(1)
def hvp(rhs): padded_rhs = torch.zeros(total_pars, rhs.shape[-1], device=rhs.device, dtype=rhs.dtype) padded_rhs = unflatten_like(padded_rhs.t(), model.parameters()) eval_hess_vec_prod(padded_rhs, net=model, criterion=loss, inputs=train_x, targets=train_y, dataloader=loader, use_cuda=use_cuda) full_hvp = gradtensor_to_tensor(model, include_bn=True) return full_hvp.unsqueeze(-1)
def hess_vec_prod(vec): vec = unflatten_like(vec.t(), params) start_time = time.time() eval_hess_vec_prod(vec, params, net, criterion, inputs=inputs, targets=targets, dataloader=dataloader, use_cuda=use_cuda) prod_time = time.time() - start_time if verbose: print(" Iter: %d time: %f" % (hess_vec_prod.count, prod_time)) out = gradtensor_to_tensor(net) return out.unsqueeze(1)
def hess_vec_prod(vec): padded_rhs = torch.zeros(N, vec.shape[-1], device=vec.device, dtype=vec.dtype) padded_rhs[mask==1] = vec print("vec shape = ", vec.shape) print("padded shape = ", padded_rhs.shape) hess_vec_prod.count += 1 # simulates a static variable padded_rhs = unflatten_like(padded_rhs.t(), net.parameters()) start_time = time.time() eval_hess_vec_prod(padded_rhs, net=net, criterion=criterion, dataloader=dataloader, use_cuda=use_cuda) prod_time = time.time() - start_time out = gradtensor_to_tensor(net, include_bn=True) sliced = out[mask==1].unsqueeze(-1) print("sliced shape = ", sliced.shape) return sliced
def get_loss_surface(basis, model, dataloader, criterion, rng=0.1, n_pts=25, use_cuda=False): """ note that loss should be a lambda function that just takes in the model! """ start_pars = model.state_dict() ## get out the plane ## dir1, dir2 = get_plane(basis) ## init loss surface and the vector multipliers ## loss_surf = torch.zeros(n_pts, n_pts) vec_len = torch.linspace(-rng / 2., rng / 2., n_pts) ## loop and get loss at each point ## for ii in range(n_pts): for jj in range(n_pts): perturb = dir1.mul(vec_len[ii]) + dir2.mul(vec_len[jj]) # print(perturb.shape) perturb = utils.unflatten_like(perturb.t(), model.parameters()) for i, par in enumerate(model.parameters()): if use_cuda: par.data = par.data + perturb[i].cuda() else: par.data = par.data + perturb[i] loss_surf[ii, jj] = loss_getter(model, dataloader, criterion, use_cuda) model.load_state_dict(start_pars) return loss_surf
def main(): ## generate model and load in trained instance ## use_cuda = torch.cuda.is_available() model = Net() saved_model = torch.load("./model.pt", map_location=('cpu')) model.load_state_dict(saved_model) if use_cuda: model = model.cuda() transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) trainset = torchvision.datasets.CIFAR10(root='/datasets/cifar10/', train=True, download=False, transform=transform) trainloader = torch.utils.data.DataLoader(trainset, batch_size=32, shuffle=False, num_workers=2) testset = torchvision.datasets.CIFAR10(root='/datasets/cifar10/', train=False, download=False, transform=transform) testloader = torch.utils.data.DataLoader(testset, batch_size=32, shuffle=False, num_workers=2) classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck') ## load in eigenpairs and clean up ## fpath = "./" fname = "cifar_evecs_200.pt" evecs = torch.load(fpath + fname, map_location=("cpu")).squeeze() fname = "cifar_evals_200.pt" evals = torch.load(fpath + fname, map_location=("cpu")) keep = np.where(evals != 1)[0] n_evals = keep.size evals = evals[keep].numpy() evecs = evecs[:, keep].numpy() idx = np.abs(evals).argsort()[::-1] evals = torch.FloatTensor(evals[idx]) evecs = torch.FloatTensor(evecs[:, idx]) pars = utils.flatten(model.parameters()) n_par = pars.numel() par_len = pars.norm() criterion = nn.CrossEntropyLoss() ## going to need the original model predictions ## model_preds = [] for i, data in enumerate(trainloader, 0): # get the inputs; data is a list of [inputs, labels] inputs, labels = data if use_cuda: inputs, labels = inputs.cuda(), labels.cuda() outputs = model(inputs) _, predicted = torch.max(outputs, 1) model_preds.append(predicted.cpu()) n_scale = 20 n_trial = 10 scales = torch.linspace(0, 1., n_scale) ## Test high curvature directions ## high_curve_losses = torch.zeros(n_scale, n_trial) n_diff_high = torch.zeros(n_scale, n_trial) for ii in range(n_scale): for tt in range(n_trials): alpha = torch.randn(n_evals) pert = evecs.matmul(alpha.unsqueeze(-1)).t() pert = scales[ii] * pert.div(pert.norm()) if use_cuda: pert = pert.cuda() pert = utils.unflatten_like(pert, model.parameters()) ## perturb ## for i, par in enumerate(model.parameters()): par.data = par.data + pert[i] ## compute the loss and label diffs ## train_loss, train_diff = compute_loss_differences( model, trainloader, criterion, model_preds) high_curve_losses[ii, tt] = train_loss n_diff_high[ii, tt] = train_diff ## need to reload pars after each perturbation ## model.load_state_dict(saved_model) ## just to track progress ## print("high curve scale {} of {} done".format(ii, n_scale)) ## save the high curvature results ## fpath = "./" fname = "high_curve_losses.pt" torch.save(high_curve_losses, fpath + fname) fname = "n_diff_high.pt" torch.save(n_diff_high, fpath + fname) print("all high curvature done \n\n") ## go through the low curvature directions ## low_curve_losses = torch.zeros(n_scale, n_trial) n_diff_low = torch.zeros(n_scale, n_trial) for ii in range(n_scale): for tt in range(n_trials): alpha = torch.randn(n_evals) # random direction pert = gram_schmidt(alpha, evecs).unsqueeze(-1).t() # orthogonal to evecs pert = scales[ii] * pert.div(pert.norm()) # scaled correctly if use_cuda: pert = pert.cuda() pert = utils.unflatten_like(pert, model.parameters()) ## go through trainloader and keep track of losses/differences in preds ## for i, par in enumerate(model.parameters()): par.data = par.data + pert[i] ## compute the loss and label diffs ## train_loss, train_diff = compute_loss_differences( model, trainloader, criterion, model_preds) low_curve_losses[ii, tt] = train_loss n_diff_low[ii, tt] = train_diff model.load_state_dict(saved_model) print("low curve scale {} of {} done".format(ii, n_scale)) ## save the high curvature results ## fpath = "./" fname = "low_curve_losses.pt" torch.save(low_curve_losses, fpath + fname) fname = "n_diff_low.pt" torch.save(n_diff_low, fpath + fname) print("\n\n Gucci.")
def main(): ########################## ## SET UP TRAINING DATA ## ########################## X, Y = twospirals(500, noise=1.5) train_x, train_y = torch.FloatTensor(X), torch.FloatTensor(Y).unsqueeze(-1) test_X, test_Y = twospirals(100, 1.5) test_x, test_y = torch.FloatTensor(test_X), torch.FloatTensor( test_Y).unsqueeze(-1) use_cuda = torch.cuda.is_available() if use_cuda: torch.cuda.set_device(2) torch.set_default_tensor_type(torch.cuda.FloatTensor) train_x, train_y = train_x.cuda(), train_y.cuda() ############################# ## SOME HYPERS AND STORAGE ## ############################# widths = [i for i in range(5, 6)] loss_func = torch.nn.BCEWithLogitsLoss() in_dim = 2 out_dim = 1 hessians = [] n_pars = [] test_errors = [] ############### ## MAIN LOOP ## ############### for width_ind, width in enumerate(widths): model = hess.nets.SimpleNet(in_dim, out_dim, n_hidden=5, hidden_size=20, activation=torch.nn.ELU(), bias=True) optimizer = torch.optim.Adam(model.parameters(), lr=0.01) n_par = sum(p.numel() for p in model.parameters()) n_pars.append(n_par) ## TRAIN MODEL ## for step in range(2000): optimizer.zero_grad() outputs = model(train_x) loss = loss_func(outputs, train_y) loss.backward() optimizer.step() print("model %i trained" % width) hessian = torch.zeros(n_par, n_par) for pp in range(n_par): base_vec = torch.zeros(n_par).unsqueeze(0) base_vec[0, pp] = 1. base_vec = utils.unflatten_like(base_vec, model.parameters()) utils.eval_hess_vec_prod(base_vec, model, criterion=torch.nn.BCEWithLogitsLoss(), inputs=train_x, targets=train_y) if pp == 0: output = utils.gradtensor_to_tensor(model, include_bn=True) hessian = torch.zeros(output.nelement(), output.nelement()) hessian[:, pp] = output hessian[:, pp] = utils.gradtensor_to_tensor(model, include_bn=True).cpu() ## SAVE THOSE OUTPUTS ## hessians.append(hessian) test_errors.append(loss_func(model(test_x), test_y).item()) ## SAVE EVERYTHING ## fpath = "./" fname = "hessians.P" with open(fpath + fname, 'wb') as fp: pickle.dump(hessians, fp) fname = "test_errors.P" with open(fpath + fname, 'wb') as fp: pickle.dump(test_errors, fp) fname = "n_pars.P" with open(fpath + fname, 'wb') as fp: pickle.dump(n_pars, fp) fname = "widths.pt" torch.save(widths, fpath + fname)