def main(unused_argv): # Build data pipelines. print('Loading data.') x_train, y_train, x_test, y_test = \ datasets.mnist(FLAGS.train_size, FLAGS.test_size) # Build the network init_fn, f, _ = stax.serial(stax.Dense(2048, 1., 0.05), stax.Erf(), stax.Dense(10, 1., 0.05)) key = random.PRNGKey(0) _, params = init_fn(key, (-1, 784)) # Create and initialize an optimizer. opt_init, opt_apply, get_params = optimizers.sgd(FLAGS.learning_rate) state = opt_init(params) # Create an mse loss function and a gradient function. loss = lambda fx, y_hat: 0.5 * np.mean((fx - y_hat)**2) grad_loss = jit(grad(lambda params, x, y: loss(f(params, x), y))) # Create an MSE predictor to solve the NTK equation in function space. ntk = batch(get_ntk_fun_empirical(f), batch_size=4, device_count=0) g_dd = ntk(x_train, None, params) g_td = ntk(x_test, x_train, params) predictor = predict.gradient_descent_mse(g_dd, y_train, g_td) # Get initial values of the network in function space. fx_train = f(params, x_train) fx_test = f(params, x_test) # Train the network. train_steps = int(FLAGS.train_time // FLAGS.learning_rate) print('Training for {} steps'.format(train_steps)) for i in range(train_steps): params = get_params(state) state = opt_apply(i, grad_loss(params, x_train, y_train), state) # Get predictions from analytic computation. print('Computing analytic prediction.') fx_train, fx_test = predictor(FLAGS.train_time, fx_train, fx_test) # Print out summary data comparing the linear / nonlinear model. util.print_summary('train', y_train, f(params, x_train), fx_train, loss) util.print_summary('test', y_test, f(params, x_test), fx_test, loss)
def _build_network(input_shape, network, out_logits): if len(input_shape) == 1: assert network == FLAT return stax.serial( stax.Dense(4096, W_std=1.2, b_std=0.05), stax.Erf(), stax.Dense(out_logits, W_std=1.2, b_std=0.05)) elif len(input_shape) == 3: if network == POOLING: return stax.serial( stax.Conv(CONVOLUTION_CHANNELS, (3, 3), W_std=2.0, b_std=0.05), stax.GlobalAvgPool(), stax.Dense(out_logits, W_std=2.0, b_std=0.05)) elif network == FLAT: return stax.serial( stax.Conv(CONVOLUTION_CHANNELS, (3, 3), W_std=2.0, b_std=0.05), stax.Flatten(), stax.Dense(out_logits, W_std=2.0, b_std=0.05)) else: raise ValueError('Unexpected network type found: {}'.format(network)) else: raise ValueError('Expected flat or image test input.')
def test_sparse_inputs(self, act, kernel): key = random.PRNGKey(1) input_count = 4 sparse_count = 2 input_size = 128 width = 4096 # NOTE(schsam): It seems that convergence is slower when inputs are sparse. samples = N_SAMPLES if xla_bridge.get_backend().platform == 'gpu': jtu._default_tolerance[np.onp.dtype(np.onp.float64)] = 5e-4 samples = 100 * N_SAMPLES else: jtu._default_tolerance[np.onp.dtype(np.onp.float32)] = 5e-2 jtu._default_tolerance[np.onp.dtype(np.onp.float64)] = 5e-3 # a batch of dense inputs x_dense = random.normal(key, (input_count, input_size)) x_sparse = ops.index_update(x_dense, ops.index[:sparse_count, :], 0.) activation = stax.Relu() if act == 'relu' else stax.Erf() init_fn, apply_fn, kernel_fn = stax.serial( stax.Dense(width), activation, stax.Dense(1 if kernel == 'ntk' else width)) exact = kernel_fn(x_sparse, None, kernel) mc = monte_carlo.monte_carlo_kernel_fn(init_fn, apply_fn, random.split(key, 2)[0], samples)(x_sparse, None, kernel) mc = np.reshape(mc, exact.shape) assert not np.any(np.isnan(exact)) self.assertAllClose(exact[sparse_count:, sparse_count:], mc[sparse_count:, sparse_count:], True)
PADDINGS = [ 'SAME', 'VALID', 'CIRCULAR' ] STRIDES = [ None, (1, 2), (2, 1), ] ACTIVATIONS = { # TODO: investigate poor erf convergence. stax.Erf(): 'erf', stax.Relu(): 'Relu', stax.ABRelu(-0.5, 0.7): 'ABRelu(-0.5, 0.7)' } PROJECTIONS = [ 'FLAT', 'POOL', 'ATTN_FIXED', 'ATTN_PARAM' ] LAYER_NORM = [ (-1,), (1, 3), (1, 2, 3)
def testTrainedEnsemblePredCov(self, train_shape, test_shape, network, out_logits): if xla_bridge.get_backend().platform == 'gpu' and config.read( 'jax_enable_x64'): raise jtu.SkipTest('Not running GPU x64 to save time.') training_steps = 5000 learning_rate = 1.0 ensemble_size = 50 init_fn, apply_fn, ker_fn = stax.serial( stax.Dense(1024, W_std=1.2, b_std=0.05), stax.Erf(), stax.Dense(out_logits, W_std=1.2, b_std=0.05)) opt_init, opt_update, get_params = optimizers.sgd(learning_rate) opt_update = jit(opt_update) key = random.PRNGKey(0) key, = random.split(key, 1) key, split = random.split(key) x_train = np.cos(random.normal(split, train_shape)) key, split = random.split(key) y_train = np.array( random.bernoulli(split, shape=(train_shape[0], out_logits)), np.float32) train = (x_train, y_train) key, split = random.split(key) x_test = np.cos(random.normal(split, test_shape)) ensemble_key = random.split(key, ensemble_size) loss = jit(lambda params, x, y: 0.5 * np.mean( (apply_fn(params, x) - y)**2)) grad_loss = jit(lambda state, x, y: grad(loss) (get_params(state), x, y)) def train_network(key): _, params = init_fn(key, (-1, ) + train_shape[1:]) opt_state = opt_init(params) for i in range(training_steps): opt_state = opt_update(i, grad_loss(opt_state, *train), opt_state) return get_params(opt_state) params = vmap(train_network)(ensemble_key) ensemble_fx = vmap(apply_fn, (0, None))(params, x_test) ensemble_loss = vmap(loss, (0, None, None))(params, x_train, y_train) ensemble_loss = np.mean(ensemble_loss) self.assertLess(ensemble_loss, 1e-5, True) mean_emp = np.mean(ensemble_fx, axis=0) mean_subtracted = ensemble_fx - mean_emp cov_emp = np.einsum( 'ijk,ilk->jl', mean_subtracted, mean_subtracted, optimize=True) / ( mean_subtracted.shape[0] * mean_subtracted.shape[-1]) reg = 1e-7 ntk_predictions = predict.gp_inference(ker_fn, x_train, y_train, x_test, 'ntk', reg, compute_cov=True) self.assertAllClose(mean_emp, ntk_predictions.mean, True, RTOL, ATOL) self.assertAllClose(cov_emp, ntk_predictions.covariance, True, RTOL, ATOL)
def _get_phi(cls, i): return {0: stax.Relu(), 1: stax.Erf(), 2: stax.Abs()}[i % 3]
def test_mask_conv(self, same_inputs, get, mask_axis, mask_constant, concat, proj, p, n, transpose): if isinstance(concat, int) and concat > n: raise absltest.SkipTest('Concatenation axis out of bounds.') test_utils.skip_test(self) if default_backend() == 'gpu' and n > 3: raise absltest.SkipTest('>=4D-CNN is not supported on GPUs.') width = 256 n_samples = 256 tol = 0.03 key = random.PRNGKey(1) spatial_shape = ((1, 2, 3, 2, 1) if transpose else (15, 8, 9))[:n] filter_shape = ((2, 3, 1, 2, 1) if transpose else (7, 2, 3))[:n] strides = (2, 1, 3, 2, 3)[:n] spatial_spec = 'HWDZX'[:n] dimension_numbers = ('N' + spatial_spec + 'C', 'OI' + spatial_spec, 'N' + spatial_spec + 'C') x1 = np.cos(random.normal(key, (2, ) + spatial_shape + (2, ))) x1 = test_utils.mask(x1, mask_constant, mask_axis, key, p) if same_inputs: x2 = None else: x2 = np.cos(random.normal(key, (4, ) + spatial_shape + (2, ))) x2 = test_utils.mask(x2, mask_constant, mask_axis, key, p) def get_attn(): return stax.GlobalSelfAttention( n_chan_out=width, n_chan_key=width, n_chan_val=int(np.round(float(width) / int(np.sqrt(width)))), n_heads=int(np.sqrt(width)), ) if proj == 'avg' else stax.Identity() conv = stax.ConvTranspose if transpose else stax.Conv nn = stax.serial( stax.FanOut(3), stax.parallel( stax.serial( conv(dimension_numbers=dimension_numbers, out_chan=width, strides=strides, filter_shape=filter_shape, padding='CIRCULAR', W_std=1.5, b_std=0.2), stax.LayerNorm(axis=(1, -1)), stax.Abs(), stax.DotGeneral(rhs=0.9), conv(dimension_numbers=dimension_numbers, out_chan=width, strides=strides, filter_shape=filter_shape, padding='VALID', W_std=1.2, b_std=0.1), ), stax.serial( conv(dimension_numbers=dimension_numbers, out_chan=width, strides=strides, filter_shape=filter_shape, padding='SAME', W_std=0.1, b_std=0.3), stax.Relu(), stax.Dropout(0.7), conv(dimension_numbers=dimension_numbers, out_chan=width, strides=strides, filter_shape=filter_shape, padding='VALID', W_std=0.9, b_std=1.), ), stax.serial( get_attn(), conv(dimension_numbers=dimension_numbers, out_chan=width, strides=strides, filter_shape=filter_shape, padding='CIRCULAR', W_std=1., b_std=0.1), stax.Erf(), stax.Dropout(0.2), stax.DotGeneral(rhs=0.7), conv(dimension_numbers=dimension_numbers, out_chan=width, strides=strides, filter_shape=filter_shape, padding='VALID', W_std=1., b_std=0.1), )), (stax.FanInSum() if concat is None else stax.FanInConcat(concat)), get_attn(), { 'avg': stax.GlobalAvgPool(), 'sum': stax.GlobalSumPool(), 'flatten': stax.Flatten(), }[proj], ) if get == 'nngp': init_fn, apply_fn, kernel_fn = stax.serial( nn, stax.Dense(width, 1., 0.)) elif get == 'ntk': init_fn, apply_fn, kernel_fn = stax.serial(nn, stax.Dense(1, 1., 0.)) else: raise ValueError(get) kernel_fn_mc = nt.monte_carlo_kernel_fn( init_fn, apply_fn, key, n_samples, device_count=0 if concat in (0, -n) else -1, implementation=_DEFAULT_TESTING_NTK_IMPLEMENTATION, vmap_axes=None if concat in (0, -n) else 0, ) kernel_fn = jit(kernel_fn, static_argnames='get') exact = kernel_fn(x1, x2, get, mask_constant=mask_constant) empirical = kernel_fn_mc(x1, x2, get=get, mask_constant=mask_constant) test_utils.assert_close_matrices(self, empirical, exact, tol)
def test_mask_fc(self, same_inputs, get, concat, p, mask_axis, mask_constant): width = 512 n_samples = 128 tol = 0.04 key = random.PRNGKey(1) x1 = random.normal(key, (4, 6, 5, 7)) x1 = test_utils.mask(x1, mask_constant, mask_axis, key, p) if same_inputs: x2 = None else: x2 = random.normal(key, (2, 6, 5, 7)) x2 = test_utils.mask(x2, mask_constant, mask_axis, key, p) nn = stax.serial( stax.Flatten(), stax.FanOut(3), stax.parallel( stax.serial( stax.Dense(width, 1., 0.1), stax.Abs(), stax.DotGeneral(lhs=-0.2), stax.Dense(width, 1.5, 0.01), ), stax.serial( stax.Dense(width, 1.1, 0.1), stax.DotGeneral(rhs=0.7), stax.Erf(), stax.Dense(width if concat != 1 else 512, 1.5, 0.1), ), stax.serial( stax.DotGeneral(rhs=0.5), stax.Dense(width, 1.2), stax.ABRelu(-0.2, 0.4), stax.Dense(width if concat != 1 else 1024, 1.3, 0.2), )), (stax.FanInSum() if concat is None else stax.FanInConcat(concat)), stax.Dense(width, 2., 0.01), stax.Relu()) if get == 'nngp': init_fn, apply_fn, kernel_fn = stax.serial( nn, stax.Dense(width, 1., 0.1)) elif get == 'ntk': init_fn, apply_fn, kernel_fn = stax.serial(nn, stax.Dense(1, 1., 0.1)) else: raise ValueError(get) kernel_fn_mc = nt.monte_carlo_kernel_fn( init_fn, apply_fn, key, n_samples, device_count=0 if concat in (0, -2) else -1, implementation=_DEFAULT_TESTING_NTK_IMPLEMENTATION, vmap_axes=None if concat in (0, -2) else 0, ) kernel_fn = jit(kernel_fn, static_argnames='get') exact = kernel_fn(x1, x2, get, mask_constant=mask_constant) empirical = kernel_fn_mc(x1, x2, get=get, mask_constant=mask_constant) test_utils.assert_close_matrices(self, empirical, exact, tol)
def test_input_req(self, same_inputs): test_utils.skip_test(self) key = random.PRNGKey(1) x1 = random.normal(key, (2, 7, 8, 4, 3)) x2 = None if same_inputs else random.normal(key, (4, 7, 8, 4, 3)) _, _, wrong_conv_fn = stax.serial( stax.Conv(out_chan=1, filter_shape=(1, 2, 3), dimension_numbers=('NDHWC', 'HDWIO', 'NCDWH')), stax.Relu(), stax.Conv(out_chan=1, filter_shape=(1, 2, 3), dimension_numbers=('NHDWC', 'HWDIO', 'NCWHD'))) with self.assertRaises(ValueError): wrong_conv_fn(x1, x2) init_fn, apply_fn, correct_conv_fn = stax.serial( stax.Conv(out_chan=1024, filter_shape=(1, 2, 3), dimension_numbers=('NHWDC', 'DHWIO', 'NCWDH')), stax.Relu(), stax.Conv(out_chan=1024, filter_shape=(1, 2, 3), dimension_numbers=('NCHDW', 'WHDIO', 'NCDWH')), stax.Flatten(), stax.Dense(1024)) correct_conv_fn_mc = nt.monte_carlo_kernel_fn( init_fn=init_fn, apply_fn=apply_fn, key=key, n_samples=400, implementation=_DEFAULT_TESTING_NTK_IMPLEMENTATION, vmap_axes=0) K = correct_conv_fn(x1, x2, get='nngp') K_mc = correct_conv_fn_mc(x1, x2, get='nngp') self.assertAllClose(K, K_mc, atol=0.01, rtol=0.05) _, _, wrong_conv_fn = stax.serial( stax.Conv(out_chan=1, filter_shape=(1, 2, 3), dimension_numbers=('NDHWC', 'HDWIO', 'NCDWH')), stax.GlobalAvgPool(channel_axis=2)) with self.assertRaises(ValueError): wrong_conv_fn(x1, x2) init_fn, apply_fn, correct_conv_fn = stax.serial( stax.Conv(out_chan=1024, filter_shape=(1, 2, 3), dimension_numbers=('NHDWC', 'DHWIO', 'NDWCH')), stax.Relu(), stax.AvgPool((2, 1, 3), batch_axis=0, channel_axis=-2), stax.Conv(out_chan=1024, filter_shape=(1, 2, 3), dimension_numbers=('NDHCW', 'IHWDO', 'NDCHW')), stax.Relu(), stax.GlobalAvgPool(channel_axis=2), stax.Dense(1024)) correct_conv_fn_mc = nt.monte_carlo_kernel_fn( init_fn=init_fn, apply_fn=apply_fn, key=key, n_samples=300, implementation=_DEFAULT_TESTING_NTK_IMPLEMENTATION, vmap_axes=0) K = correct_conv_fn(x1, x2, get='nngp') K_mc = correct_conv_fn_mc(x1, x2, get='nngp') self.assertAllClose(K, K_mc, atol=0.01, rtol=0.05) _, _, wrong_conv_fn = stax.serial( stax.Flatten(), stax.Dense(1), stax.Erf(), stax.Conv(out_chan=1, filter_shape=(1, 2), dimension_numbers=('CN', 'IO', 'NC')), ) with self.assertRaises(ValueError): wrong_conv_fn(x1, x2) init_fn, apply_fn, correct_conv_fn = stax.serial( stax.Flatten(), stax.Conv(out_chan=1024, filter_shape=()), stax.Relu(), stax.Dense(1)) correct_conv_fn_mc = nt.monte_carlo_kernel_fn( init_fn=init_fn, apply_fn=apply_fn, key=key, n_samples=200, implementation=_DEFAULT_TESTING_NTK_IMPLEMENTATION, vmap_axes=0) K = correct_conv_fn(x1, x2, get='ntk') K_mc = correct_conv_fn_mc(x1, x2, get='ntk') self.assertAllClose(K, K_mc, atol=0.01, rtol=0.05)
def main(unused_argv): # Build data pipelines. print('Loading data.') x_train, y_train, x_test, y_test = \ datasets.mnist(FLAGS.train_size, FLAGS.test_size) # x_train import numpy # numpy.argmax(y_train,1)%2 # y_train_tmp = numpy.zeros((y_train.shape[0],2)) # y_train_tmp[np.arange(y_train.shape[0]),numpy.argmax(y_train,1)%2] = 1 # y_train = y_train_tmp # y_test_tmp = numpy.zeros((y_test.shape[0],2)) # y_test_tmp[np.arange(y_train.shape[0]),numpy.argmax(y_test,1)%2] = 1 # y_test = y_test_tmp y_train_tmp = numpy.argmax(y_train, 1) % 2 y_train = np.expand_dims(y_train_tmp, 1) y_test_tmp = numpy.argmax(y_test, 1) % 2 y_test = np.expand_dims(y_test_tmp, 1) # print(y_train) # Build the network # init_fn, apply_fn, _ = stax.serial( # stax.Dense(2048, 1., 0.05), # # stax.Erf(), # stax.Relu(), # stax.Dense(2048, 1., 0.05), # # stax.Erf(), # stax.Relu(), # stax.Dense(10, 1., 0.05)) init_fn, apply_fn, _ = stax.serial(stax.Dense(2048, 1., 0.05), stax.Erf(), stax.Dense(1, 1., 0.05)) # key = random.PRNGKey(0) randnnn = numpy.random.random_integers(np.iinfo(np.int32).min, high=np.iinfo(np.int32).max, size=2)[0] key = random.PRNGKey(randnnn) _, params = init_fn(key, (-1, 784)) # params # Create and initialize an optimizer. opt_init, opt_apply, get_params = optimizers.sgd(FLAGS.learning_rate) state = opt_init(params) # state # Create an mse loss function and a gradient function. loss = lambda fx, y_hat: 0.5 * np.mean((fx - y_hat)**2) grad_loss = jit(grad(lambda params, x, y: loss(apply_fn(params, x), y))) # Create an MSE predictor to solve the NTK equation in function space. ntk = nt.batch(nt.empirical_ntk_fn(apply_fn), batch_size=4, device_count=0) g_dd = ntk(x_train, None, params) g_td = ntk(x_test, x_train, params) predictor = nt.predict.gradient_descent_mse(g_dd, y_train, g_td) # g_dd.shape # Get initial values of the network in function space. fx_train = apply_fn(params, x_train) fx_test = apply_fn(params, x_test) # Train the network. train_steps = int(FLAGS.train_time // FLAGS.learning_rate) print('Training for {} steps'.format(train_steps)) for i in range(train_steps): params = get_params(state) state = opt_apply(i, grad_loss(params, x_train, y_train), state) # Get predictions from analytic computation. print('Computing analytic prediction.') # fx_train, fx_test = predictor(FLAGS.train_time, fx_train, fx_test) fx_train, fx_test = predictor(FLAGS.train_time, fx_train, fx_test) # Print out summary data comparing the linear / nonlinear model. util.print_summary('train', y_train, apply_fn(params, x_train), fx_train, loss) util.print_summary('test', y_test, apply_fn(params, x_test), fx_test, loss)
class ElementwiseNumericalTest(test_utils.NeuralTangentsTestCase): @parameterized.named_parameters( test_utils.cases_from_list({ 'testcase_name': '_{}_{}_{}_{}'.format( model, phi[0].__name__, 'Same_inputs' if same_inputs else 'Different_inputs', get), 'model': model, 'phi': phi, 'same_inputs': same_inputs, 'get': get, } for model in ['fc', 'conv-pool', 'conv-flatten'] for phi in [ stax.Erf(), stax.Gelu(), stax.Sin(), ] for same_inputs in [False, True] for get in ['nngp', 'ntk'])) def test_elementwise_numerical(self, same_inputs, model, phi, get): if 'conv' in model: test_utils.skip_test(self) key, split = random.split(random.PRNGKey(1)) output_dim = 1 b_std = 0.01 W_std = 1.0 rtol = 2e-3 deg = 25 if get == 'ntk': rtol *= 2 if default_backend() == 'tpu': rtol *= 2 if model == 'fc': X0_1 = random.normal(key, (3, 7)) X0_2 = None if same_inputs else random.normal(split, (5, 7)) affine = stax.Dense(1024, W_std, b_std) readout = stax.Dense(output_dim) depth = 1 else: X0_1 = random.normal(key, (2, 8, 8, 3)) X0_2 = None if same_inputs else random.normal(split, (3, 8, 8, 3)) affine = stax.Conv(1024, (3, 2), W_std=W_std, b_std=b_std, padding='SAME') readout = stax.serial(stax.GlobalAvgPool() if 'pool' in model else stax.Flatten(), stax.Dense(output_dim)) depth = 2 _, _, kernel_fn = stax.serial(*[affine, phi] * depth, readout) analytic_kernel = kernel_fn(X0_1, X0_2, get) fn = lambda x: phi[1]((), x) _, _, kernel_fn = stax.serial( *[affine, stax.ElementwiseNumerical(fn, deg=deg)] * depth, readout) numerical_activation_kernel = kernel_fn(X0_1, X0_2, get) test_utils.assert_close_matrices(self, analytic_kernel, numerical_activation_kernel, rtol)
def testTrainedEnsemblePredCov(self, train_shape, test_shape, network, out_logits): training_steps = 1000 learning_rate = 0.1 ensemble_size = 1024 init_fn, apply_fn, kernel_fn = stax.serial( stax.Dense(128, W_std=1.2, b_std=0.05), stax.Erf(), stax.Dense(out_logits, W_std=1.2, b_std=0.05)) opt_init, opt_update, get_params = optimizers.sgd(learning_rate) opt_update = jit(opt_update) key, x_test, x_train, y_train = self._get_inputs( out_logits, test_shape, train_shape) predict_fn_mse_ens = predict.gradient_descent_mse_ensemble( kernel_fn, x_train, y_train, learning_rate=learning_rate, diag_reg=0.) train = (x_train, y_train) ensemble_key = random.split(key, ensemble_size) loss = jit(lambda params, x, y: 0.5 * np.mean( (apply_fn(params, x) - y)**2)) grad_loss = jit(lambda state, x, y: grad(loss) (get_params(state), x, y)) def train_network(key): _, params = init_fn(key, (-1, ) + train_shape[1:]) opt_state = opt_init(params) for i in range(training_steps): opt_state = opt_update(i, grad_loss(opt_state, *train), opt_state) return get_params(opt_state) params = vmap(train_network)(ensemble_key) rtol = 0.08 for x in [None, 'x_test']: with self.subTest(x=x): x = x if x is None else x_test x_fin = x_train if x is None else x_test ensemble_fx = vmap(apply_fn, (0, None))(params, x_fin) mean_emp = np.mean(ensemble_fx, axis=0, keepdims=True) mean_subtracted = ensemble_fx - mean_emp cov_emp = np.einsum( 'ijk,ilk->jl', mean_subtracted, mean_subtracted, optimize=True) / (mean_subtracted.shape[0] * mean_subtracted.shape[-1]) ntk = predict_fn_mse_ens(training_steps, x, 'ntk', compute_cov=True) self._assertAllClose(mean_emp, ntk.mean, rtol) self._assertAllClose(cov_emp, ntk.covariance, rtol)
def test_vmap_axes(self, same_inputs): n1, n2 = 3, 4 c1, c2, c3 = 9, 5, 7 h2, h3, w3 = 6, 8, 2 def get_x(n, k): k1, k2, k3 = random.split(k, 3) x1 = random.normal(k1, (n, c1)) x2 = random.normal(k2, (h2, n, c2)) x3 = random.normal(k3, (c3, w3, n, h3)) x = [(x1, x2), x3] return x x1 = get_x(n1, random.PRNGKey(1)) x2 = get_x(n2, random.PRNGKey(2)) if not same_inputs else None p1 = random.normal(random.PRNGKey(5), (n1, h2, h2)) p2 = None if same_inputs else random.normal(random.PRNGKey(6), (n2, h2, h2)) init_fn, apply_fn, _ = stax.serial( stax.parallel( stax.parallel( stax.serial(stax.Dense(4, 2., 0.1), stax.Relu(), stax.Dense(3, 1., 0.15)), # 1 stax.serial( stax.Conv(7, (2, ), padding='SAME', dimension_numbers=('HNC', 'OIH', 'NHC')), stax.Erf(), stax.Aggregate(1, 0, -1), stax.GlobalAvgPool(), stax.Dense(3, 0.5, 0.2)), # 2 ), stax.serial( stax.Conv(5, (2, 3), padding='SAME', dimension_numbers=('CWNH', 'IOHW', 'HWCN')), stax.Sin(), ) # 3 ), stax.parallel( stax.FanInSum(), stax.Conv(2, (2, 1), dimension_numbers=('HWCN', 'OIHW', 'HNWC')))) _, params = init_fn(random.PRNGKey(3), tree_map(np.shape, x1)) implicit = jit(nt.empirical_ntk_fn(apply_fn, implementation=2)) direct = jit(nt.empirical_ntk_fn(apply_fn, implementation=1)) implicit_batched = jit( nt.empirical_ntk_fn(apply_fn, vmap_axes=([(0, 1), 2], [-2, -3], dict(pattern=0)), implementation=2)) direct_batched = jit( nt.empirical_ntk_fn(apply_fn, vmap_axes=([(-2, -2), -2], [0, 1], dict(pattern=-3)), implementation=1)) k = direct(x1, x2, params, pattern=(p1, p2)) self.assertAllClose(k, implicit(x1, x2, params, pattern=(p1, p2))) self.assertAllClose(k, direct_batched(x1, x2, params, pattern=(p1, p2))) self.assertAllClose(k, implicit_batched(x1, x2, params, pattern=(p1, p2)))
FLAGS["train_size"] = 128 FLAGS["test_size"] = 128 FLAGS["train_time"] = 10000.0 class Struct: def __init__(self, **entries): self.__dict__.update(entries) FLAGS=Struct(**FLAGS) print('Loading data.') x_train, y_train, x_test, y_test = \ datasets.mnist(FLAGS.train_size, FLAGS.test_size) # Build the network init_fn, apply_fn, _ = stax.serial( stax.Dense(2048, 1., 0.05), stax.Erf(), stax.Dense(10, 1., 0.05)) key = random.PRNGKey(0) _, params = init_fn(key, (-1, 784)) # params # Create and initialize an optimizer. opt_init, opt_apply, get_params = optimizers.sgd(FLAGS.learning_rate) state = opt_init(params) # state #%% # Create an mse loss function and a gradient function.
st.header("Define your (Finite) Network") # st.sidebar.markdown("## Network") n_hidden = st.slider("Hidden width", 16, 2048, 512, step=16) depth = st.slider("Network depth (excludes input and output)", 1, 10, 2, step=1) sigma_w = st.slider("Sigma w", 0.1, 3.0, 1.5, step=0.1) sigma_b = st.slider("Sigma b", 0.01, 0.1, 0.05, step=0.01) activation_fn = st.selectbox("Activation Function", ("Erf", "ReLU", "None")) activation_fn_dict = {"Erf": stax.Erf(), "ReLU": stax.Relu(), "None": None} activation_fn = activation_fn_dict[activation_fn] sequence = ((stax.Dense(n_hidden, W_std=sigma_w, b_std=sigma_b), activation_fn) if activation_fn else (stax.Dense(n_hidden, W_std=sigma_w, b_std=sigma_b), )) init_fn, apply_fn, kernel_fn = stax.serial( *(sequence * depth), stax.Dense(1, W_std=sigma_w, b_std=sigma_b)) apply_fn = jit(apply_fn) kernel_fn = jit(kernel_fn, static_argnums=(2, )) st.markdown(""" We define our network using the **Neural Tanget Stax** module. It allows us to define the architecture, initialisation and returns the network function plus (infinite width) kernel function.
class ElementwiseTest(test_utils.NeuralTangentsTestCase): @parameterized.product( phi=[ stax.Identity(), stax.Erf(), stax.Sin(), stax.Relu(), ], same_inputs=[False, True, None], n=[0, 1, 2], diagonal_batch=[True, False], diagonal_spatial=[True, False] ) def test_elementwise( self, same_inputs, phi, n, diagonal_batch, diagonal_spatial ): fn = lambda x: phi[1]((), x) name = phi[0].__name__ def nngp_fn(cov12, var1, var2): if 'Identity' in name: res = cov12 elif 'Erf' in name: prod = (1 + 2 * var1) * (1 + 2 * var2) res = np.arcsin(2 * cov12 / np.sqrt(prod)) * 2 / np.pi elif 'Sin' in name: sum_ = (var1 + var2) s1 = np.exp((-0.5 * sum_ + cov12)) s2 = np.exp((-0.5 * sum_ - cov12)) res = (s1 - s2) / 2 elif 'Relu' in name: prod = var1 * var2 sqrt = np.sqrt(np.maximum(prod - cov12 ** 2, 1e-30)) angles = np.arctan2(sqrt, cov12) dot_sigma = (1 - angles / np.pi) / 2 res = sqrt / (2 * np.pi) + dot_sigma * cov12 else: raise NotImplementedError(name) return res _, _, kernel_fn = stax.serial(stax.Dense(1), stax.Elementwise(fn, nngp_fn), stax.Dense(1), stax.Elementwise(fn, nngp_fn)) _, _, kernel_fn_manual = stax.serial(stax.Dense(1), phi, stax.Dense(1), phi) key = random.PRNGKey(1) shape = (4, 3, 2)[:n] + (1,) x1 = random.normal(key, (5,) + shape) if same_inputs is None: x2 = None elif same_inputs is True: x2 = x1 else: x2 = random.normal(key, (6,) + shape) kwargs = dict(diagonal_batch=diagonal_batch, diagonal_spatial=diagonal_spatial) k = kernel_fn(x1, x2, **kwargs) k_manual = kernel_fn_manual(x1, x2, **kwargs).replace(is_gaussian=False) self.assertAllClose(k_manual, k)
def main(unused_argv): # Build data and . print('Loading data.') x_train, y_train, x_test, y_test = datasets.get_dataset('mnist', permute_train=True) # Build the network init_fn, f, _ = stax.serial(stax.Dense(512, 1., 0.05), stax.Erf(), stax.Dense(10, 1., 0.05)) key = random.stateless_random_uniform(shape=[2], seed=[0, 0], minval=None, maxval=None, dtype=np.int32) _, params = init_fn(key, (1, 784)) # Linearize the network about its initial parameters. f_lin = nt.linearize(f, params) # Create and initialize an optimizer for both f and f_lin. opt_init, opt_apply, get_params = optimizers.momentum( FLAGS.learning_rate, 0.9) opt_apply = jit(opt_apply) state = opt_init(params) state_lin = opt_init(params) # Create a cross-entropy loss function. loss = lambda fx, y_hat: -np.mean(logsoftmax(fx) * y_hat) # Specialize the loss function to compute gradients for both linearized and # full networks. grad_loss = jit(grad(lambda params, x, y: loss(f(params, x), y))) grad_loss_lin = jit(grad(lambda params, x, y: loss(f_lin(params, x), y))) # Train the network. print('Training.') print('Epoch\tLoss\tLinearized Loss') print('------------------------------------------') epoch = 0 steps_per_epoch = 50000 // FLAGS.batch_size for i, (x, y) in enumerate( datasets.minibatch(x_train, y_train, FLAGS.batch_size, FLAGS.train_epochs)): params = get_params(state) state = opt_apply(i, grad_loss(params, x, y), state) params_lin = get_params(state_lin) state_lin = opt_apply(i, grad_loss_lin(params_lin, x, y), state_lin) if i % steps_per_epoch == 0: print('{}\t{}\t{}'.format(epoch, loss(f(params, x), y), loss(f_lin(params_lin, x), y))) epoch += 1 # Print out summary data comparing the linear / nonlinear model. x, y = x_train[:10000], y_train[:10000] util.print_summary('train', y, f(params, x), f_lin(params_lin, x), loss) util.print_summary('test', y_test, f(params, x_test), f_lin(params_lin, x_test), loss)
def weight_space(train_embedding, test_embedding, data_set): init_fn, f, _ = stax.serial( stax.Dense(512, 1., 0.05), stax.Erf(), # 2 denotes 2 type of classes stax.Dense(2, 1., 0.05)) key = random.PRNGKey(0) # (-1, 135), 135 denotes the feature length, here is 9 * 15 = 135 _, params = init_fn(key, (-1, 135)) # Linearize the network about its initial parameters. f_lin = nt.linearize(f, params) # Create and initialize an optimizer for both f and f_lin. opt_init, opt_apply, get_params = optimizers.momentum(1.0, 0.9) opt_apply = jit(opt_apply) state = opt_init(params) state_lin = opt_init(params) # Create a cross-entropy loss function. loss = lambda fx, y_hat: -np.mean(logsoftmax(fx) * y_hat) # Specialize the loss function to compute gradients for both linearized and # full networks. grad_loss = jit(grad(lambda params, x, y: loss(f(params, x), y))) grad_loss_lin = jit(grad(lambda params, x, y: loss(f_lin(params, x), y))) # Train the network. print('Training.') print('Epoch\tLoss\tLinearized Loss') print('------------------------------------------') epoch = 0 # Use whole batch batch_size = 64 train_epochs = 10 steps_per_epoch = 100 for i, (x, y) in enumerate( datasets.mini_batch(train_embedding, data_set['Y_train'], batch_size, train_epochs)): params = get_params(state) state = opt_apply(i, grad_loss(params, x, y), state) params_lin = get_params(state_lin) state_lin = opt_apply(i, grad_loss_lin(params_lin, x, y), state_lin) if i % steps_per_epoch == 0: print('{}\t{:.4f}\t{:.4f}'.format(epoch, loss(f(params, x), y), loss(f_lin(params_lin, x), y))) epoch += 1 if i / steps_per_epoch == train_epochs: break # Print out summary data comparing the linear / nonlinear model. x, y = train_embedding[:10000], data_set['Y_train'][:10000] util.print_summary('train', y, f(params, x), f_lin(params_lin, x), loss) util.print_summary('test', data_set['Y_test'], f(params, test_embedding), f_lin(params_lin, test_embedding), loss)