Esempio n. 1
0
def get_hessian_eigs(loss, model, mask=None,
                     use_cuda=False, n_eigs=100, train_x=None, train_y=None,
                     loader=None, evals=False):
    if train_x is not None:
        if use_cuda:
            train_x = train_x.cuda()
            train_y = train_y.cuda()

    total_pars = sum(m.numel() for m in model.parameters())
    if n_eigs != -1:
        if mask is not None:
            numpars = int(mask.sum().item())
        else:
            numpars = total_pars
            p = next(iter(model.parameters()))
            mask = torch.ones(total_pars, dtype=p.dtype, device=p.device)

        def hvp(rhs):
            padded_rhs = torch.zeros(total_pars, rhs.shape[-1],
                                     device=rhs.device, dtype=rhs.dtype)

            print("padded rhs shape = ", padded_rhs.shape)
            # print("mask shape = ", mask.shape)
            padded_rhs[mask==1] = rhs
            padded_rhs = unflatten_like(padded_rhs.t(), model.parameters())
            eval_hess_vec_prod(padded_rhs, net=model,
                               criterion=loss, inputs=train_x,
                               targets=train_y, dataloader=loader, use_cuda=use_cuda)
            full_hvp = gradtensor_to_tensor(model, include_bn=True)
            print('norm of hvp is: ', full_hvp.norm())
            sliced_hvp = full_hvp[mask==1].unsqueeze(-1)
#             print('finished a hvp')
            print("return shape = ", sliced_hvp.shape)
            return sliced_hvp

#         print('numpars is: ', numpars)
        if train_x is None:
            data = next(iter(loader))[0]
            if use_cuda:
                data = data.cuda()
            dtype = data.dtype
            device = data.device
        else:
            dtype, device = train_x.dtype, train_x.device

        qmat, tmat = lanczos_tridiag(hvp, n_eigs, dtype=dtype,
                                  device=device, matrix_shape=(numpars,
                                  numpars))
        eigs, t_evals = lanczos_tridiag_to_diag(tmat)
        if evals:
            return eigs, qmat @ t_evals
        return eigs
    else:
        # form and extract sub hessian
        hessian = get_hessian(train_x, train_y, loss, model, use_cuda=use_cuda)

        keepers = np.array(np.where(mask.cpu() == 1))[0]
        sub_hess = hessian[np.ix_(keepers, keepers)]
        e_val, _ = np.linalg.eig(sub_hess.cpu().detach())
        return e_val.real
Esempio n. 2
0
    def _exact_predictive_covar_inv_quad_form_root(self, expanded_lt,
                                                   test_train_covar):
        """
        Computes :math:`K_{X^{*}X} S` given a precomputed cache
        Where :math:`S` is a tensor such that :math:`SS^{\\top} = (K_{XX} + \sigma^2 I)^{-1}`
        Args:
            precomputed_cache (:obj:`torch.tensor`): What was computed in _exact_predictive_covar_inv_quad_form_cache
            test_train_covar (:obj:`torch.tensor`): The observed noise (from the likelihood)
        Returns
            :obj:`~gpytorch.lazy.LazyTensor`: :math:`K_{X^{*}X} S`
        """
        qmats, tmats = lanczos_tridiag(
            expanded_lt.matmul,
            max_iter=settings.max_root_decomposition_size.value(),
            matrix_shape=expanded_lt.shape,
            device=expanded_lt.device,
            dtype=expanded_lt.dtype,
        )
        evals, evecs = lanczos_tridiag_to_diag(tmats)

        self.gram_evecs = qmats @ evecs
        self.gram_evals = evals

        covar_root = self.gram_evecs @ torch.diag(evals.pow(-0.5))

        return covar_root
Esempio n. 3
0
    def test_lanczos(self):
        size = 100
        matrix = torch.randn(size, size)
        matrix = matrix.matmul(matrix.transpose(-1, -2))
        matrix.div_(matrix.norm())
        matrix.add_(torch.ones(matrix.size(-1)).mul(1e-6).diag())
        q_mat, t_mat = lanczos_tridiag(matrix.matmul, max_iter=size, tensor_cls=matrix.new, n_dims=matrix.size(-1))

        approx = q_mat.matmul(t_mat).matmul(q_mat.transpose(-1, -2))
        self.assertTrue(approx_equal(approx, matrix))
