def test_hermite(self, same_inputs, degree, get, readout): key = random.PRNGKey(1) key1, key2, key = random.split(key, 3) if degree > 2: width = 10000 n_samples = 5000 test_utils.skip_test(self) else: width = 10000 n_samples = 100 x1 = np.cos(random.normal(key1, [2, 6, 6, 3])) x2 = x1 if same_inputs else np.cos(random.normal(key2, [3, 6, 6, 3])) conv_layers = [ stax.Conv(width, (3, 3), W_std=2., b_std=0.5), stax.LayerNorm(), stax.Hermite(degree), stax.GlobalAvgPool() if readout == 'pool' else stax.Flatten(), stax.Dense(1) if get == 'ntk' else stax.Identity()] init_fn, apply_fn, kernel_fn = stax.serial(*conv_layers) analytic_kernel = kernel_fn(x1, x2, get) mc_kernel_fn = nt.monte_carlo_kernel_fn(init_fn, apply_fn, key, n_samples) mc_kernel = mc_kernel_fn(x1, x2, get) rot = degree / 2. * 1e-2 test_utils.assert_close_matrices(self, mc_kernel, analytic_kernel, rot)
def test_parallel_out(self, same_inputs, kernel_type): platform = default_backend() rtol = RTOL if platform != 'tpu' else 0.05 rng = random.PRNGKey(0) input_key1, mc_key = random.split(rng, 2) x1, x2 = _get_inputs(input_key1, same_inputs, (BATCH_SIZE, 1)) N = 2 ** 10 def net(logits): return stax.serial( stax.Dense(N), stax.FanOut(2), stax.parallel(stax.Dense(logits), stax.Dense(logits))) init_fn, apply_fn, kernel_fn = net(N if kernel_type == 'nngp' else 1) kernel_fn_empirical = nt.monte_carlo_kernel_fn( init_fn, apply_fn, mc_key, N_SAMPLES, trace_axes=(-1,), implementation=2, vmap_axes=(0, [0, 0], {})) test_utils.assert_close_matrices(self, kernel_fn(x1, x2, kernel_type), kernel_fn_empirical(x1, x2, kernel_type), rtol)
def test_nonlineariy(self, phi, same_inputs, a, b, n): width = 2**10 n_samples = 2**9 init_fn, apply_fn, kernel_fn = stax.serial( stax.Dense(width), phi(a=a, b=b), stax.Dense(width), phi(a=a, b=b), stax.Dense(1)) key1, key2, key_mc = random.split(random.PRNGKey(1), 3) shape = (4, 3, 2)[:n] + (1,) x1 = np.cos(random.normal(key1, (2,) + shape)) if same_inputs is None: x2 = None elif same_inputs is True: x2 = x1 else: x2 = np.cos(random.normal(key2, (3,) + shape)) k = kernel_fn(x1, x2) mc_kernel_fn = nt.monte_carlo_kernel_fn(init_fn, apply_fn, key_mc, n_samples) k_mc = mc_kernel_fn(x1, x2, ('nngp', 'ntk')) test_utils.assert_close_matrices(self, k_mc.nngp, k.nngp, 6e-2) test_utils.assert_close_matrices(self, k_mc.ntk, k.ntk, 6e-2)
def test_nested_parallel(self, same_inputs, kernel_type): platform = default_backend() rtol = RTOL if platform != 'tpu' else 0.05 rng = random.PRNGKey(0) (input_key1, input_key2, input_key3, input_key4, mask_key, mc_key) = random.split(rng, 6) x1_1, x2_1 = _get_inputs(input_key1, same_inputs, (BATCH_SIZE, 5)) x1_2, x2_2 = _get_inputs(input_key2, same_inputs, (BATCH_SIZE, 2, 2, 2)) x1_3, x2_3 = _get_inputs(input_key3, same_inputs, (BATCH_SIZE, 2, 2, 3)) x1_4, x2_4 = _get_inputs(input_key4, same_inputs, (BATCH_SIZE, 3, 4)) m1_key, m2_key, m3_key, m4_key = random.split(mask_key, 4) x1_1 = test_utils.mask( x1_1, mask_constant=-1, mask_axis=(1,), key=m1_key, p=0.5) x1_2 = test_utils.mask( x1_2, mask_constant=-1, mask_axis=(2, 3,), key=m2_key, p=0.5) if not same_inputs: x2_3 = test_utils.mask( x2_3, mask_constant=-1, mask_axis=(1, 3,), key=m3_key, p=0.5) x2_4 = test_utils.mask( x2_4, mask_constant=-1, mask_axis=(2,), key=m4_key, p=0.5) x1 = (((x1_1, x1_2), x1_3), x1_4) x2 = (((x2_1, x2_2), x2_3), x2_4) if not same_inputs else None N_in = 2 ** 7 # We only include dropout on non-TPU backends, because it takes large N to # converge on TPU. dropout_or_id = stax.Dropout(0.9) if platform != 'tpu' else stax.Identity() init_fn, apply_fn, kernel_fn = stax.parallel( stax.parallel( stax.parallel(stax.Dense(N_in), stax.serial(stax.Conv(N_in + 1, (2, 2)), stax.Flatten())), stax.serial(stax.Conv(N_in + 2, (2, 2)), dropout_or_id, stax.GlobalAvgPool())), stax.Conv(N_in + 3, (2,))) kernel_fn_empirical = nt.monte_carlo_kernel_fn( init_fn, apply_fn, mc_key, N_SAMPLES, implementation=_DEFAULT_TESTING_NTK_IMPLEMENTATION, vmap_axes=(((((0, 0), 0), 0), (((0, 0), 0), 0), {}) if platform == 'tpu' else None) ) test_utils.assert_close_matrices( self, kernel_fn(x1, x2, get=kernel_type, mask_constant=-1), kernel_fn_empirical(x1, x2, get=kernel_type, mask_constant=-1), rtol)
def test_parallel_in(self, same_inputs, kernel_type): platform = default_backend() rtol = RTOL if platform != 'tpu' else 0.05 rng = random.PRNGKey(0) input_key1, input_key2, mc_key = random.split(rng, 3) x1_1, x2_1 = _get_inputs(input_key1, same_inputs, (BATCH_SIZE, 2)) x1_2, x2_2 = _get_inputs(input_key2, same_inputs, (BATCH_SIZE, 3)) x1 = (x1_1, x1_2) x2 = (x2_1, x2_2) N = 2 ** 7 def net(logits): return stax.serial( stax.parallel(stax.Dense(N), stax.Dense(N)), stax.serial(stax.FanInSum(), stax.Dense(logits))) init_fn, apply_fn, kernel_fn = net(N if kernel_type == 'nngp' else 1) kernel_fn_empirical = nt.monte_carlo_kernel_fn( init_fn, apply_fn, mc_key, N_SAMPLES, trace_axes=(-1,), implementation=_DEFAULT_TESTING_NTK_IMPLEMENTATION, vmap_axes=((0, 0), 0, {}) ) test_utils.assert_close_matrices(self, kernel_fn(x1, x2, kernel_type), kernel_fn_empirical(x1, x2, kernel_type), rtol)
def _test_activation(self, activation_fn, same_inputs, model, get, rbf_gamma=None): if 'conv' in model: test_utils.skip_test(self) key = random.PRNGKey(1) key, split = random.split(key) output_dim = 1024 if get == 'nngp' else 1 b_std = 0.5 W_std = 2.0 if activation_fn[2].__name__ == 'Sin': W_std = 0.9 if activation_fn[2].__name__ == 'Rbf': W_std = 1.0 b_std = 0.0 if model == 'fc': rtol = 0.04 X0_1 = random.normal(key, (4, 2)) X0_2 = None if same_inputs else random.normal(split, (2, 2)) affine = stax.Dense(1024, W_std, b_std) readout = stax.Dense(output_dim) depth = 1 else: rtol = 0.05 X0_1 = random.normal(key, (2, 4, 4, 3)) X0_2 = None if same_inputs else random.normal(split, (4, 4, 4, 3)) affine = stax.Conv(512, (3, 2), W_std=W_std, b_std=b_std, padding='SAME') readout = stax.serial(stax.GlobalAvgPool() if 'pool' in model else stax.Flatten(), stax.Dense(output_dim)) depth = 2 if default_backend() == 'cpu': num_samplings = 200 rtol *= 2 else: num_samplings = (500 if activation_fn[2].__name__ in ('Sin', 'Rbf') else 300) init_fn, apply_fn, kernel_fn = stax.serial( *[affine, activation_fn]*depth, readout) analytic_kernel = kernel_fn(X0_1, X0_2, get) mc_kernel_fn = nt.monte_carlo_kernel_fn( init_fn, apply_fn, split, num_samplings, implementation=2, vmap_axes=0 ) empirical_kernel = mc_kernel_fn(X0_1, X0_2, get) test_utils.assert_close_matrices(self, analytic_kernel, empirical_kernel, rtol) # Check match with explicit RBF if rbf_gamma is not None and get == 'nngp' and model == 'fc': input_dim = X0_1.shape[1] _, _, kernel_fn = self._RBF(rbf_gamma / input_dim) direct_rbf_kernel = kernel_fn(X0_1, X0_2, get) test_utils.assert_close_matrices(self, analytic_kernel, direct_rbf_kernel, rtol)
def _get_empirical(n_samples, get): kernel_fn_empirical = nt.monte_carlo_kernel_fn( init_fn, apply_fn, key, n_samples, device_count=device_count, trace_axes=(channel_axis,), batch_size=batch_size, implementation=2 ) if same_inputs: assert x2 is None return kernel_fn_empirical(x1, x2, get)
def test_sparse_inputs(self, act, kernel, do_stabilize): if do_stabilize and act != 'relu': raise absltest.SkipTest('Stabilization possible only in Relu.') key = random.PRNGKey(1) input_count = 4 sparse_count = 2 input_size = 3 width = 1024 # NOTE(schsam): It seems that convergence is slower when inputs are sparse. samples = N_SAMPLES if default_backend() == 'gpu': tol = 5e-4 samples = 100 * N_SAMPLES else: tol = {onp.dtype(onp.float32): 5e-2, onp.dtype(onp.float64): 5e-3} # a batch of dense inputs x_dense = random.normal(key, (input_count, input_size)) x_sparse = x_dense.at[:sparse_count, :].set(0.) activation = (stax.Relu(do_stabilize=do_stabilize) if act == 'relu' else stax.Erf()) init_fn, apply_fn, kernel_fn = stax.serial( stax.Dense(width), activation, stax.Dense(1 if kernel == 'ntk' else width)) exact = kernel_fn(x_sparse, None, kernel) mc = nt.monte_carlo_kernel_fn( init_fn, apply_fn, random.split(key, 2)[0], samples, vmap_axes=0, device_count=-1, implementation=2 )(x_sparse, None, kernel) mc = np.reshape(mc, exact.shape) assert not np.any(np.isnan(exact)) self.assertAllClose(exact[sparse_count:, sparse_count:], mc[sparse_count:, sparse_count:], rtol=tol, atol=tol)
def test_parallel_in_out(self, same_inputs, kernel_type): platform = default_backend() rtol = RTOL if platform != 'tpu' else 0.05 rng = random.PRNGKey(0) input_key1, input_key2, mc_key = random.split(rng, 3) x1_1, x2_1 = _get_inputs(input_key1, same_inputs, (BATCH_SIZE, 1)) x1_2, x2_2 = _get_inputs(input_key2, same_inputs, (BATCH_SIZE, 2)) x1 = (x1_1, x1_2) x2 = (x2_1, x2_2) N_in = 2 ** 10 N_out = N_in if kernel_type == 'nngp' else 1 readin = stax.serial(stax.parallel(stax.Dense(N_in), stax.Dense(N_in)), stax.FanInSum()) readout = stax.serial(stax.FanOut(3), stax.parallel(stax.Dense(N_out), stax.Dense(N_out + 1), stax.Dense(N_out + 2))) init_fn, apply_fn, _ = stax.serial(readin, readout) K_readin_fn = jit(readin[2]) K_readout_fn = jit(functools.partial(readout[2], get=kernel_type)) kernel_fn_empirical = nt.monte_carlo_kernel_fn( init_fn, apply_fn, mc_key, N_SAMPLES, trace_axes=(-1,), implementation=2, vmap_axes=((0, 0), [0, 0, 0], {}) ) test_utils.assert_close_matrices( self, K_readout_fn(K_readin_fn(x1, x2)), kernel_fn_empirical(x1, x2, get=kernel_type), rtol) # Check Both (here we just want to make sure we _can_ compute the output). K_readin_fn = jit(readin[2]) K_readout_fn = jit(functools.partial(readout[2], get=('nngp', 'ntk'))) K_readout_fn(K_readin_fn(x1, x2))
def test_fan_in_fc(self, same_inputs, axis, n_branches, get, branch_in, fan_in_mode): if fan_in_mode in ['FanInSum', 'FanInProd']: if axis != 0: raise absltest.SkipTest( '`FanInSum` and `FanInProd` are skipped when ' 'axis != 0.') axis = None if (fan_in_mode == 'FanInSum' or axis == 0) and branch_in == 'dense_after_branch_in': raise absltest.SkipTest('`FanInSum` and `FanInConcat(0)` ' 'require `is_gaussian`.') if ((axis == 1 or fan_in_mode == 'FanInProd') and branch_in == 'dense_before_branch_in'): raise absltest.SkipTest( '`FanInConcat` or `FanInProd` on feature axis requires a dense layer ' 'after concatenation or Hadamard product.') if fan_in_mode == 'FanInSum': fan_in_layer = stax.FanInSum() elif fan_in_mode == 'FanInProd': fan_in_layer = stax.FanInProd() else: fan_in_layer = stax.FanInConcat(axis) if n_branches != 2: test_utils.skip_test(self) key = random.PRNGKey(1) X0_1 = np.cos(random.normal(key, (4, 3))) X0_2 = None if same_inputs else random.normal(key, (8, 3)) width = 1024 n_samples = 256 * 2 if default_backend() == 'tpu': tol = 0.07 else: tol = 0.02 dense = stax.Dense(width, 1.25, 0.1) input_layers = [dense, stax.FanOut(n_branches)] branches = [] for b in range(n_branches): branch_layers = [FanInTest._get_phi(b)] for i in range(b): multiplier = 1 if axis not in (1, -1) else (1 + 0.25 * i) branch_layers += [ stax.Dense(int(width * multiplier), 1. + 2 * i, 0.5 + i), FanInTest._get_phi(i) ] if branch_in == 'dense_before_branch_in': branch_layers += [dense] branches += [stax.serial(*branch_layers)] output_layers = [fan_in_layer, stax.Relu()] if branch_in == 'dense_after_branch_in': output_layers.insert(1, dense) nn = stax.serial(*(input_layers + [stax.parallel(*branches)] + output_layers)) if get == 'nngp': init_fn, apply_fn, kernel_fn = nn elif get == 'ntk': init_fn, apply_fn, kernel_fn = stax.serial( nn, stax.Dense(1, 1.25, 0.5)) else: raise ValueError(get) kernel_fn_mc = nt.monte_carlo_kernel_fn( init_fn, apply_fn, key, n_samples, device_count=0 if axis in (0, -2) else -1, implementation=_DEFAULT_TESTING_NTK_IMPLEMENTATION, vmap_axes=None if axis in (0, -2) else 0, ) exact = kernel_fn(X0_1, X0_2, get=get) empirical = kernel_fn_mc(X0_1, X0_2, get=get) test_utils.assert_close_matrices(self, empirical, exact, tol)
def test_fan_in_conv(self, same_inputs, axis, n_branches, get, branch_in, readout, fan_in_mode): test_utils.skip_test(self) if fan_in_mode in ['FanInSum', 'FanInProd']: if axis != 0: raise absltest.SkipTest( '`FanInSum` and `FanInProd()` are skipped when ' 'axis != 0.') axis = None if (fan_in_mode == 'FanInSum' or axis in [0, 1, 2]) and branch_in == 'dense_after_branch_in': raise absltest.SkipTest('`FanInSum` and `FanInConcat(0/1/2)` ' 'require `is_gaussian`.') if ((axis == 3 or fan_in_mode == 'FanInProd') and branch_in == 'dense_before_branch_in'): raise absltest.SkipTest( '`FanInConcat` or `FanInProd` on feature axis ' 'requires a dense layer after concatenation ' 'or Hadamard product.') if fan_in_mode == 'FanInSum': fan_in_layer = stax.FanInSum() elif fan_in_mode == 'FanInProd': fan_in_layer = stax.FanInProd() else: fan_in_layer = stax.FanInConcat(axis) key = random.PRNGKey(1) X0_1 = random.normal(key, (2, 5, 6, 3)) X0_2 = None if same_inputs else random.normal(key, (3, 5, 6, 3)) if default_backend() == 'tpu': width = 2048 n_samples = 1024 tol = 0.02 else: width = 1024 n_samples = 512 tol = 0.01 conv = stax.Conv(out_chan=width, filter_shape=(3, 3), padding='SAME', W_std=1.25, b_std=0.1) input_layers = [conv, stax.FanOut(n_branches)] branches = [] for b in range(n_branches): branch_layers = [FanInTest._get_phi(b)] for i in range(b): multiplier = 1 if axis not in (3, -1) else (1 + 0.25 * i) branch_layers += [ stax.Conv(out_chan=int(width * multiplier), filter_shape=(i + 1, 4 - i), padding='SAME', W_std=1.25 + i, b_std=0.1 + i), FanInTest._get_phi(i) ] if branch_in == 'dense_before_branch_in': branch_layers += [conv] branches += [stax.serial(*branch_layers)] output_layers = [ fan_in_layer, stax.Relu(), stax.GlobalAvgPool() if readout == 'pool' else stax.Flatten() ] if branch_in == 'dense_after_branch_in': output_layers.insert(1, conv) nn = stax.serial(*(input_layers + [stax.parallel(*branches)] + output_layers)) init_fn, apply_fn, kernel_fn = stax.serial( nn, stax.Dense(1 if get == 'ntk' else width, 1.25, 0.5)) kernel_fn_mc = nt.monte_carlo_kernel_fn( init_fn, apply_fn, key, n_samples, device_count=0 if axis in (0, -4) else -1, implementation=_DEFAULT_TESTING_NTK_IMPLEMENTATION, vmap_axes=None if axis in (0, -4) else 0, ) exact = kernel_fn(X0_1, X0_2, get=get) empirical = kernel_fn_mc(X0_1, X0_2, get=get) test_utils.assert_close_matrices(self, empirical, exact, tol)
def test_mask_conv(self, same_inputs, get, mask_axis, mask_constant, concat, proj, p, n, transpose): if isinstance(concat, int) and concat > n: raise absltest.SkipTest('Concatenation axis out of bounds.') test_utils.skip_test(self) if default_backend() == 'gpu' and n > 3: raise absltest.SkipTest('>=4D-CNN is not supported on GPUs.') width = 256 n_samples = 256 tol = 0.03 key = random.PRNGKey(1) spatial_shape = ((1, 2, 3, 2, 1) if transpose else (15, 8, 9))[:n] filter_shape = ((2, 3, 1, 2, 1) if transpose else (7, 2, 3))[:n] strides = (2, 1, 3, 2, 3)[:n] spatial_spec = 'HWDZX'[:n] dimension_numbers = ('N' + spatial_spec + 'C', 'OI' + spatial_spec, 'N' + spatial_spec + 'C') x1 = np.cos(random.normal(key, (2, ) + spatial_shape + (2, ))) x1 = test_utils.mask(x1, mask_constant, mask_axis, key, p) if same_inputs: x2 = None else: x2 = np.cos(random.normal(key, (4, ) + spatial_shape + (2, ))) x2 = test_utils.mask(x2, mask_constant, mask_axis, key, p) def get_attn(): return stax.GlobalSelfAttention( n_chan_out=width, n_chan_key=width, n_chan_val=int(np.round(float(width) / int(np.sqrt(width)))), n_heads=int(np.sqrt(width)), ) if proj == 'avg' else stax.Identity() conv = stax.ConvTranspose if transpose else stax.Conv nn = stax.serial( stax.FanOut(3), stax.parallel( stax.serial( conv(dimension_numbers=dimension_numbers, out_chan=width, strides=strides, filter_shape=filter_shape, padding='CIRCULAR', W_std=1.5, b_std=0.2), stax.LayerNorm(axis=(1, -1)), stax.Abs(), stax.DotGeneral(rhs=0.9), conv(dimension_numbers=dimension_numbers, out_chan=width, strides=strides, filter_shape=filter_shape, padding='VALID', W_std=1.2, b_std=0.1), ), stax.serial( conv(dimension_numbers=dimension_numbers, out_chan=width, strides=strides, filter_shape=filter_shape, padding='SAME', W_std=0.1, b_std=0.3), stax.Relu(), stax.Dropout(0.7), conv(dimension_numbers=dimension_numbers, out_chan=width, strides=strides, filter_shape=filter_shape, padding='VALID', W_std=0.9, b_std=1.), ), stax.serial( get_attn(), conv(dimension_numbers=dimension_numbers, out_chan=width, strides=strides, filter_shape=filter_shape, padding='CIRCULAR', W_std=1., b_std=0.1), stax.Erf(), stax.Dropout(0.2), stax.DotGeneral(rhs=0.7), conv(dimension_numbers=dimension_numbers, out_chan=width, strides=strides, filter_shape=filter_shape, padding='VALID', W_std=1., b_std=0.1), )), (stax.FanInSum() if concat is None else stax.FanInConcat(concat)), get_attn(), { 'avg': stax.GlobalAvgPool(), 'sum': stax.GlobalSumPool(), 'flatten': stax.Flatten(), }[proj], ) if get == 'nngp': init_fn, apply_fn, kernel_fn = stax.serial( nn, stax.Dense(width, 1., 0.)) elif get == 'ntk': init_fn, apply_fn, kernel_fn = stax.serial(nn, stax.Dense(1, 1., 0.)) else: raise ValueError(get) kernel_fn_mc = nt.monte_carlo_kernel_fn( init_fn, apply_fn, key, n_samples, device_count=0 if concat in (0, -n) else -1, implementation=_DEFAULT_TESTING_NTK_IMPLEMENTATION, vmap_axes=None if concat in (0, -n) else 0, ) kernel_fn = jit(kernel_fn, static_argnames='get') exact = kernel_fn(x1, x2, get, mask_constant=mask_constant) empirical = kernel_fn_mc(x1, x2, get=get, mask_constant=mask_constant) test_utils.assert_close_matrices(self, empirical, exact, tol)
def test_mask_fc(self, same_inputs, get, concat, p, mask_axis, mask_constant): width = 512 n_samples = 128 tol = 0.04 key = random.PRNGKey(1) x1 = random.normal(key, (4, 6, 5, 7)) x1 = test_utils.mask(x1, mask_constant, mask_axis, key, p) if same_inputs: x2 = None else: x2 = random.normal(key, (2, 6, 5, 7)) x2 = test_utils.mask(x2, mask_constant, mask_axis, key, p) nn = stax.serial( stax.Flatten(), stax.FanOut(3), stax.parallel( stax.serial( stax.Dense(width, 1., 0.1), stax.Abs(), stax.DotGeneral(lhs=-0.2), stax.Dense(width, 1.5, 0.01), ), stax.serial( stax.Dense(width, 1.1, 0.1), stax.DotGeneral(rhs=0.7), stax.Erf(), stax.Dense(width if concat != 1 else 512, 1.5, 0.1), ), stax.serial( stax.DotGeneral(rhs=0.5), stax.Dense(width, 1.2), stax.ABRelu(-0.2, 0.4), stax.Dense(width if concat != 1 else 1024, 1.3, 0.2), )), (stax.FanInSum() if concat is None else stax.FanInConcat(concat)), stax.Dense(width, 2., 0.01), stax.Relu()) if get == 'nngp': init_fn, apply_fn, kernel_fn = stax.serial( nn, stax.Dense(width, 1., 0.1)) elif get == 'ntk': init_fn, apply_fn, kernel_fn = stax.serial(nn, stax.Dense(1, 1., 0.1)) else: raise ValueError(get) kernel_fn_mc = nt.monte_carlo_kernel_fn( init_fn, apply_fn, key, n_samples, device_count=0 if concat in (0, -2) else -1, implementation=_DEFAULT_TESTING_NTK_IMPLEMENTATION, vmap_axes=None if concat in (0, -2) else 0, ) kernel_fn = jit(kernel_fn, static_argnames='get') exact = kernel_fn(x1, x2, get, mask_constant=mask_constant) empirical = kernel_fn_mc(x1, x2, get=get, mask_constant=mask_constant) test_utils.assert_close_matrices(self, empirical, exact, tol)
def test_input_req(self, same_inputs): test_utils.skip_test(self) key = random.PRNGKey(1) x1 = random.normal(key, (2, 7, 8, 4, 3)) x2 = None if same_inputs else random.normal(key, (4, 7, 8, 4, 3)) _, _, wrong_conv_fn = stax.serial( stax.Conv(out_chan=1, filter_shape=(1, 2, 3), dimension_numbers=('NDHWC', 'HDWIO', 'NCDWH')), stax.Relu(), stax.Conv(out_chan=1, filter_shape=(1, 2, 3), dimension_numbers=('NHDWC', 'HWDIO', 'NCWHD'))) with self.assertRaises(ValueError): wrong_conv_fn(x1, x2) init_fn, apply_fn, correct_conv_fn = stax.serial( stax.Conv(out_chan=1024, filter_shape=(1, 2, 3), dimension_numbers=('NHWDC', 'DHWIO', 'NCWDH')), stax.Relu(), stax.Conv(out_chan=1024, filter_shape=(1, 2, 3), dimension_numbers=('NCHDW', 'WHDIO', 'NCDWH')), stax.Flatten(), stax.Dense(1024)) correct_conv_fn_mc = nt.monte_carlo_kernel_fn( init_fn=init_fn, apply_fn=apply_fn, key=key, n_samples=400, implementation=_DEFAULT_TESTING_NTK_IMPLEMENTATION, vmap_axes=0) K = correct_conv_fn(x1, x2, get='nngp') K_mc = correct_conv_fn_mc(x1, x2, get='nngp') self.assertAllClose(K, K_mc, atol=0.01, rtol=0.05) _, _, wrong_conv_fn = stax.serial( stax.Conv(out_chan=1, filter_shape=(1, 2, 3), dimension_numbers=('NDHWC', 'HDWIO', 'NCDWH')), stax.GlobalAvgPool(channel_axis=2)) with self.assertRaises(ValueError): wrong_conv_fn(x1, x2) init_fn, apply_fn, correct_conv_fn = stax.serial( stax.Conv(out_chan=1024, filter_shape=(1, 2, 3), dimension_numbers=('NHDWC', 'DHWIO', 'NDWCH')), stax.Relu(), stax.AvgPool((2, 1, 3), batch_axis=0, channel_axis=-2), stax.Conv(out_chan=1024, filter_shape=(1, 2, 3), dimension_numbers=('NDHCW', 'IHWDO', 'NDCHW')), stax.Relu(), stax.GlobalAvgPool(channel_axis=2), stax.Dense(1024)) correct_conv_fn_mc = nt.monte_carlo_kernel_fn( init_fn=init_fn, apply_fn=apply_fn, key=key, n_samples=300, implementation=_DEFAULT_TESTING_NTK_IMPLEMENTATION, vmap_axes=0) K = correct_conv_fn(x1, x2, get='nngp') K_mc = correct_conv_fn_mc(x1, x2, get='nngp') self.assertAllClose(K, K_mc, atol=0.01, rtol=0.05) _, _, wrong_conv_fn = stax.serial( stax.Flatten(), stax.Dense(1), stax.Erf(), stax.Conv(out_chan=1, filter_shape=(1, 2), dimension_numbers=('CN', 'IO', 'NC')), ) with self.assertRaises(ValueError): wrong_conv_fn(x1, x2) init_fn, apply_fn, correct_conv_fn = stax.serial( stax.Flatten(), stax.Conv(out_chan=1024, filter_shape=()), stax.Relu(), stax.Dense(1)) correct_conv_fn_mc = nt.monte_carlo_kernel_fn( init_fn=init_fn, apply_fn=apply_fn, key=key, n_samples=200, implementation=_DEFAULT_TESTING_NTK_IMPLEMENTATION, vmap_axes=0) K = correct_conv_fn(x1, x2, get='ntk') K_mc = correct_conv_fn_mc(x1, x2, get='ntk') self.assertAllClose(K, K_mc, atol=0.01, rtol=0.05)
def test_activations( self, get, parameterization, parameterization_out, x1_type, x2_type, b_std, phi, do_jit ): """Tests forward- and reverse-mode autodiff for nonlinearities.""" if phi == stax.ABRelu: phi_ = phi(0.25, 0.5) else: phi_ = phi() if phi not in [stax.Relu]: test_utils.skip_test(self) n_out = 1 if get == 'ntk' else 1024 width = 832 W_std_in = width**(-0.5) if parameterization_out == 'standard' else 1. if phi == stax.Exp: W_std_in /= 10. init_fn, apply_fn, kernel_fn = stax.serial( stax.Dense( width, W_std=W_std_in, b_std=b_std, parameterization=parameterization), phi_, stax.Dense( n_out, b_std=b_std, parameterization=parameterization_out ), ) def get_x(x_type, key): shape = (1, 2) if x_type == 'zeros': x = np.zeros(shape) elif x_type == 'ones': x = np.ones(shape) elif x_type == 'random': x = random.normal(random.PRNGKey(key), shape) elif x_type == 'sin': x = np.sin(random.normal(random.PRNGKey(key), shape)) elif x_type == 'none': return None else: raise ValueError(x_type) return x x1 = get_x(x1_type, 1) if x2_type == 'x1': x2 = x1 else: x2 = get_x(x2_type, 2) def kernel_scalar(x1, x2): return kernel_fn(x1, x2, get)[0, 0] if do_jit: kernel_scalar = jit(kernel_scalar) k1 = kernel_scalar(x1, x2) k2, k2_grad = value_and_grad(kernel_scalar)(x1, x2) self.assertAllClose(k1, k2) # Compare to forward-mode. k2_fwd, _ = jvp(kernel_scalar, (x1, x2), (x1, x2)) k2_grad_fwd = jacfwd(kernel_scalar)(x1, x2) self.assertAllClose(k1, k2_fwd) self.assertAllClose(k2_grad, k2_grad_fwd) # `stax.ExpNormalized` has no forward pass. # `stax.Sign` is discontinuous at `0`, so NTK MC kernel does not converge to # infinite-width kernel. if phi == stax.ExpNormalized or (get == 'ntk' and phi == stax.Sign): raise absltest.SkipTest('Not comparing against MC kernels.') _kernel_scalar_mc = nt.monte_carlo_kernel_fn( init_fn, apply_fn, key=random.PRNGKey(3), n_samples=1, device_count=0, ) def kernel_scalar_mc(x1, x2): return _kernel_scalar_mc(x1, x2, get)[0, 0] k_mc = kernel_scalar_mc(x1, x2) k_mc2, k_mc2_grad = value_and_grad(kernel_scalar_mc)(x1, x2) self.assertAllClose(k_mc, k_mc2) # Compare MC to forward-mode. k_mc2_fwd, _ = jvp(kernel_scalar_mc, (x1, x2), (x1, x2)) k_mc2_grad_fwd = jacfwd(kernel_scalar_mc)(x1, x2) self.assertAllClose(k_mc, k_mc2_fwd) self.assertAllClose(k_mc2_grad, k_mc2_grad_fwd) def kernel_fn_emp(x1, x2, get, params): return nt.empirical_kernel_fn(apply_fn)(x1, x2, get, params)[0, 0] kernel_fn_emp_g = jit(value_and_grad(kernel_fn_emp), static_argnames='get') def kernel_scalar_mc_grad_mean(x1, x2): key = random.PRNGKey(4) n_samples = 2**9 k, k_grad = 0., 0. for _ in range(n_samples): _, params = init_fn(key, x1.shape) k_mc2, k_mc2_grad = kernel_fn_emp_g(x1, x2, get, params) k += k_mc2 k_grad += k_mc2_grad key, _ = random.split(key) k /= n_samples k_grad /= n_samples return k, k_grad k_mc2_mean, k_mc2_grad_mean = kernel_scalar_mc_grad_mean(x1, x2) # Compare kernels. self.assertAllClose(k1, k_mc2_mean, atol=4e-3, rtol=4e-2) if phi == stax.Sign and get == 'nngp': raise absltest.SkipTest('Derivative of the empirical NNGP of a ' 'discontinuous function does not converge ' 'to the derivative of the infinite width NNGP.') if (phi in [stax.Abs, stax.Relu, stax.LeakyRelu, stax.ABRelu] and get == 'ntk'): raise absltest.SkipTest('Derivative of the empirical NTK of a ' 'non-differentiable function does not converge ' 'to the derivative of the infinite width NTK.') atol = 1e-2 # Compare gradient of the analytic kernel to empirical kernel. if np.max(np.abs(k2_grad - k_mc2_grad_mean)) > atol: test_utils.assert_close_matrices(self, k_mc2_grad_mean, k2_grad, rtol=0.05, atol=10.)
def test_kwargs(self, do_batch, mode): rng = random.PRNGKey(1) x_train = random.normal(rng, (8, 7, 10)) x_test = random.normal(rng, (4, 7, 10)) y_train = random.normal(rng, (8, 1)) rng_train, rng_test = random.split(rng, 2) pattern_train = random.normal(rng, (8, 7, 7)) pattern_test = random.normal(rng, (4, 7, 7)) init_fn, apply_fn, kernel_fn = stax.serial( stax.Dense(8), stax.Relu(), stax.Dropout(rate=0.4), stax.Aggregate(), stax.GlobalAvgPool(), stax.Dense(1) ) kw_dd = dict(pattern=(pattern_train, pattern_train)) kw_td = dict(pattern=(pattern_test, pattern_train)) kw_tt = dict(pattern=(pattern_test, pattern_test)) if mode == 'mc': kernel_fn = monte_carlo_kernel_fn(init_fn, apply_fn, rng, 2, batch_size=2 if do_batch else 0) elif mode == 'empirical': kernel_fn = empirical_kernel_fn(apply_fn) if do_batch: raise absltest.SkipTest('Batching of empirical kernel is not ' 'implemented with keyword arguments.') for kw in (kw_dd, kw_td, kw_tt): kw.update(dict(params=init_fn(rng, x_train.shape)[1], get=('nngp', 'ntk'))) kw_dd.update(dict(rng=(rng_train, None))) kw_td.update(dict(rng=(rng_test, rng_train))) kw_tt.update(dict(rng=(rng_test, None))) elif mode == 'analytic': if do_batch: kernel_fn = batch.batch(kernel_fn, batch_size=2) else: raise ValueError(mode) k_dd = kernel_fn(x_train, None, **kw_dd) k_td = kernel_fn(x_test, x_train, **kw_td) k_tt = kernel_fn(x_test, None, **kw_tt) # Infinite time NNGP/NTK. predict_fn_gp = predict.gp_inference(k_dd, y_train) out_gp = predict_fn_gp(k_test_train=k_td, nngp_test_test=k_tt.nngp) if mode == 'empirical': for kw in (kw_dd, kw_td, kw_tt): kw.pop('get') predict_fn_ensemble = predict.gradient_descent_mse_ensemble(kernel_fn, x_train, y_train, **kw_dd) out_ensemble = predict_fn_ensemble(x_test=x_test, compute_cov=True, **kw_tt) self.assertAllClose(out_gp, out_ensemble) # Finite time NTK test. predict_fn_mse = predict.gradient_descent_mse(k_dd.ntk, y_train) out_mse = predict_fn_mse(t=1., fx_train_0=None, fx_test_0=0., k_test_train=k_td.ntk) out_ensemble = predict_fn_ensemble(t=1., get='ntk', x_test=x_test, compute_cov=False, **kw_tt) self.assertAllClose(out_mse, out_ensemble) # Finite time NNGP train. predict_fn_mse = predict.gradient_descent_mse(k_dd.nngp, y_train) out_mse = predict_fn_mse(t=2., fx_train_0=0., fx_test_0=None, k_test_train=k_td.nngp) out_ensemble = predict_fn_ensemble(t=2., get='nngp', x_test=None, compute_cov=False, **kw_dd) self.assertAllClose(out_mse, out_ensemble)