Ejemplo n.º 1
0
  def test_hermite(self, same_inputs, degree, get, readout):
    key = random.PRNGKey(1)
    key1, key2, key = random.split(key, 3)

    if degree > 2:
      width = 10000
      n_samples = 5000
      test_utils.skip_test(self)
    else:
      width = 10000
      n_samples = 100

    x1 = np.cos(random.normal(key1, [2, 6, 6, 3]))
    x2 = x1 if same_inputs else np.cos(random.normal(key2, [3, 6, 6, 3]))

    conv_layers = [
        stax.Conv(width, (3, 3), W_std=2., b_std=0.5),
        stax.LayerNorm(),
        stax.Hermite(degree),
        stax.GlobalAvgPool() if readout == 'pool' else stax.Flatten(),
        stax.Dense(1) if get == 'ntk' else stax.Identity()]

    init_fn, apply_fn, kernel_fn = stax.serial(*conv_layers)
    analytic_kernel = kernel_fn(x1, x2, get)
    mc_kernel_fn = nt.monte_carlo_kernel_fn(init_fn, apply_fn, key, n_samples)
    mc_kernel = mc_kernel_fn(x1, x2, get)
    rot = degree / 2. * 1e-2
    test_utils.assert_close_matrices(self, mc_kernel, analytic_kernel, rot)
Ejemplo n.º 2
0
  def test_layernorm(self,
                     model,
                     width,
                     same_inputs,
                     is_ntk,
                     proj_into_2d,
                     layer_norm):
    is_conv = 'conv' in model
    # Check for duplicate / incorrectly-shaped NN configs / wrong backend.
    if is_conv:
      test_utils.skip_test(self)
    elif proj_into_2d != PROJECTIONS[0] or layer_norm not in ('C', 'NC'):
      raise absltest.SkipTest('FC models do not have these parameters.')

    W_std, b_std = 2.**0.5, 0.5**0.5
    filter_shape = FILTER_SHAPES[0]
    padding = PADDINGS[0]
    strides = STRIDES[0]
    phi = stax.Relu()
    use_pooling, is_res = False, False
    parameterization = 'ntk'
    pool_type = 'AVG'
    use_dropout = False

    net = _get_net(W_std, b_std, filter_shape, is_conv, use_pooling, is_res,
                   padding, phi, strides, width, is_ntk, proj_into_2d,
                   pool_type, layer_norm, parameterization, 1, use_dropout)
    _check_agreement_with_empirical(self, net, same_inputs, use_dropout, is_ntk,
                                    0.07)
Ejemplo n.º 3
0
  def test_activation(
      self,
      same_inputs,
      model,
      phi_name,
      get,
      abc,
      approximate
  ):
    if abc != [0.3, 1.5, -np.pi/4]:
      test_utils.skip_test(self)

    if approximate and phi_name != 'Gelu':
      raise absltest.SkipTest(
          f'{phi_name} does not have an `approximate parameter.')

    a, b, c = abc
    if phi_name == 'Sin':
      activation = stax.Sin(a=a, b=b, c=c)
    elif phi_name == 'Erf':
      activation = stax.Erf(a=a, b=b, c=c)
    elif phi_name in ['Gelu', 'Sign', 'Cos']:
      if a != 0.3 or b != 0.3 or c != 0.:
        raise absltest.SkipTest('Skip `Gelu/Sign/Cos` test if '
                                ' (a, b, c) != (.3, .3, 0.).')
      activation = stax.Gelu() if phi_name == 'Gelu' else stax.Sign()
    else:
      raise NotImplementedError(f'Activation {phi_name} is not implemented.')
    self._test_activation(activation, same_inputs, model, get)