Esempio n. 4
0
    def lanczos_tridiag_test(self, matrix):
        size = matrix.shape[0]
        q_mat, t_mat = lanczos_tridiag(matrix.matmul,
                                       max_iter=size,
                                       dtype=matrix.dtype,
                                       device=matrix.device,
                                       matrix_shape=matrix.shape)

        self.assert_valid_sizes(size, t_mat, q_mat)
        self.assert_tridiagonally_positive(t_mat)
        approx = q_mat.matmul(t_mat).matmul(q_mat.transpose(-1, -2))
        self.assertTrue(approx_equal(approx, matrix))
Esempio n. 5
0
def get_hessian_evals(loss,
                      model,
                      use_cuda=False,
                      n_eigs=100,
                      train_x=None,
                      train_y=None,
                      loader=None):
    if train_x is not None:
        if use_cuda:
            train_x = train_x.cuda()
            train_y = train_y.cuda()

    total_pars = sum(m.numel() for m in model.parameters())

    def hvp(rhs):
        padded_rhs = torch.zeros(total_pars,
                                 rhs.shape[-1],
                                 device=rhs.device,
                                 dtype=rhs.dtype)

        padded_rhs = unflatten_like(padded_rhs.t(), model.parameters())
        eval_hess_vec_prod(padded_rhs,
                           net=model,
                           criterion=loss,
                           inputs=train_x,
                           targets=train_y,
                           dataloader=loader,
                           use_cuda=use_cuda)
        full_hvp = gradtensor_to_tensor(model, include_bn=True)
        return full_hvp.unsqueeze(-1)

#         print('numpars is: ', numpars)

    if train_x is None:
        data = next(iter(loader))[0]
        if use_cuda:
            data = data.cuda()
        dtype = data.dtype
        device = data.device
    else:
        dtype, device = train_x.dtype, train_x.device

    qmat, tmat = lanczos_tridiag(hvp,
                                 n_eigs,
                                 dtype=dtype,
                                 device=device,
                                 matrix_shape=(total_pars, total_pars))
    eigs, t_evals = lanczos_tridiag_to_diag(tmat)

    return eigs, qmat @ t_evals
Esempio n. 6
0
def hessian_eigenpairs(net,
                       criterion,
                       inputs=None,
                       targets=None,
                       dataloader=None,
                       n_eigs=20,
                       use_cuda=torch.cuda.is_available(),
                       verbose=False):

    params = [p for p in net.parameters() if len(p.size()) > 1]
    N = sum(p.numel() for p in params)

    def hess_vec_prod(vec):
        vec = unflatten_like(vec.t(), params)

        start_time = time.time()
        eval_hess_vec_prod(vec,
                           params,
                           net,
                           criterion,
                           inputs=inputs,
                           targets=targets,
                           dataloader=dataloader,
                           use_cuda=use_cuda)
        prod_time = time.time() - start_time
        if verbose:
            print("   Iter: %d  time: %f" % (hess_vec_prod.count, prod_time))
        out = gradtensor_to_tensor(net)
        return out.unsqueeze(1)

    pos_q_mat, pos_t_mat = lanczos_tridiag(
        hess_vec_prod,
        n_eigs,
        device=params[0].device,
        dtype=params[0].dtype,
        matrix_shape=(N, N),
    )
    # convert the tridiagonal t matrix to the eigenvalues
    e_vals, e_vecs = lanczos_tridiag_to_diag(pos_t_mat)

    ## GOING TO NEED TO CHANGE E_VECS HERE ##
    e_vecs = pos_q_mat.matmul(e_vecs)

    return e_vals, e_vecs
