예제 #1
0
def get_hessian_eigs(loss, model, mask=None,
                     use_cuda=False, n_eigs=100, train_x=None, train_y=None,
                     loader=None, evals=False):
    if train_x is not None:
        if use_cuda:
            train_x = train_x.cuda()
            train_y = train_y.cuda()

    total_pars = sum(m.numel() for m in model.parameters())
    if n_eigs != -1:
        if mask is not None:
            numpars = int(mask.sum().item())
        else:
            numpars = total_pars
            p = next(iter(model.parameters()))
            mask = torch.ones(total_pars, dtype=p.dtype, device=p.device)

        def hvp(rhs):
            padded_rhs = torch.zeros(total_pars, rhs.shape[-1],
                                     device=rhs.device, dtype=rhs.dtype)

            print("padded rhs shape = ", padded_rhs.shape)
            # print("mask shape = ", mask.shape)
            padded_rhs[mask==1] = rhs
            padded_rhs = unflatten_like(padded_rhs.t(), model.parameters())
            eval_hess_vec_prod(padded_rhs, net=model,
                               criterion=loss, inputs=train_x,
                               targets=train_y, dataloader=loader, use_cuda=use_cuda)
            full_hvp = gradtensor_to_tensor(model, include_bn=True)
            print('norm of hvp is: ', full_hvp.norm())
            sliced_hvp = full_hvp[mask==1].unsqueeze(-1)
#             print('finished a hvp')
            print("return shape = ", sliced_hvp.shape)
            return sliced_hvp

#         print('numpars is: ', numpars)
        if train_x is None:
            data = next(iter(loader))[0]
            if use_cuda:
                data = data.cuda()
            dtype = data.dtype
            device = data.device
        else:
            dtype, device = train_x.dtype, train_x.device

        qmat, tmat = lanczos_tridiag(hvp, n_eigs, dtype=dtype,
                                  device=device, matrix_shape=(numpars,
                                  numpars))
        eigs, t_evals = lanczos_tridiag_to_diag(tmat)
        if evals:
            return eigs, qmat @ t_evals
        return eigs
    else:
        # form and extract sub hessian
        hessian = get_hessian(train_x, train_y, loss, model, use_cuda=use_cuda)

        keepers = np.array(np.where(mask.cpu() == 1))[0]
        sub_hess = hessian[np.ix_(keepers, keepers)]
        e_val, _ = np.linalg.eig(sub_hess.cpu().detach())
        return e_val.real
예제 #2
0
    def _exact_predictive_covar_inv_quad_form_root(self, expanded_lt,
                                                   test_train_covar):
        """
        Computes :math:`K_{X^{*}X} S` given a precomputed cache
        Where :math:`S` is a tensor such that :math:`SS^{\\top} = (K_{XX} + \sigma^2 I)^{-1}`
        Args:
            precomputed_cache (:obj:`torch.tensor`): What was computed in _exact_predictive_covar_inv_quad_form_cache
            test_train_covar (:obj:`torch.tensor`): The observed noise (from the likelihood)
        Returns
            :obj:`~gpytorch.lazy.LazyTensor`: :math:`K_{X^{*}X} S`
        """
        qmats, tmats = lanczos_tridiag(
            expanded_lt.matmul,
            max_iter=settings.max_root_decomposition_size.value(),
            matrix_shape=expanded_lt.shape,
            device=expanded_lt.device,
            dtype=expanded_lt.dtype,
        )
        evals, evecs = lanczos_tridiag_to_diag(tmats)

        self.gram_evecs = qmats @ evecs
        self.gram_evals = evals

        covar_root = self.gram_evecs @ torch.diag(evals.pow(-0.5))

        return covar_root
