def main(unused_argv): # Build data pipelines. print('Loading data.') x_train, y_train, x_test, y_test = \ datasets.mnist(FLAGS.train_size, FLAGS.test_size) # Build the network init_fn, f = stax.serial(layers.Dense(4096), stax.Tanh, layers.Dense(10)) key = random.PRNGKey(0) _, params = init_fn(key, (-1, 784)) # Create and initialize an optimizer. opt_init, opt_apply, get_params = optimizers.sgd(FLAGS.learning_rate) state = opt_init(params) # Create an mse loss function and a gradient function. loss = lambda fx, y_hat: 0.5 * np.mean((fx - y_hat)**2) grad_loss = jit(grad(lambda params, x, y: loss(f(params, x), y))) # Create an MSE predictor to solve the NTK equation in function space. theta = tangents.ntk(f, batch_size=32) g_dd = theta(params, x_train) import ipdb ipdb.set_trace() g_td = theta(params, x_test, x_train) predictor = tangents.analytic_mse_predictor(g_dd, y_train, g_td) # Get initial values of the network in function space. fx_train = f(params, x_train) fx_test = f(params, x_test) # Train the network. train_steps = int(FLAGS.train_time // FLAGS.learning_rate) print('Training for {} steps'.format(train_steps)) for i in range(train_steps): params = get_params(state) state = opt_apply(i, grad_loss(params, x_train, y_train), state) # Get predictions from analytic computation. print('Computing analytic prediction.') fx_train, fx_test = predictor(fx_train, fx_test, FLAGS.train_time) # Print out summary data comparing the linear / nonlinear model. util.print_summary('train', y_train, f(params, x_train), fx_train, loss) util.print_summary('test', y_test, f(params, x_test), fx_test, loss)
def testNTKMSEPrediction(self, shape, out_logits): key = random.PRNGKey(0) key, split = random.split(key) data_train = random.normal(split, shape) key, split = random.split(key) data_labels = np.array( random.bernoulli(split, shape=(shape[0], out_logits)), np.float32) key, split = random.split(key) data_test = random.normal(split, shape) key, w_split, b_split = random.split(key, 3) params = (random.normal(w_split, (shape[-1], out_logits)), random.normal(b_split, (out_logits,))) def f(params, x): w, b = params return np.dot(x, w) / shape[-1] + b # Regress to an MSE loss. loss = lambda params, x: \ 0.5 * np.mean((f(params, x) - data_labels) ** 2) theta = tangents.ntk(f) g_dd = theta(params, data_train) g_td = theta(params, data_test, data_train) predictor = tangents.analytic_mse_predictor(g_dd, data_labels, g_td) step_size = 1.0 train_time = 100.0 steps = int(train_time / step_size) opt_init, opt_update, get_params = opt.sgd(step_size) opt_state = opt_init(params) fx_initial_train = f(params, data_train) fx_initial_test = f(params, data_test) fx_pred_train, fx_pred_test = predictor( fx_initial_train, fx_initial_test, 0.0) # NOTE(schsam): I think at the moment stax always generates 32-bit results # since the weights are explicitly cast to float32. self.assertAllClose(fx_initial_train, fx_pred_train, False) self.assertAllClose(fx_initial_test, fx_pred_test, False) for i in range(steps): params = get_params(opt_state) opt_state = opt_update(i, grad(loss)(params, data_train), opt_state) params = get_params(opt_state) fx_train = f(params, data_train) fx_test = f(params, data_test) fx_pred_train, fx_pred_test = predictor( fx_initial_train, fx_initial_test, train_time) fx_disp_train = np.sqrt(np.mean((fx_train - fx_initial_train) ** 2)) fx_disp_test = np.sqrt(np.mean((fx_test - fx_initial_test) ** 2)) fx_error_train = (fx_train - fx_pred_train) / fx_disp_train fx_error_test = (fx_test - fx_pred_test) / fx_disp_test self.assertAllClose( fx_error_train, np.zeros_like(fx_error_train), False, 0.1, 0.1) self.assertAllClose( fx_error_test, np.zeros_like(fx_error_test), False, 0.1, 0.1)
key = random.PRNGKey(run) _, params = net_init(key, (-1, 1)) # data task = sinusoid_task(n_support=args.n_support) x_train, y_train, x_test, y_test = task['x_train'], task['y_train'], task[ 'x_test'], task['y_test'] # linearized network f_lin = tangents.linearize(f, params) # Create an MSE predictor to solve the NTK equation in function space. theta = tangents.ntk(f, batch_size=32) g_dd = theta(params, x_train) g_td = theta(params, x_test, x_train) predictor = tangents.analytic_mse_predictor(g_dd, y_train, g_td) import ipdb ipdb.set_trace() # Get initial values of the network in function space. fx_train_ana_init = f(params, x_train) fx_test_ana_init = f(params, x_test) # optimizer for f and f_lin if args.inner_opt_alg == 'sgd': optimizer = partial(optimizers.sgd, step_size=args.inner_step_size) elif args.inner_opt_alg == 'momentum': optimizer = partial(optimizers.momentum, step_size=args.inner_step_size, mass=0.9) elif args.inner_opt_alg == 'adam':