Esempio n. 7
0
def min_max_hessian_eigs(net, dataloader, criterion,
                         n_top_eigs=3, n_bottom_eigs=50, use_cuda=False):
    """
        Compute the largest and the smallest eigenvalues of the Hessian marix.
        Args:
            net: the trained model.
            dataloader: dataloader for the dataset, may use a subset of it.
            criterion: loss function.
            use_cuda: use GPU
    """

    params = [p for p in net.parameters() if len(p.size()) > 1]
    N = sum(p.numel() for p in net.parameters())
    p = next(iter(net.parameters()))
    mask = torch.ones(N, dtype=p.dtype, device=p.device)
    
    def hess_vec_prod(vec):
        padded_rhs = torch.zeros(N, vec.shape[-1],
                             device=vec.device, dtype=vec.dtype)
        padded_rhs[mask==1] = vec
        
        print("vec shape = ", vec.shape)
        print("padded shape = ", padded_rhs.shape)
        hess_vec_prod.count += 1  # simulates a static variable
        padded_rhs = unflatten_like(padded_rhs.t(), net.parameters())

        start_time = time.time()
        eval_hess_vec_prod(padded_rhs, net=net, criterion=criterion,
                           dataloader=dataloader,
                          use_cuda=use_cuda)
        prod_time = time.time() - start_time
        out = gradtensor_to_tensor(net, include_bn=True)
        
        sliced = out[mask==1].unsqueeze(-1)
        print("sliced shape = ", sliced.shape)
        return sliced

    hess_vec_prod.count = 0

    # use lanczos to get the t and q matrices out
    pos_q_mat, pos_t_mat = lanczos_tridiag(
        hess_vec_prod,
        n_top_eigs,
        device=params[0].device,
        dtype=params[0].dtype,
        matrix_shape=(N, N),
    )
    # convert the tridiagonal t matrix to the eigenvalues
    pos_eigvals, pos_eigvecs = lanczos_tridiag_to_diag(pos_t_mat)

    pos_eigvecs = pos_q_mat @ pos_eigvecs

    # If the largest eigenvalue is positive, shift matrix so that any negative eigenvalue is now the largest
    # We assume the smallest eigenvalue is zero or less, and so this shift is more than what we need
    # shift = maxeig*.51
    shift = 0.51 * pos_eigvals.max().item()
    print("Pos Eigs Computed....\n")

    def shifted_hess_vec_prod(vec):
        hvp = hess_vec_prod(vec)
        return -hvp + shift * vec


    # now run lanczos on the shifted eigenvalues
    neg_q_mat, neg_t_mat = lanczos_tridiag(
        shifted_hess_vec_prod,
        n_bottom_eigs,
        device=params[0].device,
        dtype=params[0].dtype,
        matrix_shape=(N, N),
    )
    neg_eigvals, neg_eigvecs = lanczos_tridiag_to_diag(neg_t_mat)
    neg_eigvecs = neg_q_mat @ neg_eigvecs
    print("Neg Eigs Computed...")
    print("neg eigs = ", neg_eigvals)
    

    neg_eigvals = -neg_eigvals + shift


    #return maxeig, mineig, hess_vec_prod.count, pos_eigvals, neg_eigvals, pos_bases
    return pos_eigvals, pos_eigvecs, neg_eigvals, neg_eigvecs
