def get_hessian_eigs(loss, model, mask=None, use_cuda=False, n_eigs=100, train_x=None, train_y=None, loader=None, evals=False): if train_x is not None: if use_cuda: train_x = train_x.cuda() train_y = train_y.cuda() total_pars = sum(m.numel() for m in model.parameters()) if n_eigs != -1: if mask is not None: numpars = int(mask.sum().item()) else: numpars = total_pars p = next(iter(model.parameters())) mask = torch.ones(total_pars, dtype=p.dtype, device=p.device) def hvp(rhs): padded_rhs = torch.zeros(total_pars, rhs.shape[-1], device=rhs.device, dtype=rhs.dtype) print("padded rhs shape = ", padded_rhs.shape) # print("mask shape = ", mask.shape) padded_rhs[mask==1] = rhs padded_rhs = unflatten_like(padded_rhs.t(), model.parameters()) eval_hess_vec_prod(padded_rhs, net=model, criterion=loss, inputs=train_x, targets=train_y, dataloader=loader, use_cuda=use_cuda) full_hvp = gradtensor_to_tensor(model, include_bn=True) print('norm of hvp is: ', full_hvp.norm()) sliced_hvp = full_hvp[mask==1].unsqueeze(-1) # print('finished a hvp') print("return shape = ", sliced_hvp.shape) return sliced_hvp # print('numpars is: ', numpars) if train_x is None: data = next(iter(loader))[0] if use_cuda: data = data.cuda() dtype = data.dtype device = data.device else: dtype, device = train_x.dtype, train_x.device qmat, tmat = lanczos_tridiag(hvp, n_eigs, dtype=dtype, device=device, matrix_shape=(numpars, numpars)) eigs, t_evals = lanczos_tridiag_to_diag(tmat) if evals: return eigs, qmat @ t_evals return eigs else: # form and extract sub hessian hessian = get_hessian(train_x, train_y, loss, model, use_cuda=use_cuda) keepers = np.array(np.where(mask.cpu() == 1))[0] sub_hess = hessian[np.ix_(keepers, keepers)] e_val, _ = np.linalg.eig(sub_hess.cpu().detach()) return e_val.real
def _exact_predictive_covar_inv_quad_form_root(self, expanded_lt, test_train_covar): """ Computes :math:`K_{X^{*}X} S` given a precomputed cache Where :math:`S` is a tensor such that :math:`SS^{\\top} = (K_{XX} + \sigma^2 I)^{-1}` Args: precomputed_cache (:obj:`torch.tensor`): What was computed in _exact_predictive_covar_inv_quad_form_cache test_train_covar (:obj:`torch.tensor`): The observed noise (from the likelihood) Returns :obj:`~gpytorch.lazy.LazyTensor`: :math:`K_{X^{*}X} S` """ qmats, tmats = lanczos_tridiag( expanded_lt.matmul, max_iter=settings.max_root_decomposition_size.value(), matrix_shape=expanded_lt.shape, device=expanded_lt.device, dtype=expanded_lt.dtype, ) evals, evecs = lanczos_tridiag_to_diag(tmats) self.gram_evecs = qmats @ evecs self.gram_evals = evals covar_root = self.gram_evecs @ torch.diag(evals.pow(-0.5)) return covar_root
def test_lanczos(self): size = 100 matrix = torch.randn(size, size) matrix = matrix.matmul(matrix.transpose(-1, -2)) matrix.div_(matrix.norm()) matrix.add_(torch.ones(matrix.size(-1)).mul(1e-6).diag()) q_mat, t_mat = lanczos_tridiag(matrix.matmul, max_iter=size, tensor_cls=matrix.new, n_dims=matrix.size(-1)) approx = q_mat.matmul(t_mat).matmul(q_mat.transpose(-1, -2)) self.assertTrue(approx_equal(approx, matrix))
def lanczos_tridiag_test(self, matrix): size = matrix.shape[0] q_mat, t_mat = lanczos_tridiag(matrix.matmul, max_iter=size, dtype=matrix.dtype, device=matrix.device, matrix_shape=matrix.shape) self.assert_valid_sizes(size, t_mat, q_mat) self.assert_tridiagonally_positive(t_mat) approx = q_mat.matmul(t_mat).matmul(q_mat.transpose(-1, -2)) self.assertTrue(approx_equal(approx, matrix))
def get_hessian_evals(loss, model, use_cuda=False, n_eigs=100, train_x=None, train_y=None, loader=None): if train_x is not None: if use_cuda: train_x = train_x.cuda() train_y = train_y.cuda() total_pars = sum(m.numel() for m in model.parameters()) def hvp(rhs): padded_rhs = torch.zeros(total_pars, rhs.shape[-1], device=rhs.device, dtype=rhs.dtype) padded_rhs = unflatten_like(padded_rhs.t(), model.parameters()) eval_hess_vec_prod(padded_rhs, net=model, criterion=loss, inputs=train_x, targets=train_y, dataloader=loader, use_cuda=use_cuda) full_hvp = gradtensor_to_tensor(model, include_bn=True) return full_hvp.unsqueeze(-1) # print('numpars is: ', numpars) if train_x is None: data = next(iter(loader))[0] if use_cuda: data = data.cuda() dtype = data.dtype device = data.device else: dtype, device = train_x.dtype, train_x.device qmat, tmat = lanczos_tridiag(hvp, n_eigs, dtype=dtype, device=device, matrix_shape=(total_pars, total_pars)) eigs, t_evals = lanczos_tridiag_to_diag(tmat) return eigs, qmat @ t_evals
def hessian_eigenpairs(net, criterion, inputs=None, targets=None, dataloader=None, n_eigs=20, use_cuda=torch.cuda.is_available(), verbose=False): params = [p for p in net.parameters() if len(p.size()) > 1] N = sum(p.numel() for p in params) def hess_vec_prod(vec): vec = unflatten_like(vec.t(), params) start_time = time.time() eval_hess_vec_prod(vec, params, net, criterion, inputs=inputs, targets=targets, dataloader=dataloader, use_cuda=use_cuda) prod_time = time.time() - start_time if verbose: print(" Iter: %d time: %f" % (hess_vec_prod.count, prod_time)) out = gradtensor_to_tensor(net) return out.unsqueeze(1) pos_q_mat, pos_t_mat = lanczos_tridiag( hess_vec_prod, n_eigs, device=params[0].device, dtype=params[0].dtype, matrix_shape=(N, N), ) # convert the tridiagonal t matrix to the eigenvalues e_vals, e_vecs = lanczos_tridiag_to_diag(pos_t_mat) ## GOING TO NEED TO CHANGE E_VECS HERE ## e_vecs = pos_q_mat.matmul(e_vecs) return e_vals, e_vecs
def min_max_hessian_eigs(net, dataloader, criterion, n_top_eigs=3, n_bottom_eigs=50, use_cuda=False): """ Compute the largest and the smallest eigenvalues of the Hessian marix. Args: net: the trained model. dataloader: dataloader for the dataset, may use a subset of it. criterion: loss function. use_cuda: use GPU """ params = [p for p in net.parameters() if len(p.size()) > 1] N = sum(p.numel() for p in net.parameters()) p = next(iter(net.parameters())) mask = torch.ones(N, dtype=p.dtype, device=p.device) def hess_vec_prod(vec): padded_rhs = torch.zeros(N, vec.shape[-1], device=vec.device, dtype=vec.dtype) padded_rhs[mask==1] = vec print("vec shape = ", vec.shape) print("padded shape = ", padded_rhs.shape) hess_vec_prod.count += 1 # simulates a static variable padded_rhs = unflatten_like(padded_rhs.t(), net.parameters()) start_time = time.time() eval_hess_vec_prod(padded_rhs, net=net, criterion=criterion, dataloader=dataloader, use_cuda=use_cuda) prod_time = time.time() - start_time out = gradtensor_to_tensor(net, include_bn=True) sliced = out[mask==1].unsqueeze(-1) print("sliced shape = ", sliced.shape) return sliced hess_vec_prod.count = 0 # use lanczos to get the t and q matrices out pos_q_mat, pos_t_mat = lanczos_tridiag( hess_vec_prod, n_top_eigs, device=params[0].device, dtype=params[0].dtype, matrix_shape=(N, N), ) # convert the tridiagonal t matrix to the eigenvalues pos_eigvals, pos_eigvecs = lanczos_tridiag_to_diag(pos_t_mat) pos_eigvecs = pos_q_mat @ pos_eigvecs # If the largest eigenvalue is positive, shift matrix so that any negative eigenvalue is now the largest # We assume the smallest eigenvalue is zero or less, and so this shift is more than what we need # shift = maxeig*.51 shift = 0.51 * pos_eigvals.max().item() print("Pos Eigs Computed....\n") def shifted_hess_vec_prod(vec): hvp = hess_vec_prod(vec) return -hvp + shift * vec # now run lanczos on the shifted eigenvalues neg_q_mat, neg_t_mat = lanczos_tridiag( shifted_hess_vec_prod, n_bottom_eigs, device=params[0].device, dtype=params[0].dtype, matrix_shape=(N, N), ) neg_eigvals, neg_eigvecs = lanczos_tridiag_to_diag(neg_t_mat) neg_eigvecs = neg_q_mat @ neg_eigvecs print("Neg Eigs Computed...") print("neg eigs = ", neg_eigvals) neg_eigvals = -neg_eigvals + shift #return maxeig, mineig, hess_vec_prod.count, pos_eigvals, neg_eigvals, pos_bases return pos_eigvals, pos_eigvecs, neg_eigvals, neg_eigvecs
def min_max_hessian_eigs(net, dataloader, criterion, rank=0, use_cuda=False, verbose=False, nsteps=100): """ Compute the largest and the smallest eigenvalues of the Hessian marix. Args: net: the trained model. dataloader: dataloader for the dataset, may use a subset of it. criterion: loss function. rank: rank of the working node. use_cuda: use GPU verbose: print more information Returns: maxeig: max eigenvalue mineig: min eigenvalue hess_vec_prod.count: number of iterations for calculating max and min eigenvalues """ params = [p for p in net.parameters() if len(p.size()) > 1] N = sum(p.numel() for p in params) def hess_vec_prod(vec): hess_vec_prod.count += 1 # simulates a static variable vec = unflatten_like(vec.t(), params) start_time = time.time() eval_hess_vec_prod(vec, params, net, criterion, dataloader, use_cuda) prod_time = time.time() - start_time if verbose and rank == 0: print(" Iter: %d time: %f" % (hess_vec_prod.count, prod_time)) out = gradtensor_to_tensor(net) return out.unsqueeze(1) hess_vec_prod.count = 0 if verbose and rank == 0: print("Rank %d: computing max eigenvalue" % rank) # use lanczos to get the t and q matrices out pos_q_mat, pos_t_mat = lanczos_tridiag( hess_vec_prod, nsteps, device=params[0].device, dtype=params[0].dtype, matrix_shape=(N, N), ) # convert the tridiagonal t matrix to the eigenvalues pos_eigvals, pos_eigvecs = lanczos_tridiag_to_diag(pos_t_mat) print(pos_eigvals) # eigenvalues may not be sorted maxeig = torch.max(pos_eigvals) pos_bases = pos_q_mat @ pos_eigvecs # maxeig = pos_eigvals[0] if verbose and rank == 0: print("max eigenvalue = %f" % maxeig) # # If the largest eigenvalue is positive, shift matrix so that any negative eigenvalue is now the largest # # We assume the smallest eigenvalue is zero or less, and so this shift is more than what we need # # shift = maxeig*.51 # shift = 0.51 * maxeig.item() # print(shift) # def shifted_hess_vec_prod(vec): # hvp = hess_vec_prod(vec) # return -hvp + shift * vec # if verbose and rank == 0: # print("Rank %d: Computing shifted eigenvalue" % rank) # # now run lanczos on the shifted eigenvalues # _, neg_t_mat = lanczos_tridiag( # shifted_hess_vec_prod, # 200, # device=params[0].device, # dtype=params[0].dtype, # matrix_shape=(N, N), # ) # neg_eigvals, _ = lanczos_tridiag_to_diag(neg_t_mat) # mineig = torch.max(neg_eigvals) # print(neg_eigvals) # mineig = -mineig + shift # print(mineig) # if verbose and rank == 0: # print("min eigenvalue = " + str(mineig)) # if maxeig <= 0 and mineig > 0: # maxeig, mineig = mineig, maxeig #return maxeig, mineig, hess_vec_prod.count, pos_eigvals, neg_eigvals, pos_bases return maxeig, None, hess_vec_prod.count, pos_eigvals, None, pos_t_mat
def min_max_fisher_eigs(net, dataloader, criterion, rank=0, use_cuda=False, verbose=False, fvp_method='FVP_FD'): """ Compute the largest and the smallest eigenvalues of the Hessian marix. Args: net: the trained model. dataloader: dataloader for the dataset, may use a subset of it. criterion: loss function. rank: rank of the working node. use_cuda: use GPU verbose: print more information Returns: maxeig: max eigenvalue mineig: min eigenvalue hess_vec_prod.count: number of iterations for calculating max and min eigenvalues """ if fvp_method == 'FVP_FD': print('Using FD for FVP') fvp_matmul = FVP_FD elif fvp_method == 'FVP_AG': print('Using AG for FVP') fvp_matmul = FVP_AG else: raise NotImplementedError( 'Only FD and AG have been implemented so far.') params = [p for p in net.parameters()] #if len(p.size()) > 1] N = sum(p.numel() for p in params) def fisher_vec_prod(vec): fisher_vec_prod.count += 1 # simulates a static variable vec = unflatten_like(vec.t(), net.parameters()) start_time = time.time() out = eval_fisher_vec_prod(vec, net, dataloader, use_cuda, fvp_matmul=fvp_matmul) prod_time = time.time() - start_time if verbose and rank == 0: print(" Iter: %d time: %f" % (fisher_vec_prod.count, prod_time)) #out = gradtensor_to_tensor(net) return out fisher_vec_prod.count = 0 if verbose and rank == 0: print("Rank %d: computing max eigenvalue" % rank) # use lanczos to get the t and q matrices out _, pos_t_mat = lanczos_tridiag(fisher_vec_prod, 100, device=params[0].device, dtype=params[0].dtype, matrix_shape=(N, N)) # convert the tridiagonal t matrix to the eigenvalues pos_eigvals, _ = lanczos_tridiag_to_diag(pos_t_mat) print(pos_eigvals) # eigenvalues may not be sorted maxeig = torch.max(pos_eigvals) #maxeig = pos_eigvals[0] if verbose and rank == 0: print('max eigenvalue = %f' % maxeig) # If the largest eigenvalue is positive, shift matrix so that any negative eigenvalue is now the largest # We assume the smallest eigenvalue is zero or less, and so this shift is more than what we need #shift = maxeig*.51 shift = 0.51 * maxeig.item() print(shift) def shifted_hess_vec_prod(vec): hvp = fisher_vec_prod(vec) return -hvp + shift * vec if verbose and rank == 0: print("Rank %d: Computing shifted eigenvalue" % rank) # now run lanczos on the shifted eigenvalues _, neg_t_mat = lanczos_tridiag(shifted_hess_vec_prod, 200, device=params[0].device, dtype=params[0].dtype, matrix_shape=(N, N)) neg_eigvals, _ = lanczos_tridiag_to_diag(neg_t_mat) mineig = torch.max(neg_eigvals) print(neg_eigvals) mineig = -mineig + shift print(mineig) if verbose and rank == 0: print('min eigenvalue = ' + str(mineig)) if maxeig <= 0 and mineig > 0: maxeig, mineig = mineig, maxeig return maxeig, mineig, fisher_vec_prod.count, pos_eigvals, neg_eigvals
def compute_eigenspectrum( dataset: str, data_path: str, model: str, checkpoint_path: str, curvature_matrix: str = 'hessian_lanczos', use_test: bool = True, batch_size: int = 128, num_workers: int = 4, swag: bool = False, lanczos_iters: int = 100, num_subsamples: int = None, subsample_seed: int = None, bn_train_mode: bool = True, save_spectrum_path: str = None, save_eigvec: bool = False, seed: int = None, device: str = 'cuda', ): """ This function takes a deep learning model and compute the eigenvalues and eigenvectors (if desired) of the deep learning model, either using Lanczos algorithm or using Backpack [1] interface of diagonal approximation of the various curvature matrix. Parameters ---------- dataset: str: ['CIFAR10', 'CIFAR100', 'MNIST', 'ImageNet32'*]: the dataset on which you would like to train the model. For ImageNet 32, we use the downsampled 32 x 32 Full ImageNet dataset. We do not provide download due to the proprietary issues, and please drop the data of ImageNet 32 in 'data/' folder data_path data_path: str: the path string of the dataset model: str: the neural network architecture you would like to train. All available models are listed under 'models'/ Example: VGG16BN, PreResNet110 (Preactivated ResNet - 110 layers) checkpoint_path: str: the path string to the checkpoints generated by train_network, which contains the state_dict of the network and the optimizer. curvature_matrix: str: the type of curvature matrix and computation method desired. Possible values are: hessian_lanczos: Lanczos algorithm of Hessian matrix ggn_lanczos: Lanczos algorithm on Generalised Gauss-Newton (GGN) cov_grad_lancozs: Lanczos algorithm on Covariance of Gradients WARNING: the Backpack package (the diagonal computation interface) we use does not support Residual layers in ResNets and derived networks (as of 14 Dec 2019), Further, it constrains the model to be a subclass of nn.Sequential. We have written modified VGG16 for this purpose, but there is no guarantee that other models will work as-is. use_test: bool: if True, you will test the model on the test set. If not, a portion of the training data will be assigned as the validation set. batch_size: int: the minibatch size num_workers: int: number of workers for the dataloader swag: whether to use Stochastic Weight Averaging (Gaussian) lanczos_iters: *only applicable if the curvature_matrix is set to hessian_lanczos, ggn_lanczos or cov_grad_lanczos* Number of iterations for the Lanczos algorithm. This also determines the Ritz value - vector pair generated from the Eigenspectrum. num_subsamples: int: Number of subsamples to draw randomly from the training dataset. If None, the entire dataset will be used. subsample_seed: int: the Pseudorandom number seed for subsample draw from above. bn_train_mode: bool: Applies only if the network architecture (''model'') used contains batch normalization layers. Toggles whether BN layers should be in train or eval mode. save_spectrum_path: str: If provided, the Ritz value generated (or the diagonal approximation) will be saved to this poth. save_eigvec: bool: If True, the implied eigenvectors will also be saved to the same format. Note: When this is true, instead of converting the arrays to numpy.ndarray we save directly the torch Tensor. The eigenvectors have size P, where P is the number of parameters in the model, so turning this mode on while running a large number of experiments could take lots of storage. seed: if not None, a manual seed for the pseudo-random number generation will be used. device: ['cpu', 'cuda']: the device on which the model and all computations are performed. Strongly recommend 'cuda' for GPU accleration in CUDA-enabled Nvidia Devices Returns ------- (eigvals, gammas, V): eigvals: the computed Ritz Value / diagonal elements of the curvature matrix gammas: V: """ if device == 'cuda': if not torch.cuda.is_available(): device = 'cpu' assert curvature_matrix in [ 'hessian_lanczos', 'ggn_lanczos', 'cov_grad_lanczos', ] torch.backends.cudnn.benchmark = True if seed is not None: torch.manual_seed(seed) torch.cuda.manual_seed(seed) print('Using model %s' % model) model_cfg = getattr(models, model) datasets, num_classes = data.datasets( dataset, data_path, transform_train=model_cfg.transform_test, transform_test=model_cfg.transform_test, use_validation=not use_test, train_subset=num_subsamples, train_subset_seed=subsample_seed, ) loader = torch.utils.data.DataLoader(datasets['train'], batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True) full_datasets, _ = data.datasets( dataset, data_path, transform_train=model_cfg.transform_train, transform_test=model_cfg.transform_test, use_validation=not use_test, ) full_loader = torch.utils.data.DataLoader(full_datasets['train'], batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True) print('Preparing model') print(*model_cfg.args, dict(**model_cfg.kwargs)) if not swag: model = model_cfg.base(*model_cfg.args, num_classes=num_classes, **model_cfg.kwargs) print('Loading %s' % checkpoint_path) checkpoint = torch.load(checkpoint_path) model.load_state_dict(checkpoint['state_dict']) else: swag_model = SWAG(model_cfg.base, subspace_type='random', *model_cfg.args, num_classes=num_classes, **model_cfg.kwargs) print('Loading %s' % checkpoint_path) checkpoint = torch.load(checkpoint_path) swag_model.load_state_dict(checkpoint['state_dict'], strict=False) swag_model.set_swa() model = swag_model.base_model model.to(device) num_parametrs = sum([p.numel() for p in model.parameters()]) criterion = losses.cross_entropy class CurvVecProduct(object): def __init__(self, loader, model, criterion, curvature_matrix, full_loader=None): self.loader = loader self.full_loader = full_loader self.model = model self.criterion = criterion self.iters = 0 self.timestamp = time.time() self.curvature_matrix = curvature_matrix def __call__(self, vector): start_time = time.time() if self.curvature_matrix == 'hessian_lanczos': output = utils.hess_vec( vector, self.loader, self.model, self.criterion, cuda=device == 'cuda', bn_train_mode=bn_train_mode, ) elif self.curvature_matrix == 'ggn_lanczos': output = utils.gn_vec(vector, self.loader, self.model, self.criterion, cuda=device == 'cuda', bn_train_mode=bn_train_mode) elif self.curvature_matrix == 'cov_grad_lanczos': output = utils.covgrad_vec(vector, self.loader, self.model, self.criterion, cuda=device == 'cuda', bn_train_mode=bn_train_mode) else: raise ValueError("Unrecognised curvature_matrix argument " + self.curvature_matrix) time_diff = time.time() - start_time self.iters += 1 print('Iter %d. Time: %.2f' % (self.iters, time_diff)) # return output.unsqueeze(1)¬ return output.cpu().unsqueeze(1) w = torch.cat( [param.detach().cpu().view(-1) for param in model.parameters()]) productor = CurvVecProduct(loader, model, criterion, curvature_matrix) utils.bn_update(full_loader, model) Q, T = lanczos_tridiag(productor, lanczos_iters, dtype=torch.float32, device='cpu', matrix_shape=(num_parametrs, num_parametrs)) eigvals, eigvects = T.eig(eigenvectors=True) gammas = eigvects[0, :]**2 V = eigvects.t() @ Q.t() if save_spectrum_path is not None: if save_eigvec: torch.save( { 'w': w, 'eigvals': eigvals if eigvals is not None else None, 'gammas': gammas if gammas is not None else None, 'V': V if V is not None else None, }, save_spectrum_path, ) np.savez(save_spectrum_path, w=w.numpy(), eigvals=eigvals.numpy() if eigvals is not None else None, gammas=gammas.numpy() if gammas is not None else None) return {'w': w, 'eigvals': eigvals, 'gammas': gammas, 'V': V}
def min_max_hessian_eigs(net, dataloader, criterion, rank=0, use_cuda=False, verbose=False, nsteps=100, return_evecs=False): """ Compute the largest and the smallest eigenvalues of the Hessian marix. Args: net: the trained model. dataloader: dataloader for the dataset, may use a subset of it. criterion: loss function. rank: rank of the working node. use_cuda: use GPU verbose: print more information Returns: maxeig: max eigenvalue mineig: min eigenvalue hess_vec_prod.count: number of iterations for calculating max and min eigenvalues """ params = [p for p in net.parameters() if len(p.size()) > 1] N = sum(p.numel() for p in params) def hess_vec_prod(vec): hess_vec_prod.count += 1 # simulates a static variable vec = unflatten_like(vec.t(), params) start_time = time.time() eval_hess_vec_prod(vec, params, net, criterion, dataloader, use_cuda) prod_time = time.time() - start_time if verbose and rank == 0: print(" Iter: %d time: %f" % (hess_vec_prod.count, prod_time)) out = gradtensor_to_tensor(net) return out.unsqueeze(1) hess_vec_prod.count = 0 if verbose and rank == 0: print("Rank %d: computing max eigenvalue" % rank) # use lanczos to get the t and q matrices out pos_q_mat, pos_t_mat = lanczos_tridiag( hess_vec_prod, nsteps, device=params[0].device, dtype=params[0].dtype, matrix_shape=(N, N), ) # convert the tridiagonal t matrix to the eigenvalues pos_eigvals, pos_eigvecs = lanczos_tridiag_to_diag(pos_t_mat) print(pos_eigvals) # eigenvalues may not be sorted maxeig = torch.max(pos_eigvals) pos_bases = pos_q_mat @ pos_eigvecs if verbose and rank == 0: print("max eigenvalue = %f" % maxeig) if not return_evecs: return maxeig, None, hess_vec_prod.count, pos_eigvals, None, pos_t_mat else: return maxeig, None, hess_vec_prod.count, pos_eigvals, pos_bases, pos_t_mat
def criterion(current_pars, input_data, target, return_predictions=True): r""" Loss function for Laplace current_pars (list/iterable): parameter list input_data (tensor): input data for model target (tensor): response return_predictions (bool):if predictions should be returned as well as loss """ if eval_mode: # this means prediction time # so do a Fisher vector product + jitter, take the tmatrix invert the cholesky decomp and sample # F \approx Q T Q' => F^{-1} \approx Q T^{-1} Q' # F^{-1/2} \approx Q T^{-1/2} fvp = ((num_data / input_data.shape[0]) * FVP_FD(model, input_data)).add_jitter(1.0) qmat, tmat = lanczos_tridiag( fvp.matmul, 30, dtype=current_pars[0].dtype, device=current_pars[0].device, init_vecs=None, matrix_shape=[ current_pars[0].shape[0], current_pars[0].shape[0] ], ) eigs, evecs = torch.symeig(tmat, eigenvectors=True) # only consider the top half of the eigenvalues bc they're reliable eigs_gt_zero = torch.sort(eigs)[1][-int(tmat.shape[0] / 2):] # update the eigendecomposition # note that @ is a matmul updated_evecs = (qmat @ evecs)[:, eigs_gt_zero] z = torch.randn(eigs_gt_zero.shape[0], 1, device=tmat.device, dtype=tmat.dtype) approx_lz = updated_evecs @ torch.diag( 1.0 / eigs[eigs_gt_zero].pow(0.5)) @ z sample = current_pars[0] + approx_lz else: sample = current_pars[0] rhs = sample if bias: rhs = sample - model_pars.view(-1, 1) predictions = Jacobian(model=model, data=input_data, num_outputs=1)._t_matmul(rhs) predictions_reshaped = predictions.reshape(target.shape[0], num_classes) if bias: predictions_reshaped = predictions_reshaped + model(input_data) loss = ( torch.nn.functional.cross_entropy(predictions_reshaped, target) * target.shape[0]) regularizer = current_pars[0].norm() * wd if eval_mode: output = loss else: output = num_data * loss + regularizer return output, predictions_reshaped