def _get_net(W_std, b_std, filter_shape, is_conv, use_pooling, is_res, padding, phi, strides, width, is_ntk, proj_into_2d, layer_norm): fc = partial(stax.Dense, W_std=W_std, b_std=b_std) conv = partial(stax.Conv, filter_shape=filter_shape, strides=strides, padding=padding, W_std=W_std, b_std=b_std) affine = conv(width) if is_conv else fc(width) res_unit = stax.serial((stax.AvgPool( (2, 3), None, 'SAME' if padding == 'SAME' else 'CIRCULAR') if use_pooling else stax.Identity()), phi, affine) if is_res: block = stax.serial( affine, stax.FanOut(2), stax.parallel(stax.Identity(), res_unit), stax.FanInSum(), stax.Identity() if layer_norm is None else stax.LayerNorm( axis=layer_norm)) else: block = stax.serial( affine, res_unit, stax.Identity() if layer_norm is None else stax.LayerNorm( axis=layer_norm)) if proj_into_2d == 'FLAT': proj_layer = stax.Flatten() elif proj_into_2d == 'POOL': proj_layer = stax.GlobalAvgPool() elif proj_into_2d.startswith('ATTN'): n_heads = int(np.sqrt(width)) n_chan_val = int(np.round(float(width) / n_heads)) fixed = proj_into_2d == 'ATTN_FIXED' proj_layer = stax.serial( stax.GlobalSelfAttention(width, n_chan_key=width, n_chan_val=n_chan_val, n_heads=n_heads, fixed=fixed, W_key_std=W_std, W_value_std=W_std, W_query_std=W_std, W_out_std=1.0, b_std=b_std), stax.Flatten()) else: raise ValueError(proj_into_2d) readout = stax.serial(proj_layer, fc(1 if is_ntk else width)) return stax.serial(block, readout)
def test_hermite(self, same_inputs, degree, get, readout): key = random.PRNGKey(1) key1, key2, key = random.split(key, 3) if degree > 2: width = 10000 n_samples = 5000 test_utils.skip_test(self) else: width = 10000 n_samples = 100 x1 = np.cos(random.normal(key1, [2, 6, 6, 3])) x2 = x1 if same_inputs else np.cos(random.normal(key2, [3, 6, 6, 3])) conv_layers = [ stax.Conv(width, (3, 3), W_std=2., b_std=0.5), stax.LayerNorm(), stax.Hermite(degree), stax.GlobalAvgPool() if readout == 'pool' else stax.Flatten(), stax.Dense(1) if get == 'ntk' else stax.Identity()] init_fn, apply_fn, kernel_fn = stax.serial(*conv_layers) analytic_kernel = kernel_fn(x1, x2, get) mc_kernel_fn = nt.monte_carlo_kernel_fn(init_fn, apply_fn, key, n_samples) mc_kernel = mc_kernel_fn(x1, x2, get) rot = degree / 2. * 1e-2 test_utils.assert_close_matrices(self, mc_kernel, analytic_kernel, rot)
def test_mask_conv(self, same_inputs, get, mask_axis, mask_constant, concat, proj, p, n, transpose): if isinstance(concat, int) and concat > n: raise absltest.SkipTest('Concatenation axis out of bounds.') test_utils.skip_test(self) if default_backend() == 'gpu' and n > 3: raise absltest.SkipTest('>=4D-CNN is not supported on GPUs.') width = 256 n_samples = 256 tol = 0.03 key = random.PRNGKey(1) spatial_shape = ((1, 2, 3, 2, 1) if transpose else (15, 8, 9))[:n] filter_shape = ((2, 3, 1, 2, 1) if transpose else (7, 2, 3))[:n] strides = (2, 1, 3, 2, 3)[:n] spatial_spec = 'HWDZX'[:n] dimension_numbers = ('N' + spatial_spec + 'C', 'OI' + spatial_spec, 'N' + spatial_spec + 'C') x1 = np.cos(random.normal(key, (2, ) + spatial_shape + (2, ))) x1 = test_utils.mask(x1, mask_constant, mask_axis, key, p) if same_inputs: x2 = None else: x2 = np.cos(random.normal(key, (4, ) + spatial_shape + (2, ))) x2 = test_utils.mask(x2, mask_constant, mask_axis, key, p) def get_attn(): return stax.GlobalSelfAttention( n_chan_out=width, n_chan_key=width, n_chan_val=int(np.round(float(width) / int(np.sqrt(width)))), n_heads=int(np.sqrt(width)), ) if proj == 'avg' else stax.Identity() conv = stax.ConvTranspose if transpose else stax.Conv nn = stax.serial( stax.FanOut(3), stax.parallel( stax.serial( conv(dimension_numbers=dimension_numbers, out_chan=width, strides=strides, filter_shape=filter_shape, padding='CIRCULAR', W_std=1.5, b_std=0.2), stax.LayerNorm(axis=(1, -1)), stax.Abs(), stax.DotGeneral(rhs=0.9), conv(dimension_numbers=dimension_numbers, out_chan=width, strides=strides, filter_shape=filter_shape, padding='VALID', W_std=1.2, b_std=0.1), ), stax.serial( conv(dimension_numbers=dimension_numbers, out_chan=width, strides=strides, filter_shape=filter_shape, padding='SAME', W_std=0.1, b_std=0.3), stax.Relu(), stax.Dropout(0.7), conv(dimension_numbers=dimension_numbers, out_chan=width, strides=strides, filter_shape=filter_shape, padding='VALID', W_std=0.9, b_std=1.), ), stax.serial( get_attn(), conv(dimension_numbers=dimension_numbers, out_chan=width, strides=strides, filter_shape=filter_shape, padding='CIRCULAR', W_std=1., b_std=0.1), stax.Erf(), stax.Dropout(0.2), stax.DotGeneral(rhs=0.7), conv(dimension_numbers=dimension_numbers, out_chan=width, strides=strides, filter_shape=filter_shape, padding='VALID', W_std=1., b_std=0.1), )), (stax.FanInSum() if concat is None else stax.FanInConcat(concat)), get_attn(), { 'avg': stax.GlobalAvgPool(), 'sum': stax.GlobalSumPool(), 'flatten': stax.Flatten(), }[proj], ) if get == 'nngp': init_fn, apply_fn, kernel_fn = stax.serial( nn, stax.Dense(width, 1., 0.)) elif get == 'ntk': init_fn, apply_fn, kernel_fn = stax.serial(nn, stax.Dense(1, 1., 0.)) else: raise ValueError(get) kernel_fn_mc = nt.monte_carlo_kernel_fn( init_fn, apply_fn, key, n_samples, device_count=0 if concat in (0, -n) else -1, implementation=_DEFAULT_TESTING_NTK_IMPLEMENTATION, vmap_axes=None if concat in (0, -n) else 0, ) kernel_fn = jit(kernel_fn, static_argnames='get') exact = kernel_fn(x1, x2, get, mask_constant=mask_constant) empirical = kernel_fn_mc(x1, x2, get=get, mask_constant=mask_constant) test_utils.assert_close_matrices(self, empirical, exact, tol)
def _get_net(W_std, b_std, filter_shape, is_conv, use_pooling, is_res, padding, phi, strides, width, is_ntk, proj_into_2d, pool_type, layer_norm, parameterization, s, use_dropout): if is_conv: # Select a random filter order. default_filter_spec = 'HW' filter_specs = [''.join(p) for p in itertools.permutations('HWIO')] filter_spec = prandom.choice(filter_specs) filter_shape = tuple(filter_shape[default_filter_spec.index(c)] for c in filter_spec if c in default_filter_spec) strides = tuple(strides[default_filter_spec.index(c)] for c in filter_spec if c in default_filter_spec) # Select the activation order. default_spec = 'NHWC' if default_backend() == 'tpu': # Keep batch dimension leading for TPU for batching to work. specs = ['N' + ''.join(p) for p in itertools.permutations('CHW')] else: specs = [''.join(p) for p in itertools.permutations('NCHW')] spec = prandom.choice(specs) input_shape = tuple(INPUT_SHAPE[default_spec.index(c)] for c in spec) else: input_shape = (INPUT_SHAPE[0], onp.prod(INPUT_SHAPE[1:])) if default_backend() == 'tpu': spec = 'NC' else: spec = prandom.choice(['NC', 'CN']) if spec.index('N') == 1: input_shape = input_shape[::-1] filter_spec = None dimension_numbers = (spec, filter_spec, spec) batch_axis, channel_axis = spec.index('N'), spec.index('C') spec_fc = ''.join(c for c in spec if c in ('N', 'C')) batch_axis_fc, channel_axis_fc = spec_fc.index('N'), spec_fc.index('C') if not is_conv: batch_axis = batch_axis_fc channel_axis = channel_axis_fc if layer_norm: layer_norm = tuple(spec.index(c) for c in layer_norm) def fc(out_dim, s): return stax.Dense( out_dim=out_dim, W_std=W_std, b_std=b_std, parameterization=parameterization, s=s, batch_axis=batch_axis_fc, channel_axis=channel_axis_fc ) def conv(out_chan, s): return stax.Conv( out_chan=out_chan, filter_shape=filter_shape, strides=strides, padding=padding, W_std=W_std, b_std=b_std, dimension_numbers=dimension_numbers, parameterization=parameterization, s=s ) affine = conv(width, (s, s)) if is_conv else fc(width, (s, s)) affine_bottom = conv(width, (1, s)) if is_conv else fc(width, (1, s)) rate = onp.random.uniform(0.5, 0.9) dropout = stax.Dropout(rate, mode='train') if pool_type == 'AVG': pool_fn = stax.AvgPool global_pool_fn = stax.GlobalAvgPool elif pool_type == 'SUM': pool_fn = stax.SumPool global_pool_fn = stax.GlobalSumPool else: raise ValueError(pool_type) if use_pooling: pool_or_identity = pool_fn((2, 3), None, 'SAME' if padding == 'SAME' else 'CIRCULAR', batch_axis=batch_axis, channel_axis=channel_axis) else: pool_or_identity = stax.Identity() dropout_or_identity = dropout if use_dropout else stax.Identity() layer_norm_or_identity = (stax.Identity() if layer_norm is None else stax.LayerNorm(axis=layer_norm, batch_axis=batch_axis, channel_axis=channel_axis)) res_unit = stax.serial(dropout_or_identity, affine, pool_or_identity) if is_res: block = stax.serial( affine_bottom, stax.FanOut(2), stax.parallel(stax.Identity(), res_unit), stax.FanInSum(), layer_norm_or_identity, phi) else: block = stax.serial( affine_bottom, res_unit, layer_norm_or_identity, phi) if proj_into_2d == 'FLAT': proj_layer = stax.Flatten(batch_axis, batch_axis_fc) elif proj_into_2d == 'POOL': proj_layer = global_pool_fn(batch_axis, channel_axis) elif proj_into_2d.startswith('ATTN'): n_heads = int(np.sqrt(width)) n_chan_val = int(np.round(float(width) / n_heads)) proj_layer = stax.serial( stax.GlobalSelfAttention( n_chan_out=width, n_chan_key=width, n_chan_val=n_chan_val, n_heads=n_heads, linear_scaling=True, W_key_std=W_std, W_value_std=W_std, W_query_std=W_std, W_out_std=1.0, b_std=b_std, batch_axis=batch_axis, channel_axis=channel_axis), stax.Flatten(batch_axis, batch_axis_fc)) else: raise ValueError(proj_into_2d) readout = stax.serial(proj_layer, fc(1 if is_ntk else width, (s, 1 if is_ntk else s))) device_count = -1 if spec.index('N') == 0 else 0 net = stax.serial(block, readout) return net, input_shape, device_count, channel_axis_fc
def _get_net(W_std, b_std, filter_shape, is_conv, use_pooling, is_res, padding, phi, strides, width, is_ntk, proj_into_2d, layer_norm, parameterization, use_dropout, dimension_numbers): fc = partial(stax.Dense, W_std=W_std, b_std=b_std, parameterization=parameterization) def conv(out_chan): return stax.GeneralConv(dimension_numbers=dimension_numbers, out_chan=out_chan, filter_shape=filter_shape, strides=strides, padding=padding, W_std=W_std, b_std=b_std, parameterization=parameterization) affine = conv(width) if is_conv else fc(width) spec = dimension_numbers[-1] rate = np.onp.random.uniform(0.5, 0.9) dropout = stax.Dropout(rate, mode='train') dropout_or_identity = dropout if use_dropout else stax.Identity() avg_pool_or_identity = stax.AvgPool( window_shape=(2, 3), strides=None, padding='SAME' if padding == 'SAME' else 'CIRCULAR', spec=spec) if use_pooling else stax.Identity() layer_norm_or_identity = (stax.Identity() if layer_norm is None else stax.LayerNorm(axis=layer_norm, spec=spec)) res_unit = stax.serial(avg_pool_or_identity, phi, dropout_or_identity, affine) if is_res: block = stax.serial(affine, stax.FanOut(2), stax.parallel(stax.Identity(), res_unit), stax.FanInSum(), layer_norm_or_identity) else: block = stax.serial(affine, res_unit, layer_norm_or_identity) if proj_into_2d == 'FLAT': proj_layer = stax.Flatten(spec=spec) elif proj_into_2d == 'POOL': proj_layer = stax.GlobalAvgPool(spec=spec) elif proj_into_2d.startswith('ATTN'): n_heads = int(np.sqrt(width)) n_chan_val = int(np.round(float(width) / n_heads)) fixed = proj_into_2d == 'ATTN_FIXED' proj_layer = stax.serial( stax.GlobalSelfAttention(width, n_chan_key=width, n_chan_val=n_chan_val, n_heads=n_heads, fixed=fixed, W_key_std=W_std, W_value_std=W_std, W_query_std=W_std, W_out_std=1.0, b_std=b_std, spec=spec), stax.Flatten(spec=spec)) else: raise ValueError(proj_into_2d) readout = stax.serial(proj_layer, fc(1 if is_ntk else width)) return stax.serial(block, readout)
def _get_net(W_std, b_std, filter_shape, is_conv, use_pooling, is_res, padding, phi, strides, width, is_ntk, proj_into_2d, pool_type, layer_norm, parameterization, use_dropout): if is_conv: # Select a random dimension order. default_spec = 'NHWC' if xla_bridge.get_backend().platform == 'tpu': # Keep batch dimension leading for TPU for batching to work. specs = ['NHWC', 'NHCW', 'NCHW'] else: specs = ['NHWC', 'NHCW', 'NCHW', 'CHWN', 'CHNW', 'CNHW'] spec = prandom.choice(specs) input_shape = tuple(INPUT_SHAPE[default_spec.index(c)] for c in spec) if layer_norm: layer_norm = tuple(spec.index(c) for c in layer_norm) else: # Only `NC` dimension order is supported and is enforced by layers. spec = None input_shape = INPUT_SHAPE if layer_norm: layer_norm = prandom.choice([(1,), (-1,)]) dimension_numbers = (spec, 'HWIO', spec) fc = partial( stax.Dense, W_std=W_std, b_std=b_std, parameterization=parameterization) def conv(out_chan): return stax.GeneralConv( dimension_numbers=dimension_numbers, out_chan=out_chan, filter_shape=filter_shape, strides=strides, padding=padding, W_std=W_std, b_std=b_std, parameterization=parameterization ) affine = conv(width) if is_conv else fc(width) spec = dimension_numbers[-1] rate = np.onp.random.uniform(0.5, 0.9) dropout = stax.Dropout(rate, mode='train') if pool_type == 'AVG': pool_fn = stax.AvgPool globalPool_fn = stax.GlobalAvgPool elif pool_type == 'SUM': pool_fn = stax.SumPool globalPool_fn = stax.GlobalSumPool if use_pooling: pool_or_identity = pool_fn((2, 3), None, 'SAME' if padding == 'SAME' else 'CIRCULAR', spec=spec) else: pool_or_identity = stax.Identity() dropout_or_identity = dropout if use_dropout else stax.Identity() layer_norm_or_identity = (stax.Identity() if layer_norm is None else stax.LayerNorm(axis=layer_norm, spec=spec)) res_unit = stax.serial(pool_or_identity, phi, dropout_or_identity, affine) if is_res: block = stax.serial( affine, stax.FanOut(2), stax.parallel(stax.Identity(), res_unit), stax.FanInSum(), layer_norm_or_identity) else: block = stax.serial( affine, res_unit, layer_norm_or_identity) if proj_into_2d == 'FLAT': proj_layer = stax.Flatten(spec=spec) elif proj_into_2d == 'POOL': proj_layer = globalPool_fn(spec=spec) elif proj_into_2d.startswith('ATTN'): n_heads = int(np.sqrt(width)) n_chan_val = int(np.round(float(width) / n_heads)) fixed = proj_into_2d == 'ATTN_FIXED' proj_layer = stax.serial( stax.GlobalSelfAttention( n_chan_out=width, n_chan_key=width, n_chan_val=n_chan_val, n_heads=n_heads, fixed=fixed, W_key_std=W_std, W_value_std=W_std, W_query_std=W_std, W_out_std=1.0, b_std=b_std, spec=spec), stax.Flatten(spec=spec)) else: raise ValueError(proj_into_2d) readout = stax.serial(proj_layer, fc(1 if is_ntk else width)) return stax.serial(block, readout), input_shape