def test_activation( self, same_inputs, model, phi_name, get, abc, approximate ): if abc != [0.3, 1.5, -np.pi/4]: test_utils.skip_test(self) if approximate and phi_name != 'Gelu': raise absltest.SkipTest( f'{phi_name} does not have an `approximate parameter.') a, b, c = abc if phi_name == 'Sin': activation = stax.Sin(a=a, b=b, c=c) elif phi_name == 'Erf': activation = stax.Erf(a=a, b=b, c=c) elif phi_name in ['Gelu', 'Sign', 'Cos']: if a != 0.3 or b != 0.3 or c != 0.: raise absltest.SkipTest('Skip `Gelu/Sign/Cos` test if ' ' (a, b, c) != (.3, .3, 0.).') activation = stax.Gelu() if phi_name == 'Gelu' else stax.Sign() else: raise NotImplementedError(f'Activation {phi_name} is not implemented.') self._test_activation(activation, same_inputs, model, get)
def test_vmap_axes(self, same_inputs): n1, n2 = 3, 4 c1, c2, c3 = 9, 5, 7 h2, h3, w3 = 6, 8, 2 def get_x(n, k): k1, k2, k3 = random.split(k, 3) x1 = random.normal(k1, (n, c1)) x2 = random.normal(k2, (h2, n, c2)) x3 = random.normal(k3, (c3, w3, n, h3)) x = [(x1, x2), x3] return x x1 = get_x(n1, random.PRNGKey(1)) x2 = get_x(n2, random.PRNGKey(2)) if not same_inputs else None p1 = random.normal(random.PRNGKey(5), (n1, h2, h2)) p2 = None if same_inputs else random.normal(random.PRNGKey(6), (n2, h2, h2)) init_fn, apply_fn, _ = stax.serial( stax.parallel( stax.parallel( stax.serial(stax.Dense(4, 2., 0.1), stax.Relu(), stax.Dense(3, 1., 0.15)), # 1 stax.serial(stax.Conv(7, (2,), padding='SAME', dimension_numbers=('HNC', 'OIH', 'NHC')), stax.Erf(), stax.Aggregate(1, 0, -1), stax.GlobalAvgPool(), stax.Dense(3, 0.5, 0.2)), # 2 ), stax.serial( stax.Conv(5, (2, 3), padding='SAME', dimension_numbers=('CWNH', 'IOHW', 'HWCN')), stax.Sin(), ) # 3 ), stax.parallel( stax.FanInSum(), stax.Conv(2, (2, 1), dimension_numbers=('HWCN', 'OIHW', 'HNWC')) ) ) _, params = init_fn(random.PRNGKey(3), tree_map(np.shape, x1)) implicit = jit(empirical._empirical_implicit_ntk_fn(apply_fn)) direct = jit(empirical._empirical_direct_ntk_fn(apply_fn)) implicit_batched = jit(empirical._empirical_implicit_ntk_fn( apply_fn, vmap_axes=([(0, 1), 2], [-2, -3], dict(pattern=0)))) direct_batched = jit(empirical._empirical_direct_ntk_fn( apply_fn, vmap_axes=([(-2, -2), -2], [0, 1], dict(pattern=-3)))) k = direct(x1, x2, params, pattern=(p1, p2)) self.assertAllClose(k, implicit(x1, x2, params, pattern=(p1, p2))) self.assertAllClose(k, direct_batched(x1, x2, params, pattern=(p1, p2))) self.assertAllClose(k, implicit_batched(x1, x2, params, pattern=(p1, p2)))
class ElementwiseNumericalTest(test_utils.NeuralTangentsTestCase): @parameterized.product( model=[ 'fc', 'conv-pool', 'conv-flatten' ], phi=[ stax.Erf(), stax.Gelu(), stax.Sin(), ], same_inputs=[False, True], get=['nngp', 'ntk'] ) def test_elementwise_numerical(self, same_inputs, model, phi, get): if 'conv' in model: test_utils.skip_test(self) key, split = random.split(random.PRNGKey(1)) output_dim = 1 b_std = 0.01 W_std = 1.0 rtol = 2e-3 deg = 25 if get == 'ntk': rtol *= 2 if default_backend() == 'tpu': rtol *= 2 if model == 'fc': X0_1 = random.normal(key, (3, 7)) X0_2 = None if same_inputs else random.normal(split, (5, 7)) affine = stax.Dense(1024, W_std, b_std) readout = stax.Dense(output_dim) depth = 1 else: X0_1 = random.normal(key, (2, 8, 8, 3)) X0_2 = None if same_inputs else random.normal(split, (3, 8, 8, 3)) affine = stax.Conv(1024, (3, 2), W_std=W_std, b_std=b_std, padding='SAME') readout = stax.serial(stax.GlobalAvgPool() if 'pool' in model else stax.Flatten(), stax.Dense(output_dim)) depth = 2 _, _, kernel_fn = stax.serial(*[affine, phi] * depth, readout) analytic_kernel = kernel_fn(X0_1, X0_2, get) fn = lambda x: phi[1]((), x) _, _, kernel_fn = stax.serial( *[affine, stax.ElementwiseNumerical(fn, deg=deg)] * depth, readout) numerical_activation_kernel = kernel_fn(X0_1, X0_2, get) test_utils.assert_close_matrices(self, analytic_kernel, numerical_activation_kernel, rtol)
class ElementwiseTest(test_utils.NeuralTangentsTestCase): @parameterized.product( phi=[ stax.Identity(), stax.Erf(), stax.Sin(), stax.Relu(), ], same_inputs=[False, True, None], n=[0, 1, 2], diagonal_batch=[True, False], diagonal_spatial=[True, False] ) def test_elementwise( self, same_inputs, phi, n, diagonal_batch, diagonal_spatial ): fn = lambda x: phi[1]((), x) name = phi[0].__name__ def nngp_fn(cov12, var1, var2): if 'Identity' in name: res = cov12 elif 'Erf' in name: prod = (1 + 2 * var1) * (1 + 2 * var2) res = np.arcsin(2 * cov12 / np.sqrt(prod)) * 2 / np.pi elif 'Sin' in name: sum_ = (var1 + var2) s1 = np.exp((-0.5 * sum_ + cov12)) s2 = np.exp((-0.5 * sum_ - cov12)) res = (s1 - s2) / 2 elif 'Relu' in name: prod = var1 * var2 sqrt = np.sqrt(np.maximum(prod - cov12 ** 2, 1e-30)) angles = np.arctan2(sqrt, cov12) dot_sigma = (1 - angles / np.pi) / 2 res = sqrt / (2 * np.pi) + dot_sigma * cov12 else: raise NotImplementedError(name) return res _, _, kernel_fn = stax.serial(stax.Dense(1), stax.Elementwise(fn, nngp_fn), stax.Dense(1), stax.Elementwise(fn, nngp_fn)) _, _, kernel_fn_manual = stax.serial(stax.Dense(1), phi, stax.Dense(1), phi) key = random.PRNGKey(1) shape = (4, 3, 2)[:n] + (1,) x1 = random.normal(key, (5,) + shape) if same_inputs is None: x2 = None elif same_inputs is True: x2 = x1 else: x2 = random.normal(key, (6,) + shape) kwargs = dict(diagonal_batch=diagonal_batch, diagonal_spatial=diagonal_spatial) k = kernel_fn(x1, x2, **kwargs) k_manual = kernel_fn_manual(x1, x2, **kwargs).replace(is_gaussian=False) self.assertAllClose(k_manual, k)