def testNTKMeanCovPrediction(self, train_shape, test_shape, network, out_logits): key, x_test, x_train, y_train = self._get_inputs( out_logits, test_shape, train_shape) init_fn, f, kernel_fn = stax.serial( stax.Dense(512, W_std=1.2, b_std=0.05), stax.Erf(), stax.Dense(out_logits, W_std=1.2, b_std=0.05)) reg = 1e-6 predictor = predict.gradient_descent_mse_ensemble(kernel_fn, x_train, y_train, diag_reg=reg) ts = np.array([1., 5., 10.]) fx_test_inf, cov_test_inf = predictor(ts, x_test, 'ntk', True) self.assertEqual(cov_test_inf.shape[1], x_test.shape[0]) self.assertGreater(np.min(np.linalg.eigh(cov_test_inf)[0]), -1e-8) fx_train_inf, cov_train_inf = predictor(ts, None, 'ntk', True) self.assertEqual(cov_train_inf.shape[1], x_train.shape[0]) self.assertGreater(np.min(np.linalg.eigh(cov_train_inf)[0]), -1e-8) _kernel_fn = empirical.empirical_kernel_fn(f) kernel_fn = jit( lambda x1, x2, params: _kernel_fn(x1, x2, 'ntk', params)) def predict_empirical(key): _, params = init_fn(key, train_shape) g_dd = kernel_fn(x_train, None, params) g_td = kernel_fn(x_test, x_train, params) predict_fn = predict.gradient_descent_mse(g_dd, y_train, diag_reg=reg) fx_train_0 = f(params, x_train) fx_test_0 = f(params, x_test) return predict_fn(ts, fx_train_0, fx_test_0, g_td) def predict_mc(count, key): key = tf_random_split(key, count) fx_train, fx_test = vmap(predict_empirical)(key) fx_train_mean = np.mean(fx_train, axis=0) fx_test_mean = np.mean(fx_test, axis=0) fx_train_centered = fx_train - fx_train_mean fx_test_centered = fx_test - fx_test_mean cov_train = PredictTest._cov_empirical(fx_train_centered) cov_test = PredictTest._cov_empirical(fx_test_centered) return fx_train_mean, fx_test_mean, cov_train, cov_test fx_train_mc, fx_test_mc, cov_train_mc, cov_test_mc = predict_mc( 4096, key) rtol = 0.05 self._assertAllClose(fx_train_mc, fx_train_inf, rtol) self._assertAllClose(cov_train_mc, cov_train_inf, rtol) self._assertAllClose(cov_test_mc, cov_test_inf, rtol) self._assertAllClose(fx_test_mc, fx_test_inf, rtol)
def testMaxLearningRate(self, train_shape, network, out_logits, fn_and_kernel): key = stateless_uniform(shape=[2], seed=[0, 0], minval=None, maxval=None, dtype=tf.int32) keys = tf_random_split(key) key = keys[0] split = keys[1] if len(train_shape) == 2: train_shape = (train_shape[0] * 5, train_shape[1] * 10) else: train_shape = (16, 8, 8, 3) x_train = np.asarray(normal(train_shape, seed=split)) keys = tf_random_split(key) key = keys[0] split = keys[1] y_train = np.asarray( stateless_uniform(shape=(train_shape[0], out_logits), seed=split, minval=0, maxval=1) < 0.5, np.float32) # Regress to an MSE loss. loss = lambda params, x: 0.5 * np.mean((f(params, x) - y_train)**2) grad_loss = jit(grad(loss)) def get_loss(opt_state): return loss(get_params(opt_state), x_train) steps = 20 for lr_factor in [0.5, 3.]: params, f, ntk = fn_and_kernel(key, train_shape[1:], network, out_logits) g_dd = ntk(x_train, None, 'ntk') step_size = predict.max_learning_rate( g_dd, y_train_size=y_train.size) * lr_factor opt_init, opt_update, get_params = optimizers.sgd(step_size) opt_state = opt_init(params) init_loss = get_loss(opt_state) for i in range(steps): params = get_params(opt_state) opt_state = opt_update(i, grad_loss(params, x_train), opt_state) trained_loss = get_loss(opt_state) loss_ratio = trained_loss / (init_loss + 1e-12) if lr_factor == 3.: if not math.isnan(loss_ratio): self.assertGreater(loss_ratio, 10.) else: self.assertLess(loss_ratio, 0.1)
def _kernel_fns(key, input_shape, network, out_logits, diagonal_axes, trace_axes): init_fn, f, _ = _build_network(input_shape, network, out_logits) _, params = init_fn(key, (1, ) + input_shape) implicit_kernel_fn = empirical.empirical_implicit_ntk_fn( f, trace_axes, diagonal_axes) direct_kernel_fn = empirical.empirical_direct_ntk_fn( f, trace_axes, diagonal_axes) nngp_kernel_fn = empirical.empirical_nngp_fn(f, trace_axes, diagonal_axes) implicit_kernel_fn = jit(implicit_kernel_fn) direct_kernel_fn = jit(direct_kernel_fn) nngp_kernel_fn = jit(nngp_kernel_fn) return (partial(implicit_kernel_fn, params=params), partial(direct_kernel_fn, params=params), partial(nngp_kernel_fn, params=params))
def _empirical_kernel(key, input_shape, network, out_logits, use_dropout): init_fn, f, _ = _build_network(input_shape, network, out_logits, use_dropout) keys = tf_random_split(key) key = keys[0] split = keys[1] _, params = init_fn(key, (1, ) + input_shape) kernel_fn = jit(empirical.empirical_ntk_fn(f)) return partial(kernel_fn, params=params, keys=split)
def f_pmapped(x_or_kernel: Union[np.ndarray, Kernel], *args, **kwargs): args_np, args_np_idxs = [], [] args_other = {} # TODO(romann): treat `np.ndarray`s in `kwargs` when JAX allows it. # https://github.com/google/jax/issues/912 # Filter out `np.ndarray`s from other arguments. for i, arg in enumerate(args): if _is_np_ndarray(arg): args_np.append(arg) args_np_idxs.append(i) else: args_other[i] = arg # Check cache before jitting. _key = key + \ tuple(args_other.items()) + \ tuple(kwargs.items()) # If any of the instance inside `_key` is a tf.Tensor object, use `ref()` # method to avoid directly hashing the TF Tensor. _key = list(_key) for i in range(len(_key)): if isinstance(_key[i], tf.Tensor): _key[i] = tuple(map(tuple, _key[i].ref())) elif isinstance(_key[i], onp.ndarray): _key[i] = tuple(map(tuple, _key[i])) elif isinstance(_key[i], tuple): _key[i] = list(_key[i]) for j in range(len(_key[i])): if isinstance(_key[i][j], tf.Tensor): _key[i][j] = tuple(map(tuple, _key[i][j].ref())) elif isinstance(_key[i][j], onp.ndarray): _key[i][j] = tuple(map(tuple, _key[i][j])) _key[i] = tuple(_key[i]) _key = tuple(_key) if _key in cache: _f = cache[_key] else: # Define a `np.ndarray`-only function as a closure over other arguments. def _f(_x_or_kernel, *_args_np): # Merge args. _args_np = { i: _arg_np for i, _arg_np in zip(args_np_idxs, _args_np) } _args = {**_args_np, **args_other} _args = tuple(v for k, v in sorted(_args.items())) return f(_x_or_kernel, *_args, **kwargs) _f = jit(_f) if device_count == 0 else pmap(_f) cache[_key] = _f # Broadcast `np.ndarray` arguments and apply the new function to them. args_np = tree_map(broadcast, args_np) return _f(x_or_kernel, *args_np)
def main(unused_argv): # Build data pipelines. print('Loading data.') x_train, y_train, x_test, y_test = \ datasets.get_dataset('mnist', FLAGS.train_size, FLAGS.test_size) # Build the network init_fn, apply_fn, _ = stax.serial(stax.Dense(512, 1., 0.05), stax.Erf(), stax.Dense(10, 1., 0.05)) key = stateless_uniform(shape=[2], seed=[0, 0], minval=None, maxval=None, dtype=tf.int32) _, params = init_fn(key, (1, 784)) # Create and initialize an optimizer. opt_init, opt_apply, get_params = optimizers.sgd(FLAGS.learning_rate) state = opt_init(params) # Create an mse loss function and a gradient function. loss = lambda fx, y_hat: 0.5 * np.mean((fx - y_hat)**2) grad_loss = jit(grad(lambda params, x, y: loss(apply_fn(params, x), y))) # Create an MSE predictor to solve the NTK equation in function space. ntk = nt.batch(nt.empirical_ntk_fn(apply_fn), batch_size=4, device_count=0) g_dd = ntk(x_train, None, params) g_td = ntk(x_test, x_train, params) predictor = nt.predict.gradient_descent_mse(g_dd, y_train) # Get initial values of the network in function space. fx_train = apply_fn(params, x_train) fx_test = apply_fn(params, x_test) # Train the network. train_steps = int(FLAGS.train_time // FLAGS.learning_rate) print('Training for {} steps'.format(train_steps)) for i in range(train_steps): params = get_params(state) state = opt_apply(i, grad_loss(params, x_train, y_train), state) # Get predictions from analytic computation. print('Computing analytic prediction.') fx_train, fx_test = predictor(FLAGS.train_time, fx_train, fx_test, g_td) # Print out summary data comparing the linear / nonlinear model. util.print_summary('train', y_train, apply_fn(params, x_train), fx_train, loss) util.print_summary('test', y_test, apply_fn(params, x_test), fx_test, loss)
def _jit_vmap(f): return jit(vmap(f))
def _theoretical_kernel(key, input_shape, network, out_logits): init_fn, f, kernel_fn = _build_network(input_shape, network, out_logits) _, params = init_fn(key, (-1, ) + input_shape) return params, f, jit(kernel_fn, static_argnums=(2, ))
def _empirical_kernel(key, input_shape, network, out_logits): init_fn, f, _ = _build_network(input_shape, network, out_logits) _, params = init_fn(key, (-1, ) + input_shape) _kernel_fn = empirical.empirical_kernel_fn(f, trace_axes=()) kernel_fn = lambda x1, x2, get: _kernel_fn(x1, x2, get, params) return params, f, jit(kernel_fn, static_argnums=(2, ))
def testTrainedEnsemblePredCov(self, train_shape, test_shape, network, out_logits): training_steps = 1000 learning_rate = 0.1 ensemble_size = 1024 init_fn, apply_fn, kernel_fn = stax.serial( stax.Dense(128, W_std=1.2, b_std=0.05), stax.Erf(), stax.Dense(out_logits, W_std=1.2, b_std=0.05)) opt_init, opt_update, get_params = optimizers.sgd(learning_rate) opt_update = jit(opt_update) key, x_test, x_train, y_train = self._get_inputs( out_logits, test_shape, train_shape) predict_fn_mse_ens = predict.gradient_descent_mse_ensemble( kernel_fn, x_train, y_train, learning_rate=learning_rate, diag_reg=0.) train = (x_train, y_train) ensemble_key = tf_random_split(key, ensemble_size) loss = jit(lambda params, x, y: 0.5 * np.mean( (apply_fn(params, x) - y)**2)) grad_loss = jit(lambda state, x, y: grad(loss) (get_params(state), x, y)) def train_network(key): _, params = init_fn(key, (-1, ) + train_shape[1:]) opt_state = opt_init(params) for i in range(training_steps): opt_state = opt_update(i, grad_loss(opt_state, *train), opt_state) return get_params(opt_state) params = vmap(train_network)(ensemble_key) rtol = 0.08 for x in [None, 'x_test']: with self.subTest(x=x): x = x if x is None else x_test x_fin = x_train if x is None else x_test ensemble_fx = vmap(apply_fn, (0, None))(params, x_fin) mean_emp = np.mean(ensemble_fx, axis=0) mean_subtracted = ensemble_fx - mean_emp cov_emp = np.einsum( 'ijk,ilk->jl', mean_subtracted, mean_subtracted, optimize=True) / (mean_subtracted.shape[0] * mean_subtracted.shape[-1]) ntk = predict_fn_mse_ens(training_steps, x, 'ntk', compute_cov=True) self._assertAllClose(mean_emp, ntk.mean, rtol) self._assertAllClose(cov_emp, ntk.covariance, rtol)
def testNTKGDPrediction(self, train_shape, test_shape, network, out_logits, fn_and_kernel, momentum, learning_rate, t, loss): key, x_test, x_train, y_train = self._get_inputs( out_logits, test_shape, train_shape) params, f, ntk = fn_and_kernel(key, train_shape[1:], network, out_logits) g_dd = ntk(x_train, None, 'ntk') g_td = ntk(x_test, x_train, 'ntk') # Regress to an MSE loss. loss_fn = lambda y, y_hat: 0.5 * np.mean((y - y_hat)**2) grad_loss = jit(grad(lambda params, x: loss_fn(f(params, x), y_train))) trace_axes = () if g_dd.ndim == 4 else (-1, ) if loss == 'mse_analytic': if momentum is not None: raise absltest.SkipTest(momentum) predictor = predict.gradient_descent_mse( g_dd, y_train, learning_rate=learning_rate, trace_axes=trace_axes) elif loss == 'mse': predictor = predict.gradient_descent(loss_fn, g_dd, y_train, learning_rate=learning_rate, momentum=momentum, trace_axes=trace_axes) else: raise NotImplementedError(loss) predictor = jit(predictor) fx_train_0 = f(params, x_train) fx_test_0 = f(params, x_test) self._test_zero_time(predictor, fx_train_0, fx_test_0, g_td, momentum) self._test_multi_step(predictor, fx_train_0, fx_test_0, g_td, momentum) if loss == 'mse_analytic': self._test_inf_time(predictor, fx_train_0, fx_test_0, g_td, y_train) if momentum is None: opt_init, opt_update, get_params = optimizers.sgd(learning_rate) else: opt_init, opt_update, get_params = optimizers.momentum( learning_rate, momentum) opt_state = opt_init(params) for i in range(t): params = get_params(opt_state) opt_state = opt_update(i, grad_loss(params, x_train), opt_state) params = get_params(opt_state) fx_train_nn, fx_test_nn = f(params, x_train), f(params, x_test) fx_train_t, fx_test_t = predictor(t, fx_train_0, fx_test_0, g_td) self.assertAllClose(fx_train_nn, fx_train_t, rtol=RTOL, atol=ATOL) self.assertAllClose(fx_test_nn, fx_test_t, rtol=RTOL, atol=ATOL)
def main(unused_argv): # Build data and . print('Loading data.') x_train, y_train, x_test, y_test = datasets.get_dataset('mnist', permute_train=True) # Build the network init_fn, f, _ = stax.serial(stax.Dense(512, 1., 0.05), stax.Erf(), stax.Dense(10, 1., 0.05)) key = stateless_uniform(shape=[2], seed=[0, 0], minval=None, maxval=None, dtype=tf.int32) _, params = init_fn(key, (1, 784)) # Linearize the network about its initial parameters. f_lin = nt.linearize(f, params) # Create and initialize an optimizer for both f and f_lin. opt_init, opt_apply, get_params = momentum(FLAGS.learning_rate, 0.9) # opt_init, opt_apply, get_params = optimizers.sgd(FLAGS.learning_rate) opt_apply = jit(opt_apply) state = opt_init(params) state_lin = opt_init(params) # momentum = MomentumOptimizer(learning_rate=FLAGS.learning_rate, momentum=0.9) # momentum_lin = MomentumOptimizer(learning_rate=FLAGS.learning_rate, momentum=0.9) # Create a cross-entropy loss function. loss = lambda fx, y_hat: -np.mean(log_softmax(fx) * y_hat) # Specialize the loss function to compute gradients for both linearized and # full networks. grad_loss = jit(grad(lambda params, x, y: loss(f(params, x), y))) grad_loss_lin = jit(grad(lambda params, x, y: loss(f_lin(params, x), y))) # Train the network. print('Training.') print('Epoch\tLoss\tLinearized Loss') print('------------------------------------------') epoch = 0 steps_per_epoch = 50000 // FLAGS.batch_size for i, (x, y) in enumerate( datasets.minibatch(x_train, y_train, FLAGS.batch_size, FLAGS.train_epochs)): params = get_params(state) state = opt_apply(i, grad_loss(params, x, y), state) params_lin = get_params(state_lin) state_lin = opt_apply(i, grad_loss_lin(params_lin, x, y), state_lin) # x = np.asarray(x) # y = np.asarray(y) # momentum.apply_gradients((grad_loss(params, x, y), params)) # momentum.apply_gradients((grad_loss_lin(params_lin, x, y), params_lin)) if i % steps_per_epoch == 0: print('{}\t{}\t{}'.format(epoch, loss(f(params, x), y), loss(f_lin(params_lin, x), y))) epoch += 1 # Print out summary data comparing the linear / nonlinear model. x, y = x_train[:10000], y_train[:10000] util.print_summary('train', y, f(params, x), f_lin(params_lin, x), loss) util.print_summary('test', y_test, f(params, x_test), f_lin(params_lin, x_test), loss)