class ParallelInOutTest(test_utils.NeuralTangentsTestCase): @parameterized.named_parameters( test_utils.cases_from_list({ 'testcase_name': f'_same_inputs={same_inputs}_kernel_type={kernel_type}', 'same_inputs': same_inputs, 'kernel_type': kernel_type } for same_inputs in [True, False] for kernel_type in ['ntk'])) def test_parallel_in(self, same_inputs, kernel_type): platform = default_backend() rtol = RTOL if platform != 'tpu' else 0.05 rng = random.PRNGKey(0) input_key1, input_key2, mc_key = random.split(rng, 3) x1_1, x2_1 = _get_inputs(input_key1, same_inputs, (BATCH_SIZE, 2)) x1_2, x2_2 = _get_inputs(input_key2, same_inputs, (BATCH_SIZE, 3)) x1 = (x1_1, x1_2) x2 = (x2_1, x2_2) N = 2 ** 7 def net(logits): return stax.serial( stax.parallel(stax.Dense(N), stax.Dense(N)), stax.serial(stax.FanInSum(), stax.Dense(logits))) init_fn, apply_fn, kernel_fn = net(N if kernel_type == 'nngp' else 1) kernel_fn_empirical = nt.monte_carlo_kernel_fn( init_fn, apply_fn, mc_key, N_SAMPLES, trace_axes=(-1,), implementation=2, vmap_axes=((0, 0), 0, {}) ) test_utils.assert_close_matrices(self, kernel_fn(x1, x2, kernel_type), kernel_fn_empirical(x1, x2, kernel_type), rtol) @parameterized.named_parameters( test_utils.cases_from_list({ 'testcase_name': f'_same_inputs={same_inputs}_kernel_type={kernel_type}', 'same_inputs': same_inputs, 'kernel_type': kernel_type } for same_inputs in [True, False] for kernel_type in ['ntk'])) def test_parallel_out(self, same_inputs, kernel_type): platform = default_backend() rtol = RTOL if platform != 'tpu' else 0.05 rng = random.PRNGKey(0) input_key1, mc_key = random.split(rng, 2) x1, x2 = _get_inputs(input_key1, same_inputs, (BATCH_SIZE, 1)) N = 2 ** 10 def net(logits): return stax.serial( stax.Dense(N), stax.FanOut(2), stax.parallel(stax.Dense(logits), stax.Dense(logits))) init_fn, apply_fn, kernel_fn = net(N if kernel_type == 'nngp' else 1) kernel_fn_empirical = nt.monte_carlo_kernel_fn( init_fn, apply_fn, mc_key, N_SAMPLES, trace_axes=(-1,), implementation=2, vmap_axes=(0, [0, 0], {})) test_utils.assert_close_matrices(self, kernel_fn(x1, x2, kernel_type), kernel_fn_empirical(x1, x2, kernel_type), rtol) @parameterized.named_parameters( test_utils.cases_from_list({ 'testcase_name': f'_same_inputs={same_inputs}_kernel_type={kernel_type}', 'same_inputs': same_inputs, 'kernel_type': kernel_type, } for same_inputs in [True, False] for kernel_type in ['ntk'])) def test_parallel_in_out(self, same_inputs, kernel_type): platform = default_backend() rtol = RTOL if platform != 'tpu' else 0.05 rng = random.PRNGKey(0) input_key1, input_key2, mc_key = random.split(rng, 3) x1_1, x2_1 = _get_inputs(input_key1, same_inputs, (BATCH_SIZE, 1)) x1_2, x2_2 = _get_inputs(input_key2, same_inputs, (BATCH_SIZE, 2)) x1 = (x1_1, x1_2) x2 = (x2_1, x2_2) N_in = 2 ** 10 N_out = N_in if kernel_type == 'nngp' else 1 readin = stax.serial(stax.parallel(stax.Dense(N_in), stax.Dense(N_in)), stax.FanInSum()) readout = stax.serial(stax.FanOut(3), stax.parallel(stax.Dense(N_out), stax.Dense(N_out + 1), stax.Dense(N_out + 2))) init_fn, apply_fn, _ = stax.serial(readin, readout) K_readin_fn = jit(readin[2]) K_readout_fn = jit(functools.partial(readout[2], get=kernel_type)) kernel_fn_empirical = nt.monte_carlo_kernel_fn( init_fn, apply_fn, mc_key, N_SAMPLES, trace_axes=(-1,), implementation=2, vmap_axes=((0, 0), [0, 0, 0], {}) ) test_utils.assert_close_matrices( self, K_readout_fn(K_readin_fn(x1, x2)), kernel_fn_empirical(x1, x2, get=kernel_type), rtol) # Check Both (here we just want to make sure we _can_ compute the output). K_readin_fn = jit(readin[2]) K_readout_fn = jit(functools.partial(readout[2], get=('nngp', 'ntk'))) K_readout_fn(K_readin_fn(x1, x2)) @parameterized.named_parameters( test_utils.cases_from_list({ 'testcase_name': f'_same_inputs={same_inputs}_kernel_type={kernel_type}', 'same_inputs': same_inputs, 'kernel_type': kernel_type, } for same_inputs in [True, False] for kernel_type in ['ntk'])) def test_nested_parallel(self, same_inputs, kernel_type): platform = default_backend() rtol = RTOL if platform != 'tpu' else 0.05 rng = random.PRNGKey(0) (input_key1, input_key2, input_key3, input_key4, mask_key, mc_key) = random.split(rng, 6) x1_1, x2_1 = _get_inputs(input_key1, same_inputs, (BATCH_SIZE, 5)) x1_2, x2_2 = _get_inputs(input_key2, same_inputs, (BATCH_SIZE, 2, 2, 2)) x1_3, x2_3 = _get_inputs(input_key3, same_inputs, (BATCH_SIZE, 2, 2, 3)) x1_4, x2_4 = _get_inputs(input_key4, same_inputs, (BATCH_SIZE, 3, 4)) m1_key, m2_key, m3_key, m4_key = random.split(mask_key, 4) x1_1 = test_utils.mask( x1_1, mask_constant=-1, mask_axis=(1,), key=m1_key, p=0.5) x1_2 = test_utils.mask( x1_2, mask_constant=-1, mask_axis=(2, 3,), key=m2_key, p=0.5) if not same_inputs: x2_3 = test_utils.mask( x2_3, mask_constant=-1, mask_axis=(1, 3,), key=m3_key, p=0.5) x2_4 = test_utils.mask( x2_4, mask_constant=-1, mask_axis=(2,), key=m4_key, p=0.5) x1 = (((x1_1, x1_2), x1_3), x1_4) x2 = (((x2_1, x2_2), x2_3), x2_4) if not same_inputs else None N_in = 2 ** 7 # We only include dropout on non-TPU backends, because it takes large N to # converge on TPU. dropout_or_id = stax.Dropout(0.9) if platform != 'tpu' else stax.Identity() init_fn, apply_fn, kernel_fn = stax.parallel( stax.parallel( stax.parallel(stax.Dense(N_in), stax.serial(stax.Conv(N_in + 1, (2, 2)), stax.Flatten())), stax.serial(stax.Conv(N_in + 2, (2, 2)), dropout_or_id, stax.GlobalAvgPool())), stax.Conv(N_in + 3, (2,))) kernel_fn_empirical = nt.monte_carlo_kernel_fn( init_fn, apply_fn, mc_key, N_SAMPLES, implementation=2, vmap_axes=(((((0, 0), 0), 0), (((0, 0), 0), 0), {}) if platform == 'tpu' else None) ) test_utils.assert_close_matrices( self, kernel_fn(x1, x2, get=kernel_type, mask_constant=-1), kernel_fn_empirical(x1, x2, get=kernel_type, mask_constant=-1), rtol)
class StaxTest(test_utils.NeuralTangentsTestCase): def _skip_test(self, filter_shape, is_conv, is_res, padding, proj_into_2d, strides, use_pooling): if is_conv: test_utils.skip_test(self) if (is_res and is_conv and ((strides is not None and strides != (1, 1)) or (padding == 'VALID' and filter_shape != (1, 1)))): raise absltest.SkipTest('Different paths in a residual models need to ' 'return outputs of the same shape.') elif (filter_shape != FILTER_SHAPES[0] or padding != PADDINGS[0] or strides != STRIDES[0] or proj_into_2d != PROJECTIONS[0] or use_pooling): raise absltest.SkipTest('FC models do not have these parameters.') @parameterized.named_parameters( test_utils.cases_from_list({ 'testcase_name': '_{}_{}_{}_{}_{}_{}_{}_{}_{}_{}_{}'.format( model, phi_name, width, 'same_inputs' if same_inputs else 'different_inputs', 'filter_shape=%s' % str(filter_shape), 'padding=%s' % padding, 'strides=%s' % str(strides), 'pool' if use_pooling else 'flatten', 'NTK' if is_ntk else 'NNGP', 'RESNET' if is_res else 'serial', proj_into_2d), 'model': model, 'width': width, 'strides': strides, 'padding': padding, 'phi': phi, 'same_inputs': same_inputs, 'filter_shape': filter_shape, 'use_pooling': use_pooling, 'is_ntk': is_ntk, 'is_res': is_res, 'proj_into_2d': proj_into_2d } for model in MODELS for width in WIDTHS for phi, phi_name in ACTIVATIONS.items() for same_inputs in [False] for padding in PADDINGS for strides in STRIDES for filter_shape in FILTER_SHAPES for use_pooling in [False, True] for is_ntk in [False, True] for is_res in [False, True] for proj_into_2d in PROJECTIONS)) def test_exact(self, model, width, strides, padding, phi, same_inputs, filter_shape, use_pooling, is_ntk, is_res, proj_into_2d): is_conv = 'conv' in model # Check for duplicate / incorrectly-shaped NN configs / wrong backend. self._skip_test(filter_shape, is_conv, is_res, padding, proj_into_2d, strides, use_pooling) pool_type = 'AVG' W_std, b_std = 2.**0.5, 0.5**0.5 layer_norm = None parameterization = 'ntk' use_dropout = False net = _get_net(W_std, b_std, filter_shape, is_conv, use_pooling, is_res, padding, phi, strides, width, is_ntk, proj_into_2d, pool_type, layer_norm, parameterization, 1, use_dropout) _check_agreement_with_empirical( self, net, same_inputs, use_dropout, is_ntk, RTOL, 1.1) @parameterized.named_parameters( test_utils.cases_from_list({ 'testcase_name': '_{}_{}_{}_{}_{}_{}'.format( model, width, 'same_inputs' if same_inputs else 'different_inputs', 'NTK' if is_ntk else 'NNGP', proj_into_2d, 'layer_norm=%s' % str(layer_norm)), 'model': model, 'width': width, 'same_inputs': same_inputs, 'is_ntk': is_ntk, 'proj_into_2d': proj_into_2d, 'layer_norm': layer_norm } for model in MODELS for width in WIDTHS for same_inputs in [False] for is_ntk in [False, True] for proj_into_2d in PROJECTIONS[:2] for layer_norm in LAYER_NORM)) def test_layernorm(self, model, width, same_inputs, is_ntk, proj_into_2d, layer_norm): is_conv = 'conv' in model # Check for duplicate / incorrectly-shaped NN configs / wrong backend. if is_conv: test_utils.skip_test(self) elif proj_into_2d != PROJECTIONS[0] or layer_norm not in ('C', 'NC'): raise absltest.SkipTest('FC models do not have these parameters.') W_std, b_std = 2.**0.5, 0.5**0.5 filter_shape = FILTER_SHAPES[0] padding = PADDINGS[0] strides = STRIDES[0] phi = stax.Relu() use_pooling, is_res = False, False parameterization = 'ntk' pool_type = 'AVG' use_dropout = False net = _get_net(W_std, b_std, filter_shape, is_conv, use_pooling, is_res, padding, phi, strides, width, is_ntk, proj_into_2d, pool_type, layer_norm, parameterization, 1, use_dropout) _check_agreement_with_empirical(self, net, same_inputs, use_dropout, is_ntk, 0.07) @parameterized.named_parameters( test_utils.cases_from_list({ 'testcase_name': '_{}_{}_{}_{}_{}_{}_{}_{}'.format( width, 'same_inputs' if same_inputs else 'different_inputs', 'filter_shape=%s' % str(filter_shape), 'padding=%s' % padding, 'strides=%s' % str(strides), 'NTK' if is_ntk else 'NNGP', 'pool_type=%s' % str(pool_type), 'normalize_edges=%s' % str(normalize_edges)), 'width': width, 'same_inputs': same_inputs, 'is_ntk': is_ntk, 'pool_type': pool_type, 'padding': padding, 'filter_shape': filter_shape, 'strides': strides, 'normalize_edges': normalize_edges } for width in WIDTHS for same_inputs in [False] for is_ntk in [False, True] for pool_type in POOL_TYPES for padding in PADDINGS for filter_shape in FILTER_SHAPES for strides in STRIDES for normalize_edges in [True, False])) def test_pool(self, width, same_inputs, is_ntk, pool_type, padding, filter_shape, strides, normalize_edges): use_dropout = False # Check for duplicate / incorrectly-shaped NN configs / wrong backend. test_utils.skip_test(self) if pool_type == 'SUM' and normalize_edges: raise absltest.SkipTest('normalize_edges not applicable to SumPool.') net = _get_net_pool(width, is_ntk, pool_type, padding, filter_shape, strides, normalize_edges) _check_agreement_with_empirical(self, net, same_inputs, use_dropout, is_ntk) def test_avg_pool(self): X1 = np.ones((4, 2, 3, 2)) X2 = np.ones((3, 2, 3, 2)) _, apply_fn, kernel_fn = stax.AvgPool((2, 2), (1, 1), 'SAME', normalize_edges=False) _, apply_fn_norm, kernel_fn_norm = stax.AvgPool((2, 2), (1, 1), 'SAME', normalize_edges=True) _, apply_fn_stax = ostax.AvgPool((2, 2), (1, 1), 'SAME') out1 = apply_fn((), X1) out2 = apply_fn((), X2) out1_norm = apply_fn_norm((), X1) out2_norm = apply_fn_norm((), X2) out1_stax = apply_fn_stax((), X1) out2_stax = apply_fn_stax((), X2) self.assertAllClose((out1_stax, out2_stax), (out1_norm, out2_norm)) out_unnorm = np.array([[1., 1., 0.5], [0.5, 0.5, 0.25]]).reshape( (1, 2, 3, 1)) out1_unnormalized = np.broadcast_to(out_unnorm, X1.shape) out2_unnormalized = np.broadcast_to(out_unnorm, X2.shape) self.assertAllClose((out1_unnormalized, out2_unnormalized), (out1, out2)) ker = kernel_fn(X1, X2) ker_norm = kernel_fn_norm(X1, X2) self.assertAllClose(np.ones_like(ker_norm.nngp), ker_norm.nngp) self.assertAllClose(np.ones_like(ker_norm.cov1), ker_norm.cov1) self.assertAllClose(np.ones_like(ker_norm.cov2), ker_norm.cov2) self.assertEqual(ker_norm.nngp.shape, ker.nngp.shape) self.assertEqual(ker_norm.cov1.shape, ker.cov1.shape) self.assertEqual(ker_norm.cov2.shape, ker.cov2.shape) ker_unnorm = np.outer(out_unnorm, out_unnorm).reshape((2, 3, 2, 3)) ker_unnorm = np.transpose(ker_unnorm, axes=(0, 2, 1, 3)) nngp = np.broadcast_to( ker_unnorm.reshape((1, 1) + ker_unnorm.shape), ker.nngp.shape) cov1 = np.broadcast_to(np.expand_dims(ker_unnorm, 0), ker.cov1.shape) cov2 = np.broadcast_to(np.expand_dims(ker_unnorm, 0), ker.cov2.shape) self.assertAllClose((nngp, cov1, cov2), (ker.nngp, ker.cov1, ker.cov2)) @parameterized.named_parameters( test_utils.cases_from_list({ 'testcase_name': '_{}_{}_{}_{}_{}_{}_{}_{}_{}_{}'.format( model, phi_name, width, 'same_inputs' if same_inputs else 'different_inputs', 'filter_shape=%s' % str(filter_shape), 'padding=%s' % padding, 'strides=%s' % str(strides), 'pool' if use_pooling else 'flatten', 'NTK' if is_ntk else 'NNGP', proj_into_2d), 'model': model, 'width': width, 'same_inputs': same_inputs, 'is_ntk': is_ntk, 'padding': padding, 'strides': strides, 'filter_shape': filter_shape, 'phi': phi, 'use_pooling': use_pooling, 'proj_into_2d': proj_into_2d } for model in MODELS for width in WIDTHS for same_inputs in [True, False] for phi, phi_name in ACTIVATIONS.items() for padding in ['SAME'] for strides in STRIDES for filter_shape in [(2, 1)] for is_ntk in [True, False] for use_pooling in [True, False] for proj_into_2d in ['FLAT', 'POOL'])) def test_dropout(self, model, width, same_inputs, is_ntk, padding, strides, filter_shape, phi, use_pooling, proj_into_2d): pool_type = 'AVG' use_dropout = True is_conv = 'conv' in model is_res = False W_std, b_std = 2.**0.5, 0.5**0.5 layer_norm = None parameterization = 'ntk' # Check for duplicate / incorrectly-shaped NN configs / wrong backend. self._skip_test(filter_shape, is_conv, is_res, padding, proj_into_2d, strides, use_pooling) net = _get_net(W_std, b_std, filter_shape, is_conv, use_pooling, is_res, padding, phi, strides, width, is_ntk, proj_into_2d, pool_type, layer_norm, parameterization, 1, use_dropout) _check_agreement_with_empirical(self, net, same_inputs, use_dropout, is_ntk) @parameterized.named_parameters( test_utils.cases_from_list({ 'testcase_name': f'_act={act}_kernel={kern}_do_stabilize={do_stabilize}', 'act': act, 'kernel': kern, 'do_stabilize': do_stabilize } for act in ['erf', 'relu'] for do_stabilize in [True, False] for kern in ['nngp', 'ntk'])) def test_sparse_inputs(self, act, kernel, do_stabilize): if do_stabilize and act != 'relu': raise absltest.SkipTest('Stabilization possible only in Relu.') key = random.PRNGKey(1) input_count = 4 sparse_count = 2 input_size = 3 width = 1024 # NOTE(schsam): It seems that convergence is slower when inputs are sparse. samples = N_SAMPLES if default_backend() == 'gpu': tol = 5e-4 samples = 100 * N_SAMPLES else: tol = {onp.dtype(onp.float32): 5e-2, onp.dtype(onp.float64): 5e-3} # a batch of dense inputs x_dense = random.normal(key, (input_count, input_size)) x_sparse =[:sparse_count, :].set(0.) activation = (stax.Relu(do_stabilize=do_stabilize) if act == 'relu' else stax.Erf()) init_fn, apply_fn, kernel_fn = stax.serial( stax.Dense(width), activation, stax.Dense(1 if kernel == 'ntk' else width)) exact = kernel_fn(x_sparse, None, kernel) mc = nt.monte_carlo_kernel_fn( init_fn, apply_fn, random.split(key, 2)[0], samples, vmap_axes=0, device_count=-1, implementation=2 )(x_sparse, None, kernel) mc = np.reshape(mc, exact.shape) assert not np.any(np.isnan(exact)) self.assertAllClose(exact[sparse_count:, sparse_count:], mc[sparse_count:, sparse_count:], rtol=tol, atol=tol) def test_composition_dense(self): rng = random.PRNGKey(0) x1 = random.normal(rng, (2, 3)) x2 = random.normal(rng, (4, 3)) Block = stax.serial(stax.Dense(256), stax.Relu()) _, _, ker_fn = Block _, _, composed_ker_fn = stax.serial(Block, Block) ker_out = ker_fn(ker_fn(x1)) composed_ker_out = composed_ker_fn(x1) self.assertAllClose(ker_out, composed_ker_out) ker_out = ker_fn(ker_fn(x1, x2)) composed_ker_out = composed_ker_fn(x1, x2) self.assertAllClose(ker_out, composed_ker_out) @parameterized.named_parameters( test_utils.cases_from_list({ 'testcase_name': '_avg_pool={}_same_inputs={}'.format(avg_pool, same_inputs), 'avg_pool': avg_pool, 'same_inputs': same_inputs } for avg_pool in [True, False] for same_inputs in [True, False])) def test_composition_conv(self, avg_pool, same_inputs): rng = random.PRNGKey(0) x1 = random.normal(rng, (3, 5, 5, 3)) x2 = None if same_inputs else random.normal(rng, (4, 5, 5, 3)) Block = stax.serial(stax.Conv(256, (3, 3)), stax.Relu()) if avg_pool: Readout = stax.serial(stax.Conv(256, (3, 3)), stax.GlobalAvgPool(), stax.Dense(10)) else: Readout = stax.serial(stax.Flatten(), stax.Dense(10)) block_ker_fn, readout_ker_fn = Block[2], Readout[2] _, _, composed_ker_fn = stax.serial(Block, Readout) composed_ker_out = composed_ker_fn(x1, x2) ker_out_no_marg = readout_ker_fn(block_ker_fn(x1, x2, diagonal_spatial=False)) ker_out_default = readout_ker_fn(block_ker_fn(x1, x2)) self.assertAllClose(composed_ker_out, ker_out_no_marg) self.assertAllClose(composed_ker_out, ker_out_default) if avg_pool: with self.assertRaises(ValueError): ker_out = readout_ker_fn(block_ker_fn(x1, x2, diagonal_spatial=True)) else: ker_out_marg = readout_ker_fn(block_ker_fn(x1, x2, diagonal_spatial=True)) self.assertAllClose(composed_ker_out, ker_out_marg)
class ParameterizationTest(test_utils.NeuralTangentsTestCase): @parameterized.named_parameters( test_utils.cases_from_list({ 'testcase_name': f'_get={get}' f'_s={s}' f'_depth={depth}' f'_same_inputs={same_inputs}' f'_b_std={b_std}_' f'_W_std={W_std}' f'_param={parameterization}', 'get': get, 's': s, 'depth': depth, 'same_inputs': same_inputs, 'b_std': b_std, 'W_std': W_std, 'parameterization': parameterization, } for get in ['nngp', 'ntk'] for s in [2**9, 2**8, 2**7] for depth in [0, 1, 2] for same_inputs in [True, False] for W_std in [0., 1., 2.] for b_std in [None, 0., 0.5**0.5, 2] for parameterization in ['ntk', 'standard'])) def test_linear( self, get, s, depth, same_inputs, b_std, W_std, parameterization, ): if parameterization == 'standard': width = 2**9 // s elif parameterization == 'ntk': if s != 2**9: raise absltest.SkipTest( '"ntk" parameterization does not depend on "s".') width = 2**10 else: raise ValueError(parameterization) layers = [] for i in range(depth + 1): s_in = 1 if i == 0 else s s_out = 1 if (i == depth and get == 'ntk') else s out_dim = 1 if (i == depth and get == 'ntk') else width * (i + 1) layers += [stax.Dense(out_dim, W_std=W_std / (i + 1), b_std=b_std if b_std is None else b_std / (i + 1), parameterization=parameterization, s=(s_in, s_out))] net = stax.serial(*layers) net = net, (BATCH_SIZE, 3), -1, 1 _check_agreement_with_empirical(self, net, same_inputs, False, get == 'ntk', rtol=0.02, atol=10) @parameterized.named_parameters( test_utils.cases_from_list({ 'testcase_name': f'_model={model}' f'_width={width}' f'_same_inputs={same_inputs}' f'_filter_shape={filter_shape}' f'_proj={proj_into_2d}_' f'_is_ntk={is_ntk}_' f'_b_std={b_std}_' f'_W_std={W_std}' f'_param={parameterization}' f'_s={s}', 'model': model, 'width': width, 'same_inputs': same_inputs, 'filter_shape': filter_shape, 'proj_into_2d': proj_into_2d, 'is_ntk': is_ntk, 'b_std': b_std, 'W_std': W_std, 'parameterization': parameterization, 's': s } for model in MODELS for width in [2**11] for same_inputs in [False] for is_ntk in [False, True] for filter_shape in FILTER_SHAPES for proj_into_2d in PROJECTIONS[:2] for W_std in [0., 1., 2.] for b_std in [None, 0., 0.5**0.5] for parameterization in ['ntk', 'standard'] for s in [2**10])) def test_nonlinear( self, model, width, same_inputs, is_ntk, filter_shape, proj_into_2d, b_std, W_std, parameterization, s ): is_conv = 'conv' in model if parameterization == 'standard': width //= s padding = PADDINGS[0] strides = STRIDES[0] phi = stax.Relu() use_pooling, is_res = False, False layer_norm = None pool_type = 'AVG' use_dropout = False # Check for duplicate / incorrectly-shaped NN configs / wrong backend. if is_conv: test_utils.skip_test(self) elif proj_into_2d != PROJECTIONS[0] or filter_shape != FILTER_SHAPES[0]: raise absltest.SkipTest('FC models do not have these parameters.') net = _get_net(W_std=W_std, b_std=b_std, filter_shape=filter_shape, is_conv=is_conv, use_pooling=use_pooling, is_res=is_res, padding=padding, phi=phi, strides=strides, width=width, is_ntk=is_ntk, proj_into_2d=proj_into_2d, pool_type=pool_type, layer_norm=layer_norm, parameterization=parameterization, s=s, use_dropout=use_dropout) _check_agreement_with_empirical( self, net=net, same_inputs=same_inputs, use_dropout=use_dropout, is_ntk=is_ntk, rtol=0.015, atol=1000 )
class ElementwiseNumericalTest(test_utils.NeuralTangentsTestCase): @parameterized.named_parameters( test_utils.cases_from_list({ 'testcase_name': '_{}_{}_{}_{}'.format( model, phi[0].__name__, 'Same_inputs' if same_inputs else 'Different_inputs', get), 'model': model, 'phi': phi, 'same_inputs': same_inputs, 'get': get, } for model in ['fc', 'conv-pool', 'conv-flatten'] for phi in [ stax.Erf(), stax.Gelu(), stax.Sin(), ] for same_inputs in [False, True] for get in ['nngp', 'ntk'])) def test_elementwise_numerical(self, same_inputs, model, phi, get): if 'conv' in model: test_utils.skip_test(self) key, split = random.split(random.PRNGKey(1)) output_dim = 1 b_std = 0.01 W_std = 1.0 rtol = 2e-3 deg = 25 if get == 'ntk': rtol *= 2 if default_backend() == 'tpu': rtol *= 2 if model == 'fc': X0_1 = random.normal(key, (3, 7)) X0_2 = None if same_inputs else random.normal(split, (5, 7)) affine = stax.Dense(1024, W_std, b_std) readout = stax.Dense(output_dim) depth = 1 else: X0_1 = random.normal(key, (2, 8, 8, 3)) X0_2 = None if same_inputs else random.normal(split, (3, 8, 8, 3)) affine = stax.Conv(1024, (3, 2), W_std=W_std, b_std=b_std, padding='SAME') readout = stax.serial(stax.GlobalAvgPool() if 'pool' in model else stax.Flatten(), stax.Dense(output_dim)) depth = 2 _, _, kernel_fn = stax.serial(*[affine, phi] * depth, readout) analytic_kernel = kernel_fn(X0_1, X0_2, get) fn = lambda x: phi[1]((), x) _, _, kernel_fn = stax.serial( *[affine, stax.ElementwiseNumerical(fn, deg=deg)] * depth, readout) numerical_activation_kernel = kernel_fn(X0_1, X0_2, get) test_utils.assert_close_matrices(self, analytic_kernel, numerical_activation_kernel, rtol)
class AutodiffTest(test_utils.NeuralTangentsTestCase): @parameterized.named_parameters( test_utils.cases_from_list({ 'testcase_name': f'{get}-{same_inputs}-{phi.__name__}', 'get': get, 'same_inputs': same_inputs, 'phi': phi, } for get in [ 'ntk', 'nngp' ] for same_inputs in [True, False, None] for phi in [ stax.Erf, stax.Sin, stax.Gelu, stax.Relu, stax.ElementwiseNumerical ])) def test_autodiff(self, get, same_inputs, phi): x1 = np.cos(random.normal(random.PRNGKey(1), (3, 1, 2, 3))) if same_inputs is None: x2 = None elif same_inputs is True: x2 = x1 else: x2 = np.cos(random.normal(random.PRNGKey(2), (4, 1, 2, 3))) name = phi.__name__ if name == 'LeakyRelu': phi = phi(0.1) elif name == 'ElementwiseNumerical': phi = phi(fn=np.cos, deg=25) else: phi = phi() _, _, kernel_fn = stax.serial(stax.Dense(1, 2., 0.01), phi, stax.Dense(1, 2., 0.01), phi) def k(x1, x2): return kernel_fn(x1, x2, get) dx1 = random.normal(random.PRNGKey(3), x1.shape) * 0.01 if x2 is None: dx2 = None else: dx2 = random.normal(random.PRNGKey(4), x2.shape) * 0.01 def dk(x1, x2): return jvp(k, (x1, x2), (dx1, dx2))[1] def d2k(x1, x2): return jvp(dk, (x1, x2), (dx1, dx2))[1] _dk = dk(x1, x2) if (same_inputs is not False and get == 'ntk' and ('Relu' in name or 'Abs' in name)): # TODO(romann): revisit numerical issues of second derivative of `Relu` _d2k = 0 tol = 0.01 else: _d2k = d2k(x1, x2) tol = 2e-3 if name == 'ElementwiseNumerical' else 1e-4 def assert_close(x, y, tol=3e-5): if default_backend() == 'tpu': # TODO(romann): understand why TPUs have high errors. tol = 0.21 self.assertLess( np.max(np.abs(x - y)) / (np.mean(np.abs(x)) + np.mean(np.abs(y))), tol) # k(x + dx) ~ k(x) + dk(x) dx + dx^T d2k(x) dx assert_close(k(x1 + dx1, None if same_inputs is None else x2 + dx2), k(x1, x2) + _dk + _d2k / 2, tol=tol) # d/dx1 k_fwd_0 = jacfwd(k)(x1, x2) k_rev_0 = jacrev(k)(x1, x2) assert_close(k_fwd_0, k_rev_0) if same_inputs is not None: # d/dx2 k_fwd_1 = jacfwd(k, 1)(x1, x2) k_rev_1 = jacrev(k, 1)(x1, x2) assert_close(k_fwd_1, k_rev_1) # dk(x2, x1)/dx2 = dk(x1, x2)/dx1 k_fwd_01 = jacfwd(k, 1)(x2, x1) k_rev_01 = jacrev(k, 1)(x2, x1) assert_close(np.moveaxis(k_fwd_0, (0, 2, 4), (1, 3, 5)), k_fwd_01) assert_close(np.moveaxis(k_rev_0, (0, 2, 4), (1, 3, 5)), k_rev_01) # dk(x2, x1)/dx1 = dk(x1, x2)/dx2 k_fwd_10 = jacfwd(k)(x2, x1) k_rev_10 = jacrev(k)(x2, x1) assert_close(np.moveaxis(k_fwd_1, (0, 2, 4), (1, 3, 5)), k_fwd_10) assert_close(np.moveaxis(k_rev_1, (0, 2, 4), (1, 3, 5)), k_rev_10) @parameterized.named_parameters( test_utils.cases_from_list({ 'testcase_name': f'get={get}-' f'param={parameterization}-' f'param_out={parameterization_out}-' f'x1={x1_type}-' f'x2={x2_type}-' f'phi={phi.__name__}-' f'b_std={b_std}-' f'jit={do_jit}-', 'get': get, 'parameterization': parameterization, 'parameterization_out': parameterization_out, 'x1_type': x1_type, 'x2_type': x2_type, 'phi': phi, 'b_std': b_std, 'do_jit': do_jit } for get in [ 'ntk', 'nngp' ] for parameterization in [ 'standard', 'ntk' ] for parameterization_out in [ 'ntk' ] for do_jit in [ True, ] for x1_type in [ 'zeros', 'ones', 'random', ] for x2_type in [ 'zeros', 'ones', 'random', 'x1', 'none', ] for b_std in [ None, 0.1, ] for phi in [ stax.Identity, stax.Erf, stax.Abs, stax.Gelu, stax.Relu, stax.Sigmoid_like, stax.ABRelu, stax.Exp, stax.ExpNormalized, stax.Gaussian, stax.Sign, stax.Rbf, stax.Cos, stax.Sin ])) def test_activations( self, get, parameterization, parameterization_out, x1_type, x2_type, b_std, phi, do_jit ): """Tests forward- and reverse-mode autodiff for nonlinearities.""" if phi == stax.ABRelu: phi_ = phi(0.25, 0.5) else: phi_ = phi() if phi not in [stax.Relu]: test_utils.skip_test(self) n_out = 1 if get == 'ntk' else 1024 width = 2**10 W_std_in = width**(-0.5) if parameterization_out == 'standard' else 1. if phi == stax.Exp: W_std_in /= 10. init_fn, apply_fn, kernel_fn = stax.serial( stax.Dense( width, W_std=W_std_in, b_std=b_std, parameterization=parameterization), phi_, stax.Dense( n_out, b_std=b_std, parameterization=parameterization_out ), ) def get_x(x_type, key): shape = (1, 2) if x_type == 'zeros': x = np.zeros(shape) elif x_type == 'ones': x = np.ones(shape) elif x_type == 'random': x = random.normal(random.PRNGKey(key), shape) elif x_type == 'sin': x = np.sin(random.normal(random.PRNGKey(key), shape)) elif x_type == 'none': return None else: raise ValueError(x_type) return x x1 = get_x(x1_type, 1) if x2_type == 'x1': x2 = x1 else: x2 = get_x(x2_type, 2) def kernel_scalar(x1, x2): return kernel_fn(x1, x2, get)[0, 0] if do_jit: kernel_scalar = jit(kernel_scalar) k1 = kernel_scalar(x1, x2) k2, k2_grad = value_and_grad(kernel_scalar)(x1, x2) self.assertAllClose(k1, k2) # Compare to forward-mode. k2_fwd, _ = jvp(kernel_scalar, (x1, x2), (x1, x2)) k2_grad_fwd = jacfwd(kernel_scalar)(x1, x2) self.assertAllClose(k1, k2_fwd) self.assertAllClose(k2_grad, k2_grad_fwd) # `stax.ExpNormalized` has no forward pass. # `stax.Sign` is discontinuous at `0`, so NTK MC kernel does not converge to # infinite-width kernel. if phi == stax.ExpNormalized or (get == 'ntk' and phi == stax.Sign): raise absltest.SkipTest('Not comparing against MC kernels.') _kernel_scalar_mc = nt.monte_carlo_kernel_fn( init_fn, apply_fn, key=random.PRNGKey(3), n_samples=1, device_count=0, ) def kernel_scalar_mc(x1, x2): return _kernel_scalar_mc(x1, x2, get)[0, 0] k_mc = kernel_scalar_mc(x1, x2) k_mc2, k_mc2_grad = value_and_grad(kernel_scalar_mc)(x1, x2) self.assertAllClose(k_mc, k_mc2) # Compare MC to forward-mode. k_mc2_fwd, _ = jvp(kernel_scalar_mc, (x1, x2), (x1, x2)) k_mc2_grad_fwd = jacfwd(kernel_scalar_mc)(x1, x2) self.assertAllClose(k_mc, k_mc2_fwd) self.assertAllClose(k_mc2_grad, k_mc2_grad_fwd) def kernel_fn_emp(x1, x2, get, params): return nt.empirical_kernel_fn(apply_fn)(x1, x2, get, params)[0, 0] kernel_fn_emp_g = jit(value_and_grad(kernel_fn_emp), static_argnums=(2,)) def kernel_scalar_mc_grad_mean(x1, x2): key = random.PRNGKey(4) n_samples = 2**9 k, k_grad = 0., 0. for _ in range(n_samples): _, params = init_fn(key, x1.shape) k_mc2, k_mc2_grad = kernel_fn_emp_g(x1, x2, get, params) k += k_mc2 k_grad += k_mc2_grad key, _ = random.split(key) k /= n_samples k_grad /= n_samples return k, k_grad k_mc2_mean, k_mc2_grad_mean = kernel_scalar_mc_grad_mean(x1, x2) # Compare kernels. self.assertAllClose(k1, k_mc2_mean, atol=4e-3, rtol=4e-2) if phi == stax.Sign and get == 'nngp': raise absltest.SkipTest('Derivative of the empirical NNGP of a ' 'discontinuous function does not converge ' 'to the derivative of the infinite width NNGP.') if (phi in [stax.Abs, stax.Relu, stax.LeakyRelu, stax.ABRelu] and get == 'ntk'): raise absltest.SkipTest('Derivative of the empirical NTK of a ' 'non-differentiable function does not converge ' 'to the derivative of the infinite width NTK.') atol = 1e-2 # Compare gradient of the analytic kernel to empirical kernel. if np.max(np.abs(k2_grad - k_mc2_grad_mean)) > atol: test_utils.assert_close_matrices(self, k_mc2_grad_mean, k2_grad, rtol=0.05, atol=10.) @parameterized.named_parameters( test_utils.cases_from_list({ 'testcase_name': f'get={get}-' f'architecture={architecture}-' f'jit={do_jit}-', 'get': get, 'architecture': architecture, 'do_jit': do_jit } for architecture in [ 'conv', 'wrn' ] for get in [ 'ntk', 'nngp' ] for do_jit in [ True, ])) def test_issue_123( self, get, architecture, do_jit ): """Tests""" if architecture == 'wrn': # def WideResnetBlock(channels, strides=(1, 1), channel_mismatch=False): main = stax.serial( stax.Relu(), stax.Conv( channels, (3, 3), strides, padding='SAME', parameterization='standard' ), stax.Relu(), stax.Conv(channels, (3, 3), padding='SAME', parameterization='standard'), ) shortcut = ( stax.Identity() if not channel_mismatch else stax.Conv( channels, (3, 3), strides, padding='SAME', parameterization='standard' ) ) return stax.serial(stax.FanOut(2), stax.parallel(main, shortcut), stax.FanInSum()) def WideResnetGroup(n, channels, strides=(1, 1)): blocks = [] blocks += [WideResnetBlock(channels, strides, channel_mismatch=True)] for _ in range(n - 1): blocks += [WideResnetBlock(channels, (1, 1))] return stax.serial(*blocks) def WideResnet(block_size, k, num_classes): return stax.serial( stax.Conv(16, (3, 3), padding='SAME', parameterization='standard'), WideResnetGroup(block_size, int(16 * k)), WideResnetGroup(block_size, int(32 * k), (2, 2)), WideResnetGroup(block_size, int(64 * k), (2, 2)), stax.AvgPool((8, 8), padding='SAME'), stax.Flatten(), stax.Dense(num_classes, 1.0, 0.0, parameterization='standard'), ) init_fn, apply_fn, kernel_fn = WideResnet(block_size=1, k=1, num_classes=1) elif architecture == 'conv': # init_fn, apply_fn, kernel_fn = stax.serial( stax.Conv( 1, (3, 3) ), stax.Relu(), stax.Flatten(), ) else: raise ValueError(architecture) x1 = x2 = np.zeros((1, 8, 8, 3)) def kernel_scalar(x1, x2): return kernel_fn(x1, x2, get)[0, 0] if do_jit: kernel_scalar = jit(kernel_scalar) # Compare forward pass to `value_and_grad`. k1 = kernel_scalar(x1, x2) k2, k2_grad = value_and_grad(kernel_scalar)(x1, x2) self.assertAllClose(k1, k2) # Compare to forward-mode. k2_fwd, _ = jvp(kernel_scalar, (x1, x2), (x1, x2)) k2_grad_fwd = jacfwd(kernel_scalar)(x1, x2) self.assertAllClose(k1, k2_fwd) self.assertAllClose(k2_grad, k2_grad_fwd) # Compare to 0. self.assertAllClose(grad(kernel_scalar)(x1, x2), np.zeros_like(x1))
class ElementwiseTest(test_utils.NeuralTangentsTestCase): @parameterized.named_parameters( test_utils.cases_from_list({ 'testcase_name': '_{}_{}_n={}_diag_batch={}_spatial={}'.format( phi[0].__name__, same_inputs, n, diagonal_batch, diagonal_spatial), 'phi': phi, 'same_inputs': same_inputs, 'n': n, 'diagonal_batch': diagonal_batch, 'diagonal_spatial': diagonal_spatial } for phi in [ stax.Identity(), stax.Erf(), stax.Sin(), stax.Relu(), ] for same_inputs in [False, True, None] for n in [0, 1, 2] for diagonal_batch in [True, False] for diagonal_spatial in [True, False])) def test_elementwise(self, same_inputs, phi, n, diagonal_batch, diagonal_spatial): fn = lambda x: phi[1]((), x) name = phi[0].__name__ def nngp_fn(cov12, var1, var2): if 'Identity' in name: res = cov12 elif 'Erf' in name: prod = (1 + 2 * var1) * (1 + 2 * var2) res = np.arcsin(2 * cov12 / np.sqrt(prod)) * 2 / np.pi elif 'Sin' in name: sum_ = (var1 + var2) s1 = np.exp((-0.5 * sum_ + cov12)) s2 = np.exp((-0.5 * sum_ - cov12)) res = (s1 - s2) / 2 elif 'Relu' in name: prod = var1 * var2 sqrt = np.sqrt(np.maximum(prod - cov12 ** 2, 1e-30)) angles = np.arctan2(sqrt, cov12) dot_sigma = (1 - angles / np.pi) / 2 res = sqrt / (2 * np.pi) + dot_sigma * cov12 else: raise NotImplementedError(name) return res _, _, kernel_fn = stax.serial(stax.Dense(1), stax.Elementwise(fn, nngp_fn), stax.Dense(1), stax.Elementwise(fn, nngp_fn)) _, _, kernel_fn_manual = stax.serial(stax.Dense(1), phi, stax.Dense(1), phi) key = random.PRNGKey(1) shape = (4, 3, 2)[:n] + (1,) x1 = random.normal(key, (5,) + shape) if same_inputs is None: x2 = None elif same_inputs is True: x2 = x1 else: x2 = random.normal(key, (6,) + shape) kwargs = dict(diagonal_batch=diagonal_batch, diagonal_spatial=diagonal_spatial) k = kernel_fn(x1, x2, **kwargs) k_manual = kernel_fn_manual(x1, x2, **kwargs).replace(is_gaussian=False) self.assertAllClose(k_manual, k)
class ActivationTest(test_utils.NeuralTangentsTestCase): @stax.layer def _RBF(self, gamma): init_fn = lambda key, input_shape: (input_shape, ()) def apply_fn(unused_params, unused_xs, **kwargs): raise NotImplementedError() def kernel_fn(kernels, **kwargs): if kernels.ntk is not None: raise ValueError('RBF Kernel does not have an associated NTK.') if kernels.nngp.ndim > 2: raise ValueError( ('RBF Kernel is not defined for covariance matrices with dimension' ' greater than two.')) input_dim = kernels.shape1[1] cov1 = kernels.cov1 cov1 = np.reshape(cov1, (cov1.shape[0], 1)) cov2 = cov1 if kernels.cov2 is None else kernels.cov2 cov2 = np.reshape(cov2, (1, cov2.shape[0])) nngp = kernels.nngp # TODO(schsam): Update cov1 and cov2 if we want to compose this kernel # with other kernels. return kernels.replace( nngp=np.exp(-input_dim * gamma * (cov1 + cov2 - 2 * nngp))) return init_fn, apply_fn, kernel_fn def _test_activation(self, activation_fn, same_inputs, model, get, rbf_gamma=None): if 'conv' in model: test_utils.skip_test(self) key = random.PRNGKey(1) key, split = random.split(key) output_dim = 1024 if get == 'nngp' else 1 b_std = 0.5 W_std = 2.0 if activation_fn[2].__name__ == 'Sin': W_std = 0.9 if activation_fn[2].__name__ == 'Rbf': W_std = 1.0 b_std = 0.0 if model == 'fc': rtol = 0.04 X0_1 = random.normal(key, (4, 2)) X0_2 = None if same_inputs else random.normal(split, (2, 2)) affine = stax.Dense(1024, W_std, b_std) readout = stax.Dense(output_dim) depth = 1 else: rtol = 0.05 X0_1 = random.normal(key, (2, 4, 4, 3)) X0_2 = None if same_inputs else random.normal(split, (4, 4, 4, 3)) affine = stax.Conv(512, (3, 2), W_std=W_std, b_std=b_std, padding='SAME') readout = stax.serial(stax.GlobalAvgPool() if 'pool' in model else stax.Flatten(), stax.Dense(output_dim)) depth = 2 if default_backend() == 'cpu': num_samplings = 200 rtol *= 2 else: num_samplings = (500 if activation_fn[2].__name__ in ('Sin', 'Rbf') else 300) init_fn, apply_fn, kernel_fn = stax.serial( *[affine, activation_fn]*depth, readout) analytic_kernel = kernel_fn(X0_1, X0_2, get) mc_kernel_fn = nt.monte_carlo_kernel_fn( init_fn, apply_fn, split, num_samplings, implementation=2, vmap_axes=0 ) empirical_kernel = mc_kernel_fn(X0_1, X0_2, get) test_utils.assert_close_matrices(self, analytic_kernel, empirical_kernel, rtol) # Check match with explicit RBF if rbf_gamma is not None and get == 'nngp' and model == 'fc': input_dim = X0_1.shape[1] _, _, kernel_fn = self._RBF(rbf_gamma / input_dim) direct_rbf_kernel = kernel_fn(X0_1, X0_2, get) test_utils.assert_close_matrices(self, analytic_kernel, direct_rbf_kernel, rtol) @parameterized.named_parameters( test_utils.cases_from_list({ 'testcase_name': '_model={}_phi={}_{}_get={}_abc={}_approximate={}'.format( model, phi_name, 'Same_inputs' if same_inputs else 'Different_inputs', get, abc, approximate), 'model': model, 'phi_name': phi_name, 'same_inputs': same_inputs, 'get': get, 'abc': abc, 'approximate': approximate } for model in ['fc', 'conv-pool', 'conv-flatten'] for phi_name in [ 'Sin', 'Cos', 'Erf', 'Gelu', 'Sign', ] for same_inputs in [False] for get in ['nngp', 'ntk'] for approximate in [True, False] for abc in itertools.product( [2., 0.3], [1.5, 0.3], [0., -np.pi/4., np.pi/2.] ))) def test_activation( self, same_inputs, model, phi_name, get, abc, approximate ): if abc != [0.3, 1.5, -np.pi/4]: test_utils.skip_test(self) if approximate and phi_name != 'Gelu': raise absltest.SkipTest( f'{phi_name} does not have an `approximate parameter.') a, b, c = abc if phi_name == 'Sin': activation = stax.Sin(a=a, b=b, c=c) elif phi_name == 'Erf': activation = stax.Erf(a=a, b=b, c=c) elif phi_name in ['Gelu', 'Sign', 'Cos']: if a != 0.3 or b != 0.3 or c != 0.: raise absltest.SkipTest('Skip `Gelu/Sign/Cos` test if ' ' (a, b, c) != (.3, .3, 0.).') activation = stax.Gelu() if phi_name == 'Gelu' else stax.Sign() else: raise NotImplementedError(f'Activation {phi_name} is not implemented.') self._test_activation(activation, same_inputs, model, get) @parameterized.named_parameters( test_utils.cases_from_list({ 'testcase_name': '_{}_Rbf_{}_{}_{}'.format( model, 'Same_inputs' if same_inputs else 'Different_inputs', get, gamma), 'model': model, 'same_inputs': same_inputs, 'get': get, 'gamma': gamma, } for model in ['fc', 'conv-pool', 'conv-flatten'] for same_inputs in [False, True] for get in ['nngp', 'ntk'] for gamma in [1e-6, 1e-4, 1e-2, 1.0, 2.] )) def test_rbf(self, same_inputs, model, get, gamma): activation = stax.Rbf(gamma) self._test_activation(activation, same_inputs, model, get, rbf_gamma=gamma) @parameterized.named_parameters( test_utils.cases_from_list({ 'testcase_name': f'{phi.__name__}_{same_inputs}_a={a}_b={b}_n={n}', 'same_inputs': same_inputs, 'a': a, 'b': b, 'n': n, 'phi': phi } for a in [-0.5, 0.25] for b in [-0.5, -0.1, 0.1] for phi in [stax.Gaussian, stax.Exp] for same_inputs in [False, True, None] for n in [0])) def test_nonlineariy(self, phi, same_inputs, a, b, n): width = 2**10 n_samples = 2**9 init_fn, apply_fn, kernel_fn = stax.serial( stax.Dense(width), phi(a=a, b=b), stax.Dense(width), phi(a=a, b=b), stax.Dense(1)) key1, key2, key_mc = random.split(random.PRNGKey(1), 3) shape = (4, 3, 2)[:n] + (1,) x1 = np.cos(random.normal(key1, (2,) + shape)) if same_inputs is None: x2 = None elif same_inputs is True: x2 = x1 else: x2 = np.cos(random.normal(key2, (3,) + shape)) k = kernel_fn(x1, x2) mc_kernel_fn = nt.monte_carlo_kernel_fn(init_fn, apply_fn, key_mc, n_samples) k_mc = mc_kernel_fn(x1, x2, ('nngp', 'ntk')) test_utils.assert_close_matrices(self, k_mc.nngp, k.nngp, 6e-2) test_utils.assert_close_matrices(self, k_mc.ntk, k.ntk, 6e-2) def test_exp_normalized(self): key = random.PRNGKey(0) x1 = random.normal(key, (2, 6, 7, 1)) x2 = random.normal(key, (4, 6, 7, 1)) for do_clip in [True, False]: for gamma in [1., 2., 0.5]: for get in ['nngp', 'ntk']: with self.subTest(do_clip=do_clip, gamma=gamma, get=get): _, _, kernel_fn = stax.serial( stax.Conv(1, (3, 3)), stax.ExpNormalized(gamma, do_clip), stax.Conv(1, (3, 3)), stax.ExpNormalized(gamma, do_clip), stax.GlobalAvgPool(), stax.Dense(1) ) k_12 = kernel_fn(x1, x2, get=get) self.assertEqual(k_12.shape, (x1.shape[0], x2.shape[0])) k_11 = kernel_fn(x1, None, get=get) self.assertEqual(k_11.shape, (x1.shape[0],) * 2) self.assertGreater(np.min(np.linalg.eigvalsh(k_11)), 0) k_22 = kernel_fn(x2, None, get=get) self.assertEqual(k_22.shape, (x2.shape[0],) * 2) self.assertGreater(np.min(np.linalg.eigvalsh(k_22)), 0) def test_exp_normalized_ntk(self): def nngp_fn(cov12, var1, var2): prod = np.sqrt(var1 * var2) return prod * np.exp(cov12 / prod - 1) _, _, kernel_fn = stax.serial(stax.Dense(1), stax.Elementwise(nngp_fn=nngp_fn)) _, _, kernel_fn_manual = stax.serial(stax.Dense(1), stax.ExpNormalized()) key = random.PRNGKey(1) x1 = random.normal(key, (5, 4, 3, 1)) x2 = random.normal(key, (6, 4, 3, 1)) k = kernel_fn(x1, x2) k_manual = kernel_fn_manual(x1, x2) self.assertAllClose(k_manual, k) @parameterized.named_parameters( test_utils.cases_from_list({ 'testcase_name': '_{}_degree={}_get={}_readout={}'.format( 'Same_inputs' if same_inputs else 'Different_inputs', degree, get, readout ), 'same_inputs': same_inputs, 'degree': degree, 'get': get, 'readout': readout } for same_inputs in [False, True] for degree in [1, 2, 3, 4, 5, 6] for get in ['ntk', 'nngp'] for readout in ['pool', 'flatten'])) def test_hermite(self, same_inputs, degree, get, readout): key = random.PRNGKey(1) key1, key2, key = random.split(key, 3) if degree > 2: width = 10000 n_samples = 5000 test_utils.skip_test(self) else: width = 10000 n_samples = 100 x1 = np.cos(random.normal(key1, [2, 6, 6, 3])) x2 = x1 if same_inputs else np.cos(random.normal(key2, [3, 6, 6, 3])) conv_layers = [ stax.Conv(width, (3, 3), W_std=2., b_std=0.5), stax.LayerNorm(), stax.Hermite(degree), stax.GlobalAvgPool() if readout == 'pool' else stax.Flatten(), stax.Dense(1) if get == 'ntk' else stax.Identity()] init_fn, apply_fn, kernel_fn = stax.serial(*conv_layers) analytic_kernel = kernel_fn(x1, x2, get) mc_kernel_fn = nt.monte_carlo_kernel_fn(init_fn, apply_fn, key, n_samples) mc_kernel = mc_kernel_fn(x1, x2, get) rot = degree / 2. * 1e-2 test_utils.assert_close_matrices(self, mc_kernel, analytic_kernel, rot)
class BatchTest(test_utils.NeuralTangentsTestCase): @parameterized.named_parameters( test_utils.cases_from_list({ 'testcase_name': '_train_shape={}_test_shape={}_network={}_{}_batch_size={}'.format( train, test, network, name, batch_size), 'train_shape': train, 'test_shape': test, 'network': network, 'name': name, 'kernel_fn': kernel_fn, 'batch_size': batch_size } for train, test, network in zip(TRAIN_SHAPES, TEST_SHAPES, NETWORK) for name, kernel_fn in KERNELS.items() for batch_size in [2, 8])) def testSerial(self, train_shape, test_shape, network, name, kernel_fn, batch_size): key = random.PRNGKey(0) key, self_split, other_split = random.split(key, 3) data_self = random.normal(self_split, train_shape) data_other = random.normal(other_split, test_shape) kernel_fn = kernel_fn(key, train_shape[1:], network) kernel_batched = batching._serial(kernel_fn, batch_size=batch_size) _test_kernel_against_batched(self, kernel_fn, kernel_batched, data_self, data_other) # We also exclude tests for dropout + parallel. It is not clear what is the # best way to handle this case. @parameterized.named_parameters( test_utils.cases_from_list({ 'testcase_name': '_train_shape={}_test_shape={}_network={}_{}'.format( train, test, network, name), 'train_shape': train, 'test_shape': test, 'network': network, 'name': name, 'kernel_fn': kernel_fn } for train, test, network in zip(TRAIN_SHAPES, TEST_SHAPES, NETWORK) for name, kernel_fn in KERNELS.items())) def testParallel(self, train_shape, test_shape, network, name, kernel_fn): test_utils.stub_out_pmap(batching, 2) key = random.PRNGKey(0) key, self_split, other_split = random.split(key, 3) data_self = random.normal(self_split, train_shape) data_other = random.normal(other_split, test_shape) kernel_fn = kernel_fn(key, train_shape[1:], network, use_dropout=False) kernel_batched = batching._parallel(kernel_fn) _test_kernel_against_batched(self, kernel_fn, kernel_batched, data_self, data_other, True) @parameterized.named_parameters( test_utils.cases_from_list({ 'testcase_name': '_train_shape={}_test_shape={}_network={}_{}_batch_size={}'.format( train, test, network, name, batch_size), 'train_shape': train, 'test_shape': test, 'network': network, 'name': name, 'kernel_fn': kernel_fn, 'batch_size': batch_size } for train, test, network in zip(TRAIN_SHAPES, TEST_SHAPES, NETWORK) for name, kernel_fn in KERNELS.items() for batch_size in [2, 8])) def testComposition(self, train_shape, test_shape, network, name, kernel_fn, batch_size): test_utils.stub_out_pmap(batching, 2) key = random.PRNGKey(0) key, self_split, other_split = random.split(key, 3) data_self = random.normal(self_split, train_shape) data_other = random.normal(other_split, test_shape) kernel_fn = kernel_fn(key, train_shape[1:], network) kernel_batched = batching._parallel( batching._serial(kernel_fn, batch_size=batch_size)) _test_kernel_against_batched(self, kernel_fn, kernel_batched, data_self, data_other) kernel_batched = batching._serial(batching._parallel(kernel_fn), batch_size=batch_size) _test_kernel_against_batched(self, kernel_fn, kernel_batched, data_self, data_other) @parameterized.named_parameters( test_utils.cases_from_list({ 'testcase_name': '_train_shape={}_test_shape={}_network={}_{}_batch_size={}'.format( train, test, network, name, batch_size), 'train_shape': train, 'test_shape': test, 'network': network, 'name': name, 'kernel_fn': kernel_fn, 'batch_size': batch_size } for train, test, network in zip(TRAIN_SHAPES, TEST_SHAPES, NETWORK) for name, kernel_fn in KERNELS.items() for batch_size in [2, 8])) def testAutomatic(self, train_shape, test_shape, network, name, kernel_fn, batch_size): test_utils.stub_out_pmap(batching, 2) key = random.PRNGKey(0) key, self_split, other_split = random.split(key, 3) data_self = random.normal(self_split, train_shape) data_other = random.normal(other_split, test_shape) kernel_fn = kernel_fn(key, train_shape[1:], network) kernel_batched = batching.batch(kernel_fn, batch_size=batch_size) _test_kernel_against_batched(self, kernel_fn, kernel_batched, data_self, data_other) kernel_batched = batching.batch(kernel_fn, batch_size=batch_size, store_on_device=False) _test_kernel_against_batched(self, kernel_fn, kernel_batched, data_self, data_other) def _test_analytic_kernel_composition(self, batching_fn): # Check Fully-Connected. rng = random.PRNGKey(0) rng_self, rng_other = random.split(rng) x_self = random.normal(rng_self, (8, 2)) x_other = random.normal(rng_other, (2, 2)) Block = stax.serial(stax.Dense(256), stax.Relu()) _, _, ker_fn = Block ker_fn = batching_fn(ker_fn) _, _, composed_ker_fn = stax.serial(Block, Block) ker_out = ker_fn(ker_fn(x_self)) composed_ker_out = composed_ker_fn(x_self) if batching_fn == batching._parallel: # In the parallel setting, `x1_is_x2` is not computed correctly # when x1==x2. composed_ker_out = composed_ker_out.replace( x1_is_x2=ker_out.x1_is_x2) self.assertAllClose(ker_out, composed_ker_out) ker_out = ker_fn(ker_fn(x_self, x_other)) composed_ker_out = composed_ker_fn(x_self, x_other) if batching_fn == batching._parallel: composed_ker_out = composed_ker_out.replace( x1_is_x2=ker_out.x1_is_x2) self.assertAllClose(ker_out, composed_ker_out) # Check convolutional + pooling. x_self = random.normal(rng, (8, 4, 4, 3)) x_other = random.normal(rng, (2, 4, 4, 3)) Block = stax.serial(stax.Conv(256, (2, 2)), stax.Relu()) Readout = stax.serial(stax.GlobalAvgPool(), stax.Dense(10)) block_ker_fn, readout_ker_fn = Block[2], Readout[2] _, _, composed_ker_fn = stax.serial(Block, Readout) block_ker_fn = batching_fn(block_ker_fn) readout_ker_fn = batching_fn(readout_ker_fn) ker_out = readout_ker_fn(block_ker_fn(x_self)) composed_ker_out = composed_ker_fn(x_self) if batching_fn == batching._parallel: composed_ker_out = composed_ker_out.replace( x1_is_x2=ker_out.x1_is_x2) self.assertAllClose(ker_out, composed_ker_out) ker_out = readout_ker_fn(block_ker_fn(x_self, x_other)) composed_ker_out = composed_ker_fn(x_self, x_other) if batching_fn == batching._parallel: composed_ker_out = composed_ker_out.replace( x1_is_x2=ker_out.x1_is_x2) self.assertAllClose(ker_out, composed_ker_out) @parameterized.named_parameters( test_utils.cases_from_list({ 'testcase_name': '_on_device={}_batch_size={}'.format(store_on_device, batch_size), 'store_on_device': store_on_device, 'batch_size': batch_size } for store_on_device in [True, False] for batch_size in [2, 8])) def testAnalyticKernelComposeSerial(self, store_on_device, batch_size): self._test_analytic_kernel_composition( partial(batching._serial, batch_size=batch_size, store_on_device=store_on_device)) def testAnalyticKernelComposeParallel(self): test_utils.stub_out_pmap(batching, 2) self._test_analytic_kernel_composition(batching._parallel) @parameterized.named_parameters( test_utils.cases_from_list({ 'testcase_name': '_on_device={}_batch_size={}'.format(store_on_device, batch_size), 'store_on_device': store_on_device, 'batch_size': batch_size } for store_on_device in [True, False] for batch_size in [2, 8])) def testAnalyticKernelComposeAutomatic(self, store_on_device, batch_size): test_utils.stub_out_pmap(batching, 2) self._test_analytic_kernel_composition( partial(batching.batch, batch_size=batch_size, store_on_device=store_on_device)) def test_jit_or_pmap_broadcast(self): def kernel_fn(x1, x2, do_flip, keys, do_square, params, _unused=None, p=0.65): res = np.abs(np.matmul(x1, x2)) if do_square: res *= res if do_flip: res = -res res *= random.uniform(keys) * p return [res, params] params = (np.array([1., 0.3]), (np.array([1.2]), np.array([0.5]))) x2 = np.arange(0, 10).reshape((10, )) keys = random.PRNGKey(1) kernel_fn_pmapped = batching._jit_or_pmap_broadcast(kernel_fn, device_count=0) x1 = np.arange(0, 10).reshape((1, 10)) for do_flip in [True, False]: for do_square in [True, False]: with self.subTest(do_flip=do_flip, do_square=do_square, device_count=0): res_1 = kernel_fn(x1, x2, do_flip, keys, do_square, params, _unused=True, p=0.65) res_2 = kernel_fn_pmapped(x1, x2, do_flip, keys, do_square, params, _unused=True) self.assertAllClose(res_1, res_2) test_utils.stub_out_pmap(batching, 1) x1 = np.arange(0, 10).reshape((1, 10)) kernel_fn_pmapped = batching._jit_or_pmap_broadcast(kernel_fn, device_count=1) for do_flip in [True, False]: for do_square in [True, False]: with self.subTest(do_flip=do_flip, do_square=do_square, device_count=1): res_1 = kernel_fn(x1, x2, do_flip, keys, do_square, params, _unused=False, p=0.65) res_2 = kernel_fn_pmapped(x1, x2, do_flip, keys, do_square, params, _unused=None) self.assertAllClose(res_1[0], res_2[0]) self.assertAllClose( tree_map(partial(np.expand_dims, axis=0), res_1[1]), res_2[1]) kernel_fn_pmapped = batching._jit_or_pmap_broadcast(kernel_fn, device_count=2) x1 = np.arange(0, 20).reshape((2, 10)) test_utils.stub_out_pmap(batching, 2) def broadcast(arg): return np.broadcast_to(arg, (2, ) + arg.shape) for do_flip in [True, False]: for do_square in [True, False]: with self.subTest(do_flip=do_flip, do_square=do_square, device_count=2): res_1 = kernel_fn(x1, x2, do_flip, keys, do_square, params, p=0.2) res_2 = kernel_fn_pmapped(x1, x2, do_flip, keys, do_square, params, _unused=None, p=0.2) self.assertAllClose(res_1[0][0], res_2[0][0]) self.assertAllClose(res_1[0][1], res_2[0][1]) self.assertAllClose(tree_map(broadcast, res_1[1]), res_2[1]) @parameterized.named_parameters( test_utils.cases_from_list( { 'testcase_name': '_same_inputs={}'.format(same_inputs), 'same_inputs': same_inputs } for same_inputs in [True, False])) def test_parallel_in_out(self, same_inputs): test_utils.stub_out_pmap(batching, 2) rng = random.PRNGKey(0) input_key1, input_key2 = random.split(rng, 2) x1_1, x1_2, x1_3 = random.normal(input_key1, (3, 4, 1)) x1 = (x1_1, (x1_2, x1_3)) if same_inputs: x2 = None else: x2_1, x2_2, x2_3 = random.normal(input_key2, (3, 8, 1)) x2 = (x2_1, (x2_2, x2_3)) N = WIDTH def net(N_out): return stax.parallel( stax.Dense(N_out), stax.parallel(stax.Dense(N_out + 1), stax.Dense(N_out + 2))) # Check NNGP. readin = net(N) readout = net(1) K_readin_fn = jit(readin[2]) K_readout_fn = jit(partial(readout[2], get='nngp')) batch_K_readin_fn = batching.batch(K_readin_fn, 2) batch_K_readout_fn = batching.batch(K_readout_fn, 2) test_utils.assert_close_matrices( self, K_readout_fn(K_readin_fn(x1, x2)), batch_K_readout_fn(batch_K_readin_fn(x1, x2)), RTOL) # Check Both. K_readin_fn = jit(readin[2]) K_readout_fn = jit(partial(readout[2], get=('nngp', 'ntk'))) batch_K_readin_fn = batching.batch(K_readin_fn, 2) batch_K_readout_fn = batching.batch(K_readout_fn, 2) test_utils.assert_close_matrices( self, K_readout_fn(K_readin_fn(x1, x2)), batch_K_readout_fn(batch_K_readin_fn(x1, x2)), RTOL) @parameterized.named_parameters( test_utils.cases_from_list( { 'testcase_name': '_same_inputs={}'.format(same_inputs), 'same_inputs': same_inputs } for same_inputs in [True, False])) def test_parallel_in_out_empirical(self, same_inputs): test_utils.stub_out_pmap(batching, 2) rng = random.PRNGKey(0) input_key1, input_key2, net_key = random.split(rng, 3) x1_1, x1_2, x1_3 = random.normal(input_key1, (3, 4, 1)) x1 = (x1_1, (x1_2, x1_3)) if same_inputs: x2 = None else: x2_1, x2_2, x2_3 = random.normal(input_key2, (3, 8, 1)) x2 = (x2_1, (x2_2, x2_3)) def net(N_out): return stax.parallel( stax.Dense(N_out), stax.parallel(stax.Dense(N_out + 1), stax.Dense(N_out + 2))) # Check NNGP. init_fn, apply_fn, _ = net(WIDTH) _, params = init_fn(net_key, ((-1, 1), ((-1, 1), (-1, 1)))) kernel_fn = jit(nt.empirical_nngp_fn(apply_fn)) batch_kernel_fn = jit(batching.batch(kernel_fn, 2)) test_utils.assert_close_matrices(self, kernel_fn(x1, x2, params), batch_kernel_fn(x1, x2, params), RTOL) # Check NTK. init_fn, apply_fn, _ = stax.serial(net(WIDTH), net(1)) _, params = init_fn(net_key, ((-1, 1), ((-1, 1), (-1, 1)))) kernel_fn = jit(nt.empirical_ntk_fn(apply_fn)) batch_kernel_fn = jit(batching.batch(kernel_fn, 2)) test_utils.assert_close_matrices(self, kernel_fn(x1, x2, params), batch_kernel_fn(x1, x2, params), RTOL) @parameterized.named_parameters( test_utils.cases_from_list( ({ 'testcase_name': (f'_same_inputs={same_inputs}' f'_device_count={device_count}' f'_trace_axes={trace_axes}' f'_diagonal_axes={diagonal_axes}'), 'same_inputs': same_inputs, 'device_count': device_count, 'trace_axes': trace_axes, 'diagonal_axes': diagonal_axes } for same_inputs in [True, False] for device_count in [-1, 0, 1, 2] for trace_axes, diagonal_axes in zip([(-1, ), (1, -1), ()], [( 1, ), (), (1, -1)])))) def test_empirical_ntk_diagonal_outputs(self, same_inputs, device_count, trace_axes, diagonal_axes): test_utils.stub_out_pmap(batching, 2) rng = random.PRNGKey(0) input_key1, input_key2, net_key = random.split(rng, 3) init_fn, apply_fn, _ = stax.serial(stax.Dense(5), stax.Relu(), stax.Dense(3)) test_x1 = random.normal(input_key1, (12, 4, 4)) test_x2 = None if same_inputs: test_x2 = random.normal(input_key2, (9, 4, 4)) kernel_fn = nt.empirical_ntk_fn(apply_fn, trace_axes=trace_axes, diagonal_axes=diagonal_axes, vmap_axes=0, implementation=2) _, params = init_fn(net_key, test_x1.shape) true_kernel = kernel_fn(test_x1, test_x2, params) batched_fn = batching.batch(kernel_fn, device_count=device_count, batch_size=3) batch_kernel = batched_fn(test_x1, test_x2, params) self.assertAllClose(true_kernel, batch_kernel)
class MonteCarloTest(test_utils.NeuralTangentsTestCase): @parameterized.named_parameters( test_utils.cases_from_list({ 'testcase_name': '[batch_size={}, ' 'device_count={} ' 'store_on_device={} ' 'get={} ' ']'.format(batch_size, device_count, store_on_device, get), 'batch_size': batch_size, 'device_count': device_count, 'store_on_device': store_on_device, 'get': get, } for batch_size in BATCH_SIZES for device_count in DEVICE_COUNTS for store_on_device in STORE_ON_DEVICE for get in ALL_GET)) def test_sample_once_batch(self, batch_size, device_count, store_on_device, get): test_utils.stub_out_pmap(batching, device_count) x1, x2, init_fn, apply_fn, _, key = _get_inputs_and_model() kernel_fn = nt.empirical_kernel_fn(apply_fn) sample_once_fn = monte_carlo._sample_once_kernel_fn(kernel_fn, init_fn) sample_once_batch_fn = monte_carlo._sample_once_kernel_fn( kernel_fn, init_fn, batch_size, device_count, store_on_device) one_sample = sample_once_fn(x1, x2, key, get) one_sample_batch = sample_once_batch_fn(x1, x2, key, get) self.assertAllClose(one_sample, one_sample_batch) @parameterized.named_parameters( test_utils.cases_from_list({ 'testcase_name': '[batch_size={}, ' 'device_count={} ' 'store_on_device={} ' 'get={} ' ']'.format(batch_size, device_count, store_on_device, get), 'batch_size': batch_size, 'device_count': device_count, 'store_on_device': store_on_device, 'get': get, } for batch_size in BATCH_SIZES for device_count in DEVICE_COUNTS for store_on_device in STORE_ON_DEVICE for get in ALL_GET)) def test_batch_sample_once(self, batch_size, device_count, store_on_device, get): test_utils.stub_out_pmap(batching, device_count) x1, x2, init_fn, apply_fn, _, key = _get_inputs_and_model() kernel_fn = nt.empirical_kernel_fn(apply_fn) sample_once_fn = monte_carlo._sample_once_kernel_fn( kernel_fn, init_fn, device_count=0) batch_sample_once_fn = batching.batch(sample_once_fn, batch_size, device_count, store_on_device) one_sample = sample_once_fn(x1, x2, key, get) one_batch_sample = batch_sample_once_fn(x1, x2, key, get) self.assertAllClose(one_sample, one_batch_sample) @parameterized.named_parameters( test_utils.cases_from_list({ 'testcase_name': '[batch_size={}, ' 'device_count={} ' 'store_on_device={} ' ']'.format(batch_size, device_count, store_on_device ), 'batch_size': batch_size, 'device_count': device_count, 'store_on_device': store_on_device, } for batch_size in BATCH_SIZES for device_count in DEVICE_COUNTS for store_on_device in STORE_ON_DEVICE)) def test_sample_vs_analytic_nngp(self, batch_size, device_count, store_on_device): test_utils.stub_out_pmap(batching, device_count) x1, x2, init_fn, apply_fn, stax_kernel_fn, key = _get_inputs_and_model( WIDTH, 256, jax.default_backend() == 'tpu') sample = monte_carlo.monte_carlo_kernel_fn(init_fn, apply_fn, key, 200, batch_size, device_count, store_on_device) ker_empirical = sample(x1, x2, 'nngp') ker_analytic = stax_kernel_fn(x1, x2, 'nngp') test_utils.assert_close_matrices(self, ker_analytic, ker_empirical, 2e-2) @parameterized.named_parameters( test_utils.cases_from_list({ 'testcase_name': '[batch_size={}, ' 'device_count={} ' 'store_on_device={} ' ']'.format(batch_size, device_count, store_on_device ), 'batch_size': batch_size, 'device_count': device_count, 'store_on_device': store_on_device, } for batch_size in BATCH_SIZES for device_count in DEVICE_COUNTS for store_on_device in STORE_ON_DEVICE)) def test_monte_carlo_vs_analytic_ntk(self, batch_size, device_count, store_on_device): test_utils.stub_out_pmap(batching, device_count) x1, x2, init_fn, apply_fn, stax_kernel_fn, key = _get_inputs_and_model( WIDTH, 2, jax.default_backend() == 'tpu') sample = monte_carlo.monte_carlo_kernel_fn(init_fn, apply_fn, key, 100, batch_size, device_count, store_on_device, vmap_axes=0) ker_empirical = sample(x1, x2, 'ntk') ker_analytic = stax_kernel_fn(x1, x2, 'ntk') test_utils.assert_close_matrices(self, ker_analytic, ker_empirical, 2e-2) @parameterized.named_parameters( test_utils.cases_from_list({ 'testcase_name': '[batch_size={}, ' 'device_count={} ' 'store_on_device={} ' 'get={}' ']'.format(batch_size, device_count, store_on_device, get), 'batch_size': batch_size, 'device_count': device_count, 'store_on_device': store_on_device, 'get': get } for batch_size in BATCH_SIZES for device_count in DEVICE_COUNTS for store_on_device in STORE_ON_DEVICE for get in ALL_GET)) def test_monte_carlo_generator(self, batch_size, device_count, store_on_device, get): test_utils.stub_out_pmap(batching, device_count) x1, x2, init_fn, apply_fn, stax_kernel_fn, key = _get_inputs_and_model(8, 1) x3, x4, _, _, _, _ = _get_inputs_and_model(8, 1) log_n_max = 4 n_samples = [2**k for k in range(log_n_max)] sample_generator = monte_carlo.monte_carlo_kernel_fn( init_fn, apply_fn, key, n_samples, batch_size, device_count, store_on_device, vmap_axes=0) if get is None: samples_12 = sample_generator(x1, x2) samples_34 = sample_generator(x3, x4) count = 0 for n, s_12, s_34 in zip(n_samples, samples_12, samples_34): sample_fn = monte_carlo.monte_carlo_kernel_fn(init_fn, apply_fn, key, n, batch_size, device_count, store_on_device, vmap_axes=0) sample_12 = sample_fn(x1, x2) sample_34 = sample_fn(x3, x4) self.assertAllClose(s_12, sample_12) self.assertAllClose(s_12, s_34) self.assertAllClose(s_12, sample_34) count += 1 self.assertEqual(log_n_max, count) ker_analytic_12 = stax_kernel_fn(x1, x2, ('nngp', 'ntk')) ker_analytic_34 = stax_kernel_fn(x3, x4, ('nngp', 'ntk')) else: samples_12 = sample_generator(x1, x2, get) samples_34 = sample_generator(x3, x4, get) count = 0 for n, s_12, s_34 in zip(n_samples, samples_12, samples_34): sample_fn = monte_carlo.monte_carlo_kernel_fn( init_fn, apply_fn, key, n, batch_size, device_count, store_on_device, vmap_axes=0) sample_12 = sample_fn(x1, x2, get) sample_34 = sample_fn(x3, x4, get) self.assertAllClose(s_12, sample_12) self.assertAllClose(s_12, s_34) self.assertAllClose(s_12, sample_34) count += 1 self.assertEqual(log_n_max, count) ker_analytic_12 = stax_kernel_fn(x1, x2, get) ker_analytic_34 = stax_kernel_fn(x3, x4, get) self.assertAllClose(ker_analytic_12, s_12, atol=2., rtol=2.) self.assertAllClose(ker_analytic_12, ker_analytic_34) @parameterized.named_parameters( test_utils.cases_from_list({ 'testcase_name': f'_same_inputs={same_inputs}_batch_size={batch_size}', 'same_inputs': same_inputs, 'batch_size': batch_size } for same_inputs in [True, False] for batch_size in [1, 2])) def test_parallel_in_out_mc(self, same_inputs, batch_size): rng = random.PRNGKey(0) input_key1, input_key2, net_key = random.split(rng, 3) x1_1, x1_2, x1_3 = random.normal(input_key1, (3, 2, 5)) x1 = (x1_1, (x1_2, x1_3)) if same_inputs: x2 = None else: x2_1, x2_2, x2_3 = random.normal(input_key2, (3, 4, 5)) x2 = (x2_1, (x2_2, x2_3)) def net(N_out): return stax.parallel(stax.Dense(N_out), stax.parallel(stax.Dense(N_out + 1), stax.Dense(N_out + 2))) # Check NNGP. init_fn, apply_fn, _ = net(WIDTH) nb_kernel_fn = monte_carlo.monte_carlo_kernel_fn(init_fn, apply_fn, net_key, n_samples=4, trace_axes=(-1,)) kernel_fn = monte_carlo.monte_carlo_kernel_fn(init_fn, apply_fn, net_key, n_samples=4, batch_size=batch_size, trace_axes=(-1,)) self.assertAllClose(kernel_fn(x1, x2, 'nngp'), nb_kernel_fn(x1, x2, 'nngp'))
class MaskingTest(test_utils.NeuralTangentsTestCase): @parameterized.named_parameters( test_utils.cases_from_list({ 'testcase_name': ' [{}_get={}_axis={}_mask={}_concat={}_p={}]'.format( 'same_inputs' if same_inputs else 'different_inputs', get, mask_axis, mask_constant, concat, p, ), 'same_inputs': same_inputs, 'get': get, 'mask_axis': mask_axis, 'mask_constant': mask_constant, 'concat': concat, 'p': p, } for same_inputs in [False] for get in ['ntk'] for concat in [None, 0, 1] for p in [0.5] for mask_axis in [(), (0, ), (1, 3)] for mask_constant in [10.])) def test_mask_fc(self, same_inputs, get, concat, p, mask_axis, mask_constant): width = 512 n_samples = 128 tol = 0.04 key = random.PRNGKey(1) x1 = random.normal(key, (4, 6, 5, 7)) x1 = test_utils.mask(x1, mask_constant, mask_axis, key, p) if same_inputs: x2 = None else: x2 = random.normal(key, (2, 6, 5, 7)) x2 = test_utils.mask(x2, mask_constant, mask_axis, key, p) nn = stax.serial( stax.Flatten(), stax.FanOut(3), stax.parallel( stax.serial( stax.Dense(width, 1., 0.1), stax.Abs(), stax.DotGeneral(lhs=-0.2), stax.Dense(width, 1.5, 0.01), ), stax.serial( stax.Dense(width, 1.1, 0.1), stax.DotGeneral(rhs=0.7), stax.Erf(), stax.Dense(width if concat != 1 else 512, 1.5, 0.1), ), stax.serial( stax.DotGeneral(rhs=0.5), stax.Dense(width, 1.2), stax.ABRelu(-0.2, 0.4), stax.Dense(width if concat != 1 else 1024, 1.3, 0.2), )), (stax.FanInSum() if concat is None else stax.FanInConcat(concat)), stax.Dense(width, 2., 0.01), stax.Relu()) if get == 'nngp': init_fn, apply_fn, kernel_fn = stax.serial( nn, stax.Dense(width, 1., 0.1)) elif get == 'ntk': init_fn, apply_fn, kernel_fn = stax.serial(nn, stax.Dense(1, 1., 0.1)) else: raise ValueError(get) kernel_fn_mc = nt.monte_carlo_kernel_fn( init_fn, apply_fn, key, n_samples, device_count=0 if concat in (0, -2) else -1, implementation=2, vmap_axes=None if concat in (0, -2) else 0, ) kernel_fn = jit(kernel_fn, static_argnums=(2, )) exact = kernel_fn(x1, x2, get, mask_constant=mask_constant) empirical = kernel_fn_mc(x1, x2, get=get, mask_constant=mask_constant) test_utils.assert_close_matrices(self, empirical, exact, tol) @parameterized.named_parameters( test_utils.cases_from_list({ 'testcase_name': ' [{}_get={}_axis={}_mask={}_concat={}_{}_p={}_n={}_{}]' ''.format('same_inputs' if same_inputs else 'different_inputs', get, mask_axis, mask_constant, concat, proj, p, n, 'transpose' if transpose else ''), 'same_inputs': same_inputs, 'get': get, 'mask_axis': mask_axis, 'mask_constant': mask_constant, 'concat': concat, 'proj': proj, 'p': p, 'n': n, 'transpose': transpose } for proj in ['flatten', 'avg'] for same_inputs in [False] for get in ['ntk'] for n in [0, 1] for concat in [None] + list(range(n + 1)) for mask_constant in [10.] for p in [0.5] for transpose in [True, False] for mask_axis in [(), (0, ), (0, 1, 2, 3)])) def test_mask_conv(self, same_inputs, get, mask_axis, mask_constant, concat, proj, p, n, transpose): test_utils.skip_test(self) if default_backend() == 'gpu' and n > 3: raise absltest.SkipTest('>=4D-CNN is not supported on GPUs.') width = 256 n_samples = 256 tol = 0.03 key = random.PRNGKey(1) spatial_shape = ((1, 2, 3, 2, 1) if transpose else (15, 8, 9))[:n] filter_shape = ((2, 3, 1, 2, 1) if transpose else (7, 2, 3))[:n] strides = (2, 1, 3, 2, 3)[:n] spatial_spec = 'HWDZX'[:n] dimension_numbers = ('N' + spatial_spec + 'C', 'OI' + spatial_spec, 'N' + spatial_spec + 'C') x1 = np.cos(random.normal(key, (2, ) + spatial_shape + (2, ))) x1 = test_utils.mask(x1, mask_constant, mask_axis, key, p) if same_inputs: x2 = None else: x2 = np.cos(random.normal(key, (4, ) + spatial_shape + (2, ))) x2 = test_utils.mask(x2, mask_constant, mask_axis, key, p) def get_attn(): return stax.GlobalSelfAttention( n_chan_out=width, n_chan_key=width, n_chan_val=int(np.round(float(width) / int(np.sqrt(width)))), n_heads=int(np.sqrt(width)), ) if proj == 'avg' else stax.Identity() conv = stax.ConvTranspose if transpose else stax.Conv nn = stax.serial( stax.FanOut(3), stax.parallel( stax.serial( conv(dimension_numbers=dimension_numbers, out_chan=width, strides=strides, filter_shape=filter_shape, padding='CIRCULAR', W_std=1.5, b_std=0.2), stax.LayerNorm(axis=(1, -1)), stax.Abs(), stax.DotGeneral(rhs=0.9), conv(dimension_numbers=dimension_numbers, out_chan=width, strides=strides, filter_shape=filter_shape, padding='VALID', W_std=1.2, b_std=0.1), ), stax.serial( conv(dimension_numbers=dimension_numbers, out_chan=width, strides=strides, filter_shape=filter_shape, padding='SAME', W_std=0.1, b_std=0.3), stax.Relu(), stax.Dropout(0.7), conv(dimension_numbers=dimension_numbers, out_chan=width, strides=strides, filter_shape=filter_shape, padding='VALID', W_std=0.9, b_std=1.), ), stax.serial( get_attn(), conv(dimension_numbers=dimension_numbers, out_chan=width, strides=strides, filter_shape=filter_shape, padding='CIRCULAR', W_std=1., b_std=0.1), stax.Erf(), stax.Dropout(0.2), stax.DotGeneral(rhs=0.7), conv(dimension_numbers=dimension_numbers, out_chan=width, strides=strides, filter_shape=filter_shape, padding='VALID', W_std=1., b_std=0.1), )), (stax.FanInSum() if concat is None else stax.FanInConcat(concat)), get_attn(), { 'avg': stax.GlobalAvgPool(), 'sum': stax.GlobalSumPool(), 'flatten': stax.Flatten(), }[proj], ) if get == 'nngp': init_fn, apply_fn, kernel_fn = stax.serial( nn, stax.Dense(width, 1., 0.)) elif get == 'ntk': init_fn, apply_fn, kernel_fn = stax.serial(nn, stax.Dense(1, 1., 0.)) else: raise ValueError(get) kernel_fn_mc = nt.monte_carlo_kernel_fn( init_fn, apply_fn, key, n_samples, device_count=0 if concat in (0, -n) else -1, implementation=2, vmap_axes=None if concat in (0, -n) else 0, ) kernel_fn = jit(kernel_fn, static_argnums=(2, )) exact = kernel_fn(x1, x2, get, mask_constant=mask_constant) empirical = kernel_fn_mc(x1, x2, get=get, mask_constant=mask_constant) test_utils.assert_close_matrices(self, empirical, exact, tol)
test_utils.update_test_tolerance() prandom.seed(1) @parameterized.named_parameters( test_utils.cases_from_list( { 'testcase_name': ' [{}_out={}_in={}]'.format( 'same_inputs' if same_inputs else 'different_inputs', readout[0].__name__, readin[0].__name__), 'same_inputs': same_inputs, 'readout': readout, 'readin': readin } for same_inputs in [False, True] for readout in [stax.Flatten(), stax.GlobalAvgPool(), stax.Identity()] for readin in [stax.Flatten(), stax.GlobalAvgPool(), stax.Identity()])) class DiagonalTest(test_utils.NeuralTangentsTestCase): def _get_kernel_fn(self, same_inputs, readin, readout): key = random.PRNGKey(1) x1 = random.normal(key, (2, 5, 6, 3)) x2 = None if same_inputs else random.normal(key, (3, 5, 6, 3)) layers = [readin] filter_shape = (2, 3) if readin[0].__name__ == 'Identity' else ()
class EmpiricalTest(test_utils.NeuralTangentsTestCase): # We use a three layer deep linear network for testing. @classmethod def f(cls, x, params, do_alter, do_shift_x=True): w1, w2, b = params if do_alter: b *= 2. w1 += 5. w2 /= 0.9 if do_shift_x: x = x * 2 + 1. return [ 0.5 *, w1), x) +, x) + b, (, x), w2) ] @classmethod def f_lin_exact(cls, x0, x, params, do_alter, do_shift_x=True): w1, w2, b = params f0 = EmpiricalTest.f(x0, params, do_alter, do_shift_x) if do_shift_x: x0 = x0 * 2 + 1. x = x * 2 + 1. dx = x - x0 if do_alter: b *= 2. w1 += 5. w2 /= 0.9 return tree_map( operator.add, f0, [, w1) + w2, dx), (, dx), 0.)]) @classmethod def _get_init_data(cls, shape): key = random.PRNGKey(0) key, s1, s2, s3, = random.split(key, 4) w1 = random.normal(s1, shape) w1 = 0.5 * (w1 + w1.T) w2 = random.normal(s2, shape) b = random.normal(s3, (1, ) * (len(shape) - 1) + (shape[-1], )) params = (w1, w2, b) key, split = random.split(key) x0 = random.normal(split, (shape[-1], 1)) return key, params, x0 @parameterized.named_parameters( test_utils.cases_from_list({ 'testcase_name': '_{}'.format(shape), 'shape': shape } for shape in TAYLOR_MATRIX_SHAPES)) def testLinearization(self, shape): key, params, x0 = self._get_init_data(shape) f_lin = nt.linearize(EmpiricalTest.f, x0) for _ in range(TAYLOR_RANDOM_SAMPLES): for do_alter in [True, False]: for do_shift_x in [True, False]: key, split = random.split(key) x = random.normal(split, (shape[-1], 1)) self.assertAllClose( EmpiricalTest.f_lin_exact(x0, x, params, do_alter, do_shift_x=do_shift_x), f_lin(x, params, do_alter, do_shift_x=do_shift_x)) @parameterized.named_parameters( test_utils.cases_from_list({ 'testcase_name': '_{}'.format(shape), 'shape': shape } for shape in TAYLOR_MATRIX_SHAPES)) def testTaylorExpansion(self, shape): def f_2_exact(x0, x, params, do_alter, do_shift_x=True): w1, w2, b = params f_lin = EmpiricalTest.f_lin_exact(x0, x, params, do_alter, do_shift_x) if do_shift_x: x0 = x0 * 2 + 1. x = x * 2 + 1. if do_alter: b *= 2. w1 += 5. w2 /= 0.9 dx = x - x0 return tree_map(operator.add, f_lin, [0.5 *, w1), dx), (0., 0.)]) key, params, x0 = self._get_init_data(shape) f_lin = nt.taylor_expand(EmpiricalTest.f, x0, 1) f_2 = nt.taylor_expand(EmpiricalTest.f, x0, 2) for _ in range(TAYLOR_RANDOM_SAMPLES): for do_alter in [True, False]: for do_shift_x in [True, False]: key, split = random.split(key) x = random.normal(split, (shape[-1], 1)) self.assertAllClose( EmpiricalTest.f_lin_exact(x0, x, params, do_alter, do_shift_x=do_shift_x), f_lin(x, params, do_alter, do_shift_x=do_shift_x)) self.assertAllClose( f_2_exact(x0, x, params, do_alter, do_shift_x=do_shift_x), f_2(x, params, do_alter, do_shift_x=do_shift_x)) @parameterized.named_parameters( test_utils.cases_from_list({ 'testcase_name': '_train_shape={}_test_shape={}_network={}_{}'.format( train, test, network, name), 'train_shape': train, 'test_shape': test, 'network': network, 'name': name, 'kernel_fn': kernel_fn } for train, test, network in zip(TRAIN_SHAPES, TEST_SHAPES, NETWORK) for name, kernel_fn in KERNELS.items())) def testNTKAgainstDirect(self, train_shape, test_shape, network, name, kernel_fn): key = random.PRNGKey(0) key, self_split, other_split = random.split(key, 3) data_self = random.normal(self_split, train_shape) data_other = random.normal(other_split, test_shape) implicit, direct, _ = kernel_fn(key, train_shape[1:], network, diagonal_axes=(), trace_axes=()) implicit_batched, direct_batched, _ = kernel_fn(key, train_shape[1:], network, diagonal_axes=(), trace_axes=(), vmap_axes=0) g = implicit(data_self, None) g_direct = direct(data_self, None) g_batched = implicit_batched(data_self, None) g_direct_batched = direct_batched(data_self, None) self.assertAllClose(g, g_direct) self.assertAllClose(g, g_batched) self.assertAllClose(g, g_direct_batched) g = implicit(data_other, data_self) g_direct = direct(data_other, data_self) g_batched = implicit_batched(data_other, data_self) g_direct_batched = direct_batched(data_other, data_self) self.assertAllClose(g, g_direct) self.assertAllClose(g, g_batched) self.assertAllClose(g, g_direct_batched) @parameterized.named_parameters( test_utils.cases_from_list( { 'testcase_name': '_diagonal_axes={}_trace_axes={}'.format( diagonal_axes, trace_axes), 'diagonal_axes': diagonal_axes, 'trace_axes': trace_axes, } for diagonal_axes in [(), (0, ), (0, 1), (0, 1, 2), ( 0, 1, 2, 3), (-1, ), (-2, ), (0, -1), (1, -2), (2, 3), (3, 0, 2)] for trace_axes in [(), (0, ), (0, 1), (-1, ), (1, ), ( 0, -1), (-1, -2), (0, 1, 2, 3), (3, 1, 2, 0), (1, 2, 3), (-3, -2), (-3, -1), (-2, -4), (2, 0, -1)]) ) def testAxes(self, diagonal_axes, trace_axes): key = random.PRNGKey(0) key, self_split, other_split = random.split(key, 3) data_self = random.normal(self_split, (4, 5, 6, 3)) data_other = random.normal(other_split, (2, 5, 6, 3)) _diagonal_axes = tuple(d % data_self.ndim for d in diagonal_axes) _trace_axes = tuple(t % data_self.ndim for t in trace_axes) if any(d == c for d in _diagonal_axes for c in _trace_axes): raise absltest.SkipTest( 'diagonal axes must be different from channel axes.') get_kernel = KERNELS['empirical_logits_3'] kwargs = dict(key=key, input_shape=(5, 6, 3), network=CONV, diagonal_axes=diagonal_axes, trace_axes=trace_axes) implicit, direct, nngp = get_kernel(**kwargs) implicit_batched, direct_batched, _ = get_kernel(**kwargs, vmap_axes=0) n_marg = len(_diagonal_axes) n_chan = len(_trace_axes) g_nngp = nngp(data_self, None) self.assertEqual(2 * (data_self.ndim - n_chan) - n_marg, g_nngp.ndim) g_direct = direct(data_self, None) self.assertEqual(g_nngp.shape, g_direct.shape) g_direct_batched = direct_batched(data_self, None) g = implicit(data_self, None) g_batched = implicit_batched(data_self, None) self.assertAllClose(g_direct, g) self.assertAllClose(g_direct, g_direct_batched) self.assertAllClose(g_direct, g_batched) if 0 not in _trace_axes and 0 not in _diagonal_axes: g_nngp = nngp(data_other, data_self) self.assertEqual(2 * (data_self.ndim - n_chan) - n_marg, g_nngp.ndim) g_direct = direct(data_other, data_self) self.assertEqual(g_nngp.shape, g_direct.shape) g_direct_batched = direct_batched(data_other, data_self) g = implicit(data_other, data_self) g_batched = implicit_batched(data_other, data_self) self.assertAllClose(g_direct, g) self.assertAllClose(g_direct, g_direct_batched) self.assertAllClose(g_direct, g_batched) @parameterized.named_parameters( test_utils.cases_from_list( { 'testcase_name': '_same_inputs={}'.format(same_inputs), 'same_inputs': same_inputs } for same_inputs in [True, False])) def test_parallel_in_out(self, same_inputs): rng = random.PRNGKey(0) input_key1, input_key2, net_key = random.split(rng, 3) x1_1, x1_2 = np.split(random.normal(input_key1, (3, 21)), (10, ), axis=1) x2_1, x2_2 = np.split(random.normal(input_key2, (4, 21)), (10, ), axis=1) x1 = (x1_1, x1_2) x2 = (x2_1, x2_2) if not same_inputs else None def layer(N_out): return stax.parallel(stax.Dense(N_out), stax.Dense(N_out + 1)) init_fn, apply_fn, _ = stax.serial(layer(1024), layer(1)) _, params = init_fn(net_key, (x1_1.shape, x1_2.shape)) implicit_kernel_fn = jit( nt.empirical_ntk_fn(apply_fn, implementation=2)) direct_kernel_fn = jit(nt.empirical_ntk_fn(apply_fn, implementation=1)) implicit_batched_kernel_fn = jit( nt.empirical_ntk_fn(apply_fn, vmap_axes=(0, 0), implementation=2)) direct_batched_kernel_fn = jit( nt.empirical_ntk_fn(apply_fn, vmap_axes=(0, 0), implementation=1)) k_direct = direct_kernel_fn(x1, x2, params) self.assertAllClose(k_direct, implicit_kernel_fn(x1, x2, params)) self.assertAllClose(k_direct, direct_batched_kernel_fn(x1, x2, params)) self.assertAllClose(k_direct, implicit_batched_kernel_fn(x1, x2, params)) nngp_kernel_fn = jit(nt.empirical_nngp_fn(apply_fn)) nngp = nngp_kernel_fn(x1, x2, params) self.assertEqual(len(nngp), 2) self.assertEqual(nngp[0].shape, (3, 3 if same_inputs else 4)) self.assertEqual(nngp[1].shape, (3, 3 if same_inputs else 4)) @parameterized.named_parameters( test_utils.cases_from_list( { 'testcase_name': '_same_inputs={}'.format(same_inputs), 'same_inputs': same_inputs } for same_inputs in [True, False])) def test_parallel_nested(self, same_inputs): rng = random.PRNGKey(0) input_key1, input_key2, net_key = random.split(rng, 3) x1_1, x1_2, x1_3 = np.split(random.normal(input_key1, (3, 33)), (10, 21), axis=1) x2_1, x2_2, x2_3 = np.split(random.normal(input_key2, (4, 33)), (10, 21), axis=1) x1 = ([x1_1, x1_2], x1_3) x2 = ([x2_1, x2_2], x2_3) if not same_inputs else None def layer(N_out): return stax.parallel( stax.parallel(stax.Dense(N_out), stax.Dense(N_out + 1)), stax.Dense(N_out + 2)) init_fn, apply_fn, _ = stax.serial(layer(1024), layer(1)) _, params = init_fn(net_key, tree_map(np.shape, x1)) implicit_kernel_fn = jit( nt.empirical_ntk_fn(apply_fn, implementation=2)) direct_kernel_fn = jit(nt.empirical_ntk_fn(apply_fn, implementation=1)) implicit_batched_kernel_fn = jit( nt.empirical_ntk_fn(apply_fn, vmap_axes=([0, 0], 0), implementation=2)) direct_batched_kernel_fn = jit( nt.empirical_ntk_fn(apply_fn, vmap_axes=([0, 0], 0), implementation=1)) k_direct = direct_kernel_fn(x1, x2, params) self.assertAllClose(k_direct, implicit_kernel_fn(x1, x2, params)) self.assertAllClose(k_direct, direct_batched_kernel_fn(x1, x2, params)) self.assertAllClose(k_direct, implicit_batched_kernel_fn(x1, x2, params)) nngp_kernel_fn = jit(nt.empirical_nngp_fn(apply_fn)) nngp = nngp_kernel_fn(x1, x2, params) self.assertEqual(len(nngp), 2) nngp_shape = (3, 3 if same_inputs else 4) self.assertEqual(nngp[0][0].shape, nngp_shape) self.assertEqual(nngp[0][1].shape, nngp_shape) self.assertEqual(nngp[1].shape, nngp_shape) @parameterized.named_parameters( test_utils.cases_from_list( { 'testcase_name': '_same_inputs={}'.format(same_inputs), 'same_inputs': same_inputs } for same_inputs in [True, False])) def test_vmap_axes(self, same_inputs): n1, n2 = 3, 4 c1, c2, c3 = 9, 5, 7 h2, h3, w3 = 6, 8, 2 def get_x(n, k): k1, k2, k3 = random.split(k, 3) x1 = random.normal(k1, (n, c1)) x2 = random.normal(k2, (h2, n, c2)) x3 = random.normal(k3, (c3, w3, n, h3)) x = [(x1, x2), x3] return x x1 = get_x(n1, random.PRNGKey(1)) x2 = get_x(n2, random.PRNGKey(2)) if not same_inputs else None p1 = random.normal(random.PRNGKey(5), (n1, h2, h2)) p2 = None if same_inputs else random.normal(random.PRNGKey(6), (n2, h2, h2)) init_fn, apply_fn, _ = stax.serial( stax.parallel( stax.parallel( stax.serial(stax.Dense(4, 2., 0.1), stax.Relu(), stax.Dense(3, 1., 0.15)), # 1 stax.serial( stax.Conv(7, (2, ), padding='SAME', dimension_numbers=('HNC', 'OIH', 'NHC')), stax.Erf(), stax.Aggregate(1, 0, -1), stax.GlobalAvgPool(), stax.Dense(3, 0.5, 0.2)), # 2 ), stax.serial( stax.Conv(5, (2, 3), padding='SAME', dimension_numbers=('CWNH', 'IOHW', 'HWCN')), stax.Sin(), ) # 3 ), stax.parallel( stax.FanInSum(), stax.Conv(2, (2, 1), dimension_numbers=('HWCN', 'OIHW', 'HNWC')))) _, params = init_fn(random.PRNGKey(3), tree_map(np.shape, x1)) implicit = jit(nt.empirical_ntk_fn(apply_fn, implementation=2)) direct = jit(nt.empirical_ntk_fn(apply_fn, implementation=1)) implicit_batched = jit( nt.empirical_ntk_fn(apply_fn, vmap_axes=([(0, 1), 2], [-2, -3], dict(pattern=0)), implementation=2)) direct_batched = jit( nt.empirical_ntk_fn(apply_fn, vmap_axes=([(-2, -2), -2], [0, 1], dict(pattern=-3)), implementation=1)) k = direct(x1, x2, params, pattern=(p1, p2)) self.assertAllClose(k, implicit(x1, x2, params, pattern=(p1, p2))) self.assertAllClose(k, direct_batched(x1, x2, params, pattern=(p1, p2))) self.assertAllClose(k, implicit_batched(x1, x2, params, pattern=(p1, p2)))
class JacobianRulesTest(test_utils.NeuralTangentsTestCase): def _assert_is_diagonal(self, j, axis1, axis2, constant_diagonal: bool): c = j.shape[axis1] self.assertEqual(c, j.shape[axis2]) mask_shape = [c if i in (axis1, axis2) else 1 for i in range(j.ndim)] mask = np.eye(c, dtype=np.bool_).reshape(mask_shape) # Check that removing the diagonal makes the array all 0. j_masked = np.where(mask, np.zeros((), j.dtype), j) self.assertAllClose(np.zeros_like(j, j.dtype), j_masked) if constant_diagonal: # Check that diagonal is constant. if j.size != 0: j_diagonals = np.diagonal(j, axis1=axis1, axis2=axis2) self.assertAllClose(np.min(j_diagonals, -1), np.max(j_diagonals, -1)) def _assert_constant(self, j, axis): if axis is not None: j = np.moveaxis(j, axis, 0) j = list(j) for ji in j: self.assertAllClose(j[0], ji) def _compare_jacobians(self, j_fwd, j_rev, j_rule, primitive): if primitive == lax.convert_element_type_p: # Check that only one of fwd/red Jacobians matches the rule. e_fwd, e_rev = None, None try: self.assertAllClose(j_fwd, j_rule) except Exception as e: logging.exception( 'Forward-mode Jacobian does not match the rule.') e_fwd = e try: self.assertAllClose(j_rev, j_rule) except Exception as e: logging.exception( 'Reverse-mode Jacobian does not match the rule.') e_rev = e if e_fwd is not None and e_rev is not None: raise ValueError(e_fwd, e_rev) else: if primitive == lax.reshape_p: # Reshape Jacobian is special-case defined as identity. j_rule = j_rule.reshape(j_fwd.shape) self.assertAllClose(j_fwd, j_rev) if j_rule is not None: self.assertAllClose(j_fwd, j_rule) self.assertAllClose(j_rev, j_rule) def _test_primitive(self, primitive: Optional[Primitive], shapes, dtype, params): xs = _get_inputs(shapes, dtype) n = len(xs) eqn, f = _get_f_and_eqn(params, primitive, *xs) out = f(*xs) cts_in = ShapedArray(out.shape, out.dtype) argnums = tuple(range(n)) js_fwd = jax.jacfwd(f, argnums)(*xs) js_rev = jax.jacrev(f, argnums)(*xs) for idx in range(n): if primitive == lax.conv_general_dilated_p and idx == 0: raise absltest.SkipTest( 'Jacobian of CNN wrt inputs not implemented.') if primitive == lax.div_p and idx == 1: raise absltest.SkipTest( 'Division is linear only in the first arg.') invals = _get_invals(idx, *xs) j_fwd, j_rev = js_fwd[idx], js_rev[idx] if primitive in rules.JACOBIAN_RULES: j_rule = rules.JACOBIAN_RULES[primitive](eqn, idx, invals, cts_in) else: warnings.warn( f'Jacobian rule for {primitive} at position {idx} not ' f'found.') j_rule = None with self.subTest(f'Jacobian ({idx})'): self._compare_jacobians(j_fwd, j_rev, j_rule, primitive) structure = rules.STRUCTURE_RULES[primitive](eqn, idx, invals, cts_in) j = j_fwd if j_rule is None else j_rule if primitive == lax.reshape_p: out_ndim = xs[0].ndim j = j.transpose( tuple(xs[0].ndim + i for i in onp.argsort(structure.in_trace)) + tuple(i for i in onp.argsort(structure.in_trace))) j = j.reshape(xs[0].shape + tuple(xs[0].shape[i] for i in onp.argsort(structure.in_trace))) else: out_ndim = out.ndim with self.subTest(f'Diagonal axes ({idx})'): for i, o in zip(structure.in_diagonal, structure.out_diagonal): self._assert_is_diagonal(j=j, axis1=out_ndim + i[idx], axis2=o, constant_diagonal=False) with self.subTest(f'Constant diagonal axes ({idx})'): for i, o in zip(structure.in_trace, structure.out_trace): self._assert_is_diagonal(j=j, axis1=out_ndim + i, axis2=o, constant_diagonal=True) with self.subTest(f'Input broadcast axes ({idx})'): for i in structure.in_broadcast: self._assert_constant(j=j, axis=i) with self.subTest(f'Output broadcast axes ({idx})'): for i in structure.out_broadcast: self._assert_constant(j=j, axis=i) @parameterized.parameters( test_utils.cases_from_list( dict( primitive=primitive, shape=shape, dtype=dtype, params=params, ) for shape in _SHAPES for dtype in _DTYPES for primitive in _UNARY_PRIMITIVES.keys() for params in _UNARY_PRIMITIVES[primitive](shape, dtype))) def test_unary(self, primitive: Optional[Primitive], shape, dtype, params): if primitive == jax._src.dispatch.device_put_p: # Can't instantiate devices at test generation time; using subtests. for device in [None] + jax.devices() + jax.devices('cpu'): with self.subTest(device=device): params = {'device': device} self._test_primitive(primitive, [shape], dtype, params) else: self._test_primitive(primitive, [shape], dtype, params) @parameterized.parameters( test_utils.cases_from_list( dict(primitive=primitive, shape1=shape1, shape2=shape2, dtype=dtype, params=params) for shape1 in _SHAPES for shape2 in _SHAPES for dtype in _DTYPES for primitive in _BINARY_PRIMITIVES.keys() for params in _BINARY_PRIMITIVES[primitive](shape1, shape2))) def test_binary(self, primitive: Optional[Primitive], shape1, shape2, dtype, params): # TODO(romann): revisit when bugs below are fixed. if primitive == lax.conv_general_dilated_p: if jax.default_backend() == 'tpu': raise absltest.SkipTest('http://b/235167364') elif jax.default_backend( ) == 'gpu' and params['batch_group_count'] != 1: raise absltest.SkipTest('http://b/235485533') if len(shape1) > 3 or len(shape2) > 3: test_utils.skip_test(self) self._test_primitive(primitive, [shape1, shape2], dtype, params) @parameterized.parameters( test_utils.cases_from_list( dict( primitive=primitive, shapes=shapes, dtype=dtype, params=params) for shapes in _concat_shapes(4, *_SHAPES) for dtype in _DTYPES for primitive in _N_ARY_PRIMITIVES.keys() for params in _N_ARY_PRIMITIVES[primitive](*shapes))) def test_n_ary(self, primitive: Optional[Primitive], shapes, dtype, params): self._test_primitive(primitive, shapes, dtype, params)
class FanInTest(test_utils.NeuralTangentsTestCase): @classmethod def _get_phi(cls, i): return { 0: stax.Relu(), 1: stax.Erf(), 2: stax.Abs() }[i % 3] @parameterized.named_parameters( test_utils.cases_from_list( { 'testcase_name': ' [{}_axis={}_n_branches={}_{}_{}_{}]'.format( 'same_inputs' if same_inputs else 'different_inputs', axis, n_branches, get, branch_in, fan_in_mode), 'same_inputs': same_inputs, 'axis': axis, 'n_branches': n_branches, 'get': get, 'branch_in': branch_in, 'fan_in_mode': fan_in_mode, } for same_inputs in [False] for axis in [0, 1] for n_branches in [3] for get in ['ntk'] for branch_in in ['dense_before_branch_in', 'dense_after_branch_in'] for fan_in_mode in ['FanInSum', 'FanInConcat', 'FanInProd'])) def test_fan_in_fc(self, same_inputs, axis, n_branches, get, branch_in, fan_in_mode): if fan_in_mode in ['FanInSum', 'FanInProd']: if axis != 0: raise absltest.SkipTest('`FanInSum` and `FanInProd` are skipped when ' 'axis != 0.') axis = None if (fan_in_mode == 'FanInSum' or axis == 0) and branch_in == 'dense_after_branch_in': raise absltest.SkipTest('`FanInSum` and `FanInConcat(0)` ' 'require `is_gaussian`.') if ((axis == 1 or fan_in_mode == 'FanInProd') and branch_in == 'dense_before_branch_in'): raise absltest.SkipTest( '`FanInConcat` or `FanInProd` on feature axis requires a dense layer ' 'after concatenation or Hadamard product.') if fan_in_mode == 'FanInSum': fan_in_layer = stax.FanInSum() elif fan_in_mode == 'FanInProd': fan_in_layer = stax.FanInProd() else: fan_in_layer = stax.FanInConcat(axis) if n_branches != 2: test_utils.skip_test(self) key = random.PRNGKey(1) X0_1 = np.cos(random.normal(key, (4, 3))) X0_2 = None if same_inputs else random.normal(key, (8, 3)) width = 1024 n_samples = 256 * 2 if default_backend() == 'tpu': tol = 0.07 else: tol = 0.02 dense = stax.Dense(width, 1.25, 0.1) input_layers = [dense, stax.FanOut(n_branches)] branches = [] for b in range(n_branches): branch_layers = [FanInTest._get_phi(b)] for i in range(b): multiplier = 1 if axis not in (1, -1) else (1 + 0.25 * i) branch_layers += [ stax.Dense(int(width * multiplier), 1. + 2 * i, 0.5 + i), FanInTest._get_phi(i)] if branch_in == 'dense_before_branch_in': branch_layers += [dense] branches += [stax.serial(*branch_layers)] output_layers = [ fan_in_layer, stax.Relu() ] if branch_in == 'dense_after_branch_in': output_layers.insert(1, dense) nn = stax.serial(*(input_layers + [stax.parallel(*branches)] + output_layers)) if get == 'nngp': init_fn, apply_fn, kernel_fn = nn elif get == 'ntk': init_fn, apply_fn, kernel_fn = stax.serial(nn, stax.Dense(1, 1.25, 0.5)) else: raise ValueError(get) kernel_fn_mc = nt.monte_carlo_kernel_fn( init_fn, apply_fn, key, n_samples, device_count=0 if axis in (0, -2) else -1, implementation=2, vmap_axes=None if axis in (0, -2) else 0, ) exact = kernel_fn(X0_1, X0_2, get=get) empirical = kernel_fn_mc(X0_1, X0_2, get=get) test_utils.assert_close_matrices(self, empirical, exact, tol) @parameterized.named_parameters( test_utils.cases_from_list( { 'testcase_name': ' [{}_axis={}_n_branches={}_{}_{}_{}_{}]'.format( 'same_inputs' if same_inputs else 'different_inputs', axis, n_branches, get, branch_in, readout, fan_in_mode), 'same_inputs': same_inputs, 'axis': axis, 'n_branches': n_branches, 'get': get, 'branch_in': branch_in, 'readout': readout, 'fan_in_mode': fan_in_mode, } for same_inputs in [False] for axis in [0, 1, 2, 3] for n_branches in [2] for get in ['ntk'] for branch_in in ['dense_before_branch_in', 'dense_after_branch_in'] for readout in ['pool', 'flatten'] for fan_in_mode in ['FanInSum', 'FanInConcat', 'FanInProd'])) def test_fan_in_conv(self, same_inputs, axis, n_branches, get, branch_in, readout, fan_in_mode): test_utils.skip_test(self) if fan_in_mode in ['FanInSum', 'FanInProd']: if axis != 0: raise absltest.SkipTest('`FanInSum` and `FanInProd()` are skipped when ' 'axis != 0.') axis = None if (fan_in_mode == 'FanInSum' or axis in [0, 1, 2]) and branch_in == 'dense_after_branch_in': raise absltest.SkipTest('`FanInSum` and `FanInConcat(0/1/2)` ' 'require `is_gaussian`.') if ((axis == 3 or fan_in_mode == 'FanInProd') and branch_in == 'dense_before_branch_in'): raise absltest.SkipTest('`FanInConcat` or `FanInProd` on feature axis ' 'requires a dense layer after concatenation ' 'or Hadamard product.') if fan_in_mode == 'FanInSum': fan_in_layer = stax.FanInSum() elif fan_in_mode == 'FanInProd': fan_in_layer = stax.FanInProd() else: fan_in_layer = stax.FanInConcat(axis) key = random.PRNGKey(1) X0_1 = random.normal(key, (2, 5, 6, 3)) X0_2 = None if same_inputs else random.normal(key, (3, 5, 6, 3)) if default_backend() == 'tpu': width = 2048 n_samples = 1024 tol = 0.02 else: width = 1024 n_samples = 512 tol = 0.01 conv = stax.Conv(out_chan=width, filter_shape=(3, 3), padding='SAME', W_std=1.25, b_std=0.1) input_layers = [conv, stax.FanOut(n_branches)] branches = [] for b in range(n_branches): branch_layers = [FanInTest._get_phi(b)] for i in range(b): multiplier = 1 if axis not in (3, -1) else (1 + 0.25 * i) branch_layers += [ stax.Conv( out_chan=int(width * multiplier), filter_shape=(i + 1, 4 - i), padding='SAME', W_std=1.25 + i, b_std=0.1 + i), FanInTest._get_phi(i)] if branch_in == 'dense_before_branch_in': branch_layers += [conv] branches += [stax.serial(*branch_layers)] output_layers = [ fan_in_layer, stax.Relu(), stax.GlobalAvgPool() if readout == 'pool' else stax.Flatten() ] if branch_in == 'dense_after_branch_in': output_layers.insert(1, conv) nn = stax.serial(*(input_layers + [stax.parallel(*branches)] + output_layers)) init_fn, apply_fn, kernel_fn = stax.serial( nn, stax.Dense(1 if get == 'ntk' else width, 1.25, 0.5)) kernel_fn_mc = nt.monte_carlo_kernel_fn( init_fn, apply_fn, key, n_samples, device_count=0 if axis in (0, -4) else -1, implementation=2, vmap_axes=None if axis in (0, -4) else 0, ) exact = kernel_fn(X0_1, X0_2, get=get) empirical = kernel_fn_mc(X0_1, X0_2, get=get) test_utils.assert_close_matrices(self, empirical, exact, tol)