def test_nested_parallel(self, same_inputs, kernel_type): platform = default_backend() rtol = RTOL if platform != 'tpu' else 0.05 rng = random.PRNGKey(0) (input_key1, input_key2, input_key3, input_key4, mask_key, mc_key) = random.split(rng, 6) x1_1, x2_1 = _get_inputs(input_key1, same_inputs, (BATCH_SIZE, 5)) x1_2, x2_2 = _get_inputs(input_key2, same_inputs, (BATCH_SIZE, 2, 2, 2)) x1_3, x2_3 = _get_inputs(input_key3, same_inputs, (BATCH_SIZE, 2, 2, 3)) x1_4, x2_4 = _get_inputs(input_key4, same_inputs, (BATCH_SIZE, 3, 4)) m1_key, m2_key, m3_key, m4_key = random.split(mask_key, 4) x1_1 = test_utils.mask( x1_1, mask_constant=-1, mask_axis=(1,), key=m1_key, p=0.5) x1_2 = test_utils.mask( x1_2, mask_constant=-1, mask_axis=(2, 3,), key=m2_key, p=0.5) if not same_inputs: x2_3 = test_utils.mask( x2_3, mask_constant=-1, mask_axis=(1, 3,), key=m3_key, p=0.5) x2_4 = test_utils.mask( x2_4, mask_constant=-1, mask_axis=(2,), key=m4_key, p=0.5) x1 = (((x1_1, x1_2), x1_3), x1_4) x2 = (((x2_1, x2_2), x2_3), x2_4) if not same_inputs else None N_in = 2 ** 7 # We only include dropout on non-TPU backends, because it takes large N to # converge on TPU. dropout_or_id = stax.Dropout(0.9) if platform != 'tpu' else stax.Identity() init_fn, apply_fn, kernel_fn = stax.parallel( stax.parallel( stax.parallel(stax.Dense(N_in), stax.serial(stax.Conv(N_in + 1, (2, 2)), stax.Flatten())), stax.serial(stax.Conv(N_in + 2, (2, 2)), dropout_or_id, stax.GlobalAvgPool())), stax.Conv(N_in + 3, (2,))) kernel_fn_empirical = nt.monte_carlo_kernel_fn( init_fn, apply_fn, mc_key, N_SAMPLES, implementation=_DEFAULT_TESTING_NTK_IMPLEMENTATION, vmap_axes=(((((0, 0), 0), 0), (((0, 0), 0), 0), {}) if platform == 'tpu' else None) ) test_utils.assert_close_matrices( self, kernel_fn(x1, x2, get=kernel_type, mask_constant=-1), kernel_fn_empirical(x1, x2, get=kernel_type, mask_constant=-1), rtol)
def test_vmap_axes(self, same_inputs): n1, n2 = 3, 4 c1, c2, c3 = 9, 5, 7 h2, h3, w3 = 6, 8, 2 def get_x(n, k): k1, k2, k3 = random.split(k, 3) x1 = random.normal(k1, (n, c1)) x2 = random.normal(k2, (h2, n, c2)) x3 = random.normal(k3, (c3, w3, n, h3)) x = [(x1, x2), x3] return x x1 = get_x(n1, random.PRNGKey(1)) x2 = get_x(n2, random.PRNGKey(2)) if not same_inputs else None p1 = random.normal(random.PRNGKey(5), (n1, h2, h2)) p2 = None if same_inputs else random.normal(random.PRNGKey(6), (n2, h2, h2)) init_fn, apply_fn, _ = stax.serial( stax.parallel( stax.parallel( stax.serial(stax.Dense(4, 2., 0.1), stax.Relu(), stax.Dense(3, 1., 0.15)), # 1 stax.serial(stax.Conv(7, (2,), padding='SAME', dimension_numbers=('HNC', 'OIH', 'NHC')), stax.Erf(), stax.Aggregate(1, 0, -1), stax.GlobalAvgPool(), stax.Dense(3, 0.5, 0.2)), # 2 ), stax.serial( stax.Conv(5, (2, 3), padding='SAME', dimension_numbers=('CWNH', 'IOHW', 'HWCN')), stax.Sin(), ) # 3 ), stax.parallel( stax.FanInSum(), stax.Conv(2, (2, 1), dimension_numbers=('HWCN', 'OIHW', 'HNWC')) ) ) _, params = init_fn(random.PRNGKey(3), tree_map(np.shape, x1)) implicit = jit(empirical._empirical_implicit_ntk_fn(apply_fn)) direct = jit(empirical._empirical_direct_ntk_fn(apply_fn)) implicit_batched = jit(empirical._empirical_implicit_ntk_fn( apply_fn, vmap_axes=([(0, 1), 2], [-2, -3], dict(pattern=0)))) direct_batched = jit(empirical._empirical_direct_ntk_fn( apply_fn, vmap_axes=([(-2, -2), -2], [0, 1], dict(pattern=-3)))) k = direct(x1, x2, params, pattern=(p1, p2)) self.assertAllClose(k, implicit(x1, x2, params, pattern=(p1, p2))) self.assertAllClose(k, direct_batched(x1, x2, params, pattern=(p1, p2))) self.assertAllClose(k, implicit_batched(x1, x2, params, pattern=(p1, p2)))
def _get_net(W_std, b_std, filter_shape, is_conv, use_pooling, is_res, padding, phi, strides, width, is_ntk): fc = partial(stax.Dense, W_std=W_std, b_std=b_std) conv = partial( stax.Conv, filter_shape=filter_shape, strides=strides, padding=padding, W_std=W_std, b_std=b_std) affine = conv(width) if is_conv else fc(width) res_unit = stax.serial((stax.AvgPool( (2, 3), None, 'SAME' if padding == 'SAME' else 'CIRCULAR') if use_pooling else stax.Identity()), phi, affine) if is_res: block = stax.serial(affine, stax.FanOut(2), stax.parallel(stax.Identity(), res_unit), stax.FanInSum()) else: block = stax.serial(affine, res_unit) readout = stax.serial(stax.GlobalAvgPool() if use_pooling else stax.Flatten(), fc(1 if is_ntk else width)) net = stax.serial(block, readout) return net
def WideResnetBlocknt(channels, strides=(1, 1), channel_mismatch=False, batchnorm='std', parameterization='ntk'): """A WideResnet block, with or without BatchNorm.""" Main = stax_nt.serial( _batch_norm_internal(batchnorm), stax_nt.Relu(), stax_nt.Conv(channels, (3, 3), strides, padding='SAME', parameterization=parameterization), _batch_norm_internal(batchnorm), stax_nt.Relu(), stax_nt.Conv(channels, (3, 3), padding='SAME', parameterization=parameterization)) Shortcut = stax_nt.Identity() if not channel_mismatch else stax_nt.Conv( channels, (3, 3), strides, padding='SAME', parameterization=parameterization) return stax_nt.serial(stax_nt.FanOut(2), stax_nt.parallel(Main, Shortcut), stax_nt.FanInSum())
def ResnetBlock(channels, strides=(1, 1), channel_mismatch=False): Main = stax.serial(stax.Relu(), stax.Conv(channels, (3, 3), strides, padding='SAME'), stax.Relu(), stax.Conv(channels, (3, 3), padding='SAME')) Shortcut = stax.Identity() if not channel_mismatch else stax.Conv( channels, (3, 3), strides, padding='SAME') return stax.serial(stax.FanOut(2), stax.parallel(Main, Shortcut), stax.FanInSum())
def test_parallel_in_out(self, same_inputs, kernel_type): platform = default_backend() rtol = RTOL if platform != 'tpu' else 0.05 rng = random.PRNGKey(0) input_key1, input_key2, mc_key = random.split(rng, 3) x1_1, x2_1 = _get_inputs(input_key1, same_inputs, (BATCH_SIZE, 1)) x1_2, x2_2 = _get_inputs(input_key2, same_inputs, (BATCH_SIZE, 2)) x1 = (x1_1, x1_2) x2 = (x2_1, x2_2) N_in = 2 ** 10 N_out = N_in if kernel_type == 'nngp' else 1 readin = stax.serial(stax.parallel(stax.Dense(N_in), stax.Dense(N_in)), stax.FanInSum()) readout = stax.serial(stax.FanOut(3), stax.parallel(stax.Dense(N_out), stax.Dense(N_out + 1), stax.Dense(N_out + 2))) init_fn, apply_fn, _ = stax.serial(readin, readout) K_readin_fn = jit(readin[2]) K_readout_fn = jit(functools.partial(readout[2], get=kernel_type)) kernel_fn_empirical = nt.monte_carlo_kernel_fn( init_fn, apply_fn, mc_key, N_SAMPLES, trace_axes=(-1,), implementation=2, vmap_axes=((0, 0), [0, 0, 0], {}) ) test_utils.assert_close_matrices( self, K_readout_fn(K_readin_fn(x1, x2)), kernel_fn_empirical(x1, x2, get=kernel_type), rtol) # Check Both (here we just want to make sure we _can_ compute the output). K_readin_fn = jit(readin[2]) K_readout_fn = jit(functools.partial(readout[2], get=('nngp', 'ntk'))) K_readout_fn(K_readin_fn(x1, x2))
def _get_net(W_std, b_std, filter_shape, is_conv, use_pooling, is_res, padding, phi, strides, width, is_ntk, proj_into_2d, layer_norm, parameterization, use_dropout): fc = partial(stax.Dense, W_std=W_std, b_std=b_std, parameterization=parameterization) conv = partial(stax.Conv, filter_shape=filter_shape, strides=strides, padding=padding, W_std=W_std, b_std=b_std, parameterization=parameterization) affine = conv(width) if is_conv else fc(width) rate = np.onp.random.uniform(0.5, 0.9) dropout = stax.Dropout(rate, mode='train') ave_pool = stax.AvgPool((2, 3), None, 'SAME' if padding == 'SAME' else 'CIRCULAR') ave_pool_or_identity = ave_pool if use_pooling else stax.Identity() dropout_or_identity = dropout if use_dropout else stax.Identity() layer_norm_or_identity = (stax.Identity() if layer_norm is None else stax.LayerNorm(axis=layer_norm)) res_unit = stax.serial(ave_pool_or_identity, phi, dropout_or_identity, affine) if is_res: block = stax.serial(affine, stax.FanOut(2), stax.parallel(stax.Identity(), res_unit), stax.FanInSum(), layer_norm_or_identity) else: block = stax.serial(affine, res_unit, layer_norm_or_identity) if proj_into_2d == 'FLAT': proj_layer = stax.Flatten() elif proj_into_2d == 'POOL': proj_layer = stax.GlobalAvgPool() elif proj_into_2d.startswith('ATTN'): n_heads = int(np.sqrt(width)) n_chan_val = int(np.round(float(width) / n_heads)) fixed = proj_into_2d == 'ATTN_FIXED' proj_layer = stax.serial( stax.GlobalSelfAttention(width, n_chan_key=width, n_chan_val=n_chan_val, n_heads=n_heads, fixed=fixed, W_key_std=W_std, W_value_std=W_std, W_query_std=W_std, W_out_std=1.0, b_std=b_std), stax.Flatten()) else: raise ValueError(proj_into_2d) readout = stax.serial(proj_layer, fc(1 if is_ntk else width)) return stax.serial(block, readout)
def _get_net(W_std, b_std, filter_shape, is_conv, use_pooling, is_res, padding, phi, strides, width, is_ntk, proj_into_2d): fc = partial(stax.Dense, W_std=W_std, b_std=b_std) conv = partial( stax.Conv, filter_shape=filter_shape, strides=strides, padding=padding, W_std=W_std, b_std=b_std) affine = conv(width) if is_conv else fc(width) res_unit = stax.serial((stax.AvgPool( (2, 3), None, 'SAME' if padding == 'SAME' else 'CIRCULAR') if use_pooling else stax.Identity()), phi, affine) if is_res: block = stax.serial(affine, stax.FanOut(2), stax.parallel(stax.Identity(), res_unit), stax.FanInSum()) else: block = stax.serial(affine, res_unit) if proj_into_2d == 'FLAT': proj_layer = stax.Flatten() elif proj_into_2d == 'POOL': proj_layer = stax.GlobalAvgPool() elif proj_into_2d.startswith('ATTN'): n_heads = int(np.sqrt(width)) n_chan_val = int(np.round(float(width) / n_heads)) fixed = proj_into_2d == 'ATTN_FIXED' proj_layer = stax.serial( stax.GlobalSelfAttention( width, n_chan_key=width, n_chan_val=n_chan_val, n_heads=n_heads, fixed=fixed, W_key_std=W_std, W_value_std=W_std, W_query_std=W_std, W_out_std=1.0, b_std=b_std), stax.Flatten()) else: raise ValueError(proj_into_2d) readout = stax.serial(proj_layer, fc(1 if is_ntk else width)) return stax.serial(block, readout)
def WideResnetBlock(channels, strides=(1, 1), channel_mismatch=False): main = stax.serial( stax.Relu(), stax.Conv( channels, (3, 3), strides, padding='SAME', parameterization='standard' ), stax.Relu(), stax.Conv(channels, (3, 3), padding='SAME', parameterization='standard'), ) shortcut = ( stax.Identity() if not channel_mismatch else stax.Conv( channels, (3, 3), strides, padding='SAME', parameterization='standard' ) ) return stax.serial(stax.FanOut(2), stax.parallel(main, shortcut), stax.FanInSum())
step=1) sigma_w = st.slider("Sigma w for Residual Case ", 0.1, 3.0, 1.5, step=0.1) sigma_b = st.slider("Sigma b for Residual Case", 0.0, 0.1, 0.05, step=0.01) activation_fn = st.selectbox("Activation Function for Residual Case", ("Erf", "ReLU", "None")) activation_fn = activation_fn_dict[activation_fn] sequence = ((activation_fn, stax.Dense(n_hidden, W_std=sigma_w, b_std=sigma_b)) if activation_fn else (stax.Dense(n_hidden, W_std=sigma_w, b_std=sigma_b), )) ResBlock = stax.serial( stax.FanOut(2), stax.parallel(stax.serial(*(sequence * depth)), stax.Identity()), stax.FanInSum(), ) init_fn, apply_fn, kernel_fn = stax.serial( stax.Dense(n_hidden, W_std=sigma_w, b_std=sigma_b), ResBlock, ResBlock, activation_fn, stax.Dense(1, W_std=sigma_w, b_std=sigma_b), ) apply_fn = jit(apply_fn) kernel_fn = jit(kernel_fn, static_argnums=(2, )) opt_init, opt_update, get_params = optimizers.sgd(learning_rate)
def net(N_out): return stax.parallel( stax.Dense(N_out), stax.parallel(stax.Dense(N_out + 1), stax.Dense(N_out + 2)))
def test_fan_in_fc(self, same_inputs, axis, n_branches, get, branch_in, fan_in_mode): if fan_in_mode in ['FanInSum', 'FanInProd']: if axis != 0: raise absltest.SkipTest( '`FanInSum` and `FanInProd` are skipped when ' 'axis != 0.') axis = None if (fan_in_mode == 'FanInSum' or axis == 0) and branch_in == 'dense_after_branch_in': raise absltest.SkipTest('`FanInSum` and `FanInConcat(0)` ' 'require `is_gaussian`.') if ((axis == 1 or fan_in_mode == 'FanInProd') and branch_in == 'dense_before_branch_in'): raise absltest.SkipTest( '`FanInConcat` or `FanInProd` on feature axis requires a dense layer ' 'after concatenation or Hadamard product.') if fan_in_mode == 'FanInSum': fan_in_layer = stax.FanInSum() elif fan_in_mode == 'FanInProd': fan_in_layer = stax.FanInProd() else: fan_in_layer = stax.FanInConcat(axis) if n_branches != 2: test_utils.skip_test(self) key = random.PRNGKey(1) X0_1 = np.cos(random.normal(key, (4, 3))) X0_2 = None if same_inputs else random.normal(key, (8, 3)) width = 1024 n_samples = 256 * 2 if default_backend() == 'tpu': tol = 0.07 else: tol = 0.02 dense = stax.Dense(width, 1.25, 0.1) input_layers = [dense, stax.FanOut(n_branches)] branches = [] for b in range(n_branches): branch_layers = [FanInTest._get_phi(b)] for i in range(b): multiplier = 1 if axis not in (1, -1) else (1 + 0.25 * i) branch_layers += [ stax.Dense(int(width * multiplier), 1. + 2 * i, 0.5 + i), FanInTest._get_phi(i) ] if branch_in == 'dense_before_branch_in': branch_layers += [dense] branches += [stax.serial(*branch_layers)] output_layers = [fan_in_layer, stax.Relu()] if branch_in == 'dense_after_branch_in': output_layers.insert(1, dense) nn = stax.serial(*(input_layers + [stax.parallel(*branches)] + output_layers)) if get == 'nngp': init_fn, apply_fn, kernel_fn = nn elif get == 'ntk': init_fn, apply_fn, kernel_fn = stax.serial( nn, stax.Dense(1, 1.25, 0.5)) else: raise ValueError(get) kernel_fn_mc = nt.monte_carlo_kernel_fn( init_fn, apply_fn, key, n_samples, device_count=0 if axis in (0, -2) else -1, implementation=_DEFAULT_TESTING_NTK_IMPLEMENTATION, vmap_axes=None if axis in (0, -2) else 0, ) exact = kernel_fn(X0_1, X0_2, get=get) empirical = kernel_fn_mc(X0_1, X0_2, get=get) test_utils.assert_close_matrices(self, empirical, exact, tol)
def test_fan_in_conv(self, same_inputs, axis, n_branches, get, branch_in, readout, fan_in_mode): test_utils.skip_test(self) if fan_in_mode in ['FanInSum', 'FanInProd']: if axis != 0: raise absltest.SkipTest( '`FanInSum` and `FanInProd()` are skipped when ' 'axis != 0.') axis = None if (fan_in_mode == 'FanInSum' or axis in [0, 1, 2]) and branch_in == 'dense_after_branch_in': raise absltest.SkipTest('`FanInSum` and `FanInConcat(0/1/2)` ' 'require `is_gaussian`.') if ((axis == 3 or fan_in_mode == 'FanInProd') and branch_in == 'dense_before_branch_in'): raise absltest.SkipTest( '`FanInConcat` or `FanInProd` on feature axis ' 'requires a dense layer after concatenation ' 'or Hadamard product.') if fan_in_mode == 'FanInSum': fan_in_layer = stax.FanInSum() elif fan_in_mode == 'FanInProd': fan_in_layer = stax.FanInProd() else: fan_in_layer = stax.FanInConcat(axis) key = random.PRNGKey(1) X0_1 = random.normal(key, (2, 5, 6, 3)) X0_2 = None if same_inputs else random.normal(key, (3, 5, 6, 3)) if default_backend() == 'tpu': width = 2048 n_samples = 1024 tol = 0.02 else: width = 1024 n_samples = 512 tol = 0.01 conv = stax.Conv(out_chan=width, filter_shape=(3, 3), padding='SAME', W_std=1.25, b_std=0.1) input_layers = [conv, stax.FanOut(n_branches)] branches = [] for b in range(n_branches): branch_layers = [FanInTest._get_phi(b)] for i in range(b): multiplier = 1 if axis not in (3, -1) else (1 + 0.25 * i) branch_layers += [ stax.Conv(out_chan=int(width * multiplier), filter_shape=(i + 1, 4 - i), padding='SAME', W_std=1.25 + i, b_std=0.1 + i), FanInTest._get_phi(i) ] if branch_in == 'dense_before_branch_in': branch_layers += [conv] branches += [stax.serial(*branch_layers)] output_layers = [ fan_in_layer, stax.Relu(), stax.GlobalAvgPool() if readout == 'pool' else stax.Flatten() ] if branch_in == 'dense_after_branch_in': output_layers.insert(1, conv) nn = stax.serial(*(input_layers + [stax.parallel(*branches)] + output_layers)) init_fn, apply_fn, kernel_fn = stax.serial( nn, stax.Dense(1 if get == 'ntk' else width, 1.25, 0.5)) kernel_fn_mc = nt.monte_carlo_kernel_fn( init_fn, apply_fn, key, n_samples, device_count=0 if axis in (0, -4) else -1, implementation=_DEFAULT_TESTING_NTK_IMPLEMENTATION, vmap_axes=None if axis in (0, -4) else 0, ) exact = kernel_fn(X0_1, X0_2, get=get) empirical = kernel_fn_mc(X0_1, X0_2, get=get) test_utils.assert_close_matrices(self, empirical, exact, tol)
def test_mask_conv(self, same_inputs, get, mask_axis, mask_constant, concat, proj, p, n, transpose): if isinstance(concat, int) and concat > n: raise absltest.SkipTest('Concatenation axis out of bounds.') test_utils.skip_test(self) if default_backend() == 'gpu' and n > 3: raise absltest.SkipTest('>=4D-CNN is not supported on GPUs.') width = 256 n_samples = 256 tol = 0.03 key = random.PRNGKey(1) spatial_shape = ((1, 2, 3, 2, 1) if transpose else (15, 8, 9))[:n] filter_shape = ((2, 3, 1, 2, 1) if transpose else (7, 2, 3))[:n] strides = (2, 1, 3, 2, 3)[:n] spatial_spec = 'HWDZX'[:n] dimension_numbers = ('N' + spatial_spec + 'C', 'OI' + spatial_spec, 'N' + spatial_spec + 'C') x1 = np.cos(random.normal(key, (2, ) + spatial_shape + (2, ))) x1 = test_utils.mask(x1, mask_constant, mask_axis, key, p) if same_inputs: x2 = None else: x2 = np.cos(random.normal(key, (4, ) + spatial_shape + (2, ))) x2 = test_utils.mask(x2, mask_constant, mask_axis, key, p) def get_attn(): return stax.GlobalSelfAttention( n_chan_out=width, n_chan_key=width, n_chan_val=int(np.round(float(width) / int(np.sqrt(width)))), n_heads=int(np.sqrt(width)), ) if proj == 'avg' else stax.Identity() conv = stax.ConvTranspose if transpose else stax.Conv nn = stax.serial( stax.FanOut(3), stax.parallel( stax.serial( conv(dimension_numbers=dimension_numbers, out_chan=width, strides=strides, filter_shape=filter_shape, padding='CIRCULAR', W_std=1.5, b_std=0.2), stax.LayerNorm(axis=(1, -1)), stax.Abs(), stax.DotGeneral(rhs=0.9), conv(dimension_numbers=dimension_numbers, out_chan=width, strides=strides, filter_shape=filter_shape, padding='VALID', W_std=1.2, b_std=0.1), ), stax.serial( conv(dimension_numbers=dimension_numbers, out_chan=width, strides=strides, filter_shape=filter_shape, padding='SAME', W_std=0.1, b_std=0.3), stax.Relu(), stax.Dropout(0.7), conv(dimension_numbers=dimension_numbers, out_chan=width, strides=strides, filter_shape=filter_shape, padding='VALID', W_std=0.9, b_std=1.), ), stax.serial( get_attn(), conv(dimension_numbers=dimension_numbers, out_chan=width, strides=strides, filter_shape=filter_shape, padding='CIRCULAR', W_std=1., b_std=0.1), stax.Erf(), stax.Dropout(0.2), stax.DotGeneral(rhs=0.7), conv(dimension_numbers=dimension_numbers, out_chan=width, strides=strides, filter_shape=filter_shape, padding='VALID', W_std=1., b_std=0.1), )), (stax.FanInSum() if concat is None else stax.FanInConcat(concat)), get_attn(), { 'avg': stax.GlobalAvgPool(), 'sum': stax.GlobalSumPool(), 'flatten': stax.Flatten(), }[proj], ) if get == 'nngp': init_fn, apply_fn, kernel_fn = stax.serial( nn, stax.Dense(width, 1., 0.)) elif get == 'ntk': init_fn, apply_fn, kernel_fn = stax.serial(nn, stax.Dense(1, 1., 0.)) else: raise ValueError(get) kernel_fn_mc = nt.monte_carlo_kernel_fn( init_fn, apply_fn, key, n_samples, device_count=0 if concat in (0, -n) else -1, implementation=_DEFAULT_TESTING_NTK_IMPLEMENTATION, vmap_axes=None if concat in (0, -n) else 0, ) kernel_fn = jit(kernel_fn, static_argnames='get') exact = kernel_fn(x1, x2, get, mask_constant=mask_constant) empirical = kernel_fn_mc(x1, x2, get=get, mask_constant=mask_constant) test_utils.assert_close_matrices(self, empirical, exact, tol)
def test_mask_fc(self, same_inputs, get, concat, p, mask_axis, mask_constant): width = 512 n_samples = 128 tol = 0.04 key = random.PRNGKey(1) x1 = random.normal(key, (4, 6, 5, 7)) x1 = test_utils.mask(x1, mask_constant, mask_axis, key, p) if same_inputs: x2 = None else: x2 = random.normal(key, (2, 6, 5, 7)) x2 = test_utils.mask(x2, mask_constant, mask_axis, key, p) nn = stax.serial( stax.Flatten(), stax.FanOut(3), stax.parallel( stax.serial( stax.Dense(width, 1., 0.1), stax.Abs(), stax.DotGeneral(lhs=-0.2), stax.Dense(width, 1.5, 0.01), ), stax.serial( stax.Dense(width, 1.1, 0.1), stax.DotGeneral(rhs=0.7), stax.Erf(), stax.Dense(width if concat != 1 else 512, 1.5, 0.1), ), stax.serial( stax.DotGeneral(rhs=0.5), stax.Dense(width, 1.2), stax.ABRelu(-0.2, 0.4), stax.Dense(width if concat != 1 else 1024, 1.3, 0.2), )), (stax.FanInSum() if concat is None else stax.FanInConcat(concat)), stax.Dense(width, 2., 0.01), stax.Relu()) if get == 'nngp': init_fn, apply_fn, kernel_fn = stax.serial( nn, stax.Dense(width, 1., 0.1)) elif get == 'ntk': init_fn, apply_fn, kernel_fn = stax.serial(nn, stax.Dense(1, 1., 0.1)) else: raise ValueError(get) kernel_fn_mc = nt.monte_carlo_kernel_fn( init_fn, apply_fn, key, n_samples, device_count=0 if concat in (0, -2) else -1, implementation=_DEFAULT_TESTING_NTK_IMPLEMENTATION, vmap_axes=None if concat in (0, -2) else 0, ) kernel_fn = jit(kernel_fn, static_argnames='get') exact = kernel_fn(x1, x2, get, mask_constant=mask_constant) empirical = kernel_fn_mc(x1, x2, get=get, mask_constant=mask_constant) test_utils.assert_close_matrices(self, empirical, exact, tol)
def net(logits): return stax.serial( stax.parallel(stax.Dense(N), stax.Dense(N)), stax.serial(stax.FanInSum(), stax.Dense(logits)))
def _get_net(W_std, b_std, filter_shape, is_conv, use_pooling, is_res, padding, phi, strides, width, is_ntk, proj_into_2d, pool_type, layer_norm, parameterization, s, use_dropout): if is_conv: # Select a random filter order. default_filter_spec = 'HW' filter_specs = [''.join(p) for p in itertools.permutations('HWIO')] filter_spec = prandom.choice(filter_specs) filter_shape = tuple(filter_shape[default_filter_spec.index(c)] for c in filter_spec if c in default_filter_spec) strides = tuple(strides[default_filter_spec.index(c)] for c in filter_spec if c in default_filter_spec) # Select the activation order. default_spec = 'NHWC' if default_backend() == 'tpu': # Keep batch dimension leading for TPU for batching to work. specs = ['N' + ''.join(p) for p in itertools.permutations('CHW')] else: specs = [''.join(p) for p in itertools.permutations('NCHW')] spec = prandom.choice(specs) input_shape = tuple(INPUT_SHAPE[default_spec.index(c)] for c in spec) else: input_shape = (INPUT_SHAPE[0], onp.prod(INPUT_SHAPE[1:])) if default_backend() == 'tpu': spec = 'NC' else: spec = prandom.choice(['NC', 'CN']) if spec.index('N') == 1: input_shape = input_shape[::-1] filter_spec = None dimension_numbers = (spec, filter_spec, spec) batch_axis, channel_axis = spec.index('N'), spec.index('C') spec_fc = ''.join(c for c in spec if c in ('N', 'C')) batch_axis_fc, channel_axis_fc = spec_fc.index('N'), spec_fc.index('C') if not is_conv: batch_axis = batch_axis_fc channel_axis = channel_axis_fc if layer_norm: layer_norm = tuple(spec.index(c) for c in layer_norm) def fc(out_dim, s): return stax.Dense( out_dim=out_dim, W_std=W_std, b_std=b_std, parameterization=parameterization, s=s, batch_axis=batch_axis_fc, channel_axis=channel_axis_fc ) def conv(out_chan, s): return stax.Conv( out_chan=out_chan, filter_shape=filter_shape, strides=strides, padding=padding, W_std=W_std, b_std=b_std, dimension_numbers=dimension_numbers, parameterization=parameterization, s=s ) affine = conv(width, (s, s)) if is_conv else fc(width, (s, s)) affine_bottom = conv(width, (1, s)) if is_conv else fc(width, (1, s)) rate = onp.random.uniform(0.5, 0.9) dropout = stax.Dropout(rate, mode='train') if pool_type == 'AVG': pool_fn = stax.AvgPool global_pool_fn = stax.GlobalAvgPool elif pool_type == 'SUM': pool_fn = stax.SumPool global_pool_fn = stax.GlobalSumPool else: raise ValueError(pool_type) if use_pooling: pool_or_identity = pool_fn((2, 3), None, 'SAME' if padding == 'SAME' else 'CIRCULAR', batch_axis=batch_axis, channel_axis=channel_axis) else: pool_or_identity = stax.Identity() dropout_or_identity = dropout if use_dropout else stax.Identity() layer_norm_or_identity = (stax.Identity() if layer_norm is None else stax.LayerNorm(axis=layer_norm, batch_axis=batch_axis, channel_axis=channel_axis)) res_unit = stax.serial(dropout_or_identity, affine, pool_or_identity) if is_res: block = stax.serial( affine_bottom, stax.FanOut(2), stax.parallel(stax.Identity(), res_unit), stax.FanInSum(), layer_norm_or_identity, phi) else: block = stax.serial( affine_bottom, res_unit, layer_norm_or_identity, phi) if proj_into_2d == 'FLAT': proj_layer = stax.Flatten(batch_axis, batch_axis_fc) elif proj_into_2d == 'POOL': proj_layer = global_pool_fn(batch_axis, channel_axis) elif proj_into_2d.startswith('ATTN'): n_heads = int(np.sqrt(width)) n_chan_val = int(np.round(float(width) / n_heads)) proj_layer = stax.serial( stax.GlobalSelfAttention( n_chan_out=width, n_chan_key=width, n_chan_val=n_chan_val, n_heads=n_heads, linear_scaling=True, W_key_std=W_std, W_value_std=W_std, W_query_std=W_std, W_out_std=1.0, b_std=b_std, batch_axis=batch_axis, channel_axis=channel_axis), stax.Flatten(batch_axis, batch_axis_fc)) else: raise ValueError(proj_into_2d) readout = stax.serial(proj_layer, fc(1 if is_ntk else width, (s, 1 if is_ntk else s))) device_count = -1 if spec.index('N') == 0 else 0 net = stax.serial(block, readout) return net, input_shape, device_count, channel_axis_fc
def test_fan_in_fc(self, same_inputs, axis, n_branches, get, branch_in): if axis in (None, 0) and branch_in == 'dense_after_branch_in': raise jtu.SkipTest('`FanInSum` and `FanInConcat(0)` ' 'require `is_gaussian`.') if axis == 1 and branch_in == 'dense_before_branch_in': raise jtu.SkipTest('`FanInConcat` on feature axis requires a dense layer' 'after concatenation.') key = random.PRNGKey(1) X0_1 = random.normal(key, (10, 20)) X0_2 = None if same_inputs else random.normal(key, (8, 20)) if xla_bridge.get_backend().platform == 'tpu': width = 2048 n_samples = 1024 tol = 0.02 else: width = 1024 n_samples = 256 tol = 0.01 dense = stax.Dense(width, 1.25, 0.1) input_layers = [dense, stax.FanOut(n_branches)] branches = [] for b in range(n_branches): branch_layers = [FanInTest._get_phi(b)] for i in range(b): branch_layers += [ stax.Dense(width, 1. + 2 * i, 0.5 + i), FanInTest._get_phi(i)] if branch_in == 'dense_before_branch_in': branch_layers += [dense] branches += [stax.serial(*branch_layers)] output_layers = [ stax.FanInSum() if axis is None else stax.FanInConcat(axis), stax.Relu() ] if branch_in == 'dense_after_branch_in': output_layers.insert(1, dense) nn = stax.serial(*(input_layers + [stax.parallel(*branches)] + output_layers)) if get == 'nngp': init_fn, apply_fn, kernel_fn = nn elif get == 'ntk': init_fn, apply_fn, kernel_fn = stax.serial(nn, stax.Dense(1, 1.25, 0.5)) else: raise ValueError(get) kernel_fn_mc = monte_carlo.monte_carlo_kernel_fn( init_fn, apply_fn, key, n_samples, device_count=0) exact = kernel_fn(X0_1, X0_2, get=get) empirical = kernel_fn_mc(X0_1, X0_2, get=get) empirical = empirical.reshape(exact.shape) utils.assert_close_matrices(self, empirical, exact, tol)
def net(logits): return stax.serial( stax.Dense(N), stax.FanOut(2), stax.parallel(stax.Dense(logits), stax.Dense(logits)))
def bann_model( W_std, b_std, first_layer_width, second_layer_width, subNN_num, keep_rate, activation, parameterization ): """Construct fully connected NN model and infinite width NTK & NNGP kernel function. Args: W_std (float): Weight standard deviation. b_std (float): Bias standard deviation. first_layer_width (int): First Hidden layer width. second_layer_width (int): Second Hidden layer width. subNN_num (int) : Number of sub neural networks in the architecture keep_rate (float): 1 - Dropout rate. activation (string): Activation function string, 'erf' or 'relu'. parameterization (string): Parameterization string, 'ntk' or 'standard'. Returns: `(init_fn, apply_fn, kernel_fn)` """ act = activation_fn(activation) # multi-task learning # Computational Skeleton Block CSB = stax.serial( stax.FanOut(subNN_num), stax.parallel( stax.serial( Dense(first_layer_width, W_std, b_std, parameterization=parameterization), act(), Dense(second_layer_width, W_std, b_std, parameterization=parameterization), act(), stax.Dropout(keep_rate) ), stax.serial( Dense(first_layer_width, W_std, b_std, parameterization=parameterization), act(), Dense(2 * second_layer_width, W_std, b_std, parameterization=parameterization), act(), stax.Dropout(keep_rate) ), stax.serial( Dense(first_layer_width, W_std, b_std, parameterization=parameterization), act(), Dense(3 * second_layer_width, W_std, b_std, parameterization=parameterization), act(), stax.Dropout(keep_rate) ), stax.serial( Dense(first_layer_width, W_std, b_std, parameterization=parameterization), act(), Dense(4 * second_layer_width, W_std, b_std, parameterization=parameterization), act(), stax.Dropout(keep_rate) ), stax.serial( Dense(first_layer_width, W_std, b_std, parameterization=parameterization), act(), Dense(5 * second_layer_width, W_std, b_std, parameterization=parameterization), act(), stax.Dropout(keep_rate) ), stax.serial( Dense(first_layer_width, W_std, b_std, parameterization=parameterization), act(), Dense(6 * second_layer_width, W_std, b_std, parameterization=parameterization), act(), stax.Dropout(keep_rate) ), stax.serial( Dense(first_layer_width, W_std, b_std, parameterization=parameterization), act(), Dense(7 * second_layer_width, W_std, b_std, parameterization=parameterization), act(), stax.Dropout(keep_rate) ), stax.serial( Dense(first_layer_width, W_std, b_std, parameterization=parameterization), act(), Dense(8 * second_layer_width, W_std, b_std, parameterization=parameterization), act(), stax.Dropout(keep_rate) ), stax.serial( Dense(first_layer_width, W_std, b_std, parameterization=parameterization), act(), Dense(9 * second_layer_width, W_std, b_std, parameterization=parameterization), act(), stax.Dropout(keep_rate) ), stax.serial( Dense(first_layer_width, W_std, b_std, parameterization=parameterization), act(), Dense(10 * second_layer_width, W_std, b_std, parameterization=parameterization), act(), stax.Dropout(keep_rate) ) ), stax.FanInConcat() ) Additive = stax.serial( stax.FanOut(2), stax.parallel( stax.serial( CSB, stax.Dropout(keep_rate) ), stax.serial( CSB, stax.Dropout(keep_rate) ) ), stax.FanInConcat() ) init_fn, apply_fn, kernel_fn = stax.serial( Additive, Dense(1, W_std, b_std, parameterization=parameterization) ) apply_fn = jit(apply_fn) return init_fn, apply_fn, kernel_fn
def _get_net(W_std, b_std, filter_shape, is_conv, use_pooling, is_res, padding, phi, strides, width, is_ntk, proj_into_2d, pool_type, layer_norm, parameterization, use_dropout): if is_conv: # Select a random dimension order. default_spec = 'NHWC' if xla_bridge.get_backend().platform == 'tpu': # Keep batch dimension leading for TPU for batching to work. specs = ['NHWC', 'NHCW', 'NCHW'] else: specs = ['NHWC', 'NHCW', 'NCHW', 'CHWN', 'CHNW', 'CNHW'] spec = prandom.choice(specs) input_shape = tuple(INPUT_SHAPE[default_spec.index(c)] for c in spec) if layer_norm: layer_norm = tuple(spec.index(c) for c in layer_norm) else: # Only `NC` dimension order is supported and is enforced by layers. spec = None input_shape = INPUT_SHAPE if layer_norm: layer_norm = prandom.choice([(1,), (-1,)]) dimension_numbers = (spec, 'HWIO', spec) fc = partial( stax.Dense, W_std=W_std, b_std=b_std, parameterization=parameterization) def conv(out_chan): return stax.GeneralConv( dimension_numbers=dimension_numbers, out_chan=out_chan, filter_shape=filter_shape, strides=strides, padding=padding, W_std=W_std, b_std=b_std, parameterization=parameterization ) affine = conv(width) if is_conv else fc(width) spec = dimension_numbers[-1] rate = np.onp.random.uniform(0.5, 0.9) dropout = stax.Dropout(rate, mode='train') if pool_type == 'AVG': pool_fn = stax.AvgPool globalPool_fn = stax.GlobalAvgPool elif pool_type == 'SUM': pool_fn = stax.SumPool globalPool_fn = stax.GlobalSumPool if use_pooling: pool_or_identity = pool_fn((2, 3), None, 'SAME' if padding == 'SAME' else 'CIRCULAR', spec=spec) else: pool_or_identity = stax.Identity() dropout_or_identity = dropout if use_dropout else stax.Identity() layer_norm_or_identity = (stax.Identity() if layer_norm is None else stax.LayerNorm(axis=layer_norm, spec=spec)) res_unit = stax.serial(pool_or_identity, phi, dropout_or_identity, affine) if is_res: block = stax.serial( affine, stax.FanOut(2), stax.parallel(stax.Identity(), res_unit), stax.FanInSum(), layer_norm_or_identity) else: block = stax.serial( affine, res_unit, layer_norm_or_identity) if proj_into_2d == 'FLAT': proj_layer = stax.Flatten(spec=spec) elif proj_into_2d == 'POOL': proj_layer = globalPool_fn(spec=spec) elif proj_into_2d.startswith('ATTN'): n_heads = int(np.sqrt(width)) n_chan_val = int(np.round(float(width) / n_heads)) fixed = proj_into_2d == 'ATTN_FIXED' proj_layer = stax.serial( stax.GlobalSelfAttention( n_chan_out=width, n_chan_key=width, n_chan_val=n_chan_val, n_heads=n_heads, fixed=fixed, W_key_std=W_std, W_value_std=W_std, W_query_std=W_std, W_out_std=1.0, b_std=b_std, spec=spec), stax.Flatten(spec=spec)) else: raise ValueError(proj_into_2d) readout = stax.serial(proj_layer, fc(1 if is_ntk else width)) return stax.serial(block, readout), input_shape
def test_fan_in_conv(self, same_inputs, axis, n_branches, get, branch_in, readout): if xla_bridge.get_backend().platform == 'cpu': raise jtu.SkipTest('Not running CNNs on CPU to save time.') if axis in (None, 0, 1, 2) and branch_in == 'dense_after_branch_in': raise jtu.SkipTest('`FanInSum` and `FanInConcat(0/1/2)` ' 'require `is_gaussian`.') if axis == 3 and branch_in == 'dense_before_branch_in': raise jtu.SkipTest('`FanInConcat` on feature axis requires a dense layer ' 'after concatenation.') key = random.PRNGKey(1) X0_1 = random.normal(key, (2, 5, 6, 3)) X0_2 = None if same_inputs else random.normal(key, (3, 5, 6, 3)) if xla_bridge.get_backend().platform == 'tpu': width = 2048 n_samples = 1024 tol = 0.02 else: width = 1024 n_samples = 512 tol = 0.01 conv = stax.Conv(out_chan=width, filter_shape=(3, 3), padding='SAME', W_std=1.25, b_std=0.1) input_layers = [conv, stax.FanOut(n_branches)] branches = [] for b in range(n_branches): branch_layers = [FanInTest._get_phi(b)] for i in range(b): branch_layers += [ stax.Conv( out_chan=width, filter_shape=(i + 1, 4 - i), padding='SAME', W_std=1.25 + i, b_std=0.1 + i), FanInTest._get_phi(i)] if branch_in == 'dense_before_branch_in': branch_layers += [conv] branches += [stax.serial(*branch_layers)] output_layers = [ stax.FanInSum() if axis is None else stax.FanInConcat(axis), stax.Relu(), stax.GlobalAvgPool() if readout == 'pool' else stax.Flatten() ] if branch_in == 'dense_after_branch_in': output_layers.insert(1, conv) nn = stax.serial(*(input_layers + [stax.parallel(*branches)] + output_layers)) init_fn, apply_fn, kernel_fn = stax.serial( nn, stax.Dense(1 if get == 'ntk' else width, 1.25, 0.5)) kernel_fn_mc = monte_carlo.monte_carlo_kernel_fn( init_fn, apply_fn, key, n_samples, device_count=0 if axis in (0, -4) else -1) exact = kernel_fn(X0_1, X0_2, get=get) empirical = kernel_fn_mc(X0_1, X0_2, get=get) empirical = empirical.reshape(exact.shape) utils.assert_close_matrices(self, empirical, exact, tol)
def layer(N_out): return stax.parallel(stax.parallel(stax.Dense(N_out), stax.Dense(N_out + 1)), stax.Dense(N_out + 2))