def testPredictOnCPU(self): x_train = random.normal(random.PRNGKey(1), (10, 4, 5, 3)) x_test = random.normal(random.PRNGKey(1), (8, 4, 5, 3)) y_train = random.uniform(random.PRNGKey(1), (10, 7)) _, _, kernel_fn = stax.serial(stax.Conv(1, (3, 3)), stax.Relu(), stax.Flatten(), stax.Dense(1)) for store_on_device in [False, True]: for device_count in [0, 1]: for get in ['ntk', 'nngp', ('nngp', 'ntk'), ('ntk', 'nngp')]: with self.subTest(store_on_device=store_on_device, device_count=device_count, get=get): kernel_fn_batched = batch.batch( kernel_fn, 2, device_count, store_on_device) predictor = predict.gradient_descent_mse_gp( kernel_fn_batched, x_train, y_train, x_test, get, 0., True) gp_inference = predict.gp_inference( kernel_fn_batched, x_train, y_train, x_test, get, 0., True) self.assertAllClose(predictor(None), predictor(np.inf), True) self.assertAllClose(predictor(None), gp_inference, True)
def testZeroTimeAgreement(self, train_shape, test_shape, network, out_logits): """Test that the NTK and NNGP agree at t=0.""" key = random.PRNGKey(0) key, split = random.split(key) x_train = np.cos(random.normal(split, train_shape)) key, split = random.split(key) y_train = np.array( random.bernoulli(split, shape=(train_shape[0], out_logits)), np.float32) key, split = random.split(key) x_test = np.cos(random.normal(split, test_shape)) _, _, ker_fun = _build_network(train_shape[1:], network, out_logits) reg = 1e-7 prediction = predict.gradient_descent_mse_gp(ker_fun, x_train, y_train, x_test, diag_reg=reg, get=('NTK', 'NNGP'), compute_cov=True) zero_prediction = prediction(0.0) self.assertAllClose(zero_prediction.ntk, zero_prediction.nngp, True) reference = (np.zeros( (test_shape[0], out_logits)), ker_fun(x_test, x_test, get='nngp')) self.assertAllClose((reference, ) * 2, zero_prediction, True)
def testInfiniteTimeAgreement(self, train_shape, test_shape, network, out_logits, get): key = random.PRNGKey(0) key, split = random.split(key) x_train = np.cos(random.normal(split, train_shape)) key, split = random.split(key) y_train = np.array( random.bernoulli(split, shape=(train_shape[0], out_logits)), np.float32) key, split = random.split(key) x_test = np.cos(random.normal(split, test_shape)) _, _, kernel_fn = _build_network(train_shape[1:], network, out_logits) reg = 1e-7 prediction = predict.gradient_descent_mse_gp(kernel_fn, x_train, y_train, x_test, get, diag_reg=reg, compute_cov=True) finite_prediction = prediction(np.inf) finite_prediction_none = prediction(None) gp_inference = predict.gp_inference(kernel_fn, x_train, y_train, x_test, get, reg, True) self.assertAllClose(finite_prediction_none, finite_prediction, True) self.assertAllClose(finite_prediction_none, gp_inference, True)
def testNTKPredCovPosDef(self, train_shape, test_shape, network, out_logits): key = random.PRNGKey(0) key, split = random.split(key) x_train = np.cos(random.normal(split, train_shape)) key, split = random.split(key) y_train = np.array( random.bernoulli(split, shape=(train_shape[0], out_logits)), np.float32) key, split = random.split(key) x_test = np.cos(random.normal(split, test_shape)) _, _, ker_fun = _build_network(train_shape[1:], network, out_logits) reg = 1e-7 ntk_predictions = predict.gradient_descent_mse_gp(ker_fun, x_train, y_train, x_test, diag_reg=reg, get='ntk', compute_cov=True) ts = np.logspace(-2, 8, 10) ntk_cov_predictions = [ntk_predictions(t).covariance for t in ts] if xla_bridge.get_backend().platform == 'tpu': eigh = np.onp.linalg.eigh else: eigh = np.linalg.eigh check_symmetric = np.array( [np.max(np.abs(cov - cov.T)) for cov in ntk_cov_predictions]) check_pos_evals = np.min( np.array([eigh(cov)[0] + 1e-10 for cov in ntk_cov_predictions])) self.assertAllClose(check_symmetric, np.zeros_like(check_symmetric), True) self.assertGreater(check_pos_evals, 0., True)
def testInfiniteTimeAgreement(self, train_shape, test_shape, network, out_logits, mode): # TODO(alemi): Add some finite time tests. key = random.PRNGKey(0) key, split = random.split(key) data_train = np.cos(random.normal(split, train_shape)) key, split = random.split(key) data_labels = np.array( random.bernoulli(split, shape=(train_shape[0], out_logits)), np.float32) key, split = random.split(key) data_test = np.cos(random.normal(split, test_shape)) _, _, ker_fun = _build_network(train_shape[1:], network, out_logits) reg = 1e-7 mean_pred, var = predict.gp_inference(ker_fun, data_train, data_labels, data_test, diag_reg=reg, mode=mode, compute_var=True) prediction = predict.gradient_descent_mse_gp(ker_fun, data_train, data_labels, data_test, diag_reg=reg, mode=mode, compute_var=True) inf_mean_pred, inf_var = prediction(np.inf) self.assertAllClose(mean_pred, inf_mean_pred, True) self.assertAllClose(var, inf_var, True)
def testNTK_NTKNNGPAgreement(self, train_shape, test_shape, network, out_logits): key = random.PRNGKey(0) key, split = random.split(key) x_train = np.cos(random.normal(split, train_shape)) key, split = random.split(key) y_train = np.array( random.bernoulli(split, shape=(train_shape[0], out_logits)), np.float32) key, split = random.split(key) x_test = np.cos(random.normal(split, test_shape)) _, _, ker_fun = _build_network(train_shape[1:], network, out_logits) reg = 1e-7 prediction = predict.gradient_descent_mse_gp(ker_fun, x_train, y_train, x_test, diag_reg=reg, get='NTK', compute_cov=True) ts = np.logspace(-2, 8, 10) ntk_predictions = [prediction(t).mean for t in ts] # Create a hacked kernel function that always returns the ntk kernel def always_ntk(x1, x2, get=('nngp', 'ntk')): out = ker_fun(x1, x2, get=('nngp', 'ntk')) if get == 'nngp' or get == 'ntk': return out.ntk else: return out._replace(nngp=out.ntk) ntk_nngp_prediction = predict.gradient_descent_mse_gp(always_ntk, x_train, y_train, x_test, diag_reg=reg, get='NNGP', compute_cov=True) ntk_nngp_predictions = [ntk_nngp_prediction(t).mean for t in ts] # Test if you use the nngp equations with the ntk, you get the same mean self.assertAllClose(ntk_predictions, ntk_nngp_predictions, True) # Next test that if you go through the NTK code path, but with only # the NNGP kernel, we recreate the NNGP dynamics. reg = 1e-7 nngp_prediction = predict.gradient_descent_mse_gp(ker_fun, x_train, y_train, x_test, diag_reg=reg, get='NNGP', compute_cov=True) # Create a hacked kernel function that always returns the nngp kernel def always_nngp(x1, x2, get=('nngp', 'ntk')): out = ker_fun(x1, x2, get=('nngp', 'ntk')) if get == 'nngp' or get == 'ntk': return out.nngp else: return out._replace(ntk=out.nngp) nngp_ntk_prediction = predict.gradient_descent_mse_gp(always_nngp, x_train, y_train, x_test, diag_reg=reg, get='NTK', compute_cov=True) nngp_cov_predictions = [nngp_prediction(t).covariance for t in ts] nngp_ntk_cov_predictions = [ nngp_ntk_prediction(t).covariance for t in ts ] # Test if you use the ntk equations with the nngp, you get the same cov # Although, due to accumulation of numerical errors, only roughly. self.assertAllClose(nngp_cov_predictions, nngp_ntk_cov_predictions, True)