Esempio n. 8
0
def min_max_hessian_eigs(net,
                         dataloader,
                         criterion,
                         rank=0,
                         use_cuda=False,
                         verbose=False,
                         nsteps=100):
    """
        Compute the largest and the smallest eigenvalues of the Hessian marix.
        Args:
            net: the trained model.
            dataloader: dataloader for the dataset, may use a subset of it.
            criterion: loss function.
            rank: rank of the working node.
            use_cuda: use GPU
            verbose: print more information
        Returns:
            maxeig: max eigenvalue
            mineig: min eigenvalue
            hess_vec_prod.count: number of iterations for calculating max and min eigenvalues
    """

    params = [p for p in net.parameters() if len(p.size()) > 1]
    N = sum(p.numel() for p in params)

    def hess_vec_prod(vec):
        hess_vec_prod.count += 1  # simulates a static variable
        vec = unflatten_like(vec.t(), params)

        start_time = time.time()
        eval_hess_vec_prod(vec, params, net, criterion, dataloader, use_cuda)
        prod_time = time.time() - start_time
        if verbose and rank == 0:
            print("   Iter: %d  time: %f" % (hess_vec_prod.count, prod_time))
        out = gradtensor_to_tensor(net)
        return out.unsqueeze(1)

    hess_vec_prod.count = 0
    if verbose and rank == 0:
        print("Rank %d: computing max eigenvalue" % rank)

    # use lanczos to get the t and q matrices out
    pos_q_mat, pos_t_mat = lanczos_tridiag(
        hess_vec_prod,
        nsteps,
        device=params[0].device,
        dtype=params[0].dtype,
        matrix_shape=(N, N),
    )
    # convert the tridiagonal t matrix to the eigenvalues
    pos_eigvals, pos_eigvecs = lanczos_tridiag_to_diag(pos_t_mat)
    print(pos_eigvals)
    # eigenvalues may not be sorted
    maxeig = torch.max(pos_eigvals)

    pos_bases = pos_q_mat @ pos_eigvecs
    # maxeig = pos_eigvals[0]
    if verbose and rank == 0:
        print("max eigenvalue = %f" % maxeig)

    # # If the largest eigenvalue is positive, shift matrix so that any negative eigenvalue is now the largest
    # # We assume the smallest eigenvalue is zero or less, and so this shift is more than what we need
    # # shift = maxeig*.51
    # shift = 0.51 * maxeig.item()
    # print(shift)

    # def shifted_hess_vec_prod(vec):
    #     hvp = hess_vec_prod(vec)
    #     return -hvp + shift * vec

    # if verbose and rank == 0:
    #     print("Rank %d: Computing shifted eigenvalue" % rank)

    # # now run lanczos on the shifted eigenvalues
    # _, neg_t_mat = lanczos_tridiag(
    #     shifted_hess_vec_prod,
    #     200,
    #     device=params[0].device,
    #     dtype=params[0].dtype,
    #     matrix_shape=(N, N),
    # )
    # neg_eigvals, _ = lanczos_tridiag_to_diag(neg_t_mat)
    # mineig = torch.max(neg_eigvals)
    # print(neg_eigvals)

    # mineig = -mineig + shift
    # print(mineig)
    # if verbose and rank == 0:
    #     print("min eigenvalue = " + str(mineig))

    # if maxeig <= 0 and mineig > 0:
    #     maxeig, mineig = mineig, maxeig

    #return maxeig, mineig, hess_vec_prod.count, pos_eigvals, neg_eigvals, pos_bases
    return maxeig, None, hess_vec_prod.count, pos_eigvals, None, pos_t_mat
