def test_stochastic_hessian(model,
                            criterion,
                            real_hessian,
                            x,
                            y,
                            bs=10,
                            ntrials=10):
    samples = [(x_i, y_i) for x_i, y_i in zip(x, y)]
    # full dataset
    dataloader = DataLoader(samples, batch_size=bs)

    eigenvals = []
    eigenvecs = []

    nparams = len(real_hessian)

    for _ in range(ntrials):
        est_eigenvals, est_eigenvecs = compute_hessian_eigenthings(
            model,
            dataloader,
            criterion,
            num_eigenthings=nparams,
            power_iter_steps=10,
            power_iter_err_threshold=1e-5,
            momentum=0,
            use_gpu=False,
        )
        est_eigenvals = np.array(est_eigenvals)
        est_eigenvecs = np.array([t.numpy() for t in est_eigenvecs])

        est_inds = np.argsort(est_eigenvals)
        est_eigenvals = np.array(est_eigenvals)[est_inds][::-1]
        est_eigenvecs = np.array(est_eigenvecs)[est_inds][::-1]

        eigenvals.append(est_eigenvals)
        eigenvecs.append(est_eigenvecs)

    eigenvals = np.array(eigenvals)
    eigenvecs = np.array(eigenvecs)

    real_eigenvals, real_eigenvecs = np.linalg.eig(real_hessian)
    real_inds = np.argsort(real_eigenvals)
    real_eigenvals = np.array(real_eigenvals)[real_inds][::-1]
    real_eigenvecs = np.array(real_eigenvecs)[real_inds][::-1]

    # Plot eigenvalue error
    plt.suptitle("Stochastic Hessian eigendecomposition errors: %d trials" %
                 ntrials)
    plt.subplot(1, 2, 1)
    plt.title("Eigenvalues")
    plt.plot(list(range(nparams)), real_eigenvals, label="True Eigenvals")
    plot_eigenval_estimates(eigenvals, label="Estimates")
    plt.legend()
    # Plot eigenvector L2 norm error
    plt.subplot(1, 2, 2)
    plt.title("Eigenvector cosine simliarity")
    plot_eigenvec_errors(real_eigenvecs, eigenvecs, label="Estimates")
    plt.legend()
    plt.savefig("stochastic.png")
    plt.clf()
def test_matrix(mat, ntrials, mode):
    """
    Tests the accuracy of deflated power iteration on the given matrix.
    It computes the average percent eigenval error and eigenvec simliartiy err
    """
    tensor = torch.from_numpy(mat).float()

    # for non-gpu tests, addmv not implemented for fp16 on CPU. have to do float.
    op = LambdaOperator(lambda x: torch.matmul(tensor, x.float()),
                        tensor.size()[:1])
    real_eigenvals, true_eigenvecs = np.linalg.eig(mat)
    real_eigenvecs = [true_eigenvecs[:, i] for i in range(len(real_eigenvals))]

    eigenvals = []
    eigenvecs = []
    for _ in range(ntrials):
        if mode == 'lanczos':
            method = lanczos
        else:
            method = deflated_power_iteration
        est_eigenvals, est_eigenvecs = method(
            op,
            num_eigenthings=args.num_eigenthings,
            use_gpu=False,
            fp16=args.fp16)
        est_inds = np.argsort(est_eigenvals)
        est_eigenvals = np.array(est_eigenvals)[est_inds][::-1]
        est_eigenvecs = np.array(est_eigenvecs)[est_inds][::-1]

        eigenvals.append(est_eigenvals)
        eigenvecs.append(est_eigenvecs)

    eigenvals = np.array(eigenvals)
    eigenvecs = np.array(eigenvecs)

    # truncate estimates
    real_inds = np.argsort(real_eigenvals)
    real_eigenvals = np.array(
        real_eigenvals)[real_inds][-args.num_eigenthings:][::-1]
    real_eigenvecs = np.array(
        real_eigenvecs)[real_inds][-args.num_eigenthings:][::-1]

    # Plot eigenvalue error
    plt.suptitle('Random Matrix Eigendecomposition Errors: %d trials' %
                 ntrials)
    plt.subplot(1, 2, 1)
    plt.title('Eigenvalues')
    plt.plot(list(range(len(real_eigenvals))),
             real_eigenvals,
             label='True Eigenvals')
    plot_eigenval_estimates(eigenvals, label='Estimates')
    plt.legend()
    # Plot eigenvector L2 norm error
    plt.subplot(1, 2, 2)
    plt.title('Eigenvector cosine simliarity')
    plot_eigenvec_errors(real_eigenvecs, eigenvecs, label='Estimates')
    plt.legend()
    plt.show()
