Ejemplo n.º 1
0
def generalization_expt_kteach(P,
                               spectrum,
                               M,
                               d,
                               kmax,
                               num_repeats,
                               X_teach,
                               alpha_teach,
                               spectrum_teach,
                               num_test=1000):

    all_mode_errs = np.zeros((num_repeats, kmax))
    all_mc_errs = np.zeros(num_repeats)
    all_training_errs = np.zeros(num_repeats)
    degens = np.array([gegenbauer.degeneracy(d, k) for k in range(kmax)])
    print("P = %d" % P)
    Theta = np.zeros((M, d))
    X_test = np.zeros((num_test, d))
    X = np.zeros((P, d))
    for t in range(num_repeats):
        print("t=%d" % t)
        start = time.time()

        Theta = sample_random_points_jit(M, d, Theta)
        r = np.random.standard_normal(M) / np.sqrt(M)

        X = sample_random_points_jit(P, d, X)

        K = compute_kernel(X_teach, X)
        Y = K.T @ alpha_teach

        num_iter = 3 * P
        Theta, r, E_tr = SGD(X, Y, Theta, r, num_iter, readout_only=False)
        print("Etr = %e" % E_tr)
        counter = 1

        print("finished SGD")
        print("num tries: %d" % counter)

        all_mode_errs[t, :] = get_mode_errs(Theta, Theta_teach, r, r_teach,
                                            kmax, d, degens)
        end = time.time()
        print("time = %lf" % (end - start))

        X_test = sample_random_points_jit(num_test, d, X_test)
        #X_test = sample_random_points(num_test, d)
        Y_test = feedfoward(X_test, Theta_teach, r_teach)
        Y_pred = feedfoward(X_test, Theta, r)
        all_mc_errs[t] = 1 / num_test * np.linalg.norm(Y_test - Y_pred)**2
        all_training_errs[t] = E_tr
    average_mode_errs = np.mean(all_mode_errs, axis=0)
    std_errs = np.std(all_mode_errs, axis=0)
    average_mc = np.mean(all_mc_errs)
    std_mc = np.std(all_mc_errs)
    print("average MC   = %e" % average_mc)
    print("sum of modes = %e" % np.sum(average_mode_errs))
    return average_mc, std_mc, np.mean(all_training_errs)
def compute_kernel(X, Xp, spectrum, d, kmax):
    P = X.shape[0]
    Pp = Xp.shape[0]
    gram = X @ Xp.T
    gram = np.reshape(gram, P * Pp)
    Q = gegenbauer.get_gegenbauer_fast2(gram, kmax, d)
    degens = np.array([gegenbauer.degeneracy(d, k) for k in range(kmax)])
    K = Q.T @ (spectrum * degens)
    K = np.reshape(K, (P, Pp))
    return K
Ejemplo n.º 3
0
def one_layer_NNGP(kmax, d):
    num_pts = 10000
    alpha= d/2.0 -1
    z,w=  sp.special.roots_gegenbauer(num_pts, alpha)
    NNGP = (1/math.pi * (1-z**2)**(0.5) + (1-1/math.pi * np.arccos(z)) * z)
    Q = gegenbauer.get_gegenbauer(z, kmax, d)

    norms = np.array([gegenbauer.normalizing_factor(k,d/2.0-1) for k in range(kmax)])
    degens = np.array( [gegenbauer.degeneracy(d,k) for k in range(kmax)] )

    NNGP_coeffs = Q @ (w*NNGP) / (norms*degens)

    for i in range(kmax):
        if NNGP_coeffs[i] < 1e-25:
            NNGP_coeffs[i] = 0
    return NNGP_coeffs