Ejemplo n.º 4
0
  def test_nonlinear(
      self,
      model,
      width,
      same_inputs,
      is_ntk,
      filter_shape,
      proj_into_2d,
      b_std,
      W_std,
      parameterization,
      s
  ):
    is_conv = 'conv' in model

    if parameterization == 'standard':
      width //= s

    padding = PADDINGS[0]
    strides = STRIDES[0]
    phi = stax.Relu()
    use_pooling, is_res = False, False
    layer_norm = None
    pool_type = 'AVG'
    use_dropout = False

    # Check for duplicate / incorrectly-shaped NN configs / wrong backend.
    if is_conv:
      test_utils.skip_test(self)
    elif proj_into_2d != PROJECTIONS[0] or filter_shape != FILTER_SHAPES[0]:
      raise absltest.SkipTest('FC models do not have these parameters.')

    net = _get_net(W_std=W_std,
                   b_std=b_std,
                   filter_shape=filter_shape,
                   is_conv=is_conv,
                   use_pooling=use_pooling,
                   is_res=is_res,
                   padding=padding,
                   phi=phi,
                   strides=strides,
                   width=width,
                   is_ntk=is_ntk,
                   proj_into_2d=proj_into_2d,
                   pool_type=pool_type,
                   layer_norm=layer_norm,
                   parameterization=parameterization,
                   s=s,
                   use_dropout=use_dropout)

    _check_agreement_with_empirical(
        self,
        net=net,
        same_inputs=same_inputs,
        use_dropout=use_dropout,
        is_ntk=is_ntk,
        rtol=0.015,
        atol=1000
    )
Ejemplo n.º 5
0
  def _test_activation(self, activation_fn, same_inputs, model, get,
                       rbf_gamma=None):
    if 'conv' in model:
      test_utils.skip_test(self)

    key = random.PRNGKey(1)
    key, split = random.split(key)
    output_dim = 1024 if get == 'nngp' else 1
    b_std = 0.5
    W_std = 2.0
    if activation_fn[2].__name__ == 'Sin':
      W_std = 0.9
    if activation_fn[2].__name__ == 'Rbf':
      W_std = 1.0
      b_std = 0.0

    if model == 'fc':
      rtol = 0.04
      X0_1 = random.normal(key, (4, 2))
      X0_2 = None if same_inputs else random.normal(split, (2, 2))
      affine = stax.Dense(1024, W_std, b_std)
      readout = stax.Dense(output_dim)
      depth = 1

    else:
      rtol = 0.05
      X0_1 = random.normal(key, (2, 4, 4, 3))
      X0_2 = None if same_inputs else random.normal(split, (4, 4, 4, 3))
      affine = stax.Conv(512, (3, 2), W_std=W_std, b_std=b_std, padding='SAME')
      readout = stax.serial(stax.GlobalAvgPool() if 'pool' in model else
                            stax.Flatten(),
                            stax.Dense(output_dim))
      depth = 2

    if default_backend() == 'cpu':
      num_samplings = 200
      rtol *= 2
    else:
      num_samplings = (500 if activation_fn[2].__name__ in ('Sin', 'Rbf')
                       else 300)

    init_fn, apply_fn, kernel_fn = stax.serial(
        *[affine, activation_fn]*depth, readout)
    analytic_kernel = kernel_fn(X0_1, X0_2, get)
    mc_kernel_fn = nt.monte_carlo_kernel_fn(
        init_fn, apply_fn, split, num_samplings, implementation=2,
        vmap_axes=0
    )
    empirical_kernel = mc_kernel_fn(X0_1, X0_2, get)
    test_utils.assert_close_matrices(self, analytic_kernel,
                                     empirical_kernel, rtol)

    # Check match with explicit RBF
    if rbf_gamma is not None and get == 'nngp' and model == 'fc':
      input_dim = X0_1.shape[1]
      _, _, kernel_fn = self._RBF(rbf_gamma / input_dim)
      direct_rbf_kernel = kernel_fn(X0_1, X0_2, get)
      test_utils.assert_close_matrices(self, analytic_kernel,
                                       direct_rbf_kernel, rtol)
Ejemplo n.º 6
0
  def test_pool(self, width, same_inputs, is_ntk, pool_type,
                padding, filter_shape, strides, normalize_edges):
    use_dropout = False
    # Check for duplicate / incorrectly-shaped NN configs / wrong backend.
    test_utils.skip_test(self)
    if pool_type == 'SUM' and normalize_edges:
      raise absltest.SkipTest('normalize_edges not applicable to SumPool.')

    net = _get_net_pool(width, is_ntk, pool_type,
                        padding, filter_shape, strides, normalize_edges)
    _check_agreement_with_empirical(self, net, same_inputs, use_dropout, is_ntk)