Esempio n. 9
0
def min_max_fisher_eigs(net,
                        dataloader,
                        criterion,
                        rank=0,
                        use_cuda=False,
                        verbose=False,
                        fvp_method='FVP_FD'):
    """
        Compute the largest and the smallest eigenvalues of the Hessian marix.
        Args:
            net: the trained model.
            dataloader: dataloader for the dataset, may use a subset of it.
            criterion: loss function.
            rank: rank of the working node.
            use_cuda: use GPU
            verbose: print more information
        Returns:
            maxeig: max eigenvalue
            mineig: min eigenvalue
            hess_vec_prod.count: number of iterations for calculating max and min eigenvalues
    """
    if fvp_method == 'FVP_FD':
        print('Using FD for FVP')
        fvp_matmul = FVP_FD
    elif fvp_method == 'FVP_AG':
        print('Using AG for FVP')
        fvp_matmul = FVP_AG
    else:
        raise NotImplementedError(
            'Only FD and AG have been implemented so far.')

    params = [p for p in net.parameters()]  #if len(p.size()) > 1]
    N = sum(p.numel() for p in params)

    def fisher_vec_prod(vec):
        fisher_vec_prod.count += 1  # simulates a static variable
        vec = unflatten_like(vec.t(), net.parameters())

        start_time = time.time()
        out = eval_fisher_vec_prod(vec,
                                   net,
                                   dataloader,
                                   use_cuda,
                                   fvp_matmul=fvp_matmul)
        prod_time = time.time() - start_time
        if verbose and rank == 0:
            print("   Iter: %d  time: %f" % (fisher_vec_prod.count, prod_time))
        #out = gradtensor_to_tensor(net)
        return out

    fisher_vec_prod.count = 0
    if verbose and rank == 0: print("Rank %d: computing max eigenvalue" % rank)

    # use lanczos to get the t and q matrices out
    _, pos_t_mat = lanczos_tridiag(fisher_vec_prod,
                                   100,
                                   device=params[0].device,
                                   dtype=params[0].dtype,
                                   matrix_shape=(N, N))
    # convert the tridiagonal t matrix to the eigenvalues
    pos_eigvals, _ = lanczos_tridiag_to_diag(pos_t_mat)
    print(pos_eigvals)
    # eigenvalues may not be sorted
    maxeig = torch.max(pos_eigvals)

    #maxeig = pos_eigvals[0]
    if verbose and rank == 0: print('max eigenvalue = %f' % maxeig)

    # If the largest eigenvalue is positive, shift matrix so that any negative eigenvalue is now the largest
    # We assume the smallest eigenvalue is zero or less, and so this shift is more than what we need
    #shift = maxeig*.51
    shift = 0.51 * maxeig.item()
    print(shift)

    def shifted_hess_vec_prod(vec):
        hvp = fisher_vec_prod(vec)
        return -hvp + shift * vec

    if verbose and rank == 0:
        print("Rank %d: Computing shifted eigenvalue" % rank)

    # now run lanczos on the shifted eigenvalues
    _, neg_t_mat = lanczos_tridiag(shifted_hess_vec_prod,
                                   200,
                                   device=params[0].device,
                                   dtype=params[0].dtype,
                                   matrix_shape=(N, N))
    neg_eigvals, _ = lanczos_tridiag_to_diag(neg_t_mat)
    mineig = torch.max(neg_eigvals)
    print(neg_eigvals)

    mineig = -mineig + shift
    print(mineig)
    if verbose and rank == 0: print('min eigenvalue = ' + str(mineig))

    if maxeig <= 0 and mineig > 0:
        maxeig, mineig = mineig, maxeig

    return maxeig, mineig, fisher_vec_prod.count, pos_eigvals, neg_eigvals
