def hess_vec_prod(vec): hess_vec_prod.count += 1 # simulates a static variable vec = unflatten_like(vec.t(), params) start_time = time.time() eval_hess_vec_prod(vec, params, net, criterion, dataloader, use_cuda) prod_time = time.time() - start_time if verbose and rank == 0: print(" Iter: %d time: %f" % (hess_vec_prod.count, prod_time)) out = gradtensor_to_tensor(net) return out.unsqueeze(1)
def hvp(rhs): padded_rhs = torch.zeros(total_pars, rhs.shape[-1], device=rhs.device, dtype=rhs.dtype) padded_rhs = unflatten_like(padded_rhs.t(), model.parameters()) eval_hess_vec_prod(padded_rhs, net=model, criterion=loss, inputs=train_x, targets=train_y, dataloader=loader, use_cuda=use_cuda) full_hvp = gradtensor_to_tensor(model, include_bn=True) return full_hvp.unsqueeze(-1)
def hess_vec_prod(vec): vec = unflatten_like(vec.t(), params) start_time = time.time() eval_hess_vec_prod(vec, params, net, criterion, inputs=inputs, targets=targets, dataloader=dataloader, use_cuda=use_cuda) prod_time = time.time() - start_time if verbose: print(" Iter: %d time: %f" % (hess_vec_prod.count, prod_time)) out = gradtensor_to_tensor(net) return out.unsqueeze(1)
def hess_vec_prod(vec): padded_rhs = torch.zeros(N, vec.shape[-1], device=vec.device, dtype=vec.dtype) padded_rhs[mask==1] = vec print("vec shape = ", vec.shape) print("padded shape = ", padded_rhs.shape) hess_vec_prod.count += 1 # simulates a static variable padded_rhs = unflatten_like(padded_rhs.t(), net.parameters()) start_time = time.time() eval_hess_vec_prod(padded_rhs, net=net, criterion=criterion, dataloader=dataloader, use_cuda=use_cuda) prod_time = time.time() - start_time out = gradtensor_to_tensor(net, include_bn=True) sliced = out[mask==1].unsqueeze(-1) print("sliced shape = ", sliced.shape) return sliced
def main(): ########################## ## SET UP TRAINING DATA ## ########################## X, Y = twospirals(500, noise=1.5) train_x, train_y = torch.FloatTensor(X), torch.FloatTensor(Y).unsqueeze(-1) test_X, test_Y = twospirals(100, 1.5) test_x, test_y = torch.FloatTensor(test_X), torch.FloatTensor( test_Y).unsqueeze(-1) use_cuda = torch.cuda.is_available() if use_cuda: torch.cuda.set_device(2) torch.set_default_tensor_type(torch.cuda.FloatTensor) train_x, train_y = train_x.cuda(), train_y.cuda() ############################# ## SOME HYPERS AND STORAGE ## ############################# widths = [i for i in range(5, 6)] loss_func = torch.nn.BCEWithLogitsLoss() in_dim = 2 out_dim = 1 hessians = [] n_pars = [] test_errors = [] ############### ## MAIN LOOP ## ############### for width_ind, width in enumerate(widths): model = hess.nets.SimpleNet(in_dim, out_dim, n_hidden=5, hidden_size=20, activation=torch.nn.ELU(), bias=True) optimizer = torch.optim.Adam(model.parameters(), lr=0.01) n_par = sum(p.numel() for p in model.parameters()) n_pars.append(n_par) ## TRAIN MODEL ## for step in range(2000): optimizer.zero_grad() outputs = model(train_x) loss = loss_func(outputs, train_y) loss.backward() optimizer.step() print("model %i trained" % width) hessian = torch.zeros(n_par, n_par) for pp in range(n_par): base_vec = torch.zeros(n_par).unsqueeze(0) base_vec[0, pp] = 1. base_vec = utils.unflatten_like(base_vec, model.parameters()) utils.eval_hess_vec_prod(base_vec, model, criterion=torch.nn.BCEWithLogitsLoss(), inputs=train_x, targets=train_y) if pp == 0: output = utils.gradtensor_to_tensor(model, include_bn=True) hessian = torch.zeros(output.nelement(), output.nelement()) hessian[:, pp] = output hessian[:, pp] = utils.gradtensor_to_tensor(model, include_bn=True).cpu() ## SAVE THOSE OUTPUTS ## hessians.append(hessian) test_errors.append(loss_func(model(test_x), test_y).item()) ## SAVE EVERYTHING ## fpath = "./" fname = "hessians.P" with open(fpath + fname, 'wb') as fp: pickle.dump(hessians, fp) fname = "test_errors.P" with open(fpath + fname, 'wb') as fp: pickle.dump(test_errors, fp) fname = "n_pars.P" with open(fpath + fname, 'wb') as fp: pickle.dump(n_pars, fp) fname = "widths.pt" torch.save(widths, fpath + fname)