예제 #3
0
def get_hessian_evals(loss,
                      model,
                      use_cuda=False,
                      n_eigs=100,
                      train_x=None,
                      train_y=None,
                      loader=None):
    if train_x is not None:
        if use_cuda:
            train_x = train_x.cuda()
            train_y = train_y.cuda()

    total_pars = sum(m.numel() for m in model.parameters())

    def hvp(rhs):
        padded_rhs = torch.zeros(total_pars,
                                 rhs.shape[-1],
                                 device=rhs.device,
                                 dtype=rhs.dtype)

        padded_rhs = unflatten_like(padded_rhs.t(), model.parameters())
        eval_hess_vec_prod(padded_rhs,
                           net=model,
                           criterion=loss,
                           inputs=train_x,
                           targets=train_y,
                           dataloader=loader,
                           use_cuda=use_cuda)
        full_hvp = gradtensor_to_tensor(model, include_bn=True)
        return full_hvp.unsqueeze(-1)

#         print('numpars is: ', numpars)

    if train_x is None:
        data = next(iter(loader))[0]
        if use_cuda:
            data = data.cuda()
        dtype = data.dtype
        device = data.device
    else:
        dtype, device = train_x.dtype, train_x.device

    qmat, tmat = lanczos_tridiag(hvp,
                                 n_eigs,
                                 dtype=dtype,
                                 device=device,
                                 matrix_shape=(total_pars, total_pars))
    eigs, t_evals = lanczos_tridiag_to_diag(tmat)

    return eigs, qmat @ t_evals
예제 #4
0
def hessian_eigenpairs(net,
                       criterion,
                       inputs=None,
                       targets=None,
                       dataloader=None,
                       n_eigs=20,
                       use_cuda=torch.cuda.is_available(),
                       verbose=False):

    params = [p for p in net.parameters() if len(p.size()) > 1]
    N = sum(p.numel() for p in params)

    def hess_vec_prod(vec):
        vec = unflatten_like(vec.t(), params)

        start_time = time.time()
        eval_hess_vec_prod(vec,
                           params,
                           net,
                           criterion,
                           inputs=inputs,
                           targets=targets,
                           dataloader=dataloader,
                           use_cuda=use_cuda)
        prod_time = time.time() - start_time
        if verbose:
            print("   Iter: %d  time: %f" % (hess_vec_prod.count, prod_time))
        out = gradtensor_to_tensor(net)
        return out.unsqueeze(1)

    pos_q_mat, pos_t_mat = lanczos_tridiag(
        hess_vec_prod,
        n_eigs,
        device=params[0].device,
        dtype=params[0].dtype,
        matrix_shape=(N, N),
    )
    # convert the tridiagonal t matrix to the eigenvalues
    e_vals, e_vecs = lanczos_tridiag_to_diag(pos_t_mat)

    ## GOING TO NEED TO CHANGE E_VECS HERE ##
    e_vecs = pos_q_mat.matmul(e_vecs)

    return e_vals, e_vecs
