def WideResnetBlocknt(channels, strides=(1, 1), channel_mismatch=False, batchnorm='std', parameterization='ntk'): """A WideResnet block, with or without BatchNorm.""" Main = stax_nt.serial( _batch_norm_internal(batchnorm), stax_nt.Relu(), stax_nt.Conv(channels, (3, 3), strides, padding='SAME', parameterization=parameterization), _batch_norm_internal(batchnorm), stax_nt.Relu(), stax_nt.Conv(channels, (3, 3), padding='SAME', parameterization=parameterization)) Shortcut = stax_nt.Identity() if not channel_mismatch else stax_nt.Conv( channels, (3, 3), strides, padding='SAME', parameterization=parameterization) return stax_nt.serial(stax_nt.FanOut(2), stax_nt.parallel(Main, Shortcut), stax_nt.FanInSum())
def main(unused_argv): key1, key2, key3 = random.split(random.PRNGKey(1), 3) x1 = random.normal(key1, (2, 8, 8, 3)) x2 = random.normal(key2, (3, 8, 8, 3)) # A vanilla CNN. init_fn, f, _ = stax.serial( stax.Conv(8, (3, 3)), stax.Relu(), stax.Conv(8, (3, 3)), stax.Relu(), stax.Conv(8, (3, 3)), stax.Flatten(), stax.Dense(10) ) _, params = init_fn(key3, x1.shape) kwargs = dict( f=f, trace_axes=(), vmap_axes=0, ) # Default, baseline Jacobian contraction. jacobian_contraction = nt.empirical_ntk_fn( **kwargs, implementation=nt.NtkImplementation.JACOBIAN_CONTRACTION) # (6, 3, 10, 10) full `np.ndarray` test-train NTK ntk_jc = jacobian_contraction(x2, x1, params) # NTK-vector products-based implementation. ntk_vector_products = nt.empirical_ntk_fn( **kwargs, implementation=nt.NtkImplementation.NTK_VECTOR_PRODUCTS) ntk_vp = ntk_vector_products(x2, x1, params) # Structured derivatives-based implementation. structured_derivatives = nt.empirical_ntk_fn( **kwargs, implementation=nt.NtkImplementation.STRUCTURED_DERIVATIVES) ntk_sd = structured_derivatives(x2, x1, params) # Auto-FLOPs-selecting implementation. Doesn't work correctly on CPU/GPU. auto = nt.empirical_ntk_fn( **kwargs, implementation=nt.NtkImplementation.AUTO) ntk_auto = auto(x2, x1, params) # Check that implementations match for ntk1 in [ntk_jc, ntk_vp, ntk_sd, ntk_auto]: for ntk2 in [ntk_jc, ntk_vp, ntk_sd, ntk_auto]: diff = np.max(np.abs(ntk1 - ntk2)) print(f'NTK implementation diff {diff}.') assert diff < (1e-4 if jax.default_backend() != 'tpu' else 0.1), diff print('All NTK implementations match.')
def _test_analytic_kernel_composition(self, batching_fn): # Check Fully-Connected. rng = stateless_uniform(shape=[2], seed=[0, 0], minval=None, maxval=None, dtype=tf.int32) keys = tf_random_split(rng) rng_self = keys[0] rng_other = keys[1] x_self = np.asarray(normal((8, 10), seed=rng_self)) x_other = np.asarray(normal((2, 10), seed=rng_other)) Block = stax.serial(stax.Dense(256), stax.Relu()) _, _, ker_fn = Block ker_fn = batching_fn(ker_fn) _, _, composed_ker_fn = stax.serial(Block, Block) ker_out = ker_fn(ker_fn(x_self)) composed_ker_out = composed_ker_fn(x_self) if batching_fn == batch._parallel: # In the parallel setting, `x1_is_x2` is not computed correctly # when x1==x2. composed_ker_out = composed_ker_out.replace( x1_is_x2=ker_out.x1_is_x2) self.assertAllClose(ker_out, composed_ker_out) ker_out = ker_fn(ker_fn(x_self, x_other)) composed_ker_out = composed_ker_fn(x_self, x_other) if batching_fn == batch._parallel: composed_ker_out = composed_ker_out.replace( x1_is_x2=ker_out.x1_is_x2) self.assertAllClose(ker_out, composed_ker_out) # Check convolutional + pooling. x_self = np.asarray(normal((8, 10, 10, 3), seed=rng)) x_other = np.asarray(normal((2, 10, 10, 3), seed=rng)) Block = stax.serial(stax.Conv(256, (2, 2)), stax.Relu()) Readout = stax.serial(stax.GlobalAvgPool(), stax.Dense(10)) block_ker_fn, readout_ker_fn = Block[2], Readout[2] _, _, composed_ker_fn = stax.serial(Block, Readout) block_ker_fn = batching_fn(block_ker_fn) readout_ker_fn = batching_fn(readout_ker_fn) ker_out = readout_ker_fn(block_ker_fn(x_self)) composed_ker_out = composed_ker_fn(x_self) if batching_fn == batch._parallel: composed_ker_out = composed_ker_out.replace( x1_is_x2=ker_out.x1_is_x2) self.assertAllClose(ker_out, composed_ker_out) ker_out = readout_ker_fn(block_ker_fn(x_self, x_other)) composed_ker_out = composed_ker_fn(x_self, x_other) if batching_fn == batch._parallel: composed_ker_out = composed_ker_out.replace( x1_is_x2=ker_out.x1_is_x2) self.assertAllClose(ker_out, composed_ker_out)
def ResnetBlock(channels, strides=(1, 1), channel_mismatch=False): Main = stax.serial(stax.Relu(), stax.Conv(channels, (3, 3), strides, padding='SAME'), stax.Relu(), stax.Conv(channels, (3, 3), padding='SAME')) Shortcut = stax.Identity() if not channel_mismatch else stax.Conv( channels, (3, 3), strides, padding='SAME') return stax.serial(stax.FanOut(2), stax.parallel(Main, Shortcut), stax.FanInSum())
def build_dense_network( hidden_layers: Sequence[int], activations: Union[Sequence, str] = "erf", w_std: float = 2.5, b_std=1, ) -> NTModel: """Utility function to build a simple feedforward network with the neural tangents library. Args: hidden_layers (Sequence[int]): Iterable with the number of neurons. For example, [512, 512] activations (Union[Sequence, str], optional): Iterable with neural_tangents.stax axtivations or "relu" or "erf". Defaults to "erf". w_std (float): Standard deviation of the weight distribution. b_std (float): Standard deviation of the bias distribution. Returns: NTModel: jiited init, apply and kernel functions, predict_function (None) """ from jax.config import config # pylint:disable=import-outside-toplevel config.update("jax_enable_x64", True) from jax import jit # pylint:disable=import-outside-toplevel from neural_tangents import stax # pylint:disable=import-outside-toplevel assert len(hidden_layers) >= 1, "You must provide at least one hidden layer" if activations is None: activations = [stax.Relu() for _ in hidden_layers] elif isinstance(activations, str): if activations.lower() == "relu": activations = [stax.Relu() for _ in hidden_layers] elif activations.lower() == "erf": activations = [stax.Erf() for _ in hidden_layers] else: for activation in activations: assert callable(activation), "You need to provide `neural_tangents.stax` activations" assert len(activations) == len( hidden_layers ), "The number of hidden layers should match the number of nonlinearities" stack = [] for hidden_layer, activation in zip(hidden_layers, activations): stack.append(stax.Dense(hidden_layer, W_std=w_std, b_std=b_std)) stack.append(activation) stack.append(stax.Dense(1, W_std=w_std, b_std=b_std)) init_fn, apply_fn, kernel_fn = stax.serial(*stack) return NTModel(init_fn, jit(apply_fn), jit(kernel_fn, static_argnums=(2,)), None)
def _test_analytic_kernel_composition(self, batching_fn): # Check Fully-Connected. rng = random.PRNGKey(0) rng_self, rng_other = random.split(rng) x_self = random.normal(rng_self, (8, 10)) x_other = random.normal(rng_other, (2, 10)) Block = stax.serial(stax.Dense(256), stax.Relu()) _, _, ker_fn = Block ker_fn = batching_fn(ker_fn) _, _, composed_ker_fn = stax.serial(Block, Block) ker_out = ker_fn(ker_fn(x_self)) composed_ker_out = composed_ker_fn(x_self) if batching_fn == batch._parallel: # In the parallel setting, `x1_is_x2` is not computed correctly # when x1==x2. composed_ker_out = composed_ker_out._replace( x1_is_x2=ker_out.x1_is_x2) self.assertAllClose(ker_out, composed_ker_out, True) ker_out = ker_fn(ker_fn(x_self, x_other)) composed_ker_out = composed_ker_fn(x_self, x_other) if batching_fn == batch._parallel: composed_ker_out = composed_ker_out._replace( x1_is_x2=ker_out.x1_is_x2) self.assertAllClose(ker_out, composed_ker_out, True) # Check convolutional + pooling. x_self = random.normal(rng, (8, 10, 10, 3)) x_other = random.normal(rng, (2, 10, 10, 3)) Block = stax.serial(stax.Conv(256, (2, 2)), stax.Relu()) Readout = stax.serial(stax.GlobalAvgPool(), stax.Dense(10)) block_ker_fn, readout_ker_fn = Block[2], Readout[2] _, _, composed_ker_fn = stax.serial(Block, Readout) block_ker_fn = batching_fn(block_ker_fn) readout_ker_fn = batching_fn(readout_ker_fn) ker_out = readout_ker_fn(block_ker_fn(x_self, marginalization='none')) composed_ker_out = composed_ker_fn(x_self) if batching_fn == batch._parallel: composed_ker_out = composed_ker_out._replace( x1_is_x2=ker_out.x1_is_x2) self.assertAllClose(ker_out, composed_ker_out, True) ker_out = readout_ker_fn( block_ker_fn(x_self, x_other, marginalization='none')) composed_ker_out = composed_ker_fn(x_self, x_other) if batching_fn == batch._parallel: composed_ker_out = composed_ker_out._replace( x1_is_x2=ker_out.x1_is_x2) self.assertAllClose(ker_out, composed_ker_out, True)
def build_le_net(network_width): """ Construct the LeNet of width network_width with average pooling using neural tangent's stax.""" return stax.serial( stax.Conv(out_chan=6 * network_width, filter_shape=(3, 3), strides=(1, 1), padding='VALID'), stax.Relu(), stax.AvgPool(window_shape=(2, 2), strides=(2, 2)), stax.Conv(out_chan=16 * network_width, filter_shape=(3, 3), strides=(1, 1), padding='VALID'), stax.Relu(), stax.AvgPool(window_shape=(2, 2), strides=(2, 2)), stax.Flatten(), stax.Dense(120 * network_width), stax.Relu(), stax.Dense(84 * network_width), stax.Relu(), stax.Dense(10))
def test_parameterizations(self, model, width, same_inputs, is_ntk, filter_shape, proj_into_2d, parameterization): is_conv = 'conv' in model W_std, b_std = 2.**0.5, 0.5**0.5 padding = PADDINGS[0] strides = STRIDES[0] phi = stax.Relu() use_pooling, is_res = False, False layer_norm = None pool_type = 'AVG' use_dropout = False # Check for duplicate / incorrectly-shaped NN configs / wrong backend. if is_conv: if xla_bridge.get_backend().platform == 'cpu': raise jtu.SkipTest('Not running CNN models on CPU to save time.') elif proj_into_2d != PROJECTIONS[0]: raise jtu.SkipTest('FC models do not have these parameters.') net = _get_net(W_std, b_std, filter_shape, is_conv, use_pooling, is_res, padding, phi, strides, width, is_ntk, proj_into_2d, pool_type, layer_norm, parameterization, use_dropout) self._check_agreement_with_empirical(net, same_inputs, is_conv, use_dropout, is_ntk, proj_into_2d)
def _get_net_pool(width, is_ntk, pool_type, padding, filter_shape, strides, normalize_edges): W_std, b_std = 2.**0.5, 0.5**0.5 phi = stax.Relu() parameterization = 'ntk' fc = partial( stax.Dense, W_std=W_std, b_std=b_std, parameterization=parameterization) conv = partial( stax.Conv, filter_shape=(3, 2), strides=None, padding='SAME', W_std=W_std, b_std=b_std, parameterization=parameterization) if pool_type == 'AVG': pool_fn = partial(stax.AvgPool, normalize_edges=normalize_edges) global_pool_fn = stax.GlobalAvgPool elif pool_type == 'SUM': pool_fn = stax.SumPool global_pool_fn = stax.GlobalSumPool pool = pool_fn(filter_shape, strides, padding) return stax.serial( conv(width), phi, pool, conv(width), phi, global_pool_fn(), fc(1 if is_ntk else width)), INPUT_SHAPE
def testPredictOnCPU(self): x_train = random.normal(random.PRNGKey(1), (4, 4, 4, 2)) x_test = random.normal(random.PRNGKey(1), (8, 4, 4, 2)) y_train = random.uniform(random.PRNGKey(1), (4, 2)) _, _, kernel_fn = stax.serial( stax.Conv(1, (3, 3)), stax.Relu(), stax.Flatten(), stax.Dense(1)) for store_on_device in [False, True]: for device_count in [0, 1]: for get in ['ntk', 'nngp', ('nngp', 'ntk'), ('ntk', 'nngp')]: for x in [None, 'x_test']: with self.subTest( store_on_device=store_on_device, device_count=device_count, get=get, x=x): kernel_fn_batched = batch.batch(kernel_fn, 2, device_count, store_on_device) predictor = predict.gradient_descent_mse_ensemble( kernel_fn_batched, x_train, y_train) x = x if x is None else x_test predict_none = predictor(None, x, get, compute_cov=True) predict_inf = predictor(np.inf, x, get, compute_cov=True) self.assertAllClose(predict_none, predict_inf) if x is not None: on_cpu = (not store_on_device or xla_bridge.get_backend().platform == 'cpu') self.assertEqual(on_cpu, utils.is_on_cpu(predict_inf)) self.assertEqual(on_cpu, utils.is_on_cpu(predict_none))
def WideResnetnt( block_size, k, num_classes, batchnorm='std'): #, batch_norm=None,layer_norm=None,freezelast=None): """Based off of WideResnet from paper, with or without BatchNorm. (Set config.wrn_block_size=3, config.wrn_widening_f=10 in that case). Uses default weight and bias init.""" parameterization = 'standard' layers_lst = [ stax_nt.Conv(16, (3, 3), padding='SAME', parameterization=parameterization), WideResnetGroupnt(block_size, 16 * k, parameterization=parameterization, batchnorm=batchnorm), WideResnetGroupnt(block_size, 32 * k, (2, 2), parameterization=parameterization, batchnorm=batchnorm), WideResnetGroupnt(block_size, 64 * k, (2, 2), parameterization=parameterization, batchnorm=batchnorm) ] layers_lst += [_batch_norm_internal(batchnorm), stax_nt.Relu()] layers_lst += [ stax_nt.AvgPool((8, 8)), stax_nt.Flatten(), stax_nt.Dense(num_classes, parameterization=parameterization) ] return stax_nt.serial(*layers_lst)
def GP(x_train, y_train, x_test, y_test, w_std, b_std, l, C): net0 = stax.Dense(1, w_std, b_std) nets = [net0] k_layer = [] K = net0[2](x_train, None) k_layer.append(K.nngp) for l in range(1, l + 1): net_l = stax.serial(stax.Relu(), stax.Dense(1, w_std, b_std)) K = net_l[2](K) k_layer.append(K.nngp) nets += [stax.serial(nets[-1], net_l)] kernel_fn = nets[-1][2] start = time.time() # Bayesian and infinite-time gradient descent inference with infinite network. fx_test_nngp, fx_test_ntk = nt.predict.gp_inference(kernel_fn, x_train, y_train, x_test, get=('nngp', 'ntk'), diag_reg=C) fx_test_nngp.block_until_ready() duration = time.time() - start #print('Kernel construction and inference done in %s seconds.' % duration) return accuracy(y_test, fx_test_nngp)
def main(unused_argv): train_size = FLAGS.train_size x_train, y_train, x_test, y_test = pickle.load( open("data_" + str(train_size) + ".p", "rb")) print("Got data") sys.stdout.flush() # Build the network init_fn, apply_fn, _ = stax.serial( stax.Dense(2048, 1., 0.05), # stax.Erf(), stax.Relu(), stax.Dense(1, 1., 0.05)) # initialize the network first time, to compute NTK randnnn = numpy.random.random_integers(np.iinfo(np.int32).min, high=np.iinfo(np.int32).max, size=2)[0] key = random.PRNGKey(randnnn) _, params = init_fn(key, (-1, 784)) # Create an MSE predictor to solve the NTK equation in function space. # we assume that the NTK is approximately the same for any sample of parameters (true in the limit of infinite width) print("Making NTK") sys.stdout.flush() ntk = nt.batch(nt.empirical_ntk_fn(apply_fn), batch_size=4, device_count=1) g_dd = ntk(x_train, None, params) pickle.dump(g_dd, open("ntk_train_" + str(FLAGS.train_size) + ".p", "wb")) g_td = ntk(x_test, x_train, params) pickle.dump(g_td, open("ntk_train_test_" + str(FLAGS.train_size) + ".p", "wb")) predictor = nt.predict.gradient_descent_mse(g_dd, y_train, g_td)
def _build_network(input_shape, network, out_logits): if len(input_shape) == 1: assert network == FLAT return stax.Dense(out_logits, W_std=2.0, b_std=0.5) elif len(input_shape) == 3: if network == POOLING: return stax.serial( stax.Conv(CONVOLUTION_CHANNELS, (3, 3), W_std=2.0, b_std=0.05), stax.GlobalAvgPool(), stax.Dense(out_logits, W_std=2.0, b_std=0.5)) elif network == CONV: return stax.serial( stax.Conv(CONVOLUTION_CHANNELS, (1, 2), W_std=1.5, b_std=0.1), stax.Relu(), stax.Conv(CONVOLUTION_CHANNELS, (3, 2), W_std=2.0, b_std=0.05), ) elif network == FLAT: return stax.serial( stax.Conv(CONVOLUTION_CHANNELS, (3, 3), W_std=2.0, b_std=0.05), stax.Flatten(), stax.Dense(out_logits, W_std=2.0, b_std=0.5)) else: raise ValueError( 'Unexpected network type found: {}'.format(network)) else: raise ValueError('Expected flat or image test input.')
def main(unused_argv): # Build data pipelines. print('Loading data.') x_train, y_train, x_test, y_test = \ datasets.get_dataset('cifar10', FLAGS.train_size, FLAGS.test_size) # Build the infinite network. _, _, kernel_fn = stax.serial(stax.Dense(1, 2., 0.05), stax.Relu(), stax.Dense(1, 2., 0.05)) # Optionally, compute the kernel in batches, in parallel. kernel_fn = nt.batch(kernel_fn, device_count=0, batch_size=FLAGS.batch_size) start = time.time() # Bayesian and infinite-time gradient descent inference with infinite network. predict_fn = nt.predict.gradient_descent_mse_ensemble(kernel_fn, x_train, y_train, diag_reg=1e-3) fx_test_nngp, fx_test_ntk = predict_fn(x_test=x_test) fx_test_nngp.block_until_ready() fx_test_ntk.block_until_ready() duration = time.time() - start print('Kernel construction and inference done in %s seconds.' % duration) # Print out accuracy and loss for infinite network predictions. loss = lambda fx, y_hat: 0.5 * np.mean((fx - y_hat)**2) util.print_summary('NNGP test', y_test, fx_test_nngp, None, loss) util.print_summary('NTK test', y_test, fx_test_ntk, None, loss)
def test_layernorm(self, model, width, same_inputs, is_ntk, proj_into_2d, layer_norm): is_conv = 'conv' in model # Check for duplicate / incorrectly-shaped NN configs / wrong backend. if is_conv: test_utils.skip_test(self) elif proj_into_2d != PROJECTIONS[0] or layer_norm not in ('C', 'NC'): raise absltest.SkipTest('FC models do not have these parameters.') W_std, b_std = 2.**0.5, 0.5**0.5 filter_shape = FILTER_SHAPES[0] padding = PADDINGS[0] strides = STRIDES[0] phi = stax.Relu() use_pooling, is_res = False, False parameterization = 'ntk' pool_type = 'AVG' use_dropout = False net = _get_net(W_std, b_std, filter_shape, is_conv, use_pooling, is_res, padding, phi, strides, width, is_ntk, proj_into_2d, pool_type, layer_norm, parameterization, 1, use_dropout) _check_agreement_with_empirical(self, net, same_inputs, use_dropout, is_ntk, 0.07)
def testPredictOnCPU(self): x_train = random.normal(random.PRNGKey(1), (10, 4, 5, 3)) x_test = random.normal(random.PRNGKey(1), (8, 4, 5, 3)) y_train = random.uniform(random.PRNGKey(1), (10, 7)) _, _, kernel_fn = stax.serial(stax.Conv(1, (3, 3)), stax.Relu(), stax.Flatten(), stax.Dense(1)) for store_on_device in [False, True]: for device_count in [0, 1]: for get in ['ntk', 'nngp', ('nngp', 'ntk'), ('ntk', 'nngp')]: with self.subTest(store_on_device=store_on_device, device_count=device_count, get=get): kernel_fn_batched = batch.batch( kernel_fn, 2, device_count, store_on_device) predictor = predict.gradient_descent_mse_gp( kernel_fn_batched, x_train, y_train, x_test, get, 0., True) gp_inference = predict.gp_inference( kernel_fn_batched, x_train, y_train, x_test, get, 0., True) self.assertAllClose(predictor(None), predictor(np.inf), True) self.assertAllClose(predictor(None), gp_inference, True)
def test_composition_conv(self, avg_pool, same_inputs): rng = random.PRNGKey(0) x1 = random.normal(rng, (3, 5, 5, 3)) x2 = None if same_inputs else random.normal(rng, (4, 5, 5, 3)) Block = stax.serial(stax.Conv(256, (3, 3)), stax.Relu()) if avg_pool: Readout = stax.serial(stax.Conv(256, (3, 3)), stax.GlobalAvgPool(), stax.Dense(10)) else: Readout = stax.serial(stax.Flatten(), stax.Dense(10)) block_ker_fn, readout_ker_fn = Block[2], Readout[2] _, _, composed_ker_fn = stax.serial(Block, Readout) composed_ker_out = composed_ker_fn(x1, x2) ker_out_no_marg = readout_ker_fn(block_ker_fn(x1, x2, diagonal_spatial=False)) ker_out_default = readout_ker_fn(block_ker_fn(x1, x2)) self.assertAllClose(composed_ker_out, ker_out_no_marg) self.assertAllClose(composed_ker_out, ker_out_default) if avg_pool: with self.assertRaises(ValueError): ker_out = readout_ker_fn(block_ker_fn(x1, x2, diagonal_spatial=True)) else: ker_out_marg = readout_ker_fn(block_ker_fn(x1, x2, diagonal_spatial=True)) self.assertAllClose(composed_ker_out, ker_out_marg)
def test_composition_conv(self, avg_pool): rng = random.PRNGKey(0) x1 = random.normal(rng, (5, 10, 10, 3)) x2 = random.normal(rng, (5, 10, 10, 3)) Block = stax.serial(stax.Conv(256, (3, 3)), stax.Relu()) if avg_pool: Readout = stax.serial(stax.GlobalAvgPool(), stax.Dense(10)) marginalization = 'none' else: Readout = stax.serial(stax.Flatten(), stax.Dense(10)) marginalization = 'auto' block_ker_fn, readout_ker_fn = Block[2], Readout[2] _, _, composed_ker_fn = stax.serial(Block, Readout) ker_out = readout_ker_fn( block_ker_fn(x1, marginalization=marginalization)) composed_ker_out = composed_ker_fn(x1) self.assertAllClose(ker_out, composed_ker_out, True) if avg_pool: with self.assertRaises(ValueError): ker_out = readout_ker_fn(block_ker_fn(x1)) ker_out = readout_ker_fn( block_ker_fn(x1, x2, marginalization=marginalization)) composed_ker_out = composed_ker_fn(x1, x2) self.assertAllClose(ker_out, composed_ker_out, True)
def test_layernorm(self, model, width, same_inputs, is_ntk, proj_into_2d, layer_norm): is_conv = 'conv' in model # Check for duplicate / incorrectly-shaped NN configs / wrong backend. if is_conv: if xla_bridge.get_backend().platform == 'cpu': raise jtu.SkipTest( 'Not running CNN models on CPU to save time.') elif proj_into_2d != PROJECTIONS[0] or layer_norm != LAYER_NORM[0]: raise jtu.SkipTest('FC models do not have these parameters.') W_std, b_std = 2.**0.5, 0.5**0.5 filter_size = FILTER_SIZES[0] padding = PADDINGS[0] strides = STRIDES[0] phi = stax.Relu() use_pooling, is_res = False, False parameterization = 'ntk' use_dropout = False self._check_agreement_with_empirical(W_std, b_std, filter_size, is_conv, is_ntk, is_res, layer_norm, padding, phi, proj_into_2d, same_inputs, strides, use_pooling, width, parameterization, use_dropout)
def test_empirical_ntk_diagonal_outputs(self, same_inputs, device_count, trace_axes, diagonal_axes): test_utils.stub_out_pmap(batching, 2) rng = random.PRNGKey(0) input_key1, input_key2, net_key = random.split(rng, 3) init_fn, apply_fn, _ = stax.serial(stax.Dense(5), stax.Relu(), stax.Dense(3)) test_x1 = random.normal(input_key1, (12, 4, 4)) test_x2 = None if same_inputs: test_x2 = random.normal(input_key2, (9, 4, 4)) kernel_fn = nt.empirical_ntk_fn(apply_fn, trace_axes=trace_axes, diagonal_axes=diagonal_axes, vmap_axes=0, implementation=2) _, params = init_fn(net_key, test_x1.shape) true_kernel = kernel_fn(test_x1, test_x2, params) batched_fn = batching.batch(kernel_fn, device_count=device_count, batch_size=3) batch_kernel = batched_fn(test_x1, test_x2, params) self.assertAllClose(true_kernel, batch_kernel)
def test_nonlinear( self, model, width, same_inputs, is_ntk, filter_shape, proj_into_2d, b_std, W_std, parameterization, s ): is_conv = 'conv' in model if parameterization == 'standard': width //= s padding = PADDINGS[0] strides = STRIDES[0] phi = stax.Relu() use_pooling, is_res = False, False layer_norm = None pool_type = 'AVG' use_dropout = False # Check for duplicate / incorrectly-shaped NN configs / wrong backend. if is_conv: test_utils.skip_test(self) elif proj_into_2d != PROJECTIONS[0] or filter_shape != FILTER_SHAPES[0]: raise absltest.SkipTest('FC models do not have these parameters.') net = _get_net(W_std=W_std, b_std=b_std, filter_shape=filter_shape, is_conv=is_conv, use_pooling=use_pooling, is_res=is_res, padding=padding, phi=phi, strides=strides, width=width, is_ntk=is_ntk, proj_into_2d=proj_into_2d, pool_type=pool_type, layer_norm=layer_norm, parameterization=parameterization, s=s, use_dropout=use_dropout) _check_agreement_with_empirical( self, net=net, same_inputs=same_inputs, use_dropout=use_dropout, is_ntk=is_ntk, rtol=0.015, atol=1000 )
def testGpInference(self): reg = 1e-5 key = random.PRNGKey(1) x_train = random.normal(key, (4, 2)) init_fn, apply_fn, kernel_fn_analytic = stax.serial( stax.Dense(32, 2., 0.5), stax.Relu(), stax.Dense(10, 2., 0.5)) y_train = random.normal(key, (4, 10)) for kernel_fn_is_analytic in [True, False]: if kernel_fn_is_analytic: kernel_fn = kernel_fn_analytic else: _, params = init_fn(key, x_train.shape) kernel_fn_empirical = empirical.empirical_kernel_fn(apply_fn) def kernel_fn(x1, x2, get): return kernel_fn_empirical(x1, x2, get, params) for get in [None, 'nngp', 'ntk', ('nngp',), ('ntk',), ('nngp', 'ntk'), ('ntk', 'nngp')]: k_dd = kernel_fn(x_train, None, get) gp_inference = predict.gp_inference(k_dd, y_train, diag_reg=reg) gd_ensemble = predict.gradient_descent_mse_ensemble(kernel_fn, x_train, y_train, diag_reg=reg) for x_test in [None, 'x_test']: x_test = None if x_test is None else random.normal(key, (8, 2)) k_td = None if x_test is None else kernel_fn(x_test, x_train, get) for compute_cov in [True, False]: with self.subTest(kernel_fn_is_analytic=kernel_fn_is_analytic, get=get, x_test=x_test if x_test is None else 'x_test', compute_cov=compute_cov): if compute_cov: nngp_tt = (True if x_test is None else kernel_fn(x_test, None, 'nngp')) else: nngp_tt = None out_ens = gd_ensemble(None, x_test, get, compute_cov) out_ens_inf = gd_ensemble(np.inf, x_test, get, compute_cov) self._assertAllClose(out_ens_inf, out_ens, 0.08) if (get is not None and 'nngp' not in get and compute_cov and k_td is not None): with self.assertRaises(ValueError): out_gp_inf = gp_inference(get=get, k_test_train=k_td, nngp_test_test=nngp_tt) else: out_gp_inf = gp_inference(get=get, k_test_train=k_td, nngp_test_test=nngp_tt) self.assertAllClose(out_ens, out_gp_inf)
def test_vmap_axes(self, same_inputs): n1, n2 = 3, 4 c1, c2, c3 = 9, 5, 7 h2, h3, w3 = 6, 8, 2 def get_x(n, k): k1, k2, k3 = random.split(k, 3) x1 = random.normal(k1, (n, c1)) x2 = random.normal(k2, (h2, n, c2)) x3 = random.normal(k3, (c3, w3, n, h3)) x = [(x1, x2), x3] return x x1 = get_x(n1, random.PRNGKey(1)) x2 = get_x(n2, random.PRNGKey(2)) if not same_inputs else None p1 = random.normal(random.PRNGKey(5), (n1, h2, h2)) p2 = None if same_inputs else random.normal(random.PRNGKey(6), (n2, h2, h2)) init_fn, apply_fn, _ = stax.serial( stax.parallel( stax.parallel( stax.serial(stax.Dense(4, 2., 0.1), stax.Relu(), stax.Dense(3, 1., 0.15)), # 1 stax.serial(stax.Conv(7, (2,), padding='SAME', dimension_numbers=('HNC', 'OIH', 'NHC')), stax.Erf(), stax.Aggregate(1, 0, -1), stax.GlobalAvgPool(), stax.Dense(3, 0.5, 0.2)), # 2 ), stax.serial( stax.Conv(5, (2, 3), padding='SAME', dimension_numbers=('CWNH', 'IOHW', 'HWCN')), stax.Sin(), ) # 3 ), stax.parallel( stax.FanInSum(), stax.Conv(2, (2, 1), dimension_numbers=('HWCN', 'OIHW', 'HNWC')) ) ) _, params = init_fn(random.PRNGKey(3), tree_map(np.shape, x1)) implicit = jit(empirical._empirical_implicit_ntk_fn(apply_fn)) direct = jit(empirical._empirical_direct_ntk_fn(apply_fn)) implicit_batched = jit(empirical._empirical_implicit_ntk_fn( apply_fn, vmap_axes=([(0, 1), 2], [-2, -3], dict(pattern=0)))) direct_batched = jit(empirical._empirical_direct_ntk_fn( apply_fn, vmap_axes=([(-2, -2), -2], [0, 1], dict(pattern=-3)))) k = direct(x1, x2, params, pattern=(p1, p2)) self.assertAllClose(k, implicit(x1, x2, params, pattern=(p1, p2))) self.assertAllClose(k, direct_batched(x1, x2, params, pattern=(p1, p2))) self.assertAllClose(k, implicit_batched(x1, x2, params, pattern=(p1, p2)))
def _get_inputs_and_model(width=1, n_classes=2): key = random.PRNGKey(1) key, split = random.split(key) x1 = random.normal(key, (8, 4, 3, 2)) x2 = random.normal(split, (4, 4, 3, 2)) init_fun, apply_fun, ker_fun = stax.serial(stax.Conv(width, (3, 3)), stax.Relu(), stax.Flatten(), stax.Dense(n_classes, 2., 0.5)) return x1, x2, init_fun, apply_fun, ker_fun, key
def create_network(depth, width): layers = [] for l in range(depth): layers += [ stax.Dense(M, W_std=1.5, b_std=0.0, parameterization='ntk'), stax.Relu() ] layers += [stax.Dense(1, W_std=1.5, b_std=0, parameterization='ntk')] return stax.serial(*layers)
def testPredictOnCPU(self): key1 = stateless_uniform(shape=[2], seed=[1, 1], minval=None, maxval=None, dtype=tf.int32) key2 = stateless_uniform(shape=[2], seed=[1, 1], minval=None, maxval=None, dtype=tf.int32) key3 = stateless_uniform(shape=[2], seed=[1, 1], minval=None, maxval=None, dtype=tf.int32) x_train = np.asarray(normal((4, 4, 4, 2), seed=key1)) x_test = np.asarray(normal((8, 4, 4, 2), seed=key2)) y_train = np.asarray(stateless_uniform(shape=(4, 2), seed=key3)) _, _, kernel_fn = stax.serial(stax.Conv(1, (3, 3)), stax.Relu(), stax.Flatten(), stax.Dense(1)) for store_on_device in [False, True]: for device_count in [0, 1]: for get in ['ntk', 'nngp', ('nngp', 'ntk'), ('ntk', 'nngp')]: for x in [None, 'x_test']: with self.subTest(store_on_device=store_on_device, device_count=device_count, get=get, x=x): kernel_fn_batched = batch.batch( kernel_fn, 2, device_count, store_on_device) predictor = predict.gradient_descent_mse_ensemble( kernel_fn_batched, x_train, y_train) x = x if x is None else x_test predict_none = predictor(None, x, get, compute_cov=True) predict_inf = predictor(np.inf, x, get, compute_cov=True) self.assertAllClose(predict_none, predict_inf) if x is not None: on_cpu = (not store_on_device or xla_bridge.get_backend().platform == 'cpu') self.assertEqual(on_cpu, utils.is_on_cpu(predict_inf)) self.assertEqual(on_cpu, utils.is_on_cpu(predict_none))
def _test_analytic_kernel_composition(self, batching_fn): # Check Fully-Connected. rng = random.PRNGKey(0) rng_self, rng_other = random.split(rng) x_self = random.normal(rng_self, (8, 10)) x_other = random.normal(rng_other, (20, 10)) Block = stax.serial(stax.Dense(256), stax.Relu()) _, _, ker_fn = Block ker_fn = batching_fn(ker_fn) _, _, composed_ker_fn = stax.serial(Block, Block) ker_out = ker_fn(ker_fn(x_self)) composed_ker_out = composed_ker_fn(x_self) self.assertAllClose(ker_out, composed_ker_out, True) ker_out = ker_fn(ker_fn(x_self, x_other)) composed_ker_out = composed_ker_fn(x_self, x_other) self.assertAllClose(ker_out, composed_ker_out, True) # Check convolutional + pooling. x_self = random.normal(rng, (8, 10, 10, 3)) x_other = random.normal(rng, (10, 10, 10, 3)) Block = stax.serial(stax.Conv(256, (3, 3)), stax.Relu()) Readout = stax.serial(stax.GlobalAvgPool(), stax.Dense(10)) block_ker_fn, readout_ker_fn = Block[2], Readout[2] _, _, composed_ker_fn = stax.serial(Block, Readout) block_ker_fn = batching_fn(block_ker_fn) readout_ker_fn = batching_fn(readout_ker_fn) ker_out = readout_ker_fn(block_ker_fn(x_self, marginalization='none')) composed_ker_out = composed_ker_fn(x_self) self.assertAllClose(ker_out, composed_ker_out, True) ker_out = readout_ker_fn( block_ker_fn(x_self, x_other, marginalization='none')) composed_ker_out = composed_ker_fn(x_self, x_other) self.assertAllClose(ker_out, composed_ker_out, True)
def test_composition(self): rng = random.PRNGKey(0) xs = random.normal(rng, (10, 10)) Block = stax.serial(stax.Dense(256), stax.Relu()) _, _, ker_fn = Block _, _, composed_ker_fn = stax.serial(Block, Block) ker_out = ker_fn(ker_fn(xs)) composed_ker_out = composed_ker_fn(xs) self.assertAllClose(ker_out, composed_ker_out, True)
def WideResnetBlock(channels, strides=(1, 1), channel_mismatch=False): main = stax.serial( stax.Relu(), stax.Conv( channels, (3, 3), strides, padding='SAME', parameterization='standard' ), stax.Relu(), stax.Conv(channels, (3, 3), padding='SAME', parameterization='standard'), ) shortcut = ( stax.Identity() if not channel_mismatch else stax.Conv( channels, (3, 3), strides, padding='SAME', parameterization='standard' ) ) return stax.serial(stax.FanOut(2), stax.parallel(main, shortcut), stax.FanInSum())