def generalization_expt_kteach(P, spectrum, M, d, kmax, num_repeats, X_teach, alpha_teach, spectrum_teach, num_test=1000): all_mode_errs = np.zeros((num_repeats, kmax)) all_mc_errs = np.zeros(num_repeats) all_training_errs = np.zeros(num_repeats) degens = np.array([gegenbauer.degeneracy(d, k) for k in range(kmax)]) print("P = %d" % P) Theta = np.zeros((M, d)) X_test = np.zeros((num_test, d)) X = np.zeros((P, d)) for t in range(num_repeats): print("t=%d" % t) start = time.time() Theta = sample_random_points_jit(M, d, Theta) r = np.random.standard_normal(M) / np.sqrt(M) X = sample_random_points_jit(P, d, X) K = compute_kernel(X_teach, X) Y = K.T @ alpha_teach num_iter = 3 * P Theta, r, E_tr = SGD(X, Y, Theta, r, num_iter, readout_only=False) print("Etr = %e" % E_tr) counter = 1 print("finished SGD") print("num tries: %d" % counter) all_mode_errs[t, :] = get_mode_errs(Theta, Theta_teach, r, r_teach, kmax, d, degens) end = time.time() print("time = %lf" % (end - start)) X_test = sample_random_points_jit(num_test, d, X_test) #X_test = sample_random_points(num_test, d) Y_test = feedfoward(X_test, Theta_teach, r_teach) Y_pred = feedfoward(X_test, Theta, r) all_mc_errs[t] = 1 / num_test * np.linalg.norm(Y_test - Y_pred)**2 all_training_errs[t] = E_tr average_mode_errs = np.mean(all_mode_errs, axis=0) std_errs = np.std(all_mode_errs, axis=0) average_mc = np.mean(all_mc_errs) std_mc = np.std(all_mc_errs) print("average MC = %e" % average_mc) print("sum of modes = %e" % np.sum(average_mode_errs)) return average_mc, std_mc, np.mean(all_training_errs)
def compute_kernel(X, Xp, spectrum, d, kmax): P = X.shape[0] Pp = Xp.shape[0] gram = X @ Xp.T gram = np.reshape(gram, P * Pp) Q = gegenbauer.get_gegenbauer_fast2(gram, kmax, d) degens = np.array([gegenbauer.degeneracy(d, k) for k in range(kmax)]) K = Q.T @ (spectrum * degens) K = np.reshape(K, (P, Pp)) return K
def one_layer_NNGP(kmax, d): num_pts = 10000 alpha= d/2.0 -1 z,w= sp.special.roots_gegenbauer(num_pts, alpha) NNGP = (1/math.pi * (1-z**2)**(0.5) + (1-1/math.pi * np.arccos(z)) * z) Q = gegenbauer.get_gegenbauer(z, kmax, d) norms = np.array([gegenbauer.normalizing_factor(k,d/2.0-1) for k in range(kmax)]) degens = np.array( [gegenbauer.degeneracy(d,k) for k in range(kmax)] ) NNGP_coeffs = Q @ (w*NNGP) / (norms*degens) for i in range(kmax): if NNGP_coeffs[i] < 1e-25: NNGP_coeffs[i] = 0 return NNGP_coeffs
def get_effective_spectrum_hermite(layers, kmax, d, ker='NTK'): all_coeffs = np.zeros((len(layers), kmax)) num_pts = 2000 alpha = d / 2.0 - 1 z, w = sp.special.roots_hermite(num_pts) degens = np.array([gegenbauer.degeneracy(d, k) for k in range(kmax)]) #Q = gegenbauer.get_gegenbauer_fast2(z, kmax, d) scales = np.array( [2**k * math.factorial(k) * np.sqrt(math.pi) for k in range(kmax)]) inds_valid = [i for i in range(num_pts) if np.abs(z[i]) < np.sqrt(d)] z_valid = z[inds_valid] w_valid = w[inds_valid] num_pts = len(inds_valid) H = gegenbauer.get_hermite_fast(z_valid, kmax, d) NTK_mat = np.zeros((len(layers), num_pts)) max_element = np.amax(np.abs(z_valid / np.sqrt(d))) for i in range(len(layers)): l = layers[i] if ker == 'NTK': NTK_mat[i, :] = NTK(np.arccos(z_valid / np.sqrt(d)), l) else: NTK_mat[i, :] = NNGP(np.arccos(z_valid / np.sqrt(d)), l) scaled_NTK = NTK_mat * np.outer(np.ones(len(layers)), w_valid) scaled_H = H * np.outer(scales**(-1), np.ones(num_pts)) spectrum_scaled = scaled_NTK @ scaled_H.T spectrum_scaled = spectrum_scaled * np.heaviside( spectrum_scaled - 1e-30 * np.ones(len(spectrum_scaled)), 0) spectrum_true = np.zeros(spectrum_scaled.shape) for i in range(len(layers)): spectrum_true[i, :] = gegenbauer.hermite_to_gegenbauer_coeffs( spectrum_scaled[i, :], d) for i in range(len(layers)): for j in range(kmax - 1): if spectrum_true[i, j + 1] < spectrum_true[i, j] * 1e-5: spectrum_true[i, j + 1] = 0 return spectrum_true
def get_effective_spectrum(layers, kmax, d, ker = 'NTK'): all_coeffs = np.zeros((len(layers), kmax)) num_pts = 5000 alpha = d/2.0 - 1 z, w = sp.special.roots_gegenbauer(num_pts, alpha) degens = np.array( [gegenbauer.degeneracy(d,k) for k in range(kmax)] ) Q = gegenbauer.get_gegenbauer_fast2(z, kmax, d) NTK_mat = np.zeros((len(layers), num_pts)) for i in range(len(layers)): l = layers[i] if ker == 'NTK': NTK_mat[i,:] = NTK_recurse2(z, l) else: NTK_mat[i,:] = NNGP2(z, l) scaled_NTK = NTK_mat * np.outer( np.ones(len(layers)), w) scaled_Q = Q * np.outer(degens, np.ones(num_pts)) spectrum_scaled = scaled_NTK @ scaled_Q.T * gegenbauer.surface_area(d-1)/gegenbauer.surface_area(d) spectrum_scaled = spectrum_scaled * np.heaviside(spectrum_scaled-1e-14, 0) for i in range(kmax): if spectrum_scaled[0,i] < 1e-18: spectrum_scaled[0,i] = 0 khat = Q.T @ spectrum_scaled[0,:] k = NTK_mat[0,:] spectrum_true = spectrum_scaled / np.outer(len(layers), degens) for i in range(len(layers)): for j in range(kmax-1): if spectrum_true[i,j+1] < spectrum_true[i,j]*1e-5: spectrum_true[i,j+1] = 0 return spectrum_true
def generalization(P, X_teach, spectrum, kmax, d, num_repeats, lamb=1e-6): errors_avg = np.zeros(kmax) errors_tot_MC = 0 all_errs = np.zeros((kmax, num_repeats)) all_MC = np.zeros(num_repeats) X_teach = sample_random_points(P_teach, d) alpha_teach = np.sign( np.random.random_sample(P_teach) - 0.5 * np.ones(P_teach)) for i in range(num_repeats): X_teach = sample_random_points(P_teach, d) alpha_teach = np.sign( np.random.random_sample(P_teach) - 0.5 * np.ones(P_teach)) X = sample_random_points(P, d) K_student = compute_kernel(X, X, spectrum, d, kmax) K_stu_te = compute_kernel(X, X_teach, spectrum, d, kmax) y = K_stu_te @ alpha_teach K_inv = np.linalg.inv(K_student + lamb * np.eye(P)) alpha = K_inv @ y degens = np.array([gegenbauer.degeneracy(d, k) for k in range(kmax)]) gram_ss = X @ X.T gram_st = X @ X_teach.T gram_tt = X_teach @ X_teach.T Q_ss = gegenbauer.get_gegenbauer_fast2(np.reshape(gram_ss, P**2), kmax, d) Q_st = gegenbauer.get_gegenbauer_fast2( np.reshape(gram_st, P * P_teach), kmax, d) Q_tt = gegenbauer.get_gegenbauer_fast2(np.reshape(gram_tt, P_teach**2), kmax, d) errors = np.zeros(kmax) for k in range(kmax): Q_ssk = np.reshape(Q_ss[k, :], (P, P)) Q_stk = np.reshape(Q_st[k, :], (P, P_teach)) Q_ttk = np.reshape(Q_tt[k, :], (P_teach, P_teach)) errors[k] = spectrum[k]**2 * degens[k] * ( alpha.T @ Q_ssk @ alpha - 2 * alpha.T @ Q_stk @ alpha_teach + alpha_teach.T @ Q_ttk @ alpha_teach) errors_avg += 1 / num_repeats * errors all_errs[:, i] = errors num_test = 2500 X_test = sample_random_points(num_test, d) K_s = compute_kernel(X, X_test, spectrum, d, kmax) K_t = compute_kernel(X_teach, X_test, spectrum, d, kmax) y_s = K_s.T @ alpha y_t = K_t.T @ alpha_teach tot_error = 1 / num_test * np.linalg.norm(y_s - y_t)**2 print("errors") print("expt: %e" % tot_error) print("theory: %e" % np.sum(errors)) errors_tot_MC += 1 / num_repeats * tot_error all_MC[i] = tot_error std_errs = sp.stats.sem(all_errs, axis=1) std_MC = sp.stats.sem(all_MC) return errors_avg, errors_tot_MC, std_errs, std_MC
parser.add_argument('--lamb', type=float, help='explicit regularization penalty', default=0) parser.add_argument('--NTK_depth', type=int, default=3, help='depth of Fully Connected ReLU NTK') args = parser.parse_args() d = args.input_dim lamb = args.lamb depth = args.NTK_depth kmax = 30 degens = np.array([gegenbauer.degeneracy(d, k) for k in range(kmax)]) spectrum = compute_NTK_spectrum.get_effective_spectrum([depth], kmax, d, ker='NTK')[0, :] s = [i for i in spectrum if i > 0] P = 50 P_teach = 300 P_vals = np.logspace(0.25, 3, num=15).astype('int') num_repeats = 50 all_errs = np.zeros((len(P_vals), kmax)) all_mc = np.zeros(len(P_vals)) std_errs = np.zeros((len(P_vals), kmax)) std_MC = np.zeros(len(P_vals))