def test_avg_pool(self): X1 = np.ones((4, 2, 3, 2)) X2 = np.ones((3, 2, 3, 2)) _, apply_fn, kernel_fn = stax.AvgPool((2, 2), (1, 1), 'SAME', normalize_edges=False) _, apply_fn_norm, kernel_fn_norm = stax.AvgPool((2, 2), (1, 1), 'SAME', normalize_edges=True) _, apply_fn_stax = stax.ostax.AvgPool((2, 2), (1, 1), 'SAME') out1 = apply_fn((), X1) out2 = apply_fn((), X2) out1_norm = apply_fn_norm((), X1) out2_norm = apply_fn_norm((), X2) out1_stax = apply_fn_stax((), X1) out2_stax = apply_fn_stax((), X2) self.assertAllClose((out1_stax, out2_stax), (out1_norm, out2_norm), True) out_unnorm = np.array([[1., 1., 0.5], [0.5, 0.5, 0.25]]).reshape( (1, 2, 3, 1)) out1_unnormalized = np.broadcast_to(out_unnorm, X1.shape) out2_unnormalized = np.broadcast_to(out_unnorm, X2.shape) self.assertAllClose((out1_unnormalized, out2_unnormalized), (out1, out2), True) ker = kernel_fn(X1, X2) ker_norm = kernel_fn_norm(X1, X2) self.assertAllClose(np.ones_like(ker_norm.nngp), ker_norm.nngp, True) self.assertAllClose(np.ones_like(ker_norm.var1), ker_norm.var1, True) self.assertAllClose(np.ones_like(ker_norm.var2), ker_norm.var2, True) self.assertEqual(ker_norm.nngp.shape, ker.nngp.shape) self.assertEqual(ker_norm.var1.shape, ker.var1.shape) self.assertEqual(ker_norm.var2.shape, ker.var2.shape) ker_unnorm = np.outer(out_unnorm, out_unnorm).reshape((2, 3, 2, 3)) ker_unnorm = np.transpose(ker_unnorm, axes=(0, 2, 1, 3)) nngp = np.broadcast_to(ker_unnorm.reshape((1, 1) + ker_unnorm.shape), ker.nngp.shape) var1 = np.broadcast_to(np.expand_dims(ker_unnorm, 0), ker.var1.shape) var2 = np.broadcast_to(np.expand_dims(ker_unnorm, 0), ker.var2.shape) self.assertAllClose((nngp, var1, var2), (ker.nngp, ker.var1, ker.var2), True)
def build_le_net(network_width): """ Construct the LeNet of width network_width with average pooling using neural tangent's stax.""" return stax.serial( stax.Conv(out_chan=6 * network_width, filter_shape=(3, 3), strides=(1, 1), padding='VALID'), stax.Relu(), stax.AvgPool(window_shape=(2, 2), strides=(2, 2)), stax.Conv(out_chan=16 * network_width, filter_shape=(3, 3), strides=(1, 1), padding='VALID'), stax.Relu(), stax.AvgPool(window_shape=(2, 2), strides=(2, 2)), stax.Flatten(), stax.Dense(120 * network_width), stax.Relu(), stax.Dense(84 * network_width), stax.Relu(), stax.Dense(10))
def _get_net(W_std, b_std, filter_shape, is_conv, use_pooling, is_res, padding, phi, strides, width, is_ntk): fc = partial(stax.Dense, W_std=W_std, b_std=b_std) conv = partial( stax.Conv, filter_shape=filter_shape, strides=strides, padding=padding, W_std=W_std, b_std=b_std) affine = conv(width) if is_conv else fc(width) res_unit = stax.serial((stax.AvgPool( (2, 3), None, 'SAME' if padding == 'SAME' else 'CIRCULAR') if use_pooling else stax.Identity()), phi, affine) if is_res: block = stax.serial(affine, stax.FanOut(2), stax.parallel(stax.Identity(), res_unit), stax.FanInSum()) else: block = stax.serial(affine, res_unit) readout = stax.serial(stax.GlobalAvgPool() if use_pooling else stax.Flatten(), fc(1 if is_ntk else width)) net = stax.serial(block, readout) return net
def WideResnet(block_size, k, num_classes): return stax.serial(stax.Conv(16, (3, 3), padding='SAME'), WideResnetGroup(block_size, int(16 * k)), WideResnetGroup(block_size, int(32 * k), (2, 2)), WideResnetGroup(block_size, int(64 * k), (2, 2)), stax.AvgPool((8, 8)), stax.Flatten(), stax.Dense(num_classes, 1., 0.))
def WideResnetnt( block_size, k, num_classes, batchnorm='std'): #, batch_norm=None,layer_norm=None,freezelast=None): """Based off of WideResnet from paper, with or without BatchNorm. (Set config.wrn_block_size=3, config.wrn_widening_f=10 in that case). Uses default weight and bias init.""" parameterization = 'standard' layers_lst = [ stax_nt.Conv(16, (3, 3), padding='SAME', parameterization=parameterization), WideResnetGroupnt(block_size, 16 * k, parameterization=parameterization, batchnorm=batchnorm), WideResnetGroupnt(block_size, 32 * k, (2, 2), parameterization=parameterization, batchnorm=batchnorm), WideResnetGroupnt(block_size, 64 * k, (2, 2), parameterization=parameterization, batchnorm=batchnorm) ] layers_lst += [_batch_norm_internal(batchnorm), stax_nt.Relu()] layers_lst += [ stax_nt.AvgPool((8, 8)), stax_nt.Flatten(), stax_nt.Dense(num_classes, parameterization=parameterization) ] return stax_nt.serial(*layers_lst)
def _get_net(W_std, b_std, filter_shape, is_conv, use_pooling, is_res, padding, phi, strides, width, is_ntk, proj_into_2d, layer_norm, parameterization, use_dropout): fc = partial(stax.Dense, W_std=W_std, b_std=b_std, parameterization=parameterization) conv = partial(stax.Conv, filter_shape=filter_shape, strides=strides, padding=padding, W_std=W_std, b_std=b_std, parameterization=parameterization) affine = conv(width) if is_conv else fc(width) rate = np.onp.random.uniform(0.5, 0.9) dropout = stax.Dropout(rate, mode='train') ave_pool = stax.AvgPool((2, 3), None, 'SAME' if padding == 'SAME' else 'CIRCULAR') ave_pool_or_identity = ave_pool if use_pooling else stax.Identity() dropout_or_identity = dropout if use_dropout else stax.Identity() layer_norm_or_identity = (stax.Identity() if layer_norm is None else stax.LayerNorm(axis=layer_norm)) res_unit = stax.serial(ave_pool_or_identity, phi, dropout_or_identity, affine) if is_res: block = stax.serial(affine, stax.FanOut(2), stax.parallel(stax.Identity(), res_unit), stax.FanInSum(), layer_norm_or_identity) else: block = stax.serial(affine, res_unit, layer_norm_or_identity) if proj_into_2d == 'FLAT': proj_layer = stax.Flatten() elif proj_into_2d == 'POOL': proj_layer = stax.GlobalAvgPool() elif proj_into_2d.startswith('ATTN'): n_heads = int(np.sqrt(width)) n_chan_val = int(np.round(float(width) / n_heads)) fixed = proj_into_2d == 'ATTN_FIXED' proj_layer = stax.serial( stax.GlobalSelfAttention(width, n_chan_key=width, n_chan_val=n_chan_val, n_heads=n_heads, fixed=fixed, W_key_std=W_std, W_value_std=W_std, W_query_std=W_std, W_out_std=1.0, b_std=b_std), stax.Flatten()) else: raise ValueError(proj_into_2d) readout = stax.serial(proj_layer, fc(1 if is_ntk else width)) return stax.serial(block, readout)
def _MyrtleNetwork(width, depth, W_std=jnp.sqrt(2.0), b_std=0.0): layer_factor = {5: [2, 1, 1], 7: [2, 2, 2], 10: [3, 3, 3]} activation_fn = stax.Relu() layers = [] conv = functools.partial(stax.Conv, W_std=W_std, b_std=b_std, padding="SAME") layers += [conv(width, (3, 3)), activation_fn] * layer_factor[depth][0] layers += [stax.AvgPool((2, 2), strides=(2, 2))] layers += [conv(width, (3, 3)), activation_fn] * layer_factor[depth][1] layers += [stax.AvgPool((2, 2), strides=(2, 2))] layers += [conv(width, (3, 3)), activation_fn] * layer_factor[depth][2] layers += [stax.AvgPool((2, 2), strides=(2, 2))] * 3 layers += [stax.Flatten(), stax.Dense(10, W_std, b_std)] return stax.serial(*layers)
def WideResnet(block_size, k, num_classes, W_std=1., b_std=0.): return stax.serial( stax.Conv(16, (3, 3), W_std=W_std, b_std=b_std, padding='SAME'), WideResnetGroup(block_size, int(16 * k), W_std=W_std, b_std=b_std), WideResnetGroup(block_size, int(32 * k), (2, 2), W_std=W_std, b_std=b_std), WideResnetGroup(block_size, int(64 * k), (2, 2), W_std=W_std, b_std=b_std), stax.AvgPool((7, 7)), stax.Flatten(), stax.Dense(num_classes, W_std=W_std, b_std=b_std))
def _get_net(W_std, b_std, filter_shape, is_conv, use_pooling, is_res, padding, phi, strides, width, is_ntk, proj_into_2d): fc = partial(stax.Dense, W_std=W_std, b_std=b_std) conv = partial( stax.Conv, filter_shape=filter_shape, strides=strides, padding=padding, W_std=W_std, b_std=b_std) affine = conv(width) if is_conv else fc(width) res_unit = stax.serial((stax.AvgPool( (2, 3), None, 'SAME' if padding == 'SAME' else 'CIRCULAR') if use_pooling else stax.Identity()), phi, affine) if is_res: block = stax.serial(affine, stax.FanOut(2), stax.parallel(stax.Identity(), res_unit), stax.FanInSum()) else: block = stax.serial(affine, res_unit) if proj_into_2d == 'FLAT': proj_layer = stax.Flatten() elif proj_into_2d == 'POOL': proj_layer = stax.GlobalAvgPool() elif proj_into_2d.startswith('ATTN'): n_heads = int(np.sqrt(width)) n_chan_val = int(np.round(float(width) / n_heads)) fixed = proj_into_2d == 'ATTN_FIXED' proj_layer = stax.serial( stax.GlobalSelfAttention( width, n_chan_key=width, n_chan_val=n_chan_val, n_heads=n_heads, fixed=fixed, W_key_std=W_std, W_value_std=W_std, W_query_std=W_std, W_out_std=1.0, b_std=b_std), stax.Flatten()) else: raise ValueError(proj_into_2d) readout = stax.serial(proj_layer, fc(1 if is_ntk else width)) return stax.serial(block, readout)
def test_input_req(self, same_inputs): test_utils.skip_test(self) key = random.PRNGKey(1) x1 = random.normal(key, (2, 7, 8, 4, 3)) x2 = None if same_inputs else random.normal(key, (4, 7, 8, 4, 3)) _, _, wrong_conv_fn = stax.serial( stax.Conv(out_chan=1, filter_shape=(1, 2, 3), dimension_numbers=('NDHWC', 'HDWIO', 'NCDWH')), stax.Relu(), stax.Conv(out_chan=1, filter_shape=(1, 2, 3), dimension_numbers=('NHDWC', 'HWDIO', 'NCWHD'))) with self.assertRaises(ValueError): wrong_conv_fn(x1, x2) init_fn, apply_fn, correct_conv_fn = stax.serial( stax.Conv(out_chan=1024, filter_shape=(1, 2, 3), dimension_numbers=('NHWDC', 'DHWIO', 'NCWDH')), stax.Relu(), stax.Conv(out_chan=1024, filter_shape=(1, 2, 3), dimension_numbers=('NCHDW', 'WHDIO', 'NCDWH')), stax.Flatten(), stax.Dense(1024)) correct_conv_fn_mc = nt.monte_carlo_kernel_fn( init_fn=init_fn, apply_fn=apply_fn, key=key, n_samples=400, implementation=_DEFAULT_TESTING_NTK_IMPLEMENTATION, vmap_axes=0) K = correct_conv_fn(x1, x2, get='nngp') K_mc = correct_conv_fn_mc(x1, x2, get='nngp') self.assertAllClose(K, K_mc, atol=0.01, rtol=0.05) _, _, wrong_conv_fn = stax.serial( stax.Conv(out_chan=1, filter_shape=(1, 2, 3), dimension_numbers=('NDHWC', 'HDWIO', 'NCDWH')), stax.GlobalAvgPool(channel_axis=2)) with self.assertRaises(ValueError): wrong_conv_fn(x1, x2) init_fn, apply_fn, correct_conv_fn = stax.serial( stax.Conv(out_chan=1024, filter_shape=(1, 2, 3), dimension_numbers=('NHDWC', 'DHWIO', 'NDWCH')), stax.Relu(), stax.AvgPool((2, 1, 3), batch_axis=0, channel_axis=-2), stax.Conv(out_chan=1024, filter_shape=(1, 2, 3), dimension_numbers=('NDHCW', 'IHWDO', 'NDCHW')), stax.Relu(), stax.GlobalAvgPool(channel_axis=2), stax.Dense(1024)) correct_conv_fn_mc = nt.monte_carlo_kernel_fn( init_fn=init_fn, apply_fn=apply_fn, key=key, n_samples=300, implementation=_DEFAULT_TESTING_NTK_IMPLEMENTATION, vmap_axes=0) K = correct_conv_fn(x1, x2, get='nngp') K_mc = correct_conv_fn_mc(x1, x2, get='nngp') self.assertAllClose(K, K_mc, atol=0.01, rtol=0.05) _, _, wrong_conv_fn = stax.serial( stax.Flatten(), stax.Dense(1), stax.Erf(), stax.Conv(out_chan=1, filter_shape=(1, 2), dimension_numbers=('CN', 'IO', 'NC')), ) with self.assertRaises(ValueError): wrong_conv_fn(x1, x2) init_fn, apply_fn, correct_conv_fn = stax.serial( stax.Flatten(), stax.Conv(out_chan=1024, filter_shape=()), stax.Relu(), stax.Dense(1)) correct_conv_fn_mc = nt.monte_carlo_kernel_fn( init_fn=init_fn, apply_fn=apply_fn, key=key, n_samples=200, implementation=_DEFAULT_TESTING_NTK_IMPLEMENTATION, vmap_axes=0) K = correct_conv_fn(x1, x2, get='ntk') K_mc = correct_conv_fn_mc(x1, x2, get='ntk') self.assertAllClose(K, K_mc, atol=0.01, rtol=0.05)
X, _, Xtest, _ = prep_data('CIFAR10', False, noise_index) n = X.shape[0] ntest = Xtest.shape[0] W_std = 1.0 b_std = 0.0 # Number of rows generated at each job m = onp.int(200) if model_name == 'Myrtle': init_fn, apply_fn, kernel_fn = stax.serial(stax.Conv(512, (3, 3), strides=(1, 1), W_std=W_std, b_std=b_std, padding='SAME'),\ stax.Relu(),\ stax.Conv(512, (3, 3), strides=(1, 1), W_std=W_std, b_std=b_std, padding='SAME'),\ stax.Relu(),\ stax.AvgPool((2, 2), strides=(2, 2), padding='VALID'),\ stax.Conv(512, (3, 3), strides=(1, 1), W_std=W_std, b_std=b_std, padding='SAME'),\ stax.Relu(),\ stax.AvgPool((2, 2), strides=(2, 2), padding='VALID'),\ stax.Conv(512, (3, 3), strides=(1, 1), W_std=W_std, b_std=b_std, padding='SAME'),\ stax.Relu(),\ stax.AvgPool((2, 2), strides=(2, 2), padding='VALID'),\ stax.AvgPool((2, 2), strides=(2, 2), padding='VALID'),\ stax.AvgPool((2, 2), strides=(2, 2), padding='VALID'),\ stax.Flatten(),\ stax.Dense(10, W_std, b_std)) else: raise Exception('Invalid Input Error') apply_fn = jit(apply_fn) kernel_fn = jit(kernel_fn, static_argnums=(2, ))