예제 #5
0
def min_max_hessian_eigs(net, dataloader, criterion,
                         n_top_eigs=3, n_bottom_eigs=50, use_cuda=False):
    """
        Compute the largest and the smallest eigenvalues of the Hessian marix.
        Args:
            net: the trained model.
            dataloader: dataloader for the dataset, may use a subset of it.
            criterion: loss function.
            use_cuda: use GPU
    """

    params = [p for p in net.parameters() if len(p.size()) > 1]
    N = sum(p.numel() for p in net.parameters())
    p = next(iter(net.parameters()))
    mask = torch.ones(N, dtype=p.dtype, device=p.device)
    
    def hess_vec_prod(vec):
        padded_rhs = torch.zeros(N, vec.shape[-1],
                             device=vec.device, dtype=vec.dtype)
        padded_rhs[mask==1] = vec
        
        print("vec shape = ", vec.shape)
        print("padded shape = ", padded_rhs.shape)
        hess_vec_prod.count += 1  # simulates a static variable
        padded_rhs = unflatten_like(padded_rhs.t(), net.parameters())

        start_time = time.time()
        eval_hess_vec_prod(padded_rhs, net=net, criterion=criterion,
                           dataloader=dataloader,
                          use_cuda=use_cuda)
        prod_time = time.time() - start_time
        out = gradtensor_to_tensor(net, include_bn=True)
        
        sliced = out[mask==1].unsqueeze(-1)
        print("sliced shape = ", sliced.shape)
        return sliced

    hess_vec_prod.count = 0

    # use lanczos to get the t and q matrices out
    pos_q_mat, pos_t_mat = lanczos_tridiag(
        hess_vec_prod,
        n_top_eigs,
        device=params[0].device,
        dtype=params[0].dtype,
        matrix_shape=(N, N),
    )
    # convert the tridiagonal t matrix to the eigenvalues
    pos_eigvals, pos_eigvecs = lanczos_tridiag_to_diag(pos_t_mat)

    pos_eigvecs = pos_q_mat @ pos_eigvecs

    # If the largest eigenvalue is positive, shift matrix so that any negative eigenvalue is now the largest
    # We assume the smallest eigenvalue is zero or less, and so this shift is more than what we need
    # shift = maxeig*.51
    shift = 0.51 * pos_eigvals.max().item()
    print("Pos Eigs Computed....\n")

    def shifted_hess_vec_prod(vec):
        hvp = hess_vec_prod(vec)
        return -hvp + shift * vec


    # now run lanczos on the shifted eigenvalues
    neg_q_mat, neg_t_mat = lanczos_tridiag(
        shifted_hess_vec_prod,
        n_bottom_eigs,
        device=params[0].device,
        dtype=params[0].dtype,
        matrix_shape=(N, N),
    )
    neg_eigvals, neg_eigvecs = lanczos_tridiag_to_diag(neg_t_mat)
    neg_eigvecs = neg_q_mat @ neg_eigvecs
    print("Neg Eigs Computed...")
    print("neg eigs = ", neg_eigvals)
    

    neg_eigvals = -neg_eigvals + shift


    #return maxeig, mineig, hess_vec_prod.count, pos_eigvals, neg_eigvals, pos_bases
    return pos_eigvals, pos_eigvecs, neg_eigvals, neg_eigvecs
예제 #6
0
def min_max_hessian_eigs(net,
                         dataloader,
                         criterion,
                         rank=0,
                         use_cuda=False,
                         verbose=False,
                         nsteps=100):
    """
        Compute the largest and the smallest eigenvalues of the Hessian marix.
        Args:
            net: the trained model.
            dataloader: dataloader for the dataset, may use a subset of it.
            criterion: loss function.
            rank: rank of the working node.
            use_cuda: use GPU
            verbose: print more information
        Returns:
            maxeig: max eigenvalue
            mineig: min eigenvalue
            hess_vec_prod.count: number of iterations for calculating max and min eigenvalues
    """

    params = [p for p in net.parameters() if len(p.size()) > 1]
    N = sum(p.numel() for p in params)

    def hess_vec_prod(vec):
        hess_vec_prod.count += 1  # simulates a static variable
        vec = unflatten_like(vec.t(), params)

        start_time = time.time()
        eval_hess_vec_prod(vec, params, net, criterion, dataloader, use_cuda)
        prod_time = time.time() - start_time
        if verbose and rank == 0:
            print("   Iter: %d  time: %f" % (hess_vec_prod.count, prod_time))
        out = gradtensor_to_tensor(net)
        return out.unsqueeze(1)

    hess_vec_prod.count = 0
    if verbose and rank == 0:
        print("Rank %d: computing max eigenvalue" % rank)

    # use lanczos to get the t and q matrices out
    pos_q_mat, pos_t_mat = lanczos_tridiag(
        hess_vec_prod,
        nsteps,
        device=params[0].device,
        dtype=params[0].dtype,
        matrix_shape=(N, N),
    )
    # convert the tridiagonal t matrix to the eigenvalues
    pos_eigvals, pos_eigvecs = lanczos_tridiag_to_diag(pos_t_mat)
    print(pos_eigvals)
    # eigenvalues may not be sorted
    maxeig = torch.max(pos_eigvals)

    pos_bases = pos_q_mat @ pos_eigvecs
    # maxeig = pos_eigvals[0]
    if verbose and rank == 0:
        print("max eigenvalue = %f" % maxeig)

    # # If the largest eigenvalue is positive, shift matrix so that any negative eigenvalue is now the largest
    # # We assume the smallest eigenvalue is zero or less, and so this shift is more than what we need
    # # shift = maxeig*.51
    # shift = 0.51 * maxeig.item()
    # print(shift)

    # def shifted_hess_vec_prod(vec):
    #     hvp = hess_vec_prod(vec)
    #     return -hvp + shift * vec

    # if verbose and rank == 0:
    #     print("Rank %d: Computing shifted eigenvalue" % rank)

    # # now run lanczos on the shifted eigenvalues
    # _, neg_t_mat = lanczos_tridiag(
    #     shifted_hess_vec_prod,
    #     200,
    #     device=params[0].device,
    #     dtype=params[0].dtype,
    #     matrix_shape=(N, N),
    # )
    # neg_eigvals, _ = lanczos_tridiag_to_diag(neg_t_mat)
    # mineig = torch.max(neg_eigvals)
    # print(neg_eigvals)

    # mineig = -mineig + shift
    # print(mineig)
    # if verbose and rank == 0:
    #     print("min eigenvalue = " + str(mineig))

    # if maxeig <= 0 and mineig > 0:
    #     maxeig, mineig = mineig, maxeig

    #return maxeig, mineig, hess_vec_prod.count, pos_eigvals, neg_eigvals, pos_bases
    return maxeig, None, hess_vec_prod.count, pos_eigvals, None, pos_t_mat