Ejemplo n.º 7
0
  def _skip_test(self, filter_shape, is_conv, is_res, padding, proj_into_2d,
                 strides, use_pooling):
    if is_conv:
      test_utils.skip_test(self)

      if (is_res and is_conv and ((strides is not None and strides != (1, 1)) or
                                  (padding == 'VALID' and filter_shape !=
                                   (1, 1)))):
        raise absltest.SkipTest('Different paths in a residual models need to '
                                'return outputs of the same shape.')
    elif (filter_shape != FILTER_SHAPES[0] or padding != PADDINGS[0] or
          strides != STRIDES[0] or proj_into_2d != PROJECTIONS[0] or
          use_pooling):
      raise absltest.SkipTest('FC models do not have these parameters.')
Ejemplo n.º 8
0
    def test_binary(self, primitive: Optional[Primitive], shape1, shape2,
                    dtype, params):
        # TODO(romann): revisit when bugs below are fixed.
        if primitive == lax.conv_general_dilated_p:
            if jax.default_backend() == 'tpu':
                raise absltest.SkipTest('http://b/235167364')

            elif jax.default_backend(
            ) == 'gpu' and params['batch_group_count'] != 1:
                raise absltest.SkipTest('http://b/235485533')

        if len(shape1) > 3 or len(shape2) > 3:
            test_utils.skip_test(self)

        self._test_primitive(primitive, [shape1, shape2], dtype, params)
Ejemplo n.º 9
0
  def test_elementwise_numerical(self, same_inputs, model, phi, get):
    if 'conv' in model:
      test_utils.skip_test(self)

    key, split = random.split(random.PRNGKey(1))

    output_dim = 1
    b_std = 0.01
    W_std = 1.0
    rtol = 2e-3
    deg = 25
    if get == 'ntk':
      rtol *= 2
    if default_backend() == 'tpu':
      rtol *= 2

    if model == 'fc':
      X0_1 = random.normal(key, (3, 7))
      X0_2 = None if same_inputs else random.normal(split, (5, 7))
      affine = stax.Dense(1024, W_std, b_std)
      readout = stax.Dense(output_dim)
      depth = 1
    else:
      X0_1 = random.normal(key, (2, 8, 8, 3))
      X0_2 = None if same_inputs else random.normal(split, (3, 8, 8, 3))
      affine = stax.Conv(1024, (3, 2), W_std=W_std, b_std=b_std, padding='SAME')
      readout = stax.serial(stax.GlobalAvgPool() if 'pool' in model else
                            stax.Flatten(),
                            stax.Dense(output_dim))
      depth = 2

    _, _, kernel_fn = stax.serial(*[affine, phi] * depth, readout)
    analytic_kernel = kernel_fn(X0_1, X0_2, get)

    fn = lambda x: phi[1]((), x)
    _, _, kernel_fn = stax.serial(
        *[affine, stax.ElementwiseNumerical(fn, deg=deg)] * depth, readout)
    numerical_activation_kernel = kernel_fn(X0_1, X0_2, get)

    test_utils.assert_close_matrices(self, analytic_kernel,
                                     numerical_activation_kernel, rtol)
