def main(unused_argv): # Build data pipelines. print('Loading data.') x_train, y_train, x_test, y_test = \ datasets.get_dataset('cifar10', FLAGS.train_size, FLAGS.test_size) # Build the infinite network. _, _, kernel_fn = stax.serial(stax.Dense(1, 2., 0.05), stax.Relu(), stax.Dense(1, 2., 0.05)) # Optionally, compute the kernel in batches, in parallel. kernel_fn = nt.batch(kernel_fn, device_count=0, batch_size=FLAGS.batch_size) start = time.time() # Bayesian and infinite-time gradient descent inference with infinite network. predict_fn = nt.predict.gradient_descent_mse_ensemble(kernel_fn, x_train, y_train, diag_reg=1e-3) fx_test_nngp, fx_test_ntk = predict_fn(x_test=x_test) fx_test_nngp.block_until_ready() fx_test_ntk.block_until_ready() duration = time.time() - start print('Kernel construction and inference done in %s seconds.' % duration) # Print out accuracy and loss for infinite network predictions. loss = lambda fx, y_hat: 0.5 * np.mean((fx - y_hat)**2) util.print_summary('NNGP test', y_test, fx_test_nngp, None, loss) util.print_summary('NTK test', y_test, fx_test_ntk, None, loss)
def main(unused_argv): # Build data and . print('Loading data.') x_train, y_train, x_test, y_test = datasets.get_dataset('mnist', permute_train=True) # Build the network init_fn, f, _ = stax.serial(stax.Dense(2048, 1., 0.05), stax.Erf(), stax.Dense(10, 1., 0.05)) key = random.PRNGKey(0) _, params = init_fn(key, (-1, 784)) # Linearize the network about its initial parameters. f_lin = nt.linearize(f, params) # Create and initialize an optimizer for both f and f_lin. opt_init, opt_apply, get_params = optimizers.momentum( FLAGS.learning_rate, 0.9) opt_apply = jit(opt_apply) state = opt_init(params) state_lin = opt_init(params) # Create a cross-entropy loss function. loss = lambda fx, y_hat: -np.mean(logsoftmax(fx) * y_hat) # Specialize the loss function to compute gradients for both linearized and # full networks. grad_loss = jit(grad(lambda params, x, y: loss(f(params, x), y))) grad_loss_lin = jit(grad(lambda params, x, y: loss(f_lin(params, x), y))) # Train the network. print('Training.') print('Epoch\tLoss\tLinearized Loss') print('------------------------------------------') epoch = 0 steps_per_epoch = 50000 // FLAGS.batch_size for i, (x, y) in enumerate( datasets.minibatch(x_train, y_train, FLAGS.batch_size, FLAGS.train_epochs)): params = get_params(state) state = opt_apply(i, grad_loss(params, x, y), state) params_lin = get_params(state_lin) state_lin = opt_apply(i, grad_loss_lin(params_lin, x, y), state_lin) if i % steps_per_epoch == 0: print('{}\t{:.4f}\t{:.4f}'.format(epoch, loss(f(params, x), y), loss(f_lin(params_lin, x), y))) epoch += 1 # Print out summary data comparing the linear / nonlinear model. x, y = x_train[:10000], y_train[:10000] util.print_summary('train', y, f(params, x), f_lin(params_lin, x), loss) util.print_summary('test', y_test, f(params, x_test), f_lin(params_lin, x_test), loss)
def main(unused_argv): # Build data pipelines. print('Loading data.') x_train, y_train, x_test, y_test = \ datasets.get_dataset('mnist', FLAGS.train_size, FLAGS.test_size) # Build the network init_fn, apply_fn, _ = stax.serial( stax.Dense(512, 1., 0.05), stax.Erf(), stax.Dense(10, 1., 0.05)) key = random.PRNGKey(0) _, params = init_fn(key, (-1, 784)) # Create and initialize an optimizer. opt_init, opt_apply, get_params = optimizers.sgd(FLAGS.learning_rate) state = opt_init(params) # Create an mse loss function and a gradient function. loss = lambda fx, y_hat: 0.5 * np.mean((fx - y_hat) ** 2) grad_loss = jit(grad(lambda params, x, y: loss(apply_fn(params, x), y))) # Create an MSE predictor to solve the NTK equation in function space. ntk = nt.batch(nt.empirical_ntk_fn(apply_fn, vmap_axes=0), batch_size=4, device_count=0) g_dd = ntk(x_train, None, params) g_td = ntk(x_test, x_train, params) predictor = nt.predict.gradient_descent_mse(g_dd, y_train) # Get initial values of the network in function space. fx_train = apply_fn(params, x_train) fx_test = apply_fn(params, x_test) # Train the network. train_steps = int(FLAGS.train_time // FLAGS.learning_rate) print('Training for {} steps'.format(train_steps)) for i in range(train_steps): params = get_params(state) state = opt_apply(i, grad_loss(params, x_train, y_train), state) # Get predictions from analytic computation. print('Computing analytic prediction.') fx_train, fx_test = predictor(FLAGS.train_time, fx_train, fx_test, g_td) # Print out summary data comparing the linear / nonlinear model. util.print_summary('train', y_train, apply_fn(params, x_train), fx_train, loss) util.print_summary('test', y_test, apply_fn(params, x_test), fx_test, loss)
def main(*args, use_dummy_data: bool = False, **kwargs) -> None: # Mask all padding with this value. mask_constant = 100. if use_dummy_data: x_train, y_train, x_test, y_test = _get_dummy_data(mask_constant) else: # Build data pipelines. print('Loading IMDb data.') x_train, y_train, x_test, y_test = datasets.get_dataset( name='imdb_reviews', n_train=FLAGS.n_train, n_test=FLAGS.n_test, do_flatten_and_normalize=False, data_dir=FLAGS.imdb_path, input_key='text') # Embed words and pad / truncate sentences to a fixed size. x_train, x_test = datasets.embed_glove( xs=[x_train, x_test], glove_path=FLAGS.glove_path, max_sentence_length=FLAGS.max_sentence_length, mask_constant=mask_constant) # Build the infinite network. # Not using the finite model, hence width is set to 1 everywhere. _, _, kernel_fn = stax.serial( stax.Conv(out_chan=1, filter_shape=(9, ), strides=(1, ), padding='VALID'), stax.Relu(), stax.GlobalSelfAttention(n_chan_out=1, n_chan_key=1, n_chan_val=1, pos_emb_type='SUM', W_pos_emb_std=1., pos_emb_decay_fn=lambda d: 1 / (1 + d**2), n_heads=1), stax.Relu(), stax.GlobalAvgPool(), stax.Dense(out_dim=1)) # Optionally, compute the kernel in batches, in parallel. kernel_fn = nt.batch(kernel_fn, device_count=-1, batch_size=FLAGS.batch_size) start = time.time() # Bayesian and infinite-time gradient descent inference with infinite network. predict = nt.predict.gradient_descent_mse_ensemble( kernel_fn=kernel_fn, x_train=x_train, y_train=y_train, diag_reg=1e-6, mask_constant=mask_constant) fx_test_nngp, fx_test_ntk = predict(x_test=x_test, get=('nngp', 'ntk')) fx_test_nngp.block_until_ready() fx_test_ntk.block_until_ready() duration = time.time() - start print(f'Kernel construction and inference done in {duration} seconds.') # Print out accuracy and loss for infinite network predictions. loss = lambda fx, y_hat: 0.5 * np.mean((fx - y_hat)**2) util.print_summary('NNGP test', y_test, fx_test_nngp, None, loss) util.print_summary('NTK test', y_test, fx_test_ntk, None, loss)
def main(unused_argv): # Build data pipelines. print('Loading data.') x_train, y_train, x_test, y_test = \ datasets.get_dataset('mnist', FLAGS.train_size, FLAGS.test_size) # Build the infinite network. l = 5 w_std = 1.5 b_std = 2 net0 = stax.Dense(1, w_std, b_std) nets = [net0] k_layer = [] K = net0[2](x_train, None) k_layer.append(K.nngp) for l in range(1, l+1): net_l = stax.serial(stax.Relu(), stax.Dense(1, w_std, b_std)) K = net_l[2](K) k_layer.append(K.nngp) nets += [stax.serial(nets[-1], net_l)] kernel_fn = nets[-1][2] # Optionally, compute the kernel in batches, in parallel. kernel_fn = nt.batch(kernel_fn, device_count=0, batch_size=FLAGS.batch_size) start = time.time() # Bayesian and infinite-time gradient descent inference with infinite network. fx_test_nngp, fx_test_ntk = nt.predict.gp_inference(kernel_fn, x_train, y_train, x_test, get=('nngp', 'ntk'), diag_reg=1e-3) fx_test_nngp.block_until_ready() fx_test_ntk.block_until_ready() duration = time.time() - start print('Kernel construction and inference done in %s seconds.' % duration) # Print out accuracy and loss for infinite network predictions. loss = lambda fx, y_hat: 0.5 * np.mean((fx - y_hat) ** 2) util.print_summary('NNGP test', y_test, fx_test_nngp, None, loss) grid = [] count = 1 k_plot = [] for i in k_layer: grid.append(count) count += 1 k_plot.append(np.log(i[5,5])) # plt.plot(grid, k_plot) # plt.xlabel('layer ; w_var = 10, b_var = 2, accuracy = 93%') # plt.ylabel('Log (K[5][5]) ') w, v = LA.eig(k_layer[-1]) w = np.sort(w) #print(w) #plt.scatter(w, np.zeros(len(w))) index = [] for i in range(1,len(w)+1): index.append(i) w.sort() plt.scatter(index,np.log(w)[::-1]/np.log(10)) #plt.plot(index,mp) plt.ylabel("log10[eigen val]") plt.show() sio.savemat('mnist_l10_wvar=0_85_b_var=0_1.mat', { 'kernel': k_layer[-1] })
def main(unused_argv): # Build data pipelines. print('Loading data.') x_train, y_train, x_test, y_test = \ datasets.get_dataset('mnist', FLAGS.train_size, FLAGS.test_size) # Build the network init_fn, apply_fn, _ = stax.serial( stax.Dense(2048, 1., 0.05), # stax.Erf(), stax.Relu(), stax.Dense(2048, 1., 0.05), # stax.Erf(), stax.Relu(), stax.Dense(10, 1., 0.05)) key = random.PRNGKey(0) _, params = init_fn(key, (-1, 784)) # params # Create and initialize an optimizer. opt_init, opt_apply, get_params = optimizers.sgd(FLAGS.learning_rate) state = opt_init(params) # state # Create an mse loss function and a gradient function. loss = lambda fx, y_hat: 0.5 * np.mean((fx - y_hat) ** 2) grad_loss = jit(grad(lambda params, x, y: loss(apply_fn(params, x), y))) # Create an MSE predictor to solve the NTK equation in function space. ntk = nt.batch(nt.empirical_ntk_fn(apply_fn), batch_size=4, device_count=0) g_dd = ntk(x_train, None, params) g_td = ntk(x_test, x_train, params) predictor = nt.predict.gradient_descent_mse(g_dd, y_train, g_td) # g_dd.shape m = FLAGS.train_size print(m) n = m*10 m_test = FLAGS.test_size n_test = m_test*10 # g_td.shape # predictor # g_dd # type(g_dd) # g_dd.shape theta = g_dd.transpose((0,2,1,3)).reshape(n,n) theta_test = ntk(x_test, None, params).transpose((0,2,1,3)).reshape(n_test,n_test) theta_tilde = g_td.transpose((0,2,1,3)).reshape(n_test,n) #NNGP K = nt.empirical_nngp_fn(apply_fn)(x_train,None,params) K = np.kron(theta,np.eye(10)) K_test = nt.empirical_nngp_fn(apply_fn)(x_test,None,params) K_test = np.kron(theta_test,np.eye(10)) K_tilde = nt.empirical_nngp_fn(apply_fn)(x_test,x_train,params) K_tilde = np.kron(theta_tilde,np.eye(10)) decay_matrix = np.eye(n)-scipy.linalg.expm(-t*theta) Sigma = K + np.matmul(decay_matrix, np.matmul(K, np.matmul(np.linalg.inv(theta), np.matmul(decay_matrix, theta))) - 2*K) # K.shape theta # alpha = np.matmul(np.linalg.inv(K),np.matmul(theta,np.linalg.inv(theta))) # y_train # alpha = np.matmul(np.linalg.inv(K), y_train.reshape(1280)) # Sigma = K + np.matmul() # K = theta sigma_noise = 1.0 Y = y_train.reshape(n) alpha = np.matmul(np.linalg.inv(np.eye(n)*(sigma_noise**2)+K),Y) # cov = np.linalg.inv(np.linalg.inv(K)+np.eye(n)/(sigma_noise**2)) # covi = np.linalg.inv(cov) # covi = np.linalg.inv(K)+np.eye(n)/(sigma_noise**2) # print(covi) # np.linalg.det(K) eigs = np.linalg.eigh(K)[0] logdetcoviK = np.sum(np.log((eigs+sigma_noise**2) /sigma_noise**2)) # coviK = np.matmul(covi,K) # coviK = np.eye(n) + K/(sigma_noise**2) # coviK # covi # np.linalg.det() # KL = 0.5*np.log(np.linalg.det(coviK)) + 0.5*np.trace(np.linalg.inv(coviK)) + 0.5*np.matmul(alpha.T,np.matmul(K,alpha)) - n/2 KL = 0.5*logdetcoviK + 0.5*np.trace(np.linalg.inv(coviK)) + 0.5*np.matmul(alpha.T,np.matmul(K,alpha)) - n/2 print(KL) delta = 2**-10 bound = (KL+2*np.log(m)+1-np.log(delta))/m bound = 1-np.exp(-bound) bound print("bound", bound) import numpy bigK = numpy.zeros((n+n_test,n+n_test)) bigK bigK[0:n,0:n] = K bigK[0:n,n:] = theta_tilde.T bigK[n:,0:n] = theta_tilde bigK[n:,n:] = theta_test init_ntk_f = numpy.random.multivariate_normal(np.zeros(n+n_test),bigK) fx_train = init_ntk_f[:n].reshape(m,10) fx_test = init_ntk_f[n:].reshape(m_test,10) # Get initial values of the network in function space. # fx_train = apply_fn(params, x_train) # fx_test = apply_fn(params, x_test) # Train the network. train_steps = int(FLAGS.train_time // FLAGS.learning_rate) print('Training for {} steps'.format(train_steps)) for i in range(train_steps): params = get_params(state) state = opt_apply(i, grad_loss(params, x_train, y_train), state) # Get predictions from analytic computation. print('Computing analytic prediction.') # fx_train, fx_test = predictor(FLAGS.train_time, fx_train, fx_test) fx_train, fx_test = predictor(FLAGS.train_time, fx_train, fx_test) # Print out summary data comparing the linear / nonlinear model. util.print_summary('train', y_train, apply_fn(params, x_train), fx_train, loss) util.print_summary('test', y_test, apply_fn(params, x_test), fx_test, loss)
def data_load(): print('Loading data.') x_train, y_train, x_test, y_test = \ datasets.get_dataset('cifar10', FLAGS.train_size, FLAGS.test_size) return x_train, y_train, x_test, y_test
duration = time.time() - start print('Kernel construction and inference done in %s seconds.' % duration) # Print out accuracy and loss for infinite network predictions. loss = lambda fx, y_hat: 0.5 * np.mean((fx - y_hat)**2) n_accuracy, n_loss_x = util.print_summary('NNGP test', y_test, fx_test_nngp, None, loss) n_accuracy_x, n_loss = util.print_summary('NNGP test', y_train, fx_test_nngp_train, None, loss) return (n_accuracy, n_loss, k_layer) # Build data pipelines. print('Loading data.') x_train, y_train, x_test, y_test = \ datasets.get_dataset('mnist', FLAGS.train_size, FLAGS.test_size) # testing for various w_std layer = 5 w_start = 0.91 b_std = 2 w_choice = [] n_accuracy = [] n_loss = [] kernel_evolution = [] def avg_k_evolution(x): avg = 0 for i in range(len(x) - 1):