def get_effective_spectrum_hermite(layers, kmax, d, ker='NTK'):

    all_coeffs = np.zeros((len(layers), kmax))
    num_pts = 2000
    alpha = d / 2.0 - 1
    z, w = sp.special.roots_hermite(num_pts)
    degens = np.array([gegenbauer.degeneracy(d, k) for k in range(kmax)])
    #Q = gegenbauer.get_gegenbauer_fast2(z, kmax, d)

    scales = np.array(
        [2**k * math.factorial(k) * np.sqrt(math.pi) for k in range(kmax)])

    inds_valid = [i for i in range(num_pts) if np.abs(z[i]) < np.sqrt(d)]
    z_valid = z[inds_valid]
    w_valid = w[inds_valid]
    num_pts = len(inds_valid)
    H = gegenbauer.get_hermite_fast(z_valid, kmax, d)
    NTK_mat = np.zeros((len(layers), num_pts))
    max_element = np.amax(np.abs(z_valid / np.sqrt(d)))
    for i in range(len(layers)):
        l = layers[i]
        if ker == 'NTK':
            NTK_mat[i, :] = NTK(np.arccos(z_valid / np.sqrt(d)), l)
        else:
            NTK_mat[i, :] = NNGP(np.arccos(z_valid / np.sqrt(d)), l)

    scaled_NTK = NTK_mat * np.outer(np.ones(len(layers)), w_valid)
    scaled_H = H * np.outer(scales**(-1), np.ones(num_pts))
    spectrum_scaled = scaled_NTK @ scaled_H.T

    spectrum_scaled = spectrum_scaled * np.heaviside(
        spectrum_scaled - 1e-30 * np.ones(len(spectrum_scaled)), 0)

    spectrum_true = np.zeros(spectrum_scaled.shape)

    for i in range(len(layers)):
        spectrum_true[i, :] = gegenbauer.hermite_to_gegenbauer_coeffs(
            spectrum_scaled[i, :], d)

    for i in range(len(layers)):
        for j in range(kmax - 1):
            if spectrum_true[i, j + 1] < spectrum_true[i, j] * 1e-5:
                spectrum_true[i, j + 1] = 0

    return spectrum_true
Ejemplo n.º 5
0
def get_effective_spectrum(layers, kmax, d, ker = 'NTK'):

    all_coeffs = np.zeros((len(layers), kmax))
    num_pts = 5000
    alpha = d/2.0 - 1
    z, w = sp.special.roots_gegenbauer(num_pts, alpha)
    degens = np.array( [gegenbauer.degeneracy(d,k) for k in range(kmax)] )
    Q = gegenbauer.get_gegenbauer_fast2(z, kmax, d)
    NTK_mat = np.zeros((len(layers), num_pts))
    for i in range(len(layers)):
        l = layers[i]
        if ker == 'NTK':
            NTK_mat[i,:] = NTK_recurse2(z, l)
        else:
            NTK_mat[i,:] = NNGP2(z, l)


    scaled_NTK = NTK_mat * np.outer( np.ones(len(layers)), w)
    scaled_Q =  Q * np.outer(degens, np.ones(num_pts))

    spectrum_scaled = scaled_NTK @ scaled_Q.T * gegenbauer.surface_area(d-1)/gegenbauer.surface_area(d)
    spectrum_scaled = spectrum_scaled * np.heaviside(spectrum_scaled-1e-14, 0)
    for i in range(kmax):
        if spectrum_scaled[0,i] < 1e-18:
            spectrum_scaled[0,i] = 0


    khat = Q.T @ spectrum_scaled[0,:]
    k = NTK_mat[0,:]

    spectrum_true = spectrum_scaled / np.outer(len(layers), degens)
    for i in range(len(layers)):
        for j in range(kmax-1):
            if spectrum_true[i,j+1] < spectrum_true[i,j]*1e-5:
                spectrum_true[i,j+1] = 0


    return spectrum_true