Esempio n. 10
0
def compute_eigenspectrum(
    dataset: str,
    data_path: str,
    model: str,
    checkpoint_path: str,
    curvature_matrix: str = 'hessian_lanczos',
    use_test: bool = True,
    batch_size: int = 128,
    num_workers: int = 4,
    swag: bool = False,
    lanczos_iters: int = 100,
    num_subsamples: int = None,
    subsample_seed: int = None,
    bn_train_mode: bool = True,
    save_spectrum_path: str = None,
    save_eigvec: bool = False,
    seed: int = None,
    device: str = 'cuda',
):
    """
    This function takes a deep learning model and compute the eigenvalues and eigenvectors (if desired) of the deep
    learning model, either using Lanczos algorithm or using Backpack [1] interface of diagonal approximation of the
    various curvature matrix.
    Parameters
    ----------
    dataset: str: ['CIFAR10', 'CIFAR100', 'MNIST', 'ImageNet32'*]: the dataset on which you would like to train the
    model. For ImageNet 32, we use the downsampled 32 x 32 Full ImageNet dataset. We do not provide download due to
    the proprietary issues, and please drop the data of ImageNet 32 in 'data/' folder    data_path

    data_path: str: the path string of the dataset

    model: str: the neural network architecture you would like to train. All available models are listed under 'models'/
    Example: VGG16BN, PreResNet110 (Preactivated ResNet - 110 layers)

    checkpoint_path: str: the path string to the checkpoints generated by train_network, which contains the state_dict
    of the network and the optimizer.

    curvature_matrix: str: the type of curvature matrix and computation method desired.
    Possible values are:
        hessian_lanczos: Lanczos algorithm of Hessian matrix
        ggn_lanczos: Lanczos algorithm on Generalised Gauss-Newton (GGN)
        cov_grad_lancozs: Lanczos algorithm on Covariance of Gradients

        WARNING: the Backpack package (the diagonal computation interface) we use does not support Residual layers in
        ResNets and derived networks (as of 14 Dec 2019),
        Further, it constrains the model to be a subclass of nn.Sequential. We have
        written modified VGG16 for this purpose, but there is no guarantee that other models will work as-is.

    use_test: bool: if True, you will test the model on the test set. If not, a portion of the training data will be
    assigned as the validation set.

    batch_size: int: the minibatch size

    num_workers: int: number of workers for the dataloader

    swag: whether to use Stochastic Weight Averaging (Gaussian)

    lanczos_iters: *only applicable if the curvature_matrix is set to hessian_lanczos, ggn_lanczos or cov_grad_lanczos*
    Number of iterations for the Lanczos algorithm. This also determines the Ritz value - vector pair generated from
    the Eigenspectrum.

    num_subsamples: int: Number of subsamples to draw randomly from the training dataset. If None, the entire dataset
    will be used.

    subsample_seed: int: the Pseudorandom number seed for subsample draw from above.

    bn_train_mode: bool: Applies only if the network architecture (''model'') used contains batch normalization layers.
    Toggles whether BN layers should be in train or eval mode.

    save_spectrum_path: str: If provided, the Ritz value generated (or the diagonal approximation) will be saved to this
    poth.

    save_eigvec: bool: If True, the implied eigenvectors will also be saved to the same format.
    Note: When this is true, instead of converting the arrays to numpy.ndarray we save directly the torch Tensor. The
    eigenvectors have size P, where P is the number of parameters in the model, so turning this mode on while running
    a large number of experiments could take lots of storage.

    seed: if not None, a manual seed for the pseudo-random number generation will be used.

    device: ['cpu', 'cuda']: the device on which the model and all computations are performed. Strongly recommend 'cuda'
    for GPU accleration in CUDA-enabled Nvidia Devices

    Returns
    -------
    (eigvals, gammas, V):
        eigvals: the computed Ritz Value / diagonal elements of the curvature matrix
        gammas:
        V:
    """
    if device == 'cuda':
        if not torch.cuda.is_available():
            device = 'cpu'
    assert curvature_matrix in [
        'hessian_lanczos',
        'ggn_lanczos',
        'cov_grad_lanczos',
    ]

    torch.backends.cudnn.benchmark = True
    if seed is not None:
        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)

    print('Using model %s' % model)
    model_cfg = getattr(models, model)

    datasets, num_classes = data.datasets(
        dataset,
        data_path,
        transform_train=model_cfg.transform_test,
        transform_test=model_cfg.transform_test,
        use_validation=not use_test,
        train_subset=num_subsamples,
        train_subset_seed=subsample_seed,
    )

    loader = torch.utils.data.DataLoader(datasets['train'],
                                         batch_size=batch_size,
                                         shuffle=False,
                                         num_workers=num_workers,
                                         pin_memory=True)

    full_datasets, _ = data.datasets(
        dataset,
        data_path,
        transform_train=model_cfg.transform_train,
        transform_test=model_cfg.transform_test,
        use_validation=not use_test,
    )

    full_loader = torch.utils.data.DataLoader(full_datasets['train'],
                                              batch_size=batch_size,
                                              shuffle=False,
                                              num_workers=num_workers,
                                              pin_memory=True)

    print('Preparing model')
    print(*model_cfg.args, dict(**model_cfg.kwargs))

    if not swag:
        model = model_cfg.base(*model_cfg.args,
                               num_classes=num_classes,
                               **model_cfg.kwargs)
        print('Loading %s' % checkpoint_path)
        checkpoint = torch.load(checkpoint_path)
        model.load_state_dict(checkpoint['state_dict'])
    else:
        swag_model = SWAG(model_cfg.base,
                          subspace_type='random',
                          *model_cfg.args,
                          num_classes=num_classes,
                          **model_cfg.kwargs)
        print('Loading %s' % checkpoint_path)
        checkpoint = torch.load(checkpoint_path)
        swag_model.load_state_dict(checkpoint['state_dict'], strict=False)
        swag_model.set_swa()
        model = swag_model.base_model

    model.to(device)

    num_parametrs = sum([p.numel() for p in model.parameters()])

    criterion = losses.cross_entropy

    class CurvVecProduct(object):
        def __init__(self,
                     loader,
                     model,
                     criterion,
                     curvature_matrix,
                     full_loader=None):
            self.loader = loader
            self.full_loader = full_loader
            self.model = model
            self.criterion = criterion
            self.iters = 0
            self.timestamp = time.time()
            self.curvature_matrix = curvature_matrix

        def __call__(self, vector):
            start_time = time.time()
            if self.curvature_matrix == 'hessian_lanczos':
                output = utils.hess_vec(
                    vector,
                    self.loader,
                    self.model,
                    self.criterion,
                    cuda=device == 'cuda',
                    bn_train_mode=bn_train_mode,
                )
            elif self.curvature_matrix == 'ggn_lanczos':
                output = utils.gn_vec(vector,
                                      self.loader,
                                      self.model,
                                      self.criterion,
                                      cuda=device == 'cuda',
                                      bn_train_mode=bn_train_mode)
            elif self.curvature_matrix == 'cov_grad_lanczos':
                output = utils.covgrad_vec(vector,
                                           self.loader,
                                           self.model,
                                           self.criterion,
                                           cuda=device == 'cuda',
                                           bn_train_mode=bn_train_mode)
            else:
                raise ValueError("Unrecognised curvature_matrix argument " +
                                 self.curvature_matrix)
            time_diff = time.time() - start_time
            self.iters += 1
            print('Iter %d. Time: %.2f' % (self.iters, time_diff))
            # return output.unsqueeze(1)¬
            return output.cpu().unsqueeze(1)

    w = torch.cat(
        [param.detach().cpu().view(-1) for param in model.parameters()])
    productor = CurvVecProduct(loader, model, criterion, curvature_matrix)
    utils.bn_update(full_loader, model)
    Q, T = lanczos_tridiag(productor,
                           lanczos_iters,
                           dtype=torch.float32,
                           device='cpu',
                           matrix_shape=(num_parametrs, num_parametrs))
    eigvals, eigvects = T.eig(eigenvectors=True)
    gammas = eigvects[0, :]**2
    V = eigvects.t() @ Q.t()
    if save_spectrum_path is not None:
        if save_eigvec:
            torch.save(
                {
                    'w': w,
                    'eigvals': eigvals if eigvals is not None else None,
                    'gammas': gammas if gammas is not None else None,
                    'V': V if V is not None else None,
                },
                save_spectrum_path,
            )
        np.savez(save_spectrum_path,
                 w=w.numpy(),
                 eigvals=eigvals.numpy() if eigvals is not None else None,
                 gammas=gammas.numpy() if gammas is not None else None)
    return {'w': w, 'eigvals': eigvals, 'gammas': gammas, 'V': V}