예제 #7
0
def min_max_fisher_eigs(net,
                        dataloader,
                        criterion,
                        rank=0,
                        use_cuda=False,
                        verbose=False,
                        fvp_method='FVP_FD'):
    """
        Compute the largest and the smallest eigenvalues of the Hessian marix.
        Args:
            net: the trained model.
            dataloader: dataloader for the dataset, may use a subset of it.
            criterion: loss function.
            rank: rank of the working node.
            use_cuda: use GPU
            verbose: print more information
        Returns:
            maxeig: max eigenvalue
            mineig: min eigenvalue
            hess_vec_prod.count: number of iterations for calculating max and min eigenvalues
    """
    if fvp_method == 'FVP_FD':
        print('Using FD for FVP')
        fvp_matmul = FVP_FD
    elif fvp_method == 'FVP_AG':
        print('Using AG for FVP')
        fvp_matmul = FVP_AG
    else:
        raise NotImplementedError(
            'Only FD and AG have been implemented so far.')

    params = [p for p in net.parameters()]  #if len(p.size()) > 1]
    N = sum(p.numel() for p in params)

    def fisher_vec_prod(vec):
        fisher_vec_prod.count += 1  # simulates a static variable
        vec = unflatten_like(vec.t(), net.parameters())

        start_time = time.time()
        out = eval_fisher_vec_prod(vec,
                                   net,
                                   dataloader,
                                   use_cuda,
                                   fvp_matmul=fvp_matmul)
        prod_time = time.time() - start_time
        if verbose and rank == 0:
            print("   Iter: %d  time: %f" % (fisher_vec_prod.count, prod_time))
        #out = gradtensor_to_tensor(net)
        return out

    fisher_vec_prod.count = 0
    if verbose and rank == 0: print("Rank %d: computing max eigenvalue" % rank)

    # use lanczos to get the t and q matrices out
    _, pos_t_mat = lanczos_tridiag(fisher_vec_prod,
                                   100,
                                   device=params[0].device,
                                   dtype=params[0].dtype,
                                   matrix_shape=(N, N))
    # convert the tridiagonal t matrix to the eigenvalues
    pos_eigvals, _ = lanczos_tridiag_to_diag(pos_t_mat)
    print(pos_eigvals)
    # eigenvalues may not be sorted
    maxeig = torch.max(pos_eigvals)

    #maxeig = pos_eigvals[0]
    if verbose and rank == 0: print('max eigenvalue = %f' % maxeig)

    # If the largest eigenvalue is positive, shift matrix so that any negative eigenvalue is now the largest
    # We assume the smallest eigenvalue is zero or less, and so this shift is more than what we need
    #shift = maxeig*.51
    shift = 0.51 * maxeig.item()
    print(shift)

    def shifted_hess_vec_prod(vec):
        hvp = fisher_vec_prod(vec)
        return -hvp + shift * vec

    if verbose and rank == 0:
        print("Rank %d: Computing shifted eigenvalue" % rank)

    # now run lanczos on the shifted eigenvalues
    _, neg_t_mat = lanczos_tridiag(shifted_hess_vec_prod,
                                   200,
                                   device=params[0].device,
                                   dtype=params[0].dtype,
                                   matrix_shape=(N, N))
    neg_eigvals, _ = lanczos_tridiag_to_diag(neg_t_mat)
    mineig = torch.max(neg_eigvals)
    print(neg_eigvals)

    mineig = -mineig + shift
    print(mineig)
    if verbose and rank == 0: print('min eigenvalue = ' + str(mineig))

    if maxeig <= 0 and mineig > 0:
        maxeig, mineig = mineig, maxeig

    return maxeig, mineig, fisher_vec_prod.count, pos_eigvals, neg_eigvals
