def testGPInferenceGet(self, train_shape, test_shape, network, out_logits): key = random.PRNGKey(0) key, split = random.split(key) x_train = np.cos(random.normal(split, train_shape)) key, split = random.split(key) y_train = np.array( random.bernoulli(split, shape=(train_shape[0], out_logits)), np.float32) key, split = random.split(key) x_test = np.cos(random.normal(split, test_shape)) _, _, kernel_fn = _build_network(train_shape[1:], network, out_logits) out = predict.gp_inference(kernel_fn, x_train, y_train, x_test, 'ntk', diag_reg=0., compute_cov=True) assert isinstance(out, predict.Gaussian) out = predict.gp_inference(kernel_fn, x_train, y_train, x_test, 'nngp', diag_reg=0., compute_cov=True) assert isinstance(out, predict.Gaussian) out = predict.gp_inference(kernel_fn, x_train, y_train, x_test, ('ntk', ), diag_reg=0., compute_cov=True) assert len(out) == 1 and isinstance(out[0], predict.Gaussian) out = predict.gp_inference(kernel_fn, x_train, y_train, x_test, ('ntk', 'nngp'), diag_reg=0., compute_cov=True) assert (len(out) == 2 and isinstance(out[0], predict.Gaussian) and isinstance(out[1], predict.Gaussian)) out2 = predict.gp_inference(kernel_fn, x_train, y_train, x_test, ('nngp', 'ntk'), diag_reg=0., compute_cov=True) self.assertAllClose(out[0], out2[1], True) self.assertAllClose(out[1], out2[0], True)
def testPredictOnCPU(self): x_train = random.normal(random.PRNGKey(1), (10, 4, 5, 3)) x_test = random.normal(random.PRNGKey(1), (8, 4, 5, 3)) y_train = random.uniform(random.PRNGKey(1), (10, 7)) _, _, kernel_fn = stax.serial(stax.Conv(1, (3, 3)), stax.Relu(), stax.Flatten(), stax.Dense(1)) for store_on_device in [False, True]: for device_count in [0, 1]: for get in ['ntk', 'nngp', ('nngp', 'ntk'), ('ntk', 'nngp')]: with self.subTest(store_on_device=store_on_device, device_count=device_count, get=get): kernel_fn_batched = batch.batch( kernel_fn, 2, device_count, store_on_device) predictor = predict.gradient_descent_mse_gp( kernel_fn_batched, x_train, y_train, x_test, get, 0., True) gp_inference = predict.gp_inference( kernel_fn_batched, x_train, y_train, x_test, get, 0., True) self.assertAllClose(predictor(None), predictor(np.inf), True) self.assertAllClose(predictor(None), gp_inference, True)
def testInfiniteTimeAgreement(self, train_shape, test_shape, network, out_logits, get): key = random.PRNGKey(0) key, split = random.split(key) x_train = np.cos(random.normal(split, train_shape)) key, split = random.split(key) y_train = np.array( random.bernoulli(split, shape=(train_shape[0], out_logits)), np.float32) key, split = random.split(key) x_test = np.cos(random.normal(split, test_shape)) _, _, kernel_fn = _build_network(train_shape[1:], network, out_logits) reg = 1e-7 prediction = predict.gradient_descent_mse_gp(kernel_fn, x_train, y_train, x_test, get, diag_reg=reg, compute_cov=True) finite_prediction = prediction(np.inf) finite_prediction_none = prediction(None) gp_inference = predict.gp_inference(kernel_fn, x_train, y_train, x_test, get, reg, True) self.assertAllClose(finite_prediction_none, finite_prediction, True) self.assertAllClose(finite_prediction_none, gp_inference, True)
def testGpInference(self): reg = 1e-5 key = random.PRNGKey(1) x_train = random.normal(key, (4, 2)) init_fn, apply_fn, kernel_fn_analytic = stax.serial( stax.Dense(32, 2., 0.5), stax.Relu(), stax.Dense(10, 2., 0.5)) y_train = random.normal(key, (4, 10)) for kernel_fn_is_analytic in [True, False]: if kernel_fn_is_analytic: kernel_fn = kernel_fn_analytic else: _, params = init_fn(key, x_train.shape) kernel_fn_empirical = empirical.empirical_kernel_fn(apply_fn) def kernel_fn(x1, x2, get): return kernel_fn_empirical(x1, x2, get, params) for get in [None, 'nngp', 'ntk', ('nngp',), ('ntk',), ('nngp', 'ntk'), ('ntk', 'nngp')]: k_dd = kernel_fn(x_train, None, get) gp_inference = predict.gp_inference(k_dd, y_train, diag_reg=reg) gd_ensemble = predict.gradient_descent_mse_ensemble(kernel_fn, x_train, y_train, diag_reg=reg) for x_test in [None, 'x_test']: x_test = None if x_test is None else random.normal(key, (8, 2)) k_td = None if x_test is None else kernel_fn(x_test, x_train, get) for compute_cov in [True, False]: with self.subTest(kernel_fn_is_analytic=kernel_fn_is_analytic, get=get, x_test=x_test if x_test is None else 'x_test', compute_cov=compute_cov): if compute_cov: nngp_tt = (True if x_test is None else kernel_fn(x_test, None, 'nngp')) else: nngp_tt = None out_ens = gd_ensemble(None, x_test, get, compute_cov) out_ens_inf = gd_ensemble(np.inf, x_test, get, compute_cov) self._assertAllClose(out_ens_inf, out_ens, 0.08) if (get is not None and 'nngp' not in get and compute_cov and k_td is not None): with self.assertRaises(ValueError): out_gp_inf = gp_inference(get=get, k_test_train=k_td, nngp_test_test=nngp_tt) else: out_gp_inf = gp_inference(get=get, k_test_train=k_td, nngp_test_test=nngp_tt) self.assertAllClose(out_ens, out_gp_inf)
def testNTKMeanPrediction(self, train_shape, test_shape, network, out_logits): key = random.PRNGKey(0) key, split = random.split(key) data_train = np.cos(random.normal(split, train_shape)) key, split = random.split(key) data_labels = np.array( random.bernoulli(split, shape=(train_shape[0], out_logits)), np.float32) key, split = random.split(key) data_test = np.cos(random.normal(split, test_shape)) _, _, ker_fun = _build_network(train_shape[1:], network, out_logits) mean_pred, var = predict.gp_inference(ker_fun, data_train, data_labels, data_test, diag_reg=0., mode='NTK', compute_var=True) if xla_bridge.get_backend().platform == 'tpu': eigh = np.onp.linalg.eigh else: eigh = np.linalg.eigh self.assertEqual(var.shape[0], data_test.shape[0]) min_eigh = np.min(eigh(var)[0]) self.assertGreater(min_eigh + 1e-10, 0.) def mc_sampling(count=10): empirical_mean = 0. key = random.PRNGKey(100) for _ in range(count): key, split = random.split(key) params, f, theta = _empirical_kernel(split, train_shape[1:], network, out_logits) g_dd = theta(data_train, None) g_td = theta(data_test, data_train) predictor = predict.gradient_descent_mse( g_dd, data_labels, g_td) fx_initial_train = f(params, data_train) fx_initial_test = f(params, data_test) _, fx_pred_test = predictor(1.0e8, fx_initial_train, fx_initial_test) empirical_mean += fx_pred_test return empirical_mean / count atol = ATOL rtol = RTOL mean_emp = mc_sampling(100) self.assertAllClose(mean_pred, mean_emp, True, rtol, atol)
def testInfiniteTimeAgreement(self, train_shape, test_shape, network, out_logits, mode): # TODO(alemi): Add some finite time tests. key = random.PRNGKey(0) key, split = random.split(key) data_train = np.cos(random.normal(split, train_shape)) key, split = random.split(key) data_labels = np.array( random.bernoulli(split, shape=(train_shape[0], out_logits)), np.float32) key, split = random.split(key) data_test = np.cos(random.normal(split, test_shape)) _, _, ker_fun = _build_network(train_shape[1:], network, out_logits) reg = 1e-7 mean_pred, var = predict.gp_inference(ker_fun, data_train, data_labels, data_test, diag_reg=reg, mode=mode, compute_var=True) prediction = predict.gradient_descent_mse_gp(ker_fun, data_train, data_labels, data_test, diag_reg=reg, mode=mode, compute_var=True) inf_mean_pred, inf_var = prediction(np.inf) self.assertAllClose(mean_pred, inf_mean_pred, True) self.assertAllClose(var, inf_var, True)
def testTrainedEnsemblePredCov(self, train_shape, test_shape, network, out_logits): if xla_bridge.get_backend().platform == 'gpu' and config.read( 'jax_enable_x64'): raise jtu.SkipTest('Not running GPU x64 to save time.') training_steps = 5000 learning_rate = 1.0 ensemble_size = 50 init_fn, apply_fn, ker_fn = stax.serial( stax.Dense(1024, W_std=1.2, b_std=0.05), stax.Erf(), stax.Dense(out_logits, W_std=1.2, b_std=0.05)) opt_init, opt_update, get_params = optimizers.sgd(learning_rate) opt_update = jit(opt_update) key = random.PRNGKey(0) key, = random.split(key, 1) key, split = random.split(key) x_train = np.cos(random.normal(split, train_shape)) key, split = random.split(key) y_train = np.array( random.bernoulli(split, shape=(train_shape[0], out_logits)), np.float32) train = (x_train, y_train) key, split = random.split(key) x_test = np.cos(random.normal(split, test_shape)) ensemble_key = random.split(key, ensemble_size) loss = jit(lambda params, x, y: 0.5 * np.mean( (apply_fn(params, x) - y)**2)) grad_loss = jit(lambda state, x, y: grad(loss) (get_params(state), x, y)) def train_network(key): _, params = init_fn(key, (-1, ) + train_shape[1:]) opt_state = opt_init(params) for i in range(training_steps): opt_state = opt_update(i, grad_loss(opt_state, *train), opt_state) return get_params(opt_state) params = vmap(train_network)(ensemble_key) ensemble_fx = vmap(apply_fn, (0, None))(params, x_test) ensemble_loss = vmap(loss, (0, None, None))(params, x_train, y_train) ensemble_loss = np.mean(ensemble_loss) self.assertLess(ensemble_loss, 1e-5, True) mean_emp = np.mean(ensemble_fx, axis=0) mean_subtracted = ensemble_fx - mean_emp cov_emp = np.einsum( 'ijk,ilk->jl', mean_subtracted, mean_subtracted, optimize=True) / ( mean_subtracted.shape[0] * mean_subtracted.shape[-1]) reg = 1e-7 ntk_predictions = predict.gp_inference(ker_fn, x_train, y_train, x_test, 'ntk', reg, compute_cov=True) self.assertAllClose(mean_emp, ntk_predictions.mean, True, RTOL, ATOL) self.assertAllClose(cov_emp, ntk_predictions.covariance, True, RTOL, ATOL)
def testNTKMeanCovPrediction(self, train_shape, test_shape, network, out_logits): key = random.PRNGKey(0) key, split = random.split(key) x_train = np.cos(random.normal(split, train_shape)) key, split = random.split(key) y_train = np.array( random.bernoulli(split, shape=(train_shape[0], out_logits)), np.float32) key, split = random.split(key) x_test = np.cos(random.normal(split, test_shape)) _, _, kernel_fn = _build_network(train_shape[1:], network, out_logits) mean_pred, cov_pred = predict.gp_inference(kernel_fn, x_train, y_train, x_test, 'ntk', diag_reg=0., compute_cov=True) if xla_bridge.get_backend().platform == 'tpu': eigh = np.onp.linalg.eigh else: eigh = np.linalg.eigh self.assertEqual(cov_pred.shape[0], x_test.shape[0]) min_eigh = np.min(eigh(cov_pred)[0]) self.assertGreater(min_eigh + 1e-10, 0.) def mc_sampling(count=10): key = random.PRNGKey(100) init_fn, f, _ = _build_network(train_shape[1:], network, out_logits) _kernel_fn = empirical.empirical_kernel_fn(f) kernel_fn = jit( lambda x1, x2, params: _kernel_fn(x1, x2, params, 'ntk')) collect_test_predict = [] for _ in range(count): key, split = random.split(key) _, params = init_fn(split, train_shape) g_dd = kernel_fn(x_train, None, params) g_td = kernel_fn(x_test, x_train, params) predictor = predict.gradient_descent_mse(g_dd, y_train, g_td) fx_initial_train = f(params, x_train) fx_initial_test = f(params, x_test) _, fx_pred_test = predictor(1.0e8, fx_initial_train, fx_initial_test) collect_test_predict.append(fx_pred_test) collect_test_predict = np.array(collect_test_predict) mean_emp = np.mean(collect_test_predict, axis=0) mean_subtracted = collect_test_predict - mean_emp cov_emp = np.einsum( 'ijk,ilk->jl', mean_subtracted, mean_subtracted, optimize=True) / (mean_subtracted.shape[0] * mean_subtracted.shape[-1]) return mean_emp, cov_emp atol = ATOL rtol = RTOL mean_emp, cov_emp = mc_sampling(100) self.assertAllClose(mean_pred, mean_emp, True, rtol, atol) self.assertAllClose(cov_pred, cov_emp, True, rtol, atol)
def test_kwargs(self, do_batch, mode): rng = random.PRNGKey(1) x_train = random.normal(rng, (8, 7, 10)) x_test = random.normal(rng, (4, 7, 10)) y_train = random.normal(rng, (8, 1)) rng_train, rng_test = random.split(rng, 2) pattern_train = random.normal(rng, (8, 7, 7)) pattern_test = random.normal(rng, (4, 7, 7)) init_fn, apply_fn, kernel_fn = stax.serial( stax.Dense(8), stax.Relu(), stax.Dropout(rate=0.4), stax.Aggregate(), stax.GlobalAvgPool(), stax.Dense(1) ) kw_dd = dict(pattern=(pattern_train, pattern_train)) kw_td = dict(pattern=(pattern_test, pattern_train)) kw_tt = dict(pattern=(pattern_test, pattern_test)) if mode == 'mc': kernel_fn = monte_carlo_kernel_fn(init_fn, apply_fn, rng, 2, batch_size=2 if do_batch else 0) elif mode == 'empirical': kernel_fn = empirical_kernel_fn(apply_fn) if do_batch: raise absltest.SkipTest('Batching of empirical kernel is not ' 'implemented with keyword arguments.') for kw in (kw_dd, kw_td, kw_tt): kw.update(dict(params=init_fn(rng, x_train.shape)[1], get=('nngp', 'ntk'))) kw_dd.update(dict(rng=(rng_train, None))) kw_td.update(dict(rng=(rng_test, rng_train))) kw_tt.update(dict(rng=(rng_test, None))) elif mode == 'analytic': if do_batch: kernel_fn = batch.batch(kernel_fn, batch_size=2) else: raise ValueError(mode) k_dd = kernel_fn(x_train, None, **kw_dd) k_td = kernel_fn(x_test, x_train, **kw_td) k_tt = kernel_fn(x_test, None, **kw_tt) # Infinite time NNGP/NTK. predict_fn_gp = predict.gp_inference(k_dd, y_train) out_gp = predict_fn_gp(k_test_train=k_td, nngp_test_test=k_tt.nngp) if mode == 'empirical': for kw in (kw_dd, kw_td, kw_tt): kw.pop('get') predict_fn_ensemble = predict.gradient_descent_mse_ensemble(kernel_fn, x_train, y_train, **kw_dd) out_ensemble = predict_fn_ensemble(x_test=x_test, compute_cov=True, **kw_tt) self.assertAllClose(out_gp, out_ensemble) # Finite time NTK test. predict_fn_mse = predict.gradient_descent_mse(k_dd.ntk, y_train) out_mse = predict_fn_mse(t=1., fx_train_0=None, fx_test_0=0., k_test_train=k_td.ntk) out_ensemble = predict_fn_ensemble(t=1., get='ntk', x_test=x_test, compute_cov=False, **kw_tt) self.assertAllClose(out_mse, out_ensemble) # Finite time NNGP train. predict_fn_mse = predict.gradient_descent_mse(k_dd.nngp, y_train) out_mse = predict_fn_mse(t=2., fx_train_0=0., fx_test_0=None, k_test_train=k_td.nngp) out_ensemble = predict_fn_ensemble(t=2., get='nngp', x_test=None, compute_cov=False, **kw_dd) self.assertAllClose(out_mse, out_ensemble)