def testMaxLearningRate(self, train_shape, network, out_logits, fn_and_kernel): key = stateless_uniform(shape=[2], seed=[0, 0], minval=None, maxval=None, dtype=tf.int32) keys = tf_random_split(key) key = keys[0] split = keys[1] if len(train_shape) == 2: train_shape = (train_shape[0] * 5, train_shape[1] * 10) else: train_shape = (16, 8, 8, 3) x_train = np.asarray(normal(train_shape, seed=split)) keys = tf_random_split(key) key = keys[0] split = keys[1] y_train = np.asarray( stateless_uniform(shape=(train_shape[0], out_logits), seed=split, minval=0, maxval=1) < 0.5, np.float32) # Regress to an MSE loss. loss = lambda params, x: 0.5 * np.mean((f(params, x) - y_train)**2) grad_loss = jit(grad(loss)) def get_loss(opt_state): return loss(get_params(opt_state), x_train) steps = 20 for lr_factor in [0.5, 3.]: params, f, ntk = fn_and_kernel(key, train_shape[1:], network, out_logits) g_dd = ntk(x_train, None, 'ntk') step_size = predict.max_learning_rate( g_dd, y_train_size=y_train.size) * lr_factor opt_init, opt_update, get_params = optimizers.sgd(step_size) opt_state = opt_init(params) init_loss = get_loss(opt_state) for i in range(steps): params = get_params(opt_state) opt_state = opt_update(i, grad_loss(params, x_train), opt_state) trained_loss = get_loss(opt_state) loss_ratio = trained_loss / (init_loss + 1e-12) if lr_factor == 3.: if not math.isnan(loss_ratio): self.assertGreater(loss_ratio, 10.) else: self.assertLess(loss_ratio, 0.1)
def testAutomatic(self, train_shape, test_shape, network, name, kernel_fn, batch_size): test_utils.stub_out_pmap(batch, 2) key = stateless_uniform(shape=[2], seed=[0, 0], minval=None, maxval=None, dtype=tf.int32) keys = tf_random_split(key, 3) key = keys[0] self_split = keys[1] other_split = keys[2] data_self = np.asarray(normal(train_shape, seed=self_split)) data_other = np.asarray(normal(test_shape, seed=other_split)) kernel_fn = kernel_fn(key, train_shape[1:], network) kernel_batched = batch.batch(kernel_fn, batch_size=batch_size) _test_kernel_against_batched(self, kernel_fn, kernel_batched, data_self, data_other) kernel_batched = batch.batch(kernel_fn, batch_size=batch_size, store_on_device=False) _test_kernel_against_batched(self, kernel_fn, kernel_batched, data_self, data_other)
def testNTKAgainstDirect(self, train_shape, test_shape, network, name, kernel_fn): key = stateless_uniform(shape=[2], seed=[0, 0], minval=None, maxval=None, dtype=tf.int32) splits = tf_random_split(seed=tf.convert_to_tensor(key, dtype=tf.int32), num=3) key = splits[0] self_split = splits[1] other_split = splits[2] data_self = np.asarray(normal(train_shape, seed=self_split)) data_other = np.asarray(normal(test_shape, seed=other_split)) implicit, direct, _ = kernel_fn(key, train_shape[1:], network, diagonal_axes=(), trace_axes=()) g = implicit(data_self, None) g_direct = direct(data_self, None) self.assertAllClose(g, g_direct) g = implicit(data_other, data_self) g_direct = direct(data_other, data_self) self.assertAllClose(g, g_direct)
def _test_analytic_kernel_composition(self, batching_fn): # Check Fully-Connected. rng = stateless_uniform(shape=[2], seed=[0, 0], minval=None, maxval=None, dtype=tf.int32) keys = tf_random_split(rng) rng_self = keys[0] rng_other = keys[1] x_self = np.asarray(normal((8, 10), seed=rng_self)) x_other = np.asarray(normal((2, 10), seed=rng_other)) Block = stax.serial(stax.Dense(256), stax.Relu()) _, _, ker_fn = Block ker_fn = batching_fn(ker_fn) _, _, composed_ker_fn = stax.serial(Block, Block) ker_out = ker_fn(ker_fn(x_self)) composed_ker_out = composed_ker_fn(x_self) if batching_fn == batch._parallel: # In the parallel setting, `x1_is_x2` is not computed correctly # when x1==x2. composed_ker_out = composed_ker_out.replace( x1_is_x2=ker_out.x1_is_x2) self.assertAllClose(ker_out, composed_ker_out) ker_out = ker_fn(ker_fn(x_self, x_other)) composed_ker_out = composed_ker_fn(x_self, x_other) if batching_fn == batch._parallel: composed_ker_out = composed_ker_out.replace( x1_is_x2=ker_out.x1_is_x2) self.assertAllClose(ker_out, composed_ker_out) # Check convolutional + pooling. x_self = np.asarray(normal((8, 10, 10, 3), seed=rng)) x_other = np.asarray(normal((2, 10, 10, 3), seed=rng)) Block = stax.serial(stax.Conv(256, (2, 2)), stax.Relu()) Readout = stax.serial(stax.GlobalAvgPool(), stax.Dense(10)) block_ker_fn, readout_ker_fn = Block[2], Readout[2] _, _, composed_ker_fn = stax.serial(Block, Readout) block_ker_fn = batching_fn(block_ker_fn) readout_ker_fn = batching_fn(readout_ker_fn) ker_out = readout_ker_fn(block_ker_fn(x_self)) composed_ker_out = composed_ker_fn(x_self) if batching_fn == batch._parallel: composed_ker_out = composed_ker_out.replace( x1_is_x2=ker_out.x1_is_x2) self.assertAllClose(ker_out, composed_ker_out) ker_out = readout_ker_fn(block_ker_fn(x_self, x_other)) composed_ker_out = composed_ker_fn(x_self, x_other) if batching_fn == batch._parallel: composed_ker_out = composed_ker_out.replace( x1_is_x2=ker_out.x1_is_x2) self.assertAllClose(ker_out, composed_ker_out)
def _empirical_kernel(key, input_shape, network, out_logits, use_dropout): init_fn, f, _ = _build_network(input_shape, network, out_logits, use_dropout) keys = tf_random_split(key) key = keys[0] split = keys[1] _, params = init_fn(key, (1, ) + input_shape) kernel_fn = jit(empirical.empirical_ntk_fn(f)) return partial(kernel_fn, params=params, keys=split)
def testLinearization(self, shape): key = stateless_uniform(shape=[2], seed=[0, 0], minval=None, maxval=None, dtype=tf.int32) splits = tf_random_split(seed=tf.convert_to_tensor(key, dtype=tf.int32), num=4) key = splits[0] s1 = splits[1] s2 = splits[2] s3 = splits[3] w1 = np.asarray(normal(shape, seed=s1)) w1 = 0.5 * (w1 + w1.T) w2 = np.asarray(normal(shape, seed=s2)) b = np.asarray(normal((shape[-1], ), seed=s3)) params = (w1, w2, b) splits = tf_random_split(seed=tf.convert_to_tensor(key, dtype=tf.int32), num=2) key = splits[0] split = splits[1] x0 = np.asarray(normal((shape[-1], ), seed=split)) f_lin = empirical.linearize(EmpiricalTest.f, x0) for _ in range(TAYLOR_RANDOM_SAMPLES): for do_alter in [True, False]: for do_shift_x in [True, False]: splits = tf_random_split(seed=tf.convert_to_tensor( key, dtype=tf.int32), num=2) key = splits[0] split = splits[1] x = np.asarray(normal((shape[-1], ), seed=split)) self.assertAllClose( EmpiricalTest.f_lin_exact(x0, x, params, do_alter, do_shift_x=do_shift_x), f_lin(x, params, do_alter, do_shift_x=do_shift_x))
def predict_mc(count, key): key = tf_random_split(key, count) fx_train, fx_test = vmap(predict_empirical)(key) fx_train_mean = np.mean(fx_train, axis=0) fx_test_mean = np.mean(fx_test, axis=0) fx_train_centered = fx_train - fx_train_mean fx_test_centered = fx_test - fx_test_mean cov_train = PredictTest._cov_empirical(fx_train_centered) cov_test = PredictTest._cov_empirical(fx_test_centered) return fx_train_mean, fx_test_mean, cov_train, cov_test
def _get_inputs(cls, out_logits, test_shape, train_shape): key = stateless_uniform(shape=[2], seed=[0, 0], minval=None, maxval=None, dtype=tf.int32) keys = tf_random_split(key) key = keys[0] split = keys[1] x_train = np.asarray(normal(train_shape, seed=split)) keys = tf_random_split(key) key = keys[0] split = keys[1] y_train = np.asarray( stateless_uniform(shape=(train_shape[0], out_logits), seed=split, minval=0, maxval=1) < 0.5, np.float32) keys = tf_random_split(key) key = keys[0] split = keys[1] x_test = np.asarray(normal(test_shape, seed=split)) return key, x_test, x_train, y_train
def testAxes(self, diagonal_axes, trace_axes): key = stateless_uniform(shape=[2], seed=[0, 0], minval=None, maxval=None, dtype=tf.int32) splits = tf_random_split(seed=tf.convert_to_tensor(key, dtype=tf.int32), num=3) key = splits[0] self_split = splits[1] other_split = splits[2] data_self = np.asarray(normal((4, 5, 6, 3), seed=self_split)) data_other = np.asarray(normal((2, 5, 6, 3), seed=other_split)) _diagonal_axes = utils.canonicalize_axis(diagonal_axes, data_self) _trace_axes = utils.canonicalize_axis(trace_axes, data_self) if any(d == c for d in _diagonal_axes for c in _trace_axes): raise absltest.SkipTest( 'diagonal axes must be different from channel axes.') implicit, direct, nngp = KERNELS['empirical_logits_3']( key, (5, 6, 3), CONV, diagonal_axes=diagonal_axes, trace_axes=trace_axes) n_marg = len(_diagonal_axes) n_chan = len(_trace_axes) g = implicit(data_self, None) g_direct = direct(data_self, None) g_nngp = nngp(data_self, None) self.assertAllClose(g, g_direct) self.assertEqual(g_nngp.shape, g.shape) self.assertEqual(2 * (data_self.ndim - n_chan) - n_marg, g_nngp.ndim) if 0 not in _trace_axes and 0 not in _diagonal_axes: g = implicit(data_other, data_self) g_direct = direct(data_other, data_self) g_nngp = nngp(data_other, data_self) self.assertAllClose(g, g_direct) self.assertEqual(g_nngp.shape, g.shape) self.assertEqual(2 * (data_self.ndim - n_chan) - n_marg, g_nngp.ndim)
def testSerial(self, train_shape, test_shape, network, name, kernel_fn, batch_size): key = stateless_uniform(shape=[2], seed=[0, 0], minval=None, maxval=None, dtype=tf.int32) keys = tf_random_split(key, 3) key = keys[0] self_split = keys[1] other_split = keys[2] data_self = np.asarray(normal(train_shape, seed=self_split)) data_other = np.asarray(normal(test_shape, seed=other_split)) kernel_fn = kernel_fn(key, train_shape[1:], network) kernel_batched = batch._serial(kernel_fn, batch_size=batch_size) _test_kernel_against_batched(self, kernel_fn, kernel_batched, data_self, data_other)
def testParallel(self, train_shape, test_shape, network, name, kernel_fn): test_utils.stub_out_pmap(batch, 2) key = stateless_uniform(shape=[2], seed=[0, 0], minval=None, maxval=None, dtype=tf.int32) keys = tf_random_split(key, 3) key = keys[0] self_split = keys[1] other_split = keys[2] data_self = np.asarray(normal(train_shape, seed=self_split)) data_other = np.asarray(normal(test_shape, seed=other_split)) kernel_fn = kernel_fn(key, train_shape[1:], network, use_dropout=False) kernel_batched = batch._parallel(kernel_fn) _test_kernel_against_batched(self, kernel_fn, kernel_batched, data_self, data_other, True)
def _get_inputs_and_model(width=1, n_classes=2, use_conv=True): key = stateless_uniform(shape=[2], seed=[1, 1], minval=None, maxval=None, dtype=tf.int32) keys = tf_random_split(key) key = keys[0] split = keys[1] x1 = np.asarray(normal((8, 4, 3, 2), seed=key)) x2 = np.asarray(normal((4, 4, 3, 2), seed=split)) if not use_conv: x1 = np.reshape(x1, (x1.shape[0], -1)) x2 = np.reshape(x2, (x2.shape[0], -1)) init_fn, apply_fn, kernel_fn = stax.serial( stax.Conv(width, (3, 3)) if use_conv else stax.Dense(width), stax.Relu(), stax.Flatten(), stax.Dense(n_classes, 2., 0.5)) return x1, x2, init_fn, apply_fn, kernel_fn, key
def testTrainedEnsemblePredCov(self, train_shape, test_shape, network, out_logits): training_steps = 1000 learning_rate = 0.1 ensemble_size = 1024 init_fn, apply_fn, kernel_fn = stax.serial( stax.Dense(128, W_std=1.2, b_std=0.05), stax.Erf(), stax.Dense(out_logits, W_std=1.2, b_std=0.05)) opt_init, opt_update, get_params = optimizers.sgd(learning_rate) opt_update = jit(opt_update) key, x_test, x_train, y_train = self._get_inputs( out_logits, test_shape, train_shape) predict_fn_mse_ens = predict.gradient_descent_mse_ensemble( kernel_fn, x_train, y_train, learning_rate=learning_rate, diag_reg=0.) train = (x_train, y_train) ensemble_key = tf_random_split(key, ensemble_size) loss = jit(lambda params, x, y: 0.5 * np.mean( (apply_fn(params, x) - y)**2)) grad_loss = jit(lambda state, x, y: grad(loss) (get_params(state), x, y)) def train_network(key): _, params = init_fn(key, (-1, ) + train_shape[1:]) opt_state = opt_init(params) for i in range(training_steps): opt_state = opt_update(i, grad_loss(opt_state, *train), opt_state) return get_params(opt_state) params = vmap(train_network)(ensemble_key) rtol = 0.08 for x in [None, 'x_test']: with self.subTest(x=x): x = x if x is None else x_test x_fin = x_train if x is None else x_test ensemble_fx = vmap(apply_fn, (0, None))(params, x_fin) mean_emp = np.mean(ensemble_fx, axis=0) mean_subtracted = ensemble_fx - mean_emp cov_emp = np.einsum( 'ijk,ilk->jl', mean_subtracted, mean_subtracted, optimize=True) / (mean_subtracted.shape[0] * mean_subtracted.shape[-1]) ntk = predict_fn_mse_ens(training_steps, x, 'ntk', compute_cov=True) self._assertAllClose(mean_emp, ntk.mean, rtol) self._assertAllClose(cov_emp, ntk.covariance, rtol)
def testTaylorExpansion(self, shape): def f_2_exact(x0, x, params, do_alter, do_shift_x=True): w1, w2, b = params f_lin = EmpiricalTest.f_lin_exact(x0, x, params, do_alter, do_shift_x) if do_shift_x: x0 = x0 * 2 + 1. x = x * 2 + 1. if do_alter: b *= 2. w1 += 5. w2 /= 0.9 dx = x - x0 return f_lin + 0.5 * np.dot(np.dot(dx.T, w1), dx) key = stateless_uniform(shape=[2], seed=[0, 0], minval=None, maxval=None, dtype=tf.int32) splits = tf_random_split(seed=tf.convert_to_tensor(key, dtype=tf.int32), num=4) key = splits[0] s1 = splits[1] s2 = splits[2] s3 = splits[3] w1 = np.asarray(normal(shape, seed=s1)) w1 = 0.5 * (w1 + w1.T) w2 = np.asarray(normal(shape, seed=s2)) b = np.asarray(normal((shape[-1], ), seed=s3)) params = (w1, w2, b) splits = tf_random_split(seed=tf.convert_to_tensor(key, dtype=tf.int32), num=2) key = splits[0] split = splits[1] x0 = np.asarray(normal((shape[-1], ), seed=split)) f_lin = empirical.taylor_expand(EmpiricalTest.f, x0, 1) f_2 = empirical.taylor_expand(EmpiricalTest.f, x0, 2) for _ in range(TAYLOR_RANDOM_SAMPLES): for do_alter in [True, False]: for do_shift_x in [True, False]: splits = tf_random_split(seed=tf.convert_to_tensor( key, dtype=tf.int32), num=2) key = splits[0] split = splits[1] x = np.asarray(normal((shape[-1], ), seed=split)) self.assertAllClose( EmpiricalTest.f_lin_exact(x0, x, params, do_alter, do_shift_x=do_shift_x), f_lin(x, params, do_alter, do_shift_x=do_shift_x)) self.assertAllClose( f_2_exact(x0, x, params, do_alter, do_shift_x=do_shift_x), f_2(x, params, do_alter, do_shift_x=do_shift_x))