def testMaxLearningRate(self, train_shape, network, out_logits, fn_and_kernel): key = stateless_uniform(shape=[2], seed=[0, 0], minval=None, maxval=None, dtype=tf.int32) keys = tf_random_split(key) key = keys[0] split = keys[1] if len(train_shape) == 2: train_shape = (train_shape[0] * 5, train_shape[1] * 10) else: train_shape = (16, 8, 8, 3) x_train = np.asarray(normal(train_shape, seed=split)) keys = tf_random_split(key) key = keys[0] split = keys[1] y_train = np.asarray( stateless_uniform(shape=(train_shape[0], out_logits), seed=split, minval=0, maxval=1) < 0.5, np.float32) # Regress to an MSE loss. loss = lambda params, x: 0.5 * np.mean((f(params, x) - y_train)**2) grad_loss = jit(grad(loss)) def get_loss(opt_state): return loss(get_params(opt_state), x_train) steps = 20 for lr_factor in [0.5, 3.]: params, f, ntk = fn_and_kernel(key, train_shape[1:], network, out_logits) g_dd = ntk(x_train, None, 'ntk') step_size = predict.max_learning_rate( g_dd, y_train_size=y_train.size) * lr_factor opt_init, opt_update, get_params = optimizers.sgd(step_size) opt_state = opt_init(params) init_loss = get_loss(opt_state) for i in range(steps): params = get_params(opt_state) opt_state = opt_update(i, grad_loss(params, x_train), opt_state) trained_loss = get_loss(opt_state) loss_ratio = trained_loss / (init_loss + 1e-12) if lr_factor == 3.: if not math.isnan(loss_ratio): self.assertGreater(loss_ratio, 10.) else: self.assertLess(loss_ratio, 0.1)
def testPredictOnCPU(self): key1 = stateless_uniform(shape=[2], seed=[1, 1], minval=None, maxval=None, dtype=tf.int32) key2 = stateless_uniform(shape=[2], seed=[1, 1], minval=None, maxval=None, dtype=tf.int32) key3 = stateless_uniform(shape=[2], seed=[1, 1], minval=None, maxval=None, dtype=tf.int32) x_train = np.asarray(normal((4, 4, 4, 2), seed=key1)) x_test = np.asarray(normal((8, 4, 4, 2), seed=key2)) y_train = np.asarray(stateless_uniform(shape=(4, 2), seed=key3)) _, _, kernel_fn = stax.serial(stax.Conv(1, (3, 3)), stax.Relu(), stax.Flatten(), stax.Dense(1)) for store_on_device in [False, True]: for device_count in [0, 1]: for get in ['ntk', 'nngp', ('nngp', 'ntk'), ('ntk', 'nngp')]: for x in [None, 'x_test']: with self.subTest(store_on_device=store_on_device, device_count=device_count, get=get, x=x): kernel_fn_batched = batch.batch( kernel_fn, 2, device_count, store_on_device) predictor = predict.gradient_descent_mse_ensemble( kernel_fn_batched, x_train, y_train) x = x if x is None else x_test predict_none = predictor(None, x, get, compute_cov=True) predict_inf = predictor(np.inf, x, get, compute_cov=True) self.assertAllClose(predict_none, predict_inf) if x is not None: on_cpu = (not store_on_device or xla_bridge.get_backend().platform == 'cpu') self.assertEqual(on_cpu, utils.is_on_cpu(predict_inf)) self.assertEqual(on_cpu, utils.is_on_cpu(predict_none))
def init_fun(rng, input_shape): output_shape = input_shape[:-1] + (out_dim,) keys = split(seed=tf.convert_to_tensor(rng, dtype=tf.int32), num=2) k1 = keys[0] k2 = keys[1] # convert the two keys from shape (2,) into a scalar k1 = stateless_uniform(shape=[], seed=k1, minval=None, maxval=None, dtype=tf.int32) k2 = stateless_uniform(shape=[], seed=k2, minval=None, maxval=None, dtype=tf.int32) W = W_init(seed=k1, shape=(input_shape[-1], out_dim)) b = b_init(seed=k2, shape=(out_dim,)) return tfnp.zeros(output_shape), (W.numpy(), b.numpy())
def testAutomatic(self, train_shape, test_shape, network, name, kernel_fn, batch_size): test_utils.stub_out_pmap(batch, 2) key = stateless_uniform(shape=[2], seed=[0, 0], minval=None, maxval=None, dtype=tf.int32) keys = tf_random_split(key, 3) key = keys[0] self_split = keys[1] other_split = keys[2] data_self = np.asarray(normal(train_shape, seed=self_split)) data_other = np.asarray(normal(test_shape, seed=other_split)) kernel_fn = kernel_fn(key, train_shape[1:], network) kernel_batched = batch.batch(kernel_fn, batch_size=batch_size) _test_kernel_against_batched(self, kernel_fn, kernel_batched, data_self, data_other) kernel_batched = batch.batch(kernel_fn, batch_size=batch_size, store_on_device=False) _test_kernel_against_batched(self, kernel_fn, kernel_batched, data_self, data_other)
def testNTKAgainstDirect(self, train_shape, test_shape, network, name, kernel_fn): key = stateless_uniform(shape=[2], seed=[0, 0], minval=None, maxval=None, dtype=tf.int32) splits = tf_random_split(seed=tf.convert_to_tensor(key, dtype=tf.int32), num=3) key = splits[0] self_split = splits[1] other_split = splits[2] data_self = np.asarray(normal(train_shape, seed=self_split)) data_other = np.asarray(normal(test_shape, seed=other_split)) implicit, direct, _ = kernel_fn(key, train_shape[1:], network, diagonal_axes=(), trace_axes=()) g = implicit(data_self, None) g_direct = direct(data_self, None) self.assertAllClose(g, g_direct) g = implicit(data_other, data_self) g_direct = direct(data_other, data_self) self.assertAllClose(g, g_direct)
def _test_analytic_kernel_composition(self, batching_fn): # Check Fully-Connected. rng = stateless_uniform(shape=[2], seed=[0, 0], minval=None, maxval=None, dtype=tf.int32) keys = tf_random_split(rng) rng_self = keys[0] rng_other = keys[1] x_self = np.asarray(normal((8, 10), seed=rng_self)) x_other = np.asarray(normal((2, 10), seed=rng_other)) Block = stax.serial(stax.Dense(256), stax.Relu()) _, _, ker_fn = Block ker_fn = batching_fn(ker_fn) _, _, composed_ker_fn = stax.serial(Block, Block) ker_out = ker_fn(ker_fn(x_self)) composed_ker_out = composed_ker_fn(x_self) if batching_fn == batch._parallel: # In the parallel setting, `x1_is_x2` is not computed correctly # when x1==x2. composed_ker_out = composed_ker_out.replace( x1_is_x2=ker_out.x1_is_x2) self.assertAllClose(ker_out, composed_ker_out) ker_out = ker_fn(ker_fn(x_self, x_other)) composed_ker_out = composed_ker_fn(x_self, x_other) if batching_fn == batch._parallel: composed_ker_out = composed_ker_out.replace( x1_is_x2=ker_out.x1_is_x2) self.assertAllClose(ker_out, composed_ker_out) # Check convolutional + pooling. x_self = np.asarray(normal((8, 10, 10, 3), seed=rng)) x_other = np.asarray(normal((2, 10, 10, 3), seed=rng)) Block = stax.serial(stax.Conv(256, (2, 2)), stax.Relu()) Readout = stax.serial(stax.GlobalAvgPool(), stax.Dense(10)) block_ker_fn, readout_ker_fn = Block[2], Readout[2] _, _, composed_ker_fn = stax.serial(Block, Readout) block_ker_fn = batching_fn(block_ker_fn) readout_ker_fn = batching_fn(readout_ker_fn) ker_out = readout_ker_fn(block_ker_fn(x_self)) composed_ker_out = composed_ker_fn(x_self) if batching_fn == batch._parallel: composed_ker_out = composed_ker_out.replace( x1_is_x2=ker_out.x1_is_x2) self.assertAllClose(ker_out, composed_ker_out) ker_out = readout_ker_fn(block_ker_fn(x_self, x_other)) composed_ker_out = composed_ker_fn(x_self, x_other) if batching_fn == batch._parallel: composed_ker_out = composed_ker_out.replace( x1_is_x2=ker_out.x1_is_x2) self.assertAllClose(ker_out, composed_ker_out)
def main(unused_argv): # Build data pipelines. print('Loading data.') x_train, y_train, x_test, y_test = \ datasets.get_dataset('mnist', FLAGS.train_size, FLAGS.test_size) # Build the network init_fn, apply_fn, _ = stax.serial(stax.Dense(512, 1., 0.05), stax.Erf(), stax.Dense(10, 1., 0.05)) key = stateless_uniform(shape=[2], seed=[0, 0], minval=None, maxval=None, dtype=tf.int32) _, params = init_fn(key, (1, 784)) # Create and initialize an optimizer. opt_init, opt_apply, get_params = optimizers.sgd(FLAGS.learning_rate) state = opt_init(params) # Create an mse loss function and a gradient function. loss = lambda fx, y_hat: 0.5 * np.mean((fx - y_hat)**2) grad_loss = jit(grad(lambda params, x, y: loss(apply_fn(params, x), y))) # Create an MSE predictor to solve the NTK equation in function space. ntk = nt.batch(nt.empirical_ntk_fn(apply_fn), batch_size=4, device_count=0) g_dd = ntk(x_train, None, params) g_td = ntk(x_test, x_train, params) predictor = nt.predict.gradient_descent_mse(g_dd, y_train) # Get initial values of the network in function space. fx_train = apply_fn(params, x_train) fx_test = apply_fn(params, x_test) # Train the network. train_steps = int(FLAGS.train_time // FLAGS.learning_rate) print('Training for {} steps'.format(train_steps)) for i in range(train_steps): params = get_params(state) state = opt_apply(i, grad_loss(params, x_train, y_train), state) # Get predictions from analytic computation. print('Computing analytic prediction.') fx_train, fx_test = predictor(FLAGS.train_time, fx_train, fx_test, g_td) # Print out summary data comparing the linear / nonlinear model. util.print_summary('train', y_train, apply_fn(params, x_train), fx_train, loss) util.print_summary('test', y_test, apply_fn(params, x_test), fx_test, loss)
def _get_inputs(cls, out_logits, test_shape, train_shape): key = stateless_uniform(shape=[2], seed=[0, 0], minval=None, maxval=None, dtype=tf.int32) keys = tf_random_split(key) key = keys[0] split = keys[1] x_train = np.asarray(normal(train_shape, seed=split)) keys = tf_random_split(key) key = keys[0] split = keys[1] y_train = np.asarray( stateless_uniform(shape=(train_shape[0], out_logits), seed=split, minval=0, maxval=1) < 0.5, np.float32) keys = tf_random_split(key) key = keys[0] split = keys[1] x_test = np.asarray(normal(test_shape, seed=split)) return key, x_test, x_train, y_train
def testAxes(self, diagonal_axes, trace_axes): key = stateless_uniform(shape=[2], seed=[0, 0], minval=None, maxval=None, dtype=tf.int32) splits = tf_random_split(seed=tf.convert_to_tensor(key, dtype=tf.int32), num=3) key = splits[0] self_split = splits[1] other_split = splits[2] data_self = np.asarray(normal((4, 5, 6, 3), seed=self_split)) data_other = np.asarray(normal((2, 5, 6, 3), seed=other_split)) _diagonal_axes = utils.canonicalize_axis(diagonal_axes, data_self) _trace_axes = utils.canonicalize_axis(trace_axes, data_self) if any(d == c for d in _diagonal_axes for c in _trace_axes): raise absltest.SkipTest( 'diagonal axes must be different from channel axes.') implicit, direct, nngp = KERNELS['empirical_logits_3']( key, (5, 6, 3), CONV, diagonal_axes=diagonal_axes, trace_axes=trace_axes) n_marg = len(_diagonal_axes) n_chan = len(_trace_axes) g = implicit(data_self, None) g_direct = direct(data_self, None) g_nngp = nngp(data_self, None) self.assertAllClose(g, g_direct) self.assertEqual(g_nngp.shape, g.shape) self.assertEqual(2 * (data_self.ndim - n_chan) - n_marg, g_nngp.ndim) if 0 not in _trace_axes and 0 not in _diagonal_axes: g = implicit(data_other, data_self) g_direct = direct(data_other, data_self) g_nngp = nngp(data_other, data_self) self.assertAllClose(g, g_direct) self.assertEqual(g_nngp.shape, g.shape) self.assertEqual(2 * (data_self.ndim - n_chan) - n_marg, g_nngp.ndim)
def kernel_fn(x1, x2, do_flip, keys, do_square, params, _unused=None, p=0.65): res = np.abs(np.matmul(x1, x2)) if do_square: res *= res if do_flip: res = -res res *= stateless_uniform(shape=[], seed=keys) * p return [res, params]
def testGradientDescentMseEnsembleTrain(self): key = stateless_uniform(shape=[2], seed=[1, 1], minval=None, maxval=None, dtype=tf.int32) x = np.asarray(normal((8, 4, 6, 3), seed=key)) _, _, kernel_fn = stax.serial(stax.Conv(1, (2, 2)), stax.Relu(), stax.Conv(1, (2, 1))) y = np.asarray(normal((8, 2, 5, 1), seed=key)) predictor = predict.gradient_descent_mse_ensemble(kernel_fn, x, y) for t in [None, np.array([0., 1., 10.])]: with self.subTest(t=t): y_none = predictor(t, None, None, compute_cov=True) y_x = predictor(t, x, None, compute_cov=True) self._assertAllClose(y_none, y_x, 0.04)
def apply_fun(params, inputs, **kwargs): rng = kwargs.get('rng', None) if rng is None: msg = ("Dropout layer requires apply_fun to be called with a PRNG key " "argument. That is, instead of `apply_fun(params, inputs)`, call " "it like `apply_fun(params, inputs, rng)` where `rng` is a " "jax.random.PRNGKey value.") raise ValueError(msg) if mode == 'train': # keep = random.bernoulli(rng, rate, inputs.shape) # bernoulli = tfp.distributions.Bernoulli(probs=rate) # keep = bernoulli.sample(sample_shape=inputs.shape, seed=rng[0]) prob = tf.ones(inputs.shape) * rate keep = stateless_uniform(shape=inputs.shape, seed=rng, minval=0, maxval=1) < prob return tfnp.where(keep, inputs / rate, 0) else: return inputs
def get_samples(x1: np.ndarray, x2: Optional[np.ndarray], get: Get, **apply_fn_kwargs): _key = stateless_uniform(shape=[2], seed=key, minval=None, maxval=None, dtype=tf.int32) ker_sampled = None for n in range(1, max(n_samples) + 1): _key, split = tf_split(_key) one_sample = kernel_fn_sample_once(x1, x2, split, get, **apply_fn_kwargs) if ker_sampled is None: ker_sampled = one_sample else: ker_sampled = tree_multimap(operator.add, ker_sampled, one_sample) yield n, ker_sampled
def testSerial(self, train_shape, test_shape, network, name, kernel_fn, batch_size): key = stateless_uniform(shape=[2], seed=[0, 0], minval=None, maxval=None, dtype=tf.int32) keys = tf_random_split(key, 3) key = keys[0] self_split = keys[1] other_split = keys[2] data_self = np.asarray(normal(train_shape, seed=self_split)) data_other = np.asarray(normal(test_shape, seed=other_split)) kernel_fn = kernel_fn(key, train_shape[1:], network) kernel_batched = batch._serial(kernel_fn, batch_size=batch_size) _test_kernel_against_batched(self, kernel_fn, kernel_batched, data_self, data_other)
def testLinearization(self, shape): key = stateless_uniform(shape=[2], seed=[0, 0], minval=None, maxval=None, dtype=tf.int32) splits = tf_random_split(seed=tf.convert_to_tensor(key, dtype=tf.int32), num=4) key = splits[0] s1 = splits[1] s2 = splits[2] s3 = splits[3] w1 = np.asarray(normal(shape, seed=s1)) w1 = 0.5 * (w1 + w1.T) w2 = np.asarray(normal(shape, seed=s2)) b = np.asarray(normal((shape[-1], ), seed=s3)) params = (w1, w2, b) splits = tf_random_split(seed=tf.convert_to_tensor(key, dtype=tf.int32), num=2) key = splits[0] split = splits[1] x0 = np.asarray(normal((shape[-1], ), seed=split)) f_lin = empirical.linearize(EmpiricalTest.f, x0) for _ in range(TAYLOR_RANDOM_SAMPLES): for do_alter in [True, False]: for do_shift_x in [True, False]: splits = tf_random_split(seed=tf.convert_to_tensor( key, dtype=tf.int32), num=2) key = splits[0] split = splits[1] x = np.asarray(normal((shape[-1], ), seed=split)) self.assertAllClose( EmpiricalTest.f_lin_exact(x0, x, params, do_alter, do_shift_x=do_shift_x), f_lin(x, params, do_alter, do_shift_x=do_shift_x))
def testParallel(self, train_shape, test_shape, network, name, kernel_fn): test_utils.stub_out_pmap(batch, 2) key = stateless_uniform(shape=[2], seed=[0, 0], minval=None, maxval=None, dtype=tf.int32) keys = tf_random_split(key, 3) key = keys[0] self_split = keys[1] other_split = keys[2] data_self = np.asarray(normal(train_shape, seed=self_split)) data_other = np.asarray(normal(test_shape, seed=other_split)) kernel_fn = kernel_fn(key, train_shape[1:], network, use_dropout=False) kernel_batched = batch._parallel(kernel_fn) _test_kernel_against_batched(self, kernel_fn, kernel_batched, data_self, data_other, True)
def _get_inputs_and_model(width=1, n_classes=2, use_conv=True): key = stateless_uniform(shape=[2], seed=[1, 1], minval=None, maxval=None, dtype=tf.int32) keys = tf_random_split(key) key = keys[0] split = keys[1] x1 = np.asarray(normal((8, 4, 3, 2), seed=key)) x2 = np.asarray(normal((4, 4, 3, 2), seed=split)) if not use_conv: x1 = np.reshape(x1, (x1.shape[0], -1)) x2 = np.reshape(x2, (x2.shape[0], -1)) init_fn, apply_fn, kernel_fn = stax.serial( stax.Conv(width, (3, 3)) if use_conv else stax.Dense(width), stax.Relu(), stax.Flatten(), stax.Dense(n_classes, 2., 0.5)) return x1, x2, init_fn, apply_fn, kernel_fn, key
def test_jit_or_pmap_broadcast(self): def kernel_fn(x1, x2, do_flip, keys, do_square, params, _unused=None, p=0.65): res = np.abs(np.matmul(x1, x2)) if do_square: res *= res if do_flip: res = -res res *= stateless_uniform(shape=[], seed=keys) * p return [res, params] params = (np.array([1., 0.3]), (np.array([1.2]), np.array([0.5]))) x2 = np.arange(0, 10).reshape((10, )) keys = stateless_uniform(shape=[2], seed=[1, 1], minval=None, maxval=None, dtype=tf.int32) kernel_fn_pmapped = batch._jit_or_pmap_broadcast(kernel_fn, device_count=0) x1 = np.arange(0, 10).reshape((1, 10)) for do_flip in [True, False]: for do_square in [True, False]: with self.subTest(do_flip=do_flip, do_square=do_square, device_count=0): res_1 = kernel_fn(x1, x2, do_flip, keys, do_square, params, _unused=True, p=0.65) res_2 = kernel_fn_pmapped(x1, x2, do_flip, keys, do_square, params, _unused=True) self.assertAllClose(res_1, res_2) test_utils.stub_out_pmap(batch, 1) x1 = np.arange(0, 10).reshape((1, 10)) kernel_fn_pmapped = batch._jit_or_pmap_broadcast(kernel_fn, device_count=1) for do_flip in [True, False]: for do_square in [True, False]: with self.subTest(do_flip=do_flip, do_square=do_square, device_count=1): res_1 = kernel_fn(x1, x2, do_flip, keys, do_square, params, _unused=False, p=0.65) res_2 = kernel_fn_pmapped(x1, x2, do_flip, keys, do_square, params, _unused=None) self.assertAllClose(res_1[0], res_2[0]) self.assertAllClose( tree_map(partial(np.expand_dims, axis=0), res_1[1]), res_2[1]) kernel_fn_pmapped = batch._jit_or_pmap_broadcast(kernel_fn, device_count=2) x1 = np.arange(0, 20).reshape((2, 10)) test_utils.stub_out_pmap(batch, 2) def broadcast(arg): return np.broadcast_to(arg, (2, ) + arg.shape) for do_flip in [True, False]: for do_square in [True, False]: with self.subTest(do_flip=do_flip, do_square=do_square, device_count=2): res_1 = kernel_fn(x1, x2, do_flip, keys, do_square, params, p=0.2) res_2 = kernel_fn_pmapped(x1, x2, do_flip, keys, do_square, params, _unused=None, p=0.2) self.assertAllClose(res_1[0][0], res_2[0][0]) self.assertAllClose(res_1[0][1], res_2[0][1]) self.assertAllClose(tree_map(broadcast, res_1[1]), res_2[1])
def main(unused_argv): # Build data and . print('Loading data.') x_train, y_train, x_test, y_test = datasets.get_dataset('mnist', permute_train=True) # Build the network init_fn, f, _ = stax.serial(stax.Dense(512, 1., 0.05), stax.Erf(), stax.Dense(10, 1., 0.05)) key = stateless_uniform(shape=[2], seed=[0, 0], minval=None, maxval=None, dtype=tf.int32) _, params = init_fn(key, (1, 784)) # Linearize the network about its initial parameters. f_lin = nt.linearize(f, params) # Create and initialize an optimizer for both f and f_lin. opt_init, opt_apply, get_params = momentum(FLAGS.learning_rate, 0.9) # opt_init, opt_apply, get_params = optimizers.sgd(FLAGS.learning_rate) opt_apply = jit(opt_apply) state = opt_init(params) state_lin = opt_init(params) # momentum = MomentumOptimizer(learning_rate=FLAGS.learning_rate, momentum=0.9) # momentum_lin = MomentumOptimizer(learning_rate=FLAGS.learning_rate, momentum=0.9) # Create a cross-entropy loss function. loss = lambda fx, y_hat: -np.mean(log_softmax(fx) * y_hat) # Specialize the loss function to compute gradients for both linearized and # full networks. grad_loss = jit(grad(lambda params, x, y: loss(f(params, x), y))) grad_loss_lin = jit(grad(lambda params, x, y: loss(f_lin(params, x), y))) # Train the network. print('Training.') print('Epoch\tLoss\tLinearized Loss') print('------------------------------------------') epoch = 0 steps_per_epoch = 50000 // FLAGS.batch_size for i, (x, y) in enumerate( datasets.minibatch(x_train, y_train, FLAGS.batch_size, FLAGS.train_epochs)): params = get_params(state) state = opt_apply(i, grad_loss(params, x, y), state) params_lin = get_params(state_lin) state_lin = opt_apply(i, grad_loss_lin(params_lin, x, y), state_lin) # x = np.asarray(x) # y = np.asarray(y) # momentum.apply_gradients((grad_loss(params, x, y), params)) # momentum.apply_gradients((grad_loss_lin(params_lin, x, y), params_lin)) if i % steps_per_epoch == 0: print('{}\t{}\t{}'.format(epoch, loss(f(params, x), y), loss(f_lin(params_lin, x), y))) epoch += 1 # Print out summary data comparing the linear / nonlinear model. x, y = x_train[:10000], y_train[:10000] util.print_summary('train', y, f(params, x), f_lin(params_lin, x), loss) util.print_summary('test', y_test, f(params, x_test), f_lin(params_lin, x_test), loss)
def testGpInference(self): reg = 1e-5 key = stateless_uniform(shape=[2], seed=[1, 1], minval=None, maxval=None, dtype=tf.int32) x_train = np.asarray(normal((4, 2), seed=key)) init_fn, apply_fn, kernel_fn_analytic = stax.serial( stax.Dense(32, 2., 0.5), stax.Relu(), stax.Dense(10, 2., 0.5)) y_train = np.asarray(normal((4, 10), seed=key)) for kernel_fn_is_analytic in [True, False]: if kernel_fn_is_analytic: kernel_fn = kernel_fn_analytic else: _, params = init_fn(key, x_train.shape) kernel_fn_empirical = empirical.empirical_kernel_fn(apply_fn) def kernel_fn(x1, x2, get): return kernel_fn_empirical(x1, x2, get, params) for get in [ None, 'nngp', 'ntk', ('nngp', ), ('ntk', ), ('nngp', 'ntk'), ('ntk', 'nngp') ]: k_dd = kernel_fn(x_train, None, get) gp_inference = predict.gp_inference(k_dd, y_train, diag_reg=reg) gd_ensemble = predict.gradient_descent_mse_ensemble( kernel_fn, x_train, y_train, diag_reg=reg) for x_test in [None, 'x_test']: x_test = None if x_test is None else np.asarray( normal((8, 2), seed=key)) k_td = None if x_test is None else kernel_fn( x_test, x_train, get) for compute_cov in [True, False]: with self.subTest( kernel_fn_is_analytic=kernel_fn_is_analytic, get=get, x_test=x_test if x_test is None else 'x_test', compute_cov=compute_cov): if compute_cov: nngp_tt = (True if x_test is None else kernel_fn(x_test, None, 'nngp')) else: nngp_tt = None out_ens = gd_ensemble(None, x_test, get, compute_cov) out_ens_inf = gd_ensemble(np.inf, x_test, get, compute_cov) self._assertAllClose(out_ens_inf, out_ens, 0.08) if (get is not None and 'nngp' not in get and compute_cov and k_td is not None): with self.assertRaises(ValueError): out_gp_inf = gp_inference( get=get, k_test_train=k_td, nngp_test_test=nngp_tt) else: out_gp_inf = gp_inference( get=get, k_test_train=k_td, nngp_test_test=nngp_tt) self.assertAllClose(out_ens, out_gp_inf)
def testTaylorExpansion(self, shape): def f_2_exact(x0, x, params, do_alter, do_shift_x=True): w1, w2, b = params f_lin = EmpiricalTest.f_lin_exact(x0, x, params, do_alter, do_shift_x) if do_shift_x: x0 = x0 * 2 + 1. x = x * 2 + 1. if do_alter: b *= 2. w1 += 5. w2 /= 0.9 dx = x - x0 return f_lin + 0.5 * np.dot(np.dot(dx.T, w1), dx) key = stateless_uniform(shape=[2], seed=[0, 0], minval=None, maxval=None, dtype=tf.int32) splits = tf_random_split(seed=tf.convert_to_tensor(key, dtype=tf.int32), num=4) key = splits[0] s1 = splits[1] s2 = splits[2] s3 = splits[3] w1 = np.asarray(normal(shape, seed=s1)) w1 = 0.5 * (w1 + w1.T) w2 = np.asarray(normal(shape, seed=s2)) b = np.asarray(normal((shape[-1], ), seed=s3)) params = (w1, w2, b) splits = tf_random_split(seed=tf.convert_to_tensor(key, dtype=tf.int32), num=2) key = splits[0] split = splits[1] x0 = np.asarray(normal((shape[-1], ), seed=split)) f_lin = empirical.taylor_expand(EmpiricalTest.f, x0, 1) f_2 = empirical.taylor_expand(EmpiricalTest.f, x0, 2) for _ in range(TAYLOR_RANDOM_SAMPLES): for do_alter in [True, False]: for do_shift_x in [True, False]: splits = tf_random_split(seed=tf.convert_to_tensor( key, dtype=tf.int32), num=2) key = splits[0] split = splits[1] x = np.asarray(normal((shape[-1], ), seed=split)) self.assertAllClose( EmpiricalTest.f_lin_exact(x0, x, params, do_alter, do_shift_x=do_shift_x), f_lin(x, params, do_alter, do_shift_x=do_shift_x)) self.assertAllClose( f_2_exact(x0, x, params, do_alter, do_shift_x=do_shift_x), f_2(x, params, do_alter, do_shift_x=do_shift_x))
def testPredictND(self): n_chan = 6 key = stateless_uniform(shape=[2], seed=[1, 1], minval=None, maxval=None, dtype=tf.int32) im_shape = (5, 4, 3) n_train = 2 n_test = 2 x_train = np.asarray(normal((n_train, ) + im_shape, seed=key)) y_train = stateless_uniform(shape=(n_train, 3, 2, n_chan), seed=key) init_fn, apply_fn, _ = stax.Conv(n_chan, (3, 2), (1, 2)) _, params = init_fn(key, x_train.shape) fx_train_0 = apply_fn(params, x_train) for trace_axes in [(), (-1, ), (-2, ), (-3, ), (0, 1), (2, 3), (2, ), (1, 3), (0, -1), (0, 0, -3), (0, 1, 2, 3), (0, 1, -1, 2)]: for ts in [None, np.arange(6).reshape((2, 3))]: for x in [None, 'x_test']: with self.subTest(trace_axes=trace_axes, ts=ts, x=x): t_shape = ts.shape if ts is not None else () y_test_shape = t_shape + (n_test, ) + y_train.shape[1:] y_train_shape = t_shape + y_train.shape x = x if x is None else np.asarray( normal((n_test, ) + im_shape, seed=key)) fx_test_0 = None if x is None else apply_fn(params, x) kernel_fn = empirical.empirical_kernel_fn( apply_fn, trace_axes=trace_axes) # TODO(romann): investigate the SIGTERM error on CPU. # kernel_fn = jit(kernel_fn, static_argnums=(2,)) ntk_train_train = kernel_fn(x_train, None, 'ntk', params) if x is not None: ntk_test_train = kernel_fn(x, x_train, 'ntk', params) loss = lambda x, y: 0.5 * np.mean(x - y)**2 predict_fn_mse = predict.gradient_descent_mse( ntk_train_train, y_train, trace_axes=trace_axes) predict_fn_mse_ensemble = predict.gradient_descent_mse_ensemble( kernel_fn, x_train, y_train, trace_axes=trace_axes, params=params) if x is None: p_train_mse = predict_fn_mse(ts, fx_train_0) else: p_train_mse, p_test_mse = predict_fn_mse( ts, fx_train_0, fx_test_0, ntk_test_train) self.assertAllClose(y_test_shape, p_test_mse.shape) self.assertAllClose(y_train_shape, p_train_mse.shape) p_nngp_mse_ens, p_ntk_mse_ens = predict_fn_mse_ensemble( ts, x, ('nngp', 'ntk'), compute_cov=True) ref_shape = y_train_shape if x is None else y_test_shape self.assertAllClose(ref_shape, p_ntk_mse_ens.mean.shape) self.assertAllClose(ref_shape, p_nngp_mse_ens.mean.shape) if ts is not None: predict_fn = predict.gradient_descent( loss, ntk_train_train, y_train, trace_axes=trace_axes) if x is None: p_train = predict_fn(ts, fx_train_0) else: p_train, p_test = predict_fn( ts, fx_train_0, fx_test_0, ntk_test_train) self.assertAllClose(y_test_shape, p_test.shape) self.assertAllClose(y_train_shape, p_train.shape)