def test_parallel_in_out_empirical(self, same_inputs): test_utils.stub_out_pmap(batch, 2) rng = random.PRNGKey(0) input_key1, input_key2, net_key = random.split(rng, 3) x1_1, x1_2, x1_3 = random.normal(input_key1, (3, 4, 10)) x2_1, x2_2, x2_3 = random.normal(input_key2, (3, 8, 10)) x1 = (x1_1, (x1_2, x1_3)) x2 = (x2_1, (x2_2, x2_3)) def net(N_out): return stax.parallel( stax.Dense(N_out), stax.parallel(stax.Dense(N_out + 1), stax.Dense(N_out + 2))) # Check NNGP. init_fn, apply_fn, _ = net(WIDTH) _, params = init_fn(net_key, ((-1, 10), ((-1, 10), (-1, 10)))) kernel_fn = jit(empirical.empirical_nngp_fn(apply_fn)) batch_kernel_fn = jit(batch.batch(kernel_fn, 2)) test_utils.assert_close_matrices(self, kernel_fn(x1, x2, params), batch_kernel_fn(x1, x2, params), RTOL) # Check NTK. init_fn, apply_fn, _ = stax.serial(net(WIDTH), net(1)) _, params = init_fn(net_key, ((-1, 10), ((-1, 10), (-1, 10)))) kernel_fn = jit(empirical.empirical_ntk_fn(apply_fn)) batch_kernel_fn = jit(batch.batch(kernel_fn, 2)) test_utils.assert_close_matrices(self, kernel_fn(x1, x2, params), batch_kernel_fn(x1, x2, params), RTOL)
def test_sample_vs_analytic_nngp(self, batch_size, device_count, store_on_device): test_utils.stub_out_pmap(batch, device_count) x1, x2, init_fn, apply_fn, stax_kernel_fn, key = _get_inputs_and_model( WIDTH, 256, xla_bridge.get_backend().platform == 'tpu') sample = monte_carlo.monte_carlo_kernel_fn(init_fn, apply_fn, key, 200, batch_size, device_count, store_on_device) ker_empirical = sample(x1, x2, 'nngp') ker_analytic = stax_kernel_fn(x1, x2, 'nngp') test_utils.assert_close_matrices(self, ker_analytic, ker_empirical, 2e-2)
def test_parallel_in_out(self, same_inputs): test_utils.stub_out_pmap(batch, 2) rng = random.PRNGKey(0) input_key1, input_key2, mc_key = random.split(rng, 3) x1_1, x1_2, x1_3 = random.normal(input_key1, (3, 4, 10)) x2_1, x2_2, x2_3 = random.normal(input_key2, (3, 8, 10)) x1 = (x1_1, (x1_2, x1_3)) x2 = (x2_1, (x2_2, x2_3)) N = WIDTH def net(N_out): return stax.parallel( stax.Dense(N_out), stax.parallel(stax.Dense(N_out + 1), stax.Dense(N_out + 2))) # Check NNGP. readin = net(N) readout = net(1) K_readin_fn = jit(readin[2]) K_readout_fn = jit(partial(readout[2], get='nngp')) batch_K_readin_fn = batch.batch(K_readin_fn, 2) batch_K_readout_fn = batch.batch(K_readout_fn, 2) test_utils.assert_close_matrices( self, K_readout_fn(K_readin_fn(x1, x2)), batch_K_readout_fn(batch_K_readin_fn(x1, x2)), RTOL) # Check Both. K_readin_fn = jit(readin[2]) K_readout_fn = jit(partial(readout[2], get=('nngp', 'ntk'))) batch_K_readin_fn = batch.batch(K_readin_fn, 2) batch_K_readout_fn = batch.batch(K_readout_fn, 2) get_ntk = utils.nt_tree_fn()(lambda k: k.ntk) test_utils.assert_close_matrices( self, get_ntk(K_readout_fn(K_readin_fn(x1, x2))), get_ntk(batch_K_readout_fn(batch_K_readin_fn(x1, x2))), RTOL)
def test_monte_carlo_vs_analytic_ntk(self, batch_size, device_count, store_on_device): test_utils.stub_out_pmap(batch, device_count) x1, x2, init_fn, apply_fn, stax_kernel_fn, key = _get_inputs_and_model( 256, 2, xla_bridge.get_backend().platform == 'tpu') sample = monte_carlo.monte_carlo_kernel_fn(init_fn, apply_fn, key, 100, batch_size, device_count, store_on_device) ker_empirical = sample(x1, x2, 'ntk') ker_empirical = (np.sum(ker_empirical, axis=(-1, -2)) / ker_empirical.shape[-1]) ker_analytic = stax_kernel_fn(x1, x2, 'ntk') test_utils.assert_close_matrices(self, ker_analytic, ker_empirical, 2e-2)
def _check_agreement_with_empirical(self, net, same_inputs, is_conv, use_dropout, is_ntk, proj_into_2d): (init_fn, apply_fn, kernel_fn), input_shape = net num_samples = N_SAMPLES * 5 if use_dropout else N_SAMPLES key = random.PRNGKey(1) x1, x2 = _get_inputs(key, is_conv, same_inputs, input_shape) x1_out_shape, params = init_fn(key, x1.shape) if same_inputs: assert (x2 is None) if x2 is None: x2_out_shape = x1_out_shape else: x2_out_shape, params = init_fn(key, x2.shape) del (params) def _get_empirical(n_samples, get): kernel_fn_empirical = monte_carlo.monte_carlo_kernel_fn( init_fn, apply_fn, key, n_samples) if same_inputs: assert (x2 is None) return kernel_fn_empirical(x1, x2, get) if proj_into_2d == 'ATTN_PARAM': # no analytic kernel available, just test forward/backward pass _get_empirical(1, 'ntk' if is_ntk else 'nngp') else: if is_ntk: exact, shape1, shape2 = kernel_fn(x1, x2, ('ntk', 'shape1', 'shape2')) empirical = np.reshape(_get_empirical(num_samples, 'ntk'), exact.shape) else: exact, shape1, shape2 = kernel_fn(x1, x2, ('nngp', 'shape1', 'shape2')) empirical = _get_empirical(num_samples, 'nngp') test_utils.assert_close_matrices(self, exact, empirical, RTOL) self.assertEqual(shape1, x1_out_shape) self.assertEqual(shape2, x2_out_shape)
def test_fan_in_conv(self, same_inputs, axis, n_branches, get, branch_in, readout): if xla_bridge.get_backend().platform == 'cpu': raise jtu.SkipTest('Not running CNNs on CPU to save time.') if axis in (None, 0, 1, 2) and branch_in == 'dense_after_branch_in': raise jtu.SkipTest('`FanInSum` and `FanInConcat(0/1/2)` ' 'require `is_gaussian`.') if axis == 3 and branch_in == 'dense_before_branch_in': raise jtu.SkipTest( '`FanInConcat` on feature axis requires a dense layer ' 'after concatenation.') key = random.PRNGKey(1) X0_1 = random.normal(key, (2, 5, 6, 3)) X0_2 = None if same_inputs else random.normal(key, (3, 5, 6, 3)) if xla_bridge.get_backend().platform == 'tpu': width = 2048 n_samples = 1024 tol = 0.02 else: width = 1024 n_samples = 512 tol = 0.01 conv = stax.Conv(out_chan=width, filter_shape=(3, 3), padding='SAME', W_std=1.25, b_std=0.1) input_layers = [conv, stax.FanOut(n_branches)] branches = [] for b in range(n_branches): branch_layers = [FanInTest._get_phi(b)] for i in range(b): branch_layers += [ stax.Conv(out_chan=width, filter_shape=(i + 1, 4 - i), padding='SAME', W_std=1.25 + i, b_std=0.1 + i), FanInTest._get_phi(i) ] if branch_in == 'dense_before_branch_in': branch_layers += [conv] branches += [stax.serial(*branch_layers)] output_layers = [ stax.FanInSum() if axis is None else stax.FanInConcat(axis), stax.Relu(), stax.GlobalAvgPool() if readout == 'pool' else stax.Flatten() ] if branch_in == 'dense_after_branch_in': output_layers.insert(1, conv) nn = stax.serial(*(input_layers + [stax.parallel(*branches)] + output_layers)) init_fn, apply_fn, kernel_fn = stax.serial( nn, stax.Dense(1 if get == 'ntk' else width, 1.25, 0.5)) kernel_fn_mc = monte_carlo.monte_carlo_kernel_fn( init_fn, apply_fn, key, n_samples, device_count=0 if axis in (0, -4) else -1) exact = kernel_fn(X0_1, X0_2, get=get) empirical = kernel_fn_mc(X0_1, X0_2, get=get) empirical = empirical.reshape(exact.shape) test_utils.assert_close_matrices(self, empirical, exact, tol)
def test_fan_in_fc(self, same_inputs, axis, n_branches, get, branch_in): if axis in (None, 0) and branch_in == 'dense_after_branch_in': raise jtu.SkipTest('`FanInSum` and `FanInConcat(0)` ' 'require `is_gaussian`.') if axis == 1 and branch_in == 'dense_before_branch_in': raise jtu.SkipTest( '`FanInConcat` on feature axis requires a dense layer' 'after concatenation.') key = random.PRNGKey(1) X0_1 = random.normal(key, (10, 20)) X0_2 = None if same_inputs else random.normal(key, (8, 20)) if xla_bridge.get_backend().platform == 'tpu': width = 2048 n_samples = 1024 tol = 0.02 else: width = 1024 n_samples = 256 tol = 0.01 dense = stax.Dense(width, 1.25, 0.1) input_layers = [dense, stax.FanOut(n_branches)] branches = [] for b in range(n_branches): branch_layers = [FanInTest._get_phi(b)] for i in range(b): branch_layers += [ stax.Dense(width, 1. + 2 * i, 0.5 + i), FanInTest._get_phi(i) ] if branch_in == 'dense_before_branch_in': branch_layers += [dense] branches += [stax.serial(*branch_layers)] output_layers = [ stax.FanInSum() if axis is None else stax.FanInConcat(axis), stax.Relu() ] if branch_in == 'dense_after_branch_in': output_layers.insert(1, dense) nn = stax.serial(*(input_layers + [stax.parallel(*branches)] + output_layers)) if get == 'nngp': init_fn, apply_fn, kernel_fn = nn elif get == 'ntk': init_fn, apply_fn, kernel_fn = stax.serial( nn, stax.Dense(1, 1.25, 0.5)) else: raise ValueError(get) kernel_fn_mc = monte_carlo.monte_carlo_kernel_fn(init_fn, apply_fn, key, n_samples, device_count=0) exact = kernel_fn(X0_1, X0_2, get=get) empirical = kernel_fn_mc(X0_1, X0_2, get=get) empirical = empirical.reshape(exact.shape) test_utils.assert_close_matrices(self, empirical, exact, tol)