Example #1
0
  def test_ab_relu_id(self, same_inputs, do_stabilize):
    key = random.PRNGKey(1)
    X0_1 = random.normal(key, (3, 2))
    fc = stax.Dense(5, 1, 0)

    X0_2 = None if same_inputs else random.normal(key, (4, 2))

    # Test that ABRelu(a, a) == a * Identity
    init_fn, apply_id, kernel_fn_id = stax.serial(fc, stax.Identity())
    _, params = init_fn(key, input_shape=X0_1.shape)

    for a in [-5, -1, -0.5, 0, 0.5, 1, 5]:
      with self.subTest(a=a):
        _, apply_ab_relu, kernel_fn_ab_relu = stax.serial(
            fc, stax.ABRelu(a, a, do_stabilize=do_stabilize))

        X1_1_id = a * apply_id(params, X0_1)
        X1_1_ab_relu = apply_ab_relu(params, X0_1)
        self.assertAllClose(X1_1_id, X1_1_ab_relu)

        kernels_id = kernel_fn_id(X0_1 * a, None if X0_2 is None else a * X0_2)
        kernels_ab_relu = kernel_fn_ab_relu(X0_1, X0_2)
        # Manually correct the value of `is_gaussian` because
        # `ab_relu` (incorrectly) sets `is_gaussian=False` when `a==b`.
        kernels_ab_relu = kernels_ab_relu.replace(is_gaussian=True)
        self.assertAllClose(kernels_id, kernels_ab_relu)
    def test_ab_relu_id(self, same_inputs):
        key = random.PRNGKey(1)
        X0_1 = random.normal(key, (5, 7))
        fc = stax.Dense(10, 1, 0)

        X0_2 = None if same_inputs else random.normal(key, (9, 7))

        # Test that ABRelu(a, a) == a * Identity
        init_fn, apply_id, kernel_fn_id = stax.serial(fc, stax.Identity())
        params = init_fn(key, input_shape=(-1, 7))

        for a in [-5, -1, -0.5, 0, 0.5, 1, 5]:
            with self.subTest(a=a):
                _, apply_ab_relu, kernel_fn_ab_relu = stax.serial(
                    fc, stax.ABRelu(a, a))

                X1_1_id = a * apply_id(params, X0_1)
                X1_1_ab_relu = apply_ab_relu(params, X0_1)
                self.assertAllClose(X1_1_id, X1_1_ab_relu, True)

                kernels_id = kernel_fn_id(X0_1 * a,
                                          None if X0_2 is None else a * X0_2)
                kernels_ab_relu = kernel_fn_ab_relu(X0_1, X0_2,
                                                    ('nngp', 'ntk'))
                self.assertAllClose(kernels_id, kernels_ab_relu, True)
Example #3
0
  def test_abs(self, same_inputs):
    key = random.PRNGKey(1)
    X0_1 = random.normal(key, (5, 7))
    fc = stax.Dense(10, 1, 0)

    X0_2 = None if same_inputs else random.normal(key, (9, 7))

    # Test that Abs == ABRelu(-1, 1)
    init_fn, apply_leaky_relu, kernel_fn_abs = stax.serial(fc, stax.Abs())
    _, apply_ab_relu, kernel_fn_ab_relu = stax.serial(fc, stax.ABRelu(-1, 1))

    params = init_fn(key, input_shape=(-1, 7))
    X1_1_abs = apply_leaky_relu(params, X0_1)
    X1_1_ab_relu = apply_ab_relu(params, X0_1)
    self.assertAllClose(X1_1_abs, X1_1_ab_relu, True)

    kernels_abs = kernel_fn_abs(X0_1, X0_2, ('nngp', 'ntk'))
    kernels_ab_relu = kernel_fn_ab_relu(X0_1, X0_2, ('nngp', 'ntk'))
    self.assertAllClose(kernels_abs, kernels_ab_relu, True)
