def main(unused_argv): train_size = FLAGS.train_size x_train, y_train, x_test, y_test = pickle.load( open("data_" + str(train_size) + ".p", "rb")) print("Got data") sys.stdout.flush() # Build the network init_fn, apply_fn, _ = stax.serial( stax.Dense(2048, 1., 0.05), # stax.Erf(), stax.Relu(), stax.Dense(1, 1., 0.05)) # initialize the network first time, to compute NTK randnnn = numpy.random.random_integers(np.iinfo(np.int32).min, high=np.iinfo(np.int32).max, size=2)[0] key = random.PRNGKey(randnnn) _, params = init_fn(key, (-1, 784)) # Create an MSE predictor to solve the NTK equation in function space. # we assume that the NTK is approximately the same for any sample of parameters (true in the limit of infinite width) print("Making NTK") sys.stdout.flush() ntk = nt.batch(nt.empirical_ntk_fn(apply_fn), batch_size=4, device_count=1) g_dd = ntk(x_train, None, params) pickle.dump(g_dd, open("ntk_train_" + str(FLAGS.train_size) + ".p", "wb")) g_td = ntk(x_test, x_train, params) pickle.dump(g_td, open("ntk_train_test_" + str(FLAGS.train_size) + ".p", "wb")) predictor = nt.predict.gradient_descent_mse(g_dd, y_train, g_td)
def main(unused_argv): # Build data pipelines. print('Loading data.') x_train, y_train, x_test, y_test = \ datasets.get_dataset('cifar10', FLAGS.train_size, FLAGS.test_size) # Build the infinite network. _, _, kernel_fn = stax.serial(stax.Dense(1, 2., 0.05), stax.Relu(), stax.Dense(1, 2., 0.05)) # Optionally, compute the kernel in batches, in parallel. kernel_fn = nt.batch(kernel_fn, device_count=0, batch_size=FLAGS.batch_size) start = time.time() # Bayesian and infinite-time gradient descent inference with infinite network. predict_fn = nt.predict.gradient_descent_mse_ensemble(kernel_fn, x_train, y_train, diag_reg=1e-3) fx_test_nngp, fx_test_ntk = predict_fn(x_test=x_test) fx_test_nngp.block_until_ready() fx_test_ntk.block_until_ready() duration = time.time() - start print('Kernel construction and inference done in %s seconds.' % duration) # Print out accuracy and loss for infinite network predictions. loss = lambda fx, y_hat: 0.5 * np.mean((fx - y_hat)**2) util.print_summary('NNGP test', y_test, fx_test_nngp, None, loss) util.print_summary('NTK test', y_test, fx_test_ntk, None, loss)
def train(kernel_fn, x_train, y_train, x_test, y_test): batched_kernel_fn = nt.batch(kernel_fn, 25) K_test_train = batched_kernel_fn(x_test, x_train).ntk K_train_train = batched_kernel_fn(x_train, x_train).ntk # NNGP = batched_kernel_fn(x_train, x_train).nngp # print(NNGP) #print(K_train_train) # print(K_train_train) y_test_pred = K_test_train @ np.linalg.inv(K_train_train) @ y_train #print(y_test_pred) loss_d = np.mean((y_test_pred - y_test)**2) y_test_class = np.where(y_test_pred > 0, 1., -1.) acc_d = np.mean(y_test_class == y_test) # y_train_pred = K_train_train @ np.linalg.inv(K_train_train) @ y_train # loss_t = np.mean((y_train_pred - y_train)**2) # y_train_class = np.where(y_train_pred > 0, 1., -1.) # acc_t = np.mean(y_train_class == y_train) # x_id = np.eye(D).reshape(D, img_size[0], img_size[1], img_size[2]) # K_id_train = batched_kernel_fn(x_id, x_train).ntk # operator = K_id_train @ np.linalg.inv(K_train_train) @ y_train # norm = np.linalg.norm(operator) # batched_kernel_fn = nt.batch(kernel_fn, 32) # B_matrix = batched_kernel_fn(x_id, x_id).ntk # w, _ = numpy.linalg.eig(B_matrix) # condition_no = numpy.max(w)/numpy.min(w) return loss_d, acc_d
def main(): train_size = 1000 test_size = 1000 batch_size = 0 init_fn, apply_fn, kernel_fn = WideResnet(block_size=4, k=1, num_classes=10) x_train, y_train, x_test, y_test = get_dataset( 'cifar10', train_size, test_size, do_flatten_and_normalize=False) kernel_fn = nt.batch(kernel_fn, device_count=0, batch_size=batch_size) start = time.time() # Bayesian and infinite-time gradient descent inference with infinite network. fx_test_nngp, fx_test_ntk = nt.predict.gradient_descent_mse_ensemble( kernel_fn, x_train, y_train, x_test, get=('nngp', 'ntk'), diag_reg=1e-3) fx_test_nngp.block_until_ready() fx_test_ntk.block_until_ready() duration = time.time() - start print('Kernel construction and inference done in %s seconds.' % duration) # Print out accuracy and loss for infinite network predictions. loss = lambda fx, y_hat: 0.5 * np.mean((fx - y_hat)**2) print_summary('NNGP test', y_test, fx_test_nngp, None, loss) print_summary('NTK test', y_test, fx_test_ntk, None, loss)
def define(args): hidden_layers = [] for _ in range(args.num_hidden_layers): hidden_layers.append(stax.Dense(args.hidden_neurons, W_std=args.W_std, b_std=args.b_std)) hidden_layers.append(stax.Relu()) init_fn, apply_fn, kernel_fn = stax.serial( *hidden_layers, stax.Dense(args.output_dim, W_std=args.W_std, b_std=args.b_std) ) apply_fn = jit(apply_fn) batched_kernel_fn = nt.batch(kernel_fn, batch_size=args.batch_size, device_count=-1) return init_fn, apply_fn, kernel_fn, batched_kernel_fn
def main(unused_argv): # Build data pipelines. print('Loading data.') x_train, y_train, x_test, y_test = \ datasets.get_dataset('mnist', FLAGS.train_size, FLAGS.test_size) # Build the network init_fn, apply_fn, _ = stax.serial( stax.Dense(512, 1., 0.05), stax.Erf(), stax.Dense(10, 1., 0.05)) key = random.PRNGKey(0) _, params = init_fn(key, (-1, 784)) # Create and initialize an optimizer. opt_init, opt_apply, get_params = optimizers.sgd(FLAGS.learning_rate) state = opt_init(params) # Create an mse loss function and a gradient function. loss = lambda fx, y_hat: 0.5 * np.mean((fx - y_hat) ** 2) grad_loss = jit(grad(lambda params, x, y: loss(apply_fn(params, x), y))) # Create an MSE predictor to solve the NTK equation in function space. ntk = nt.batch(nt.empirical_ntk_fn(apply_fn, vmap_axes=0), batch_size=4, device_count=0) g_dd = ntk(x_train, None, params) g_td = ntk(x_test, x_train, params) predictor = nt.predict.gradient_descent_mse(g_dd, y_train) # Get initial values of the network in function space. fx_train = apply_fn(params, x_train) fx_test = apply_fn(params, x_test) # Train the network. train_steps = int(FLAGS.train_time // FLAGS.learning_rate) print('Training for {} steps'.format(train_steps)) for i in range(train_steps): params = get_params(state) state = opt_apply(i, grad_loss(params, x_train, y_train), state) # Get predictions from analytic computation. print('Computing analytic prediction.') fx_train, fx_test = predictor(FLAGS.train_time, fx_train, fx_test, g_td) # Print out summary data comparing the linear / nonlinear model. util.print_summary('train', y_train, apply_fn(params, x_train), fx_train, loss) util.print_summary('test', y_test, apply_fn(params, x_test), fx_test, loss)
def maker(beta=0, W_std=0.8, b_std=0, diag_reg=1e-4): init_fn, apply_fn, ker_fn = get_network(net_type=net_type, w_std=W_std, b_std=b_std) ker_fn = nt.batch(ker_fn, batch_size=batch_size, device_count=num_of_gpus, store_on_device=True) predict_fn = gradient_descent_mse_vib(beta, train_images, train_labels, diag_reg, ker_fn) predict_fn = partial(predict_fn, get='ntk', compute_cov=False) init_pred_train = predict_fn(t=0, x_test=train_images) init_pred_test = predict_fn(t=0, x_test=test_images) return partial(calc_metrics, predict_fn=predict_fn, init_preds=[init_pred_train, init_pred_test], beta=beta)
def infinite_resnet(train_embedding, test_embedding, data_set): _, _, kernel_fn = wide_resnet(block_size=4, k=1, num_classes=2) kernel_fn = nt.batch(kernel_fn, device_count=0, batch_size=0) fx_test_nngp, fx_test_ntk = nt.predict.gp_inference(kernel_fn, train_embedding, data_set['Y_train'], test_embedding, get=('nngp', 'ntk'), diag_reg=1e-3) fx_test_nngp.block_until_ready() fx_test_ntk.block_until_ready() # Print out accuracy and loss for infinite network predictions. loss = lambda fx, y_hat: 0.5 * np.mean((fx - y_hat)**2) util.print_summary('NNGP test', data_set['Y_test'], fx_test_nngp, None, loss) util.print_summary('NTK test', data_set['Y_test'], fx_test_ntk, None, loss)
def infinite_fcn(train_embedding, test_embedding, data_set, binary=True): _, _, kernel_fn = stax.serial( stax.Dense(64, 2., 0.05), stax.Relu(), stax.Dense(32, 2., 0.05), stax.Relu(), stax.Dense(4, 2., 0.05), stax.Relu(), ) # 0 for no batching, whole batch kernel_fn = nt.batch(kernel_fn, device_count=0, batch_size=0) start = time.time() # Bayesian and infinite-time gradient descent inference with infinite network. #for i in range(10): predict_fn = \ nt.predict.gradient_descent_mse_ensemble(kernel_fn, train_embedding, data_set['Y_train'], diag_reg_absolute_scale=True, learning_rate=1, diag_reg=1e-3) #1e0 1e-3 nngp_mean, nngp_covariance = predict_fn(x_test=test_embedding, get='nngp', compute_cov=True) #fx_test_nngp.block_until_ready() #fx_test_ntk.block_until_ready() duration = time.time() - start print('Kernel construction and inference done in %s seconds.' % duration) # Print out accuracy and loss for infinite network predictions. loss = lambda fx, y_hat: 0.5 * np.mean((fx - y_hat)**2) utils.print_summary('NNGP test', data_set['Y_test'], nngp_mean, None, loss, nngp_covariance, binary=binary)
def main(unused_argv): # Load and normalize data print('Loading data...') x_train, y_train, x_test, y_test = datasets.get_dataset('mnist', n_train=10, n_test=10, permute_train=True) # Reformat MNIST data to 28x28x1 pictures x_train = np.asarray(x_train.reshape(-1, 28, 28, 1)) x_test = np.asarray(x_test.reshape(-1, 28, 28, 1)) print(f'Data loaded and reshaped with n_train = {x_train.shape[0]} (batch size {FLAGS.batch_size_kernel}) and ' f'n_test = {x_test.shape[0]}.') # # Add random translation to images # x_train = util.add_translation(x_train, FLAGS.max_pixel) # x_test = util.add_translation(x_test, FLAGS.max_pixel) # print(f'Random translations by up to {FLAGS.max_pixel} pixels added') # # Add random translations with padding # x_train = util.add_padded_translation(x_train, 10) # x_test = util.add_padded_translation(x_test, 10) # print(f'Random translations with additional padding up to 10 pixels added') # Build the LeNet network init_fn, f, kernel_fn = util.build_le_net(FLAGS.network_width) print('Network build complete') # Construct the kernel function # Use 'store_on_device = False' for larger kernels kernel_fn = nt.batch(kernel_fn, device_count=-1, batch_size=FLAGS.batch_size_kernel, store_on_device=False) # Set start time start_inf = time.time() # Bayesian and infinite-time gradient descent inference with infinite network print('Starting bayesian and infinite-time gradient descent inference with infinite network') predict_fn = nt.predict.gradient_descent_mse_ensemble( kernel_fn=kernel_fn, x_train=x_train, y_train=y_train, diag_reg=1e-6 ) duration_kernel = time.time() - start_inf print(f'Kernel constructed in {duration_kernel} seconds.') # fx_test_nngp_ub, fx_test_ntk_ub = predict_fn(x_test=x_test, get=('nngp', 'ntk')) fx_test_nngp, fx_test_ntk = [] * x_test.shape[0], [] * x_test.shape[0] print('Output vector allocated.') # print(f'Available GPU memory: {util.get_gpu_memory()} MiB') # Compute predictions in batches for i in range(x_test.shape[0] // FLAGS.batch_size_output): time_batch = time.time() start, end = i * FLAGS.batch_size_output, (i+1) * FLAGS.batch_size_output x = x_test[start:end] tmp_nngp, tmp_ntk = predict_fn(x_test=x, get=('nngp', 'ntk')) # tmp_ntk = predict_fn(x_test=x, get='ntk') duration_batch = time.time() - time_batch print(f'Batch {i+1} predicted in {duration_batch} seconds.') # print(f'Available GPU memory: {util.get_gpu_memory()} MiB') fx_test_nngp[start:end] = tmp_nngp fx_test_ntk[start:end] = tmp_ntk fx_test_nngp = np.array(fx_test_nngp) fx_test_ntk = np.array(fx_test_ntk) # fx_test_nngp.block_until_ready() # fx_test_ntk.block_until_ready() duration_inf = time.time() - start_inf print(f'Inference done in {duration_inf} seconds.') # Print out accuracy and loss for infinite network predictions. loss = lambda fx, y_hat: 0.5 * np.mean((fx - y_hat) ** 2) util.print_summary('NNGP test', y_test, fx_test_nngp, None, loss) util.print_summary('NTK test', y_test, fx_test_ntk, None, loss)
def main(*args, use_dummy_data: bool = False, **kwargs) -> None: # Mask all padding with this value. mask_constant = 100. if use_dummy_data: x_train, y_train, x_test, y_test = _get_dummy_data(mask_constant) else: # Build data pipelines. print('Loading IMDb data.') x_train, y_train, x_test, y_test = datasets.get_dataset( name='imdb_reviews', n_train=FLAGS.n_train, n_test=FLAGS.n_test, do_flatten_and_normalize=False, data_dir=FLAGS.imdb_path, input_key='text') # Embed words and pad / truncate sentences to a fixed size. x_train, x_test = datasets.embed_glove( xs=[x_train, x_test], glove_path=FLAGS.glove_path, max_sentence_length=FLAGS.max_sentence_length, mask_constant=mask_constant) # Build the infinite network. # Not using the finite model, hence width is set to 1 everywhere. _, _, kernel_fn = stax.serial( stax.Conv(out_chan=1, filter_shape=(9, ), strides=(1, ), padding='VALID'), stax.Relu(), stax.GlobalSelfAttention(n_chan_out=1, n_chan_key=1, n_chan_val=1, pos_emb_type='SUM', W_pos_emb_std=1., pos_emb_decay_fn=lambda d: 1 / (1 + d**2), n_heads=1), stax.Relu(), stax.GlobalAvgPool(), stax.Dense(out_dim=1)) # Optionally, compute the kernel in batches, in parallel. kernel_fn = nt.batch(kernel_fn, device_count=-1, batch_size=FLAGS.batch_size) start = time.time() # Bayesian and infinite-time gradient descent inference with infinite network. predict = nt.predict.gradient_descent_mse_ensemble( kernel_fn=kernel_fn, x_train=x_train, y_train=y_train, diag_reg=1e-6, mask_constant=mask_constant) fx_test_nngp, fx_test_ntk = predict(x_test=x_test, get=('nngp', 'ntk')) fx_test_nngp.block_until_ready() fx_test_ntk.block_until_ready() duration = time.time() - start print(f'Kernel construction and inference done in {duration} seconds.') # Print out accuracy and loss for infinite network predictions. loss = lambda fx, y_hat: 0.5 * np.mean((fx - y_hat)**2) util.print_summary('NNGP test', y_test, fx_test_nngp, None, loss) util.print_summary('NTK test', y_test, fx_test_ntk, None, loss)
def main(unused_argv): # Build data pipelines. print('Loading data.') x_train, y_train, x_test, y_test = \ datasets.get_dataset('mnist', FLAGS.train_size, FLAGS.test_size) # Build the infinite network. l = 5 w_std = 1.5 b_std = 2 net0 = stax.Dense(1, w_std, b_std) nets = [net0] k_layer = [] K = net0[2](x_train, None) k_layer.append(K.nngp) for l in range(1, l+1): net_l = stax.serial(stax.Relu(), stax.Dense(1, w_std, b_std)) K = net_l[2](K) k_layer.append(K.nngp) nets += [stax.serial(nets[-1], net_l)] kernel_fn = nets[-1][2] # Optionally, compute the kernel in batches, in parallel. kernel_fn = nt.batch(kernel_fn, device_count=0, batch_size=FLAGS.batch_size) start = time.time() # Bayesian and infinite-time gradient descent inference with infinite network. fx_test_nngp, fx_test_ntk = nt.predict.gp_inference(kernel_fn, x_train, y_train, x_test, get=('nngp', 'ntk'), diag_reg=1e-3) fx_test_nngp.block_until_ready() fx_test_ntk.block_until_ready() duration = time.time() - start print('Kernel construction and inference done in %s seconds.' % duration) # Print out accuracy and loss for infinite network predictions. loss = lambda fx, y_hat: 0.5 * np.mean((fx - y_hat) ** 2) util.print_summary('NNGP test', y_test, fx_test_nngp, None, loss) grid = [] count = 1 k_plot = [] for i in k_layer: grid.append(count) count += 1 k_plot.append(np.log(i[5,5])) # plt.plot(grid, k_plot) # plt.xlabel('layer ; w_var = 10, b_var = 2, accuracy = 93%') # plt.ylabel('Log (K[5][5]) ') w, v = LA.eig(k_layer[-1]) w = np.sort(w) #print(w) #plt.scatter(w, np.zeros(len(w))) index = [] for i in range(1,len(w)+1): index.append(i) w.sort() plt.scatter(index,np.log(w)[::-1]/np.log(10)) #plt.plot(index,mp) plt.ylabel("log10[eigen val]") plt.show() sio.savemat('mnist_l10_wvar=0_85_b_var=0_1.mat', { 'kernel': k_layer[-1] })
def main(unused_argv): # Build data pipelines. print('Loading data.') x_train, y_train, x_test, y_test = \ datasets.get_dataset('mnist', FLAGS.train_size, FLAGS.test_size) # Build the network init_fn, apply_fn, _ = stax.serial( stax.Dense(2048, 1., 0.05), # stax.Erf(), stax.Relu(), stax.Dense(2048, 1., 0.05), # stax.Erf(), stax.Relu(), stax.Dense(10, 1., 0.05)) key = random.PRNGKey(0) _, params = init_fn(key, (-1, 784)) # params # Create and initialize an optimizer. opt_init, opt_apply, get_params = optimizers.sgd(FLAGS.learning_rate) state = opt_init(params) # state # Create an mse loss function and a gradient function. loss = lambda fx, y_hat: 0.5 * np.mean((fx - y_hat) ** 2) grad_loss = jit(grad(lambda params, x, y: loss(apply_fn(params, x), y))) # Create an MSE predictor to solve the NTK equation in function space. ntk = nt.batch(nt.empirical_ntk_fn(apply_fn), batch_size=4, device_count=0) g_dd = ntk(x_train, None, params) g_td = ntk(x_test, x_train, params) predictor = nt.predict.gradient_descent_mse(g_dd, y_train, g_td) # g_dd.shape m = FLAGS.train_size print(m) n = m*10 m_test = FLAGS.test_size n_test = m_test*10 # g_td.shape # predictor # g_dd # type(g_dd) # g_dd.shape theta = g_dd.transpose((0,2,1,3)).reshape(n,n) theta_test = ntk(x_test, None, params).transpose((0,2,1,3)).reshape(n_test,n_test) theta_tilde = g_td.transpose((0,2,1,3)).reshape(n_test,n) #NNGP K = nt.empirical_nngp_fn(apply_fn)(x_train,None,params) K = np.kron(theta,np.eye(10)) K_test = nt.empirical_nngp_fn(apply_fn)(x_test,None,params) K_test = np.kron(theta_test,np.eye(10)) K_tilde = nt.empirical_nngp_fn(apply_fn)(x_test,x_train,params) K_tilde = np.kron(theta_tilde,np.eye(10)) decay_matrix = np.eye(n)-scipy.linalg.expm(-t*theta) Sigma = K + np.matmul(decay_matrix, np.matmul(K, np.matmul(np.linalg.inv(theta), np.matmul(decay_matrix, theta))) - 2*K) # K.shape theta # alpha = np.matmul(np.linalg.inv(K),np.matmul(theta,np.linalg.inv(theta))) # y_train # alpha = np.matmul(np.linalg.inv(K), y_train.reshape(1280)) # Sigma = K + np.matmul() # K = theta sigma_noise = 1.0 Y = y_train.reshape(n) alpha = np.matmul(np.linalg.inv(np.eye(n)*(sigma_noise**2)+K),Y) # cov = np.linalg.inv(np.linalg.inv(K)+np.eye(n)/(sigma_noise**2)) # covi = np.linalg.inv(cov) # covi = np.linalg.inv(K)+np.eye(n)/(sigma_noise**2) # print(covi) # np.linalg.det(K) eigs = np.linalg.eigh(K)[0] logdetcoviK = np.sum(np.log((eigs+sigma_noise**2) /sigma_noise**2)) # coviK = np.matmul(covi,K) # coviK = np.eye(n) + K/(sigma_noise**2) # coviK # covi # np.linalg.det() # KL = 0.5*np.log(np.linalg.det(coviK)) + 0.5*np.trace(np.linalg.inv(coviK)) + 0.5*np.matmul(alpha.T,np.matmul(K,alpha)) - n/2 KL = 0.5*logdetcoviK + 0.5*np.trace(np.linalg.inv(coviK)) + 0.5*np.matmul(alpha.T,np.matmul(K,alpha)) - n/2 print(KL) delta = 2**-10 bound = (KL+2*np.log(m)+1-np.log(delta))/m bound = 1-np.exp(-bound) bound print("bound", bound) import numpy bigK = numpy.zeros((n+n_test,n+n_test)) bigK bigK[0:n,0:n] = K bigK[0:n,n:] = theta_tilde.T bigK[n:,0:n] = theta_tilde bigK[n:,n:] = theta_test init_ntk_f = numpy.random.multivariate_normal(np.zeros(n+n_test),bigK) fx_train = init_ntk_f[:n].reshape(m,10) fx_test = init_ntk_f[n:].reshape(m_test,10) # Get initial values of the network in function space. # fx_train = apply_fn(params, x_train) # fx_test = apply_fn(params, x_test) # Train the network. train_steps = int(FLAGS.train_time // FLAGS.learning_rate) print('Training for {} steps'.format(train_steps)) for i in range(train_steps): params = get_params(state) state = opt_apply(i, grad_loss(params, x_train, y_train), state) # Get predictions from analytic computation. print('Computing analytic prediction.') # fx_train, fx_test = predictor(FLAGS.train_time, fx_train, fx_test) fx_train, fx_test = predictor(FLAGS.train_time, fx_train, fx_test) # Print out summary data comparing the linear / nonlinear model. util.print_summary('train', y_train, apply_fn(params, x_train), fx_train, loss) util.print_summary('test', y_test, apply_fn(params, x_test), fx_test, loss)
stax.Conv(512, (3, 3), strides=(1, 1), W_std=W_std, b_std=b_std, padding='SAME'),\ stax.Relu(),\ stax.AvgPool((2, 2), strides=(2, 2), padding='VALID'),\ stax.Conv(512, (3, 3), strides=(1, 1), W_std=W_std, b_std=b_std, padding='SAME'),\ stax.Relu(),\ stax.AvgPool((2, 2), strides=(2, 2), padding='VALID'),\ stax.AvgPool((2, 2), strides=(2, 2), padding='VALID'),\ stax.AvgPool((2, 2), strides=(2, 2), padding='VALID'),\ stax.Flatten(),\ stax.Dense(10, W_std, b_std)) else: raise Exception('Invalid Input Error') apply_fn = jit(apply_fn) kernel_fn = jit(kernel_fn, static_argnums=(2, )) kernel_fn = nt.batch(kernel_fn, batch_size=20) X1 = X[row_id * m:(row_id + 1) * m, :, :, :] assert X1.shape[0] == m and X1.shape[1] == 32 and X1.shape[ 2] == 32 and X1.shape[3] == 3 # Training kernel K = onp.zeros((m, n), dtype=onp.float32) col_count = onp.int(n / m) for col_id in range(row_id, col_count): t1 = time.time() X2 = X[col_id * m:(col_id + 1) * m, :, :, :] assert X2.shape[0] == m and X2.shape[1] == 32 and X2.shape[ 2] == 32 and X2.shape[3] == 3 temp = kernel_fn(X1, X2, 'ntk') K[:, col_id * m:(col_id + 1) * m] = temp.astype(onp.float32)
# params # Create and initialize an optimizer. opt_init, opt_apply, get_params = optimizers.sgd(FLAGS.learning_rate) state = opt_init(params) # state #%% # Create an mse loss function and a gradient function. loss = lambda fx, y_hat: 0.5 * np.mean((fx - y_hat) ** 2) grad_loss = jit(grad(lambda params, x, y: loss(apply_fn(params, x), y))) # Create an MSE predictor to solve the NTK equation in function space. ntk = nt.batch(nt.empirical_ntk_fn(apply_fn), batch_size=4, device_count=0) g_dd = ntk(x_train, None, params) g_td = ntk(x_test, x_train, params) predictor = nt.predict.gradient_descent_mse(g_dd, y_train, g_td) #%% m = FLAGS.train_size n = m*10 m_test = FLAGS.test_size n_test = m_test*10 # g_td.shape # predictor # g_dd # type(g_dd) # g_dd.shape theta = g_dd.transpose((0,2,1,3)).reshape(n,n)
from train import validate import datasets from matplotlib import pyplot as plt # Model definition init_fn, apply_fn, kernel_fn = nt.stax.serial( nt.stax.Dense(512, 1.0, 0.05), nt.stax.Erf(), nt.stax.Dense(512, 1.0, 0.05), nt.stax.Erf(), nt.stax.Dense(10, 1.0, 0.05), ) apply_fn = jax.jit(apply_fn) kernel_fn = nt.batch(kernel_fn, 64) def kernel_fit(x_tr, y_tr, lam=1e-3): g_dd = kernel_fn(x_tr, None, "ntk") predictor = nt.predict.gradient_descent_mse(g_dd, y_tr - 0.1, diag_reg=lam) def model(x_te): g_td = kernel_fn(x_te, x_tr, "ntk") return predictor(None, None, -1, g_td) return model n_train, n_test = 2048, 128 # Generating dataset
def main(unused_argv): # Build data pipelines. print('Loading data.') x_train, y_train, x_test, y_test = \ datasets.mnist(FLAGS.train_size, FLAGS.test_size) # x_train import numpy # numpy.argmax(y_train,1)%2 # y_train_tmp = numpy.zeros((y_train.shape[0],2)) # y_train_tmp[np.arange(y_train.shape[0]),numpy.argmax(y_train,1)%2] = 1 # y_train = y_train_tmp # y_test_tmp = numpy.zeros((y_test.shape[0],2)) # y_test_tmp[np.arange(y_train.shape[0]),numpy.argmax(y_test,1)%2] = 1 # y_test = y_test_tmp y_train_tmp = numpy.argmax(y_train, 1) % 2 y_train = np.expand_dims(y_train_tmp, 1) y_test_tmp = numpy.argmax(y_test, 1) % 2 y_test = np.expand_dims(y_test_tmp, 1) # print(y_train) # Build the network # init_fn, apply_fn, _ = stax.serial( # stax.Dense(2048, 1., 0.05), # # stax.Erf(), # stax.Relu(), # stax.Dense(2048, 1., 0.05), # # stax.Erf(), # stax.Relu(), # stax.Dense(10, 1., 0.05)) init_fn, apply_fn, _ = stax.serial(stax.Dense(2048, 1., 0.05), stax.Erf(), stax.Dense(1, 1., 0.05)) # key = random.PRNGKey(0) randnnn = numpy.random.random_integers(np.iinfo(np.int32).min, high=np.iinfo(np.int32).max, size=2)[0] key = random.PRNGKey(randnnn) _, params = init_fn(key, (-1, 784)) # params # Create and initialize an optimizer. opt_init, opt_apply, get_params = optimizers.sgd(FLAGS.learning_rate) state = opt_init(params) # state # Create an mse loss function and a gradient function. loss = lambda fx, y_hat: 0.5 * np.mean((fx - y_hat)**2) grad_loss = jit(grad(lambda params, x, y: loss(apply_fn(params, x), y))) # Create an MSE predictor to solve the NTK equation in function space. ntk = nt.batch(nt.empirical_ntk_fn(apply_fn), batch_size=4, device_count=0) g_dd = ntk(x_train, None, params) g_td = ntk(x_test, x_train, params) predictor = nt.predict.gradient_descent_mse(g_dd, y_train, g_td) # g_dd.shape # Get initial values of the network in function space. fx_train = apply_fn(params, x_train) fx_test = apply_fn(params, x_test) # Train the network. train_steps = int(FLAGS.train_time // FLAGS.learning_rate) print('Training for {} steps'.format(train_steps)) for i in range(train_steps): params = get_params(state) state = opt_apply(i, grad_loss(params, x_train, y_train), state) # Get predictions from analytic computation. print('Computing analytic prediction.') # fx_train, fx_test = predictor(FLAGS.train_time, fx_train, fx_test) fx_train, fx_test = predictor(FLAGS.train_time, fx_train, fx_test) # Print out summary data comparing the linear / nonlinear model. util.print_summary('train', y_train, apply_fn(params, x_train), fx_train, loss) util.print_summary('test', y_test, apply_fn(params, x_test), fx_test, loss)