def mc_sampling(count=10): empirical_mean = 0. key = random.PRNGKey(100) init_fn, f, _ = _build_network(train_shape[1:], network, out_logits) _kernel_fn = empirical.empirical_kernel_fn(f) kernel_fn = jit( lambda x1, x2, params: _kernel_fn(x1, x2, params, 'ntk')) for _ in range(count): key, split = random.split(key) _, params = init_fn(split, train_shape) g_dd = kernel_fn(data_train, None, params) g_td = kernel_fn(data_test, data_train, params) predictor = predict.gradient_descent_mse( g_dd, data_labels, g_td) fx_initial_train = f(params, data_train) fx_initial_test = f(params, data_test) _, fx_pred_test = predictor(1.0e8, fx_initial_train, fx_initial_test) empirical_mean += fx_pred_test return empirical_mean / count
def mc_sampling(count=10): key = random.PRNGKey(100) init_fn, f, _ = _build_network(train_shape[1:], network, out_logits) _kernel_fn = empirical.empirical_kernel_fn(f) kernel_fn = jit( lambda x1, x2, params: _kernel_fn(x1, x2, params, 'ntk')) collect_test_predict = [] for _ in range(count): key, split = random.split(key) _, params = init_fn(split, train_shape) g_dd = kernel_fn(x_train, None, params) g_td = kernel_fn(x_test, x_train, params) predictor = predict.gradient_descent_mse(g_dd, y_train, g_td) fx_initial_train = f(params, x_train) fx_initial_test = f(params, x_test) _, fx_pred_test = predictor(1.0e8, fx_initial_train, fx_initial_test) collect_test_predict.append(fx_pred_test) collect_test_predict = np.array(collect_test_predict) mean_emp = np.mean(collect_test_predict, axis=0) mean_subtracted = collect_test_predict - mean_emp cov_emp = np.einsum( 'ijk,ilk->jl', mean_subtracted, mean_subtracted, optimize=True) / (mean_subtracted.shape[0] * mean_subtracted.shape[-1]) return mean_emp, cov_emp
def predict_empirical(key): _, params = init_fn(key, train_shape) g_dd = kernel_fn(x_train, None, params) g_td = kernel_fn(x_test, x_train, params) predict_fn = predict.gradient_descent_mse(g_dd, y_train, diag_reg=reg) fx_train_0 = f(params, x_train) fx_test_0 = f(params, x_test) return predict_fn(ts, fx_train_0, fx_test_0, g_td)
def testNTKGDPrediction(self, train_shape, test_shape, network, out_logits, fn_and_kernel, momentum, learning_rate, t, loss): key, x_test, x_train, y_train = self._get_inputs(out_logits, test_shape, train_shape) params, f, ntk = fn_and_kernel(key, train_shape[1:], network, out_logits) g_dd = ntk(x_train, None, 'ntk') g_td = ntk(x_test, x_train, 'ntk') # Regress to an MSE loss. loss_fn = lambda y, y_hat: 0.5 * np.mean((y - y_hat)**2) grad_loss = jit(grad(lambda params, x: loss_fn(f(params, x), y_train))) trace_axes = () if g_dd.ndim == 4 else (-1,) if loss == 'mse_analytic': if momentum is not None: raise absltest.SkipTest(momentum) predictor = predict.gradient_descent_mse(g_dd, y_train, learning_rate=learning_rate, trace_axes=trace_axes) elif loss == 'mse': predictor = predict.gradient_descent(loss_fn, g_dd, y_train, learning_rate=learning_rate, momentum=momentum, trace_axes=trace_axes) else: raise NotImplementedError(loss) predictor = jit(predictor) fx_train_0 = f(params, x_train) fx_test_0 = f(params, x_test) self._test_zero_time(predictor, fx_train_0, fx_test_0, g_td, momentum) self._test_multi_step(predictor, fx_train_0, fx_test_0, g_td, momentum) if loss == 'mse_analytic': self._test_inf_time(predictor, fx_train_0, fx_test_0, g_td, y_train) if momentum is None: opt_init, opt_update, get_params = optimizers.sgd(learning_rate) else: opt_init, opt_update, get_params = optimizers.momentum(learning_rate, momentum) opt_state = opt_init(params) for i in range(t): params = get_params(opt_state) opt_state = opt_update(i, grad_loss(params, x_train), opt_state) params = get_params(opt_state) fx_train_nn, fx_test_nn = f(params, x_train), f(params, x_test) fx_train_t, fx_test_t = predictor(t, fx_train_0, fx_test_0, g_td) self.assertAllClose(fx_train_nn, fx_train_t, rtol=RTOL, atol=ATOL) self.assertAllClose(fx_test_nn, fx_test_t, rtol=RTOL, atol=ATOL)
def main(unused_argv): # Build data pipelines. print('Loading data.') x_train, y_train, x_test, y_test = \ datasets.mnist(FLAGS.train_size, FLAGS.test_size) # Build the network init_fn, f, _ = stax.serial(stax.Dense(2048, 1., 0.05), stax.Erf(), stax.Dense(10, 1., 0.05)) key = random.PRNGKey(0) _, params = init_fn(key, (-1, 784)) # Create and initialize an optimizer. opt_init, opt_apply, get_params = optimizers.sgd(FLAGS.learning_rate) state = opt_init(params) # Create an mse loss function and a gradient function. loss = lambda fx, y_hat: 0.5 * np.mean((fx - y_hat)**2) grad_loss = jit(grad(lambda params, x, y: loss(f(params, x), y))) # Create an MSE predictor to solve the NTK equation in function space. ntk = batch(get_ntk_fun_empirical(f), batch_size=4, device_count=0) g_dd = ntk(x_train, None, params) g_td = ntk(x_test, x_train, params) predictor = predict.gradient_descent_mse(g_dd, y_train, g_td) # Get initial values of the network in function space. fx_train = f(params, x_train) fx_test = f(params, x_test) # Train the network. train_steps = int(FLAGS.train_time // FLAGS.learning_rate) print('Training for {} steps'.format(train_steps)) for i in range(train_steps): params = get_params(state) state = opt_apply(i, grad_loss(params, x_train, y_train), state) # Get predictions from analytic computation. print('Computing analytic prediction.') fx_train, fx_test = predictor(FLAGS.train_time, fx_train, fx_test) # Print out summary data comparing the linear / nonlinear model. util.print_summary('train', y_train, f(params, x_train), fx_train, loss) util.print_summary('test', y_test, f(params, x_test), fx_test, loss)
def mc_sampling(count=10): empirical_mean = 0. key = random.PRNGKey(100) for _ in range(count): key, split = random.split(key) params, f, theta = _empirical_kernel(split, train_shape[1:], network, out_logits) g_dd = theta(data_train, None) g_td = theta(data_test, data_train) predictor = predict.gradient_descent_mse( g_dd, data_labels, g_td) fx_initial_train = f(params, data_train) fx_initial_test = f(params, data_test) _, fx_pred_test = predictor(1.0e8, fx_initial_train, fx_initial_test) empirical_mean += fx_pred_test return empirical_mean / count
def testNTKMSEPrediction(self, train_shape, test_shape, network, out_logits, fn_and_kernel): key = random.PRNGKey(0) key, split = random.split(key) x_train = random.normal(split, train_shape) key, split = random.split(key) y_train = np.array( random.bernoulli(split, shape=(train_shape[0], out_logits)), np.float32) key, split = random.split(key) x_test = random.normal(split, test_shape) params, f, ntk = fn_and_kernel(key, train_shape[1:], network, out_logits) # Regress to an MSE loss. loss = lambda params, x: \ 0.5 * np.mean((f(params, x) - y_train) ** 2) grad_loss = jit(grad(loss)) g_dd = ntk(x_train, None, 'ntk') g_td = ntk(x_test, x_train, 'ntk') predictor = predict.gradient_descent_mse(g_dd, y_train, g_td) atol = ATOL rtol = RTOL step_size = 0.1 if len(train_shape) > 2: # Hacky way to up the tolerance just for convolutions. atol = ATOL * 2 rtol = RTOL * 2 step_size = 0.1 train_time = 100.0 steps = int(train_time / step_size) opt_init, opt_update, get_params = optimizers.sgd(step_size) opt_state = opt_init(params) fx_initial_train = f(params, x_train) fx_initial_test = f(params, x_test) fx_pred_train, fx_pred_test = predictor(0.0, fx_initial_train, fx_initial_test) self.assertAllClose(fx_initial_train, fx_pred_train, True) self.assertAllClose(fx_initial_test, fx_pred_test, True) for i in range(steps): params = get_params(opt_state) opt_state = opt_update(i, grad_loss(params, x_train), opt_state) params = get_params(opt_state) fx_train = f(params, x_train) fx_test = f(params, x_test) fx_pred_train, fx_pred_test = predictor(train_time, fx_initial_train, fx_initial_test) fx_disp_train = np.sqrt(np.mean((fx_train - fx_initial_train)**2)) fx_disp_test = np.sqrt(np.mean((fx_test - fx_initial_test)**2)) fx_error_train = (fx_train - fx_pred_train) / fx_disp_train fx_error_test = (fx_test - fx_pred_test) / fx_disp_test self.assertAllClose(fx_error_train, np.zeros_like(fx_error_train), True, rtol, atol) self.assertAllClose(fx_error_test, np.zeros_like(fx_error_test), True, rtol, atol)
def testPredictND(self): n_chan = 6 key = random.PRNGKey(1) im_shape = (5, 4, 3) n_train = 2 n_test = 2 x_train = random.normal(key, (n_train, ) + im_shape) y_train = random.uniform(key, (n_train, 3, 2, n_chan)) init_fn, apply_fn, _ = stax.Conv(n_chan, (3, 2), (1, 2)) _, params = init_fn(key, x_train.shape) fx_train_0 = apply_fn(params, x_train) for trace_axes in [(), (-1, ), (-2, ), (-3, ), (0, 1), (2, 3), (2, ), (1, 3), (0, -1), (0, 0, -3), (0, 1, 2, 3), (0, 1, -1, 2)]: for ts in [None, np.arange(6).reshape((2, 3))]: for x in [None, 'x_test']: with self.subTest(trace_axes=trace_axes, ts=ts, x=x): t_shape = ts.shape if ts is not None else () y_test_shape = t_shape + (n_test, ) + y_train.shape[1:] y_train_shape = t_shape + y_train.shape x = x if x is None else random.normal( key, (n_test, ) + im_shape) fx_test_0 = None if x is None else apply_fn(params, x) kernel_fn = empirical.empirical_kernel_fn( apply_fn, trace_axes=trace_axes) # TODO(romann): investigate the SIGTERM error on CPU. # kernel_fn = jit(kernel_fn, static_argnums=(2,)) ntk_train_train = kernel_fn(x_train, None, 'ntk', params) if x is not None: ntk_test_train = kernel_fn(x, x_train, 'ntk', params) loss = lambda x, y: 0.5 * np.mean(x - y)**2 predict_fn_mse = predict.gradient_descent_mse( ntk_train_train, y_train, trace_axes=trace_axes) predict_fn_mse_ensemble = predict.gradient_descent_mse_ensemble( kernel_fn, x_train, y_train, trace_axes=trace_axes, params=params) if x is None: p_train_mse = predict_fn_mse(ts, fx_train_0) else: p_train_mse, p_test_mse = predict_fn_mse( ts, fx_train_0, fx_test_0, ntk_test_train) self.assertAllClose(y_test_shape, p_test_mse.shape) self.assertAllClose(y_train_shape, p_train_mse.shape) p_nngp_mse_ens, p_ntk_mse_ens = predict_fn_mse_ensemble( ts, x, ('nngp', 'ntk'), compute_cov=True) ref_shape = y_train_shape if x is None else y_test_shape self.assertAllClose(ref_shape, p_ntk_mse_ens.mean.shape) self.assertAllClose(ref_shape, p_nngp_mse_ens.mean.shape) if ts is not None: predict_fn = predict.gradient_descent( loss, ntk_train_train, y_train, trace_axes=trace_axes) if x is None: p_train = predict_fn(ts, fx_train_0) else: p_train, p_test = predict_fn( ts, fx_train_0, fx_test_0, ntk_test_train) self.assertAllClose(y_test_shape, p_test.shape) self.assertAllClose(y_train_shape, p_train.shape)
def test_kwargs(self, do_batch, mode): rng = random.PRNGKey(1) x_train = random.normal(rng, (8, 7, 10)) x_test = random.normal(rng, (4, 7, 10)) y_train = random.normal(rng, (8, 1)) rng_train, rng_test = random.split(rng, 2) pattern_train = random.normal(rng, (8, 7, 7)) pattern_test = random.normal(rng, (4, 7, 7)) init_fn, apply_fn, kernel_fn = stax.serial( stax.Dense(8), stax.Relu(), stax.Dropout(rate=0.4), stax.Aggregate(), stax.GlobalAvgPool(), stax.Dense(1) ) kw_dd = dict(pattern=(pattern_train, pattern_train)) kw_td = dict(pattern=(pattern_test, pattern_train)) kw_tt = dict(pattern=(pattern_test, pattern_test)) if mode == 'mc': kernel_fn = monte_carlo_kernel_fn(init_fn, apply_fn, rng, 2, batch_size=2 if do_batch else 0) elif mode == 'empirical': kernel_fn = empirical_kernel_fn(apply_fn) if do_batch: raise absltest.SkipTest('Batching of empirical kernel is not ' 'implemented with keyword arguments.') for kw in (kw_dd, kw_td, kw_tt): kw.update(dict(params=init_fn(rng, x_train.shape)[1], get=('nngp', 'ntk'))) kw_dd.update(dict(rng=(rng_train, None))) kw_td.update(dict(rng=(rng_test, rng_train))) kw_tt.update(dict(rng=(rng_test, None))) elif mode == 'analytic': if do_batch: kernel_fn = batch.batch(kernel_fn, batch_size=2) else: raise ValueError(mode) k_dd = kernel_fn(x_train, None, **kw_dd) k_td = kernel_fn(x_test, x_train, **kw_td) k_tt = kernel_fn(x_test, None, **kw_tt) # Infinite time NNGP/NTK. predict_fn_gp = predict.gp_inference(k_dd, y_train) out_gp = predict_fn_gp(k_test_train=k_td, nngp_test_test=k_tt.nngp) if mode == 'empirical': for kw in (kw_dd, kw_td, kw_tt): kw.pop('get') predict_fn_ensemble = predict.gradient_descent_mse_ensemble(kernel_fn, x_train, y_train, **kw_dd) out_ensemble = predict_fn_ensemble(x_test=x_test, compute_cov=True, **kw_tt) self.assertAllClose(out_gp, out_ensemble) # Finite time NTK test. predict_fn_mse = predict.gradient_descent_mse(k_dd.ntk, y_train) out_mse = predict_fn_mse(t=1., fx_train_0=None, fx_test_0=0., k_test_train=k_td.ntk) out_ensemble = predict_fn_ensemble(t=1., get='ntk', x_test=x_test, compute_cov=False, **kw_tt) self.assertAllClose(out_mse, out_ensemble) # Finite time NNGP train. predict_fn_mse = predict.gradient_descent_mse(k_dd.nngp, y_train) out_mse = predict_fn_mse(t=2., fx_train_0=0., fx_test_0=None, k_test_train=k_td.nngp) out_ensemble = predict_fn_ensemble(t=2., get='nngp', x_test=None, compute_cov=False, **kw_dd) self.assertAllClose(out_mse, out_ensemble)