def main(unused_argv): key1, key2, key3 = random.split(random.PRNGKey(1), 3) x1 = random.normal(key1, (2, 8, 8, 3)) x2 = random.normal(key2, (3, 8, 8, 3)) # A vanilla CNN. init_fn, f, _ = stax.serial( stax.Conv(8, (3, 3)), stax.Relu(), stax.Conv(8, (3, 3)), stax.Relu(), stax.Conv(8, (3, 3)), stax.Flatten(), stax.Dense(10) ) _, params = init_fn(key3, x1.shape) kwargs = dict( f=f, trace_axes=(), vmap_axes=0, ) # Default, baseline Jacobian contraction. jacobian_contraction = nt.empirical_ntk_fn( **kwargs, implementation=nt.NtkImplementation.JACOBIAN_CONTRACTION) # (6, 3, 10, 10) full `np.ndarray` test-train NTK ntk_jc = jacobian_contraction(x2, x1, params) # NTK-vector products-based implementation. ntk_vector_products = nt.empirical_ntk_fn( **kwargs, implementation=nt.NtkImplementation.NTK_VECTOR_PRODUCTS) ntk_vp = ntk_vector_products(x2, x1, params) # Structured derivatives-based implementation. structured_derivatives = nt.empirical_ntk_fn( **kwargs, implementation=nt.NtkImplementation.STRUCTURED_DERIVATIVES) ntk_sd = structured_derivatives(x2, x1, params) # Auto-FLOPs-selecting implementation. Doesn't work correctly on CPU/GPU. auto = nt.empirical_ntk_fn( **kwargs, implementation=nt.NtkImplementation.AUTO) ntk_auto = auto(x2, x1, params) # Check that implementations match for ntk1 in [ntk_jc, ntk_vp, ntk_sd, ntk_auto]: for ntk2 in [ntk_jc, ntk_vp, ntk_sd, ntk_auto]: diff = np.max(np.abs(ntk1 - ntk2)) print(f'NTK implementation diff {diff}.') assert diff < (1e-4 if jax.default_backend() != 'tpu' else 0.1), diff print('All NTK implementations match.')
def _build_network(input_shape, network, out_logits): if len(input_shape) == 1: assert network == FLAT return stax.Dense(out_logits, W_std=2.0, b_std=0.5) elif len(input_shape) == 3: if network == POOLING: return stax.serial( stax.Conv(CONVOLUTION_CHANNELS, (3, 3), W_std=2.0, b_std=0.05), stax.GlobalAvgPool(), stax.Dense(out_logits, W_std=2.0, b_std=0.5)) elif network == CONV: return stax.serial( stax.Conv(CONVOLUTION_CHANNELS, (1, 2), W_std=1.5, b_std=0.1), stax.Relu(), stax.Conv(CONVOLUTION_CHANNELS, (3, 2), W_std=2.0, b_std=0.05), ) elif network == FLAT: return stax.serial( stax.Conv(CONVOLUTION_CHANNELS, (3, 3), W_std=2.0, b_std=0.05), stax.Flatten(), stax.Dense(out_logits, W_std=2.0, b_std=0.5)) else: raise ValueError( 'Unexpected network type found: {}'.format(network)) else: raise ValueError('Expected flat or image test input.')
def test_exp_normalized(self): key = random.PRNGKey(0) x1 = random.normal(key, (2, 6, 7, 1)) x2 = random.normal(key, (4, 6, 7, 1)) for do_clip in [True, False]: for gamma in [1., 2., 0.5]: for get in ['nngp', 'ntk']: with self.subTest(do_clip=do_clip, gamma=gamma, get=get): _, _, kernel_fn = stax.serial( stax.Conv(1, (3, 3)), stax.ExpNormalized(gamma, do_clip), stax.Conv(1, (3, 3)), stax.ExpNormalized(gamma, do_clip), stax.GlobalAvgPool(), stax.Dense(1) ) k_12 = kernel_fn(x1, x2, get=get) self.assertEqual(k_12.shape, (x1.shape[0], x2.shape[0])) k_11 = kernel_fn(x1, None, get=get) self.assertEqual(k_11.shape, (x1.shape[0],) * 2) self.assertGreater(np.min(np.linalg.eigvalsh(k_11)), 0) k_22 = kernel_fn(x2, None, get=get) self.assertEqual(k_22.shape, (x2.shape[0],) * 2) self.assertGreater(np.min(np.linalg.eigvalsh(k_22)), 0)
def test_nested_parallel(self, same_inputs, kernel_type): platform = default_backend() rtol = RTOL if platform != 'tpu' else 0.05 rng = random.PRNGKey(0) (input_key1, input_key2, input_key3, input_key4, mask_key, mc_key) = random.split(rng, 6) x1_1, x2_1 = _get_inputs(input_key1, same_inputs, (BATCH_SIZE, 5)) x1_2, x2_2 = _get_inputs(input_key2, same_inputs, (BATCH_SIZE, 2, 2, 2)) x1_3, x2_3 = _get_inputs(input_key3, same_inputs, (BATCH_SIZE, 2, 2, 3)) x1_4, x2_4 = _get_inputs(input_key4, same_inputs, (BATCH_SIZE, 3, 4)) m1_key, m2_key, m3_key, m4_key = random.split(mask_key, 4) x1_1 = test_utils.mask( x1_1, mask_constant=-1, mask_axis=(1,), key=m1_key, p=0.5) x1_2 = test_utils.mask( x1_2, mask_constant=-1, mask_axis=(2, 3,), key=m2_key, p=0.5) if not same_inputs: x2_3 = test_utils.mask( x2_3, mask_constant=-1, mask_axis=(1, 3,), key=m3_key, p=0.5) x2_4 = test_utils.mask( x2_4, mask_constant=-1, mask_axis=(2,), key=m4_key, p=0.5) x1 = (((x1_1, x1_2), x1_3), x1_4) x2 = (((x2_1, x2_2), x2_3), x2_4) if not same_inputs else None N_in = 2 ** 7 # We only include dropout on non-TPU backends, because it takes large N to # converge on TPU. dropout_or_id = stax.Dropout(0.9) if platform != 'tpu' else stax.Identity() init_fn, apply_fn, kernel_fn = stax.parallel( stax.parallel( stax.parallel(stax.Dense(N_in), stax.serial(stax.Conv(N_in + 1, (2, 2)), stax.Flatten())), stax.serial(stax.Conv(N_in + 2, (2, 2)), dropout_or_id, stax.GlobalAvgPool())), stax.Conv(N_in + 3, (2,))) kernel_fn_empirical = nt.monte_carlo_kernel_fn( init_fn, apply_fn, mc_key, N_SAMPLES, implementation=_DEFAULT_TESTING_NTK_IMPLEMENTATION, vmap_axes=(((((0, 0), 0), 0), (((0, 0), 0), 0), {}) if platform == 'tpu' else None) ) test_utils.assert_close_matrices( self, kernel_fn(x1, x2, get=kernel_type, mask_constant=-1), kernel_fn_empirical(x1, x2, get=kernel_type, mask_constant=-1), rtol)
def _build_network(input_shape, network, out_logits, use_dropout): dropout = stax.Dropout(0.9, mode='train') if use_dropout else stax.Identity() if len(input_shape) == 1: assert network == 'FLAT' return stax.serial(stax.Dense(WIDTH, W_std=2.0, b_std=0.5), dropout, stax.Dense(out_logits, W_std=2.0, b_std=0.5)) elif len(input_shape) == 3: if network == POOLING: return stax.serial( stax.Conv(CONVOLUTION_CHANNELS, (2, 2), W_std=2.0, b_std=0.05), stax.GlobalAvgPool(), dropout, stax.Dense(out_logits, W_std=2.0, b_std=0.5)) elif network == FLAT: return stax.serial( stax.Conv(CONVOLUTION_CHANNELS, (2, 2), W_std=2.0, b_std=0.05), stax.Flatten(), dropout, stax.Dense(out_logits, W_std=2.0, b_std=0.5)) elif network == INTERMEDIATE_CONV: return stax.Conv(CONVOLUTION_CHANNELS, (2, 2), W_std=2.0, b_std=0.05) else: raise ValueError( 'Unexpected network type found: {}'.format(network)) else: raise ValueError('Expected flat or image test input.')
def WideResnetBlocknt(channels, strides=(1, 1), channel_mismatch=False, batchnorm='std', parameterization='ntk'): """A WideResnet block, with or without BatchNorm.""" Main = stax_nt.serial( _batch_norm_internal(batchnorm), stax_nt.Relu(), stax_nt.Conv(channels, (3, 3), strides, padding='SAME', parameterization=parameterization), _batch_norm_internal(batchnorm), stax_nt.Relu(), stax_nt.Conv(channels, (3, 3), padding='SAME', parameterization=parameterization)) Shortcut = stax_nt.Identity() if not channel_mismatch else stax_nt.Conv( channels, (3, 3), strides, padding='SAME', parameterization=parameterization) return stax_nt.serial(stax_nt.FanOut(2), stax_nt.parallel(Main, Shortcut), stax_nt.FanInSum())
def test_composition_conv(self, avg_pool, same_inputs): rng = random.PRNGKey(0) x1 = random.normal(rng, (3, 5, 5, 3)) x2 = None if same_inputs else random.normal(rng, (4, 5, 5, 3)) Block = stax.serial(stax.Conv(256, (3, 3)), stax.Relu()) if avg_pool: Readout = stax.serial(stax.Conv(256, (3, 3)), stax.GlobalAvgPool(), stax.Dense(10)) else: Readout = stax.serial(stax.Flatten(), stax.Dense(10)) block_ker_fn, readout_ker_fn = Block[2], Readout[2] _, _, composed_ker_fn = stax.serial(Block, Readout) composed_ker_out = composed_ker_fn(x1, x2) ker_out_no_marg = readout_ker_fn(block_ker_fn(x1, x2, diagonal_spatial=False)) ker_out_default = readout_ker_fn(block_ker_fn(x1, x2)) self.assertAllClose(composed_ker_out, ker_out_no_marg) self.assertAllClose(composed_ker_out, ker_out_default) if avg_pool: with self.assertRaises(ValueError): ker_out = readout_ker_fn(block_ker_fn(x1, x2, diagonal_spatial=True)) else: ker_out_marg = readout_ker_fn(block_ker_fn(x1, x2, diagonal_spatial=True)) self.assertAllClose(composed_ker_out, ker_out_marg)
def test_vmap_axes(self, same_inputs): n1, n2 = 3, 4 c1, c2, c3 = 9, 5, 7 h2, h3, w3 = 6, 8, 2 def get_x(n, k): k1, k2, k3 = random.split(k, 3) x1 = random.normal(k1, (n, c1)) x2 = random.normal(k2, (h2, n, c2)) x3 = random.normal(k3, (c3, w3, n, h3)) x = [(x1, x2), x3] return x x1 = get_x(n1, random.PRNGKey(1)) x2 = get_x(n2, random.PRNGKey(2)) if not same_inputs else None p1 = random.normal(random.PRNGKey(5), (n1, h2, h2)) p2 = None if same_inputs else random.normal(random.PRNGKey(6), (n2, h2, h2)) init_fn, apply_fn, _ = stax.serial( stax.parallel( stax.parallel( stax.serial(stax.Dense(4, 2., 0.1), stax.Relu(), stax.Dense(3, 1., 0.15)), # 1 stax.serial(stax.Conv(7, (2,), padding='SAME', dimension_numbers=('HNC', 'OIH', 'NHC')), stax.Erf(), stax.Aggregate(1, 0, -1), stax.GlobalAvgPool(), stax.Dense(3, 0.5, 0.2)), # 2 ), stax.serial( stax.Conv(5, (2, 3), padding='SAME', dimension_numbers=('CWNH', 'IOHW', 'HWCN')), stax.Sin(), ) # 3 ), stax.parallel( stax.FanInSum(), stax.Conv(2, (2, 1), dimension_numbers=('HWCN', 'OIHW', 'HNWC')) ) ) _, params = init_fn(random.PRNGKey(3), tree_map(np.shape, x1)) implicit = jit(empirical._empirical_implicit_ntk_fn(apply_fn)) direct = jit(empirical._empirical_direct_ntk_fn(apply_fn)) implicit_batched = jit(empirical._empirical_implicit_ntk_fn( apply_fn, vmap_axes=([(0, 1), 2], [-2, -3], dict(pattern=0)))) direct_batched = jit(empirical._empirical_direct_ntk_fn( apply_fn, vmap_axes=([(-2, -2), -2], [0, 1], dict(pattern=-3)))) k = direct(x1, x2, params, pattern=(p1, p2)) self.assertAllClose(k, implicit(x1, x2, params, pattern=(p1, p2))) self.assertAllClose(k, direct_batched(x1, x2, params, pattern=(p1, p2))) self.assertAllClose(k, implicit_batched(x1, x2, params, pattern=(p1, p2)))
def ResnetBlock(channels, strides=(1, 1), channel_mismatch=False): Main = stax.serial(stax.Relu(), stax.Conv(channels, (3, 3), strides, padding='SAME'), stax.Relu(), stax.Conv(channels, (3, 3), padding='SAME')) Shortcut = stax.Identity() if not channel_mismatch else stax.Conv( channels, (3, 3), strides, padding='SAME') return stax.serial(stax.FanOut(2), stax.parallel(Main, Shortcut), stax.FanInSum())
def testGradientDescentMseEnsembleTrain(self): key = random.PRNGKey(1) x = random.normal(key, (8, 4, 6, 3)) _, _, kernel_fn = stax.serial(stax.Conv(1, (2, 2)), stax.Relu(), stax.Conv(1, (2, 1))) y = random.normal(key, (8, 2, 5, 1)) predictor = predict.gradient_descent_mse_ensemble(kernel_fn, x, y) for t in [None, np.array([0., 1., 10.])]: with self.subTest(t=t): y_none = predictor(t, None, None, compute_cov=True) y_x = predictor(t, x, None, compute_cov=True) self._assertAllClose(y_none, y_x, 0.04)
def _get_kernel_fn(self, same_inputs, readin, readout): key = random.PRNGKey(1) x1 = random.normal(key, (2, 5, 6, 3)) x2 = None if same_inputs else random.normal(key, (3, 5, 6, 3)) layers = [readin] filter_shape = (2, 3) if readin[0].__name__ == 'Identity' else () layers += [ stax.Conv(1, filter_shape, padding='SAME'), stax.Relu(), stax.Conv(1, filter_shape, padding='SAME'), stax.Erf(), readout ] _, _, kernel_fn = stax.serial(*layers) return kernel_fn, x1, x2
def build_le_net(network_width): """ Construct the LeNet of width network_width with average pooling using neural tangent's stax.""" return stax.serial( stax.Conv(out_chan=6 * network_width, filter_shape=(3, 3), strides=(1, 1), padding='VALID'), stax.Relu(), stax.AvgPool(window_shape=(2, 2), strides=(2, 2)), stax.Conv(out_chan=16 * network_width, filter_shape=(3, 3), strides=(1, 1), padding='VALID'), stax.Relu(), stax.AvgPool(window_shape=(2, 2), strides=(2, 2)), stax.Flatten(), stax.Dense(120 * network_width), stax.Relu(), stax.Dense(84 * network_width), stax.Relu(), stax.Dense(10))
def testPredictOnCPU(self): x_train = random.normal(random.PRNGKey(1), (10, 4, 5, 3)) x_test = random.normal(random.PRNGKey(1), (8, 4, 5, 3)) y_train = random.uniform(random.PRNGKey(1), (10, 7)) _, _, kernel_fn = stax.serial(stax.Conv(1, (3, 3)), stax.Relu(), stax.Flatten(), stax.Dense(1)) for store_on_device in [False, True]: for device_count in [0, 1]: for get in ['ntk', 'nngp', ('nngp', 'ntk'), ('ntk', 'nngp')]: with self.subTest(store_on_device=store_on_device, device_count=device_count, get=get): kernel_fn_batched = batch.batch( kernel_fn, 2, device_count, store_on_device) predictor = predict.gradient_descent_mse_gp( kernel_fn_batched, x_train, y_train, x_test, get, 0., True) gp_inference = predict.gp_inference( kernel_fn_batched, x_train, y_train, x_test, get, 0., True) self.assertAllClose(predictor(None), predictor(np.inf), True) self.assertAllClose(predictor(None), gp_inference, True)
def WideResnetnt( block_size, k, num_classes, batchnorm='std'): #, batch_norm=None,layer_norm=None,freezelast=None): """Based off of WideResnet from paper, with or without BatchNorm. (Set config.wrn_block_size=3, config.wrn_widening_f=10 in that case). Uses default weight and bias init.""" parameterization = 'standard' layers_lst = [ stax_nt.Conv(16, (3, 3), padding='SAME', parameterization=parameterization), WideResnetGroupnt(block_size, 16 * k, parameterization=parameterization, batchnorm=batchnorm), WideResnetGroupnt(block_size, 32 * k, (2, 2), parameterization=parameterization, batchnorm=batchnorm), WideResnetGroupnt(block_size, 64 * k, (2, 2), parameterization=parameterization, batchnorm=batchnorm) ] layers_lst += [_batch_norm_internal(batchnorm), stax_nt.Relu()] layers_lst += [ stax_nt.AvgPool((8, 8)), stax_nt.Flatten(), stax_nt.Dense(num_classes, parameterization=parameterization) ] return stax_nt.serial(*layers_lst)
def WideResnet(block_size, k, num_classes): return stax.serial( stax.Conv(16, (3, 3), padding='SAME'), ntk_generator.ResnetGroup(block_size, int(16 * k)), ntk_generator.ResnetGroup(block_size, int(32 * k), (2, 2)), ntk_generator.ResnetGroup(block_size, int(64 * k), (2, 2)), stax.Flatten(), stax.Dense(num_classes, 1., 0.))
def testPredictOnCPU(self): x_train = random.normal(random.PRNGKey(1), (4, 4, 4, 2)) x_test = random.normal(random.PRNGKey(1), (8, 4, 4, 2)) y_train = random.uniform(random.PRNGKey(1), (4, 2)) _, _, kernel_fn = stax.serial( stax.Conv(1, (3, 3)), stax.Relu(), stax.Flatten(), stax.Dense(1)) for store_on_device in [False, True]: for device_count in [0, 1]: for get in ['ntk', 'nngp', ('nngp', 'ntk'), ('ntk', 'nngp')]: for x in [None, 'x_test']: with self.subTest( store_on_device=store_on_device, device_count=device_count, get=get, x=x): kernel_fn_batched = batch.batch(kernel_fn, 2, device_count, store_on_device) predictor = predict.gradient_descent_mse_ensemble( kernel_fn_batched, x_train, y_train) x = x if x is None else x_test predict_none = predictor(None, x, get, compute_cov=True) predict_inf = predictor(np.inf, x, get, compute_cov=True) self.assertAllClose(predict_none, predict_inf) if x is not None: on_cpu = (not store_on_device or xla_bridge.get_backend().platform == 'cpu') self.assertEqual(on_cpu, utils.is_on_cpu(predict_inf)) self.assertEqual(on_cpu, utils.is_on_cpu(predict_none))
def test_composition_conv(self, avg_pool): rng = random.PRNGKey(0) x1 = random.normal(rng, (5, 10, 10, 3)) x2 = random.normal(rng, (5, 10, 10, 3)) Block = stax.serial(stax.Conv(256, (3, 3)), stax.Relu()) if avg_pool: Readout = stax.serial(stax.GlobalAvgPool(), stax.Dense(10)) marginalization = 'none' else: Readout = stax.serial(stax.Flatten(), stax.Dense(10)) marginalization = 'auto' block_ker_fn, readout_ker_fn = Block[2], Readout[2] _, _, composed_ker_fn = stax.serial(Block, Readout) ker_out = readout_ker_fn( block_ker_fn(x1, marginalization=marginalization)) composed_ker_out = composed_ker_fn(x1) self.assertAllClose(ker_out, composed_ker_out, True) if avg_pool: with self.assertRaises(ValueError): ker_out = readout_ker_fn(block_ker_fn(x1)) ker_out = readout_ker_fn( block_ker_fn(x1, x2, marginalization=marginalization)) composed_ker_out = composed_ker_fn(x1, x2) self.assertAllClose(ker_out, composed_ker_out, True)
def wide_resnet(block_size, k, num_classes): return stax.serial(stax.Conv(16, (3, 3), padding='SAME'), wide_resnet_group(block_size, int(16 * k)), wide_resnet_group(block_size, int(32 * k), (2, 2)), wide_resnet_group(block_size, int(64 * k), (2, 2)), stax.AvgPool((8, 8)), stax.Flatten(), stax.Dense(num_classes, 1., 0.))
def Resnet(block_size, num_classes): return stax.serial(stax.Conv(64, (3, 3), padding='SAME'), ResnetGroup(block_size, 64), ResnetGroup(block_size, 128, (2, 2)), ResnetGroup(block_size, 256, (2, 2)), ResnetGroup(block_size, 512, (2, 2)), stax.Flatten(), stax.Dense(num_classes, 1., 0.05))
def test_hermite(self, same_inputs, degree, get, readout): key = random.PRNGKey(1) key1, key2, key = random.split(key, 3) if degree > 2: width = 10000 n_samples = 5000 test_utils.skip_test(self) else: width = 10000 n_samples = 100 x1 = np.cos(random.normal(key1, [2, 6, 6, 3])) x2 = x1 if same_inputs else np.cos(random.normal(key2, [3, 6, 6, 3])) conv_layers = [ stax.Conv(width, (3, 3), W_std=2., b_std=0.5), stax.LayerNorm(), stax.Hermite(degree), stax.GlobalAvgPool() if readout == 'pool' else stax.Flatten(), stax.Dense(1) if get == 'ntk' else stax.Identity()] init_fn, apply_fn, kernel_fn = stax.serial(*conv_layers) analytic_kernel = kernel_fn(x1, x2, get) mc_kernel_fn = nt.monte_carlo_kernel_fn(init_fn, apply_fn, key, n_samples) mc_kernel = mc_kernel_fn(x1, x2, get) rot = degree / 2. * 1e-2 test_utils.assert_close_matrices(self, mc_kernel, analytic_kernel, rot)
def _test_analytic_kernel_composition(self, batching_fn): # Check Fully-Connected. rng = stateless_uniform(shape=[2], seed=[0, 0], minval=None, maxval=None, dtype=tf.int32) keys = tf_random_split(rng) rng_self = keys[0] rng_other = keys[1] x_self = np.asarray(normal((8, 10), seed=rng_self)) x_other = np.asarray(normal((2, 10), seed=rng_other)) Block = stax.serial(stax.Dense(256), stax.Relu()) _, _, ker_fn = Block ker_fn = batching_fn(ker_fn) _, _, composed_ker_fn = stax.serial(Block, Block) ker_out = ker_fn(ker_fn(x_self)) composed_ker_out = composed_ker_fn(x_self) if batching_fn == batch._parallel: # In the parallel setting, `x1_is_x2` is not computed correctly # when x1==x2. composed_ker_out = composed_ker_out.replace( x1_is_x2=ker_out.x1_is_x2) self.assertAllClose(ker_out, composed_ker_out) ker_out = ker_fn(ker_fn(x_self, x_other)) composed_ker_out = composed_ker_fn(x_self, x_other) if batching_fn == batch._parallel: composed_ker_out = composed_ker_out.replace( x1_is_x2=ker_out.x1_is_x2) self.assertAllClose(ker_out, composed_ker_out) # Check convolutional + pooling. x_self = np.asarray(normal((8, 10, 10, 3), seed=rng)) x_other = np.asarray(normal((2, 10, 10, 3), seed=rng)) Block = stax.serial(stax.Conv(256, (2, 2)), stax.Relu()) Readout = stax.serial(stax.GlobalAvgPool(), stax.Dense(10)) block_ker_fn, readout_ker_fn = Block[2], Readout[2] _, _, composed_ker_fn = stax.serial(Block, Readout) block_ker_fn = batching_fn(block_ker_fn) readout_ker_fn = batching_fn(readout_ker_fn) ker_out = readout_ker_fn(block_ker_fn(x_self)) composed_ker_out = composed_ker_fn(x_self) if batching_fn == batch._parallel: composed_ker_out = composed_ker_out.replace( x1_is_x2=ker_out.x1_is_x2) self.assertAllClose(ker_out, composed_ker_out) ker_out = readout_ker_fn(block_ker_fn(x_self, x_other)) composed_ker_out = composed_ker_fn(x_self, x_other) if batching_fn == batch._parallel: composed_ker_out = composed_ker_out.replace( x1_is_x2=ker_out.x1_is_x2) self.assertAllClose(ker_out, composed_ker_out)
def _test_activation(self, activation_fn, same_inputs, model, get, rbf_gamma=None): if 'conv' in model: test_utils.skip_test(self) key = random.PRNGKey(1) key, split = random.split(key) output_dim = 1024 if get == 'nngp' else 1 b_std = 0.5 W_std = 2.0 if activation_fn[2].__name__ == 'Sin': W_std = 0.9 if activation_fn[2].__name__ == 'Rbf': W_std = 1.0 b_std = 0.0 if model == 'fc': rtol = 0.04 X0_1 = random.normal(key, (4, 2)) X0_2 = None if same_inputs else random.normal(split, (2, 2)) affine = stax.Dense(1024, W_std, b_std) readout = stax.Dense(output_dim) depth = 1 else: rtol = 0.05 X0_1 = random.normal(key, (2, 4, 4, 3)) X0_2 = None if same_inputs else random.normal(split, (4, 4, 4, 3)) affine = stax.Conv(512, (3, 2), W_std=W_std, b_std=b_std, padding='SAME') readout = stax.serial(stax.GlobalAvgPool() if 'pool' in model else stax.Flatten(), stax.Dense(output_dim)) depth = 2 if default_backend() == 'cpu': num_samplings = 200 rtol *= 2 else: num_samplings = (500 if activation_fn[2].__name__ in ('Sin', 'Rbf') else 300) init_fn, apply_fn, kernel_fn = stax.serial( *[affine, activation_fn]*depth, readout) analytic_kernel = kernel_fn(X0_1, X0_2, get) mc_kernel_fn = nt.monte_carlo_kernel_fn( init_fn, apply_fn, split, num_samplings, implementation=2, vmap_axes=0 ) empirical_kernel = mc_kernel_fn(X0_1, X0_2, get) test_utils.assert_close_matrices(self, analytic_kernel, empirical_kernel, rtol) # Check match with explicit RBF if rbf_gamma is not None and get == 'nngp' and model == 'fc': input_dim = X0_1.shape[1] _, _, kernel_fn = self._RBF(rbf_gamma / input_dim) direct_rbf_kernel = kernel_fn(X0_1, X0_2, get) test_utils.assert_close_matrices(self, analytic_kernel, direct_rbf_kernel, rtol)
def testGradientDescentMseEnsembleTrain(self): key = stateless_uniform(shape=[2], seed=[1, 1], minval=None, maxval=None, dtype=tf.int32) x = np.asarray(normal((8, 4, 6, 3), seed=key)) _, _, kernel_fn = stax.serial(stax.Conv(1, (2, 2)), stax.Relu(), stax.Conv(1, (2, 1))) y = np.asarray(normal((8, 2, 5, 1), seed=key)) predictor = predict.gradient_descent_mse_ensemble(kernel_fn, x, y) for t in [None, np.array([0., 1., 10.])]: with self.subTest(t=t): y_none = predictor(t, None, None, compute_cov=True) y_x = predictor(t, x, None, compute_cov=True) self._assertAllClose(y_none, y_x, 0.04)
def MyConv(*args, parameterization='standard', order=None, **kwargs): """Wrapper for convolutional layer with different parameterizations.""" if parameterization == 'standard': return jax_stax.Conv(*args, **kwargs) elif parameterization == 'ntk': return stax.Conv(*args, b_std=1.0, **kwargs)[:2] elif parameterization == 'taylor': return TaylorConv(*args, b_std=1.0, order=order, **kwargs)
def _get_inputs_and_model(width=1, n_classes=2): key = random.PRNGKey(1) key, split = random.split(key) x1 = random.normal(key, (8, 4, 3, 2)) x2 = random.normal(split, (4, 4, 3, 2)) init_fun, apply_fun, ker_fun = stax.serial(stax.Conv(width, (3, 3)), stax.Relu(), stax.Flatten(), stax.Dense(n_classes, 2., 0.5)) return x1, x2, init_fun, apply_fun, ker_fun, key
def testPredictOnCPU(self): key1 = stateless_uniform(shape=[2], seed=[1, 1], minval=None, maxval=None, dtype=tf.int32) key2 = stateless_uniform(shape=[2], seed=[1, 1], minval=None, maxval=None, dtype=tf.int32) key3 = stateless_uniform(shape=[2], seed=[1, 1], minval=None, maxval=None, dtype=tf.int32) x_train = np.asarray(normal((4, 4, 4, 2), seed=key1)) x_test = np.asarray(normal((8, 4, 4, 2), seed=key2)) y_train = np.asarray(stateless_uniform(shape=(4, 2), seed=key3)) _, _, kernel_fn = stax.serial(stax.Conv(1, (3, 3)), stax.Relu(), stax.Flatten(), stax.Dense(1)) for store_on_device in [False, True]: for device_count in [0, 1]: for get in ['ntk', 'nngp', ('nngp', 'ntk'), ('ntk', 'nngp')]: for x in [None, 'x_test']: with self.subTest(store_on_device=store_on_device, device_count=device_count, get=get, x=x): kernel_fn_batched = batch.batch( kernel_fn, 2, device_count, store_on_device) predictor = predict.gradient_descent_mse_ensemble( kernel_fn_batched, x_train, y_train) x = x if x is None else x_test predict_none = predictor(None, x, get, compute_cov=True) predict_inf = predictor(np.inf, x, get, compute_cov=True) self.assertAllClose(predict_none, predict_inf) if x is not None: on_cpu = (not store_on_device or xla_bridge.get_backend().platform == 'cpu') self.assertEqual(on_cpu, utils.is_on_cpu(predict_inf)) self.assertEqual(on_cpu, utils.is_on_cpu(predict_none))
def WideResnet(block_size, k, num_classes): return stax.serial( stax.Conv(16, (3, 3), padding='SAME', parameterization='standard'), WideResnetGroup(block_size, int(16 * k)), WideResnetGroup(block_size, int(32 * k), (2, 2)), WideResnetGroup(block_size, int(64 * k), (2, 2)), stax.AvgPool((8, 8), padding='SAME'), stax.Flatten(), stax.Dense(num_classes, 1.0, 0.0, parameterization='standard'), )
def _test_analytic_kernel_composition(self, batching_fn): # Check Fully-Connected. rng = random.PRNGKey(0) rng_self, rng_other = random.split(rng) x_self = random.normal(rng_self, (8, 10)) x_other = random.normal(rng_other, (2, 10)) Block = stax.serial(stax.Dense(256), stax.Relu()) _, _, ker_fn = Block ker_fn = batching_fn(ker_fn) _, _, composed_ker_fn = stax.serial(Block, Block) ker_out = ker_fn(ker_fn(x_self)) composed_ker_out = composed_ker_fn(x_self) if batching_fn == batch._parallel: # In the parallel setting, `x1_is_x2` is not computed correctly # when x1==x2. composed_ker_out = composed_ker_out._replace( x1_is_x2=ker_out.x1_is_x2) self.assertAllClose(ker_out, composed_ker_out, True) ker_out = ker_fn(ker_fn(x_self, x_other)) composed_ker_out = composed_ker_fn(x_self, x_other) if batching_fn == batch._parallel: composed_ker_out = composed_ker_out._replace( x1_is_x2=ker_out.x1_is_x2) self.assertAllClose(ker_out, composed_ker_out, True) # Check convolutional + pooling. x_self = random.normal(rng, (8, 10, 10, 3)) x_other = random.normal(rng, (2, 10, 10, 3)) Block = stax.serial(stax.Conv(256, (2, 2)), stax.Relu()) Readout = stax.serial(stax.GlobalAvgPool(), stax.Dense(10)) block_ker_fn, readout_ker_fn = Block[2], Readout[2] _, _, composed_ker_fn = stax.serial(Block, Readout) block_ker_fn = batching_fn(block_ker_fn) readout_ker_fn = batching_fn(readout_ker_fn) ker_out = readout_ker_fn(block_ker_fn(x_self, marginalization='none')) composed_ker_out = composed_ker_fn(x_self) if batching_fn == batch._parallel: composed_ker_out = composed_ker_out._replace( x1_is_x2=ker_out.x1_is_x2) self.assertAllClose(ker_out, composed_ker_out, True) ker_out = readout_ker_fn( block_ker_fn(x_self, x_other, marginalization='none')) composed_ker_out = composed_ker_fn(x_self, x_other) if batching_fn == batch._parallel: composed_ker_out = composed_ker_out._replace( x1_is_x2=ker_out.x1_is_x2) self.assertAllClose(ker_out, composed_ker_out, True)
def WideResnetBlock(channels, strides=(1, 1), channel_mismatch=False): main = stax.serial( stax.Relu(), stax.Conv( channels, (3, 3), strides, padding='SAME', parameterization='standard' ), stax.Relu(), stax.Conv(channels, (3, 3), padding='SAME', parameterization='standard'), ) shortcut = ( stax.Identity() if not channel_mismatch else stax.Conv( channels, (3, 3), strides, padding='SAME', parameterization='standard' ) ) return stax.serial(stax.FanOut(2), stax.parallel(main, shortcut), stax.FanInSum())
def conv(out_chan, s): return stax.Conv( out_chan=out_chan, filter_shape=filter_shape, strides=strides, padding=padding, W_std=W_std, b_std=b_std, dimension_numbers=dimension_numbers, parameterization=parameterization, s=s )