Ejemplo n.º 10
0
    def test_fan_in_fc(self, same_inputs, axis, n_branches, get, branch_in,
                       fan_in_mode):
        if fan_in_mode in ['FanInSum', 'FanInProd']:
            if axis != 0:
                raise absltest.SkipTest(
                    '`FanInSum` and `FanInProd` are skipped when '
                    'axis != 0.')
            axis = None
        if (fan_in_mode == 'FanInSum'
                or axis == 0) and branch_in == 'dense_after_branch_in':
            raise absltest.SkipTest('`FanInSum` and `FanInConcat(0)` '
                                    'require `is_gaussian`.')

        if ((axis == 1 or fan_in_mode == 'FanInProd')
                and branch_in == 'dense_before_branch_in'):
            raise absltest.SkipTest(
                '`FanInConcat` or `FanInProd` on feature axis requires a dense layer '
                'after concatenation or Hadamard product.')
        if fan_in_mode == 'FanInSum':
            fan_in_layer = stax.FanInSum()
        elif fan_in_mode == 'FanInProd':
            fan_in_layer = stax.FanInProd()
        else:
            fan_in_layer = stax.FanInConcat(axis)

        if n_branches != 2:
            test_utils.skip_test(self)

        key = random.PRNGKey(1)
        X0_1 = np.cos(random.normal(key, (4, 3)))
        X0_2 = None if same_inputs else random.normal(key, (8, 3))

        width = 1024
        n_samples = 256 * 2

        if default_backend() == 'tpu':
            tol = 0.07
        else:
            tol = 0.02

        dense = stax.Dense(width, 1.25, 0.1)
        input_layers = [dense, stax.FanOut(n_branches)]

        branches = []
        for b in range(n_branches):
            branch_layers = [FanInTest._get_phi(b)]
            for i in range(b):
                multiplier = 1 if axis not in (1, -1) else (1 + 0.25 * i)
                branch_layers += [
                    stax.Dense(int(width * multiplier), 1. + 2 * i, 0.5 + i),
                    FanInTest._get_phi(i)
                ]

            if branch_in == 'dense_before_branch_in':
                branch_layers += [dense]
            branches += [stax.serial(*branch_layers)]

        output_layers = [fan_in_layer, stax.Relu()]
        if branch_in == 'dense_after_branch_in':
            output_layers.insert(1, dense)

        nn = stax.serial(*(input_layers + [stax.parallel(*branches)] +
                           output_layers))

        if get == 'nngp':
            init_fn, apply_fn, kernel_fn = nn
        elif get == 'ntk':
            init_fn, apply_fn, kernel_fn = stax.serial(
                nn, stax.Dense(1, 1.25, 0.5))
        else:
            raise ValueError(get)

        kernel_fn_mc = nt.monte_carlo_kernel_fn(
            init_fn,
            apply_fn,
            key,
            n_samples,
            device_count=0 if axis in (0, -2) else -1,
            implementation=_DEFAULT_TESTING_NTK_IMPLEMENTATION,
            vmap_axes=None if axis in (0, -2) else 0,
        )

        exact = kernel_fn(X0_1, X0_2, get=get)
        empirical = kernel_fn_mc(X0_1, X0_2, get=get)
        test_utils.assert_close_matrices(self, empirical, exact, tol)
Ejemplo n.º 11
0
    def test_fan_in_conv(self, same_inputs, axis, n_branches, get, branch_in,
                         readout, fan_in_mode):
        test_utils.skip_test(self)
        if fan_in_mode in ['FanInSum', 'FanInProd']:
            if axis != 0:
                raise absltest.SkipTest(
                    '`FanInSum` and `FanInProd()` are skipped when '
                    'axis != 0.')
            axis = None
        if (fan_in_mode == 'FanInSum'
                or axis in [0, 1, 2]) and branch_in == 'dense_after_branch_in':
            raise absltest.SkipTest('`FanInSum` and `FanInConcat(0/1/2)` '
                                    'require `is_gaussian`.')

        if ((axis == 3 or fan_in_mode == 'FanInProd')
                and branch_in == 'dense_before_branch_in'):
            raise absltest.SkipTest(
                '`FanInConcat` or `FanInProd` on feature axis '
                'requires a dense layer after concatenation '
                'or Hadamard product.')

        if fan_in_mode == 'FanInSum':
            fan_in_layer = stax.FanInSum()
        elif fan_in_mode == 'FanInProd':
            fan_in_layer = stax.FanInProd()
        else:
            fan_in_layer = stax.FanInConcat(axis)

        key = random.PRNGKey(1)
        X0_1 = random.normal(key, (2, 5, 6, 3))
        X0_2 = None if same_inputs else random.normal(key, (3, 5, 6, 3))

        if default_backend() == 'tpu':
            width = 2048
            n_samples = 1024
            tol = 0.02
        else:
            width = 1024
            n_samples = 512
            tol = 0.01

        conv = stax.Conv(out_chan=width,
                         filter_shape=(3, 3),
                         padding='SAME',
                         W_std=1.25,
                         b_std=0.1)

        input_layers = [conv, stax.FanOut(n_branches)]

        branches = []
        for b in range(n_branches):
            branch_layers = [FanInTest._get_phi(b)]
            for i in range(b):
                multiplier = 1 if axis not in (3, -1) else (1 + 0.25 * i)
                branch_layers += [
                    stax.Conv(out_chan=int(width * multiplier),
                              filter_shape=(i + 1, 4 - i),
                              padding='SAME',
                              W_std=1.25 + i,
                              b_std=0.1 + i),
                    FanInTest._get_phi(i)
                ]

            if branch_in == 'dense_before_branch_in':
                branch_layers += [conv]
            branches += [stax.serial(*branch_layers)]

        output_layers = [
            fan_in_layer,
            stax.Relu(),
            stax.GlobalAvgPool() if readout == 'pool' else stax.Flatten()
        ]
        if branch_in == 'dense_after_branch_in':
            output_layers.insert(1, conv)

        nn = stax.serial(*(input_layers + [stax.parallel(*branches)] +
                           output_layers))

        init_fn, apply_fn, kernel_fn = stax.serial(
            nn, stax.Dense(1 if get == 'ntk' else width, 1.25, 0.5))

        kernel_fn_mc = nt.monte_carlo_kernel_fn(
            init_fn,
            apply_fn,
            key,
            n_samples,
            device_count=0 if axis in (0, -4) else -1,
            implementation=_DEFAULT_TESTING_NTK_IMPLEMENTATION,
            vmap_axes=None if axis in (0, -4) else 0,
        )

        exact = kernel_fn(X0_1, X0_2, get=get)
        empirical = kernel_fn_mc(X0_1, X0_2, get=get)
        test_utils.assert_close_matrices(self, empirical, exact, tol)