Example #4
0
  def test_abs(self, same_inputs, do_stabilize):
    key = random.PRNGKey(1)
    X0_1 = random.normal(key, (3, 2))
    fc = stax.Dense(5, 1, 0)

    X0_2 = None if same_inputs else random.normal(key, (4, 2))

    # Test that Abs == ABRelu(-1, 1)
    init_fn, apply_leaky_relu, kernel_fn_abs = stax.serial(
        fc, stax.Abs(do_stabilize=do_stabilize))
    _, apply_ab_relu, kernel_fn_ab_relu = stax.serial(fc, stax.ABRelu(-1, 1))

    _, params = init_fn(key, input_shape=X0_1.shape)
    X1_1_abs = apply_leaky_relu(params, X0_1)
    X1_1_ab_relu = apply_ab_relu(params, X0_1)
    self.assertAllClose(X1_1_abs, X1_1_ab_relu)

    kernels_abs = kernel_fn_abs(X0_1, X0_2, ('nngp', 'ntk'))
    kernels_ab_relu = kernel_fn_ab_relu(X0_1, X0_2, ('nngp', 'ntk'))
    self.assertAllClose(kernels_abs, kernels_ab_relu)
Example #5
0
  def test_leaky_relu(self, same_inputs):
    key = random.PRNGKey(1)
    X0_1 = random.normal(key, (5, 7))
    fc = stax.Dense(10, 1, 0)

    X0_2 = None if same_inputs else random.normal(key, (9, 7))

    # Test that ABRelu(alpha, 1) == LeakyRelu(alpha)
    for a in [-2, -1, 0, 1, 2]:
      with self.subTest(alpha=a):
        init_fn, apply_leaky_relu, kernel_fn_leaky_relu = stax.serial(
            fc, stax.LeakyRelu(a))
        _, apply_ab_relu, kernel_fn_ab_relu = stax.serial(fc, stax.ABRelu(a, 1))

        params = init_fn(key, input_shape=(-1, 7))
        X1_1_leaky_relu = apply_leaky_relu(params, X0_1)
        X1_1_ab_relu = apply_ab_relu(params, X0_1)
        self.assertAllClose(X1_1_leaky_relu, X1_1_ab_relu, True)

        kernels_leaky_relu = kernel_fn_leaky_relu(X0_1, X0_2)
        kernels_ab_relu = kernel_fn_ab_relu(X0_1, X0_2)
        self.assertAllClose(kernels_leaky_relu, kernels_ab_relu, True)
Example #6
0
  def test_ab_relu_relu(self, same_inputs):
    key = random.PRNGKey(1)
    X0_1 = random.normal(key, (5, 7))
    fc = stax.Dense(10, 1, 0)

    # Test that ABRelu(0, 1) == ReLU
    init_fn, apply_relu, kernel_fn_relu = stax.serial(fc, stax.Relu())
    params = init_fn(key, input_shape=(-1, 7))

    X0_2 = None if same_inputs else random.normal(key, (9, 7))

    for a, b in [(0, 1), (0, -1), (-1, 0), (1, 0)]:
      with self.subTest(a=a, b=b):
        _, apply_ab_relu, kernel_fn_ab_relu = stax.serial(fc, stax.ABRelu(a, b))

        X1_1_relu = (b - a) * apply_relu(params, X0_1 * (-1 if a != 0 else 1))
        X1_1_ab_relu = apply_ab_relu(params, X0_1)
        self.assertAllClose(X1_1_relu, X1_1_ab_relu, True)

        kernels_relu = kernel_fn_relu(X0_1, X0_2)
        kernels_ab_relu = kernel_fn_ab_relu(X0_1, X0_2)
        self.assertAllClose(kernels_relu, kernels_ab_relu, True)
Example #7
0
    'SAME',
    'VALID',
    'CIRCULAR'
]

STRIDES = [
    None,
    (1, 2),
    (2, 1),
]

