def test_composition_conv(self, avg_pool): rng = random.PRNGKey(0) x1 = random.normal(rng, (5, 10, 10, 3)) x2 = random.normal(rng, (5, 10, 10, 3)) Block = stax.serial(stax.Conv(256, (3, 3)), stax.Relu()) if avg_pool: Readout = stax.serial(stax.GlobalAvgPool(), stax.Dense(10)) marginalization = 'none' else: Readout = stax.serial(stax.Flatten(), stax.Dense(10)) marginalization = 'auto' block_ker_fn, readout_ker_fn = Block[2], Readout[2] _, _, composed_ker_fn = stax.serial(Block, Readout) ker_out = readout_ker_fn( block_ker_fn(x1, marginalization=marginalization)) composed_ker_out = composed_ker_fn(x1) self.assertAllClose(ker_out, composed_ker_out, True) if avg_pool: with self.assertRaises(ValueError): ker_out = readout_ker_fn(block_ker_fn(x1)) ker_out = readout_ker_fn( block_ker_fn(x1, x2, marginalization=marginalization)) composed_ker_out = composed_ker_fn(x1, x2) self.assertAllClose(ker_out, composed_ker_out, True)
def test_composition_conv(self, avg_pool, same_inputs): rng = random.PRNGKey(0) x1 = random.normal(rng, (3, 5, 5, 3)) x2 = None if same_inputs else random.normal(rng, (4, 5, 5, 3)) Block = stax.serial(stax.Conv(256, (3, 3)), stax.Relu()) if avg_pool: Readout = stax.serial(stax.Conv(256, (3, 3)), stax.GlobalAvgPool(), stax.Dense(10)) else: Readout = stax.serial(stax.Flatten(), stax.Dense(10)) block_ker_fn, readout_ker_fn = Block[2], Readout[2] _, _, composed_ker_fn = stax.serial(Block, Readout) composed_ker_out = composed_ker_fn(x1, x2) ker_out_no_marg = readout_ker_fn(block_ker_fn(x1, x2, diagonal_spatial=False)) ker_out_default = readout_ker_fn(block_ker_fn(x1, x2)) self.assertAllClose(composed_ker_out, ker_out_no_marg) self.assertAllClose(composed_ker_out, ker_out_default) if avg_pool: with self.assertRaises(ValueError): ker_out = readout_ker_fn(block_ker_fn(x1, x2, diagonal_spatial=True)) else: ker_out_marg = readout_ker_fn(block_ker_fn(x1, x2, diagonal_spatial=True)) self.assertAllClose(composed_ker_out, ker_out_marg)
def _get_net(W_std, b_std, filter_shape, is_conv, use_pooling, is_res, padding, phi, strides, width, is_ntk): fc = partial(stax.Dense, W_std=W_std, b_std=b_std) conv = partial( stax.Conv, filter_shape=filter_shape, strides=strides, padding=padding, W_std=W_std, b_std=b_std) affine = conv(width) if is_conv else fc(width) res_unit = stax.serial((stax.AvgPool( (2, 3), None, 'SAME' if padding == 'SAME' else 'CIRCULAR') if use_pooling else stax.Identity()), phi, affine) if is_res: block = stax.serial(affine, stax.FanOut(2), stax.parallel(stax.Identity(), res_unit), stax.FanInSum()) else: block = stax.serial(affine, res_unit) readout = stax.serial(stax.GlobalAvgPool() if use_pooling else stax.Flatten(), fc(1 if is_ntk else width)) net = stax.serial(block, readout) return net
def test_nested_parallel(self, same_inputs, kernel_type): platform = default_backend() rtol = RTOL if platform != 'tpu' else 0.05 rng = random.PRNGKey(0) (input_key1, input_key2, input_key3, input_key4, mask_key, mc_key) = random.split(rng, 6) x1_1, x2_1 = _get_inputs(input_key1, same_inputs, (BATCH_SIZE, 5)) x1_2, x2_2 = _get_inputs(input_key2, same_inputs, (BATCH_SIZE, 2, 2, 2)) x1_3, x2_3 = _get_inputs(input_key3, same_inputs, (BATCH_SIZE, 2, 2, 3)) x1_4, x2_4 = _get_inputs(input_key4, same_inputs, (BATCH_SIZE, 3, 4)) m1_key, m2_key, m3_key, m4_key = random.split(mask_key, 4) x1_1 = test_utils.mask( x1_1, mask_constant=-1, mask_axis=(1,), key=m1_key, p=0.5) x1_2 = test_utils.mask( x1_2, mask_constant=-1, mask_axis=(2, 3,), key=m2_key, p=0.5) if not same_inputs: x2_3 = test_utils.mask( x2_3, mask_constant=-1, mask_axis=(1, 3,), key=m3_key, p=0.5) x2_4 = test_utils.mask( x2_4, mask_constant=-1, mask_axis=(2,), key=m4_key, p=0.5) x1 = (((x1_1, x1_2), x1_3), x1_4) x2 = (((x2_1, x2_2), x2_3), x2_4) if not same_inputs else None N_in = 2 ** 7 # We only include dropout on non-TPU backends, because it takes large N to # converge on TPU. dropout_or_id = stax.Dropout(0.9) if platform != 'tpu' else stax.Identity() init_fn, apply_fn, kernel_fn = stax.parallel( stax.parallel( stax.parallel(stax.Dense(N_in), stax.serial(stax.Conv(N_in + 1, (2, 2)), stax.Flatten())), stax.serial(stax.Conv(N_in + 2, (2, 2)), dropout_or_id, stax.GlobalAvgPool())), stax.Conv(N_in + 3, (2,))) kernel_fn_empirical = nt.monte_carlo_kernel_fn( init_fn, apply_fn, mc_key, N_SAMPLES, implementation=_DEFAULT_TESTING_NTK_IMPLEMENTATION, vmap_axes=(((((0, 0), 0), 0), (((0, 0), 0), 0), {}) if platform == 'tpu' else None) ) test_utils.assert_close_matrices( self, kernel_fn(x1, x2, get=kernel_type, mask_constant=-1), kernel_fn_empirical(x1, x2, get=kernel_type, mask_constant=-1), rtol)
def _build_network(input_shape, network, out_logits): if len(input_shape) == 1: assert network == FLAT return stax.Dense(out_logits, W_std=2.0, b_std=0.5) elif len(input_shape) == 3: if network == POOLING: return stax.serial( stax.Conv(CONVOLUTION_CHANNELS, (3, 3), W_std=2.0, b_std=0.05), stax.GlobalAvgPool(), stax.Dense(out_logits, W_std=2.0, b_std=0.5)) elif network == CONV: return stax.serial( stax.Conv(CONVOLUTION_CHANNELS, (1, 2), W_std=1.5, b_std=0.1), stax.Relu(), stax.Conv(CONVOLUTION_CHANNELS, (3, 2), W_std=2.0, b_std=0.05), ) elif network == FLAT: return stax.serial( stax.Conv(CONVOLUTION_CHANNELS, (3, 3), W_std=2.0, b_std=0.05), stax.Flatten(), stax.Dense(out_logits, W_std=2.0, b_std=0.5)) else: raise ValueError( 'Unexpected network type found: {}'.format(network)) else: raise ValueError('Expected flat or image test input.')
def test_ab_relu_id(self, same_inputs, do_stabilize): key = random.PRNGKey(1) X0_1 = random.normal(key, (3, 2)) fc = stax.Dense(5, 1, 0) X0_2 = None if same_inputs else random.normal(key, (4, 2)) # Test that ABRelu(a, a) == a * Identity init_fn, apply_id, kernel_fn_id = stax.serial(fc, stax.Identity()) _, params = init_fn(key, input_shape=X0_1.shape) for a in [-5, -1, -0.5, 0, 0.5, 1, 5]: with self.subTest(a=a): _, apply_ab_relu, kernel_fn_ab_relu = stax.serial( fc, stax.ABRelu(a, a, do_stabilize=do_stabilize)) X1_1_id = a * apply_id(params, X0_1) X1_1_ab_relu = apply_ab_relu(params, X0_1) self.assertAllClose(X1_1_id, X1_1_ab_relu) kernels_id = kernel_fn_id(X0_1 * a, None if X0_2 is None else a * X0_2) kernels_ab_relu = kernel_fn_ab_relu(X0_1, X0_2) # Manually correct the value of `is_gaussian` because # `ab_relu` (incorrectly) sets `is_gaussian=False` when `a==b`. kernels_ab_relu = kernels_ab_relu.replace(is_gaussian=True) self.assertAllClose(kernels_id, kernels_ab_relu)
def _build_network(input_shape, network, out_logits, use_dropout): dropout = stax.Dropout(0.9, mode='train') if use_dropout else stax.Identity() if len(input_shape) == 1: assert network == 'FLAT' return stax.serial(stax.Dense(WIDTH, W_std=2.0, b_std=0.5), dropout, stax.Dense(out_logits, W_std=2.0, b_std=0.5)) elif len(input_shape) == 3: if network == POOLING: return stax.serial( stax.Conv(CONVOLUTION_CHANNELS, (2, 2), W_std=2.0, b_std=0.05), stax.GlobalAvgPool(), dropout, stax.Dense(out_logits, W_std=2.0, b_std=0.5)) elif network == FLAT: return stax.serial( stax.Conv(CONVOLUTION_CHANNELS, (2, 2), W_std=2.0, b_std=0.05), stax.Flatten(), dropout, stax.Dense(out_logits, W_std=2.0, b_std=0.5)) elif network == INTERMEDIATE_CONV: return stax.Conv(CONVOLUTION_CHANNELS, (2, 2), W_std=2.0, b_std=0.05) else: raise ValueError( 'Unexpected network type found: {}'.format(network)) else: raise ValueError('Expected flat or image test input.')
def WideResnetBlocknt(channels, strides=(1, 1), channel_mismatch=False, batchnorm='std', parameterization='ntk'): """A WideResnet block, with or without BatchNorm.""" Main = stax_nt.serial( _batch_norm_internal(batchnorm), stax_nt.Relu(), stax_nt.Conv(channels, (3, 3), strides, padding='SAME', parameterization=parameterization), _batch_norm_internal(batchnorm), stax_nt.Relu(), stax_nt.Conv(channels, (3, 3), padding='SAME', parameterization=parameterization)) Shortcut = stax_nt.Identity() if not channel_mismatch else stax_nt.Conv( channels, (3, 3), strides, padding='SAME', parameterization=parameterization) return stax_nt.serial(stax_nt.FanOut(2), stax_nt.parallel(Main, Shortcut), stax_nt.FanInSum())
def test_ab_relu_id(self, same_inputs): key = random.PRNGKey(1) X0_1 = random.normal(key, (5, 7)) fc = stax.Dense(10, 1, 0) X0_2 = None if same_inputs else random.normal(key, (9, 7)) # Test that ABRelu(a, a) == a * Identity init_fn, apply_id, kernel_fn_id = stax.serial(fc, stax.Identity()) params = init_fn(key, input_shape=(-1, 7)) for a in [-5, -1, -0.5, 0, 0.5, 1, 5]: with self.subTest(a=a): _, apply_ab_relu, kernel_fn_ab_relu = stax.serial( fc, stax.ABRelu(a, a)) X1_1_id = a * apply_id(params, X0_1) X1_1_ab_relu = apply_ab_relu(params, X0_1) self.assertAllClose(X1_1_id, X1_1_ab_relu, True) kernels_id = kernel_fn_id(X0_1 * a, None if X0_2 is None else a * X0_2) kernels_ab_relu = kernel_fn_ab_relu(X0_1, X0_2, ('nngp', 'ntk')) self.assertAllClose(kernels_id, kernels_ab_relu, True)
def GP(x_train, y_train, x_test, y_test, w_std, b_std, l, C): net0 = stax.Dense(1, w_std, b_std) nets = [net0] k_layer = [] K = net0[2](x_train, None) k_layer.append(K.nngp) for l in range(1, l + 1): net_l = stax.serial(stax.Relu(), stax.Dense(1, w_std, b_std)) K = net_l[2](K) k_layer.append(K.nngp) nets += [stax.serial(nets[-1], net_l)] kernel_fn = nets[-1][2] start = time.time() # Bayesian and infinite-time gradient descent inference with infinite network. fx_test_nngp, fx_test_ntk = nt.predict.gp_inference(kernel_fn, x_train, y_train, x_test, get=('nngp', 'ntk'), diag_reg=C) fx_test_nngp.block_until_ready() duration = time.time() - start #print('Kernel construction and inference done in %s seconds.' % duration) return accuracy(y_test, fx_test_nngp)
def _test_activation(self, activation_fn, same_inputs, model, get, rbf_gamma=None): if 'conv' in model: test_utils.skip_test(self) key = random.PRNGKey(1) key, split = random.split(key) output_dim = 1024 if get == 'nngp' else 1 b_std = 0.5 W_std = 2.0 if activation_fn[2].__name__ == 'Sin': W_std = 0.9 if activation_fn[2].__name__ == 'Rbf': W_std = 1.0 b_std = 0.0 if model == 'fc': rtol = 0.04 X0_1 = random.normal(key, (4, 2)) X0_2 = None if same_inputs else random.normal(split, (2, 2)) affine = stax.Dense(1024, W_std, b_std) readout = stax.Dense(output_dim) depth = 1 else: rtol = 0.05 X0_1 = random.normal(key, (2, 4, 4, 3)) X0_2 = None if same_inputs else random.normal(split, (4, 4, 4, 3)) affine = stax.Conv(512, (3, 2), W_std=W_std, b_std=b_std, padding='SAME') readout = stax.serial(stax.GlobalAvgPool() if 'pool' in model else stax.Flatten(), stax.Dense(output_dim)) depth = 2 if default_backend() == 'cpu': num_samplings = 200 rtol *= 2 else: num_samplings = (500 if activation_fn[2].__name__ in ('Sin', 'Rbf') else 300) init_fn, apply_fn, kernel_fn = stax.serial( *[affine, activation_fn]*depth, readout) analytic_kernel = kernel_fn(X0_1, X0_2, get) mc_kernel_fn = nt.monte_carlo_kernel_fn( init_fn, apply_fn, split, num_samplings, implementation=2, vmap_axes=0 ) empirical_kernel = mc_kernel_fn(X0_1, X0_2, get) test_utils.assert_close_matrices(self, analytic_kernel, empirical_kernel, rtol) # Check match with explicit RBF if rbf_gamma is not None and get == 'nngp' and model == 'fc': input_dim = X0_1.shape[1] _, _, kernel_fn = self._RBF(rbf_gamma / input_dim) direct_rbf_kernel = kernel_fn(X0_1, X0_2, get) test_utils.assert_close_matrices(self, analytic_kernel, direct_rbf_kernel, rtol)
def test_elementwise( self, same_inputs, phi, n, diagonal_batch, diagonal_spatial ): fn = lambda x: phi[1]((), x) name = phi[0].__name__ def nngp_fn(cov12, var1, var2): if 'Identity' in name: res = cov12 elif 'Erf' in name: prod = (1 + 2 * var1) * (1 + 2 * var2) res = np.arcsin(2 * cov12 / np.sqrt(prod)) * 2 / np.pi elif 'Sin' in name: sum_ = (var1 + var2) s1 = np.exp((-0.5 * sum_ + cov12)) s2 = np.exp((-0.5 * sum_ - cov12)) res = (s1 - s2) / 2 elif 'Relu' in name: prod = var1 * var2 sqrt = np.sqrt(np.maximum(prod - cov12 ** 2, 1e-30)) angles = np.arctan2(sqrt, cov12) dot_sigma = (1 - angles / np.pi) / 2 res = sqrt / (2 * np.pi) + dot_sigma * cov12 else: raise NotImplementedError(name) return res _, _, kernel_fn = stax.serial(stax.Dense(1), stax.Elementwise(fn, nngp_fn), stax.Dense(1), stax.Elementwise(fn, nngp_fn)) _, _, kernel_fn_manual = stax.serial(stax.Dense(1), phi, stax.Dense(1), phi) key = random.PRNGKey(1) shape = (4, 3, 2)[:n] + (1,) x1 = random.normal(key, (5,) + shape) if same_inputs is None: x2 = None elif same_inputs is True: x2 = x1 else: x2 = random.normal(key, (6,) + shape) kwargs = dict(diagonal_batch=diagonal_batch, diagonal_spatial=diagonal_spatial) k = kernel_fn(x1, x2, **kwargs) k_manual = kernel_fn_manual(x1, x2, **kwargs).replace(is_gaussian=False) self.assertAllClose(k_manual, k)
def _test_analytic_kernel_composition(self, batching_fn): # Check Fully-Connected. rng = stateless_uniform(shape=[2], seed=[0, 0], minval=None, maxval=None, dtype=tf.int32) keys = tf_random_split(rng) rng_self = keys[0] rng_other = keys[1] x_self = np.asarray(normal((8, 10), seed=rng_self)) x_other = np.asarray(normal((2, 10), seed=rng_other)) Block = stax.serial(stax.Dense(256), stax.Relu()) _, _, ker_fn = Block ker_fn = batching_fn(ker_fn) _, _, composed_ker_fn = stax.serial(Block, Block) ker_out = ker_fn(ker_fn(x_self)) composed_ker_out = composed_ker_fn(x_self) if batching_fn == batch._parallel: # In the parallel setting, `x1_is_x2` is not computed correctly # when x1==x2. composed_ker_out = composed_ker_out.replace( x1_is_x2=ker_out.x1_is_x2) self.assertAllClose(ker_out, composed_ker_out) ker_out = ker_fn(ker_fn(x_self, x_other)) composed_ker_out = composed_ker_fn(x_self, x_other) if batching_fn == batch._parallel: composed_ker_out = composed_ker_out.replace( x1_is_x2=ker_out.x1_is_x2) self.assertAllClose(ker_out, composed_ker_out) # Check convolutional + pooling. x_self = np.asarray(normal((8, 10, 10, 3), seed=rng)) x_other = np.asarray(normal((2, 10, 10, 3), seed=rng)) Block = stax.serial(stax.Conv(256, (2, 2)), stax.Relu()) Readout = stax.serial(stax.GlobalAvgPool(), stax.Dense(10)) block_ker_fn, readout_ker_fn = Block[2], Readout[2] _, _, composed_ker_fn = stax.serial(Block, Readout) block_ker_fn = batching_fn(block_ker_fn) readout_ker_fn = batching_fn(readout_ker_fn) ker_out = readout_ker_fn(block_ker_fn(x_self)) composed_ker_out = composed_ker_fn(x_self) if batching_fn == batch._parallel: composed_ker_out = composed_ker_out.replace( x1_is_x2=ker_out.x1_is_x2) self.assertAllClose(ker_out, composed_ker_out) ker_out = readout_ker_fn(block_ker_fn(x_self, x_other)) composed_ker_out = composed_ker_fn(x_self, x_other) if batching_fn == batch._parallel: composed_ker_out = composed_ker_out.replace( x1_is_x2=ker_out.x1_is_x2) self.assertAllClose(ker_out, composed_ker_out)
def test_vmap_axes(self, same_inputs): n1, n2 = 3, 4 c1, c2, c3 = 9, 5, 7 h2, h3, w3 = 6, 8, 2 def get_x(n, k): k1, k2, k3 = random.split(k, 3) x1 = random.normal(k1, (n, c1)) x2 = random.normal(k2, (h2, n, c2)) x3 = random.normal(k3, (c3, w3, n, h3)) x = [(x1, x2), x3] return x x1 = get_x(n1, random.PRNGKey(1)) x2 = get_x(n2, random.PRNGKey(2)) if not same_inputs else None p1 = random.normal(random.PRNGKey(5), (n1, h2, h2)) p2 = None if same_inputs else random.normal(random.PRNGKey(6), (n2, h2, h2)) init_fn, apply_fn, _ = stax.serial( stax.parallel( stax.parallel( stax.serial(stax.Dense(4, 2., 0.1), stax.Relu(), stax.Dense(3, 1., 0.15)), # 1 stax.serial(stax.Conv(7, (2,), padding='SAME', dimension_numbers=('HNC', 'OIH', 'NHC')), stax.Erf(), stax.Aggregate(1, 0, -1), stax.GlobalAvgPool(), stax.Dense(3, 0.5, 0.2)), # 2 ), stax.serial( stax.Conv(5, (2, 3), padding='SAME', dimension_numbers=('CWNH', 'IOHW', 'HWCN')), stax.Sin(), ) # 3 ), stax.parallel( stax.FanInSum(), stax.Conv(2, (2, 1), dimension_numbers=('HWCN', 'OIHW', 'HNWC')) ) ) _, params = init_fn(random.PRNGKey(3), tree_map(np.shape, x1)) implicit = jit(empirical._empirical_implicit_ntk_fn(apply_fn)) direct = jit(empirical._empirical_direct_ntk_fn(apply_fn)) implicit_batched = jit(empirical._empirical_implicit_ntk_fn( apply_fn, vmap_axes=([(0, 1), 2], [-2, -3], dict(pattern=0)))) direct_batched = jit(empirical._empirical_direct_ntk_fn( apply_fn, vmap_axes=([(-2, -2), -2], [0, 1], dict(pattern=-3)))) k = direct(x1, x2, params, pattern=(p1, p2)) self.assertAllClose(k, implicit(x1, x2, params, pattern=(p1, p2))) self.assertAllClose(k, direct_batched(x1, x2, params, pattern=(p1, p2))) self.assertAllClose(k, implicit_batched(x1, x2, params, pattern=(p1, p2)))
def ResnetBlock(channels, strides=(1, 1), channel_mismatch=False): Main = stax.serial(stax.Relu(), stax.Conv(channels, (3, 3), strides, padding='SAME'), stax.Relu(), stax.Conv(channels, (3, 3), padding='SAME')) Shortcut = stax.Identity() if not channel_mismatch else stax.Conv( channels, (3, 3), strides, padding='SAME') return stax.serial(stax.FanOut(2), stax.parallel(Main, Shortcut), stax.FanInSum())
def _test_analytic_kernel_composition(self, batching_fn): # Check Fully-Connected. rng = random.PRNGKey(0) rng_self, rng_other = random.split(rng) x_self = random.normal(rng_self, (8, 10)) x_other = random.normal(rng_other, (2, 10)) Block = stax.serial(stax.Dense(256), stax.Relu()) _, _, ker_fn = Block ker_fn = batching_fn(ker_fn) _, _, composed_ker_fn = stax.serial(Block, Block) ker_out = ker_fn(ker_fn(x_self)) composed_ker_out = composed_ker_fn(x_self) if batching_fn == batch._parallel: # In the parallel setting, `x1_is_x2` is not computed correctly # when x1==x2. composed_ker_out = composed_ker_out._replace( x1_is_x2=ker_out.x1_is_x2) self.assertAllClose(ker_out, composed_ker_out, True) ker_out = ker_fn(ker_fn(x_self, x_other)) composed_ker_out = composed_ker_fn(x_self, x_other) if batching_fn == batch._parallel: composed_ker_out = composed_ker_out._replace( x1_is_x2=ker_out.x1_is_x2) self.assertAllClose(ker_out, composed_ker_out, True) # Check convolutional + pooling. x_self = random.normal(rng, (8, 10, 10, 3)) x_other = random.normal(rng, (2, 10, 10, 3)) Block = stax.serial(stax.Conv(256, (2, 2)), stax.Relu()) Readout = stax.serial(stax.GlobalAvgPool(), stax.Dense(10)) block_ker_fn, readout_ker_fn = Block[2], Readout[2] _, _, composed_ker_fn = stax.serial(Block, Readout) block_ker_fn = batching_fn(block_ker_fn) readout_ker_fn = batching_fn(readout_ker_fn) ker_out = readout_ker_fn(block_ker_fn(x_self, marginalization='none')) composed_ker_out = composed_ker_fn(x_self) if batching_fn == batch._parallel: composed_ker_out = composed_ker_out._replace( x1_is_x2=ker_out.x1_is_x2) self.assertAllClose(ker_out, composed_ker_out, True) ker_out = readout_ker_fn( block_ker_fn(x_self, x_other, marginalization='none')) composed_ker_out = composed_ker_fn(x_self, x_other) if batching_fn == batch._parallel: composed_ker_out = composed_ker_out._replace( x1_is_x2=ker_out.x1_is_x2) self.assertAllClose(ker_out, composed_ker_out, True)
def test_composition(self): rng = random.PRNGKey(0) xs = random.normal(rng, (10, 10)) Block = stax.serial(stax.Dense(256), stax.Relu()) _, _, ker_fn = Block _, _, composed_ker_fn = stax.serial(Block, Block) ker_out = ker_fn(ker_fn(xs)) composed_ker_out = composed_ker_fn(xs) self.assertAllClose(ker_out, composed_ker_out, True)
def main(unused_argv): # Build data pipelines. print('Loading data.') x_train, y_train, x_test, y_test = \ datasets.get_dataset('cifar10', FLAGS.train_size, FLAGS.test_size) # Build the infinite network. _, _, kernel_fn = stax.serial(stax.Dense(1, 2., 0.05), stax.Relu(), stax.Dense(1, 2., 0.05)) # Optionally, compute the kernel in batches, in parallel. kernel_fn = nt.batch(kernel_fn, device_count=0, batch_size=FLAGS.batch_size) start = time.time() # Bayesian and infinite-time gradient descent inference with infinite network. predict_fn = nt.predict.gradient_descent_mse_ensemble(kernel_fn, x_train, y_train, diag_reg=1e-3) fx_test_nngp, fx_test_ntk = predict_fn(x_test=x_test) fx_test_nngp.block_until_ready() fx_test_ntk.block_until_ready() duration = time.time() - start print('Kernel construction and inference done in %s seconds.' % duration) # Print out accuracy and loss for infinite network predictions. loss = lambda fx, y_hat: 0.5 * np.mean((fx - y_hat)**2) util.print_summary('NNGP test', y_test, fx_test_nngp, None, loss) util.print_summary('NTK test', y_test, fx_test_ntk, None, loss)
def testPredictOnCPU(self): x_train = random.normal(random.PRNGKey(1), (10, 4, 5, 3)) x_test = random.normal(random.PRNGKey(1), (8, 4, 5, 3)) y_train = random.uniform(random.PRNGKey(1), (10, 7)) _, _, kernel_fn = stax.serial(stax.Conv(1, (3, 3)), stax.Relu(), stax.Flatten(), stax.Dense(1)) for store_on_device in [False, True]: for device_count in [0, 1]: for get in ['ntk', 'nngp', ('nngp', 'ntk'), ('ntk', 'nngp')]: with self.subTest(store_on_device=store_on_device, device_count=device_count, get=get): kernel_fn_batched = batch.batch( kernel_fn, 2, device_count, store_on_device) predictor = predict.gradient_descent_mse_gp( kernel_fn_batched, x_train, y_train, x_test, get, 0., True) gp_inference = predict.gp_inference( kernel_fn_batched, x_train, y_train, x_test, get, 0., True) self.assertAllClose(predictor(None), predictor(np.inf), True) self.assertAllClose(predictor(None), gp_inference, True)
def test_parallel_in_out_empirical(self, same_inputs): test_utils.stub_out_pmap(batch, 2) rng = random.PRNGKey(0) input_key1, input_key2, net_key = random.split(rng, 3) x1_1, x1_2, x1_3 = random.normal(input_key1, (3, 4, 10)) x2_1, x2_2, x2_3 = random.normal(input_key2, (3, 8, 10)) x1 = (x1_1, (x1_2, x1_3)) x2 = (x2_1, (x2_2, x2_3)) def net(N_out): return stax.parallel( stax.Dense(N_out), stax.parallel(stax.Dense(N_out + 1), stax.Dense(N_out + 2))) # Check NNGP. init_fn, apply_fn, _ = net(WIDTH) _, params = init_fn(net_key, ((-1, 10), ((-1, 10), (-1, 10)))) kernel_fn = jit(empirical.empirical_nngp_fn(apply_fn)) batch_kernel_fn = jit(batch.batch(kernel_fn, 2)) test_utils.assert_close_matrices(self, kernel_fn(x1, x2, params), batch_kernel_fn(x1, x2, params), RTOL) # Check NTK. init_fn, apply_fn, _ = stax.serial(net(WIDTH), net(1)) _, params = init_fn(net_key, ((-1, 10), ((-1, 10), (-1, 10)))) kernel_fn = jit(empirical.empirical_ntk_fn(apply_fn)) batch_kernel_fn = jit(batch.batch(kernel_fn, 2)) test_utils.assert_close_matrices(self, kernel_fn(x1, x2, params), batch_kernel_fn(x1, x2, params), RTOL)
def WideResnet(block_size, k, num_classes): return stax.serial( stax.Conv(16, (3, 3), padding='SAME'), ntk_generator.ResnetGroup(block_size, int(16 * k)), ntk_generator.ResnetGroup(block_size, int(32 * k), (2, 2)), ntk_generator.ResnetGroup(block_size, int(64 * k), (2, 2)), stax.Flatten(), stax.Dense(num_classes, 1., 0.))
def test_hermite(self, same_inputs, degree, get, readout): key = random.PRNGKey(1) key1, key2, key = random.split(key, 3) if degree > 2: width = 10000 n_samples = 5000 test_utils.skip_test(self) else: width = 10000 n_samples = 100 x1 = np.cos(random.normal(key1, [2, 6, 6, 3])) x2 = x1 if same_inputs else np.cos(random.normal(key2, [3, 6, 6, 3])) conv_layers = [ stax.Conv(width, (3, 3), W_std=2., b_std=0.5), stax.LayerNorm(), stax.Hermite(degree), stax.GlobalAvgPool() if readout == 'pool' else stax.Flatten(), stax.Dense(1) if get == 'ntk' else stax.Identity()] init_fn, apply_fn, kernel_fn = stax.serial(*conv_layers) analytic_kernel = kernel_fn(x1, x2, get) mc_kernel_fn = nt.monte_carlo_kernel_fn(init_fn, apply_fn, key, n_samples) mc_kernel = mc_kernel_fn(x1, x2, get) rot = degree / 2. * 1e-2 test_utils.assert_close_matrices(self, mc_kernel, analytic_kernel, rot)
def main(unused_argv): # Build data and . print('Loading data.') x_train, y_train, x_test, y_test = datasets.get_dataset('mnist', permute_train=True) # Build the network init_fn, f, _ = stax.serial(stax.Dense(2048, 1., 0.05), stax.Erf(), stax.Dense(10, 1., 0.05)) key = random.PRNGKey(0) _, params = init_fn(key, (-1, 784)) # Linearize the network about its initial parameters. f_lin = nt.linearize(f, params) # Create and initialize an optimizer for both f and f_lin. opt_init, opt_apply, get_params = optimizers.momentum( FLAGS.learning_rate, 0.9) opt_apply = jit(opt_apply) state = opt_init(params) state_lin = opt_init(params) # Create a cross-entropy loss function. loss = lambda fx, y_hat: -np.mean(logsoftmax(fx) * y_hat) # Specialize the loss function to compute gradients for both linearized and # full networks. grad_loss = jit(grad(lambda params, x, y: loss(f(params, x), y))) grad_loss_lin = jit(grad(lambda params, x, y: loss(f_lin(params, x), y))) # Train the network. print('Training.') print('Epoch\tLoss\tLinearized Loss') print('------------------------------------------') epoch = 0 steps_per_epoch = 50000 // FLAGS.batch_size for i, (x, y) in enumerate( datasets.minibatch(x_train, y_train, FLAGS.batch_size, FLAGS.train_epochs)): params = get_params(state) state = opt_apply(i, grad_loss(params, x, y), state) params_lin = get_params(state_lin) state_lin = opt_apply(i, grad_loss_lin(params_lin, x, y), state_lin) if i % steps_per_epoch == 0: print('{}\t{:.4f}\t{:.4f}'.format(epoch, loss(f(params, x), y), loss(f_lin(params_lin, x), y))) epoch += 1 # Print out summary data comparing the linear / nonlinear model. x, y = x_train[:10000], y_train[:10000] util.print_summary('train', y, f(params, x), f_lin(params_lin, x), loss) util.print_summary('test', y_test, f(params, x_test), f_lin(params_lin, x_test), loss)
def test_nonlineariy(self, phi, same_inputs, a, b, n): width = 2**10 n_samples = 2**9 init_fn, apply_fn, kernel_fn = stax.serial( stax.Dense(width), phi(a=a, b=b), stax.Dense(width), phi(a=a, b=b), stax.Dense(1)) key1, key2, key_mc = random.split(random.PRNGKey(1), 3) shape = (4, 3, 2)[:n] + (1,) x1 = np.cos(random.normal(key1, (2,) + shape)) if same_inputs is None: x2 = None elif same_inputs is True: x2 = x1 else: x2 = np.cos(random.normal(key2, (3,) + shape)) k = kernel_fn(x1, x2) mc_kernel_fn = nt.monte_carlo_kernel_fn(init_fn, apply_fn, key_mc, n_samples) k_mc = mc_kernel_fn(x1, x2, ('nngp', 'ntk')) test_utils.assert_close_matrices(self, k_mc.nngp, k.nngp, 6e-2) test_utils.assert_close_matrices(self, k_mc.ntk, k.ntk, 6e-2)
def test_exp_normalized(self): key = random.PRNGKey(0) x1 = random.normal(key, (2, 6, 7, 1)) x2 = random.normal(key, (4, 6, 7, 1)) for do_clip in [True, False]: for gamma in [1., 2., 0.5]: for get in ['nngp', 'ntk']: with self.subTest(do_clip=do_clip, gamma=gamma, get=get): _, _, kernel_fn = stax.serial( stax.Conv(1, (3, 3)), stax.ExpNormalized(gamma, do_clip), stax.Conv(1, (3, 3)), stax.ExpNormalized(gamma, do_clip), stax.GlobalAvgPool(), stax.Dense(1) ) k_12 = kernel_fn(x1, x2, get=get) self.assertEqual(k_12.shape, (x1.shape[0], x2.shape[0])) k_11 = kernel_fn(x1, None, get=get) self.assertEqual(k_11.shape, (x1.shape[0],) * 2) self.assertGreater(np.min(np.linalg.eigvalsh(k_11)), 0) k_22 = kernel_fn(x2, None, get=get) self.assertEqual(k_22.shape, (x2.shape[0],) * 2) self.assertGreater(np.min(np.linalg.eigvalsh(k_22)), 0)
def main(unused_argv): key1, key2, key3 = random.split(random.PRNGKey(1), 3) x1 = random.normal(key1, (2, 8, 8, 3)) x2 = random.normal(key2, (3, 8, 8, 3)) # A vanilla CNN. init_fn, f, _ = stax.serial( stax.Conv(8, (3, 3)), stax.Relu(), stax.Conv(8, (3, 3)), stax.Relu(), stax.Conv(8, (3, 3)), stax.Flatten(), stax.Dense(10) ) _, params = init_fn(key3, x1.shape) kwargs = dict( f=f, trace_axes=(), vmap_axes=0, ) # Default, baseline Jacobian contraction. jacobian_contraction = nt.empirical_ntk_fn( **kwargs, implementation=nt.NtkImplementation.JACOBIAN_CONTRACTION) # (6, 3, 10, 10) full `np.ndarray` test-train NTK ntk_jc = jacobian_contraction(x2, x1, params) # NTK-vector products-based implementation. ntk_vector_products = nt.empirical_ntk_fn( **kwargs, implementation=nt.NtkImplementation.NTK_VECTOR_PRODUCTS) ntk_vp = ntk_vector_products(x2, x1, params) # Structured derivatives-based implementation. structured_derivatives = nt.empirical_ntk_fn( **kwargs, implementation=nt.NtkImplementation.STRUCTURED_DERIVATIVES) ntk_sd = structured_derivatives(x2, x1, params) # Auto-FLOPs-selecting implementation. Doesn't work correctly on CPU/GPU. auto = nt.empirical_ntk_fn( **kwargs, implementation=nt.NtkImplementation.AUTO) ntk_auto = auto(x2, x1, params) # Check that implementations match for ntk1 in [ntk_jc, ntk_vp, ntk_sd, ntk_auto]: for ntk2 in [ntk_jc, ntk_vp, ntk_sd, ntk_auto]: diff = np.max(np.abs(ntk1 - ntk2)) print(f'NTK implementation diff {diff}.') assert diff < (1e-4 if jax.default_backend() != 'tpu' else 0.1), diff print('All NTK implementations match.')
def test_linear( self, get, s, depth, same_inputs, b_std, W_std, parameterization, ): if parameterization == 'standard': width = 2**9 // s elif parameterization == 'ntk': if s != 2**9: raise absltest.SkipTest( '"ntk" parameterization does not depend on "s".') width = 2**10 else: raise ValueError(parameterization) layers = [] for i in range(depth + 1): s_in = 1 if i == 0 else s s_out = 1 if (i == depth and get == 'ntk') else s out_dim = 1 if (i == depth and get == 'ntk') else width * (i + 1) layers += [stax.Dense(out_dim, W_std=W_std / (i + 1), b_std=b_std if b_std is None else b_std / (i + 1), parameterization=parameterization, s=(s_in, s_out))] net = stax.serial(*layers) net = net, (BATCH_SIZE, 3), -1, 1 _check_agreement_with_empirical(self, net, same_inputs, False, get == 'ntk', rtol=0.02, atol=10)
def test_empirical_ntk_diagonal_outputs(self, same_inputs, device_count, trace_axes, diagonal_axes): test_utils.stub_out_pmap(batching, 2) rng = random.PRNGKey(0) input_key1, input_key2, net_key = random.split(rng, 3) init_fn, apply_fn, _ = stax.serial(stax.Dense(5), stax.Relu(), stax.Dense(3)) test_x1 = random.normal(input_key1, (12, 4, 4)) test_x2 = None if same_inputs: test_x2 = random.normal(input_key2, (9, 4, 4)) kernel_fn = nt.empirical_ntk_fn(apply_fn, trace_axes=trace_axes, diagonal_axes=diagonal_axes, vmap_axes=0, implementation=2) _, params = init_fn(net_key, test_x1.shape) true_kernel = kernel_fn(test_x1, test_x2, params) batched_fn = batching.batch(kernel_fn, device_count=device_count, batch_size=3) batch_kernel = batched_fn(test_x1, test_x2, params) self.assertAllClose(true_kernel, batch_kernel)
def Resnet(block_size, num_classes): return stax.serial(stax.Conv(64, (3, 3), padding='SAME'), ResnetGroup(block_size, 64), ResnetGroup(block_size, 128, (2, 2)), ResnetGroup(block_size, 256, (2, 2)), ResnetGroup(block_size, 512, (2, 2)), stax.Flatten(), stax.Dense(num_classes, 1., 0.05))
def test_flatten_first(self, same_inputs): key = random.PRNGKey(1) X0_1 = random.normal(key, (5, 4, 3, 2)) X0_2 = None if same_inputs else random.normal(key, (3, 4, 3, 2)) X0_1_flat = np.reshape(X0_1, (X0_1.shape[0], -1)) X0_2_flat = None if same_inputs else np.reshape( X0_2, (X0_2.shape[0], -1)) _, _, fc_flat = stax.serial(stax.Dense(10, 2., 0.5), stax.Erf()) _, _, fc = stax.serial(stax.Flatten(), stax.Dense(10, 2., 0.5), stax.Erf()) K_flat = fc_flat(X0_1_flat, X0_2_flat) K = fc(X0_1, X0_2) self.assertAllClose(K_flat, K, True)