예제 #8
0
def min_max_hessian_eigs(net,
                         dataloader,
                         criterion,
                         rank=0,
                         use_cuda=False,
                         verbose=False,
                         nsteps=100,
                         return_evecs=False):
    """
        Compute the largest and the smallest eigenvalues of the Hessian marix.
        Args:
            net: the trained model.
            dataloader: dataloader for the dataset, may use a subset of it.
            criterion: loss function.
            rank: rank of the working node.
            use_cuda: use GPU
            verbose: print more information
        Returns:
            maxeig: max eigenvalue
            mineig: min eigenvalue
            hess_vec_prod.count: number of iterations for calculating max and min eigenvalues
    """

    params = [p for p in net.parameters() if len(p.size()) > 1]
    N = sum(p.numel() for p in params)

    def hess_vec_prod(vec):
        hess_vec_prod.count += 1  # simulates a static variable
        vec = unflatten_like(vec.t(), params)

        start_time = time.time()
        eval_hess_vec_prod(vec, params, net, criterion, dataloader, use_cuda)
        prod_time = time.time() - start_time
        if verbose and rank == 0:
            print("   Iter: %d  time: %f" % (hess_vec_prod.count, prod_time))
        out = gradtensor_to_tensor(net)
        return out.unsqueeze(1)

    hess_vec_prod.count = 0
    if verbose and rank == 0:
        print("Rank %d: computing max eigenvalue" % rank)

    # use lanczos to get the t and q matrices out
    pos_q_mat, pos_t_mat = lanczos_tridiag(
        hess_vec_prod,
        nsteps,
        device=params[0].device,
        dtype=params[0].dtype,
        matrix_shape=(N, N),
    )
    # convert the tridiagonal t matrix to the eigenvalues
    pos_eigvals, pos_eigvecs = lanczos_tridiag_to_diag(pos_t_mat)
    print(pos_eigvals)
    # eigenvalues may not be sorted
    maxeig = torch.max(pos_eigvals)

    pos_bases = pos_q_mat @ pos_eigvecs
    if verbose and rank == 0:
        print("max eigenvalue = %f" % maxeig)

    if not return_evecs:
        return maxeig, None, hess_vec_prod.count, pos_eigvals, None, pos_t_mat
    else:
        return maxeig, None, hess_vec_prod.count, pos_eigvals, pos_bases, pos_t_mat