def test_ab_relu_id(self, same_inputs, do_stabilize): key = random.PRNGKey(1) X0_1 = random.normal(key, (3, 2)) fc = stax.Dense(5, 1, 0) X0_2 = None if same_inputs else random.normal(key, (4, 2)) # Test that ABRelu(a, a) == a * Identity init_fn, apply_id, kernel_fn_id = stax.serial(fc, stax.Identity()) _, params = init_fn(key, input_shape=X0_1.shape) for a in [-5, -1, -0.5, 0, 0.5, 1, 5]: with self.subTest(a=a): _, apply_ab_relu, kernel_fn_ab_relu = stax.serial( fc, stax.ABRelu(a, a, do_stabilize=do_stabilize)) X1_1_id = a * apply_id(params, X0_1) X1_1_ab_relu = apply_ab_relu(params, X0_1) self.assertAllClose(X1_1_id, X1_1_ab_relu) kernels_id = kernel_fn_id(X0_1 * a, None if X0_2 is None else a * X0_2) kernels_ab_relu = kernel_fn_ab_relu(X0_1, X0_2) # Manually correct the value of `is_gaussian` because # `ab_relu` (incorrectly) sets `is_gaussian=False` when `a==b`. kernels_ab_relu = kernels_ab_relu.replace(is_gaussian=True) self.assertAllClose(kernels_id, kernels_ab_relu)
def test_ab_relu_id(self, same_inputs): key = random.PRNGKey(1) X0_1 = random.normal(key, (5, 7)) fc = stax.Dense(10, 1, 0) X0_2 = None if same_inputs else random.normal(key, (9, 7)) # Test that ABRelu(a, a) == a * Identity init_fn, apply_id, kernel_fn_id = stax.serial(fc, stax.Identity()) params = init_fn(key, input_shape=(-1, 7)) for a in [-5, -1, -0.5, 0, 0.5, 1, 5]: with self.subTest(a=a): _, apply_ab_relu, kernel_fn_ab_relu = stax.serial( fc, stax.ABRelu(a, a)) X1_1_id = a * apply_id(params, X0_1) X1_1_ab_relu = apply_ab_relu(params, X0_1) self.assertAllClose(X1_1_id, X1_1_ab_relu, True) kernels_id = kernel_fn_id(X0_1 * a, None if X0_2 is None else a * X0_2) kernels_ab_relu = kernel_fn_ab_relu(X0_1, X0_2, ('nngp', 'ntk')) self.assertAllClose(kernels_id, kernels_ab_relu, True)
def test_abs(self, same_inputs): key = random.PRNGKey(1) X0_1 = random.normal(key, (5, 7)) fc = stax.Dense(10, 1, 0) X0_2 = None if same_inputs else random.normal(key, (9, 7)) # Test that Abs == ABRelu(-1, 1) init_fn, apply_leaky_relu, kernel_fn_abs = stax.serial(fc, stax.Abs()) _, apply_ab_relu, kernel_fn_ab_relu = stax.serial(fc, stax.ABRelu(-1, 1)) params = init_fn(key, input_shape=(-1, 7)) X1_1_abs = apply_leaky_relu(params, X0_1) X1_1_ab_relu = apply_ab_relu(params, X0_1) self.assertAllClose(X1_1_abs, X1_1_ab_relu, True) kernels_abs = kernel_fn_abs(X0_1, X0_2, ('nngp', 'ntk')) kernels_ab_relu = kernel_fn_ab_relu(X0_1, X0_2, ('nngp', 'ntk')) self.assertAllClose(kernels_abs, kernels_ab_relu, True)
def test_abs(self, same_inputs, do_stabilize): key = random.PRNGKey(1) X0_1 = random.normal(key, (3, 2)) fc = stax.Dense(5, 1, 0) X0_2 = None if same_inputs else random.normal(key, (4, 2)) # Test that Abs == ABRelu(-1, 1) init_fn, apply_leaky_relu, kernel_fn_abs = stax.serial( fc, stax.Abs(do_stabilize=do_stabilize)) _, apply_ab_relu, kernel_fn_ab_relu = stax.serial(fc, stax.ABRelu(-1, 1)) _, params = init_fn(key, input_shape=X0_1.shape) X1_1_abs = apply_leaky_relu(params, X0_1) X1_1_ab_relu = apply_ab_relu(params, X0_1) self.assertAllClose(X1_1_abs, X1_1_ab_relu) kernels_abs = kernel_fn_abs(X0_1, X0_2, ('nngp', 'ntk')) kernels_ab_relu = kernel_fn_ab_relu(X0_1, X0_2, ('nngp', 'ntk')) self.assertAllClose(kernels_abs, kernels_ab_relu)
def test_leaky_relu(self, same_inputs): key = random.PRNGKey(1) X0_1 = random.normal(key, (5, 7)) fc = stax.Dense(10, 1, 0) X0_2 = None if same_inputs else random.normal(key, (9, 7)) # Test that ABRelu(alpha, 1) == LeakyRelu(alpha) for a in [-2, -1, 0, 1, 2]: with self.subTest(alpha=a): init_fn, apply_leaky_relu, kernel_fn_leaky_relu = stax.serial( fc, stax.LeakyRelu(a)) _, apply_ab_relu, kernel_fn_ab_relu = stax.serial(fc, stax.ABRelu(a, 1)) params = init_fn(key, input_shape=(-1, 7)) X1_1_leaky_relu = apply_leaky_relu(params, X0_1) X1_1_ab_relu = apply_ab_relu(params, X0_1) self.assertAllClose(X1_1_leaky_relu, X1_1_ab_relu, True) kernels_leaky_relu = kernel_fn_leaky_relu(X0_1, X0_2) kernels_ab_relu = kernel_fn_ab_relu(X0_1, X0_2) self.assertAllClose(kernels_leaky_relu, kernels_ab_relu, True)
def test_ab_relu_relu(self, same_inputs): key = random.PRNGKey(1) X0_1 = random.normal(key, (5, 7)) fc = stax.Dense(10, 1, 0) # Test that ABRelu(0, 1) == ReLU init_fn, apply_relu, kernel_fn_relu = stax.serial(fc, stax.Relu()) params = init_fn(key, input_shape=(-1, 7)) X0_2 = None if same_inputs else random.normal(key, (9, 7)) for a, b in [(0, 1), (0, -1), (-1, 0), (1, 0)]: with self.subTest(a=a, b=b): _, apply_ab_relu, kernel_fn_ab_relu = stax.serial(fc, stax.ABRelu(a, b)) X1_1_relu = (b - a) * apply_relu(params, X0_1 * (-1 if a != 0 else 1)) X1_1_ab_relu = apply_ab_relu(params, X0_1) self.assertAllClose(X1_1_relu, X1_1_ab_relu, True) kernels_relu = kernel_fn_relu(X0_1, X0_2) kernels_ab_relu = kernel_fn_ab_relu(X0_1, X0_2) self.assertAllClose(kernels_relu, kernels_ab_relu, True)
'SAME', 'VALID', 'CIRCULAR' ] STRIDES = [ None, (1, 2), (2, 1), ] ACTIVATIONS = { # TODO: investigate poor erf convergence. stax.Erf(): 'erf', stax.Relu(): 'Relu', stax.ABRelu(-0.5, 0.7): 'ABRelu(-0.5, 0.7)' } PROJECTIONS = [ 'FLAT', 'POOL', 'ATTN_FIXED', 'ATTN_PARAM' ] LAYER_NORM = [ (-1,), (1, 3), (1, 2, 3) ]
def test_mask_fc(self, same_inputs, get, concat, p, mask_axis, mask_constant): width = 512 n_samples = 128 tol = 0.04 key = random.PRNGKey(1) x1 = random.normal(key, (4, 6, 5, 7)) x1 = test_utils.mask(x1, mask_constant, mask_axis, key, p) if same_inputs: x2 = None else: x2 = random.normal(key, (2, 6, 5, 7)) x2 = test_utils.mask(x2, mask_constant, mask_axis, key, p) nn = stax.serial( stax.Flatten(), stax.FanOut(3), stax.parallel( stax.serial( stax.Dense(width, 1., 0.1), stax.Abs(), stax.DotGeneral(lhs=-0.2), stax.Dense(width, 1.5, 0.01), ), stax.serial( stax.Dense(width, 1.1, 0.1), stax.DotGeneral(rhs=0.7), stax.Erf(), stax.Dense(width if concat != 1 else 512, 1.5, 0.1), ), stax.serial( stax.DotGeneral(rhs=0.5), stax.Dense(width, 1.2), stax.ABRelu(-0.2, 0.4), stax.Dense(width if concat != 1 else 1024, 1.3, 0.2), )), (stax.FanInSum() if concat is None else stax.FanInConcat(concat)), stax.Dense(width, 2., 0.01), stax.Relu()) if get == 'nngp': init_fn, apply_fn, kernel_fn = stax.serial( nn, stax.Dense(width, 1., 0.1)) elif get == 'ntk': init_fn, apply_fn, kernel_fn = stax.serial(nn, stax.Dense(1, 1., 0.1)) else: raise ValueError(get) kernel_fn_mc = nt.monte_carlo_kernel_fn( init_fn, apply_fn, key, n_samples, device_count=0 if concat in (0, -2) else -1, implementation=_DEFAULT_TESTING_NTK_IMPLEMENTATION, vmap_axes=None if concat in (0, -2) else 0, ) kernel_fn = jit(kernel_fn, static_argnames='get') exact = kernel_fn(x1, x2, get, mask_constant=mask_constant) empirical = kernel_fn_mc(x1, x2, get=get, mask_constant=mask_constant) test_utils.assert_close_matrices(self, empirical, exact, tol)
'SAME', 'VALID', 'CIRCULAR' ] STRIDES = [ None, (1, 2), (2, 1), ] ACTIVATIONS = { # TODO(romann): investigate poor erf convergence. stax.Erf(): 'erf', stax.Relu(): 'Relu', stax.ABRelu(-3, 2): 'ABRelu(-3, 2)' } def _get_inputs(key, is_conv, same_inputs, input_shape, fun=np.cos): key, split = random.split(key) shape = input_shape if is_conv else (input_shape[0], np.prod(input_shape[1:])) x1 = fun(random.normal(key, shape)) x2 = None if same_inputs else 2 * fun(random.normal(split, shape)) return x1, x2 def _get_net(W_std, b_std, filter_shape, is_conv, use_pooling, is_res, padding, phi, strides, width, is_ntk): fc = partial(stax.Dense, W_std=W_std, b_std=b_std) conv = partial(