def generalization(P, X_teach, spectrum, kmax, d, num_repeats, lamb=1e-6):

    errors_avg = np.zeros(kmax)
    errors_tot_MC = 0
    all_errs = np.zeros((kmax, num_repeats))
    all_MC = np.zeros(num_repeats)
    X_teach = sample_random_points(P_teach, d)
    alpha_teach = np.sign(
        np.random.random_sample(P_teach) - 0.5 * np.ones(P_teach))

    for i in range(num_repeats):

        X_teach = sample_random_points(P_teach, d)
        alpha_teach = np.sign(
            np.random.random_sample(P_teach) - 0.5 * np.ones(P_teach))

        X = sample_random_points(P, d)
        K_student = compute_kernel(X, X, spectrum, d, kmax)
        K_stu_te = compute_kernel(X, X_teach, spectrum, d, kmax)
        y = K_stu_te @ alpha_teach

        K_inv = np.linalg.inv(K_student + lamb * np.eye(P))
        alpha = K_inv @ y

        degens = np.array([gegenbauer.degeneracy(d, k) for k in range(kmax)])

        gram_ss = X @ X.T
        gram_st = X @ X_teach.T
        gram_tt = X_teach @ X_teach.T

        Q_ss = gegenbauer.get_gegenbauer_fast2(np.reshape(gram_ss, P**2), kmax,
                                               d)
        Q_st = gegenbauer.get_gegenbauer_fast2(
            np.reshape(gram_st, P * P_teach), kmax, d)
        Q_tt = gegenbauer.get_gegenbauer_fast2(np.reshape(gram_tt, P_teach**2),
                                               kmax, d)

        errors = np.zeros(kmax)
        for k in range(kmax):
            Q_ssk = np.reshape(Q_ss[k, :], (P, P))
            Q_stk = np.reshape(Q_st[k, :], (P, P_teach))
            Q_ttk = np.reshape(Q_tt[k, :], (P_teach, P_teach))
            errors[k] = spectrum[k]**2 * degens[k] * (
                alpha.T @ Q_ssk @ alpha - 2 * alpha.T @ Q_stk @ alpha_teach +
                alpha_teach.T @ Q_ttk @ alpha_teach)
        errors_avg += 1 / num_repeats * errors
        all_errs[:, i] = errors

        num_test = 2500
        X_test = sample_random_points(num_test, d)
        K_s = compute_kernel(X, X_test, spectrum, d, kmax)
        K_t = compute_kernel(X_teach, X_test, spectrum, d, kmax)

        y_s = K_s.T @ alpha
        y_t = K_t.T @ alpha_teach
        tot_error = 1 / num_test * np.linalg.norm(y_s - y_t)**2
        print("errors")
        print("expt:   %e" % tot_error)
        print("theory: %e" % np.sum(errors))

        errors_tot_MC += 1 / num_repeats * tot_error
        all_MC[i] = tot_error

    std_errs = sp.stats.sem(all_errs, axis=1)
    std_MC = sp.stats.sem(all_MC)

    return errors_avg, errors_tot_MC, std_errs, std_MC
parser.add_argument('--lamb',
                    type=float,
                    help='explicit regularization penalty',
                    default=0)
parser.add_argument('--NTK_depth',
                    type=int,
                    default=3,
                    help='depth of Fully Connected ReLU NTK')

args = parser.parse_args()
d = args.input_dim
lamb = args.lamb
depth = args.NTK_depth

kmax = 30
degens = np.array([gegenbauer.degeneracy(d, k) for k in range(kmax)])
spectrum = compute_NTK_spectrum.get_effective_spectrum([depth],
                                                       kmax,
                                                       d,
                                                       ker='NTK')[0, :]

s = [i for i in spectrum if i > 0]
P = 50
P_teach = 300
P_vals = np.logspace(0.25, 3, num=15).astype('int')
num_repeats = 50

all_errs = np.zeros((len(P_vals), kmax))
all_mc = np.zeros(len(P_vals))
std_errs = np.zeros((len(P_vals), kmax))
std_MC = np.zeros(len(P_vals))