def _check_agreement_with_empirical(self, W_std, b_std, filter_size, is_conv, is_ntk, is_res, layer_norm, padding, phi, proj_into_2d, same_inputs, strides, use_pooling, width): key = random.PRNGKey(1) x1, x2 = _get_inputs(key, is_conv, same_inputs, INPUT_SHAPE) init_fn, apply_fn, kernel_fn = _get_net(W_std, b_std, filter_size, is_conv, use_pooling, is_res, padding, phi, strides, width, is_ntk, proj_into_2d, layer_norm) x1_out_shape, params = init_fn(key, x1.shape) if x2 is None: x2_out_shape = x1_out_shape else: x2_out_shape, params = init_fn(key, x2.shape) del(params) def _get_empirical(n_samples, get): kernel_fn_empirical = monte_carlo.monte_carlo_kernel_fn( init_fn, apply_fn, key, n_samples) return kernel_fn_empirical(x1, x2, get) if proj_into_2d == 'ATTN_PARAM': # no analytic kernel available, just test forward/backward pass _get_empirical(1, 'ntk' if is_ntk else 'nngp') else: if is_ntk: exact, shape1, shape2 = kernel_fn(x1, x2, ('ntk', 'shape1', 'shape2')) empirical = np.reshape(_get_empirical(N_SAMPLES, 'ntk'), exact.shape) else: exact, shape1, shape2 = kernel_fn(x1, x2, ('nngp', 'shape1', 'shape2')) empirical = _get_empirical(N_SAMPLES, 'nngp') utils.assert_close_matrices(self, empirical, exact, RTOL) self.assertEqual(shape1, x1_out_shape) self.assertEqual(shape2, x2_out_shape)
def test_exact(self, model, width, strides, padding, phi, same_inputs, filter_size, use_pooling, is_ntk, is_res, proj_into_2d): is_conv = 'conv' in model # Check for duplicate / incorrectly-shaped NN configs / wrong backend. if is_conv: if xla_bridge.get_backend().platform == 'cpu': raise jtu.SkipTest('Not running CNN models on CPU to save time.') if (is_res and is_conv and ((strides is not None and strides != (1, 1)) or (padding == 'VALID' and filter_size != (1, 1)))): raise jtu.SkipTest('Different paths in a residual models need to return' ' outputs of the same shape.') elif (filter_size != FILTER_SIZES[0] or padding != PADDINGS[0] or strides != STRIDES[0] or proj_into_2d != PROJECTIONS[0] or use_pooling): raise jtu.SkipTest('FC models do not have these parameters.') if (proj_into_2d.startswith('ATTN') and strides == (2, 1) and padding == 'VALID' and xla_bridge.get_backend().platform == 'tpu'): #TODO(jirihron): speed up the vmap alternative impl or fix the current one raise jtu.SkipTest('ATTN forward pass on TPU is broken if one of' ' the spatial dimensions is singleton.') W_std, b_std = 2.**0.5, 0.5**0.5 key = random.PRNGKey(1) x1, x2 = _get_inputs(key, is_conv, same_inputs, INPUT_SHAPE) init_fn, apply_fn, kernel_fn = _get_net(W_std, b_std, filter_size, is_conv, use_pooling, is_res, padding, phi, strides, width, is_ntk, proj_into_2d) def _get_empirical(n_samples, get): kernel_fn_empirical = monte_carlo.monte_carlo_kernel_fn( init_fn, apply_fn, key, n_samples) return kernel_fn_empirical(x1, x2, get) if proj_into_2d == 'ATTN_PARAM': # no analytic kernel available, just test forward/backward pass _get_empirical(1, 'ntk' if is_ntk else 'nngp') else: if is_ntk: exact = kernel_fn(x1, x2, 'ntk') empirical = np.reshape(_get_empirical(N_SAMPLES, 'ntk'), exact.shape) else: exact = kernel_fn(x1, x2, 'nngp') empirical = _get_empirical(N_SAMPLES, 'nngp') utils.assert_close_matrices(self, empirical, exact, RTOL)
def test_exact(self, model, width, strides, padding, phi, same_inputs, filter_size, use_pooling, is_ntk, is_res): is_conv = 'conv' in model # Check for duplicate / incorrectly-shaped NN configs / wrong backend. if is_conv: if xla_bridge.get_backend().platform == 'cpu': raise jtu.SkipTest( 'Not running CNN models on CPU to save time.') if use_pooling and not same_inputs: raise jtu.SkipTest( 'Pooling layers for different inputs or for same ' 'padding not implemented.') if (is_res and is_conv and ((strides is not None and strides != (1, 1)) or (padding == 'VALID' and filter_size != (1, 1)))): raise jtu.SkipTest( 'Different paths in a residual models need to return' ' outputs of the same shape.') elif (filter_size != FILTER_SIZES[0] or padding != PADDINGS[0] or strides != STRIDES[0] or use_pooling): raise jtu.SkipTest('FC models do not have these parameters.') W_std, b_std = 2.**0.5, 0.5**0.5 key = random.PRNGKey(1) x1, x2 = _get_inputs(key, is_conv, same_inputs, INPUT_SHAPE) init_fun, apply_fun, ker_fun = _get_net(W_std, b_std, filter_size, is_conv, use_pooling, is_res, padding, phi, strides, width, is_ntk) if is_ntk: exact = ker_fun(x1, x2).ntk ker_fun_empirical = monte_carlo.get_ker_fun_monte_carlo( init_fun, apply_fun, False, True) empirical = ker_fun_empirical(x1, x2, key, N_SAMPLES).ntk empirical = np.reshape(empirical, exact.shape) else: exact = ker_fun(x1, x2, compute_ntk=False).nngp ker_fun_empirical = monte_carlo.get_ker_fun_monte_carlo( init_fun, apply_fun, True, False) empirical = ker_fun_empirical(x1, x2, key, N_SAMPLES).nngp utils.assert_close_matrices(self, empirical, exact, RTOL)
def test_sample_vs_analytic_nngp(self, batch_size, device_count, store_on_device): utils.stub_out_pmap(batch, device_count) x1, x2, init_fn, apply_fn, stax_kernel_fn, key = _get_inputs_and_model( 1024, 256, xla_bridge.get_backend().platform == 'tpu') sample = monte_carlo.monte_carlo_kernel_fn(init_fn, apply_fn, key, 200, batch_size, device_count, store_on_device) ker_empirical = sample(x1, x2, 'nngp') ker_analytic = stax_kernel_fn(x1, x2, 'nngp') utils.assert_close_matrices(self, ker_analytic, ker_empirical, 2e-2)
def test_monte_carlo_vs_analytic_ntk(self, batch_size, device_count, store_on_device): utils.stub_out_pmap(batch, device_count) x1, x2, init_fn, apply_fn, stax_kernel_fn, key = _get_inputs_and_model( 256, 2, xla_bridge.get_backend().platform == 'tpu') sample = monte_carlo.monte_carlo_kernel_fn(init_fn, apply_fn, key, 100, batch_size, device_count, store_on_device) ker_empirical = sample(x1, x2, 'ntk') ker_empirical = (np.sum(ker_empirical, axis=(-1, -2)) / ker_empirical.shape[-1]) ker_analytic = stax_kernel_fn(x1, x2, 'ntk') utils.assert_close_matrices(self, ker_analytic, ker_empirical, 2e-2)
def test_sample_vs_analytic_nngp(self, batch_size, device_count, store_on_device): utils.stub_out_pmap(batch, device_count) x1, x2, init_fun, apply_fun, stax_ker_fun, key = _get_inputs_and_model( 512, 512) sample = monte_carlo.get_ker_fun_monte_carlo(init_fun, apply_fun, True, False, batch_size, device_count, store_on_device) ker_empirical = sample(x1, x2, key, 200).nngp ker_analytic = stax_ker_fun(x1, x2, compute_ntk=False, compute_nngp=True) ker_analytic = ker_analytic.nngp utils.assert_close_matrices(self, ker_analytic, ker_empirical, 1e-2)
def _check_agreement_with_empirical(self, net, same_inputs, is_conv, use_dropout, is_ntk, proj_into_2d): (init_fn, apply_fn, kernel_fn), input_shape = net num_samples = N_SAMPLES * 5 if use_dropout else N_SAMPLES key = random.PRNGKey(1) x1, x2 = _get_inputs(key, is_conv, same_inputs, input_shape) x1_out_shape, params = init_fn(key, x1.shape) if same_inputs: assert (x2 is None) if x2 is None: x2_out_shape = x1_out_shape else: x2_out_shape, params = init_fn(key, x2.shape) del (params) def _get_empirical(n_samples, get): kernel_fn_empirical = monte_carlo.monte_carlo_kernel_fn( init_fn, apply_fn, key, n_samples) if same_inputs: assert (x2 is None) return kernel_fn_empirical(x1, x2, get) if proj_into_2d == 'ATTN_PARAM': # no analytic kernel available, just test forward/backward pass _get_empirical(1, 'ntk' if is_ntk else 'nngp') else: if is_ntk: exact, shape1, shape2 = kernel_fn(x1, x2, ('ntk', 'shape1', 'shape2')) empirical = np.reshape(_get_empirical(num_samples, 'ntk'), exact.shape) else: exact, shape1, shape2 = kernel_fn(x1, x2, ('nngp', 'shape1', 'shape2')) empirical = _get_empirical(num_samples, 'nngp') utils.assert_close_matrices(self, exact, empirical, RTOL) self.assertEqual(shape1, x1_out_shape) self.assertEqual(shape2, x2_out_shape)
def test_monte_carlo_vs_analytic_ntk(self, batch_size, device_count, store_on_device): utils.stub_out_pmap(batch, device_count) x1, x2, init_fun, apply_fun, stax_ker_fun, key = _get_inputs_and_model( 512, 2) sample = monte_carlo.get_ker_fun_monte_carlo(init_fun, apply_fun, False, True, batch_size, device_count, store_on_device) ker_empirical = sample(x1, x2, key, 100).ntk ker_empirical = (np.sum(ker_empirical, axis=(-1, -2)) / ker_empirical.shape[-1]) ker_analytic = stax_ker_fun(x1, x2, compute_ntk=True, compute_nngp=True) ker_analytic = ker_analytic.ntk utils.assert_close_matrices(self, ker_analytic, ker_empirical, 1e-2)
def _check_agreement_with_empirical(self, W_std, b_std, filter_size, is_conv, is_ntk, is_res, layer_norm, padding, phi, proj_into_2d, same_inputs, strides, use_pooling, width, parameterization, use_dropout): if is_conv: # Select a random dimension order. default_spec = 'NHWC' if xla_bridge.get_backend().platform == 'tpu': # Keep batch dimension leading for TPU for batching to work. specs = ['NHWC', 'NHCW', 'NCHW'] else: specs = ['NHWC', 'NHCW', 'NCHW', 'CHWN', 'CHNW', 'CNHW'] spec = prandom.choice(specs) input_shape = tuple(INPUT_SHAPE[default_spec.index(c)] for c in spec) if layer_norm: layer_norm = tuple(spec.index(c) for c in layer_norm) else: # Only `NC` dimension order is supported and is enforced by layers. spec = None input_shape = INPUT_SHAPE if layer_norm: layer_norm = prandom.choice([(1, ), (-1, )]) num_samples = N_SAMPLES * 5 if use_dropout else N_SAMPLES key = random.PRNGKey(1) dimension_numbers = (spec, 'HWIO', spec) x1, x2 = _get_inputs(key, is_conv, same_inputs, input_shape) init_fn, apply_fn, kernel_fn = _get_net(W_std, b_std, filter_size, is_conv, use_pooling, is_res, padding, phi, strides, width, is_ntk, proj_into_2d, layer_norm, parameterization, use_dropout, dimension_numbers) x1_out_shape, params = init_fn(key, x1.shape) if same_inputs: assert (x2 is None) if x2 is None: x2_out_shape = x1_out_shape else: x2_out_shape, params = init_fn(key, x2.shape) del (params) def _get_empirical(n_samples, get): kernel_fn_empirical = monte_carlo.monte_carlo_kernel_fn( init_fn, apply_fn, key, n_samples) if same_inputs: assert (x2 is None) return kernel_fn_empirical(x1, x2, get) if proj_into_2d == 'ATTN_PARAM': # no analytic kernel available, just test forward/backward pass _get_empirical(1, 'ntk' if is_ntk else 'nngp') else: if is_ntk: exact, shape1, shape2 = kernel_fn(x1, x2, ('ntk', 'shape1', 'shape2')) empirical = np.reshape(_get_empirical(num_samples, 'ntk'), exact.shape) else: exact, shape1, shape2 = kernel_fn(x1, x2, ('nngp', 'shape1', 'shape2')) empirical = _get_empirical(num_samples, 'nngp') utils.assert_close_matrices(self, empirical, exact, RTOL) self.assertEqual(shape1, x1_out_shape) self.assertEqual(shape2, x2_out_shape)
def test_fan_in_conv(self, same_inputs, axis, n_branches, get, branch_in, readout): if xla_bridge.get_backend().platform == 'cpu': raise jtu.SkipTest('Not running CNNs on CPU to save time.') if axis in (None, 0, 1, 2) and branch_in == 'dense_after_branch_in': raise jtu.SkipTest('`FanInSum` and `FanInConcat(0/1/2)` ' 'require `is_gaussian`.') if axis == 3 and branch_in == 'dense_before_branch_in': raise jtu.SkipTest('`FanInConcat` on feature axis requires a dense layer ' 'after concatenation.') key = random.PRNGKey(1) X0_1 = random.normal(key, (2, 5, 6, 3)) X0_2 = None if same_inputs else random.normal(key, (3, 5, 6, 3)) if xla_bridge.get_backend().platform == 'tpu': width = 2048 n_samples = 1024 tol = 0.02 else: width = 1024 n_samples = 512 tol = 0.01 conv = stax.Conv(out_chan=width, filter_shape=(3, 3), padding='SAME', W_std=1.25, b_std=0.1) input_layers = [conv, stax.FanOut(n_branches)] branches = [] for b in range(n_branches): branch_layers = [FanInTest._get_phi(b)] for i in range(b): branch_layers += [ stax.Conv( out_chan=width, filter_shape=(i + 1, 4 - i), padding='SAME', W_std=1.25 + i, b_std=0.1 + i), FanInTest._get_phi(i)] if branch_in == 'dense_before_branch_in': branch_layers += [conv] branches += [stax.serial(*branch_layers)] output_layers = [ stax.FanInSum() if axis is None else stax.FanInConcat(axis), stax.Relu(), stax.GlobalAvgPool() if readout == 'pool' else stax.Flatten() ] if branch_in == 'dense_after_branch_in': output_layers.insert(1, conv) nn = stax.serial(*(input_layers + [stax.parallel(*branches)] + output_layers)) init_fn, apply_fn, kernel_fn = stax.serial( nn, stax.Dense(1 if get == 'ntk' else width, 1.25, 0.5)) kernel_fn_mc = monte_carlo.monte_carlo_kernel_fn( init_fn, apply_fn, key, n_samples, device_count=0 if axis in (0, -4) else -1) exact = kernel_fn(X0_1, X0_2, get=get) empirical = kernel_fn_mc(X0_1, X0_2, get=get) empirical = empirical.reshape(exact.shape) utils.assert_close_matrices(self, empirical, exact, tol)
def test_fan_in_fc(self, same_inputs, axis, n_branches, get, branch_in): if axis in (None, 0) and branch_in == 'dense_after_branch_in': raise jtu.SkipTest('`FanInSum` and `FanInConcat(0)` ' 'require `is_gaussian`.') if axis == 1 and branch_in == 'dense_before_branch_in': raise jtu.SkipTest('`FanInConcat` on feature axis requires a dense layer' 'after concatenation.') key = random.PRNGKey(1) X0_1 = random.normal(key, (10, 20)) X0_2 = None if same_inputs else random.normal(key, (8, 20)) if xla_bridge.get_backend().platform == 'tpu': width = 2048 n_samples = 1024 tol = 0.02 else: width = 1024 n_samples = 256 tol = 0.01 dense = stax.Dense(width, 1.25, 0.1) input_layers = [dense, stax.FanOut(n_branches)] branches = [] for b in range(n_branches): branch_layers = [FanInTest._get_phi(b)] for i in range(b): branch_layers += [ stax.Dense(width, 1. + 2 * i, 0.5 + i), FanInTest._get_phi(i)] if branch_in == 'dense_before_branch_in': branch_layers += [dense] branches += [stax.serial(*branch_layers)] output_layers = [ stax.FanInSum() if axis is None else stax.FanInConcat(axis), stax.Relu() ] if branch_in == 'dense_after_branch_in': output_layers.insert(1, dense) nn = stax.serial(*(input_layers + [stax.parallel(*branches)] + output_layers)) if get == 'nngp': init_fn, apply_fn, kernel_fn = nn elif get == 'ntk': init_fn, apply_fn, kernel_fn = stax.serial(nn, stax.Dense(1, 1.25, 0.5)) else: raise ValueError(get) kernel_fn_mc = monte_carlo.monte_carlo_kernel_fn( init_fn, apply_fn, key, n_samples, device_count=0) exact = kernel_fn(X0_1, X0_2, get=get) empirical = kernel_fn_mc(X0_1, X0_2, get=get) empirical = empirical.reshape(exact.shape) utils.assert_close_matrices(self, empirical, exact, tol)