def test_empirical_ntk_diagonal_outputs(self, same_inputs, device_count, trace_axes, diagonal_axes): test_utils.stub_out_pmap(batch, 2) rng = random.PRNGKey(0) input_key1, input_key2, net_key = random.split(rng, 3) init_fn, apply_fn, _ = stax.serial(stax.Dense(256), stax.Relu(), stax.Dense(10)) test_x1 = random.normal(input_key1, (50, 4, 4)) test_x2 = None if same_inputs: test_x2 = random.normal(input_key2, (60, 4, 4)) kernel_fn = empirical.empirical_ntk_fn(apply_fn, trace_axes=trace_axes, diagonal_axes=diagonal_axes, vmap_axes=0, implementation=2) _, params = init_fn(net_key, test_x1.shape) true_kernel = kernel_fn(test_x1, test_x2, params) batched_fn = batch.batch(kernel_fn, device_count=device_count, batch_size=5) batch_kernel = batched_fn(test_x1, test_x2, params) self.assertAllClose(true_kernel, batch_kernel)
def test_parallel_in_out_empirical(self, same_inputs): test_utils.stub_out_pmap(batch, 2) rng = random.PRNGKey(0) input_key1, input_key2, net_key = random.split(rng, 3) x1_1, x1_2, x1_3 = random.normal(input_key1, (3, 4, 10)) x2_1, x2_2, x2_3 = random.normal(input_key2, (3, 8, 10)) x1 = (x1_1, (x1_2, x1_3)) x2 = (x2_1, (x2_2, x2_3)) def net(N_out): return stax.parallel( stax.Dense(N_out), stax.parallel(stax.Dense(N_out + 1), stax.Dense(N_out + 2))) # Check NNGP. init_fn, apply_fn, _ = net(WIDTH) _, params = init_fn(net_key, ((-1, 10), ((-1, 10), (-1, 10)))) kernel_fn = jit(empirical.empirical_nngp_fn(apply_fn)) batch_kernel_fn = jit(batch.batch(kernel_fn, 2)) test_utils.assert_close_matrices(self, kernel_fn(x1, x2, params), batch_kernel_fn(x1, x2, params), RTOL) # Check NTK. init_fn, apply_fn, _ = stax.serial(net(WIDTH), net(1)) _, params = init_fn(net_key, ((-1, 10), ((-1, 10), (-1, 10)))) kernel_fn = jit(empirical.empirical_ntk_fn(apply_fn)) batch_kernel_fn = jit(batch.batch(kernel_fn, 2)) test_utils.assert_close_matrices(self, kernel_fn(x1, x2, params), batch_kernel_fn(x1, x2, params), RTOL)
def testAutomatic(self, train_shape, test_shape, network, name, kernel_fn, batch_size): test_utils.stub_out_pmap(batch, 2) key = stateless_uniform(shape=[2], seed=[0, 0], minval=None, maxval=None, dtype=tf.int32) keys = tf_random_split(key, 3) key = keys[0] self_split = keys[1] other_split = keys[2] data_self = np.asarray(normal(train_shape, seed=self_split)) data_other = np.asarray(normal(test_shape, seed=other_split)) kernel_fn = kernel_fn(key, train_shape[1:], network) kernel_batched = batch.batch(kernel_fn, batch_size=batch_size) _test_kernel_against_batched(self, kernel_fn, kernel_batched, data_self, data_other) kernel_batched = batch.batch(kernel_fn, batch_size=batch_size, store_on_device=False) _test_kernel_against_batched(self, kernel_fn, kernel_batched, data_self, data_other)
def test_monte_carlo_generator(self, batch_size, device_count, store_on_device, get): test_utils.stub_out_pmap(batch, device_count) x1, x2, init_fn, apply_fn, stax_kernel_fn, key = _get_inputs_and_model( 8, 1) x3, x4, _, _, _, _ = _get_inputs_and_model(8, 1) log_n_max = 4 n_samples = [2**k for k in range(log_n_max)] sample_generator = monte_carlo.monte_carlo_kernel_fn( init_fn, apply_fn, key, n_samples, batch_size, device_count, store_on_device) if get is None: samples_12 = sample_generator(x1, x2) samples_34 = sample_generator(x3, x4) count = 0 for n, s_12, s_34 in zip(n_samples, samples_12, samples_34): sample_fn = monte_carlo.monte_carlo_kernel_fn( init_fn, apply_fn, key, n, batch_size, device_count, store_on_device) sample_12 = sample_fn(x1, x2) sample_34 = sample_fn(x3, x4) self.assertAllClose(s_12, sample_12) self.assertAllClose(s_12, s_34) self.assertAllClose(s_12, sample_34) count += 1 self.assertEqual(log_n_max, count) ker_analytic_12 = stax_kernel_fn(x1, x2, ('nngp', 'ntk')) ker_analytic_34 = stax_kernel_fn(x3, x4, ('nngp', 'ntk')) else: samples_12 = sample_generator(x1, x2, get) samples_34 = sample_generator(x3, x4, get) count = 0 for n, s_12, s_34 in zip(n_samples, samples_12, samples_34): sample_fn = monte_carlo.monte_carlo_kernel_fn( init_fn, apply_fn, key, n, batch_size, device_count, store_on_device) sample_12 = sample_fn(x1, x2, get) sample_34 = sample_fn(x3, x4, get) self.assertAllClose(s_12, sample_12) self.assertAllClose(s_12, s_34) self.assertAllClose(s_12, sample_34) count += 1 self.assertEqual(log_n_max, count) ker_analytic_12 = stax_kernel_fn(x1, x2, get) ker_analytic_34 = stax_kernel_fn(x3, x4, get) self.assertAllClose(ker_analytic_12, s_12, atol=2., rtol=2.) self.assertAllClose(ker_analytic_12, ker_analytic_34)
def testParallel(self, train_shape, test_shape, network, name, kernel_fn): test_utils.stub_out_pmap(batch, 2) key = random.PRNGKey(0) key, self_split, other_split = random.split(key, 3) data_self = random.normal(self_split, train_shape) data_other = random.normal(other_split, test_shape) kernel_fn = kernel_fn(key, train_shape[1:], network, use_dropout=False) kernel_batched = batch._parallel(kernel_fn) _test_kernel_against_batched(self, kernel_fn, kernel_batched, data_self, data_other, True)
def test_sample_once_batch(self, batch_size, device_count, store_on_device, get): test_utils.stub_out_pmap(batch, device_count) x1, x2, init_fn, apply_fn, _, key = _get_inputs_and_model() kernel_fn = empirical.empirical_kernel_fn(apply_fn) sample_once_fn = monte_carlo._sample_once_kernel_fn(kernel_fn, init_fn) sample_once_batch_fn = monte_carlo._sample_once_kernel_fn( kernel_fn, init_fn, batch_size, device_count, store_on_device) one_sample = sample_once_fn(x1, x2, key, get) one_sample_batch = sample_once_batch_fn(x1, x2, key, get) self.assertAllClose(one_sample, one_sample_batch)
def test_sample_vs_analytic_nngp(self, batch_size, device_count, store_on_device): test_utils.stub_out_pmap(batch, device_count) x1, x2, init_fn, apply_fn, stax_kernel_fn, key = _get_inputs_and_model( WIDTH, 256, xla_bridge.get_backend().platform == 'tpu') sample = monte_carlo.monte_carlo_kernel_fn(init_fn, apply_fn, key, 200, batch_size, device_count, store_on_device) ker_empirical = sample(x1, x2, 'nngp') ker_analytic = stax_kernel_fn(x1, x2, 'nngp') test_utils.assert_close_matrices(self, ker_analytic, ker_empirical, 2e-2)
def test_parallel_in_out(self, same_inputs): test_utils.stub_out_pmap(batch, 2) rng = random.PRNGKey(0) input_key1, input_key2, mc_key = random.split(rng, 3) x1_1, x1_2, x1_3 = random.normal(input_key1, (3, 4, 10)) x2_1, x2_2, x2_3 = random.normal(input_key2, (3, 8, 10)) x1 = (x1_1, (x1_2, x1_3)) x2 = (x2_1, (x2_2, x2_3)) N = WIDTH def net(N_out): return stax.parallel( stax.Dense(N_out), stax.parallel(stax.Dense(N_out + 1), stax.Dense(N_out + 2))) # Check NNGP. readin = net(N) readout = net(1) K_readin_fn = jit(readin[2]) K_readout_fn = jit(partial(readout[2], get='nngp')) batch_K_readin_fn = batch.batch(K_readin_fn, 2) batch_K_readout_fn = batch.batch(K_readout_fn, 2) test_utils.assert_close_matrices( self, K_readout_fn(K_readin_fn(x1, x2)), batch_K_readout_fn(batch_K_readin_fn(x1, x2)), RTOL) # Check Both. K_readin_fn = jit(readin[2]) K_readout_fn = jit(partial(readout[2], get=('nngp', 'ntk'))) batch_K_readin_fn = batch.batch(K_readin_fn, 2) batch_K_readout_fn = batch.batch(K_readout_fn, 2) get_ntk = utils.nt_tree_fn()(lambda k: k.ntk) test_utils.assert_close_matrices( self, get_ntk(K_readout_fn(K_readin_fn(x1, x2))), get_ntk(batch_K_readout_fn(batch_K_readin_fn(x1, x2))), RTOL)
def testParallel(self, train_shape, test_shape, network, name, kernel_fn): test_utils.stub_out_pmap(batch, 2) key = stateless_uniform(shape=[2], seed=[0, 0], minval=None, maxval=None, dtype=tf.int32) keys = tf_random_split(key, 3) key = keys[0] self_split = keys[1] other_split = keys[2] data_self = np.asarray(normal(train_shape, seed=self_split)) data_other = np.asarray(normal(test_shape, seed=other_split)) kernel_fn = kernel_fn(key, train_shape[1:], network, use_dropout=False) kernel_batched = batch._parallel(kernel_fn) _test_kernel_against_batched(self, kernel_fn, kernel_batched, data_self, data_other, True)
def testAutomatic(self, train_shape, test_shape, network, name, kernel_fn, batch_size): test_utils.stub_out_pmap(batch, 2) key = random.PRNGKey(0) key, self_split, other_split = random.split(key, 3) data_self = random.normal(self_split, train_shape) data_other = random.normal(other_split, test_shape) kernel_fn = kernel_fn(key, train_shape[1:], network) kernel_batched = batch.batch(kernel_fn, batch_size=batch_size) _test_kernel_against_batched(self, kernel_fn, kernel_batched, data_self, data_other) kernel_batched = batch.batch(kernel_fn, batch_size=batch_size, store_on_device=False) _test_kernel_against_batched(self, kernel_fn, kernel_batched, data_self, data_other)
def testComposition(self, train_shape, test_shape, network, name, kernel_fn, batch_size): test_utils.stub_out_pmap(batch, 2) key = random.PRNGKey(0) key, self_split, other_split = random.split(key, 3) data_self = random.normal(self_split, train_shape) data_other = random.normal(other_split, test_shape) kernel_fn = kernel_fn(key, train_shape[1:], network) kernel_batched = batch._parallel( batch._serial(kernel_fn, batch_size=batch_size)) _test_kernel_against_batched(self, kernel_fn, kernel_batched, data_self, data_other) kernel_batched = batch._serial(batch._parallel(kernel_fn), batch_size=batch_size) _test_kernel_against_batched(self, kernel_fn, kernel_batched, data_self, data_other)
def test_monte_carlo_vs_analytic_ntk(self, batch_size, device_count, store_on_device): test_utils.stub_out_pmap(batch, device_count) x1, x2, init_fn, apply_fn, stax_kernel_fn, key = _get_inputs_and_model( 256, 2, xla_bridge.get_backend().platform == 'tpu') sample = monte_carlo.monte_carlo_kernel_fn(init_fn, apply_fn, key, 100, batch_size, device_count, store_on_device) ker_empirical = sample(x1, x2, 'ntk') ker_empirical = (np.sum(ker_empirical, axis=(-1, -2)) / ker_empirical.shape[-1]) ker_analytic = stax_kernel_fn(x1, x2, 'ntk') test_utils.assert_close_matrices(self, ker_analytic, ker_empirical, 2e-2)
def test_jit_or_pmap_broadcast(self): def kernel_fn(x1, x2, do_flip, keys, do_square, params, _unused=None, p=0.65): res = np.abs(np.matmul(x1, x2)) if do_square: res *= res if do_flip: res = -res res *= random.uniform(keys) * p return [res, params] params = (np.array([1., 0.3]), (np.array([1.2]), np.array([0.5]))) x2 = np.arange(0, 10).reshape((10, )) keys = random.PRNGKey(1) kernel_fn_pmapped = batch._jit_or_pmap_broadcast(kernel_fn, device_count=0) x1 = np.arange(0, 10).reshape((1, 10)) for do_flip in [True, False]: for do_square in [True, False]: with self.subTest(do_flip=do_flip, do_square=do_square, device_count=0): res_1 = kernel_fn(x1, x2, do_flip, keys, do_square, params, _unused=True, p=0.65) res_2 = kernel_fn_pmapped(x1, x2, do_flip, keys, do_square, params, _unused=True) self.assertAllClose(res_1, res_2) test_utils.stub_out_pmap(batch, 1) x1 = np.arange(0, 10).reshape((1, 10)) kernel_fn_pmapped = batch._jit_or_pmap_broadcast(kernel_fn, device_count=1) for do_flip in [True, False]: for do_square in [True, False]: with self.subTest(do_flip=do_flip, do_square=do_square, device_count=1): res_1 = kernel_fn(x1, x2, do_flip, keys, do_square, params, _unused=False, p=0.65) res_2 = kernel_fn_pmapped(x1, x2, do_flip, keys, do_square, params, _unused=None) self.assertAllClose(res_1[0], res_2[0]) self.assertAllClose( tree_map(partial(np.expand_dims, axis=0), res_1[1]), res_2[1]) kernel_fn_pmapped = batch._jit_or_pmap_broadcast(kernel_fn, device_count=2) x1 = np.arange(0, 20).reshape((2, 10)) test_utils.stub_out_pmap(batch, 2) def broadcast(arg): return np.broadcast_to(arg, (2, ) + arg.shape) for do_flip in [True, False]: for do_square in [True, False]: with self.subTest(do_flip=do_flip, do_square=do_square, device_count=2): res_1 = kernel_fn(x1, x2, do_flip, keys, do_square, params, p=0.2) res_2 = kernel_fn_pmapped(x1, x2, do_flip, keys, do_square, params, _unused=None, p=0.2) self.assertAllClose(res_1[0][0], res_2[0][0]) self.assertAllClose(res_1[0][1], res_2[0][1]) self.assertAllClose(tree_map(broadcast, res_1[1]), res_2[1])
def testAnalyticKernelComposeAutomatic(self, store_on_device, batch_size): test_utils.stub_out_pmap(batch, 2) self._test_analytic_kernel_composition( partial(batch.batch, batch_size=batch_size, store_on_device=store_on_device))
def testAnalyticKernelComposeParallel(self): test_utils.stub_out_pmap(batch, 2) self._test_analytic_kernel_composition(batch._parallel)