def testGpInference(self): reg = 1e-5 key = stateless_uniform(shape=[2], seed=[1, 1], minval=None, maxval=None, dtype=tf.int32) x_train = np.asarray(normal((4, 2), seed=key)) init_fn, apply_fn, kernel_fn_analytic = stax.serial( stax.Dense(32, 2., 0.5), stax.Relu(), stax.Dense(10, 2., 0.5)) y_train = np.asarray(normal((4, 10), seed=key)) for kernel_fn_is_analytic in [True, False]: if kernel_fn_is_analytic: kernel_fn = kernel_fn_analytic else: _, params = init_fn(key, x_train.shape) kernel_fn_empirical = empirical.empirical_kernel_fn(apply_fn) def kernel_fn(x1, x2, get): return kernel_fn_empirical(x1, x2, get, params) for get in [ None, 'nngp', 'ntk', ('nngp', ), ('ntk', ), ('nngp', 'ntk'), ('ntk', 'nngp') ]: k_dd = kernel_fn(x_train, None, get) gp_inference = predict.gp_inference(k_dd, y_train, diag_reg=reg) gd_ensemble = predict.gradient_descent_mse_ensemble( kernel_fn, x_train, y_train, diag_reg=reg) for x_test in [None, 'x_test']: x_test = None if x_test is None else np.asarray( normal((8, 2), seed=key)) k_td = None if x_test is None else kernel_fn( x_test, x_train, get) for compute_cov in [True, False]: with self.subTest( kernel_fn_is_analytic=kernel_fn_is_analytic, get=get, x_test=x_test if x_test is None else 'x_test', compute_cov=compute_cov): if compute_cov: nngp_tt = (True if x_test is None else kernel_fn(x_test, None, 'nngp')) else: nngp_tt = None out_ens = gd_ensemble(None, x_test, get, compute_cov) out_ens_inf = gd_ensemble(np.inf, x_test, get, compute_cov) self._assertAllClose(out_ens_inf, out_ens, 0.08) if (get is not None and 'nngp' not in get and compute_cov and k_td is not None): with self.assertRaises(ValueError): out_gp_inf = gp_inference( get=get, k_test_train=k_td, nngp_test_test=nngp_tt) else: out_gp_inf = gp_inference( get=get, k_test_train=k_td, nngp_test_test=nngp_tt) self.assertAllClose(out_ens, out_gp_inf)
def testTrainedEnsemblePredCov(self, train_shape, test_shape, network, out_logits): training_steps = 1000 learning_rate = 0.1 ensemble_size = 1024 init_fn, apply_fn, kernel_fn = stax.serial( stax.Dense(128, W_std=1.2, b_std=0.05), stax.Erf(), stax.Dense(out_logits, W_std=1.2, b_std=0.05)) opt_init, opt_update, get_params = optimizers.sgd(learning_rate) opt_update = jit(opt_update) key, x_test, x_train, y_train = self._get_inputs( out_logits, test_shape, train_shape) predict_fn_mse_ens = predict.gradient_descent_mse_ensemble( kernel_fn, x_train, y_train, learning_rate=learning_rate, diag_reg=0.) train = (x_train, y_train) ensemble_key = tf_random_split(key, ensemble_size) loss = jit(lambda params, x, y: 0.5 * np.mean( (apply_fn(params, x) - y)**2)) grad_loss = jit(lambda state, x, y: grad(loss) (get_params(state), x, y)) def train_network(key): _, params = init_fn(key, (-1, ) + train_shape[1:]) opt_state = opt_init(params) for i in range(training_steps): opt_state = opt_update(i, grad_loss(opt_state, *train), opt_state) return get_params(opt_state) params = vmap(train_network)(ensemble_key) rtol = 0.08 for x in [None, 'x_test']: with self.subTest(x=x): x = x if x is None else x_test x_fin = x_train if x is None else x_test ensemble_fx = vmap(apply_fn, (0, None))(params, x_fin) mean_emp = np.mean(ensemble_fx, axis=0) mean_subtracted = ensemble_fx - mean_emp cov_emp = np.einsum( 'ijk,ilk->jl', mean_subtracted, mean_subtracted, optimize=True) / (mean_subtracted.shape[0] * mean_subtracted.shape[-1]) ntk = predict_fn_mse_ens(training_steps, x, 'ntk', compute_cov=True) self._assertAllClose(mean_emp, ntk.mean, rtol) self._assertAllClose(cov_emp, ntk.covariance, rtol)
def testPredictND(self): n_chan = 6 key = random.PRNGKey(1) im_shape = (5, 4, 3) n_train = 2 n_test = 2 x_train = random.normal(key, (n_train, ) + im_shape) y_train = random.uniform(key, (n_train, 3, 2, n_chan)) init_fn, apply_fn, _ = stax.Conv(n_chan, (3, 2), (1, 2)) _, params = init_fn(key, x_train.shape) fx_train_0 = apply_fn(params, x_train) for trace_axes in [(), (-1, ), (-2, ), (-3, ), (0, 1), (2, 3), (2, ), (1, 3), (0, -1), (0, 0, -3), (0, 1, 2, 3), (0, 1, -1, 2)]: for ts in [None, np.arange(6).reshape((2, 3))]: for x in [None, 'x_test']: with self.subTest(trace_axes=trace_axes, ts=ts, x=x): t_shape = ts.shape if ts is not None else () y_test_shape = t_shape + (n_test, ) + y_train.shape[1:] y_train_shape = t_shape + y_train.shape x = x if x is None else random.normal( key, (n_test, ) + im_shape) fx_test_0 = None if x is None else apply_fn(params, x) kernel_fn = empirical.empirical_kernel_fn( apply_fn, trace_axes=trace_axes) # TODO(romann): investigate the SIGTERM error on CPU. # kernel_fn = jit(kernel_fn, static_argnums=(2,)) ntk_train_train = kernel_fn(x_train, None, 'ntk', params) if x is not None: ntk_test_train = kernel_fn(x, x_train, 'ntk', params) loss = lambda x, y: 0.5 * np.mean(x - y)**2 predict_fn_mse = predict.gradient_descent_mse( ntk_train_train, y_train, trace_axes=trace_axes) predict_fn_mse_ensemble = predict.gradient_descent_mse_ensemble( kernel_fn, x_train, y_train, trace_axes=trace_axes, params=params) if x is None: p_train_mse = predict_fn_mse(ts, fx_train_0) else: p_train_mse, p_test_mse = predict_fn_mse( ts, fx_train_0, fx_test_0, ntk_test_train) self.assertAllClose(y_test_shape, p_test_mse.shape) self.assertAllClose(y_train_shape, p_train_mse.shape) p_nngp_mse_ens, p_ntk_mse_ens = predict_fn_mse_ensemble( ts, x, ('nngp', 'ntk'), compute_cov=True) ref_shape = y_train_shape if x is None else y_test_shape self.assertAllClose(ref_shape, p_ntk_mse_ens.mean.shape) self.assertAllClose(ref_shape, p_nngp_mse_ens.mean.shape) if ts is not None: predict_fn = predict.gradient_descent( loss, ntk_train_train, y_train, trace_axes=trace_axes) if x is None: p_train = predict_fn(ts, fx_train_0) else: p_train, p_test = predict_fn( ts, fx_train_0, fx_test_0, ntk_test_train) self.assertAllClose(y_test_shape, p_test.shape) self.assertAllClose(y_train_shape, p_train.shape)
def testNTK_NTKNNGPAgreement(self, train_shape, test_shape, network, out_logits): _, x_test, x_train, y_train = self._get_inputs(out_logits, test_shape, train_shape) _, _, ker_fun = _build_network(train_shape[1:], network, out_logits) reg = 1e-7 predictor = predict.gradient_descent_mse_ensemble(ker_fun, x_train, y_train, diag_reg=reg) ts = np.logspace(-2, 8, 10).reshape((5, 2)) for t in (None, 'ts'): for x in (None, 'x_test'): with self.subTest(t=t, x=x): x = x if x is None else x_test t = t if t is None else ts ntk = predictor(t=t, get='ntk', x_test=x) # Test time broadcasting if t is not None: ntk_ind = np.array([ predictor(t=t, get='ntk', x_test=x) for t in t.ravel() ]).reshape(t.shape + ntk.shape[2:]) self.assertAllClose(ntk_ind, ntk) # Create a hacked kernel function that always returns the ntk kernel def always_ntk(x1, x2, get=('nngp', 'ntk')): out = ker_fun(x1, x2, get=('nngp', 'ntk')) if get == 'nngp' or get == 'ntk': return out.ntk else: return out._replace(nngp=out.ntk) predictor_ntk = predict.gradient_descent_mse_ensemble( always_ntk, x_train, y_train, diag_reg=reg) ntk_nngp = predictor_ntk(t=t, get='nngp', x_test=x) # Test if you use nngp equations with ntk, you get the same mean self.assertAllClose(ntk, ntk_nngp) # Next test that if you go through the NTK code path, but with only # the NNGP kernel, we recreate the NNGP dynamics. # Create a hacked kernel function that always returns the nngp kernel def always_nngp(x1, x2, get=('nngp', 'ntk')): out = ker_fun(x1, x2, get=('nngp', 'ntk')) if get == 'nngp' or get == 'ntk': return out.nngp else: return out._replace(ntk=out.nngp) predictor_nngp = predict.gradient_descent_mse_ensemble( always_nngp, x_train, y_train, diag_reg=reg) nngp_cov = predictor(t=t, get='nngp', x_test=x, compute_cov=True).covariance # test time broadcasting for covariance nngp_ntk_cov = predictor_nngp(t=t, get='ntk', x_test=x, compute_cov=True).covariance if t is not None: nngp_ntk_cov_ind = np.array([ predictor_nngp(t=t, get='ntk', x_test=x, compute_cov=True).covariance for t in t.ravel() ]).reshape(t.shape + nngp_cov.shape[2:]) self.assertAllClose(nngp_ntk_cov_ind, nngp_ntk_cov) # Test if you use ntk equations with nngp, you get the same cov # Although, due to accumulation of numerical errors, only roughly. self.assertAllClose(nngp_cov, nngp_ntk_cov)
def test_kwargs(self, do_batch, mode): rng = random.PRNGKey(1) x_train = random.normal(rng, (8, 7, 10)) x_test = random.normal(rng, (4, 7, 10)) y_train = random.normal(rng, (8, 1)) rng_train, rng_test = random.split(rng, 2) pattern_train = random.normal(rng, (8, 7, 7)) pattern_test = random.normal(rng, (4, 7, 7)) init_fn, apply_fn, kernel_fn = stax.serial( stax.Dense(8), stax.Relu(), stax.Dropout(rate=0.4), stax.Aggregate(), stax.GlobalAvgPool(), stax.Dense(1) ) kw_dd = dict(pattern=(pattern_train, pattern_train)) kw_td = dict(pattern=(pattern_test, pattern_train)) kw_tt = dict(pattern=(pattern_test, pattern_test)) if mode == 'mc': kernel_fn = monte_carlo_kernel_fn(init_fn, apply_fn, rng, 2, batch_size=2 if do_batch else 0) elif mode == 'empirical': kernel_fn = empirical_kernel_fn(apply_fn) if do_batch: raise absltest.SkipTest('Batching of empirical kernel is not ' 'implemented with keyword arguments.') for kw in (kw_dd, kw_td, kw_tt): kw.update(dict(params=init_fn(rng, x_train.shape)[1], get=('nngp', 'ntk'))) kw_dd.update(dict(rng=(rng_train, None))) kw_td.update(dict(rng=(rng_test, rng_train))) kw_tt.update(dict(rng=(rng_test, None))) elif mode == 'analytic': if do_batch: kernel_fn = batch.batch(kernel_fn, batch_size=2) else: raise ValueError(mode) k_dd = kernel_fn(x_train, None, **kw_dd) k_td = kernel_fn(x_test, x_train, **kw_td) k_tt = kernel_fn(x_test, None, **kw_tt) # Infinite time NNGP/NTK. predict_fn_gp = predict.gp_inference(k_dd, y_train) out_gp = predict_fn_gp(k_test_train=k_td, nngp_test_test=k_tt.nngp) if mode == 'empirical': for kw in (kw_dd, kw_td, kw_tt): kw.pop('get') predict_fn_ensemble = predict.gradient_descent_mse_ensemble(kernel_fn, x_train, y_train, **kw_dd) out_ensemble = predict_fn_ensemble(x_test=x_test, compute_cov=True, **kw_tt) self.assertAllClose(out_gp, out_ensemble) # Finite time NTK test. predict_fn_mse = predict.gradient_descent_mse(k_dd.ntk, y_train) out_mse = predict_fn_mse(t=1., fx_train_0=None, fx_test_0=0., k_test_train=k_td.ntk) out_ensemble = predict_fn_ensemble(t=1., get='ntk', x_test=x_test, compute_cov=False, **kw_tt) self.assertAllClose(out_mse, out_ensemble) # Finite time NNGP train. predict_fn_mse = predict.gradient_descent_mse(k_dd.nngp, y_train) out_mse = predict_fn_mse(t=2., fx_train_0=0., fx_test_0=None, k_test_train=k_td.nngp) out_ensemble = predict_fn_ensemble(t=2., get='nngp', x_test=None, compute_cov=False, **kw_dd) self.assertAllClose(out_mse, out_ensemble)