Esempio n. 11
0
def min_max_hessian_eigs(net,
                         dataloader,
                         criterion,
                         rank=0,
                         use_cuda=False,
                         verbose=False,
                         nsteps=100,
                         return_evecs=False):
    """
        Compute the largest and the smallest eigenvalues of the Hessian marix.
        Args:
            net: the trained model.
            dataloader: dataloader for the dataset, may use a subset of it.
            criterion: loss function.
            rank: rank of the working node.
            use_cuda: use GPU
            verbose: print more information
        Returns:
            maxeig: max eigenvalue
            mineig: min eigenvalue
            hess_vec_prod.count: number of iterations for calculating max and min eigenvalues
    """

    params = [p for p in net.parameters() if len(p.size()) > 1]
    N = sum(p.numel() for p in params)

    def hess_vec_prod(vec):
        hess_vec_prod.count += 1  # simulates a static variable
        vec = unflatten_like(vec.t(), params)

        start_time = time.time()
        eval_hess_vec_prod(vec, params, net, criterion, dataloader, use_cuda)
        prod_time = time.time() - start_time
        if verbose and rank == 0:
            print("   Iter: %d  time: %f" % (hess_vec_prod.count, prod_time))
        out = gradtensor_to_tensor(net)
        return out.unsqueeze(1)

    hess_vec_prod.count = 0
    if verbose and rank == 0:
        print("Rank %d: computing max eigenvalue" % rank)

    # use lanczos to get the t and q matrices out
    pos_q_mat, pos_t_mat = lanczos_tridiag(
        hess_vec_prod,
        nsteps,
        device=params[0].device,
        dtype=params[0].dtype,
        matrix_shape=(N, N),
    )
    # convert the tridiagonal t matrix to the eigenvalues
    pos_eigvals, pos_eigvecs = lanczos_tridiag_to_diag(pos_t_mat)
    print(pos_eigvals)
    # eigenvalues may not be sorted
    maxeig = torch.max(pos_eigvals)

    pos_bases = pos_q_mat @ pos_eigvecs
    if verbose and rank == 0:
        print("max eigenvalue = %f" % maxeig)

    if not return_evecs:
        return maxeig, None, hess_vec_prod.count, pos_eigvals, None, pos_t_mat
    else:
        return maxeig, None, hess_vec_prod.count, pos_eigvals, pos_bases, pos_t_mat
