def get_hessian_eigs(loss, model, mask=None, use_cuda=False, n_eigs=100, train_x=None, train_y=None, loader=None, evals=False): if train_x is not None: if use_cuda: train_x = train_x.cuda() train_y = train_y.cuda() total_pars = sum(m.numel() for m in model.parameters()) if n_eigs != -1: if mask is not None: numpars = int(mask.sum().item()) else: numpars = total_pars p = next(iter(model.parameters())) mask = torch.ones(total_pars, dtype=p.dtype, device=p.device) def hvp(rhs): padded_rhs = torch.zeros(total_pars, rhs.shape[-1], device=rhs.device, dtype=rhs.dtype) print("padded rhs shape = ", padded_rhs.shape) # print("mask shape = ", mask.shape) padded_rhs[mask==1] = rhs padded_rhs = unflatten_like(padded_rhs.t(), model.parameters()) eval_hess_vec_prod(padded_rhs, net=model, criterion=loss, inputs=train_x, targets=train_y, dataloader=loader, use_cuda=use_cuda) full_hvp = gradtensor_to_tensor(model, include_bn=True) print('norm of hvp is: ', full_hvp.norm()) sliced_hvp = full_hvp[mask==1].unsqueeze(-1) # print('finished a hvp') print("return shape = ", sliced_hvp.shape) return sliced_hvp # print('numpars is: ', numpars) if train_x is None: data = next(iter(loader))[0] if use_cuda: data = data.cuda() dtype = data.dtype device = data.device else: dtype, device = train_x.dtype, train_x.device qmat, tmat = lanczos_tridiag(hvp, n_eigs, dtype=dtype, device=device, matrix_shape=(numpars, numpars)) eigs, t_evals = lanczos_tridiag_to_diag(tmat) if evals: return eigs, qmat @ t_evals return eigs else: # form and extract sub hessian hessian = get_hessian(train_x, train_y, loss, model, use_cuda=use_cuda) keepers = np.array(np.where(mask.cpu() == 1))[0] sub_hess = hessian[np.ix_(keepers, keepers)] e_val, _ = np.linalg.eig(sub_hess.cpu().detach()) return e_val.real
def _exact_predictive_covar_inv_quad_form_root(self, expanded_lt, test_train_covar): """ Computes :math:`K_{X^{*}X} S` given a precomputed cache Where :math:`S` is a tensor such that :math:`SS^{\\top} = (K_{XX} + \sigma^2 I)^{-1}` Args: precomputed_cache (:obj:`torch.tensor`): What was computed in _exact_predictive_covar_inv_quad_form_cache test_train_covar (:obj:`torch.tensor`): The observed noise (from the likelihood) Returns :obj:`~gpytorch.lazy.LazyTensor`: :math:`K_{X^{*}X} S` """ qmats, tmats = lanczos_tridiag( expanded_lt.matmul, max_iter=settings.max_root_decomposition_size.value(), matrix_shape=expanded_lt.shape, device=expanded_lt.device, dtype=expanded_lt.dtype, ) evals, evecs = lanczos_tridiag_to_diag(tmats) self.gram_evecs = qmats @ evecs self.gram_evals = evals covar_root = self.gram_evecs @ torch.diag(evals.pow(-0.5)) return covar_root
def get_hessian_evals(loss, model, use_cuda=False, n_eigs=100, train_x=None, train_y=None, loader=None): if train_x is not None: if use_cuda: train_x = train_x.cuda() train_y = train_y.cuda() total_pars = sum(m.numel() for m in model.parameters()) def hvp(rhs): padded_rhs = torch.zeros(total_pars, rhs.shape[-1], device=rhs.device, dtype=rhs.dtype) padded_rhs = unflatten_like(padded_rhs.t(), model.parameters()) eval_hess_vec_prod(padded_rhs, net=model, criterion=loss, inputs=train_x, targets=train_y, dataloader=loader, use_cuda=use_cuda) full_hvp = gradtensor_to_tensor(model, include_bn=True) return full_hvp.unsqueeze(-1) # print('numpars is: ', numpars) if train_x is None: data = next(iter(loader))[0] if use_cuda: data = data.cuda() dtype = data.dtype device = data.device else: dtype, device = train_x.dtype, train_x.device qmat, tmat = lanczos_tridiag(hvp, n_eigs, dtype=dtype, device=device, matrix_shape=(total_pars, total_pars)) eigs, t_evals = lanczos_tridiag_to_diag(tmat) return eigs, qmat @ t_evals
def hessian_eigenpairs(net, criterion, inputs=None, targets=None, dataloader=None, n_eigs=20, use_cuda=torch.cuda.is_available(), verbose=False): params = [p for p in net.parameters() if len(p.size()) > 1] N = sum(p.numel() for p in params) def hess_vec_prod(vec): vec = unflatten_like(vec.t(), params) start_time = time.time() eval_hess_vec_prod(vec, params, net, criterion, inputs=inputs, targets=targets, dataloader=dataloader, use_cuda=use_cuda) prod_time = time.time() - start_time if verbose: print(" Iter: %d time: %f" % (hess_vec_prod.count, prod_time)) out = gradtensor_to_tensor(net) return out.unsqueeze(1) pos_q_mat, pos_t_mat = lanczos_tridiag( hess_vec_prod, n_eigs, device=params[0].device, dtype=params[0].dtype, matrix_shape=(N, N), ) # convert the tridiagonal t matrix to the eigenvalues e_vals, e_vecs = lanczos_tridiag_to_diag(pos_t_mat) ## GOING TO NEED TO CHANGE E_VECS HERE ## e_vecs = pos_q_mat.matmul(e_vecs) return e_vals, e_vecs
def min_max_hessian_eigs(net, dataloader, criterion, n_top_eigs=3, n_bottom_eigs=50, use_cuda=False): """ Compute the largest and the smallest eigenvalues of the Hessian marix. Args: net: the trained model. dataloader: dataloader for the dataset, may use a subset of it. criterion: loss function. use_cuda: use GPU """ params = [p for p in net.parameters() if len(p.size()) > 1] N = sum(p.numel() for p in net.parameters()) p = next(iter(net.parameters())) mask = torch.ones(N, dtype=p.dtype, device=p.device) def hess_vec_prod(vec): padded_rhs = torch.zeros(N, vec.shape[-1], device=vec.device, dtype=vec.dtype) padded_rhs[mask==1] = vec print("vec shape = ", vec.shape) print("padded shape = ", padded_rhs.shape) hess_vec_prod.count += 1 # simulates a static variable padded_rhs = unflatten_like(padded_rhs.t(), net.parameters()) start_time = time.time() eval_hess_vec_prod(padded_rhs, net=net, criterion=criterion, dataloader=dataloader, use_cuda=use_cuda) prod_time = time.time() - start_time out = gradtensor_to_tensor(net, include_bn=True) sliced = out[mask==1].unsqueeze(-1) print("sliced shape = ", sliced.shape) return sliced hess_vec_prod.count = 0 # use lanczos to get the t and q matrices out pos_q_mat, pos_t_mat = lanczos_tridiag( hess_vec_prod, n_top_eigs, device=params[0].device, dtype=params[0].dtype, matrix_shape=(N, N), ) # convert the tridiagonal t matrix to the eigenvalues pos_eigvals, pos_eigvecs = lanczos_tridiag_to_diag(pos_t_mat) pos_eigvecs = pos_q_mat @ pos_eigvecs # If the largest eigenvalue is positive, shift matrix so that any negative eigenvalue is now the largest # We assume the smallest eigenvalue is zero or less, and so this shift is more than what we need # shift = maxeig*.51 shift = 0.51 * pos_eigvals.max().item() print("Pos Eigs Computed....\n") def shifted_hess_vec_prod(vec): hvp = hess_vec_prod(vec) return -hvp + shift * vec # now run lanczos on the shifted eigenvalues neg_q_mat, neg_t_mat = lanczos_tridiag( shifted_hess_vec_prod, n_bottom_eigs, device=params[0].device, dtype=params[0].dtype, matrix_shape=(N, N), ) neg_eigvals, neg_eigvecs = lanczos_tridiag_to_diag(neg_t_mat) neg_eigvecs = neg_q_mat @ neg_eigvecs print("Neg Eigs Computed...") print("neg eigs = ", neg_eigvals) neg_eigvals = -neg_eigvals + shift #return maxeig, mineig, hess_vec_prod.count, pos_eigvals, neg_eigvals, pos_bases return pos_eigvals, pos_eigvecs, neg_eigvals, neg_eigvecs
def min_max_hessian_eigs(net, dataloader, criterion, rank=0, use_cuda=False, verbose=False, nsteps=100): """ Compute the largest and the smallest eigenvalues of the Hessian marix. Args: net: the trained model. dataloader: dataloader for the dataset, may use a subset of it. criterion: loss function. rank: rank of the working node. use_cuda: use GPU verbose: print more information Returns: maxeig: max eigenvalue mineig: min eigenvalue hess_vec_prod.count: number of iterations for calculating max and min eigenvalues """ params = [p for p in net.parameters() if len(p.size()) > 1] N = sum(p.numel() for p in params) def hess_vec_prod(vec): hess_vec_prod.count += 1 # simulates a static variable vec = unflatten_like(vec.t(), params) start_time = time.time() eval_hess_vec_prod(vec, params, net, criterion, dataloader, use_cuda) prod_time = time.time() - start_time if verbose and rank == 0: print(" Iter: %d time: %f" % (hess_vec_prod.count, prod_time)) out = gradtensor_to_tensor(net) return out.unsqueeze(1) hess_vec_prod.count = 0 if verbose and rank == 0: print("Rank %d: computing max eigenvalue" % rank) # use lanczos to get the t and q matrices out pos_q_mat, pos_t_mat = lanczos_tridiag( hess_vec_prod, nsteps, device=params[0].device, dtype=params[0].dtype, matrix_shape=(N, N), ) # convert the tridiagonal t matrix to the eigenvalues pos_eigvals, pos_eigvecs = lanczos_tridiag_to_diag(pos_t_mat) print(pos_eigvals) # eigenvalues may not be sorted maxeig = torch.max(pos_eigvals) pos_bases = pos_q_mat @ pos_eigvecs # maxeig = pos_eigvals[0] if verbose and rank == 0: print("max eigenvalue = %f" % maxeig) # # If the largest eigenvalue is positive, shift matrix so that any negative eigenvalue is now the largest # # We assume the smallest eigenvalue is zero or less, and so this shift is more than what we need # # shift = maxeig*.51 # shift = 0.51 * maxeig.item() # print(shift) # def shifted_hess_vec_prod(vec): # hvp = hess_vec_prod(vec) # return -hvp + shift * vec # if verbose and rank == 0: # print("Rank %d: Computing shifted eigenvalue" % rank) # # now run lanczos on the shifted eigenvalues # _, neg_t_mat = lanczos_tridiag( # shifted_hess_vec_prod, # 200, # device=params[0].device, # dtype=params[0].dtype, # matrix_shape=(N, N), # ) # neg_eigvals, _ = lanczos_tridiag_to_diag(neg_t_mat) # mineig = torch.max(neg_eigvals) # print(neg_eigvals) # mineig = -mineig + shift # print(mineig) # if verbose and rank == 0: # print("min eigenvalue = " + str(mineig)) # if maxeig <= 0 and mineig > 0: # maxeig, mineig = mineig, maxeig #return maxeig, mineig, hess_vec_prod.count, pos_eigvals, neg_eigvals, pos_bases return maxeig, None, hess_vec_prod.count, pos_eigvals, None, pos_t_mat
def min_max_fisher_eigs(net, dataloader, criterion, rank=0, use_cuda=False, verbose=False, fvp_method='FVP_FD'): """ Compute the largest and the smallest eigenvalues of the Hessian marix. Args: net: the trained model. dataloader: dataloader for the dataset, may use a subset of it. criterion: loss function. rank: rank of the working node. use_cuda: use GPU verbose: print more information Returns: maxeig: max eigenvalue mineig: min eigenvalue hess_vec_prod.count: number of iterations for calculating max and min eigenvalues """ if fvp_method == 'FVP_FD': print('Using FD for FVP') fvp_matmul = FVP_FD elif fvp_method == 'FVP_AG': print('Using AG for FVP') fvp_matmul = FVP_AG else: raise NotImplementedError( 'Only FD and AG have been implemented so far.') params = [p for p in net.parameters()] #if len(p.size()) > 1] N = sum(p.numel() for p in params) def fisher_vec_prod(vec): fisher_vec_prod.count += 1 # simulates a static variable vec = unflatten_like(vec.t(), net.parameters()) start_time = time.time() out = eval_fisher_vec_prod(vec, net, dataloader, use_cuda, fvp_matmul=fvp_matmul) prod_time = time.time() - start_time if verbose and rank == 0: print(" Iter: %d time: %f" % (fisher_vec_prod.count, prod_time)) #out = gradtensor_to_tensor(net) return out fisher_vec_prod.count = 0 if verbose and rank == 0: print("Rank %d: computing max eigenvalue" % rank) # use lanczos to get the t and q matrices out _, pos_t_mat = lanczos_tridiag(fisher_vec_prod, 100, device=params[0].device, dtype=params[0].dtype, matrix_shape=(N, N)) # convert the tridiagonal t matrix to the eigenvalues pos_eigvals, _ = lanczos_tridiag_to_diag(pos_t_mat) print(pos_eigvals) # eigenvalues may not be sorted maxeig = torch.max(pos_eigvals) #maxeig = pos_eigvals[0] if verbose and rank == 0: print('max eigenvalue = %f' % maxeig) # If the largest eigenvalue is positive, shift matrix so that any negative eigenvalue is now the largest # We assume the smallest eigenvalue is zero or less, and so this shift is more than what we need #shift = maxeig*.51 shift = 0.51 * maxeig.item() print(shift) def shifted_hess_vec_prod(vec): hvp = fisher_vec_prod(vec) return -hvp + shift * vec if verbose and rank == 0: print("Rank %d: Computing shifted eigenvalue" % rank) # now run lanczos on the shifted eigenvalues _, neg_t_mat = lanczos_tridiag(shifted_hess_vec_prod, 200, device=params[0].device, dtype=params[0].dtype, matrix_shape=(N, N)) neg_eigvals, _ = lanczos_tridiag_to_diag(neg_t_mat) mineig = torch.max(neg_eigvals) print(neg_eigvals) mineig = -mineig + shift print(mineig) if verbose and rank == 0: print('min eigenvalue = ' + str(mineig)) if maxeig <= 0 and mineig > 0: maxeig, mineig = mineig, maxeig return maxeig, mineig, fisher_vec_prod.count, pos_eigvals, neg_eigvals
def min_max_hessian_eigs(net, dataloader, criterion, rank=0, use_cuda=False, verbose=False, nsteps=100, return_evecs=False): """ Compute the largest and the smallest eigenvalues of the Hessian marix. Args: net: the trained model. dataloader: dataloader for the dataset, may use a subset of it. criterion: loss function. rank: rank of the working node. use_cuda: use GPU verbose: print more information Returns: maxeig: max eigenvalue mineig: min eigenvalue hess_vec_prod.count: number of iterations for calculating max and min eigenvalues """ params = [p for p in net.parameters() if len(p.size()) > 1] N = sum(p.numel() for p in params) def hess_vec_prod(vec): hess_vec_prod.count += 1 # simulates a static variable vec = unflatten_like(vec.t(), params) start_time = time.time() eval_hess_vec_prod(vec, params, net, criterion, dataloader, use_cuda) prod_time = time.time() - start_time if verbose and rank == 0: print(" Iter: %d time: %f" % (hess_vec_prod.count, prod_time)) out = gradtensor_to_tensor(net) return out.unsqueeze(1) hess_vec_prod.count = 0 if verbose and rank == 0: print("Rank %d: computing max eigenvalue" % rank) # use lanczos to get the t and q matrices out pos_q_mat, pos_t_mat = lanczos_tridiag( hess_vec_prod, nsteps, device=params[0].device, dtype=params[0].dtype, matrix_shape=(N, N), ) # convert the tridiagonal t matrix to the eigenvalues pos_eigvals, pos_eigvecs = lanczos_tridiag_to_diag(pos_t_mat) print(pos_eigvals) # eigenvalues may not be sorted maxeig = torch.max(pos_eigvals) pos_bases = pos_q_mat @ pos_eigvecs if verbose and rank == 0: print("max eigenvalue = %f" % maxeig) if not return_evecs: return maxeig, None, hess_vec_prod.count, pos_eigvals, None, pos_t_mat else: return maxeig, None, hess_vec_prod.count, pos_eigvals, pos_bases, pos_t_mat