def testNTKBatched(self, shape, out_logits): key = random.PRNGKey(0) data_self = random.normal(key, shape) data_other = random.normal(key, shape) key, w_split, b_split = random.split(key, 3) params = (random.normal(w_split, (shape[-1], out_logits)), random.normal(b_split, (out_logits,))) def f(params, x): w, b = params return np.dot(x, w) / shape[-1] + b g_fn = tangents.ntk(f) g_batched_fn = tangents.ntk(f, batch_size=2) g = g_fn(params, data_self) g_batched = g_batched_fn(params, data_self) self.assertAllClose(g, g_batched, check_dtypes=False) g = g_fn(params, data_other, data_self) g_batched = g_batched_fn(params, data_other, data_self) self.assertAllClose(g, g_batched, check_dtypes=False)
def main(unused_argv): # Build data pipelines. print('Loading data.') x_train, y_train, x_test, y_test = \ datasets.mnist(FLAGS.train_size, FLAGS.test_size) # Build the network init_fn, f = stax.serial(layers.Dense(4096), stax.Tanh, layers.Dense(10)) key = random.PRNGKey(0) _, params = init_fn(key, (-1, 784)) # Create and initialize an optimizer. opt_init, opt_apply, get_params = optimizers.sgd(FLAGS.learning_rate) state = opt_init(params) # Create an mse loss function and a gradient function. loss = lambda fx, y_hat: 0.5 * np.mean((fx - y_hat)**2) grad_loss = jit(grad(lambda params, x, y: loss(f(params, x), y))) # Create an MSE predictor to solve the NTK equation in function space. theta = tangents.ntk(f, batch_size=32) g_dd = theta(params, x_train) import ipdb ipdb.set_trace() g_td = theta(params, x_test, x_train) predictor = tangents.analytic_mse_predictor(g_dd, y_train, g_td) # Get initial values of the network in function space. fx_train = f(params, x_train) fx_test = f(params, x_test) # Train the network. train_steps = int(FLAGS.train_time // FLAGS.learning_rate) print('Training for {} steps'.format(train_steps)) for i in range(train_steps): params = get_params(state) state = opt_apply(i, grad_loss(params, x_train, y_train), state) # Get predictions from analytic computation. print('Computing analytic prediction.') fx_train, fx_test = predictor(fx_train, fx_test, FLAGS.train_time) # Print out summary data comparing the linear / nonlinear model. util.print_summary('train', y_train, f(params, x_train), fx_train, loss) util.print_summary('test', y_test, f(params, x_test), fx_test, loss)
def testNTKAgainstDirect(self, shape, out_logits): def sum_and_contract(j1, j2): def contract(x, y): param_count = int(np.prod(x.shape[2:])) x = np.reshape(x, (-1, param_count)) y = np.reshape(y, (-1, param_count)) return np.dot(x, np.transpose(y)) return tree_reduce(operator.add, tree_multimap(contract, j1, j2)) def ntk_direct(f, params, x1, x2): jac_fn = jacobian(f) j1 = jac_fn(params, x1) if x2 is None: j2 = j1 else: j2 = jac_fn(params, x2) return sum_and_contract(j1, j2) key = random.PRNGKey(0) data_self = random.normal(key, shape) data_other = random.normal(key, shape) key, w_split, b_split = random.split(key, 3) params = (random.normal(w_split, (shape[-1], out_logits)), random.normal(b_split, (out_logits,))) def f(params, x): w, b = params return np.dot(x, w) / shape[-1] + b g_fn = tangents.ntk(f) g = g_fn(params, data_self) g_direct = ntk_direct(f, params, data_self, data_self) self.assertAllClose(g, g_direct, check_dtypes=False) g = g_fn(params, data_other, data_self) g_direct = ntk_direct(f, params, data_other, data_self) self.assertAllClose(g, g_direct, check_dtypes=False)
n_query=args.n_query) else: raise ValueError ntk_frequency = 50 plot_update_frequency = 100 for i, task_batch in tqdm(enumerate( taskbatch(task_fn=task_fn, batch_size=args.task_batch_size, n_task=args.n_train_task)), total=args.n_train_task // args.task_batch_size): aux = dict() # ntk if i == 0 or (i + 1) % (args.n_train_task // args.task_batch_size // ntk_frequency) == 0: ntk = tangents.ntk(f, batch_size=100)(outer_get_params(outer_state), task_eval['x_train']) aux['ntk_train_rank_eval'] = onp.linalg.matrix_rank(ntk) f_lin = tangents.linearize(f, outer_get_params(outer_state_lin)) ntk_lin = tangents.ntk(f_lin, batch_size=100)( outer_get_params(outer_state_lin), task_eval['x_train']) aux['ntk_train_rank_eval_lin'] = onp.linalg.matrix_rank(ntk_lin) log.append([(key, aux[key]) for key in win_rank_eval_keys]) # spectrum evals, evecs = onp.linalg.eigh(ntk) # eigenvectors are columns for j in range(len(evals)): aux[f'ntk_spectrum_{j}_eval'] = evals[j] log.append([(key, aux[key]) for key in win_spectrum_eval_keys]) evals = evals.clip(min=1e-10) ind = onp.arange(len(evals)) + 1 # +1 because we are taking log
def testNTKMomentumPrediction(self, shape, out_logits): key = random.PRNGKey(1) key, split = random.split(key) data_train = random.normal(split, shape) key, split = random.split(key) label_ids = random.randint(split, (shape[0],), 0, out_logits) data_labels = np.eye(out_logits)[label_ids] key, split = random.split(key) data_test = random.normal(split, shape) key, w_split, b_split = random.split(key, 3) params = (random.normal(w_split, (shape[-1], out_logits)), random.normal(b_split, (out_logits,))) def f(params, x): w, b = params return np.dot(x, w) / shape[-1] + b loss = lambda y, y_hat: 0.5 * np.mean((y - y_hat) ** 2) grad_loss = grad(lambda params, x: loss(f(params, x), data_labels)) theta = tangents.ntk(f) g_dd = theta(params, data_train) g_td = theta(params, data_test, data_train) step_size = 1.0 train_time = 100.0 steps = int(train_time / np.sqrt(step_size)) init_fn, predict_fn, get_fn = tangents.momentum_predictor( g_dd, data_labels, loss, step_size, g_td) opt_init, opt_update, get_params = momentum(step_size, 0.9) opt_state = opt_init(params) fx_initial_train = f(params, data_train) fx_initial_test = f(params, data_test) lin_state = init_fn(fx_initial_train, fx_initial_test) for i in range(steps): params = get_params(opt_state) opt_state = opt_update(i, grad_loss(params, data_train), opt_state) params = get_params(opt_state) fx_train = f(params, data_train) fx_test = f(params, data_test) lin_state = predict_fn(lin_state, train_time) fx_pred_train, fx_pred_test = get_fn(lin_state) # Put errors in units of RMS distance of the function values during # optimization. fx_disp_train = np.sqrt(np.mean((fx_train - fx_initial_train) ** 2)) fx_disp_test = np.sqrt(np.mean((fx_test - fx_initial_test) ** 2)) fx_error_train = (fx_train - fx_pred_train) / fx_disp_train fx_error_test = (fx_test - fx_pred_test) / fx_disp_test self.assertAllClose( fx_error_train, np.zeros_like(fx_error_train), False, 0.1, 0.1) self.assertAllClose( fx_error_test, np.zeros_like(fx_error_test), False, 0.1, 0.1)
def testNTKGDPrediction(self, shape, out_logits): key = random.PRNGKey(1) key, split = random.split(key) data_train = random.normal(split, shape) key, split = random.split(key) label_ids = random.randint(split, (shape[0],), 0, out_logits) data_labels = np.eye(out_logits)[label_ids] key, split = random.split(key) data_test = random.normal(split, shape) key, w_split, b_split = random.split(key, 3) params = (random.normal(w_split, (shape[-1], out_logits)), random.normal(b_split, (out_logits,))) def f(params, x): w, b = params return np.dot(x, w) / shape[-1] + b loss = lambda y, y_hat: 0.5 * np.mean((y - y_hat) ** 2) grad_loss = grad(lambda params, x: loss(f(params, x), data_labels)) theta = tangents.ntk(f) g_dd = theta(params, data_train) g_td = theta(params, data_test, data_train) predictor = tangents.gradient_descent_predictor( g_dd, data_labels, loss, g_td) step_size = 1.0 train_time = 100.0 steps = int(train_time / step_size) opt_init, opt_update, get_params = opt.sgd(step_size) opt_state = opt_init(params) fx_initial_train = f(params, data_train) fx_initial_test = f(params, data_test) fx_pred_train, fx_pred_test = predictor( fx_initial_train, fx_initial_test, 0.0) # NOTE(schsam): I think at the moment stax always generates 32-bit results # since the weights are explicitly cast to float32. self.assertAllClose(fx_initial_train, fx_pred_train, False) self.assertAllClose(fx_initial_test, fx_pred_test, False) for i in range(steps): params = get_params(opt_state) opt_state = opt_update(i, grad_loss(params, data_train), opt_state) params = get_params(opt_state) fx_train = f(params, data_train) fx_test = f(params, data_test) fx_pred_train, fx_pred_test = predictor( fx_initial_train, fx_initial_test, train_time) # Put errors in units of RMS distance of the function values during # optimization. fx_disp_train = np.sqrt(np.mean((fx_train - fx_initial_train) ** 2)) fx_disp_test = np.sqrt(np.mean((fx_test - fx_initial_test) ** 2)) fx_error_train = (fx_train - fx_pred_train) / fx_disp_train fx_error_test = (fx_test - fx_pred_test) / fx_disp_test self.assertAllClose( fx_error_train, np.zeros_like(fx_error_train), False, 0.1, 0.1) self.assertAllClose( fx_error_test, np.zeros_like(fx_error_test), False, 0.1, 0.1)
def testNTKMSEPrediction(self, shape, out_logits): key = random.PRNGKey(0) key, split = random.split(key) data_train = random.normal(split, shape) key, split = random.split(key) data_labels = np.array( random.bernoulli(split, shape=(shape[0], out_logits)), np.float32) key, split = random.split(key) data_test = random.normal(split, shape) key, w_split, b_split = random.split(key, 3) params = (random.normal(w_split, (shape[-1], out_logits)), random.normal(b_split, (out_logits,))) def f(params, x): w, b = params return np.dot(x, w) / shape[-1] + b # Regress to an MSE loss. loss = lambda params, x: \ 0.5 * np.mean((f(params, x) - data_labels) ** 2) theta = tangents.ntk(f) g_dd = theta(params, data_train) g_td = theta(params, data_test, data_train) predictor = tangents.analytic_mse_predictor(g_dd, data_labels, g_td) step_size = 1.0 train_time = 100.0 steps = int(train_time / step_size) opt_init, opt_update, get_params = opt.sgd(step_size) opt_state = opt_init(params) fx_initial_train = f(params, data_train) fx_initial_test = f(params, data_test) fx_pred_train, fx_pred_test = predictor( fx_initial_train, fx_initial_test, 0.0) # NOTE(schsam): I think at the moment stax always generates 32-bit results # since the weights are explicitly cast to float32. self.assertAllClose(fx_initial_train, fx_pred_train, False) self.assertAllClose(fx_initial_test, fx_pred_test, False) for i in range(steps): params = get_params(opt_state) opt_state = opt_update(i, grad(loss)(params, data_train), opt_state) params = get_params(opt_state) fx_train = f(params, data_train) fx_test = f(params, data_test) fx_pred_train, fx_pred_test = predictor( fx_initial_train, fx_initial_test, train_time) fx_disp_train = np.sqrt(np.mean((fx_train - fx_initial_train) ** 2)) fx_disp_test = np.sqrt(np.mean((fx_test - fx_initial_test) ** 2)) fx_error_train = (fx_train - fx_pred_train) / fx_disp_train fx_error_test = (fx_test - fx_pred_test) / fx_disp_test self.assertAllClose( fx_error_train, np.zeros_like(fx_error_train), False, 0.1, 0.1) self.assertAllClose( fx_error_test, np.zeros_like(fx_error_test), False, 0.1, 0.1)
norm=args.norm) # initialize network key = random.PRNGKey(run) _, params = net_init(key, (-1, 1)) # data task = sinusoid_task(n_support=args.n_support) x_train, y_train, x_test, y_test = task['x_train'], task['y_train'], task[ 'x_test'], task['y_test'] # linearized network f_lin = tangents.linearize(f, params) # Create an MSE predictor to solve the NTK equation in function space. theta = tangents.ntk(f, batch_size=32) g_dd = theta(params, x_train) g_td = theta(params, x_test, x_train) predictor = tangents.analytic_mse_predictor(g_dd, y_train, g_td) import ipdb ipdb.set_trace() # Get initial values of the network in function space. fx_train_ana_init = f(params, x_train) fx_test_ana_init = f(params, x_test) # optimizer for f and f_lin if args.inner_opt_alg == 'sgd': optimizer = partial(optimizers.sgd, step_size=args.inner_step_size) elif args.inner_opt_alg == 'momentum': optimizer = partial(optimizers.momentum,