Esempio n. 12
0
    def criterion(current_pars, input_data, target, return_predictions=True):
        r"""
        Loss function for Laplace

        current_pars (list/iterable): parameter list
        input_data (tensor): input data for model
        target (tensor): response
        return_predictions (bool):if predictions should be returned as well as loss
        """
        if eval_mode:
            # this means prediction time
            # so do a Fisher vector product + jitter, take the tmatrix invert the cholesky decomp and sample
            # F \approx Q T Q' => F^{-1} \approx Q T^{-1} Q'
            # F^{-1/2} \approx Q T^{-1/2}
            fvp = ((num_data / input_data.shape[0]) *
                   FVP_FD(model, input_data)).add_jitter(1.0)
            qmat, tmat = lanczos_tridiag(
                fvp.matmul,
                30,
                dtype=current_pars[0].dtype,
                device=current_pars[0].device,
                init_vecs=None,
                matrix_shape=[
                    current_pars[0].shape[0], current_pars[0].shape[0]
                ],
            )

            eigs, evecs = torch.symeig(tmat, eigenvectors=True)

            # only consider the top half of the eigenvalues bc they're reliable
            eigs_gt_zero = torch.sort(eigs)[1][-int(tmat.shape[0] / 2):]

            # update the eigendecomposition
            # note that @ is a matmul
            updated_evecs = (qmat @ evecs)[:, eigs_gt_zero]

            z = torch.randn(eigs_gt_zero.shape[0],
                            1,
                            device=tmat.device,
                            dtype=tmat.dtype)
            approx_lz = updated_evecs @ torch.diag(
                1.0 / eigs[eigs_gt_zero].pow(0.5)) @ z
            sample = current_pars[0] + approx_lz
        else:
            sample = current_pars[0]

        rhs = sample
        if bias:
            rhs = sample - model_pars.view(-1, 1)

        predictions = Jacobian(model=model, data=input_data,
                               num_outputs=1)._t_matmul(rhs)
        predictions_reshaped = predictions.reshape(target.shape[0],
                                                   num_classes)

        if bias:
            predictions_reshaped = predictions_reshaped + model(input_data)

        loss = (
            torch.nn.functional.cross_entropy(predictions_reshaped, target) *
            target.shape[0])
        regularizer = current_pars[0].norm() * wd

        if eval_mode:
            output = loss
        else:
            output = num_data * loss + regularizer

        return output, predictions_reshaped