def test_monte_carlo_generator(self, batch_size, device_count, store_on_device, get): test_utils.stub_out_pmap(batch, device_count) x1, x2, init_fn, apply_fn, stax_kernel_fn, key = _get_inputs_and_model( 8, 1) x3, x4, _, _, _, _ = _get_inputs_and_model(8, 1) log_n_max = 4 n_samples = [2**k for k in range(log_n_max)] sample_generator = monte_carlo.monte_carlo_kernel_fn( init_fn, apply_fn, key, n_samples, batch_size, device_count, store_on_device) if get is None: samples_12 = sample_generator(x1, x2) samples_34 = sample_generator(x3, x4) count = 0 for n, s_12, s_34 in zip(n_samples, samples_12, samples_34): sample_fn = monte_carlo.monte_carlo_kernel_fn( init_fn, apply_fn, key, n, batch_size, device_count, store_on_device) sample_12 = sample_fn(x1, x2) sample_34 = sample_fn(x3, x4) self.assertAllClose(s_12, sample_12) self.assertAllClose(s_12, s_34) self.assertAllClose(s_12, sample_34) count += 1 self.assertEqual(log_n_max, count) ker_analytic_12 = stax_kernel_fn(x1, x2, ('nngp', 'ntk')) ker_analytic_34 = stax_kernel_fn(x3, x4, ('nngp', 'ntk')) else: samples_12 = sample_generator(x1, x2, get) samples_34 = sample_generator(x3, x4, get) count = 0 for n, s_12, s_34 in zip(n_samples, samples_12, samples_34): sample_fn = monte_carlo.monte_carlo_kernel_fn( init_fn, apply_fn, key, n, batch_size, device_count, store_on_device) sample_12 = sample_fn(x1, x2, get) sample_34 = sample_fn(x3, x4, get) self.assertAllClose(s_12, sample_12) self.assertAllClose(s_12, s_34) self.assertAllClose(s_12, sample_34) count += 1 self.assertEqual(log_n_max, count) ker_analytic_12 = stax_kernel_fn(x1, x2, get) ker_analytic_34 = stax_kernel_fn(x3, x4, get) self.assertAllClose(ker_analytic_12, s_12, atol=2., rtol=2.) self.assertAllClose(ker_analytic_12, ker_analytic_34)
def test_parallel_in_out_mc(self, same_inputs, batch_size): rng = random.PRNGKey(0) input_key1, input_key2, net_key = random.split(rng, 3) x1_1, x1_2, x1_3 = random.normal(input_key1, (3, 2, 5)) x1 = (x1_1, (x1_2, x1_3)) if same_inputs: x2 = None else: x2_1, x2_2, x2_3 = random.normal(input_key2, (3, 4, 5)) x2 = (x2_1, (x2_2, x2_3)) def net(N_out): return stax.parallel( stax.Dense(N_out), stax.parallel(stax.Dense(N_out + 1), stax.Dense(N_out + 2))) # Check NNGP. init_fn, apply_fn, _ = net(WIDTH) nb_kernel_fn = monte_carlo.monte_carlo_kernel_fn(init_fn, apply_fn, net_key, n_samples=4, trace_axes=(-1, )) kernel_fn = monte_carlo.monte_carlo_kernel_fn(init_fn, apply_fn, net_key, n_samples=4, batch_size=batch_size, trace_axes=(-1, )) self.assertAllClose(kernel_fn(x1, x2, 'nngp'), nb_kernel_fn(x1, x2, 'nngp'))
def test_sample_vs_analytic_nngp(self, batch_size, device_count, store_on_device): utils.stub_out_pmap(batch, device_count) x1, x2, init_fn, apply_fn, stax_kernel_fn, key = _get_inputs_and_model( 1024, 256, xla_bridge.get_backend().platform == 'tpu') sample = monte_carlo.monte_carlo_kernel_fn(init_fn, apply_fn, key, 200, batch_size, device_count, store_on_device) ker_empirical = sample(x1, x2, 'nngp') ker_analytic = stax_kernel_fn(x1, x2, 'nngp') utils.assert_close_matrices(self, ker_analytic, ker_empirical, 2e-2)
def test_monte_carlo_vs_analytic_ntk(self, batch_size, device_count, store_on_device): utils.stub_out_pmap(batch, device_count) x1, x2, init_fn, apply_fn, stax_kernel_fn, key = _get_inputs_and_model( 256, 2, xla_bridge.get_backend().platform == 'tpu') sample = monte_carlo.monte_carlo_kernel_fn(init_fn, apply_fn, key, 100, batch_size, device_count, store_on_device) ker_empirical = sample(x1, x2, 'ntk') ker_empirical = (np.sum(ker_empirical, axis=(-1, -2)) / ker_empirical.shape[-1]) ker_analytic = stax_kernel_fn(x1, x2, 'ntk') utils.assert_close_matrices(self, ker_analytic, ker_empirical, 2e-2)
def test_monte_carlo_vs_analytic_ntk(self, batch_size, device_count, store_on_device): test_utils.stub_out_pmap(batch, device_count) x1, x2, init_fn, apply_fn, stax_kernel_fn, key = _get_inputs_and_model( WIDTH, 2, xla_bridge.get_backend().platform == 'tpu') sample = monte_carlo.monte_carlo_kernel_fn(init_fn, apply_fn, key, 100, batch_size, device_count, store_on_device, vmap_axes=0) ker_empirical = sample(x1, x2, 'ntk') ker_analytic = stax_kernel_fn(x1, x2, 'ntk') test_utils.assert_close_matrices(self, ker_analytic, ker_empirical, 2e-2)
def test_sparse_inputs(self, act, kernel): key = random.PRNGKey(1) input_count = 4 sparse_count = 2 input_size = 128 width = 4096 # NOTE(schsam): It seems that convergence is slower when inputs are sparse. samples = N_SAMPLES if xla_bridge.get_backend().platform == 'gpu': jtu._default_tolerance[np.onp.dtype(np.onp.float64)] = 5e-4 samples = 100 * N_SAMPLES else: jtu._default_tolerance[np.onp.dtype(np.onp.float32)] = 5e-2 jtu._default_tolerance[np.onp.dtype(np.onp.float64)] = 5e-3 # a batch of dense inputs x_dense = random.normal(key, (input_count, input_size)) x_sparse = ops.index_update(x_dense, ops.index[:sparse_count, :], 0.) activation = stax.Relu() if act == 'relu' else stax.Erf() init_fn, apply_fn, kernel_fn = stax.serial( stax.Dense(width), activation, stax.Dense(1 if kernel == 'ntk' else width)) exact = kernel_fn(x_sparse, None, kernel) mc = monte_carlo.monte_carlo_kernel_fn(init_fn, apply_fn, random.split(key, 2)[0], samples)(x_sparse, None, kernel) mc = np.reshape(mc, exact.shape) assert not np.any(np.isnan(exact)) self.assertAllClose(exact[sparse_count:, sparse_count:], mc[sparse_count:, sparse_count:], True)
def _get_empirical(n_samples, get): kernel_fn_empirical = monte_carlo.monte_carlo_kernel_fn( init_fn, apply_fn, key, n_samples) return kernel_fn_empirical(x1, x2, get)
def test_monte_carlo_generator(self, batch_size, device_count, store_on_device, get): utils.stub_out_pmap(batch, device_count) x1, x2, init_fn, apply_fn, stax_kernel_fn, key = _get_inputs_and_model( 8, 1) x3, x4, _, _, _, _ = _get_inputs_and_model(8, 1) log_n_max = 4 n_samples = [2**k for k in range(log_n_max)] sample_generator = monte_carlo.monte_carlo_kernel_fn( init_fn, apply_fn, key, n_samples, batch_size, device_count, store_on_device) if get is None: samples_12 = sample_generator(x1, x2) samples_34 = sample_generator(x3, x4) count = 0 for n, s_12, s_34 in zip(n_samples, samples_12, samples_34): sample_fn = monte_carlo.monte_carlo_kernel_fn( init_fn, apply_fn, key, n, batch_size, device_count, store_on_device) sample_12 = sample_fn(x1, x2) sample_34 = sample_fn(x3, x4) self.assertAllClose(s_12, sample_12, True) self.assertAllClose(s_12, s_34, True) self.assertAllClose(s_12, sample_34, True) count += 1 self.assertEqual(log_n_max, count) ker_analytic_12 = stax_kernel_fn(x1, x2) ker_analytic_34 = stax_kernel_fn(x3, x4) else: samples_12 = sample_generator(x1, x2, get) samples_34 = sample_generator(x3, x4, get) count = 0 for n, s_12, s_34 in zip(n_samples, samples_12, samples_34): sample_fn = monte_carlo.monte_carlo_kernel_fn( init_fn, apply_fn, key, n, batch_size, device_count, store_on_device) sample_12 = sample_fn(x1, x2, get) sample_34 = sample_fn(x3, x4, get) self.assertAllClose(s_12, sample_12, True) self.assertAllClose(s_12, s_34, True) self.assertAllClose(s_12, sample_34, True) count += 1 self.assertEqual(log_n_max, count) ker_analytic_12 = stax_kernel_fn(x1, x2, get) ker_analytic_34 = stax_kernel_fn(x3, x4, get) if get == 'ntk': s_12 = np.squeeze(s_12, (-1, -2)) elif get is None or 'ntk' in get: s_12 = s_12._replace(ntk=np.squeeze(s_12.ntk, (-1, -2))) self.assertAllClose(ker_analytic_12, s_12, True, 2., 2.) self.assertAllClose(ker_analytic_12, ker_analytic_34, True)
def _get_empirical(n_samples, get): kernel_fn_empirical = monte_carlo.monte_carlo_kernel_fn( init_fn, apply_fn, key, n_samples) if same_inputs: assert (x2 is None) return kernel_fn_empirical(x1, x2, get)
def test_fan_in_conv(self, same_inputs, axis, n_branches, get, branch_in, readout): if xla_bridge.get_backend().platform == 'cpu': raise jtu.SkipTest('Not running CNNs on CPU to save time.') if axis in (None, 0, 1, 2) and branch_in == 'dense_after_branch_in': raise jtu.SkipTest('`FanInSum` and `FanInConcat(0/1/2)` ' 'require `is_gaussian`.') if axis == 3 and branch_in == 'dense_before_branch_in': raise jtu.SkipTest('`FanInConcat` on feature axis requires a dense layer ' 'after concatenation.') key = random.PRNGKey(1) X0_1 = random.normal(key, (2, 5, 6, 3)) X0_2 = None if same_inputs else random.normal(key, (3, 5, 6, 3)) if xla_bridge.get_backend().platform == 'tpu': width = 2048 n_samples = 1024 tol = 0.02 else: width = 1024 n_samples = 512 tol = 0.01 conv = stax.Conv(out_chan=width, filter_shape=(3, 3), padding='SAME', W_std=1.25, b_std=0.1) input_layers = [conv, stax.FanOut(n_branches)] branches = [] for b in range(n_branches): branch_layers = [FanInTest._get_phi(b)] for i in range(b): branch_layers += [ stax.Conv( out_chan=width, filter_shape=(i + 1, 4 - i), padding='SAME', W_std=1.25 + i, b_std=0.1 + i), FanInTest._get_phi(i)] if branch_in == 'dense_before_branch_in': branch_layers += [conv] branches += [stax.serial(*branch_layers)] output_layers = [ stax.FanInSum() if axis is None else stax.FanInConcat(axis), stax.Relu(), stax.GlobalAvgPool() if readout == 'pool' else stax.Flatten() ] if branch_in == 'dense_after_branch_in': output_layers.insert(1, conv) nn = stax.serial(*(input_layers + [stax.parallel(*branches)] + output_layers)) init_fn, apply_fn, kernel_fn = stax.serial( nn, stax.Dense(1 if get == 'ntk' else width, 1.25, 0.5)) kernel_fn_mc = monte_carlo.monte_carlo_kernel_fn( init_fn, apply_fn, key, n_samples, device_count=0 if axis in (0, -4) else -1) exact = kernel_fn(X0_1, X0_2, get=get) empirical = kernel_fn_mc(X0_1, X0_2, get=get) empirical = empirical.reshape(exact.shape) utils.assert_close_matrices(self, empirical, exact, tol)
def test_fan_in_fc(self, same_inputs, axis, n_branches, get, branch_in): if axis in (None, 0) and branch_in == 'dense_after_branch_in': raise jtu.SkipTest('`FanInSum` and `FanInConcat(0)` ' 'require `is_gaussian`.') if axis == 1 and branch_in == 'dense_before_branch_in': raise jtu.SkipTest('`FanInConcat` on feature axis requires a dense layer' 'after concatenation.') key = random.PRNGKey(1) X0_1 = random.normal(key, (10, 20)) X0_2 = None if same_inputs else random.normal(key, (8, 20)) if xla_bridge.get_backend().platform == 'tpu': width = 2048 n_samples = 1024 tol = 0.02 else: width = 1024 n_samples = 256 tol = 0.01 dense = stax.Dense(width, 1.25, 0.1) input_layers = [dense, stax.FanOut(n_branches)] branches = [] for b in range(n_branches): branch_layers = [FanInTest._get_phi(b)] for i in range(b): branch_layers += [ stax.Dense(width, 1. + 2 * i, 0.5 + i), FanInTest._get_phi(i)] if branch_in == 'dense_before_branch_in': branch_layers += [dense] branches += [stax.serial(*branch_layers)] output_layers = [ stax.FanInSum() if axis is None else stax.FanInConcat(axis), stax.Relu() ] if branch_in == 'dense_after_branch_in': output_layers.insert(1, dense) nn = stax.serial(*(input_layers + [stax.parallel(*branches)] + output_layers)) if get == 'nngp': init_fn, apply_fn, kernel_fn = nn elif get == 'ntk': init_fn, apply_fn, kernel_fn = stax.serial(nn, stax.Dense(1, 1.25, 0.5)) else: raise ValueError(get) kernel_fn_mc = monte_carlo.monte_carlo_kernel_fn( init_fn, apply_fn, key, n_samples, device_count=0) exact = kernel_fn(X0_1, X0_2, get=get) empirical = kernel_fn_mc(X0_1, X0_2, get=get) empirical = empirical.reshape(exact.shape) utils.assert_close_matrices(self, empirical, exact, tol)