Ejemplo n.º 12
0
    def test_mask_conv(self, same_inputs, get, mask_axis, mask_constant,
                       concat, proj, p, n, transpose):
        if isinstance(concat, int) and concat > n:
            raise absltest.SkipTest('Concatenation axis out of bounds.')

        test_utils.skip_test(self)
        if default_backend() == 'gpu' and n > 3:
            raise absltest.SkipTest('>=4D-CNN is not supported on GPUs.')

        width = 256
        n_samples = 256
        tol = 0.03
        key = random.PRNGKey(1)

        spatial_shape = ((1, 2, 3, 2, 1) if transpose else (15, 8, 9))[:n]
        filter_shape = ((2, 3, 1, 2, 1) if transpose else (7, 2, 3))[:n]
        strides = (2, 1, 3, 2, 3)[:n]
        spatial_spec = 'HWDZX'[:n]
        dimension_numbers = ('N' + spatial_spec + 'C', 'OI' + spatial_spec,
                             'N' + spatial_spec + 'C')

        x1 = np.cos(random.normal(key, (2, ) + spatial_shape + (2, )))
        x1 = test_utils.mask(x1, mask_constant, mask_axis, key, p)

        if same_inputs:
            x2 = None
        else:
            x2 = np.cos(random.normal(key, (4, ) + spatial_shape + (2, )))
            x2 = test_utils.mask(x2, mask_constant, mask_axis, key, p)

        def get_attn():
            return stax.GlobalSelfAttention(
                n_chan_out=width,
                n_chan_key=width,
                n_chan_val=int(np.round(float(width) / int(np.sqrt(width)))),
                n_heads=int(np.sqrt(width)),
            ) if proj == 'avg' else stax.Identity()

        conv = stax.ConvTranspose if transpose else stax.Conv

        nn = stax.serial(
            stax.FanOut(3),
            stax.parallel(
                stax.serial(
                    conv(dimension_numbers=dimension_numbers,
                         out_chan=width,
                         strides=strides,
                         filter_shape=filter_shape,
                         padding='CIRCULAR',
                         W_std=1.5,
                         b_std=0.2),
                    stax.LayerNorm(axis=(1, -1)),
                    stax.Abs(),
                    stax.DotGeneral(rhs=0.9),
                    conv(dimension_numbers=dimension_numbers,
                         out_chan=width,
                         strides=strides,
                         filter_shape=filter_shape,
                         padding='VALID',
                         W_std=1.2,
                         b_std=0.1),
                ),
                stax.serial(
                    conv(dimension_numbers=dimension_numbers,
                         out_chan=width,
                         strides=strides,
                         filter_shape=filter_shape,
                         padding='SAME',
                         W_std=0.1,
                         b_std=0.3),
                    stax.Relu(),
                    stax.Dropout(0.7),
                    conv(dimension_numbers=dimension_numbers,
                         out_chan=width,
                         strides=strides,
                         filter_shape=filter_shape,
                         padding='VALID',
                         W_std=0.9,
                         b_std=1.),
                ),
                stax.serial(
                    get_attn(),
                    conv(dimension_numbers=dimension_numbers,
                         out_chan=width,
                         strides=strides,
                         filter_shape=filter_shape,
                         padding='CIRCULAR',
                         W_std=1.,
                         b_std=0.1),
                    stax.Erf(),
                    stax.Dropout(0.2),
                    stax.DotGeneral(rhs=0.7),
                    conv(dimension_numbers=dimension_numbers,
                         out_chan=width,
                         strides=strides,
                         filter_shape=filter_shape,
                         padding='VALID',
                         W_std=1.,
                         b_std=0.1),
                )),
            (stax.FanInSum() if concat is None else stax.FanInConcat(concat)),
            get_attn(),
            {
                'avg': stax.GlobalAvgPool(),
                'sum': stax.GlobalSumPool(),
                'flatten': stax.Flatten(),
            }[proj],
        )

        if get == 'nngp':
            init_fn, apply_fn, kernel_fn = stax.serial(
                nn, stax.Dense(width, 1., 0.))
        elif get == 'ntk':
            init_fn, apply_fn, kernel_fn = stax.serial(nn,
                                                       stax.Dense(1, 1., 0.))
        else:
            raise ValueError(get)

        kernel_fn_mc = nt.monte_carlo_kernel_fn(
            init_fn,
            apply_fn,
            key,
            n_samples,
            device_count=0 if concat in (0, -n) else -1,
            implementation=_DEFAULT_TESTING_NTK_IMPLEMENTATION,
            vmap_axes=None if concat in (0, -n) else 0,
        )

        kernel_fn = jit(kernel_fn, static_argnames='get')
        exact = kernel_fn(x1, x2, get, mask_constant=mask_constant)
        empirical = kernel_fn_mc(x1, x2, get=get, mask_constant=mask_constant)
        test_utils.assert_close_matrices(self, empirical, exact, tol)