ACTIVATIONS = {
    # TODO: investigate poor erf convergence.
    stax.Erf(): 'erf',
    stax.Relu(): 'Relu',
    stax.ABRelu(-0.5, 0.7): 'ABRelu(-0.5, 0.7)'
}

PROJECTIONS = [
    'FLAT',
    'POOL',
    'ATTN_FIXED',
    'ATTN_PARAM'
]

LAYER_NORM = [
    (-1,),
    (1, 3),
    (1, 2, 3)
]
    def test_mask_fc(self, same_inputs, get, concat, p, mask_axis,
                     mask_constant):
        width = 512
        n_samples = 128
        tol = 0.04
        key = random.PRNGKey(1)

        x1 = random.normal(key, (4, 6, 5, 7))
        x1 = test_utils.mask(x1, mask_constant, mask_axis, key, p)

        if same_inputs:
            x2 = None
        else:
            x2 = random.normal(key, (2, 6, 5, 7))
            x2 = test_utils.mask(x2, mask_constant, mask_axis, key, p)

        nn = stax.serial(
            stax.Flatten(), stax.FanOut(3),
            stax.parallel(
                stax.serial(
                    stax.Dense(width, 1., 0.1),
                    stax.Abs(),
                    stax.DotGeneral(lhs=-0.2),
                    stax.Dense(width, 1.5, 0.01),
                ),
                stax.serial(
                    stax.Dense(width, 1.1, 0.1),
                    stax.DotGeneral(rhs=0.7),
                    stax.Erf(),
                    stax.Dense(width if concat != 1 else 512, 1.5, 0.1),
                ),
                stax.serial(
                    stax.DotGeneral(rhs=0.5),
                    stax.Dense(width, 1.2),
                    stax.ABRelu(-0.2, 0.4),
                    stax.Dense(width if concat != 1 else 1024, 1.3, 0.2),
                )),
            (stax.FanInSum() if concat is None else stax.FanInConcat(concat)),
            stax.Dense(width, 2., 0.01), stax.Relu())

        if get == 'nngp':
            init_fn, apply_fn, kernel_fn = stax.serial(
                nn, stax.Dense(width, 1., 0.1))
        elif get == 'ntk':
            init_fn, apply_fn, kernel_fn = stax.serial(nn,
                                                       stax.Dense(1, 1., 0.1))
        else:
            raise ValueError(get)

        kernel_fn_mc = nt.monte_carlo_kernel_fn(
            init_fn,
            apply_fn,
            key,
            n_samples,
            device_count=0 if concat in (0, -2) else -1,
            implementation=_DEFAULT_TESTING_NTK_IMPLEMENTATION,
            vmap_axes=None if concat in (0, -2) else 0,
        )

        kernel_fn = jit(kernel_fn, static_argnames='get')
        exact = kernel_fn(x1, x2, get, mask_constant=mask_constant)
        empirical = kernel_fn_mc(x1, x2, get=get, mask_constant=mask_constant)
        test_utils.assert_close_matrices(self, empirical, exact, tol)
    'SAME',
    'VALID',
    'CIRCULAR'
]

STRIDES = [
    None,
    (1, 2),
    (2, 1),
]

ACTIVATIONS = {
    # TODO(romann): investigate poor erf convergence.
    stax.Erf(): 'erf',
    stax.Relu(): 'Relu',
    stax.ABRelu(-3, 2): 'ABRelu(-3, 2)'
}


def _get_inputs(key, is_conv, same_inputs, input_shape, fun=np.cos):
  key, split = random.split(key)
  shape = input_shape if is_conv else (input_shape[0], np.prod(input_shape[1:]))
  x1 = fun(random.normal(key, shape))
  x2 = None if same_inputs else 2 * fun(random.normal(split, shape))
  return x1, x2


def _get_net(W_std, b_std, filter_shape, is_conv, use_pooling, is_res,
             padding, phi, strides, width, is_ntk):
  fc = partial(stax.Dense, W_std=W_std, b_std=b_std)
  conv = partial(