def WideResnet(block_size, k, num_classes): return stax.serial(stax.Conv(16, (3, 3), padding='SAME'), WideResnetGroup(block_size, int(16 * k)), WideResnetGroup(block_size, int(32 * k), (2, 2)), WideResnetGroup(block_size, int(64 * k), (2, 2)), stax.GlobalAvgPool(), stax.Flatten(), stax.Dense(num_classes, 1., 0.))
def _build_network(input_shape, network, out_logits): if len(input_shape) == 1: assert network == FLAT return stax.Dense(out_logits, W_std=2.0, b_std=0.5) elif len(input_shape) == 3: if network == POOLING: return stax.serial( stax.Conv(CONVOLUTION_CHANNELS, (3, 3), W_std=2.0, b_std=0.05), stax.GlobalAvgPool(), stax.Dense(out_logits, W_std=2.0, b_std=0.5)) elif network == CONV: return stax.serial( stax.Conv(CONVOLUTION_CHANNELS, (1, 2), W_std=1.5, b_std=0.1), stax.Relu(), stax.Conv(CONVOLUTION_CHANNELS, (3, 2), W_std=2.0, b_std=0.05), ) elif network == FLAT: return stax.serial( stax.Conv(CONVOLUTION_CHANNELS, (3, 3), W_std=2.0, b_std=0.05), stax.Flatten(), stax.Dense(out_logits, W_std=2.0, b_std=0.5)) else: raise ValueError( 'Unexpected network type found: {}'.format(network)) else: raise ValueError('Expected flat or image test input.')
def test_composition_conv(self, avg_pool): rng = random.PRNGKey(0) x1 = random.normal(rng, (5, 10, 10, 3)) x2 = random.normal(rng, (5, 10, 10, 3)) Block = stax.serial(stax.Conv(256, (3, 3)), stax.Relu()) if avg_pool: Readout = stax.serial(stax.GlobalAvgPool(), stax.Dense(10)) marginalization = 'none' else: Readout = stax.serial(stax.Flatten(), stax.Dense(10)) marginalization = 'auto' block_ker_fn, readout_ker_fn = Block[2], Readout[2] _, _, composed_ker_fn = stax.serial(Block, Readout) ker_out = readout_ker_fn( block_ker_fn(x1, marginalization=marginalization)) composed_ker_out = composed_ker_fn(x1) self.assertAllClose(ker_out, composed_ker_out, True) if avg_pool: with self.assertRaises(ValueError): ker_out = readout_ker_fn(block_ker_fn(x1)) ker_out = readout_ker_fn( block_ker_fn(x1, x2, marginalization=marginalization)) composed_ker_out = composed_ker_fn(x1, x2) self.assertAllClose(ker_out, composed_ker_out, True)
def _build_network(input_shape, network, out_logits, use_dropout): dropout = stax.Dropout(0.9, mode='train') if use_dropout else stax.Identity() if len(input_shape) == 1: assert network == 'FLAT' return stax.serial(stax.Dense(WIDTH, W_std=2.0, b_std=0.5), dropout, stax.Dense(out_logits, W_std=2.0, b_std=0.5)) elif len(input_shape) == 3: if network == POOLING: return stax.serial( stax.Conv(CONVOLUTION_CHANNELS, (2, 2), W_std=2.0, b_std=0.05), stax.GlobalAvgPool(), dropout, stax.Dense(out_logits, W_std=2.0, b_std=0.5)) elif network == FLAT: return stax.serial( stax.Conv(CONVOLUTION_CHANNELS, (2, 2), W_std=2.0, b_std=0.05), stax.Flatten(), dropout, stax.Dense(out_logits, W_std=2.0, b_std=0.5)) elif network == INTERMEDIATE_CONV: return stax.Conv(CONVOLUTION_CHANNELS, (2, 2), W_std=2.0, b_std=0.05) else: raise ValueError( 'Unexpected network type found: {}'.format(network)) else: raise ValueError('Expected flat or image test input.')
def _get_net(W_std, b_std, filter_shape, is_conv, use_pooling, is_res, padding, phi, strides, width, is_ntk): fc = partial(stax.Dense, W_std=W_std, b_std=b_std) conv = partial( stax.Conv, filter_shape=filter_shape, strides=strides, padding=padding, W_std=W_std, b_std=b_std) affine = conv(width) if is_conv else fc(width) res_unit = stax.serial((stax.AvgPool( (2, 3), None, 'SAME' if padding == 'SAME' else 'CIRCULAR') if use_pooling else stax.Identity()), phi, affine) if is_res: block = stax.serial(affine, stax.FanOut(2), stax.parallel(stax.Identity(), res_unit), stax.FanInSum()) else: block = stax.serial(affine, res_unit) readout = stax.serial(stax.GlobalAvgPool() if use_pooling else stax.Flatten(), fc(1 if is_ntk else width)) net = stax.serial(block, readout) return net
def test_exp_normalized(self): key = random.PRNGKey(0) x1 = random.normal(key, (2, 6, 7, 1)) x2 = random.normal(key, (4, 6, 7, 1)) for do_clip in [True, False]: for gamma in [1., 2., 0.5]: for get in ['nngp', 'ntk']: with self.subTest(do_clip=do_clip, gamma=gamma, get=get): _, _, kernel_fn = stax.serial( stax.Conv(1, (3, 3)), stax.ExpNormalized(gamma, do_clip), stax.Conv(1, (3, 3)), stax.ExpNormalized(gamma, do_clip), stax.GlobalAvgPool(), stax.Dense(1) ) k_12 = kernel_fn(x1, x2, get=get) self.assertEqual(k_12.shape, (x1.shape[0], x2.shape[0])) k_11 = kernel_fn(x1, None, get=get) self.assertEqual(k_11.shape, (x1.shape[0],) * 2) self.assertGreater(np.min(np.linalg.eigvalsh(k_11)), 0) k_22 = kernel_fn(x2, None, get=get) self.assertEqual(k_22.shape, (x2.shape[0],) * 2) self.assertGreater(np.min(np.linalg.eigvalsh(k_22)), 0)
def test_hermite(self, same_inputs, degree, get, readout): key = random.PRNGKey(1) key1, key2, key = random.split(key, 3) if degree > 2: width = 10000 n_samples = 5000 test_utils.skip_test(self) else: width = 10000 n_samples = 100 x1 = np.cos(random.normal(key1, [2, 6, 6, 3])) x2 = x1 if same_inputs else np.cos(random.normal(key2, [3, 6, 6, 3])) conv_layers = [ stax.Conv(width, (3, 3), W_std=2., b_std=0.5), stax.LayerNorm(), stax.Hermite(degree), stax.GlobalAvgPool() if readout == 'pool' else stax.Flatten(), stax.Dense(1) if get == 'ntk' else stax.Identity()] init_fn, apply_fn, kernel_fn = stax.serial(*conv_layers) analytic_kernel = kernel_fn(x1, x2, get) mc_kernel_fn = nt.monte_carlo_kernel_fn(init_fn, apply_fn, key, n_samples) mc_kernel = mc_kernel_fn(x1, x2, get) rot = degree / 2. * 1e-2 test_utils.assert_close_matrices(self, mc_kernel, analytic_kernel, rot)
def test_composition_conv(self, avg_pool, same_inputs): rng = random.PRNGKey(0) x1 = random.normal(rng, (3, 5, 5, 3)) x2 = None if same_inputs else random.normal(rng, (4, 5, 5, 3)) Block = stax.serial(stax.Conv(256, (3, 3)), stax.Relu()) if avg_pool: Readout = stax.serial(stax.Conv(256, (3, 3)), stax.GlobalAvgPool(), stax.Dense(10)) else: Readout = stax.serial(stax.Flatten(), stax.Dense(10)) block_ker_fn, readout_ker_fn = Block[2], Readout[2] _, _, composed_ker_fn = stax.serial(Block, Readout) composed_ker_out = composed_ker_fn(x1, x2) ker_out_no_marg = readout_ker_fn(block_ker_fn(x1, x2, diagonal_spatial=False)) ker_out_default = readout_ker_fn(block_ker_fn(x1, x2)) self.assertAllClose(composed_ker_out, ker_out_no_marg) self.assertAllClose(composed_ker_out, ker_out_default) if avg_pool: with self.assertRaises(ValueError): ker_out = readout_ker_fn(block_ker_fn(x1, x2, diagonal_spatial=True)) else: ker_out_marg = readout_ker_fn(block_ker_fn(x1, x2, diagonal_spatial=True)) self.assertAllClose(composed_ker_out, ker_out_marg)
def test_nested_parallel(self, same_inputs, kernel_type): platform = default_backend() rtol = RTOL if platform != 'tpu' else 0.05 rng = random.PRNGKey(0) (input_key1, input_key2, input_key3, input_key4, mask_key, mc_key) = random.split(rng, 6) x1_1, x2_1 = _get_inputs(input_key1, same_inputs, (BATCH_SIZE, 5)) x1_2, x2_2 = _get_inputs(input_key2, same_inputs, (BATCH_SIZE, 2, 2, 2)) x1_3, x2_3 = _get_inputs(input_key3, same_inputs, (BATCH_SIZE, 2, 2, 3)) x1_4, x2_4 = _get_inputs(input_key4, same_inputs, (BATCH_SIZE, 3, 4)) m1_key, m2_key, m3_key, m4_key = random.split(mask_key, 4) x1_1 = test_utils.mask( x1_1, mask_constant=-1, mask_axis=(1,), key=m1_key, p=0.5) x1_2 = test_utils.mask( x1_2, mask_constant=-1, mask_axis=(2, 3,), key=m2_key, p=0.5) if not same_inputs: x2_3 = test_utils.mask( x2_3, mask_constant=-1, mask_axis=(1, 3,), key=m3_key, p=0.5) x2_4 = test_utils.mask( x2_4, mask_constant=-1, mask_axis=(2,), key=m4_key, p=0.5) x1 = (((x1_1, x1_2), x1_3), x1_4) x2 = (((x2_1, x2_2), x2_3), x2_4) if not same_inputs else None N_in = 2 ** 7 # We only include dropout on non-TPU backends, because it takes large N to # converge on TPU. dropout_or_id = stax.Dropout(0.9) if platform != 'tpu' else stax.Identity() init_fn, apply_fn, kernel_fn = stax.parallel( stax.parallel( stax.parallel(stax.Dense(N_in), stax.serial(stax.Conv(N_in + 1, (2, 2)), stax.Flatten())), stax.serial(stax.Conv(N_in + 2, (2, 2)), dropout_or_id, stax.GlobalAvgPool())), stax.Conv(N_in + 3, (2,))) kernel_fn_empirical = nt.monte_carlo_kernel_fn( init_fn, apply_fn, mc_key, N_SAMPLES, implementation=_DEFAULT_TESTING_NTK_IMPLEMENTATION, vmap_axes=(((((0, 0), 0), 0), (((0, 0), 0), 0), {}) if platform == 'tpu' else None) ) test_utils.assert_close_matrices( self, kernel_fn(x1, x2, get=kernel_type, mask_constant=-1), kernel_fn_empirical(x1, x2, get=kernel_type, mask_constant=-1), rtol)
def _test_activation(self, activation_fn, same_inputs, model, get, rbf_gamma=None): if 'conv' in model: test_utils.skip_test(self) key = random.PRNGKey(1) key, split = random.split(key) output_dim = 1024 if get == 'nngp' else 1 b_std = 0.5 W_std = 2.0 if activation_fn[2].__name__ == 'Sin': W_std = 0.9 if activation_fn[2].__name__ == 'Rbf': W_std = 1.0 b_std = 0.0 if model == 'fc': rtol = 0.04 X0_1 = random.normal(key, (4, 2)) X0_2 = None if same_inputs else random.normal(split, (2, 2)) affine = stax.Dense(1024, W_std, b_std) readout = stax.Dense(output_dim) depth = 1 else: rtol = 0.05 X0_1 = random.normal(key, (2, 4, 4, 3)) X0_2 = None if same_inputs else random.normal(split, (4, 4, 4, 3)) affine = stax.Conv(512, (3, 2), W_std=W_std, b_std=b_std, padding='SAME') readout = stax.serial(stax.GlobalAvgPool() if 'pool' in model else stax.Flatten(), stax.Dense(output_dim)) depth = 2 if default_backend() == 'cpu': num_samplings = 200 rtol *= 2 else: num_samplings = (500 if activation_fn[2].__name__ in ('Sin', 'Rbf') else 300) init_fn, apply_fn, kernel_fn = stax.serial( *[affine, activation_fn]*depth, readout) analytic_kernel = kernel_fn(X0_1, X0_2, get) mc_kernel_fn = nt.monte_carlo_kernel_fn( init_fn, apply_fn, split, num_samplings, implementation=2, vmap_axes=0 ) empirical_kernel = mc_kernel_fn(X0_1, X0_2, get) test_utils.assert_close_matrices(self, analytic_kernel, empirical_kernel, rtol) # Check match with explicit RBF if rbf_gamma is not None and get == 'nngp' and model == 'fc': input_dim = X0_1.shape[1] _, _, kernel_fn = self._RBF(rbf_gamma / input_dim) direct_rbf_kernel = kernel_fn(X0_1, X0_2, get) test_utils.assert_close_matrices(self, analytic_kernel, direct_rbf_kernel, rtol)
def _test_analytic_kernel_composition(self, batching_fn): # Check Fully-Connected. rng = stateless_uniform(shape=[2], seed=[0, 0], minval=None, maxval=None, dtype=tf.int32) keys = tf_random_split(rng) rng_self = keys[0] rng_other = keys[1] x_self = np.asarray(normal((8, 10), seed=rng_self)) x_other = np.asarray(normal((2, 10), seed=rng_other)) Block = stax.serial(stax.Dense(256), stax.Relu()) _, _, ker_fn = Block ker_fn = batching_fn(ker_fn) _, _, composed_ker_fn = stax.serial(Block, Block) ker_out = ker_fn(ker_fn(x_self)) composed_ker_out = composed_ker_fn(x_self) if batching_fn == batch._parallel: # In the parallel setting, `x1_is_x2` is not computed correctly # when x1==x2. composed_ker_out = composed_ker_out.replace( x1_is_x2=ker_out.x1_is_x2) self.assertAllClose(ker_out, composed_ker_out) ker_out = ker_fn(ker_fn(x_self, x_other)) composed_ker_out = composed_ker_fn(x_self, x_other) if batching_fn == batch._parallel: composed_ker_out = composed_ker_out.replace( x1_is_x2=ker_out.x1_is_x2) self.assertAllClose(ker_out, composed_ker_out) # Check convolutional + pooling. x_self = np.asarray(normal((8, 10, 10, 3), seed=rng)) x_other = np.asarray(normal((2, 10, 10, 3), seed=rng)) Block = stax.serial(stax.Conv(256, (2, 2)), stax.Relu()) Readout = stax.serial(stax.GlobalAvgPool(), stax.Dense(10)) block_ker_fn, readout_ker_fn = Block[2], Readout[2] _, _, composed_ker_fn = stax.serial(Block, Readout) block_ker_fn = batching_fn(block_ker_fn) readout_ker_fn = batching_fn(readout_ker_fn) ker_out = readout_ker_fn(block_ker_fn(x_self)) composed_ker_out = composed_ker_fn(x_self) if batching_fn == batch._parallel: composed_ker_out = composed_ker_out.replace( x1_is_x2=ker_out.x1_is_x2) self.assertAllClose(ker_out, composed_ker_out) ker_out = readout_ker_fn(block_ker_fn(x_self, x_other)) composed_ker_out = composed_ker_fn(x_self, x_other) if batching_fn == batch._parallel: composed_ker_out = composed_ker_out.replace( x1_is_x2=ker_out.x1_is_x2) self.assertAllClose(ker_out, composed_ker_out)
def test_vmap_axes(self, same_inputs): n1, n2 = 3, 4 c1, c2, c3 = 9, 5, 7 h2, h3, w3 = 6, 8, 2 def get_x(n, k): k1, k2, k3 = random.split(k, 3) x1 = random.normal(k1, (n, c1)) x2 = random.normal(k2, (h2, n, c2)) x3 = random.normal(k3, (c3, w3, n, h3)) x = [(x1, x2), x3] return x x1 = get_x(n1, random.PRNGKey(1)) x2 = get_x(n2, random.PRNGKey(2)) if not same_inputs else None p1 = random.normal(random.PRNGKey(5), (n1, h2, h2)) p2 = None if same_inputs else random.normal(random.PRNGKey(6), (n2, h2, h2)) init_fn, apply_fn, _ = stax.serial( stax.parallel( stax.parallel( stax.serial(stax.Dense(4, 2., 0.1), stax.Relu(), stax.Dense(3, 1., 0.15)), # 1 stax.serial(stax.Conv(7, (2,), padding='SAME', dimension_numbers=('HNC', 'OIH', 'NHC')), stax.Erf(), stax.Aggregate(1, 0, -1), stax.GlobalAvgPool(), stax.Dense(3, 0.5, 0.2)), # 2 ), stax.serial( stax.Conv(5, (2, 3), padding='SAME', dimension_numbers=('CWNH', 'IOHW', 'HWCN')), stax.Sin(), ) # 3 ), stax.parallel( stax.FanInSum(), stax.Conv(2, (2, 1), dimension_numbers=('HWCN', 'OIHW', 'HNWC')) ) ) _, params = init_fn(random.PRNGKey(3), tree_map(np.shape, x1)) implicit = jit(empirical._empirical_implicit_ntk_fn(apply_fn)) direct = jit(empirical._empirical_direct_ntk_fn(apply_fn)) implicit_batched = jit(empirical._empirical_implicit_ntk_fn( apply_fn, vmap_axes=([(0, 1), 2], [-2, -3], dict(pattern=0)))) direct_batched = jit(empirical._empirical_direct_ntk_fn( apply_fn, vmap_axes=([(-2, -2), -2], [0, 1], dict(pattern=-3)))) k = direct(x1, x2, params, pattern=(p1, p2)) self.assertAllClose(k, implicit(x1, x2, params, pattern=(p1, p2))) self.assertAllClose(k, direct_batched(x1, x2, params, pattern=(p1, p2))) self.assertAllClose(k, implicit_batched(x1, x2, params, pattern=(p1, p2)))
def _get_net(W_std, b_std, filter_shape, is_conv, use_pooling, is_res, padding, phi, strides, width, is_ntk, proj_into_2d, layer_norm, parameterization, use_dropout): fc = partial(stax.Dense, W_std=W_std, b_std=b_std, parameterization=parameterization) conv = partial(stax.Conv, filter_shape=filter_shape, strides=strides, padding=padding, W_std=W_std, b_std=b_std, parameterization=parameterization) affine = conv(width) if is_conv else fc(width) rate = np.onp.random.uniform(0.5, 0.9) dropout = stax.Dropout(rate, mode='train') ave_pool = stax.AvgPool((2, 3), None, 'SAME' if padding == 'SAME' else 'CIRCULAR') ave_pool_or_identity = ave_pool if use_pooling else stax.Identity() dropout_or_identity = dropout if use_dropout else stax.Identity() layer_norm_or_identity = (stax.Identity() if layer_norm is None else stax.LayerNorm(axis=layer_norm)) res_unit = stax.serial(ave_pool_or_identity, phi, dropout_or_identity, affine) if is_res: block = stax.serial(affine, stax.FanOut(2), stax.parallel(stax.Identity(), res_unit), stax.FanInSum(), layer_norm_or_identity) else: block = stax.serial(affine, res_unit, layer_norm_or_identity) if proj_into_2d == 'FLAT': proj_layer = stax.Flatten() elif proj_into_2d == 'POOL': proj_layer = stax.GlobalAvgPool() elif proj_into_2d.startswith('ATTN'): n_heads = int(np.sqrt(width)) n_chan_val = int(np.round(float(width) / n_heads)) fixed = proj_into_2d == 'ATTN_FIXED' proj_layer = stax.serial( stax.GlobalSelfAttention(width, n_chan_key=width, n_chan_val=n_chan_val, n_heads=n_heads, fixed=fixed, W_key_std=W_std, W_value_std=W_std, W_query_std=W_std, W_out_std=1.0, b_std=b_std), stax.Flatten()) else: raise ValueError(proj_into_2d) readout = stax.serial(proj_layer, fc(1 if is_ntk else width)) return stax.serial(block, readout)
def _test_analytic_kernel_composition(self, batching_fn): # Check Fully-Connected. rng = random.PRNGKey(0) rng_self, rng_other = random.split(rng) x_self = random.normal(rng_self, (8, 10)) x_other = random.normal(rng_other, (2, 10)) Block = stax.serial(stax.Dense(256), stax.Relu()) _, _, ker_fn = Block ker_fn = batching_fn(ker_fn) _, _, composed_ker_fn = stax.serial(Block, Block) ker_out = ker_fn(ker_fn(x_self)) composed_ker_out = composed_ker_fn(x_self) if batching_fn == batch._parallel: # In the parallel setting, `x1_is_x2` is not computed correctly # when x1==x2. composed_ker_out = composed_ker_out._replace( x1_is_x2=ker_out.x1_is_x2) self.assertAllClose(ker_out, composed_ker_out, True) ker_out = ker_fn(ker_fn(x_self, x_other)) composed_ker_out = composed_ker_fn(x_self, x_other) if batching_fn == batch._parallel: composed_ker_out = composed_ker_out._replace( x1_is_x2=ker_out.x1_is_x2) self.assertAllClose(ker_out, composed_ker_out, True) # Check convolutional + pooling. x_self = random.normal(rng, (8, 10, 10, 3)) x_other = random.normal(rng, (2, 10, 10, 3)) Block = stax.serial(stax.Conv(256, (2, 2)), stax.Relu()) Readout = stax.serial(stax.GlobalAvgPool(), stax.Dense(10)) block_ker_fn, readout_ker_fn = Block[2], Readout[2] _, _, composed_ker_fn = stax.serial(Block, Readout) block_ker_fn = batching_fn(block_ker_fn) readout_ker_fn = batching_fn(readout_ker_fn) ker_out = readout_ker_fn(block_ker_fn(x_self, marginalization='none')) composed_ker_out = composed_ker_fn(x_self) if batching_fn == batch._parallel: composed_ker_out = composed_ker_out._replace( x1_is_x2=ker_out.x1_is_x2) self.assertAllClose(ker_out, composed_ker_out, True) ker_out = readout_ker_fn( block_ker_fn(x_self, x_other, marginalization='none')) composed_ker_out = composed_ker_fn(x_self, x_other) if batching_fn == batch._parallel: composed_ker_out = composed_ker_out._replace( x1_is_x2=ker_out.x1_is_x2) self.assertAllClose(ker_out, composed_ker_out, True)
def _test_analytic_kernel_composition(self, batching_fn): # Check Fully-Connected. rng = random.PRNGKey(0) rng_self, rng_other = random.split(rng) x_self = random.normal(rng_self, (8, 10)) x_other = random.normal(rng_other, (20, 10)) Block = stax.serial(stax.Dense(256), stax.Relu()) _, _, ker_fn = Block ker_fn = batching_fn(ker_fn) _, _, composed_ker_fn = stax.serial(Block, Block) ker_out = ker_fn(ker_fn(x_self)) composed_ker_out = composed_ker_fn(x_self) self.assertAllClose(ker_out, composed_ker_out, True) ker_out = ker_fn(ker_fn(x_self, x_other)) composed_ker_out = composed_ker_fn(x_self, x_other) self.assertAllClose(ker_out, composed_ker_out, True) # Check convolutional + pooling. x_self = random.normal(rng, (8, 10, 10, 3)) x_other = random.normal(rng, (10, 10, 10, 3)) Block = stax.serial(stax.Conv(256, (3, 3)), stax.Relu()) Readout = stax.serial(stax.GlobalAvgPool(), stax.Dense(10)) block_ker_fn, readout_ker_fn = Block[2], Readout[2] _, _, composed_ker_fn = stax.serial(Block, Readout) block_ker_fn = batching_fn(block_ker_fn) readout_ker_fn = batching_fn(readout_ker_fn) ker_out = readout_ker_fn(block_ker_fn(x_self, marginalization='none')) composed_ker_out = composed_ker_fn(x_self) self.assertAllClose(ker_out, composed_ker_out, True) ker_out = readout_ker_fn( block_ker_fn(x_self, x_other, marginalization='none')) composed_ker_out = composed_ker_fn(x_self, x_other) self.assertAllClose(ker_out, composed_ker_out, True)
def _get_net(W_std, b_std, filter_shape, is_conv, use_pooling, is_res, padding, phi, strides, width, is_ntk, proj_into_2d): fc = partial(stax.Dense, W_std=W_std, b_std=b_std) conv = partial( stax.Conv, filter_shape=filter_shape, strides=strides, padding=padding, W_std=W_std, b_std=b_std) affine = conv(width) if is_conv else fc(width) res_unit = stax.serial((stax.AvgPool( (2, 3), None, 'SAME' if padding == 'SAME' else 'CIRCULAR') if use_pooling else stax.Identity()), phi, affine) if is_res: block = stax.serial(affine, stax.FanOut(2), stax.parallel(stax.Identity(), res_unit), stax.FanInSum()) else: block = stax.serial(affine, res_unit) if proj_into_2d == 'FLAT': proj_layer = stax.Flatten() elif proj_into_2d == 'POOL': proj_layer = stax.GlobalAvgPool() elif proj_into_2d.startswith('ATTN'): n_heads = int(np.sqrt(width)) n_chan_val = int(np.round(float(width) / n_heads)) fixed = proj_into_2d == 'ATTN_FIXED' proj_layer = stax.serial( stax.GlobalSelfAttention( width, n_chan_key=width, n_chan_val=n_chan_val, n_heads=n_heads, fixed=fixed, W_key_std=W_std, W_value_std=W_std, W_query_std=W_std, W_out_std=1.0, b_std=b_std), stax.Flatten()) else: raise ValueError(proj_into_2d) readout = stax.serial(proj_layer, fc(1 if is_ntk else width)) return stax.serial(block, readout)
def test_elementwise_numerical(self, same_inputs, model, phi, get): if 'conv' in model: test_utils.skip_test(self) key, split = random.split(random.PRNGKey(1)) output_dim = 1 b_std = 0.01 W_std = 1.0 rtol = 2e-3 deg = 25 if get == 'ntk': rtol *= 2 if default_backend() == 'tpu': rtol *= 2 if model == 'fc': X0_1 = random.normal(key, (3, 7)) X0_2 = None if same_inputs else random.normal(split, (5, 7)) affine = stax.Dense(1024, W_std, b_std) readout = stax.Dense(output_dim) depth = 1 else: X0_1 = random.normal(key, (2, 8, 8, 3)) X0_2 = None if same_inputs else random.normal(split, (3, 8, 8, 3)) affine = stax.Conv(1024, (3, 2), W_std=W_std, b_std=b_std, padding='SAME') readout = stax.serial(stax.GlobalAvgPool() if 'pool' in model else stax.Flatten(), stax.Dense(output_dim)) depth = 2 _, _, kernel_fn = stax.serial(*[affine, phi] * depth, readout) analytic_kernel = kernel_fn(X0_1, X0_2, get) fn = lambda x: phi[1]((), x) _, _, kernel_fn = stax.serial( *[affine, stax.ElementwiseNumerical(fn, deg=deg)] * depth, readout) numerical_activation_kernel = kernel_fn(X0_1, X0_2, get) test_utils.assert_close_matrices(self, analytic_kernel, numerical_activation_kernel, rtol)
def CNNStandard(n_channels, L, filter=(3, 3), data='cifar10', gap=True, nonlinearity='relu', parameterization='standard', order=None): if data == 'cifar10': num_classes = 10 if data == 'cifar100': num_classes = 100 if nonlinearity == 'relu': nonlin = Relu elif nonlinearity == 'swish': nonlin = Swish init_fn, f = jax_stax.serial(*[ jax_stax.serial( MyConv(n_channels, filter, parameterization=parameterization, order=order), nonlin, ) for _ in range(L) ]) if gap: init_fn, f = jax_stax.serial((init_fn, f), stax.GlobalAvgPool()[:2], MyDense(num_classes, parameterization=parameterization, order=order)) else: init_fn, f = jax_stax.serial((init_fn, f), stax.Flatten()[:2], MyDense(num_classes, parameterization=parameterization, order=order)) return init_fn, f
def test_fan_in_conv(self, same_inputs, axis, n_branches, get, branch_in, readout, fan_in_mode): test_utils.skip_test(self) if fan_in_mode in ['FanInSum', 'FanInProd']: if axis != 0: raise absltest.SkipTest( '`FanInSum` and `FanInProd()` are skipped when ' 'axis != 0.') axis = None if (fan_in_mode == 'FanInSum' or axis in [0, 1, 2]) and branch_in == 'dense_after_branch_in': raise absltest.SkipTest('`FanInSum` and `FanInConcat(0/1/2)` ' 'require `is_gaussian`.') if ((axis == 3 or fan_in_mode == 'FanInProd') and branch_in == 'dense_before_branch_in'): raise absltest.SkipTest( '`FanInConcat` or `FanInProd` on feature axis ' 'requires a dense layer after concatenation ' 'or Hadamard product.') if fan_in_mode == 'FanInSum': fan_in_layer = stax.FanInSum() elif fan_in_mode == 'FanInProd': fan_in_layer = stax.FanInProd() else: fan_in_layer = stax.FanInConcat(axis) key = random.PRNGKey(1) X0_1 = random.normal(key, (2, 5, 6, 3)) X0_2 = None if same_inputs else random.normal(key, (3, 5, 6, 3)) if default_backend() == 'tpu': width = 2048 n_samples = 1024 tol = 0.02 else: width = 1024 n_samples = 512 tol = 0.01 conv = stax.Conv(out_chan=width, filter_shape=(3, 3), padding='SAME', W_std=1.25, b_std=0.1) input_layers = [conv, stax.FanOut(n_branches)] branches = [] for b in range(n_branches): branch_layers = [FanInTest._get_phi(b)] for i in range(b): multiplier = 1 if axis not in (3, -1) else (1 + 0.25 * i) branch_layers += [ stax.Conv(out_chan=int(width * multiplier), filter_shape=(i + 1, 4 - i), padding='SAME', W_std=1.25 + i, b_std=0.1 + i), FanInTest._get_phi(i) ] if branch_in == 'dense_before_branch_in': branch_layers += [conv] branches += [stax.serial(*branch_layers)] output_layers = [ fan_in_layer, stax.Relu(), stax.GlobalAvgPool() if readout == 'pool' else stax.Flatten() ] if branch_in == 'dense_after_branch_in': output_layers.insert(1, conv) nn = stax.serial(*(input_layers + [stax.parallel(*branches)] + output_layers)) init_fn, apply_fn, kernel_fn = stax.serial( nn, stax.Dense(1 if get == 'ntk' else width, 1.25, 0.5)) kernel_fn_mc = nt.monte_carlo_kernel_fn( init_fn, apply_fn, key, n_samples, device_count=0 if axis in (0, -4) else -1, implementation=_DEFAULT_TESTING_NTK_IMPLEMENTATION, vmap_axes=None if axis in (0, -4) else 0, ) exact = kernel_fn(X0_1, X0_2, get=get) empirical = kernel_fn_mc(X0_1, X0_2, get=get) test_utils.assert_close_matrices(self, empirical, exact, tol)
def test_kwargs(self, do_batch, mode): rng = random.PRNGKey(1) x_train = random.normal(rng, (8, 7, 10)) x_test = random.normal(rng, (4, 7, 10)) y_train = random.normal(rng, (8, 1)) rng_train, rng_test = random.split(rng, 2) pattern_train = random.normal(rng, (8, 7, 7)) pattern_test = random.normal(rng, (4, 7, 7)) init_fn, apply_fn, kernel_fn = stax.serial( stax.Dense(8), stax.Relu(), stax.Dropout(rate=0.4), stax.Aggregate(), stax.GlobalAvgPool(), stax.Dense(1) ) kw_dd = dict(pattern=(pattern_train, pattern_train)) kw_td = dict(pattern=(pattern_test, pattern_train)) kw_tt = dict(pattern=(pattern_test, pattern_test)) if mode == 'mc': kernel_fn = monte_carlo_kernel_fn(init_fn, apply_fn, rng, 2, batch_size=2 if do_batch else 0) elif mode == 'empirical': kernel_fn = empirical_kernel_fn(apply_fn) if do_batch: raise absltest.SkipTest('Batching of empirical kernel is not ' 'implemented with keyword arguments.') for kw in (kw_dd, kw_td, kw_tt): kw.update(dict(params=init_fn(rng, x_train.shape)[1], get=('nngp', 'ntk'))) kw_dd.update(dict(rng=(rng_train, None))) kw_td.update(dict(rng=(rng_test, rng_train))) kw_tt.update(dict(rng=(rng_test, None))) elif mode == 'analytic': if do_batch: kernel_fn = batch.batch(kernel_fn, batch_size=2) else: raise ValueError(mode) k_dd = kernel_fn(x_train, None, **kw_dd) k_td = kernel_fn(x_test, x_train, **kw_td) k_tt = kernel_fn(x_test, None, **kw_tt) # Infinite time NNGP/NTK. predict_fn_gp = predict.gp_inference(k_dd, y_train) out_gp = predict_fn_gp(k_test_train=k_td, nngp_test_test=k_tt.nngp) if mode == 'empirical': for kw in (kw_dd, kw_td, kw_tt): kw.pop('get') predict_fn_ensemble = predict.gradient_descent_mse_ensemble(kernel_fn, x_train, y_train, **kw_dd) out_ensemble = predict_fn_ensemble(x_test=x_test, compute_cov=True, **kw_tt) self.assertAllClose(out_gp, out_ensemble) # Finite time NTK test. predict_fn_mse = predict.gradient_descent_mse(k_dd.ntk, y_train) out_mse = predict_fn_mse(t=1., fx_train_0=None, fx_test_0=0., k_test_train=k_td.ntk) out_ensemble = predict_fn_ensemble(t=1., get='ntk', x_test=x_test, compute_cov=False, **kw_tt) self.assertAllClose(out_mse, out_ensemble) # Finite time NNGP train. predict_fn_mse = predict.gradient_descent_mse(k_dd.nngp, y_train) out_mse = predict_fn_mse(t=2., fx_train_0=0., fx_test_0=None, k_test_train=k_td.nngp) out_ensemble = predict_fn_ensemble(t=2., get='nngp', x_test=None, compute_cov=False, **kw_dd) self.assertAllClose(out_mse, out_ensemble)
def test_input_req(self, same_inputs): test_utils.skip_test(self) key = random.PRNGKey(1) x1 = random.normal(key, (2, 7, 8, 4, 3)) x2 = None if same_inputs else random.normal(key, (4, 7, 8, 4, 3)) _, _, wrong_conv_fn = stax.serial( stax.Conv(out_chan=1, filter_shape=(1, 2, 3), dimension_numbers=('NDHWC', 'HDWIO', 'NCDWH')), stax.Relu(), stax.Conv(out_chan=1, filter_shape=(1, 2, 3), dimension_numbers=('NHDWC', 'HWDIO', 'NCWHD'))) with self.assertRaises(ValueError): wrong_conv_fn(x1, x2) init_fn, apply_fn, correct_conv_fn = stax.serial( stax.Conv(out_chan=1024, filter_shape=(1, 2, 3), dimension_numbers=('NHWDC', 'DHWIO', 'NCWDH')), stax.Relu(), stax.Conv(out_chan=1024, filter_shape=(1, 2, 3), dimension_numbers=('NCHDW', 'WHDIO', 'NCDWH')), stax.Flatten(), stax.Dense(1024)) correct_conv_fn_mc = nt.monte_carlo_kernel_fn( init_fn=init_fn, apply_fn=apply_fn, key=key, n_samples=400, implementation=_DEFAULT_TESTING_NTK_IMPLEMENTATION, vmap_axes=0) K = correct_conv_fn(x1, x2, get='nngp') K_mc = correct_conv_fn_mc(x1, x2, get='nngp') self.assertAllClose(K, K_mc, atol=0.01, rtol=0.05) _, _, wrong_conv_fn = stax.serial( stax.Conv(out_chan=1, filter_shape=(1, 2, 3), dimension_numbers=('NDHWC', 'HDWIO', 'NCDWH')), stax.GlobalAvgPool(channel_axis=2)) with self.assertRaises(ValueError): wrong_conv_fn(x1, x2) init_fn, apply_fn, correct_conv_fn = stax.serial( stax.Conv(out_chan=1024, filter_shape=(1, 2, 3), dimension_numbers=('NHDWC', 'DHWIO', 'NDWCH')), stax.Relu(), stax.AvgPool((2, 1, 3), batch_axis=0, channel_axis=-2), stax.Conv(out_chan=1024, filter_shape=(1, 2, 3), dimension_numbers=('NDHCW', 'IHWDO', 'NDCHW')), stax.Relu(), stax.GlobalAvgPool(channel_axis=2), stax.Dense(1024)) correct_conv_fn_mc = nt.monte_carlo_kernel_fn( init_fn=init_fn, apply_fn=apply_fn, key=key, n_samples=300, implementation=_DEFAULT_TESTING_NTK_IMPLEMENTATION, vmap_axes=0) K = correct_conv_fn(x1, x2, get='nngp') K_mc = correct_conv_fn_mc(x1, x2, get='nngp') self.assertAllClose(K, K_mc, atol=0.01, rtol=0.05) _, _, wrong_conv_fn = stax.serial( stax.Flatten(), stax.Dense(1), stax.Erf(), stax.Conv(out_chan=1, filter_shape=(1, 2), dimension_numbers=('CN', 'IO', 'NC')), ) with self.assertRaises(ValueError): wrong_conv_fn(x1, x2) init_fn, apply_fn, correct_conv_fn = stax.serial( stax.Flatten(), stax.Conv(out_chan=1024, filter_shape=()), stax.Relu(), stax.Dense(1)) correct_conv_fn_mc = nt.monte_carlo_kernel_fn( init_fn=init_fn, apply_fn=apply_fn, key=key, n_samples=200, implementation=_DEFAULT_TESTING_NTK_IMPLEMENTATION, vmap_axes=0) K = correct_conv_fn(x1, x2, get='ntk') K_mc = correct_conv_fn_mc(x1, x2, get='ntk') self.assertAllClose(K, K_mc, atol=0.01, rtol=0.05)
def test_mask_conv(self, same_inputs, get, mask_axis, mask_constant, concat, proj, p, n, transpose): if isinstance(concat, int) and concat > n: raise absltest.SkipTest('Concatenation axis out of bounds.') test_utils.skip_test(self) if default_backend() == 'gpu' and n > 3: raise absltest.SkipTest('>=4D-CNN is not supported on GPUs.') width = 256 n_samples = 256 tol = 0.03 key = random.PRNGKey(1) spatial_shape = ((1, 2, 3, 2, 1) if transpose else (15, 8, 9))[:n] filter_shape = ((2, 3, 1, 2, 1) if transpose else (7, 2, 3))[:n] strides = (2, 1, 3, 2, 3)[:n] spatial_spec = 'HWDZX'[:n] dimension_numbers = ('N' + spatial_spec + 'C', 'OI' + spatial_spec, 'N' + spatial_spec + 'C') x1 = np.cos(random.normal(key, (2, ) + spatial_shape + (2, ))) x1 = test_utils.mask(x1, mask_constant, mask_axis, key, p) if same_inputs: x2 = None else: x2 = np.cos(random.normal(key, (4, ) + spatial_shape + (2, ))) x2 = test_utils.mask(x2, mask_constant, mask_axis, key, p) def get_attn(): return stax.GlobalSelfAttention( n_chan_out=width, n_chan_key=width, n_chan_val=int(np.round(float(width) / int(np.sqrt(width)))), n_heads=int(np.sqrt(width)), ) if proj == 'avg' else stax.Identity() conv = stax.ConvTranspose if transpose else stax.Conv nn = stax.serial( stax.FanOut(3), stax.parallel( stax.serial( conv(dimension_numbers=dimension_numbers, out_chan=width, strides=strides, filter_shape=filter_shape, padding='CIRCULAR', W_std=1.5, b_std=0.2), stax.LayerNorm(axis=(1, -1)), stax.Abs(), stax.DotGeneral(rhs=0.9), conv(dimension_numbers=dimension_numbers, out_chan=width, strides=strides, filter_shape=filter_shape, padding='VALID', W_std=1.2, b_std=0.1), ), stax.serial( conv(dimension_numbers=dimension_numbers, out_chan=width, strides=strides, filter_shape=filter_shape, padding='SAME', W_std=0.1, b_std=0.3), stax.Relu(), stax.Dropout(0.7), conv(dimension_numbers=dimension_numbers, out_chan=width, strides=strides, filter_shape=filter_shape, padding='VALID', W_std=0.9, b_std=1.), ), stax.serial( get_attn(), conv(dimension_numbers=dimension_numbers, out_chan=width, strides=strides, filter_shape=filter_shape, padding='CIRCULAR', W_std=1., b_std=0.1), stax.Erf(), stax.Dropout(0.2), stax.DotGeneral(rhs=0.7), conv(dimension_numbers=dimension_numbers, out_chan=width, strides=strides, filter_shape=filter_shape, padding='VALID', W_std=1., b_std=0.1), )), (stax.FanInSum() if concat is None else stax.FanInConcat(concat)), get_attn(), { 'avg': stax.GlobalAvgPool(), 'sum': stax.GlobalSumPool(), 'flatten': stax.Flatten(), }[proj], ) if get == 'nngp': init_fn, apply_fn, kernel_fn = stax.serial( nn, stax.Dense(width, 1., 0.)) elif get == 'ntk': init_fn, apply_fn, kernel_fn = stax.serial(nn, stax.Dense(1, 1., 0.)) else: raise ValueError(get) kernel_fn_mc = nt.monte_carlo_kernel_fn( init_fn, apply_fn, key, n_samples, device_count=0 if concat in (0, -n) else -1, implementation=_DEFAULT_TESTING_NTK_IMPLEMENTATION, vmap_axes=None if concat in (0, -n) else 0, ) kernel_fn = jit(kernel_fn, static_argnames='get') exact = kernel_fn(x1, x2, get, mask_constant=mask_constant) empirical = kernel_fn_mc(x1, x2, get=get, mask_constant=mask_constant) test_utils.assert_close_matrices(self, empirical, exact, tol)
from neural_tangents import stax from neural_tangents._src.empirical import _DEFAULT_TESTING_NTK_IMPLEMENTATION from tests import test_utils config.parse_flags_with_absl() config.update('jax_numpy_rank_promotion', 'raise') test_utils.update_test_tolerance() prandom.seed(1) @parameterized.product( same_inputs=[False, True], readout=[stax.Flatten(), stax.GlobalAvgPool(), stax.Identity()], readin=[stax.Flatten(), stax.GlobalAvgPool(), stax.Identity()]) class DiagonalTest(test_utils.NeuralTangentsTestCase): def _get_kernel_fn(self, same_inputs, readin, readout): key = random.PRNGKey(1) x1 = random.normal(key, (2, 5, 6, 3)) x2 = None if same_inputs else random.normal(key, (3, 5, 6, 3)) layers = [readin] filter_shape = (2, 3) if readin[0].__name__ == 'Identity' else () layers += [ stax.Conv(1, filter_shape, padding='SAME'), stax.Relu(), stax.Conv(1, filter_shape, padding='SAME'),
def test_fan_in_conv(self, same_inputs, axis, n_branches, get, branch_in, readout): if xla_bridge.get_backend().platform == 'cpu': raise jtu.SkipTest('Not running CNNs on CPU to save time.') if axis in (None, 0, 1, 2) and branch_in == 'dense_after_branch_in': raise jtu.SkipTest('`FanInSum` and `FanInConcat(0/1/2)` ' 'require `is_gaussian`.') if axis == 3 and branch_in == 'dense_before_branch_in': raise jtu.SkipTest('`FanInConcat` on feature axis requires a dense layer ' 'after concatenation.') key = random.PRNGKey(1) X0_1 = random.normal(key, (2, 5, 6, 3)) X0_2 = None if same_inputs else random.normal(key, (3, 5, 6, 3)) if xla_bridge.get_backend().platform == 'tpu': width = 2048 n_samples = 1024 tol = 0.02 else: width = 1024 n_samples = 512 tol = 0.01 conv = stax.Conv(out_chan=width, filter_shape=(3, 3), padding='SAME', W_std=1.25, b_std=0.1) input_layers = [conv, stax.FanOut(n_branches)] branches = [] for b in range(n_branches): branch_layers = [FanInTest._get_phi(b)] for i in range(b): branch_layers += [ stax.Conv( out_chan=width, filter_shape=(i + 1, 4 - i), padding='SAME', W_std=1.25 + i, b_std=0.1 + i), FanInTest._get_phi(i)] if branch_in == 'dense_before_branch_in': branch_layers += [conv] branches += [stax.serial(*branch_layers)] output_layers = [ stax.FanInSum() if axis is None else stax.FanInConcat(axis), stax.Relu(), stax.GlobalAvgPool() if readout == 'pool' else stax.Flatten() ] if branch_in == 'dense_after_branch_in': output_layers.insert(1, conv) nn = stax.serial(*(input_layers + [stax.parallel(*branches)] + output_layers)) init_fn, apply_fn, kernel_fn = stax.serial( nn, stax.Dense(1 if get == 'ntk' else width, 1.25, 0.5)) kernel_fn_mc = monte_carlo.monte_carlo_kernel_fn( init_fn, apply_fn, key, n_samples, device_count=0 if axis in (0, -4) else -1) exact = kernel_fn(X0_1, X0_2, get=get) empirical = kernel_fn_mc(X0_1, X0_2, get=get) empirical = empirical.reshape(exact.shape) utils.assert_close_matrices(self, empirical, exact, tol)
@parameterized.named_parameters( test_utils.cases_from_list( { 'testcase_name': ' [{}_out={}_in={}]'.format( 'same_inputs' if same_inputs else 'different_inputs', readout[0].__name__, readin[0].__name__), 'same_inputs': same_inputs, 'readout': readout, 'readin': readin } for same_inputs in [False, True] for readout in [stax.Flatten(), stax.GlobalAvgPool(), stax.Identity()] for readin in [stax.Flatten(), stax.GlobalAvgPool(), stax.Identity()])) class DiagonalTest(test_utils.NeuralTangentsTestCase): def _get_kernel_fn(self, same_inputs, readin, readout): key = random.PRNGKey(1) x1 = random.normal(key, (2, 5, 6, 3)) x2 = None if same_inputs else random.normal(key, (3, 5, 6, 3)) layers = [readin] filter_shape = (2, 3) if readin[0].__name__ == 'Identity' else () layers += [ stax.Conv(1, filter_shape, padding='SAME'), stax.Relu(), stax.Conv(1, filter_shape, padding='SAME'),
def ResNet50(num_classes, batchnorm=True, parameterization='standard', nonlinearity='relu'): # Define layer constructors if parameterization == 'standard': def MyGeneralConv(*args, **kwargs): return GeneralConv(*args, **kwargs) def MyDense(*args, **kwargs): return Dense(*args, **kwargs) elif parameterization == 'ntk': def MyGeneralConv(*args, **kwargs): return stax._GeneralConv(*args, **kwargs)[:2] def MyDense(*args, **kwargs): return stax.Dense(*args, **kwargs)[:2] # Define nonlinearity if nonlinearity == 'relu': nonlin = Relu elif nonlinearity == 'swish': nonlin = Swish elif nonlinearity == 'swishten': nonlin = Swishten elif nonlinearity == 'softplus': nonlin = Softplus return jax_stax.serial( MyGeneralConv(('NHWC', 'HWIO', 'NHWC'), 64, (7, 7), strides=(2, 2), padding='SAME'), BatchNorm() if batchnorm else Identity, nonlin, MaxPool((3, 3), strides=(2, 2)), ConvBlock(3, [64, 64, 256], strides=(1, 1), batchnorm=batchnorm, parameterization=parameterization, nonlin=nonlin), IdentityBlock(3, [64, 64], batchnorm=batchnorm, parameterization=parameterization, nonlin=nonlin), IdentityBlock(3, [64, 64], batchnorm=batchnorm, parameterization=parameterization, nonlin=nonlin), ConvBlock(3, [128, 128, 512], batchnorm=batchnorm, parameterization=parameterization, nonlin=nonlin), IdentityBlock(3, [128, 128], batchnorm=batchnorm, parameterization=parameterization, nonlin=nonlin), IdentityBlock(3, [128, 128], batchnorm=batchnorm, parameterization=parameterization, nonlin=nonlin), IdentityBlock(3, [128, 128], batchnorm=batchnorm, parameterization=parameterization, nonlin=nonlin), ConvBlock(3, [256, 256, 1024], batchnorm=batchnorm, parameterization=parameterization, nonlin=nonlin), IdentityBlock(3, [256, 256], batchnorm=batchnorm, parameterization=parameterization, nonlin=nonlin), IdentityBlock(3, [256, 256], batchnorm=batchnorm, parameterization=parameterization, nonlin=nonlin), IdentityBlock(3, [256, 256], batchnorm=batchnorm, parameterization=parameterization, nonlin=nonlin), IdentityBlock(3, [256, 256], batchnorm=batchnorm, parameterization=parameterization, nonlin=nonlin), IdentityBlock(3, [256, 256], batchnorm=batchnorm, parameterization=parameterization, nonlin=nonlin), ConvBlock(3, [512, 512, 2048], batchnorm=batchnorm, parameterization=parameterization, nonlin=nonlin), IdentityBlock(3, [512, 512], batchnorm=batchnorm, parameterization=parameterization, nonlin=nonlin), IdentityBlock(3, [512, 512], batchnorm=batchnorm, parameterization=parameterization, nonlin=nonlin), stax.GlobalAvgPool()[:-1], MyDense(num_classes))
def main(*args, use_dummy_data: bool = False, **kwargs) -> None: # Mask all padding with this value. mask_constant = 100. if use_dummy_data: x_train, y_train, x_test, y_test = _get_dummy_data(mask_constant) else: # Build data pipelines. print('Loading IMDb data.') x_train, y_train, x_test, y_test = datasets.get_dataset( name='imdb_reviews', n_train=FLAGS.n_train, n_test=FLAGS.n_test, do_flatten_and_normalize=False, data_dir=FLAGS.imdb_path, input_key='text') # Embed words and pad / truncate sentences to a fixed size. x_train, x_test = datasets.embed_glove( xs=[x_train, x_test], glove_path=FLAGS.glove_path, max_sentence_length=FLAGS.max_sentence_length, mask_constant=mask_constant) # Build the infinite network. # Not using the finite model, hence width is set to 1 everywhere. _, _, kernel_fn = stax.serial( stax.Conv(out_chan=1, filter_shape=(9, ), strides=(1, ), padding='VALID'), stax.Relu(), stax.GlobalSelfAttention(n_chan_out=1, n_chan_key=1, n_chan_val=1, pos_emb_type='SUM', W_pos_emb_std=1., pos_emb_decay_fn=lambda d: 1 / (1 + d**2), n_heads=1), stax.Relu(), stax.GlobalAvgPool(), stax.Dense(out_dim=1)) # Optionally, compute the kernel in batches, in parallel. kernel_fn = nt.batch(kernel_fn, device_count=-1, batch_size=FLAGS.batch_size) start = time.time() # Bayesian and infinite-time gradient descent inference with infinite network. predict = nt.predict.gradient_descent_mse_ensemble( kernel_fn=kernel_fn, x_train=x_train, y_train=y_train, diag_reg=1e-6, mask_constant=mask_constant) fx_test_nngp, fx_test_ntk = predict(x_test=x_test, get=('nngp', 'ntk')) fx_test_nngp.block_until_ready() fx_test_ntk.block_until_ready() duration = time.time() - start print(f'Kernel construction and inference done in {duration} seconds.') # Print out accuracy and loss for infinite network predictions. loss = lambda fx, y_hat: 0.5 * np.mean((fx - y_hat)**2) util.print_summary('NNGP test', y_test, fx_test_nngp, None, loss) util.print_summary('NTK test', y_test, fx_test_ntk, None, loss)