Ejemplo n.º 13
0
    def test_input_req(self, same_inputs):
        test_utils.skip_test(self)

        key = random.PRNGKey(1)
        x1 = random.normal(key, (2, 7, 8, 4, 3))
        x2 = None if same_inputs else random.normal(key, (4, 7, 8, 4, 3))

        _, _, wrong_conv_fn = stax.serial(
            stax.Conv(out_chan=1,
                      filter_shape=(1, 2, 3),
                      dimension_numbers=('NDHWC', 'HDWIO', 'NCDWH')),
            stax.Relu(),
            stax.Conv(out_chan=1,
                      filter_shape=(1, 2, 3),
                      dimension_numbers=('NHDWC', 'HWDIO', 'NCWHD')))
        with self.assertRaises(ValueError):
            wrong_conv_fn(x1, x2)

        init_fn, apply_fn, correct_conv_fn = stax.serial(
            stax.Conv(out_chan=1024,
                      filter_shape=(1, 2, 3),
                      dimension_numbers=('NHWDC', 'DHWIO', 'NCWDH')),
            stax.Relu(),
            stax.Conv(out_chan=1024,
                      filter_shape=(1, 2, 3),
                      dimension_numbers=('NCHDW', 'WHDIO', 'NCDWH')),
            stax.Flatten(), stax.Dense(1024))

        correct_conv_fn_mc = nt.monte_carlo_kernel_fn(
            init_fn=init_fn,
            apply_fn=apply_fn,
            key=key,
            n_samples=400,
            implementation=_DEFAULT_TESTING_NTK_IMPLEMENTATION,
            vmap_axes=0)
        K = correct_conv_fn(x1, x2, get='nngp')
        K_mc = correct_conv_fn_mc(x1, x2, get='nngp')
        self.assertAllClose(K, K_mc, atol=0.01, rtol=0.05)

        _, _, wrong_conv_fn = stax.serial(
            stax.Conv(out_chan=1,
                      filter_shape=(1, 2, 3),
                      dimension_numbers=('NDHWC', 'HDWIO', 'NCDWH')),
            stax.GlobalAvgPool(channel_axis=2))
        with self.assertRaises(ValueError):
            wrong_conv_fn(x1, x2)

        init_fn, apply_fn, correct_conv_fn = stax.serial(
            stax.Conv(out_chan=1024,
                      filter_shape=(1, 2, 3),
                      dimension_numbers=('NHDWC', 'DHWIO', 'NDWCH')),
            stax.Relu(), stax.AvgPool((2, 1, 3), batch_axis=0,
                                      channel_axis=-2),
            stax.Conv(out_chan=1024,
                      filter_shape=(1, 2, 3),
                      dimension_numbers=('NDHCW', 'IHWDO', 'NDCHW')),
            stax.Relu(), stax.GlobalAvgPool(channel_axis=2), stax.Dense(1024))

        correct_conv_fn_mc = nt.monte_carlo_kernel_fn(
            init_fn=init_fn,
            apply_fn=apply_fn,
            key=key,
            n_samples=300,
            implementation=_DEFAULT_TESTING_NTK_IMPLEMENTATION,
            vmap_axes=0)
        K = correct_conv_fn(x1, x2, get='nngp')
        K_mc = correct_conv_fn_mc(x1, x2, get='nngp')
        self.assertAllClose(K, K_mc, atol=0.01, rtol=0.05)

        _, _, wrong_conv_fn = stax.serial(
            stax.Flatten(),
            stax.Dense(1),
            stax.Erf(),
            stax.Conv(out_chan=1,
                      filter_shape=(1, 2),
                      dimension_numbers=('CN', 'IO', 'NC')),
        )
        with self.assertRaises(ValueError):
            wrong_conv_fn(x1, x2)

        init_fn, apply_fn, correct_conv_fn = stax.serial(
            stax.Flatten(), stax.Conv(out_chan=1024, filter_shape=()),
            stax.Relu(), stax.Dense(1))

        correct_conv_fn_mc = nt.monte_carlo_kernel_fn(
            init_fn=init_fn,
            apply_fn=apply_fn,
            key=key,
            n_samples=200,
            implementation=_DEFAULT_TESTING_NTK_IMPLEMENTATION,
            vmap_axes=0)
        K = correct_conv_fn(x1, x2, get='ntk')
        K_mc = correct_conv_fn_mc(x1, x2, get='ntk')
        self.assertAllClose(K, K_mc, atol=0.01, rtol=0.05)