def test_full_hessian(model, criterion, x, y, ntrials=10):
    loss = criterion(model(x), y)
    loss_grad = torch.autograd.grad(loss,
                                    model.parameters(),
                                    create_graph=True)
    real_hessian = get_full_hessian(loss_grad, model)

    samples = [(x_i, y_i) for x_i, y_i in zip(x, y)]
    # full dataset
    dataloader = DataLoader(samples, batch_size=len(x))

    eigenvals = []
    eigenvecs = []

    nparams = len(real_hessian)

    for _ in range(ntrials):
        est_eigenvals, est_eigenvecs = compute_hessian_eigenthings(
            model,
            dataloader,
            criterion,
            num_eigenthings=nparams,
            power_iter_steps=10,
            power_iter_err_threshold=1e-5,
            momentum=0,
            use_gpu=False,
        )
        est_inds = np.argsort(est_eigenvals)
        est_eigenvals = np.array(est_eigenvals)[est_inds][::-1]
        est_eigenvecs = np.array(est_eigenvecs)[est_inds][::-1]

        eigenvals.append(est_eigenvals)
        eigenvecs.append(est_eigenvecs)

    eigenvals = np.array(eigenvals)
    eigenvecs = np.array(eigenvecs)

    real_eigenvals, real_eigenvecs = np.linalg.eig(real_hessian)
    real_inds = np.argsort(real_eigenvals)
    real_eigenvals = np.array(real_eigenvals)[real_inds][::-1]
    real_eigenvecs = np.array(real_eigenvecs)[real_inds][::-1]

    # Plot eigenvalue error
    plt.suptitle("Hessian eigendecomposition errors: %d trials" % ntrials)
    plt.subplot(1, 2, 1)
    plt.title("Eigenvalues")
    plt.plot(list(range(nparams)), real_eigenvals, label="True Eigenvals")
    plot_eigenval_estimates(eigenvals, label="Estimates")
    plt.legend()
    # Plot eigenvector L2 norm error
    plt.subplot(1, 2, 2)
    plt.title("Eigenvector cosine simliarity")
    plot_eigenvec_errors(real_eigenvecs, eigenvecs, label="Estimates")
    plt.legend()
    plt.savefig("full.png")
    plt.clf()
    return real_hessian
Пример #4
0
def test_fixed_mini(model, criterion, real_hessian, x, y, bs=10, ntrials=10):
    x = x[:bs]
    y = y[:bs]

    samples = [(x_i, y_i) for x_i, y_i in zip(x, y)]
    # full dataset
    dataloader = DataLoader(samples, batch_size=len(x))

    eigenvals = []
    eigenvecs = []

    nparams = len(real_hessian)

    for _ in range(ntrials):
        est_eigenvals, est_eigenvecs = compute_hessian_eigenthings(
            model,
            dataloader,
            criterion,
            num_eigenthings=nparams,
            mode='lanczos',
            power_iter_steps=10,
            power_iter_err_threshold=1e-5,
            momentum=0,
            use_gpu=False)
        est_eigenvals = np.array(est_eigenvals)
        est_eigenvecs = np.array([t.numpy() for t in est_eigenvecs])

        est_inds = np.argsort(est_eigenvals)
        est_eigenvals = np.array(est_eigenvals)[est_inds][::-1]
        est_eigenvecs = np.array(est_eigenvecs)[est_inds][::-1]

        eigenvals.append(est_eigenvals)
        eigenvecs.append(est_eigenvecs)

    eigenvals = np.array(eigenvals)
    eigenvecs = np.array(eigenvecs)

    real_eigenvals, real_eigenvecs = np.linalg.eig(real_hessian)
    real_inds = np.argsort(real_eigenvals)
    real_eigenvals = np.array(real_eigenvals)[real_inds][::-1]
    real_eigenvecs = np.array(real_eigenvecs)[real_inds][::-1]

    # Plot eigenvalue error
    plt.suptitle(
        'Fixed mini-batch Hessian eigendecomposition errors: %d trials' %
        ntrials)
    plt.subplot(1, 2, 1)
    plt.title('Eigenvalues')
    plt.plot(list(range(nparams)), real_eigenvals, label='True Eigenvals')
    plot_eigenval_estimates(eigenvals, label='Estimates')
    plt.legend()
    # Plot eigenvector L2 norm error
    plt.subplot(1, 2, 2)
    plt.title('Eigenvector cosine simliarity')
    plot_eigenvec_errors(real_eigenvecs, eigenvecs, label='Estimates')
    plt.legend()
    plt.savefig('fixed.png')