Пример #1
0
def get_gegenbauer_gram(Theta1, Theta2):
    gram = Theta1 @ Theta2.T
    M = Theta1.shape[0]
    perc = 0
    Q = gegenbauer.get_gegenbauer_fast2(np.reshape(gram, M**2), kmax, d)

    return Q
def target_fn(x, target_pts, alpha, k, d):
    p1 = x.shape[0]
    p2 = x.shape[0]
    gram = np.matmul(x, target_pts.T)
    Q = gegenbauer.get_gegenbauer_fast2(gram, k + 2, d)[k, :]
    y = npo.matmul(Q, alpha)
    return y
def compute_kernel(X, Xp, spectrum, d, kmax):
    P = X.shape[0]
    Pp = Xp.shape[0]
    gram = X @ Xp.T
    gram = np.reshape(gram, P * Pp)
    Q = gegenbauer.get_gegenbauer_fast2(gram, kmax, d)
    degens = np.array([gegenbauer.degeneracy(d, k) for k in range(kmax)])
    K = Q.T @ (spectrum * degens)
    K = np.reshape(K, (P, Pp))
    return K
Пример #4
0
def get_effective_spectrum(layers, kmax, d, ker = 'NTK'):

    all_coeffs = np.zeros((len(layers), kmax))
    num_pts = 5000
    alpha = d/2.0 - 1
    z, w = sp.special.roots_gegenbauer(num_pts, alpha)
    degens = np.array( [gegenbauer.degeneracy(d,k) for k in range(kmax)] )
    Q = gegenbauer.get_gegenbauer_fast2(z, kmax, d)
    NTK_mat = np.zeros((len(layers), num_pts))
    for i in range(len(layers)):
        l = layers[i]
        if ker == 'NTK':
            NTK_mat[i,:] = NTK_recurse2(z, l)
        else:
            NTK_mat[i,:] = NNGP2(z, l)


    scaled_NTK = NTK_mat * np.outer( np.ones(len(layers)), w)
    scaled_Q =  Q * np.outer(degens, np.ones(num_pts))

    spectrum_scaled = scaled_NTK @ scaled_Q.T * gegenbauer.surface_area(d-1)/gegenbauer.surface_area(d)
    spectrum_scaled = spectrum_scaled * np.heaviside(spectrum_scaled-1e-14, 0)
    for i in range(kmax):
        if spectrum_scaled[0,i] < 1e-18:
            spectrum_scaled[0,i] = 0


    khat = Q.T @ spectrum_scaled[0,:]
    k = NTK_mat[0,:]

    spectrum_true = spectrum_scaled / np.outer(len(layers), degens)
    for i in range(len(layers)):
        for j in range(kmax-1):
            if spectrum_true[i,j+1] < spectrum_true[i,j]*1e-5:
                spectrum_true[i,j+1] = 0


    return spectrum_true
def generalization(P, X_teach, spectrum, kmax, d, num_repeats, lamb=1e-6):

    errors_avg = np.zeros(kmax)
    errors_tot_MC = 0
    all_errs = np.zeros((kmax, num_repeats))
    all_MC = np.zeros(num_repeats)
    X_teach = sample_random_points(P_teach, d)
    alpha_teach = np.sign(
        np.random.random_sample(P_teach) - 0.5 * np.ones(P_teach))

    for i in range(num_repeats):

        X_teach = sample_random_points(P_teach, d)
        alpha_teach = np.sign(
            np.random.random_sample(P_teach) - 0.5 * np.ones(P_teach))

        X = sample_random_points(P, d)
        K_student = compute_kernel(X, X, spectrum, d, kmax)
        K_stu_te = compute_kernel(X, X_teach, spectrum, d, kmax)
        y = K_stu_te @ alpha_teach

        K_inv = np.linalg.inv(K_student + lamb * np.eye(P))
        alpha = K_inv @ y

        degens = np.array([gegenbauer.degeneracy(d, k) for k in range(kmax)])

        gram_ss = X @ X.T
        gram_st = X @ X_teach.T
        gram_tt = X_teach @ X_teach.T

        Q_ss = gegenbauer.get_gegenbauer_fast2(np.reshape(gram_ss, P**2), kmax,
                                               d)
        Q_st = gegenbauer.get_gegenbauer_fast2(
            np.reshape(gram_st, P * P_teach), kmax, d)
        Q_tt = gegenbauer.get_gegenbauer_fast2(np.reshape(gram_tt, P_teach**2),
                                               kmax, d)

        errors = np.zeros(kmax)
        for k in range(kmax):
            Q_ssk = np.reshape(Q_ss[k, :], (P, P))
            Q_stk = np.reshape(Q_st[k, :], (P, P_teach))
            Q_ttk = np.reshape(Q_tt[k, :], (P_teach, P_teach))
            errors[k] = spectrum[k]**2 * degens[k] * (
                alpha.T @ Q_ssk @ alpha - 2 * alpha.T @ Q_stk @ alpha_teach +
                alpha_teach.T @ Q_ttk @ alpha_teach)
        errors_avg += 1 / num_repeats * errors
        all_errs[:, i] = errors

        num_test = 2500
        X_test = sample_random_points(num_test, d)
        K_s = compute_kernel(X, X_test, spectrum, d, kmax)
        K_t = compute_kernel(X_teach, X_test, spectrum, d, kmax)

        y_s = K_s.T @ alpha
        y_t = K_t.T @ alpha_teach
        tot_error = 1 / num_test * np.linalg.norm(y_s - y_t)**2
        print("errors")
        print("expt:   %e" % tot_error)
        print("theory: %e" % np.sum(errors))

        errors_tot_MC += 1 / num_repeats * tot_error
        all_MC[i] = tot_error

    std_errs = sp.stats.sem(all_errs, axis=1)
    std_MC = sp.stats.sem(all_MC)

    return errors_avg, errors_tot_MC, std_errs, std_MC