def _get_net(W_std, b_std, filter_shape, is_conv, use_pooling, is_res, padding, phi, strides, width, is_ntk): fc = partial(stax.Dense, W_std=W_std, b_std=b_std) conv = partial( stax.Conv, filter_shape=filter_shape, strides=strides, padding=padding, W_std=W_std, b_std=b_std) affine = conv(width) if is_conv else fc(width) res_unit = stax.serial((stax.AvgPool( (2, 3), None, 'SAME' if padding == 'SAME' else 'CIRCULAR') if use_pooling else stax.Identity()), phi, affine) if is_res: block = stax.serial(affine, stax.FanOut(2), stax.parallel(stax.Identity(), res_unit), stax.FanInSum()) else: block = stax.serial(affine, res_unit) readout = stax.serial(stax.GlobalAvgPool() if use_pooling else stax.Flatten(), fc(1 if is_ntk else width)) net = stax.serial(block, readout) return net
def _get_net(W_std, b_std, filter_shape, is_conv, use_pooling, is_res, padding, phi, strides, width, is_ntk, proj_into_2d, layer_norm, parameterization, use_dropout): fc = partial(stax.Dense, W_std=W_std, b_std=b_std, parameterization=parameterization) conv = partial(stax.Conv, filter_shape=filter_shape, strides=strides, padding=padding, W_std=W_std, b_std=b_std, parameterization=parameterization) affine = conv(width) if is_conv else fc(width) rate = np.onp.random.uniform(0.5, 0.9) dropout = stax.Dropout(rate, mode='train') ave_pool = stax.AvgPool((2, 3), None, 'SAME' if padding == 'SAME' else 'CIRCULAR') ave_pool_or_identity = ave_pool if use_pooling else stax.Identity() dropout_or_identity = dropout if use_dropout else stax.Identity() layer_norm_or_identity = (stax.Identity() if layer_norm is None else stax.LayerNorm(axis=layer_norm)) res_unit = stax.serial(ave_pool_or_identity, phi, dropout_or_identity, affine) if is_res: block = stax.serial(affine, stax.FanOut(2), stax.parallel(stax.Identity(), res_unit), stax.FanInSum(), layer_norm_or_identity) else: block = stax.serial(affine, res_unit, layer_norm_or_identity) if proj_into_2d == 'FLAT': proj_layer = stax.Flatten() elif proj_into_2d == 'POOL': proj_layer = stax.GlobalAvgPool() elif proj_into_2d.startswith('ATTN'): n_heads = int(np.sqrt(width)) n_chan_val = int(np.round(float(width) / n_heads)) fixed = proj_into_2d == 'ATTN_FIXED' proj_layer = stax.serial( stax.GlobalSelfAttention(width, n_chan_key=width, n_chan_val=n_chan_val, n_heads=n_heads, fixed=fixed, W_key_std=W_std, W_value_std=W_std, W_query_std=W_std, W_out_std=1.0, b_std=b_std), stax.Flatten()) else: raise ValueError(proj_into_2d) readout = stax.serial(proj_layer, fc(1 if is_ntk else width)) return stax.serial(block, readout)
def get_attn(): return stax.GlobalSelfAttention( n_chan_out=width, n_chan_key=width, n_chan_val=int(np.round(float(width) / int(np.sqrt(width)))), n_heads=int(np.sqrt(width)), ) if proj == 'avg' else stax.Identity()
def test_ab_relu_id(self, same_inputs, do_stabilize): key = random.PRNGKey(1) X0_1 = random.normal(key, (3, 2)) fc = stax.Dense(5, 1, 0) X0_2 = None if same_inputs else random.normal(key, (4, 2)) # Test that ABRelu(a, a) == a * Identity init_fn, apply_id, kernel_fn_id = stax.serial(fc, stax.Identity()) _, params = init_fn(key, input_shape=X0_1.shape) for a in [-5, -1, -0.5, 0, 0.5, 1, 5]: with self.subTest(a=a): _, apply_ab_relu, kernel_fn_ab_relu = stax.serial( fc, stax.ABRelu(a, a, do_stabilize=do_stabilize)) X1_1_id = a * apply_id(params, X0_1) X1_1_ab_relu = apply_ab_relu(params, X0_1) self.assertAllClose(X1_1_id, X1_1_ab_relu) kernels_id = kernel_fn_id(X0_1 * a, None if X0_2 is None else a * X0_2) kernels_ab_relu = kernel_fn_ab_relu(X0_1, X0_2) # Manually correct the value of `is_gaussian` because # `ab_relu` (incorrectly) sets `is_gaussian=False` when `a==b`. kernels_ab_relu = kernels_ab_relu.replace(is_gaussian=True) self.assertAllClose(kernels_id, kernels_ab_relu)
def test_hermite(self, same_inputs, degree, get, readout): key = random.PRNGKey(1) key1, key2, key = random.split(key, 3) if degree > 2: width = 10000 n_samples = 5000 test_utils.skip_test(self) else: width = 10000 n_samples = 100 x1 = np.cos(random.normal(key1, [2, 6, 6, 3])) x2 = x1 if same_inputs else np.cos(random.normal(key2, [3, 6, 6, 3])) conv_layers = [ stax.Conv(width, (3, 3), W_std=2., b_std=0.5), stax.LayerNorm(), stax.Hermite(degree), stax.GlobalAvgPool() if readout == 'pool' else stax.Flatten(), stax.Dense(1) if get == 'ntk' else stax.Identity()] init_fn, apply_fn, kernel_fn = stax.serial(*conv_layers) analytic_kernel = kernel_fn(x1, x2, get) mc_kernel_fn = nt.monte_carlo_kernel_fn(init_fn, apply_fn, key, n_samples) mc_kernel = mc_kernel_fn(x1, x2, get) rot = degree / 2. * 1e-2 test_utils.assert_close_matrices(self, mc_kernel, analytic_kernel, rot)
def test_nested_parallel(self, same_inputs, kernel_type): platform = default_backend() rtol = RTOL if platform != 'tpu' else 0.05 rng = random.PRNGKey(0) (input_key1, input_key2, input_key3, input_key4, mask_key, mc_key) = random.split(rng, 6) x1_1, x2_1 = _get_inputs(input_key1, same_inputs, (BATCH_SIZE, 5)) x1_2, x2_2 = _get_inputs(input_key2, same_inputs, (BATCH_SIZE, 2, 2, 2)) x1_3, x2_3 = _get_inputs(input_key3, same_inputs, (BATCH_SIZE, 2, 2, 3)) x1_4, x2_4 = _get_inputs(input_key4, same_inputs, (BATCH_SIZE, 3, 4)) m1_key, m2_key, m3_key, m4_key = random.split(mask_key, 4) x1_1 = test_utils.mask( x1_1, mask_constant=-1, mask_axis=(1,), key=m1_key, p=0.5) x1_2 = test_utils.mask( x1_2, mask_constant=-1, mask_axis=(2, 3,), key=m2_key, p=0.5) if not same_inputs: x2_3 = test_utils.mask( x2_3, mask_constant=-1, mask_axis=(1, 3,), key=m3_key, p=0.5) x2_4 = test_utils.mask( x2_4, mask_constant=-1, mask_axis=(2,), key=m4_key, p=0.5) x1 = (((x1_1, x1_2), x1_3), x1_4) x2 = (((x2_1, x2_2), x2_3), x2_4) if not same_inputs else None N_in = 2 ** 7 # We only include dropout on non-TPU backends, because it takes large N to # converge on TPU. dropout_or_id = stax.Dropout(0.9) if platform != 'tpu' else stax.Identity() init_fn, apply_fn, kernel_fn = stax.parallel( stax.parallel( stax.parallel(stax.Dense(N_in), stax.serial(stax.Conv(N_in + 1, (2, 2)), stax.Flatten())), stax.serial(stax.Conv(N_in + 2, (2, 2)), dropout_or_id, stax.GlobalAvgPool())), stax.Conv(N_in + 3, (2,))) kernel_fn_empirical = nt.monte_carlo_kernel_fn( init_fn, apply_fn, mc_key, N_SAMPLES, implementation=_DEFAULT_TESTING_NTK_IMPLEMENTATION, vmap_axes=(((((0, 0), 0), 0), (((0, 0), 0), 0), {}) if platform == 'tpu' else None) ) test_utils.assert_close_matrices( self, kernel_fn(x1, x2, get=kernel_type, mask_constant=-1), kernel_fn_empirical(x1, x2, get=kernel_type, mask_constant=-1), rtol)
def WideResnetBlocknt(channels, strides=(1, 1), channel_mismatch=False, batchnorm='std', parameterization='ntk'): """A WideResnet block, with or without BatchNorm.""" Main = stax_nt.serial( _batch_norm_internal(batchnorm), stax_nt.Relu(), stax_nt.Conv(channels, (3, 3), strides, padding='SAME', parameterization=parameterization), _batch_norm_internal(batchnorm), stax_nt.Relu(), stax_nt.Conv(channels, (3, 3), padding='SAME', parameterization=parameterization)) Shortcut = stax_nt.Identity() if not channel_mismatch else stax_nt.Conv( channels, (3, 3), strides, padding='SAME', parameterization=parameterization) return stax_nt.serial(stax_nt.FanOut(2), stax_nt.parallel(Main, Shortcut), stax_nt.FanInSum())
def _build_network(input_shape, network, out_logits, use_dropout): dropout = stax.Dropout(0.9, mode='train') if use_dropout else stax.Identity() if len(input_shape) == 1: assert network == 'FLAT' return stax.serial(stax.Dense(WIDTH, W_std=2.0, b_std=0.5), dropout, stax.Dense(out_logits, W_std=2.0, b_std=0.5)) elif len(input_shape) == 3: if network == POOLING: return stax.serial( stax.Conv(CONVOLUTION_CHANNELS, (2, 2), W_std=2.0, b_std=0.05), stax.GlobalAvgPool(), dropout, stax.Dense(out_logits, W_std=2.0, b_std=0.5)) elif network == FLAT: return stax.serial( stax.Conv(CONVOLUTION_CHANNELS, (2, 2), W_std=2.0, b_std=0.05), stax.Flatten(), dropout, stax.Dense(out_logits, W_std=2.0, b_std=0.5)) elif network == INTERMEDIATE_CONV: return stax.Conv(CONVOLUTION_CHANNELS, (2, 2), W_std=2.0, b_std=0.05) else: raise ValueError( 'Unexpected network type found: {}'.format(network)) else: raise ValueError('Expected flat or image test input.')
def test_ab_relu_id(self, same_inputs): key = random.PRNGKey(1) X0_1 = random.normal(key, (5, 7)) fc = stax.Dense(10, 1, 0) X0_2 = None if same_inputs else random.normal(key, (9, 7)) # Test that ABRelu(a, a) == a * Identity init_fn, apply_id, kernel_fn_id = stax.serial(fc, stax.Identity()) params = init_fn(key, input_shape=(-1, 7)) for a in [-5, -1, -0.5, 0, 0.5, 1, 5]: with self.subTest(a=a): _, apply_ab_relu, kernel_fn_ab_relu = stax.serial( fc, stax.ABRelu(a, a)) X1_1_id = a * apply_id(params, X0_1) X1_1_ab_relu = apply_ab_relu(params, X0_1) self.assertAllClose(X1_1_id, X1_1_ab_relu, True) kernels_id = kernel_fn_id(X0_1 * a, None if X0_2 is None else a * X0_2) kernels_ab_relu = kernel_fn_ab_relu(X0_1, X0_2, ('nngp', 'ntk')) self.assertAllClose(kernels_id, kernels_ab_relu, True)
def ResnetBlock(channels, strides=(1, 1), channel_mismatch=False): Main = stax.serial(stax.Relu(), stax.Conv(channels, (3, 3), strides, padding='SAME'), stax.Relu(), stax.Conv(channels, (3, 3), padding='SAME')) Shortcut = stax.Identity() if not channel_mismatch else stax.Conv( channels, (3, 3), strides, padding='SAME') return stax.serial(stax.FanOut(2), stax.parallel(Main, Shortcut), stax.FanInSum())
def _get_net(W_std, b_std, filter_shape, is_conv, use_pooling, is_res, padding, phi, strides, width, is_ntk, proj_into_2d): fc = partial(stax.Dense, W_std=W_std, b_std=b_std) conv = partial( stax.Conv, filter_shape=filter_shape, strides=strides, padding=padding, W_std=W_std, b_std=b_std) affine = conv(width) if is_conv else fc(width) res_unit = stax.serial((stax.AvgPool( (2, 3), None, 'SAME' if padding == 'SAME' else 'CIRCULAR') if use_pooling else stax.Identity()), phi, affine) if is_res: block = stax.serial(affine, stax.FanOut(2), stax.parallel(stax.Identity(), res_unit), stax.FanInSum()) else: block = stax.serial(affine, res_unit) if proj_into_2d == 'FLAT': proj_layer = stax.Flatten() elif proj_into_2d == 'POOL': proj_layer = stax.GlobalAvgPool() elif proj_into_2d.startswith('ATTN'): n_heads = int(np.sqrt(width)) n_chan_val = int(np.round(float(width) / n_heads)) fixed = proj_into_2d == 'ATTN_FIXED' proj_layer = stax.serial( stax.GlobalSelfAttention( width, n_chan_key=width, n_chan_val=n_chan_val, n_heads=n_heads, fixed=fixed, W_key_std=W_std, W_value_std=W_std, W_query_std=W_std, W_out_std=1.0, b_std=b_std), stax.Flatten()) else: raise ValueError(proj_into_2d) readout = stax.serial(proj_layer, fc(1 if is_ntk else width)) return stax.serial(block, readout)
def WideResnetBlock(channels, strides=(1, 1), channel_mismatch=False): main = stax.serial( stax.Relu(), stax.Conv( channels, (3, 3), strides, padding='SAME', parameterization='standard' ), stax.Relu(), stax.Conv(channels, (3, 3), padding='SAME', parameterization='standard'), ) shortcut = ( stax.Identity() if not channel_mismatch else stax.Conv( channels, (3, 3), strides, padding='SAME', parameterization='standard' ) ) return stax.serial(stax.FanOut(2), stax.parallel(main, shortcut), stax.FanInSum())
from neural_tangents._src.empirical import _DEFAULT_TESTING_NTK_IMPLEMENTATION from tests import test_utils config.parse_flags_with_absl() config.update('jax_numpy_rank_promotion', 'raise') test_utils.update_test_tolerance() prandom.seed(1) @parameterized.product( same_inputs=[False, True], readout=[stax.Flatten(), stax.GlobalAvgPool(), stax.Identity()], readin=[stax.Flatten(), stax.GlobalAvgPool(), stax.Identity()]) class DiagonalTest(test_utils.NeuralTangentsTestCase): def _get_kernel_fn(self, same_inputs, readin, readout): key = random.PRNGKey(1) x1 = random.normal(key, (2, 5, 6, 3)) x2 = None if same_inputs else random.normal(key, (3, 5, 6, 3)) layers = [readin] filter_shape = (2, 3) if readin[0].__name__ == 'Identity' else () layers += [ stax.Conv(1, filter_shape, padding='SAME'), stax.Relu(), stax.Conv(1, filter_shape, padding='SAME'), stax.Erf(), readout
class ElementwiseTest(test_utils.NeuralTangentsTestCase): @parameterized.product( phi=[ stax.Identity(), stax.Erf(), stax.Sin(), stax.Relu(), ], same_inputs=[False, True, None], n=[0, 1, 2], diagonal_batch=[True, False], diagonal_spatial=[True, False] ) def test_elementwise( self, same_inputs, phi, n, diagonal_batch, diagonal_spatial ): fn = lambda x: phi[1]((), x) name = phi[0].__name__ def nngp_fn(cov12, var1, var2): if 'Identity' in name: res = cov12 elif 'Erf' in name: prod = (1 + 2 * var1) * (1 + 2 * var2) res = np.arcsin(2 * cov12 / np.sqrt(prod)) * 2 / np.pi elif 'Sin' in name: sum_ = (var1 + var2) s1 = np.exp((-0.5 * sum_ + cov12)) s2 = np.exp((-0.5 * sum_ - cov12)) res = (s1 - s2) / 2 elif 'Relu' in name: prod = var1 * var2 sqrt = np.sqrt(np.maximum(prod - cov12 ** 2, 1e-30)) angles = np.arctan2(sqrt, cov12) dot_sigma = (1 - angles / np.pi) / 2 res = sqrt / (2 * np.pi) + dot_sigma * cov12 else: raise NotImplementedError(name) return res _, _, kernel_fn = stax.serial(stax.Dense(1), stax.Elementwise(fn, nngp_fn), stax.Dense(1), stax.Elementwise(fn, nngp_fn)) _, _, kernel_fn_manual = stax.serial(stax.Dense(1), phi, stax.Dense(1), phi) key = random.PRNGKey(1) shape = (4, 3, 2)[:n] + (1,) x1 = random.normal(key, (5,) + shape) if same_inputs is None: x2 = None elif same_inputs is True: x2 = x1 else: x2 = random.normal(key, (6,) + shape) kwargs = dict(diagonal_batch=diagonal_batch, diagonal_spatial=diagonal_spatial) k = kernel_fn(x1, x2, **kwargs) k_manual = kernel_fn_manual(x1, x2, **kwargs).replace(is_gaussian=False) self.assertAllClose(k_manual, k)
@parameterized.named_parameters( test_utils.cases_from_list( { 'testcase_name': ' [{}_out={}_in={}]'.format( 'same_inputs' if same_inputs else 'different_inputs', readout[0].__name__, readin[0].__name__), 'same_inputs': same_inputs, 'readout': readout, 'readin': readin } for same_inputs in [False, True] for readout in [stax.Flatten(), stax.GlobalAvgPool(), stax.Identity()] for readin in [stax.Flatten(), stax.GlobalAvgPool(), stax.Identity()])) class DiagonalTest(test_utils.NeuralTangentsTestCase): def _get_kernel_fn(self, same_inputs, readin, readout): key = random.PRNGKey(1) x1 = random.normal(key, (2, 5, 6, 3)) x2 = None if same_inputs else random.normal(key, (3, 5, 6, 3)) layers = [readin] filter_shape = (2, 3) if readin[0].__name__ == 'Identity' else () layers += [ stax.Conv(1, filter_shape, padding='SAME'), stax.Relu(), stax.Conv(1, filter_shape, padding='SAME'), stax.Erf(), readout
def _get_net(W_std, b_std, filter_shape, is_conv, use_pooling, is_res, padding, phi, strides, width, is_ntk, proj_into_2d, pool_type, layer_norm, parameterization, use_dropout): if is_conv: # Select a random dimension order. default_spec = 'NHWC' if xla_bridge.get_backend().platform == 'tpu': # Keep batch dimension leading for TPU for batching to work. specs = ['NHWC', 'NHCW', 'NCHW'] else: specs = ['NHWC', 'NHCW', 'NCHW', 'CHWN', 'CHNW', 'CNHW'] spec = prandom.choice(specs) input_shape = tuple(INPUT_SHAPE[default_spec.index(c)] for c in spec) if layer_norm: layer_norm = tuple(spec.index(c) for c in layer_norm) else: # Only `NC` dimension order is supported and is enforced by layers. spec = None input_shape = INPUT_SHAPE if layer_norm: layer_norm = prandom.choice([(1,), (-1,)]) dimension_numbers = (spec, 'HWIO', spec) fc = partial( stax.Dense, W_std=W_std, b_std=b_std, parameterization=parameterization) def conv(out_chan): return stax.GeneralConv( dimension_numbers=dimension_numbers, out_chan=out_chan, filter_shape=filter_shape, strides=strides, padding=padding, W_std=W_std, b_std=b_std, parameterization=parameterization ) affine = conv(width) if is_conv else fc(width) spec = dimension_numbers[-1] rate = np.onp.random.uniform(0.5, 0.9) dropout = stax.Dropout(rate, mode='train') if pool_type == 'AVG': pool_fn = stax.AvgPool globalPool_fn = stax.GlobalAvgPool elif pool_type == 'SUM': pool_fn = stax.SumPool globalPool_fn = stax.GlobalSumPool if use_pooling: pool_or_identity = pool_fn((2, 3), None, 'SAME' if padding == 'SAME' else 'CIRCULAR', spec=spec) else: pool_or_identity = stax.Identity() dropout_or_identity = dropout if use_dropout else stax.Identity() layer_norm_or_identity = (stax.Identity() if layer_norm is None else stax.LayerNorm(axis=layer_norm, spec=spec)) res_unit = stax.serial(pool_or_identity, phi, dropout_or_identity, affine) if is_res: block = stax.serial( affine, stax.FanOut(2), stax.parallel(stax.Identity(), res_unit), stax.FanInSum(), layer_norm_or_identity) else: block = stax.serial( affine, res_unit, layer_norm_or_identity) if proj_into_2d == 'FLAT': proj_layer = stax.Flatten(spec=spec) elif proj_into_2d == 'POOL': proj_layer = globalPool_fn(spec=spec) elif proj_into_2d.startswith('ATTN'): n_heads = int(np.sqrt(width)) n_chan_val = int(np.round(float(width) / n_heads)) fixed = proj_into_2d == 'ATTN_FIXED' proj_layer = stax.serial( stax.GlobalSelfAttention( n_chan_out=width, n_chan_key=width, n_chan_val=n_chan_val, n_heads=n_heads, fixed=fixed, W_key_std=W_std, W_value_std=W_std, W_query_std=W_std, W_out_std=1.0, b_std=b_std, spec=spec), stax.Flatten(spec=spec)) else: raise ValueError(proj_into_2d) readout = stax.serial(proj_layer, fc(1 if is_ntk else width)) return stax.serial(block, readout), input_shape
step=1) sigma_w = st.slider("Sigma w for Residual Case ", 0.1, 3.0, 1.5, step=0.1) sigma_b = st.slider("Sigma b for Residual Case", 0.0, 0.1, 0.05, step=0.01) activation_fn = st.selectbox("Activation Function for Residual Case", ("Erf", "ReLU", "None")) activation_fn = activation_fn_dict[activation_fn] sequence = ((activation_fn, stax.Dense(n_hidden, W_std=sigma_w, b_std=sigma_b)) if activation_fn else (stax.Dense(n_hidden, W_std=sigma_w, b_std=sigma_b), )) ResBlock = stax.serial( stax.FanOut(2), stax.parallel(stax.serial(*(sequence * depth)), stax.Identity()), stax.FanInSum(), ) init_fn, apply_fn, kernel_fn = stax.serial( stax.Dense(n_hidden, W_std=sigma_w, b_std=sigma_b), ResBlock, ResBlock, activation_fn, stax.Dense(1, W_std=sigma_w, b_std=sigma_b), ) apply_fn = jit(apply_fn) kernel_fn = jit(kernel_fn, static_argnums=(2, )) opt_init, opt_update, get_params = optimizers.sgd(learning_rate)
def _get_net(W_std, b_std, filter_shape, is_conv, use_pooling, is_res, padding, phi, strides, width, is_ntk, proj_into_2d, pool_type, layer_norm, parameterization, s, use_dropout): if is_conv: # Select a random filter order. default_filter_spec = 'HW' filter_specs = [''.join(p) for p in itertools.permutations('HWIO')] filter_spec = prandom.choice(filter_specs) filter_shape = tuple(filter_shape[default_filter_spec.index(c)] for c in filter_spec if c in default_filter_spec) strides = tuple(strides[default_filter_spec.index(c)] for c in filter_spec if c in default_filter_spec) # Select the activation order. default_spec = 'NHWC' if default_backend() == 'tpu': # Keep batch dimension leading for TPU for batching to work. specs = ['N' + ''.join(p) for p in itertools.permutations('CHW')] else: specs = [''.join(p) for p in itertools.permutations('NCHW')] spec = prandom.choice(specs) input_shape = tuple(INPUT_SHAPE[default_spec.index(c)] for c in spec) else: input_shape = (INPUT_SHAPE[0], onp.prod(INPUT_SHAPE[1:])) if default_backend() == 'tpu': spec = 'NC' else: spec = prandom.choice(['NC', 'CN']) if spec.index('N') == 1: input_shape = input_shape[::-1] filter_spec = None dimension_numbers = (spec, filter_spec, spec) batch_axis, channel_axis = spec.index('N'), spec.index('C') spec_fc = ''.join(c for c in spec if c in ('N', 'C')) batch_axis_fc, channel_axis_fc = spec_fc.index('N'), spec_fc.index('C') if not is_conv: batch_axis = batch_axis_fc channel_axis = channel_axis_fc if layer_norm: layer_norm = tuple(spec.index(c) for c in layer_norm) def fc(out_dim, s): return stax.Dense( out_dim=out_dim, W_std=W_std, b_std=b_std, parameterization=parameterization, s=s, batch_axis=batch_axis_fc, channel_axis=channel_axis_fc ) def conv(out_chan, s): return stax.Conv( out_chan=out_chan, filter_shape=filter_shape, strides=strides, padding=padding, W_std=W_std, b_std=b_std, dimension_numbers=dimension_numbers, parameterization=parameterization, s=s ) affine = conv(width, (s, s)) if is_conv else fc(width, (s, s)) affine_bottom = conv(width, (1, s)) if is_conv else fc(width, (1, s)) rate = onp.random.uniform(0.5, 0.9) dropout = stax.Dropout(rate, mode='train') if pool_type == 'AVG': pool_fn = stax.AvgPool global_pool_fn = stax.GlobalAvgPool elif pool_type == 'SUM': pool_fn = stax.SumPool global_pool_fn = stax.GlobalSumPool else: raise ValueError(pool_type) if use_pooling: pool_or_identity = pool_fn((2, 3), None, 'SAME' if padding == 'SAME' else 'CIRCULAR', batch_axis=batch_axis, channel_axis=channel_axis) else: pool_or_identity = stax.Identity() dropout_or_identity = dropout if use_dropout else stax.Identity() layer_norm_or_identity = (stax.Identity() if layer_norm is None else stax.LayerNorm(axis=layer_norm, batch_axis=batch_axis, channel_axis=channel_axis)) res_unit = stax.serial(dropout_or_identity, affine, pool_or_identity) if is_res: block = stax.serial( affine_bottom, stax.FanOut(2), stax.parallel(stax.Identity(), res_unit), stax.FanInSum(), layer_norm_or_identity, phi) else: block = stax.serial( affine_bottom, res_unit, layer_norm_or_identity, phi) if proj_into_2d == 'FLAT': proj_layer = stax.Flatten(batch_axis, batch_axis_fc) elif proj_into_2d == 'POOL': proj_layer = global_pool_fn(batch_axis, channel_axis) elif proj_into_2d.startswith('ATTN'): n_heads = int(np.sqrt(width)) n_chan_val = int(np.round(float(width) / n_heads)) proj_layer = stax.serial( stax.GlobalSelfAttention( n_chan_out=width, n_chan_key=width, n_chan_val=n_chan_val, n_heads=n_heads, linear_scaling=True, W_key_std=W_std, W_value_std=W_std, W_query_std=W_std, W_out_std=1.0, b_std=b_std, batch_axis=batch_axis, channel_axis=channel_axis), stax.Flatten(batch_axis, batch_axis_fc)) else: raise ValueError(proj_into_2d) readout = stax.serial(proj_layer, fc(1 if is_ntk else width, (s, 1 if is_ntk else s))) device_count = -1 if spec.index('N') == 0 else 0 net = stax.serial(block, readout) return net, input_shape, device_count, channel_axis_fc