Ejemplo n.º 14
0
  def test_activations(
      self,
      get,
      parameterization,
      parameterization_out,
      x1_type,
      x2_type,
      b_std,
      phi,
      do_jit
  ):
    """Tests forward- and reverse-mode autodiff for nonlinearities."""
    if phi == stax.ABRelu:
      phi_ = phi(0.25, 0.5)
    else:
      phi_ = phi()

    if phi not in [stax.Relu]:
      test_utils.skip_test(self)

    n_out = 1 if get == 'ntk' else 1024
    width = 832

    W_std_in = width**(-0.5) if parameterization_out == 'standard' else 1.
    if phi == stax.Exp:
      W_std_in /= 10.

    init_fn, apply_fn, kernel_fn = stax.serial(
        stax.Dense(
            width,
            W_std=W_std_in,
            b_std=b_std,
            parameterization=parameterization),
        phi_,
        stax.Dense(
            n_out,
            b_std=b_std,
            parameterization=parameterization_out
        ),
    )

    def get_x(x_type, key):
      shape = (1, 2)
      if x_type == 'zeros':
        x = np.zeros(shape)
      elif x_type == 'ones':
        x = np.ones(shape)
      elif x_type == 'random':
        x = random.normal(random.PRNGKey(key), shape)
      elif x_type == 'sin':
        x = np.sin(random.normal(random.PRNGKey(key), shape))
      elif x_type == 'none':
        return None
      else:
        raise ValueError(x_type)
      return x

    x1 = get_x(x1_type, 1)
    if x2_type == 'x1':
      x2 = x1
    else:
      x2 = get_x(x2_type, 2)

    def kernel_scalar(x1, x2):
      return kernel_fn(x1, x2, get)[0, 0]

    if do_jit:
      kernel_scalar = jit(kernel_scalar)

    k1 = kernel_scalar(x1, x2)
    k2, k2_grad = value_and_grad(kernel_scalar)(x1, x2)
    self.assertAllClose(k1, k2)

    # Compare to forward-mode.
    k2_fwd, _ = jvp(kernel_scalar, (x1, x2), (x1, x2))
    k2_grad_fwd = jacfwd(kernel_scalar)(x1, x2)
    self.assertAllClose(k1, k2_fwd)
    self.assertAllClose(k2_grad, k2_grad_fwd)

    # `stax.ExpNormalized` has no forward pass.
    # `stax.Sign` is discontinuous at `0`, so NTK MC kernel does not converge to
    # infinite-width kernel.
    if phi == stax.ExpNormalized or (get == 'ntk' and phi == stax.Sign):
      raise absltest.SkipTest('Not comparing against MC kernels.')

    _kernel_scalar_mc = nt.monte_carlo_kernel_fn(
        init_fn,
        apply_fn,
        key=random.PRNGKey(3),
        n_samples=1,
        device_count=0,
    )

    def kernel_scalar_mc(x1, x2):
      return _kernel_scalar_mc(x1, x2, get)[0, 0]

    k_mc = kernel_scalar_mc(x1, x2)
    k_mc2, k_mc2_grad = value_and_grad(kernel_scalar_mc)(x1, x2)
    self.assertAllClose(k_mc, k_mc2)

    # Compare MC to forward-mode.
    k_mc2_fwd, _ = jvp(kernel_scalar_mc, (x1, x2), (x1, x2))
    k_mc2_grad_fwd = jacfwd(kernel_scalar_mc)(x1, x2)
    self.assertAllClose(k_mc, k_mc2_fwd)
    self.assertAllClose(k_mc2_grad, k_mc2_grad_fwd)

    def kernel_fn_emp(x1, x2, get, params):
      return nt.empirical_kernel_fn(apply_fn)(x1, x2, get, params)[0, 0]

    kernel_fn_emp_g = jit(value_and_grad(kernel_fn_emp), static_argnames='get')

    def kernel_scalar_mc_grad_mean(x1, x2):
      key = random.PRNGKey(4)
      n_samples = 2**9
      k, k_grad = 0., 0.

      for _ in range(n_samples):
        _, params = init_fn(key, x1.shape)
        k_mc2, k_mc2_grad = kernel_fn_emp_g(x1, x2, get, params)
        k += k_mc2
        k_grad += k_mc2_grad
        key, _ = random.split(key)

      k /= n_samples
      k_grad /= n_samples
      return k, k_grad

    k_mc2_mean, k_mc2_grad_mean = kernel_scalar_mc_grad_mean(x1, x2)

    # Compare kernels.
    self.assertAllClose(k1, k_mc2_mean, atol=4e-3, rtol=4e-2)

    if phi == stax.Sign and get == 'nngp':
      raise absltest.SkipTest('Derivative of the empirical NNGP of a '
                              'discontinuous function does not converge '
                              'to the derivative of the infinite width NNGP.')

    if (phi in [stax.Abs, stax.Relu, stax.LeakyRelu, stax.ABRelu] and
        get == 'ntk'):
      raise absltest.SkipTest('Derivative of the empirical NTK of a '
                              'non-differentiable function does not converge '
                              'to the derivative of the infinite width NTK.')

    atol = 1e-2

    # Compare gradient of the analytic kernel to empirical kernel.
    if np.max(np.abs(k2_grad - k_mc2_grad_mean)) > atol:
      test_utils.assert_close_matrices(self,
                                       k_mc2_grad_mean,
                                       k2_grad,
                                       rtol=0.05,
                                       atol=10.)