示例#1
0
class ParallelInOutTest(test_utils.NeuralTangentsTestCase):

  @parameterized.named_parameters(
      test_utils.cases_from_list({
          'testcase_name':
              f'_same_inputs={same_inputs}_kernel_type={kernel_type}',
          'same_inputs': same_inputs,
          'kernel_type': kernel_type
      }
                          for same_inputs in [True, False]
                          for kernel_type in ['ntk']))
  def test_parallel_in(self, same_inputs, kernel_type):
    platform = default_backend()
    rtol = RTOL if platform != 'tpu' else 0.05

    rng = random.PRNGKey(0)
    input_key1, input_key2, mc_key = random.split(rng, 3)

    x1_1, x2_1 = _get_inputs(input_key1, same_inputs, (BATCH_SIZE, 2))
    x1_2, x2_2 = _get_inputs(input_key2, same_inputs, (BATCH_SIZE, 3))

    x1 = (x1_1, x1_2)
    x2 = (x2_1, x2_2)

    N = 2 ** 7

    def net(logits):
      return stax.serial(
          stax.parallel(stax.Dense(N), stax.Dense(N)),
          stax.serial(stax.FanInSum(), stax.Dense(logits)))

    init_fn, apply_fn, kernel_fn = net(N if kernel_type == 'nngp' else 1)

    kernel_fn_empirical = nt.monte_carlo_kernel_fn(
        init_fn, apply_fn, mc_key, N_SAMPLES, trace_axes=(-1,),
        implementation=2,
        vmap_axes=((0, 0), 0, {})
    )
    test_utils.assert_close_matrices(self,
                                     kernel_fn(x1, x2, kernel_type),
                                     kernel_fn_empirical(x1, x2, kernel_type),
                                     rtol)

  @parameterized.named_parameters(
      test_utils.cases_from_list({
          'testcase_name':
              f'_same_inputs={same_inputs}_kernel_type={kernel_type}',
          'same_inputs': same_inputs,
          'kernel_type': kernel_type
      } for same_inputs in [True, False] for kernel_type in ['ntk']))
  def test_parallel_out(self, same_inputs, kernel_type):
    platform = default_backend()
    rtol = RTOL if platform != 'tpu' else 0.05

    rng = random.PRNGKey(0)
    input_key1, mc_key = random.split(rng, 2)

    x1, x2 = _get_inputs(input_key1, same_inputs, (BATCH_SIZE, 1))

    N = 2 ** 10

    def net(logits):
      return stax.serial(
          stax.Dense(N),
          stax.FanOut(2),
          stax.parallel(stax.Dense(logits), stax.Dense(logits)))

    init_fn, apply_fn, kernel_fn = net(N if kernel_type == 'nngp' else 1)

    kernel_fn_empirical = nt.monte_carlo_kernel_fn(
        init_fn, apply_fn, mc_key, N_SAMPLES, trace_axes=(-1,),
        implementation=2,
        vmap_axes=(0, [0, 0], {}))

    test_utils.assert_close_matrices(self,
                                     kernel_fn(x1, x2, kernel_type),
                                     kernel_fn_empirical(x1, x2, kernel_type),
                                     rtol)

  @parameterized.named_parameters(
      test_utils.cases_from_list({
          'testcase_name':
              f'_same_inputs={same_inputs}_kernel_type={kernel_type}',
          'same_inputs': same_inputs,
          'kernel_type': kernel_type,
      } for same_inputs in [True, False] for kernel_type in ['ntk']))
  def test_parallel_in_out(self, same_inputs, kernel_type):
    platform = default_backend()
    rtol = RTOL if platform != 'tpu' else 0.05

    rng = random.PRNGKey(0)
    input_key1, input_key2, mc_key = random.split(rng, 3)

    x1_1, x2_1 = _get_inputs(input_key1, same_inputs, (BATCH_SIZE, 1))
    x1_2, x2_2 = _get_inputs(input_key2, same_inputs, (BATCH_SIZE, 2))

    x1 = (x1_1, x1_2)
    x2 = (x2_1, x2_2)

    N_in = 2 ** 10
    N_out = N_in if kernel_type == 'nngp' else 1

    readin = stax.serial(stax.parallel(stax.Dense(N_in), stax.Dense(N_in)),
                         stax.FanInSum())
    readout = stax.serial(stax.FanOut(3),
                          stax.parallel(stax.Dense(N_out),
                                        stax.Dense(N_out + 1),
                                        stax.Dense(N_out + 2)))
    init_fn, apply_fn, _ = stax.serial(readin, readout)

    K_readin_fn = jit(readin[2])
    K_readout_fn = jit(functools.partial(readout[2], get=kernel_type))

    kernel_fn_empirical = nt.monte_carlo_kernel_fn(
        init_fn, apply_fn, mc_key, N_SAMPLES, trace_axes=(-1,),
        implementation=2,
        vmap_axes=((0, 0), [0, 0, 0], {})
    )

    test_utils.assert_close_matrices(
        self,
        K_readout_fn(K_readin_fn(x1, x2)),
        kernel_fn_empirical(x1, x2, get=kernel_type),
        rtol)

    # Check Both (here we just want to make sure we _can_ compute the output).
    K_readin_fn = jit(readin[2])
    K_readout_fn = jit(functools.partial(readout[2], get=('nngp', 'ntk')))

    K_readout_fn(K_readin_fn(x1, x2))

  @parameterized.named_parameters(
      test_utils.cases_from_list({
          'testcase_name':
              f'_same_inputs={same_inputs}_kernel_type={kernel_type}',
          'same_inputs': same_inputs,
          'kernel_type': kernel_type,
      } for same_inputs in [True, False] for kernel_type in ['ntk']))
  def test_nested_parallel(self, same_inputs, kernel_type):
    platform = default_backend()
    rtol = RTOL if platform != 'tpu' else 0.05

    rng = random.PRNGKey(0)
    (input_key1,
     input_key2,
     input_key3,
     input_key4,
     mask_key,
     mc_key) = random.split(rng, 6)

    x1_1, x2_1 = _get_inputs(input_key1, same_inputs, (BATCH_SIZE, 5))
    x1_2, x2_2 = _get_inputs(input_key2, same_inputs, (BATCH_SIZE, 2, 2, 2))
    x1_3, x2_3 = _get_inputs(input_key3, same_inputs, (BATCH_SIZE, 2, 2, 3))
    x1_4, x2_4 = _get_inputs(input_key4, same_inputs, (BATCH_SIZE, 3, 4))

    m1_key, m2_key, m3_key, m4_key = random.split(mask_key, 4)

    x1_1 = test_utils.mask(
        x1_1, mask_constant=-1, mask_axis=(1,), key=m1_key, p=0.5)
    x1_2 = test_utils.mask(
        x1_2, mask_constant=-1, mask_axis=(2, 3,), key=m2_key, p=0.5)
    if not same_inputs:
      x2_3 = test_utils.mask(
          x2_3, mask_constant=-1, mask_axis=(1, 3,), key=m3_key, p=0.5)
      x2_4 = test_utils.mask(
          x2_4, mask_constant=-1, mask_axis=(2,), key=m4_key, p=0.5)

    x1 = (((x1_1, x1_2), x1_3), x1_4)
    x2 = (((x2_1, x2_2), x2_3), x2_4) if not same_inputs else None

    N_in = 2 ** 7

    # We only include dropout on non-TPU backends, because it takes large N to
    # converge on TPU.
    dropout_or_id = stax.Dropout(0.9) if platform != 'tpu' else stax.Identity()

    init_fn, apply_fn, kernel_fn = stax.parallel(
        stax.parallel(
            stax.parallel(stax.Dense(N_in),
                          stax.serial(stax.Conv(N_in + 1, (2, 2)),
                                      stax.Flatten())),
            stax.serial(stax.Conv(N_in + 2, (2, 2)),
                        dropout_or_id,
                        stax.GlobalAvgPool())),
        stax.Conv(N_in + 3, (2,)))

    kernel_fn_empirical = nt.monte_carlo_kernel_fn(
        init_fn, apply_fn, mc_key, N_SAMPLES, implementation=2,
        vmap_axes=(((((0, 0), 0), 0), (((0, 0), 0), 0), {})
                   if platform == 'tpu' else None)
    )

    test_utils.assert_close_matrices(
        self,
        kernel_fn(x1, x2, get=kernel_type, mask_constant=-1),
        kernel_fn_empirical(x1, x2, get=kernel_type, mask_constant=-1),
        rtol)
示例#2
0
class StaxTest(test_utils.NeuralTangentsTestCase):

  def _skip_test(self, filter_shape, is_conv, is_res, padding, proj_into_2d,
                 strides, use_pooling):
    if is_conv:
      test_utils.skip_test(self)

      if (is_res and is_conv and ((strides is not None and strides != (1, 1)) or
                                  (padding == 'VALID' and filter_shape !=
                                   (1, 1)))):
        raise absltest.SkipTest('Different paths in a residual models need to '
                                'return outputs of the same shape.')
    elif (filter_shape != FILTER_SHAPES[0] or padding != PADDINGS[0] or
          strides != STRIDES[0] or proj_into_2d != PROJECTIONS[0] or
          use_pooling):
      raise absltest.SkipTest('FC models do not have these parameters.')

  @parameterized.named_parameters(
      test_utils.cases_from_list({
          'testcase_name':
              '_{}_{}_{}_{}_{}_{}_{}_{}_{}_{}_{}'.format(
                  model, phi_name, width, 'same_inputs'
                  if same_inputs else 'different_inputs', 'filter_shape=%s' %
                  str(filter_shape), 'padding=%s' % padding, 'strides=%s' %
                  str(strides), 'pool' if use_pooling else 'flatten',
                  'NTK' if is_ntk else 'NNGP', 'RESNET' if is_res else 'serial',
                  proj_into_2d),
          'model':
              model,
          'width':
              width,
          'strides':
              strides,
          'padding':
              padding,
          'phi':
              phi,
          'same_inputs':
              same_inputs,
          'filter_shape':
              filter_shape,
          'use_pooling':
              use_pooling,
          'is_ntk':
              is_ntk,
          'is_res':
              is_res,
          'proj_into_2d':
              proj_into_2d
      }
                          for model in MODELS
                          for width in WIDTHS
                          for phi, phi_name in ACTIVATIONS.items()
                          for same_inputs in [False]
                          for padding in PADDINGS for strides in STRIDES
                          for filter_shape in FILTER_SHAPES
                          for use_pooling in [False, True]
                          for is_ntk in [False, True]
                          for is_res in [False, True]
                          for proj_into_2d in PROJECTIONS))
  def test_exact(self, model, width, strides, padding, phi, same_inputs,
                 filter_shape, use_pooling, is_ntk, is_res, proj_into_2d):
    is_conv = 'conv' in model

    # Check for duplicate / incorrectly-shaped NN configs / wrong backend.
    self._skip_test(filter_shape, is_conv, is_res, padding, proj_into_2d,
                    strides, use_pooling)

    pool_type = 'AVG'
    W_std, b_std = 2.**0.5, 0.5**0.5
    layer_norm = None
    parameterization = 'ntk'
    use_dropout = False

    net = _get_net(W_std, b_std, filter_shape, is_conv, use_pooling, is_res,
                   padding, phi, strides, width, is_ntk, proj_into_2d,
                   pool_type, layer_norm, parameterization, 1, use_dropout)
    _check_agreement_with_empirical(
        self, net, same_inputs, use_dropout, is_ntk, RTOL, 1.1)

  @parameterized.named_parameters(
      test_utils.cases_from_list({
          'testcase_name':
              '_{}_{}_{}_{}_{}_{}'.format(
                  model,
                  width,
                  'same_inputs' if same_inputs else 'different_inputs',
                  'NTK' if is_ntk else 'NNGP',
                  proj_into_2d,
                  'layer_norm=%s' % str(layer_norm)),
          'model':
              model,
          'width':
              width,
          'same_inputs':
              same_inputs,
          'is_ntk':
              is_ntk,
          'proj_into_2d':
              proj_into_2d,
          'layer_norm':
              layer_norm
      }
                          for model in MODELS
                          for width in WIDTHS
                          for same_inputs in [False]
                          for is_ntk in [False, True]
                          for proj_into_2d in PROJECTIONS[:2]
                          for layer_norm in LAYER_NORM))
  def test_layernorm(self,
                     model,
                     width,
                     same_inputs,
                     is_ntk,
                     proj_into_2d,
                     layer_norm):
    is_conv = 'conv' in model
    # Check for duplicate / incorrectly-shaped NN configs / wrong backend.
    if is_conv:
      test_utils.skip_test(self)
    elif proj_into_2d != PROJECTIONS[0] or layer_norm not in ('C', 'NC'):
      raise absltest.SkipTest('FC models do not have these parameters.')

    W_std, b_std = 2.**0.5, 0.5**0.5
    filter_shape = FILTER_SHAPES[0]
    padding = PADDINGS[0]
    strides = STRIDES[0]
    phi = stax.Relu()
    use_pooling, is_res = False, False
    parameterization = 'ntk'
    pool_type = 'AVG'
    use_dropout = False

    net = _get_net(W_std, b_std, filter_shape, is_conv, use_pooling, is_res,
                   padding, phi, strides, width, is_ntk, proj_into_2d,
                   pool_type, layer_norm, parameterization, 1, use_dropout)
    _check_agreement_with_empirical(self, net, same_inputs, use_dropout, is_ntk,
                                    0.07)

  @parameterized.named_parameters(
      test_utils.cases_from_list({
          'testcase_name':
              '_{}_{}_{}_{}_{}_{}_{}_{}'.format(
                  width, 'same_inputs' if same_inputs else 'different_inputs',
                  'filter_shape=%s' % str(filter_shape), 'padding=%s' %
                  padding, 'strides=%s' % str(strides),
                  'NTK' if is_ntk else 'NNGP', 'pool_type=%s' %
                  str(pool_type), 'normalize_edges=%s' % str(normalize_edges)),
          'width':
              width,
          'same_inputs':
              same_inputs,
          'is_ntk':
              is_ntk,
          'pool_type':
              pool_type,
          'padding':
              padding,
          'filter_shape':
              filter_shape,
          'strides':
              strides,
          'normalize_edges':
              normalize_edges
      } for width in WIDTHS for same_inputs in [False]
                          for is_ntk in [False, True]
                          for pool_type in POOL_TYPES for padding in PADDINGS
                          for filter_shape in FILTER_SHAPES
                          for strides in STRIDES
                          for normalize_edges in [True, False]))
  def test_pool(self, width, same_inputs, is_ntk, pool_type,
                padding, filter_shape, strides, normalize_edges):
    use_dropout = False
    # Check for duplicate / incorrectly-shaped NN configs / wrong backend.
    test_utils.skip_test(self)
    if pool_type == 'SUM' and normalize_edges:
      raise absltest.SkipTest('normalize_edges not applicable to SumPool.')

    net = _get_net_pool(width, is_ntk, pool_type,
                        padding, filter_shape, strides, normalize_edges)
    _check_agreement_with_empirical(self, net, same_inputs, use_dropout, is_ntk)

  def test_avg_pool(self):
    X1 = np.ones((4, 2, 3, 2))
    X2 = np.ones((3, 2, 3, 2))

    _, apply_fn, kernel_fn = stax.AvgPool((2, 2), (1, 1), 'SAME',
                                          normalize_edges=False)
    _, apply_fn_norm, kernel_fn_norm = stax.AvgPool((2, 2), (1, 1), 'SAME',
                                                    normalize_edges=True)
    _, apply_fn_stax = ostax.AvgPool((2, 2), (1, 1), 'SAME')

    out1 = apply_fn((), X1)
    out2 = apply_fn((), X2)

    out1_norm = apply_fn_norm((), X1)
    out2_norm = apply_fn_norm((), X2)

    out1_stax = apply_fn_stax((), X1)
    out2_stax = apply_fn_stax((), X2)

    self.assertAllClose((out1_stax, out2_stax), (out1_norm, out2_norm))

    out_unnorm = np.array([[1., 1., 0.5], [0.5, 0.5, 0.25]]).reshape(
        (1, 2, 3, 1))
    out1_unnormalized = np.broadcast_to(out_unnorm, X1.shape)
    out2_unnormalized = np.broadcast_to(out_unnorm, X2.shape)

    self.assertAllClose((out1_unnormalized, out2_unnormalized), (out1, out2))

    ker = kernel_fn(X1, X2)
    ker_norm = kernel_fn_norm(X1, X2)

    self.assertAllClose(np.ones_like(ker_norm.nngp), ker_norm.nngp)
    self.assertAllClose(np.ones_like(ker_norm.cov1), ker_norm.cov1)
    self.assertAllClose(np.ones_like(ker_norm.cov2), ker_norm.cov2)

    self.assertEqual(ker_norm.nngp.shape, ker.nngp.shape)
    self.assertEqual(ker_norm.cov1.shape, ker.cov1.shape)
    self.assertEqual(ker_norm.cov2.shape, ker.cov2.shape)

    ker_unnorm = np.outer(out_unnorm, out_unnorm).reshape((2, 3, 2, 3))
    ker_unnorm = np.transpose(ker_unnorm, axes=(0, 2, 1, 3))
    nngp = np.broadcast_to(
        ker_unnorm.reshape((1, 1) + ker_unnorm.shape), ker.nngp.shape)
    cov1 = np.broadcast_to(np.expand_dims(ker_unnorm, 0), ker.cov1.shape)
    cov2 = np.broadcast_to(np.expand_dims(ker_unnorm, 0), ker.cov2.shape)
    self.assertAllClose((nngp, cov1, cov2), (ker.nngp, ker.cov1, ker.cov2))

  @parameterized.named_parameters(
      test_utils.cases_from_list({
          'testcase_name':
              '_{}_{}_{}_{}_{}_{}_{}_{}_{}_{}'.format(
                  model, phi_name, width, 'same_inputs'
                  if same_inputs else 'different_inputs', 'filter_shape=%s' %
                  str(filter_shape), 'padding=%s' % padding, 'strides=%s' %
                  str(strides), 'pool' if use_pooling else 'flatten',
                  'NTK' if is_ntk else 'NNGP', proj_into_2d),
          'model':
              model,
          'width':
              width,
          'same_inputs':
              same_inputs,
          'is_ntk':
              is_ntk,
          'padding':
              padding,
          'strides':
              strides,
          'filter_shape':
              filter_shape,
          'phi':
              phi,
          'use_pooling':
              use_pooling,
          'proj_into_2d':
              proj_into_2d
      } for model in MODELS for width in WIDTHS
                          for same_inputs in [True, False]
                          for phi, phi_name in ACTIVATIONS.items()
                          for padding in ['SAME'] for strides in STRIDES
                          for filter_shape in [(2, 1)]
                          for is_ntk in [True, False]
                          for use_pooling in [True, False]
                          for proj_into_2d in ['FLAT', 'POOL']))
  def test_dropout(self, model, width, same_inputs, is_ntk, padding, strides,
                   filter_shape, phi, use_pooling, proj_into_2d):
    pool_type = 'AVG'
    use_dropout = True
    is_conv = 'conv' in model
    is_res = False
    W_std, b_std = 2.**0.5, 0.5**0.5
    layer_norm = None
    parameterization = 'ntk'
    # Check for duplicate / incorrectly-shaped NN configs / wrong backend.
    self._skip_test(filter_shape, is_conv, is_res, padding, proj_into_2d,
                    strides, use_pooling)

    net = _get_net(W_std, b_std, filter_shape, is_conv, use_pooling, is_res,
                   padding, phi, strides, width, is_ntk, proj_into_2d,
                   pool_type, layer_norm, parameterization, 1, use_dropout)
    _check_agreement_with_empirical(self, net, same_inputs, use_dropout, is_ntk)

  @parameterized.named_parameters(
      test_utils.cases_from_list({
          'testcase_name':
              f'_act={act}_kernel={kern}_do_stabilize={do_stabilize}',
          'act': act,
          'kernel': kern,
          'do_stabilize': do_stabilize
      }
                          for act in ['erf', 'relu']
                          for do_stabilize in [True, False]
                          for kern in ['nngp', 'ntk']))
  def test_sparse_inputs(self, act, kernel, do_stabilize):
    if do_stabilize and act != 'relu':
      raise absltest.SkipTest('Stabilization possible only in Relu.')

    key = random.PRNGKey(1)

    input_count = 4
    sparse_count = 2
    input_size = 3
    width = 1024

    # NOTE(schsam): It seems that convergence is slower when inputs are sparse.
    samples = N_SAMPLES

    if default_backend() == 'gpu':
      tol = 5e-4
      samples = 100 * N_SAMPLES
    else:
      tol = {onp.dtype(onp.float32): 5e-2, onp.dtype(onp.float64): 5e-3}

    # a batch of dense inputs
    x_dense = random.normal(key, (input_count, input_size))
    x_sparse = x_dense.at[:sparse_count, :].set(0.)

    activation = (stax.Relu(do_stabilize=do_stabilize) if act == 'relu'
                  else stax.Erf())

    init_fn, apply_fn, kernel_fn = stax.serial(
        stax.Dense(width),
        activation,
        stax.Dense(1 if kernel == 'ntk' else width))
    exact = kernel_fn(x_sparse, None, kernel)

    mc = nt.monte_carlo_kernel_fn(
        init_fn,
        apply_fn,
        random.split(key, 2)[0],
        samples,
        vmap_axes=0,
        device_count=-1,
        implementation=2
    )(x_sparse, None, kernel)
    mc = np.reshape(mc, exact.shape)

    assert not np.any(np.isnan(exact))
    self.assertAllClose(exact[sparse_count:, sparse_count:],
                        mc[sparse_count:, sparse_count:],
                        rtol=tol, atol=tol)

  def test_composition_dense(self):
    rng = random.PRNGKey(0)
    x1 = random.normal(rng, (2, 3))
    x2 = random.normal(rng, (4, 3))

    Block = stax.serial(stax.Dense(256), stax.Relu())

    _, _, ker_fn = Block
    _, _, composed_ker_fn = stax.serial(Block, Block)

    ker_out = ker_fn(ker_fn(x1))
    composed_ker_out = composed_ker_fn(x1)
    self.assertAllClose(ker_out, composed_ker_out)

    ker_out = ker_fn(ker_fn(x1, x2))
    composed_ker_out = composed_ker_fn(x1, x2)
    self.assertAllClose(ker_out, composed_ker_out)

  @parameterized.named_parameters(
      test_utils.cases_from_list({
          'testcase_name': '_avg_pool={}_same_inputs={}'.format(avg_pool,
                                                                same_inputs),
          'avg_pool': avg_pool,
          'same_inputs': same_inputs
      } for avg_pool in [True, False] for same_inputs in [True, False]))
  def test_composition_conv(self, avg_pool, same_inputs):
    rng = random.PRNGKey(0)
    x1 = random.normal(rng, (3, 5, 5, 3))
    x2 = None if same_inputs else random.normal(rng, (4, 5, 5, 3))

    Block = stax.serial(stax.Conv(256, (3, 3)), stax.Relu())
    if avg_pool:
      Readout = stax.serial(stax.Conv(256, (3, 3)),
                            stax.GlobalAvgPool(),
                            stax.Dense(10))
    else:
      Readout = stax.serial(stax.Flatten(), stax.Dense(10))

    block_ker_fn, readout_ker_fn = Block[2], Readout[2]
    _, _, composed_ker_fn = stax.serial(Block, Readout)

    composed_ker_out = composed_ker_fn(x1, x2)
    ker_out_no_marg = readout_ker_fn(block_ker_fn(x1, x2,
                                                  diagonal_spatial=False))
    ker_out_default = readout_ker_fn(block_ker_fn(x1, x2))
    self.assertAllClose(composed_ker_out, ker_out_no_marg)
    self.assertAllClose(composed_ker_out, ker_out_default)

    if avg_pool:
      with self.assertRaises(ValueError):
        ker_out = readout_ker_fn(block_ker_fn(x1, x2, diagonal_spatial=True))
    else:
      ker_out_marg = readout_ker_fn(block_ker_fn(x1, x2,
                                                 diagonal_spatial=True))
      self.assertAllClose(composed_ker_out, ker_out_marg)
示例#3
0
class ParameterizationTest(test_utils.NeuralTangentsTestCase):


  @parameterized.named_parameters(
      test_utils.cases_from_list({
                              'testcase_name':
                                f'_get={get}'
                                f'_s={s}'
                                f'_depth={depth}'
                                f'_same_inputs={same_inputs}'
                                f'_b_std={b_std}_'
                                f'_W_std={W_std}'
                                f'_param={parameterization}',
                              'get':
                                get,
                              's':
                                s,
                              'depth':
                                depth,
                              'same_inputs':
                                same_inputs,
                              'b_std':
                                b_std,
                              'W_std':
                                W_std,
                              'parameterization':
                                parameterization,
                          }
                          for get in ['nngp', 'ntk']
                          for s in [2**9, 2**8, 2**7]
                          for depth in [0, 1, 2]
                          for same_inputs in [True, False]
                          for W_std in [0., 1., 2.]
                          for b_std in [None, 0., 0.5**0.5, 2]
                          for parameterization in ['ntk', 'standard']))
  def test_linear(
      self,
      get,
      s,
      depth,
      same_inputs,
      b_std,
      W_std,
      parameterization,
  ):
    if parameterization == 'standard':
      width = 2**9 // s
    elif parameterization == 'ntk':
      if s != 2**9:
        raise absltest.SkipTest(
            '"ntk" parameterization does not depend on "s".')
      width = 2**10
    else:
      raise ValueError(parameterization)

    layers = []
    for i in range(depth + 1):
      s_in = 1 if i == 0 else s
      s_out = 1 if (i == depth and get == 'ntk') else s
      out_dim = 1 if (i == depth and get == 'ntk') else width * (i + 1)
      layers += [stax.Dense(out_dim,
                            W_std=W_std / (i + 1),
                            b_std=b_std if b_std is None else b_std / (i + 1),
                            parameterization=parameterization,
                            s=(s_in, s_out))]

    net = stax.serial(*layers)
    net = net, (BATCH_SIZE, 3), -1, 1
    _check_agreement_with_empirical(self, net, same_inputs, False, get == 'ntk',
                                    rtol=0.02, atol=10)


  @parameterized.named_parameters(
      test_utils.cases_from_list({
                              'testcase_name':
                                f'_model={model}'
                                f'_width={width}'
                                f'_same_inputs={same_inputs}'
                                f'_filter_shape={filter_shape}'
                                f'_proj={proj_into_2d}_'
                                f'_is_ntk={is_ntk}_'
                                f'_b_std={b_std}_'
                                f'_W_std={W_std}'
                                f'_param={parameterization}'
                                f'_s={s}',
                              'model':
                                model,
                              'width':
                                width,
                              'same_inputs':
                                same_inputs,
                              'filter_shape':
                                filter_shape,
                              'proj_into_2d':
                                proj_into_2d,
                              'is_ntk':
                                is_ntk,
                              'b_std':
                                b_std,
                              'W_std':
                                W_std,
                              'parameterization':
                                parameterization,
                              's':
                                s
                          }
                          for model in MODELS
                          for width in [2**11]
                          for same_inputs in [False]
                          for is_ntk in [False, True]
                          for filter_shape in FILTER_SHAPES
                          for proj_into_2d in PROJECTIONS[:2]
                          for W_std in [0., 1., 2.]
                          for b_std in [None, 0., 0.5**0.5]
                          for parameterization in ['ntk', 'standard']
                          for s in [2**10]))
  def test_nonlinear(
      self,
      model,
      width,
      same_inputs,
      is_ntk,
      filter_shape,
      proj_into_2d,
      b_std,
      W_std,
      parameterization,
      s
  ):
    is_conv = 'conv' in model

    if parameterization == 'standard':
      width //= s

    padding = PADDINGS[0]
    strides = STRIDES[0]
    phi = stax.Relu()
    use_pooling, is_res = False, False
    layer_norm = None
    pool_type = 'AVG'
    use_dropout = False

    # Check for duplicate / incorrectly-shaped NN configs / wrong backend.
    if is_conv:
      test_utils.skip_test(self)
    elif proj_into_2d != PROJECTIONS[0] or filter_shape != FILTER_SHAPES[0]:
      raise absltest.SkipTest('FC models do not have these parameters.')

    net = _get_net(W_std=W_std,
                   b_std=b_std,
                   filter_shape=filter_shape,
                   is_conv=is_conv,
                   use_pooling=use_pooling,
                   is_res=is_res,
                   padding=padding,
                   phi=phi,
                   strides=strides,
                   width=width,
                   is_ntk=is_ntk,
                   proj_into_2d=proj_into_2d,
                   pool_type=pool_type,
                   layer_norm=layer_norm,
                   parameterization=parameterization,
                   s=s,
                   use_dropout=use_dropout)

    _check_agreement_with_empirical(
        self,
        net=net,
        same_inputs=same_inputs,
        use_dropout=use_dropout,
        is_ntk=is_ntk,
        rtol=0.015,
        atol=1000
    )
示例#4
0
class ElementwiseNumericalTest(test_utils.NeuralTangentsTestCase):

  @parameterized.named_parameters(
      test_utils.cases_from_list({
          'testcase_name':
              '_{}_{}_{}_{}'.format(
                  model,
                  phi[0].__name__,
                  'Same_inputs' if same_inputs else 'Different_inputs',
                  get),
          'model': model,
          'phi': phi,
          'same_inputs': same_inputs,
          'get': get,
      }
                          for model in ['fc', 'conv-pool', 'conv-flatten']
                          for phi in [
                              stax.Erf(),
                              stax.Gelu(),
                              stax.Sin(),
                          ]
                          for same_inputs in [False, True]
                          for get in ['nngp', 'ntk']))
  def test_elementwise_numerical(self, same_inputs, model, phi, get):
    if 'conv' in model:
      test_utils.skip_test(self)

    key, split = random.split(random.PRNGKey(1))

    output_dim = 1
    b_std = 0.01
    W_std = 1.0
    rtol = 2e-3
    deg = 25
    if get == 'ntk':
      rtol *= 2
    if default_backend() == 'tpu':
      rtol *= 2

    if model == 'fc':
      X0_1 = random.normal(key, (3, 7))
      X0_2 = None if same_inputs else random.normal(split, (5, 7))
      affine = stax.Dense(1024, W_std, b_std)
      readout = stax.Dense(output_dim)
      depth = 1
    else:
      X0_1 = random.normal(key, (2, 8, 8, 3))
      X0_2 = None if same_inputs else random.normal(split, (3, 8, 8, 3))
      affine = stax.Conv(1024, (3, 2), W_std=W_std, b_std=b_std, padding='SAME')
      readout = stax.serial(stax.GlobalAvgPool() if 'pool' in model else
                            stax.Flatten(),
                            stax.Dense(output_dim))
      depth = 2

    _, _, kernel_fn = stax.serial(*[affine, phi] * depth, readout)
    analytic_kernel = kernel_fn(X0_1, X0_2, get)

    fn = lambda x: phi[1]((), x)
    _, _, kernel_fn = stax.serial(
        *[affine, stax.ElementwiseNumerical(fn, deg=deg)] * depth, readout)
    numerical_activation_kernel = kernel_fn(X0_1, X0_2, get)

    test_utils.assert_close_matrices(self, analytic_kernel,
                                     numerical_activation_kernel, rtol)
示例#5
0
class AutodiffTest(test_utils.NeuralTangentsTestCase):

  @parameterized.named_parameters(
      test_utils.cases_from_list({
          'testcase_name': f'{get}-{same_inputs}-{phi.__name__}',
          'get': get,
          'same_inputs': same_inputs,
          'phi': phi,
      }
                          for get in [
                              'ntk',
                              'nngp'
                          ]
                          for same_inputs in [True, False, None]
                          for phi in [
                              stax.Erf,
                              stax.Sin,
                              stax.Gelu,
                              stax.Relu,
                              stax.ElementwiseNumerical
                          ]))
  def test_autodiff(self, get, same_inputs, phi):
    x1 = np.cos(random.normal(random.PRNGKey(1), (3, 1, 2, 3)))
    if same_inputs is None:
      x2 = None
    elif same_inputs is True:
      x2 = x1
    else:
      x2 = np.cos(random.normal(random.PRNGKey(2), (4, 1, 2, 3)))

    name = phi.__name__
    if name == 'LeakyRelu':
      phi = phi(0.1)
    elif name == 'ElementwiseNumerical':
      phi = phi(fn=np.cos, deg=25)
    else:
      phi = phi()

    _, _, kernel_fn = stax.serial(stax.Dense(1, 2., 0.01), phi,
                                  stax.Dense(1, 2., 0.01), phi)

    def k(x1, x2):
      return kernel_fn(x1, x2, get)

    dx1 = random.normal(random.PRNGKey(3), x1.shape) * 0.01
    if x2 is None:
      dx2 = None
    else:
      dx2 = random.normal(random.PRNGKey(4), x2.shape) * 0.01

    def dk(x1, x2):
      return jvp(k, (x1, x2), (dx1, dx2))[1]

    def d2k(x1, x2):
      return jvp(dk, (x1, x2), (dx1, dx2))[1]

    _dk = dk(x1, x2)

    if (same_inputs is not False and
        get == 'ntk' and
        ('Relu' in name or 'Abs' in name)):
      # TODO(romann): revisit numerical issues of second derivative of `Relu`
      _d2k = 0
      tol = 0.01
    else:
      _d2k = d2k(x1, x2)
      tol = 2e-3 if name == 'ElementwiseNumerical' else 1e-4

    def assert_close(x, y, tol=3e-5):
      if default_backend() == 'tpu':
        # TODO(romann): understand why TPUs have high errors.
        tol = 0.21
      self.assertLess(
          np.max(np.abs(x - y)) / (np.mean(np.abs(x)) + np.mean(np.abs(y))),
          tol)

    # k(x + dx) ~ k(x) + dk(x) dx + dx^T d2k(x) dx
    assert_close(k(x1 + dx1, None if same_inputs is None else x2 + dx2),
                 k(x1, x2) + _dk + _d2k / 2,
                 tol=tol)

    # d/dx1
    k_fwd_0 = jacfwd(k)(x1, x2)
    k_rev_0 = jacrev(k)(x1, x2)
    assert_close(k_fwd_0, k_rev_0)

    if same_inputs is not None:
      # d/dx2
      k_fwd_1 = jacfwd(k, 1)(x1, x2)
      k_rev_1 = jacrev(k, 1)(x1, x2)
      assert_close(k_fwd_1, k_rev_1)

      # dk(x2, x1)/dx2 = dk(x1, x2)/dx1
      k_fwd_01 = jacfwd(k, 1)(x2, x1)
      k_rev_01 = jacrev(k, 1)(x2, x1)
      assert_close(np.moveaxis(k_fwd_0, (0, 2, 4), (1, 3, 5)), k_fwd_01)
      assert_close(np.moveaxis(k_rev_0, (0, 2, 4), (1, 3, 5)), k_rev_01)

      # dk(x2, x1)/dx1 = dk(x1, x2)/dx2
      k_fwd_10 = jacfwd(k)(x2, x1)
      k_rev_10 = jacrev(k)(x2, x1)
      assert_close(np.moveaxis(k_fwd_1, (0, 2, 4), (1, 3, 5)), k_fwd_10)
      assert_close(np.moveaxis(k_rev_1, (0, 2, 4), (1, 3, 5)), k_rev_10)

  @parameterized.named_parameters(
      test_utils.cases_from_list({
          'testcase_name':
              f'get={get}-'
              f'param={parameterization}-'
              f'param_out={parameterization_out}-'
              f'x1={x1_type}-'
              f'x2={x2_type}-'
              f'phi={phi.__name__}-'
              f'b_std={b_std}-'
              f'jit={do_jit}-',
          'get': get,
          'parameterization': parameterization,
          'parameterization_out': parameterization_out,
          'x1_type': x1_type,
          'x2_type': x2_type,
          'phi': phi,
          'b_std': b_std,
          'do_jit': do_jit
      }
                          for get in [
                              'ntk',
                              'nngp'
                          ]
                          for parameterization in [
                              'standard',
                              'ntk'
                          ]
                          for parameterization_out in [
                              'ntk'
                          ]
                          for do_jit in [
                              True,
                          ]
                          for x1_type in [
                              'zeros',
                              'ones',
                              'random',
                          ]
                          for x2_type in [
                              'zeros',
                              'ones',
                              'random',
                              'x1',
                              'none',
                          ]
                          for b_std in [
                              None,
                              0.1,
                          ]
                          for phi in [
                              stax.Identity,
                              stax.Erf,
                              stax.Abs,
                              stax.Gelu,
                              stax.Relu,
                              stax.Sigmoid_like,
                              stax.ABRelu,
                              stax.Exp,
                              stax.ExpNormalized,
                              stax.Gaussian,
                              stax.Sign,
                              stax.Rbf,
                              stax.Cos,
                              stax.Sin
                          ]))
  def test_activations(
      self,
      get,
      parameterization,
      parameterization_out,
      x1_type,
      x2_type,
      b_std,
      phi,
      do_jit
  ):
    """Tests forward- and reverse-mode autodiff for nonlinearities."""
    if phi == stax.ABRelu:
      phi_ = phi(0.25, 0.5)
    else:
      phi_ = phi()

    if phi not in [stax.Relu]:
      test_utils.skip_test(self)

    n_out = 1 if get == 'ntk' else 1024
    width = 2**10

    W_std_in = width**(-0.5) if parameterization_out == 'standard' else 1.
    if phi == stax.Exp:
      W_std_in /= 10.

    init_fn, apply_fn, kernel_fn = stax.serial(
        stax.Dense(
            width,
            W_std=W_std_in,
            b_std=b_std,
            parameterization=parameterization),
        phi_,
        stax.Dense(
            n_out,
            b_std=b_std,
            parameterization=parameterization_out
        ),
    )

    def get_x(x_type, key):
      shape = (1, 2)
      if x_type == 'zeros':
        x = np.zeros(shape)
      elif x_type == 'ones':
        x = np.ones(shape)
      elif x_type == 'random':
        x = random.normal(random.PRNGKey(key), shape)
      elif x_type == 'sin':
        x = np.sin(random.normal(random.PRNGKey(key), shape))
      elif x_type == 'none':
        return None
      else:
        raise ValueError(x_type)
      return x

    x1 = get_x(x1_type, 1)
    if x2_type == 'x1':
      x2 = x1
    else:
      x2 = get_x(x2_type, 2)

    def kernel_scalar(x1, x2):
      return kernel_fn(x1, x2, get)[0, 0]

    if do_jit:
      kernel_scalar = jit(kernel_scalar)

    k1 = kernel_scalar(x1, x2)
    k2, k2_grad = value_and_grad(kernel_scalar)(x1, x2)
    self.assertAllClose(k1, k2)

    # Compare to forward-mode.
    k2_fwd, _ = jvp(kernel_scalar, (x1, x2), (x1, x2))
    k2_grad_fwd = jacfwd(kernel_scalar)(x1, x2)
    self.assertAllClose(k1, k2_fwd)
    self.assertAllClose(k2_grad, k2_grad_fwd)

    # `stax.ExpNormalized` has no forward pass.
    # `stax.Sign` is discontinuous at `0`, so NTK MC kernel does not converge to
    # infinite-width kernel.
    if phi == stax.ExpNormalized or (get == 'ntk' and phi == stax.Sign):
      raise absltest.SkipTest('Not comparing against MC kernels.')

    _kernel_scalar_mc = nt.monte_carlo_kernel_fn(
        init_fn,
        apply_fn,
        key=random.PRNGKey(3),
        n_samples=1,
        device_count=0,
    )

    def kernel_scalar_mc(x1, x2):
      return _kernel_scalar_mc(x1, x2, get)[0, 0]

    k_mc = kernel_scalar_mc(x1, x2)
    k_mc2, k_mc2_grad = value_and_grad(kernel_scalar_mc)(x1, x2)
    self.assertAllClose(k_mc, k_mc2)

    # Compare MC to forward-mode.
    k_mc2_fwd, _ = jvp(kernel_scalar_mc, (x1, x2), (x1, x2))
    k_mc2_grad_fwd = jacfwd(kernel_scalar_mc)(x1, x2)
    self.assertAllClose(k_mc, k_mc2_fwd)
    self.assertAllClose(k_mc2_grad, k_mc2_grad_fwd)

    def kernel_fn_emp(x1, x2, get, params):
      return nt.empirical_kernel_fn(apply_fn)(x1, x2, get, params)[0, 0]

    kernel_fn_emp_g = jit(value_and_grad(kernel_fn_emp), static_argnums=(2,))

    def kernel_scalar_mc_grad_mean(x1, x2):
      key = random.PRNGKey(4)
      n_samples = 2**9
      k, k_grad = 0., 0.

      for _ in range(n_samples):
        _, params = init_fn(key, x1.shape)
        k_mc2, k_mc2_grad = kernel_fn_emp_g(x1, x2, get, params)
        k += k_mc2
        k_grad += k_mc2_grad
        key, _ = random.split(key)

      k /= n_samples
      k_grad /= n_samples
      return k, k_grad

    k_mc2_mean, k_mc2_grad_mean = kernel_scalar_mc_grad_mean(x1, x2)

    # Compare kernels.
    self.assertAllClose(k1, k_mc2_mean, atol=4e-3, rtol=4e-2)

    if phi == stax.Sign and get == 'nngp':
      raise absltest.SkipTest('Derivative of the empirical NNGP of a '
                              'discontinuous function does not converge '
                              'to the derivative of the infinite width NNGP.')

    if (phi in [stax.Abs, stax.Relu, stax.LeakyRelu, stax.ABRelu] and
        get == 'ntk'):
      raise absltest.SkipTest('Derivative of the empirical NTK of a '
                              'non-differentiable function does not converge '
                              'to the derivative of the infinite width NTK.')

    atol = 1e-2

    # Compare gradient of the analytic kernel to empirical kernel.
    if np.max(np.abs(k2_grad - k_mc2_grad_mean)) > atol:
      test_utils.assert_close_matrices(self,
                                       k_mc2_grad_mean,
                                       k2_grad,
                                       rtol=0.05,
                                       atol=10.)

  @parameterized.named_parameters(
      test_utils.cases_from_list({
          'testcase_name':
              f'get={get}-'
              f'architecture={architecture}-'
              f'jit={do_jit}-',
          'get': get,
          'architecture': architecture,
          'do_jit': do_jit
      }
                          for architecture in [
                              'conv',
                              'wrn'
                          ]
                          for get in [
                              'ntk',
                              'nngp'
                          ]
                          for do_jit in [
                              True,
                          ]))
  def test_issue_123(
      self,
      get,
      architecture,
      do_jit
  ):
    """Tests https://github.com/google/neural-tangents/issues/123."""
    if architecture == 'wrn':
      # https://github.com/google/neural-tangents/issues/123#issue-992927376
      def WideResnetBlock(channels, strides=(1, 1), channel_mismatch=False):
        main = stax.serial(
            stax.Relu(),
            stax.Conv(
                channels, (3, 3), strides, padding='SAME',
                parameterization='standard'
            ),
            stax.Relu(),
            stax.Conv(channels, (3, 3), padding='SAME',
                      parameterization='standard'),
        )
        shortcut = (
            stax.Identity()
            if not channel_mismatch
            else stax.Conv(
                channels, (3, 3), strides, padding='SAME',
                parameterization='standard'
            )
        )
        return stax.serial(stax.FanOut(2), stax.parallel(main, shortcut),
                           stax.FanInSum())

      def WideResnetGroup(n, channels, strides=(1, 1)):
        blocks = []
        blocks += [WideResnetBlock(channels, strides, channel_mismatch=True)]
        for _ in range(n - 1):
          blocks += [WideResnetBlock(channels, (1, 1))]
        return stax.serial(*blocks)

      def WideResnet(block_size, k, num_classes):
        return stax.serial(
            stax.Conv(16, (3, 3), padding='SAME', parameterization='standard'),
            WideResnetGroup(block_size, int(16 * k)),
            WideResnetGroup(block_size, int(32 * k), (2, 2)),
            WideResnetGroup(block_size, int(64 * k), (2, 2)),
            stax.AvgPool((8, 8), padding='SAME'),
            stax.Flatten(),
            stax.Dense(num_classes, 1.0, 0.0, parameterization='standard'),
        )

      init_fn, apply_fn, kernel_fn = WideResnet(block_size=1,
                                                k=1,
                                                num_classes=1)

    elif architecture == 'conv':
      # https://github.com/google/neural-tangents/issues/123#issuecomment-932809224
      init_fn, apply_fn, kernel_fn = stax.serial(
          stax.Conv(
              1,
              (3, 3)
          ),
          stax.Relu(),
          stax.Flatten(),
      )

    else:
      raise ValueError(architecture)

    x1 = x2 = np.zeros((1, 8, 8, 3))

    def kernel_scalar(x1, x2):
      return kernel_fn(x1, x2, get)[0, 0]

    if do_jit:
      kernel_scalar = jit(kernel_scalar)

    # Compare forward pass to `value_and_grad`.
    k1 = kernel_scalar(x1, x2)
    k2, k2_grad = value_and_grad(kernel_scalar)(x1, x2)
    self.assertAllClose(k1, k2)

    # Compare to forward-mode.
    k2_fwd, _ = jvp(kernel_scalar, (x1, x2), (x1, x2))
    k2_grad_fwd = jacfwd(kernel_scalar)(x1, x2)
    self.assertAllClose(k1, k2_fwd)
    self.assertAllClose(k2_grad, k2_grad_fwd)

    # Compare to 0.
    self.assertAllClose(grad(kernel_scalar)(x1, x2), np.zeros_like(x1))
示例#6
0
class ElementwiseTest(test_utils.NeuralTangentsTestCase):

  @parameterized.named_parameters(
      test_utils.cases_from_list({
          'testcase_name':
              '_{}_{}_n={}_diag_batch={}_spatial={}'.format(
                  phi[0].__name__, same_inputs, n, diagonal_batch,
                  diagonal_spatial),
          'phi': phi,
          'same_inputs': same_inputs,
          'n': n,
          'diagonal_batch': diagonal_batch,
          'diagonal_spatial': diagonal_spatial
      }
                          for phi in [
                              stax.Identity(),
                              stax.Erf(),
                              stax.Sin(),
                              stax.Relu(),
                          ]
                          for same_inputs in [False, True, None]
                          for n in [0, 1, 2]
                          for diagonal_batch in [True, False]
                          for diagonal_spatial in [True, False]))
  def test_elementwise(self, same_inputs, phi, n, diagonal_batch,
                       diagonal_spatial):
    fn = lambda x: phi[1]((), x)

    name = phi[0].__name__

    def nngp_fn(cov12, var1, var2):
      if 'Identity' in name:
        res = cov12

      elif 'Erf' in name:
        prod = (1 + 2 * var1) * (1 + 2 * var2)
        res = np.arcsin(2 * cov12 / np.sqrt(prod)) * 2 / np.pi

      elif 'Sin' in name:
        sum_ = (var1 + var2)
        s1 = np.exp((-0.5 * sum_ + cov12))
        s2 = np.exp((-0.5 * sum_ - cov12))
        res = (s1 - s2) / 2

      elif 'Relu' in name:
        prod = var1 * var2
        sqrt = np.sqrt(np.maximum(prod - cov12 ** 2, 1e-30))
        angles = np.arctan2(sqrt, cov12)
        dot_sigma = (1 - angles / np.pi) / 2
        res = sqrt / (2 * np.pi) + dot_sigma * cov12

      else:
        raise NotImplementedError(name)

      return res

    _, _, kernel_fn = stax.serial(stax.Dense(1), stax.Elementwise(fn, nngp_fn),
                                  stax.Dense(1), stax.Elementwise(fn, nngp_fn))
    _, _, kernel_fn_manual = stax.serial(stax.Dense(1), phi,
                                         stax.Dense(1), phi)

    key = random.PRNGKey(1)
    shape = (4, 3, 2)[:n] + (1,)
    x1 = random.normal(key, (5,) + shape)
    if same_inputs is None:
      x2 = None
    elif same_inputs is True:
      x2 = x1
    else:
      x2 = random.normal(key, (6,) + shape)

    kwargs = dict(diagonal_batch=diagonal_batch,
                  diagonal_spatial=diagonal_spatial)

    k = kernel_fn(x1, x2, **kwargs)
    k_manual = kernel_fn_manual(x1, x2, **kwargs).replace(is_gaussian=False)
    self.assertAllClose(k_manual, k)
示例#7
0
class ActivationTest(test_utils.NeuralTangentsTestCase):

  @stax.layer
  def _RBF(self, gamma):
    init_fn = lambda key, input_shape: (input_shape, ())
    def apply_fn(unused_params, unused_xs, **kwargs):
      raise NotImplementedError()
    def kernel_fn(kernels, **kwargs):
      if kernels.ntk is not None:
        raise ValueError('RBF Kernel does not have an associated NTK.')

      if kernels.nngp.ndim > 2:
        raise ValueError(
            ('RBF Kernel is not defined for covariance matrices with dimension'
             ' greater than two.'))

      input_dim = kernels.shape1[1]
      cov1 = kernels.cov1
      cov1 = np.reshape(cov1, (cov1.shape[0], 1))
      cov2 = cov1 if kernels.cov2 is None else kernels.cov2
      cov2 = np.reshape(cov2, (1, cov2.shape[0]))
      nngp = kernels.nngp

      # TODO(schsam): Update cov1 and cov2 if we want to compose this kernel
      # with other kernels.
      return kernels.replace(
          nngp=np.exp(-input_dim * gamma * (cov1 + cov2 - 2 * nngp)))
    return init_fn, apply_fn, kernel_fn

  def _test_activation(self, activation_fn, same_inputs, model, get,
                       rbf_gamma=None):
    if 'conv' in model:
      test_utils.skip_test(self)

    key = random.PRNGKey(1)
    key, split = random.split(key)
    output_dim = 1024 if get == 'nngp' else 1
    b_std = 0.5
    W_std = 2.0
    if activation_fn[2].__name__ == 'Sin':
      W_std = 0.9
    if activation_fn[2].__name__ == 'Rbf':
      W_std = 1.0
      b_std = 0.0

    if model == 'fc':
      rtol = 0.04
      X0_1 = random.normal(key, (4, 2))
      X0_2 = None if same_inputs else random.normal(split, (2, 2))
      affine = stax.Dense(1024, W_std, b_std)
      readout = stax.Dense(output_dim)
      depth = 1

    else:
      rtol = 0.05
      X0_1 = random.normal(key, (2, 4, 4, 3))
      X0_2 = None if same_inputs else random.normal(split, (4, 4, 4, 3))
      affine = stax.Conv(512, (3, 2), W_std=W_std, b_std=b_std, padding='SAME')
      readout = stax.serial(stax.GlobalAvgPool() if 'pool' in model else
                            stax.Flatten(),
                            stax.Dense(output_dim))
      depth = 2

    if default_backend() == 'cpu':
      num_samplings = 200
      rtol *= 2
    else:
      num_samplings = (500 if activation_fn[2].__name__ in ('Sin', 'Rbf')
                       else 300)

    init_fn, apply_fn, kernel_fn = stax.serial(
        *[affine, activation_fn]*depth, readout)
    analytic_kernel = kernel_fn(X0_1, X0_2, get)
    mc_kernel_fn = nt.monte_carlo_kernel_fn(
        init_fn, apply_fn, split, num_samplings, implementation=2,
        vmap_axes=0
    )
    empirical_kernel = mc_kernel_fn(X0_1, X0_2, get)
    test_utils.assert_close_matrices(self, analytic_kernel,
                                     empirical_kernel, rtol)

    # Check match with explicit RBF
    if rbf_gamma is not None and get == 'nngp' and model == 'fc':
      input_dim = X0_1.shape[1]
      _, _, kernel_fn = self._RBF(rbf_gamma / input_dim)
      direct_rbf_kernel = kernel_fn(X0_1, X0_2, get)
      test_utils.assert_close_matrices(self, analytic_kernel,
                                       direct_rbf_kernel, rtol)

  @parameterized.named_parameters(
      test_utils.cases_from_list({
          'testcase_name':
              '_model={}_phi={}_{}_get={}_abc={}_approximate={}'.format(
                  model,
                  phi_name,
                  'Same_inputs' if same_inputs else 'Different_inputs',
                  get,
                  abc,
                  approximate),
          'model': model,
          'phi_name': phi_name,
          'same_inputs': same_inputs,
          'get': get,
          'abc': abc,
          'approximate': approximate
      }
                          for model in ['fc', 'conv-pool', 'conv-flatten']
                          for phi_name in [
                              'Sin',
                              'Cos',
                              'Erf',
                              'Gelu',
                              'Sign',
                          ]
                          for same_inputs in [False]
                          for get in ['nngp', 'ntk']
                          for approximate in [True, False]
                          for abc in itertools.product(
                              [2., 0.3],
                              [1.5, 0.3],
                              [0., -np.pi/4., np.pi/2.]
                              )))
  def test_activation(
      self,
      same_inputs,
      model,
      phi_name,
      get,
      abc,
      approximate
  ):
    if abc != [0.3, 1.5, -np.pi/4]:
      test_utils.skip_test(self)

    if approximate and phi_name != 'Gelu':
      raise absltest.SkipTest(
          f'{phi_name} does not have an `approximate parameter.')

    a, b, c = abc
    if phi_name == 'Sin':
      activation = stax.Sin(a=a, b=b, c=c)
    elif phi_name == 'Erf':
      activation = stax.Erf(a=a, b=b, c=c)
    elif phi_name in ['Gelu', 'Sign', 'Cos']:
      if a != 0.3 or b != 0.3 or c != 0.:
        raise absltest.SkipTest('Skip `Gelu/Sign/Cos` test if '
                                ' (a, b, c) != (.3, .3, 0.).')
      activation = stax.Gelu() if phi_name == 'Gelu' else stax.Sign()
    else:
      raise NotImplementedError(f'Activation {phi_name} is not implemented.')
    self._test_activation(activation, same_inputs, model, get)

  @parameterized.named_parameters(
      test_utils.cases_from_list({
          'testcase_name':
              '_{}_Rbf_{}_{}_{}'.format(
                  model,
                  'Same_inputs' if same_inputs else 'Different_inputs',
                  get,
                  gamma),
          'model': model,
          'same_inputs': same_inputs,
          'get': get,
          'gamma': gamma,
      }
                          for model in ['fc', 'conv-pool', 'conv-flatten']
                          for same_inputs in [False, True]
                          for get in ['nngp', 'ntk']
                          for gamma in [1e-6, 1e-4, 1e-2, 1.0, 2.]
                          ))
  def test_rbf(self, same_inputs, model, get, gamma):
    activation = stax.Rbf(gamma)
    self._test_activation(activation, same_inputs, model, get,
                          rbf_gamma=gamma)

  @parameterized.named_parameters(
      test_utils.cases_from_list({
          'testcase_name': f'{phi.__name__}_{same_inputs}_a={a}_b={b}_n={n}',
          'same_inputs': same_inputs,
          'a': a,
          'b': b,
          'n': n,
          'phi': phi
      }
                          for a in [-0.5, 0.25]
                          for b in [-0.5, -0.1, 0.1]
                          for phi in [stax.Gaussian, stax.Exp]
                          for same_inputs in [False, True, None]
                          for n in [0]))
  def test_nonlineariy(self, phi, same_inputs, a, b, n):
    width = 2**10
    n_samples = 2**9
    init_fn, apply_fn, kernel_fn = stax.serial(
        stax.Dense(width),
        phi(a=a, b=b),
        stax.Dense(width),
        phi(a=a, b=b),
        stax.Dense(1))

    key1, key2, key_mc = random.split(random.PRNGKey(1), 3)
    shape = (4, 3, 2)[:n] + (1,)
    x1 = np.cos(random.normal(key1, (2,) + shape))
    if same_inputs is None:
      x2 = None
    elif same_inputs is True:
      x2 = x1
    else:
      x2 = np.cos(random.normal(key2, (3,) + shape))

    k = kernel_fn(x1, x2)
    mc_kernel_fn = nt.monte_carlo_kernel_fn(init_fn, apply_fn, key_mc,
                                            n_samples)
    k_mc = mc_kernel_fn(x1, x2, ('nngp', 'ntk'))
    test_utils.assert_close_matrices(self, k_mc.nngp, k.nngp, 6e-2)
    test_utils.assert_close_matrices(self, k_mc.ntk, k.ntk, 6e-2)

  def test_exp_normalized(self):
    key = random.PRNGKey(0)
    x1 = random.normal(key, (2, 6, 7, 1))
    x2 = random.normal(key, (4, 6, 7, 1))

    for do_clip in [True, False]:
      for gamma in [1., 2., 0.5]:
        for get in ['nngp', 'ntk']:
          with self.subTest(do_clip=do_clip, gamma=gamma, get=get):
            _, _, kernel_fn = stax.serial(
                stax.Conv(1, (3, 3)),
                stax.ExpNormalized(gamma, do_clip),
                stax.Conv(1, (3, 3)),
                stax.ExpNormalized(gamma, do_clip),
                stax.GlobalAvgPool(),
                stax.Dense(1)
            )
            k_12 = kernel_fn(x1, x2, get=get)
            self.assertEqual(k_12.shape, (x1.shape[0], x2.shape[0]))

            k_11 = kernel_fn(x1, None, get=get)
            self.assertEqual(k_11.shape, (x1.shape[0],) * 2)
            self.assertGreater(np.min(np.linalg.eigvalsh(k_11)), 0)

            k_22 = kernel_fn(x2, None, get=get)
            self.assertEqual(k_22.shape, (x2.shape[0],) * 2)
            self.assertGreater(np.min(np.linalg.eigvalsh(k_22)), 0)

  def test_exp_normalized_ntk(self):
    def nngp_fn(cov12, var1, var2):
      prod = np.sqrt(var1 * var2)
      return prod * np.exp(cov12 / prod - 1)

    _, _, kernel_fn = stax.serial(stax.Dense(1),
                                  stax.Elementwise(nngp_fn=nngp_fn))

    _, _, kernel_fn_manual = stax.serial(stax.Dense(1),
                                         stax.ExpNormalized())

    key = random.PRNGKey(1)
    x1 = random.normal(key, (5, 4, 3, 1))
    x2 = random.normal(key, (6, 4, 3, 1))

    k = kernel_fn(x1, x2)
    k_manual = kernel_fn_manual(x1, x2)
    self.assertAllClose(k_manual, k)

  @parameterized.named_parameters(
      test_utils.cases_from_list({
          'testcase_name':
              '_{}_degree={}_get={}_readout={}'.format(
                  'Same_inputs' if same_inputs else 'Different_inputs',
                  degree,
                  get,
                  readout
              ),
          'same_inputs': same_inputs,
          'degree': degree,
          'get': get,
          'readout': readout
      }
                          for same_inputs in [False, True]
                          for degree in [1, 2, 3, 4, 5, 6]
                          for get in ['ntk', 'nngp']
                          for readout in ['pool', 'flatten']))
  def test_hermite(self, same_inputs, degree, get, readout):
    key = random.PRNGKey(1)
    key1, key2, key = random.split(key, 3)

    if degree > 2:
      width = 10000
      n_samples = 5000
      test_utils.skip_test(self)
    else:
      width = 10000
      n_samples = 100

    x1 = np.cos(random.normal(key1, [2, 6, 6, 3]))
    x2 = x1 if same_inputs else np.cos(random.normal(key2, [3, 6, 6, 3]))

    conv_layers = [
        stax.Conv(width, (3, 3), W_std=2., b_std=0.5),
        stax.LayerNorm(),
        stax.Hermite(degree),
        stax.GlobalAvgPool() if readout == 'pool' else stax.Flatten(),
        stax.Dense(1) if get == 'ntk' else stax.Identity()]

    init_fn, apply_fn, kernel_fn = stax.serial(*conv_layers)
    analytic_kernel = kernel_fn(x1, x2, get)
    mc_kernel_fn = nt.monte_carlo_kernel_fn(init_fn, apply_fn, key, n_samples)
    mc_kernel = mc_kernel_fn(x1, x2, get)
    rot = degree / 2. * 1e-2
    test_utils.assert_close_matrices(self, mc_kernel, analytic_kernel, rot)
示例#8
0
class BatchTest(test_utils.NeuralTangentsTestCase):
    @parameterized.named_parameters(
        test_utils.cases_from_list({
            'testcase_name':
            '_train_shape={}_test_shape={}_network={}_{}_batch_size={}'.format(
                train, test, network, name, batch_size),
            'train_shape':
            train,
            'test_shape':
            test,
            'network':
            network,
            'name':
            name,
            'kernel_fn':
            kernel_fn,
            'batch_size':
            batch_size
        } for train, test, network in zip(TRAIN_SHAPES, TEST_SHAPES, NETWORK)
                                   for name, kernel_fn in KERNELS.items()
                                   for batch_size in [2, 8]))
    def testSerial(self, train_shape, test_shape, network, name, kernel_fn,
                   batch_size):
        key = random.PRNGKey(0)
        key, self_split, other_split = random.split(key, 3)
        data_self = random.normal(self_split, train_shape)
        data_other = random.normal(other_split, test_shape)
        kernel_fn = kernel_fn(key, train_shape[1:], network)
        kernel_batched = batching._serial(kernel_fn, batch_size=batch_size)

        _test_kernel_against_batched(self, kernel_fn, kernel_batched,
                                     data_self, data_other)

    # We also exclude tests for dropout + parallel. It is not clear what is the
    # best way to handle this case.
    @parameterized.named_parameters(
        test_utils.cases_from_list({
            'testcase_name':
            '_train_shape={}_test_shape={}_network={}_{}'.format(
                train, test, network, name),
            'train_shape':
            train,
            'test_shape':
            test,
            'network':
            network,
            'name':
            name,
            'kernel_fn':
            kernel_fn
        } for train, test, network in zip(TRAIN_SHAPES, TEST_SHAPES, NETWORK)
                                   for name, kernel_fn in KERNELS.items()))
    def testParallel(self, train_shape, test_shape, network, name, kernel_fn):
        test_utils.stub_out_pmap(batching, 2)
        key = random.PRNGKey(0)
        key, self_split, other_split = random.split(key, 3)
        data_self = random.normal(self_split, train_shape)
        data_other = random.normal(other_split, test_shape)

        kernel_fn = kernel_fn(key, train_shape[1:], network, use_dropout=False)
        kernel_batched = batching._parallel(kernel_fn)

        _test_kernel_against_batched(self, kernel_fn, kernel_batched,
                                     data_self, data_other, True)

    @parameterized.named_parameters(
        test_utils.cases_from_list({
            'testcase_name':
            '_train_shape={}_test_shape={}_network={}_{}_batch_size={}'.format(
                train, test, network, name, batch_size),
            'train_shape':
            train,
            'test_shape':
            test,
            'network':
            network,
            'name':
            name,
            'kernel_fn':
            kernel_fn,
            'batch_size':
            batch_size
        } for train, test, network in zip(TRAIN_SHAPES, TEST_SHAPES, NETWORK)
                                   for name, kernel_fn in KERNELS.items()
                                   for batch_size in [2, 8]))
    def testComposition(self, train_shape, test_shape, network, name,
                        kernel_fn, batch_size):
        test_utils.stub_out_pmap(batching, 2)

        key = random.PRNGKey(0)
        key, self_split, other_split = random.split(key, 3)
        data_self = random.normal(self_split, train_shape)
        data_other = random.normal(other_split, test_shape)

        kernel_fn = kernel_fn(key, train_shape[1:], network)

        kernel_batched = batching._parallel(
            batching._serial(kernel_fn, batch_size=batch_size))
        _test_kernel_against_batched(self, kernel_fn, kernel_batched,
                                     data_self, data_other)

        kernel_batched = batching._serial(batching._parallel(kernel_fn),
                                          batch_size=batch_size)
        _test_kernel_against_batched(self, kernel_fn, kernel_batched,
                                     data_self, data_other)

    @parameterized.named_parameters(
        test_utils.cases_from_list({
            'testcase_name':
            '_train_shape={}_test_shape={}_network={}_{}_batch_size={}'.format(
                train, test, network, name, batch_size),
            'train_shape':
            train,
            'test_shape':
            test,
            'network':
            network,
            'name':
            name,
            'kernel_fn':
            kernel_fn,
            'batch_size':
            batch_size
        } for train, test, network in zip(TRAIN_SHAPES, TEST_SHAPES, NETWORK)
                                   for name, kernel_fn in KERNELS.items()
                                   for batch_size in [2, 8]))
    def testAutomatic(self, train_shape, test_shape, network, name, kernel_fn,
                      batch_size):
        test_utils.stub_out_pmap(batching, 2)

        key = random.PRNGKey(0)
        key, self_split, other_split = random.split(key, 3)
        data_self = random.normal(self_split, train_shape)
        data_other = random.normal(other_split, test_shape)

        kernel_fn = kernel_fn(key, train_shape[1:], network)

        kernel_batched = batching.batch(kernel_fn, batch_size=batch_size)
        _test_kernel_against_batched(self, kernel_fn, kernel_batched,
                                     data_self, data_other)

        kernel_batched = batching.batch(kernel_fn,
                                        batch_size=batch_size,
                                        store_on_device=False)
        _test_kernel_against_batched(self, kernel_fn, kernel_batched,
                                     data_self, data_other)

    def _test_analytic_kernel_composition(self, batching_fn):
        # Check Fully-Connected.
        rng = random.PRNGKey(0)
        rng_self, rng_other = random.split(rng)
        x_self = random.normal(rng_self, (8, 2))
        x_other = random.normal(rng_other, (2, 2))
        Block = stax.serial(stax.Dense(256), stax.Relu())

        _, _, ker_fn = Block
        ker_fn = batching_fn(ker_fn)

        _, _, composed_ker_fn = stax.serial(Block, Block)

        ker_out = ker_fn(ker_fn(x_self))
        composed_ker_out = composed_ker_fn(x_self)
        if batching_fn == batching._parallel:
            # In the parallel setting, `x1_is_x2` is not computed correctly
            # when x1==x2.
            composed_ker_out = composed_ker_out.replace(
                x1_is_x2=ker_out.x1_is_x2)
        self.assertAllClose(ker_out, composed_ker_out)

        ker_out = ker_fn(ker_fn(x_self, x_other))
        composed_ker_out = composed_ker_fn(x_self, x_other)
        if batching_fn == batching._parallel:
            composed_ker_out = composed_ker_out.replace(
                x1_is_x2=ker_out.x1_is_x2)
        self.assertAllClose(ker_out, composed_ker_out)

        # Check convolutional + pooling.
        x_self = random.normal(rng, (8, 4, 4, 3))
        x_other = random.normal(rng, (2, 4, 4, 3))

        Block = stax.serial(stax.Conv(256, (2, 2)), stax.Relu())
        Readout = stax.serial(stax.GlobalAvgPool(), stax.Dense(10))

        block_ker_fn, readout_ker_fn = Block[2], Readout[2]
        _, _, composed_ker_fn = stax.serial(Block, Readout)
        block_ker_fn = batching_fn(block_ker_fn)
        readout_ker_fn = batching_fn(readout_ker_fn)

        ker_out = readout_ker_fn(block_ker_fn(x_self))
        composed_ker_out = composed_ker_fn(x_self)
        if batching_fn == batching._parallel:
            composed_ker_out = composed_ker_out.replace(
                x1_is_x2=ker_out.x1_is_x2)
        self.assertAllClose(ker_out, composed_ker_out)
        ker_out = readout_ker_fn(block_ker_fn(x_self, x_other))
        composed_ker_out = composed_ker_fn(x_self, x_other)
        if batching_fn == batching._parallel:
            composed_ker_out = composed_ker_out.replace(
                x1_is_x2=ker_out.x1_is_x2)
        self.assertAllClose(ker_out, composed_ker_out)

    @parameterized.named_parameters(
        test_utils.cases_from_list({
            'testcase_name':
            '_on_device={}_batch_size={}'.format(store_on_device, batch_size),
            'store_on_device':
            store_on_device,
            'batch_size':
            batch_size
        } for store_on_device in [True, False] for batch_size in [2, 8]))
    def testAnalyticKernelComposeSerial(self, store_on_device, batch_size):
        self._test_analytic_kernel_composition(
            partial(batching._serial,
                    batch_size=batch_size,
                    store_on_device=store_on_device))

    def testAnalyticKernelComposeParallel(self):
        test_utils.stub_out_pmap(batching, 2)
        self._test_analytic_kernel_composition(batching._parallel)

    @parameterized.named_parameters(
        test_utils.cases_from_list({
            'testcase_name':
            '_on_device={}_batch_size={}'.format(store_on_device, batch_size),
            'store_on_device':
            store_on_device,
            'batch_size':
            batch_size
        } for store_on_device in [True, False] for batch_size in [2, 8]))
    def testAnalyticKernelComposeAutomatic(self, store_on_device, batch_size):
        test_utils.stub_out_pmap(batching, 2)
        self._test_analytic_kernel_composition(
            partial(batching.batch,
                    batch_size=batch_size,
                    store_on_device=store_on_device))

    def test_jit_or_pmap_broadcast(self):
        def kernel_fn(x1,
                      x2,
                      do_flip,
                      keys,
                      do_square,
                      params,
                      _unused=None,
                      p=0.65):
            res = np.abs(np.matmul(x1, x2))
            if do_square:
                res *= res
            if do_flip:
                res = -res

            res *= random.uniform(keys) * p
            return [res, params]

        params = (np.array([1., 0.3]), (np.array([1.2]), np.array([0.5])))
        x2 = np.arange(0, 10).reshape((10, ))
        keys = random.PRNGKey(1)

        kernel_fn_pmapped = batching._jit_or_pmap_broadcast(kernel_fn,
                                                            device_count=0)
        x1 = np.arange(0, 10).reshape((1, 10))
        for do_flip in [True, False]:
            for do_square in [True, False]:
                with self.subTest(do_flip=do_flip,
                                  do_square=do_square,
                                  device_count=0):
                    res_1 = kernel_fn(x1,
                                      x2,
                                      do_flip,
                                      keys,
                                      do_square,
                                      params,
                                      _unused=True,
                                      p=0.65)
                    res_2 = kernel_fn_pmapped(x1,
                                              x2,
                                              do_flip,
                                              keys,
                                              do_square,
                                              params,
                                              _unused=True)
                    self.assertAllClose(res_1, res_2)

        test_utils.stub_out_pmap(batching, 1)
        x1 = np.arange(0, 10).reshape((1, 10))
        kernel_fn_pmapped = batching._jit_or_pmap_broadcast(kernel_fn,
                                                            device_count=1)
        for do_flip in [True, False]:
            for do_square in [True, False]:
                with self.subTest(do_flip=do_flip,
                                  do_square=do_square,
                                  device_count=1):
                    res_1 = kernel_fn(x1,
                                      x2,
                                      do_flip,
                                      keys,
                                      do_square,
                                      params,
                                      _unused=False,
                                      p=0.65)
                    res_2 = kernel_fn_pmapped(x1,
                                              x2,
                                              do_flip,
                                              keys,
                                              do_square,
                                              params,
                                              _unused=None)
                    self.assertAllClose(res_1[0], res_2[0])
                    self.assertAllClose(
                        tree_map(partial(np.expand_dims, axis=0), res_1[1]),
                        res_2[1])

        kernel_fn_pmapped = batching._jit_or_pmap_broadcast(kernel_fn,
                                                            device_count=2)
        x1 = np.arange(0, 20).reshape((2, 10))
        test_utils.stub_out_pmap(batching, 2)

        def broadcast(arg):
            return np.broadcast_to(arg, (2, ) + arg.shape)

        for do_flip in [True, False]:
            for do_square in [True, False]:
                with self.subTest(do_flip=do_flip,
                                  do_square=do_square,
                                  device_count=2):
                    res_1 = kernel_fn(x1,
                                      x2,
                                      do_flip,
                                      keys,
                                      do_square,
                                      params,
                                      p=0.2)
                    res_2 = kernel_fn_pmapped(x1,
                                              x2,
                                              do_flip,
                                              keys,
                                              do_square,
                                              params,
                                              _unused=None,
                                              p=0.2)
                    self.assertAllClose(res_1[0][0], res_2[0][0])
                    self.assertAllClose(res_1[0][1], res_2[0][1])
                    self.assertAllClose(tree_map(broadcast, res_1[1]),
                                        res_2[1])

    @parameterized.named_parameters(
        test_utils.cases_from_list(
            {
                'testcase_name': '_same_inputs={}'.format(same_inputs),
                'same_inputs': same_inputs
            } for same_inputs in [True, False]))
    def test_parallel_in_out(self, same_inputs):
        test_utils.stub_out_pmap(batching, 2)
        rng = random.PRNGKey(0)
        input_key1, input_key2 = random.split(rng, 2)

        x1_1, x1_2, x1_3 = random.normal(input_key1, (3, 4, 1))

        x1 = (x1_1, (x1_2, x1_3))

        if same_inputs:
            x2 = None
        else:
            x2_1, x2_2, x2_3 = random.normal(input_key2, (3, 8, 1))
            x2 = (x2_1, (x2_2, x2_3))

        N = WIDTH

        def net(N_out):
            return stax.parallel(
                stax.Dense(N_out),
                stax.parallel(stax.Dense(N_out + 1), stax.Dense(N_out + 2)))

        # Check NNGP.

        readin = net(N)
        readout = net(1)

        K_readin_fn = jit(readin[2])
        K_readout_fn = jit(partial(readout[2], get='nngp'))

        batch_K_readin_fn = batching.batch(K_readin_fn, 2)
        batch_K_readout_fn = batching.batch(K_readout_fn, 2)

        test_utils.assert_close_matrices(
            self, K_readout_fn(K_readin_fn(x1, x2)),
            batch_K_readout_fn(batch_K_readin_fn(x1, x2)), RTOL)

        # Check Both.
        K_readin_fn = jit(readin[2])
        K_readout_fn = jit(partial(readout[2], get=('nngp', 'ntk')))

        batch_K_readin_fn = batching.batch(K_readin_fn, 2)
        batch_K_readout_fn = batching.batch(K_readout_fn, 2)

        test_utils.assert_close_matrices(
            self, K_readout_fn(K_readin_fn(x1, x2)),
            batch_K_readout_fn(batch_K_readin_fn(x1, x2)), RTOL)

    @parameterized.named_parameters(
        test_utils.cases_from_list(
            {
                'testcase_name': '_same_inputs={}'.format(same_inputs),
                'same_inputs': same_inputs
            } for same_inputs in [True, False]))
    def test_parallel_in_out_empirical(self, same_inputs):
        test_utils.stub_out_pmap(batching, 2)
        rng = random.PRNGKey(0)
        input_key1, input_key2, net_key = random.split(rng, 3)

        x1_1, x1_2, x1_3 = random.normal(input_key1, (3, 4, 1))
        x1 = (x1_1, (x1_2, x1_3))

        if same_inputs:
            x2 = None
        else:
            x2_1, x2_2, x2_3 = random.normal(input_key2, (3, 8, 1))
            x2 = (x2_1, (x2_2, x2_3))

        def net(N_out):
            return stax.parallel(
                stax.Dense(N_out),
                stax.parallel(stax.Dense(N_out + 1), stax.Dense(N_out + 2)))

        # Check NNGP.
        init_fn, apply_fn, _ = net(WIDTH)
        _, params = init_fn(net_key, ((-1, 1), ((-1, 1), (-1, 1))))

        kernel_fn = jit(nt.empirical_nngp_fn(apply_fn))
        batch_kernel_fn = jit(batching.batch(kernel_fn, 2))

        test_utils.assert_close_matrices(self, kernel_fn(x1, x2, params),
                                         batch_kernel_fn(x1, x2, params), RTOL)

        # Check NTK.
        init_fn, apply_fn, _ = stax.serial(net(WIDTH), net(1))
        _, params = init_fn(net_key, ((-1, 1), ((-1, 1), (-1, 1))))

        kernel_fn = jit(nt.empirical_ntk_fn(apply_fn))
        batch_kernel_fn = jit(batching.batch(kernel_fn, 2))

        test_utils.assert_close_matrices(self, kernel_fn(x1, x2, params),
                                         batch_kernel_fn(x1, x2, params), RTOL)

    @parameterized.named_parameters(
        test_utils.cases_from_list(
            ({
                'testcase_name': (f'_same_inputs={same_inputs}'
                                  f'_device_count={device_count}'
                                  f'_trace_axes={trace_axes}'
                                  f'_diagonal_axes={diagonal_axes}'),
                'same_inputs':
                same_inputs,
                'device_count':
                device_count,
                'trace_axes':
                trace_axes,
                'diagonal_axes':
                diagonal_axes
            } for same_inputs in [True, False]
             for device_count in [-1, 0, 1, 2]
             for trace_axes, diagonal_axes in zip([(-1, ), (1, -1), ()], [(
                 1, ), (), (1, -1)]))))
    def test_empirical_ntk_diagonal_outputs(self, same_inputs, device_count,
                                            trace_axes, diagonal_axes):
        test_utils.stub_out_pmap(batching, 2)
        rng = random.PRNGKey(0)

        input_key1, input_key2, net_key = random.split(rng, 3)

        init_fn, apply_fn, _ = stax.serial(stax.Dense(5), stax.Relu(),
                                           stax.Dense(3))

        test_x1 = random.normal(input_key1, (12, 4, 4))
        test_x2 = None
        if same_inputs:
            test_x2 = random.normal(input_key2, (9, 4, 4))

        kernel_fn = nt.empirical_ntk_fn(apply_fn,
                                        trace_axes=trace_axes,
                                        diagonal_axes=diagonal_axes,
                                        vmap_axes=0,
                                        implementation=2)

        _, params = init_fn(net_key, test_x1.shape)

        true_kernel = kernel_fn(test_x1, test_x2, params)
        batched_fn = batching.batch(kernel_fn,
                                    device_count=device_count,
                                    batch_size=3)
        batch_kernel = batched_fn(test_x1, test_x2, params)
        self.assertAllClose(true_kernel, batch_kernel)
class MonteCarloTest(test_utils.NeuralTangentsTestCase):

  @parameterized.named_parameters(
      test_utils.cases_from_list({
          'testcase_name': '[batch_size={}, '
                           'device_count={} '
                           'store_on_device={} '
                           'get={} '
                           ']'.format(batch_size,
                                      device_count,
                                      store_on_device,
                                      get),
          'batch_size': batch_size,
          'device_count': device_count,
          'store_on_device': store_on_device,
          'get': get,
      } for batch_size in BATCH_SIZES for device_count in DEVICE_COUNTS
                          for store_on_device in STORE_ON_DEVICE
                          for get in ALL_GET))
  def test_sample_once_batch(self, batch_size, device_count, store_on_device,
                             get):
    test_utils.stub_out_pmap(batching, device_count)

    x1, x2, init_fn, apply_fn, _, key = _get_inputs_and_model()
    kernel_fn = nt.empirical_kernel_fn(apply_fn)

    sample_once_fn = monte_carlo._sample_once_kernel_fn(kernel_fn, init_fn)
    sample_once_batch_fn = monte_carlo._sample_once_kernel_fn(
        kernel_fn, init_fn, batch_size, device_count, store_on_device)

    one_sample = sample_once_fn(x1, x2, key, get)
    one_sample_batch = sample_once_batch_fn(x1, x2, key, get)
    self.assertAllClose(one_sample, one_sample_batch)

  @parameterized.named_parameters(
      test_utils.cases_from_list({
          'testcase_name': '[batch_size={}, '
                           'device_count={} '
                           'store_on_device={} '
                           'get={} '
                           ']'.format(batch_size, device_count, store_on_device,
                                      get),
          'batch_size': batch_size,
          'device_count': device_count,
          'store_on_device': store_on_device,
          'get': get,
      } for batch_size in BATCH_SIZES for device_count in DEVICE_COUNTS
                          for store_on_device in STORE_ON_DEVICE
                          for get in ALL_GET))
  def test_batch_sample_once(self, batch_size, device_count, store_on_device,
                             get):
    test_utils.stub_out_pmap(batching, device_count)

    x1, x2, init_fn, apply_fn, _, key = _get_inputs_and_model()
    kernel_fn = nt.empirical_kernel_fn(apply_fn)
    sample_once_fn = monte_carlo._sample_once_kernel_fn(
        kernel_fn, init_fn, device_count=0)
    batch_sample_once_fn = batching.batch(sample_once_fn, batch_size,
                                          device_count, store_on_device)
    one_sample = sample_once_fn(x1, x2, key, get)
    one_batch_sample = batch_sample_once_fn(x1, x2, key, get)
    self.assertAllClose(one_sample, one_batch_sample)

  @parameterized.named_parameters(
      test_utils.cases_from_list({
          'testcase_name': '[batch_size={}, '
                           'device_count={} '
                           'store_on_device={} '
                           ']'.format(batch_size, device_count, store_on_device
                                     ),
          'batch_size': batch_size,
          'device_count': device_count,
          'store_on_device': store_on_device,
      } for batch_size in BATCH_SIZES for device_count in DEVICE_COUNTS
                          for store_on_device in STORE_ON_DEVICE))
  def test_sample_vs_analytic_nngp(self, batch_size, device_count,
                                   store_on_device):
    test_utils.stub_out_pmap(batching, device_count)

    x1, x2, init_fn, apply_fn, stax_kernel_fn, key = _get_inputs_and_model(
        WIDTH, 256, jax.default_backend() == 'tpu')

    sample = monte_carlo.monte_carlo_kernel_fn(init_fn, apply_fn, key, 200,
                                               batch_size, device_count,
                                               store_on_device)

    ker_empirical = sample(x1, x2, 'nngp')
    ker_analytic = stax_kernel_fn(x1, x2, 'nngp')

    test_utils.assert_close_matrices(self, ker_analytic, ker_empirical, 2e-2)

  @parameterized.named_parameters(
      test_utils.cases_from_list({
          'testcase_name': '[batch_size={}, '
                           'device_count={} '
                           'store_on_device={} '
                           ']'.format(batch_size, device_count, store_on_device
                                     ),
          'batch_size': batch_size,
          'device_count': device_count,
          'store_on_device': store_on_device,
      } for batch_size in BATCH_SIZES for device_count in DEVICE_COUNTS
                          for store_on_device in STORE_ON_DEVICE))
  def test_monte_carlo_vs_analytic_ntk(self, batch_size, device_count,
                                       store_on_device):
    test_utils.stub_out_pmap(batching, device_count)

    x1, x2, init_fn, apply_fn, stax_kernel_fn, key = _get_inputs_and_model(
        WIDTH, 2, jax.default_backend() == 'tpu')

    sample = monte_carlo.monte_carlo_kernel_fn(init_fn, apply_fn, key, 100,
                                               batch_size, device_count,
                                               store_on_device,
                                               vmap_axes=0)

    ker_empirical = sample(x1, x2, 'ntk')
    ker_analytic = stax_kernel_fn(x1, x2, 'ntk')

    test_utils.assert_close_matrices(self, ker_analytic, ker_empirical, 2e-2)

  @parameterized.named_parameters(
      test_utils.cases_from_list({
          'testcase_name': '[batch_size={}, '
                           'device_count={} '
                           'store_on_device={} '
                           'get={}'
                           ']'.format(batch_size, device_count, store_on_device,
                                      get),
          'batch_size': batch_size,
          'device_count': device_count,
          'store_on_device': store_on_device,
          'get': get
      } for batch_size in BATCH_SIZES for device_count in DEVICE_COUNTS
                          for store_on_device in STORE_ON_DEVICE
                          for get in ALL_GET))
  def test_monte_carlo_generator(self, batch_size, device_count,
                                 store_on_device, get):
    test_utils.stub_out_pmap(batching, device_count)

    x1, x2, init_fn, apply_fn, stax_kernel_fn, key = _get_inputs_and_model(8, 1)
    x3, x4, _, _, _, _ = _get_inputs_and_model(8, 1)

    log_n_max = 4
    n_samples = [2**k for k in range(log_n_max)]
    sample_generator = monte_carlo.monte_carlo_kernel_fn(
        init_fn, apply_fn, key, n_samples, batch_size, device_count,
        store_on_device, vmap_axes=0)

    if get is None:
      samples_12 = sample_generator(x1, x2)
      samples_34 = sample_generator(x3, x4)

      count = 0
      for n, s_12, s_34 in zip(n_samples, samples_12, samples_34):
        sample_fn = monte_carlo.monte_carlo_kernel_fn(init_fn, apply_fn, key,
                                                      n, batch_size,
                                                      device_count,
                                                      store_on_device,
                                                      vmap_axes=0)
        sample_12 = sample_fn(x1, x2)
        sample_34 = sample_fn(x3, x4)
        self.assertAllClose(s_12, sample_12)
        self.assertAllClose(s_12, s_34)
        self.assertAllClose(s_12, sample_34)
        count += 1

      self.assertEqual(log_n_max, count)

      ker_analytic_12 = stax_kernel_fn(x1, x2, ('nngp', 'ntk'))
      ker_analytic_34 = stax_kernel_fn(x3, x4, ('nngp', 'ntk'))

    else:
      samples_12 = sample_generator(x1, x2, get)
      samples_34 = sample_generator(x3, x4, get)

      count = 0
      for n, s_12, s_34 in zip(n_samples, samples_12, samples_34):
        sample_fn = monte_carlo.monte_carlo_kernel_fn(
            init_fn, apply_fn, key, n, batch_size,
            device_count, store_on_device, vmap_axes=0)
        sample_12 = sample_fn(x1, x2, get)
        sample_34 = sample_fn(x3, x4, get)
        self.assertAllClose(s_12, sample_12)
        self.assertAllClose(s_12, s_34)
        self.assertAllClose(s_12, sample_34)
        count += 1

      self.assertEqual(log_n_max, count)

      ker_analytic_12 = stax_kernel_fn(x1, x2, get)
      ker_analytic_34 = stax_kernel_fn(x3, x4, get)

    self.assertAllClose(ker_analytic_12, s_12, atol=2., rtol=2.)
    self.assertAllClose(ker_analytic_12, ker_analytic_34)

  @parameterized.named_parameters(
      test_utils.cases_from_list({
          'testcase_name':
              f'_same_inputs={same_inputs}_batch_size={batch_size}',
          'same_inputs': same_inputs,
          'batch_size': batch_size
      } for same_inputs in [True, False] for batch_size in [1, 2]))
  def test_parallel_in_out_mc(self, same_inputs, batch_size):
    rng = random.PRNGKey(0)
    input_key1, input_key2, net_key = random.split(rng, 3)

    x1_1, x1_2, x1_3 = random.normal(input_key1, (3, 2, 5))
    x1 = (x1_1, (x1_2, x1_3))

    if same_inputs:
      x2 = None
    else:
      x2_1, x2_2, x2_3 = random.normal(input_key2, (3, 4, 5))
      x2 = (x2_1, (x2_2, x2_3))

    def net(N_out):
      return stax.parallel(stax.Dense(N_out),
                           stax.parallel(stax.Dense(N_out + 1),
                                         stax.Dense(N_out + 2)))

    # Check NNGP.
    init_fn, apply_fn, _ = net(WIDTH)

    nb_kernel_fn = monte_carlo.monte_carlo_kernel_fn(init_fn,
                                                     apply_fn,
                                                     net_key,
                                                     n_samples=4,
                                                     trace_axes=(-1,))

    kernel_fn = monte_carlo.monte_carlo_kernel_fn(init_fn,
                                                  apply_fn,
                                                  net_key,
                                                  n_samples=4,
                                                  batch_size=batch_size,
                                                  trace_axes=(-1,))

    self.assertAllClose(kernel_fn(x1, x2, 'nngp'), nb_kernel_fn(x1, x2, 'nngp'))
示例#10
0
class MaskingTest(test_utils.NeuralTangentsTestCase):
    @parameterized.named_parameters(
        test_utils.cases_from_list({
            'testcase_name':
            ' [{}_get={}_axis={}_mask={}_concat={}_p={}]'.format(
                'same_inputs' if same_inputs else 'different_inputs',
                get,
                mask_axis,
                mask_constant,
                concat,
                p,
            ),
            'same_inputs':
            same_inputs,
            'get':
            get,
            'mask_axis':
            mask_axis,
            'mask_constant':
            mask_constant,
            'concat':
            concat,
            'p':
            p,
        } for same_inputs in [False] for get in ['ntk']
                                   for concat in [None, 0, 1] for p in [0.5]
                                   for mask_axis in [(), (0, ), (1, 3)]
                                   for mask_constant in [10.]))
    def test_mask_fc(self, same_inputs, get, concat, p, mask_axis,
                     mask_constant):
        width = 512
        n_samples = 128
        tol = 0.04
        key = random.PRNGKey(1)

        x1 = random.normal(key, (4, 6, 5, 7))
        x1 = test_utils.mask(x1, mask_constant, mask_axis, key, p)

        if same_inputs:
            x2 = None
        else:
            x2 = random.normal(key, (2, 6, 5, 7))
            x2 = test_utils.mask(x2, mask_constant, mask_axis, key, p)

        nn = stax.serial(
            stax.Flatten(), stax.FanOut(3),
            stax.parallel(
                stax.serial(
                    stax.Dense(width, 1., 0.1),
                    stax.Abs(),
                    stax.DotGeneral(lhs=-0.2),
                    stax.Dense(width, 1.5, 0.01),
                ),
                stax.serial(
                    stax.Dense(width, 1.1, 0.1),
                    stax.DotGeneral(rhs=0.7),
                    stax.Erf(),
                    stax.Dense(width if concat != 1 else 512, 1.5, 0.1),
                ),
                stax.serial(
                    stax.DotGeneral(rhs=0.5),
                    stax.Dense(width, 1.2),
                    stax.ABRelu(-0.2, 0.4),
                    stax.Dense(width if concat != 1 else 1024, 1.3, 0.2),
                )),
            (stax.FanInSum() if concat is None else stax.FanInConcat(concat)),
            stax.Dense(width, 2., 0.01), stax.Relu())

        if get == 'nngp':
            init_fn, apply_fn, kernel_fn = stax.serial(
                nn, stax.Dense(width, 1., 0.1))
        elif get == 'ntk':
            init_fn, apply_fn, kernel_fn = stax.serial(nn,
                                                       stax.Dense(1, 1., 0.1))
        else:
            raise ValueError(get)

        kernel_fn_mc = nt.monte_carlo_kernel_fn(
            init_fn,
            apply_fn,
            key,
            n_samples,
            device_count=0 if concat in (0, -2) else -1,
            implementation=2,
            vmap_axes=None if concat in (0, -2) else 0,
        )

        kernel_fn = jit(kernel_fn, static_argnums=(2, ))
        exact = kernel_fn(x1, x2, get, mask_constant=mask_constant)
        empirical = kernel_fn_mc(x1, x2, get=get, mask_constant=mask_constant)
        test_utils.assert_close_matrices(self, empirical, exact, tol)

    @parameterized.named_parameters(
        test_utils.cases_from_list({
            'testcase_name':
            ' [{}_get={}_axis={}_mask={}_concat={}_{}_p={}_n={}_{}]'
            ''.format('same_inputs' if same_inputs else 'different_inputs',
                      get, mask_axis, mask_constant, concat, proj, p, n,
                      'transpose' if transpose else ''),
            'same_inputs':
            same_inputs,
            'get':
            get,
            'mask_axis':
            mask_axis,
            'mask_constant':
            mask_constant,
            'concat':
            concat,
            'proj':
            proj,
            'p':
            p,
            'n':
            n,
            'transpose':
            transpose
        } for proj in ['flatten', 'avg'] for same_inputs in [False]
                                   for get in ['ntk'] for n in [0, 1]
                                   for concat in [None] + list(range(n + 1))
                                   for mask_constant in [10.] for p in [0.5]
                                   for transpose in [True, False]
                                   for mask_axis in [(), (0, ), (0, 1, 2, 3)]))
    def test_mask_conv(self, same_inputs, get, mask_axis, mask_constant,
                       concat, proj, p, n, transpose):
        test_utils.skip_test(self)
        if default_backend() == 'gpu' and n > 3:
            raise absltest.SkipTest('>=4D-CNN is not supported on GPUs.')

        width = 256
        n_samples = 256
        tol = 0.03
        key = random.PRNGKey(1)

        spatial_shape = ((1, 2, 3, 2, 1) if transpose else (15, 8, 9))[:n]
        filter_shape = ((2, 3, 1, 2, 1) if transpose else (7, 2, 3))[:n]
        strides = (2, 1, 3, 2, 3)[:n]
        spatial_spec = 'HWDZX'[:n]
        dimension_numbers = ('N' + spatial_spec + 'C', 'OI' + spatial_spec,
                             'N' + spatial_spec + 'C')

        x1 = np.cos(random.normal(key, (2, ) + spatial_shape + (2, )))
        x1 = test_utils.mask(x1, mask_constant, mask_axis, key, p)

        if same_inputs:
            x2 = None
        else:
            x2 = np.cos(random.normal(key, (4, ) + spatial_shape + (2, )))
            x2 = test_utils.mask(x2, mask_constant, mask_axis, key, p)

        def get_attn():
            return stax.GlobalSelfAttention(
                n_chan_out=width,
                n_chan_key=width,
                n_chan_val=int(np.round(float(width) / int(np.sqrt(width)))),
                n_heads=int(np.sqrt(width)),
            ) if proj == 'avg' else stax.Identity()

        conv = stax.ConvTranspose if transpose else stax.Conv

        nn = stax.serial(
            stax.FanOut(3),
            stax.parallel(
                stax.serial(
                    conv(dimension_numbers=dimension_numbers,
                         out_chan=width,
                         strides=strides,
                         filter_shape=filter_shape,
                         padding='CIRCULAR',
                         W_std=1.5,
                         b_std=0.2),
                    stax.LayerNorm(axis=(1, -1)),
                    stax.Abs(),
                    stax.DotGeneral(rhs=0.9),
                    conv(dimension_numbers=dimension_numbers,
                         out_chan=width,
                         strides=strides,
                         filter_shape=filter_shape,
                         padding='VALID',
                         W_std=1.2,
                         b_std=0.1),
                ),
                stax.serial(
                    conv(dimension_numbers=dimension_numbers,
                         out_chan=width,
                         strides=strides,
                         filter_shape=filter_shape,
                         padding='SAME',
                         W_std=0.1,
                         b_std=0.3),
                    stax.Relu(),
                    stax.Dropout(0.7),
                    conv(dimension_numbers=dimension_numbers,
                         out_chan=width,
                         strides=strides,
                         filter_shape=filter_shape,
                         padding='VALID',
                         W_std=0.9,
                         b_std=1.),
                ),
                stax.serial(
                    get_attn(),
                    conv(dimension_numbers=dimension_numbers,
                         out_chan=width,
                         strides=strides,
                         filter_shape=filter_shape,
                         padding='CIRCULAR',
                         W_std=1.,
                         b_std=0.1),
                    stax.Erf(),
                    stax.Dropout(0.2),
                    stax.DotGeneral(rhs=0.7),
                    conv(dimension_numbers=dimension_numbers,
                         out_chan=width,
                         strides=strides,
                         filter_shape=filter_shape,
                         padding='VALID',
                         W_std=1.,
                         b_std=0.1),
                )),
            (stax.FanInSum() if concat is None else stax.FanInConcat(concat)),
            get_attn(),
            {
                'avg': stax.GlobalAvgPool(),
                'sum': stax.GlobalSumPool(),
                'flatten': stax.Flatten(),
            }[proj],
        )

        if get == 'nngp':
            init_fn, apply_fn, kernel_fn = stax.serial(
                nn, stax.Dense(width, 1., 0.))
        elif get == 'ntk':
            init_fn, apply_fn, kernel_fn = stax.serial(nn,
                                                       stax.Dense(1, 1., 0.))
        else:
            raise ValueError(get)

        kernel_fn_mc = nt.monte_carlo_kernel_fn(
            init_fn,
            apply_fn,
            key,
            n_samples,
            device_count=0 if concat in (0, -n) else -1,
            implementation=2,
            vmap_axes=None if concat in (0, -n) else 0,
        )

        kernel_fn = jit(kernel_fn, static_argnums=(2, ))
        exact = kernel_fn(x1, x2, get, mask_constant=mask_constant)
        empirical = kernel_fn_mc(x1, x2, get=get, mask_constant=mask_constant)
        test_utils.assert_close_matrices(self, empirical, exact, tol)
示例#11
0
test_utils.update_test_tolerance()

prandom.seed(1)


@parameterized.named_parameters(
    test_utils.cases_from_list(
        {
            'testcase_name':
            ' [{}_out={}_in={}]'.format(
                'same_inputs' if same_inputs else 'different_inputs',
                readout[0].__name__, readin[0].__name__),
            'same_inputs':
            same_inputs,
            'readout':
            readout,
            'readin':
            readin
        } for same_inputs in [False, True] for readout in
        [stax.Flatten(), stax.GlobalAvgPool(),
         stax.Identity()]
        for readin in [stax.Flatten(),
                       stax.GlobalAvgPool(),
                       stax.Identity()]))
class DiagonalTest(test_utils.NeuralTangentsTestCase):
    def _get_kernel_fn(self, same_inputs, readin, readout):
        key = random.PRNGKey(1)
        x1 = random.normal(key, (2, 5, 6, 3))
        x2 = None if same_inputs else random.normal(key, (3, 5, 6, 3))
        layers = [readin]
        filter_shape = (2, 3) if readin[0].__name__ == 'Identity' else ()
示例#12
0
class EmpiricalTest(test_utils.NeuralTangentsTestCase):

    # We use a three layer deep linear network for testing.
    @classmethod
    def f(cls, x, params, do_alter, do_shift_x=True):
        w1, w2, b = params
        if do_alter:
            b *= 2.
            w1 += 5.
            w2 /= 0.9
        if do_shift_x:
            x = x * 2 + 1.
        return [
            0.5 * np.dot(np.dot(x.T, w1), x) + np.dot(w2, x) + b,
            (np.dot(w1, x), w2)
        ]

    @classmethod
    def f_lin_exact(cls, x0, x, params, do_alter, do_shift_x=True):
        w1, w2, b = params
        f0 = EmpiricalTest.f(x0, params, do_alter, do_shift_x)
        if do_shift_x:
            x0 = x0 * 2 + 1.
            x = x * 2 + 1.
        dx = x - x0
        if do_alter:
            b *= 2.
            w1 += 5.
            w2 /= 0.9
        return tree_map(
            operator.add, f0,
            [np.dot(np.dot(x0.T, w1) + w2, dx), (np.dot(w1, dx), 0.)])

    @classmethod
    def _get_init_data(cls, shape):
        key = random.PRNGKey(0)
        key, s1, s2, s3, = random.split(key, 4)
        w1 = random.normal(s1, shape)
        w1 = 0.5 * (w1 + w1.T)
        w2 = random.normal(s2, shape)
        b = random.normal(s3, (1, ) * (len(shape) - 1) + (shape[-1], ))
        params = (w1, w2, b)
        key, split = random.split(key)
        x0 = random.normal(split, (shape[-1], 1))
        return key, params, x0

    @parameterized.named_parameters(
        test_utils.cases_from_list({
            'testcase_name': '_{}'.format(shape),
            'shape': shape
        } for shape in TAYLOR_MATRIX_SHAPES))
    def testLinearization(self, shape):
        key, params, x0 = self._get_init_data(shape)

        f_lin = nt.linearize(EmpiricalTest.f, x0)

        for _ in range(TAYLOR_RANDOM_SAMPLES):
            for do_alter in [True, False]:
                for do_shift_x in [True, False]:
                    key, split = random.split(key)
                    x = random.normal(split, (shape[-1], 1))
                    self.assertAllClose(
                        EmpiricalTest.f_lin_exact(x0,
                                                  x,
                                                  params,
                                                  do_alter,
                                                  do_shift_x=do_shift_x),
                        f_lin(x, params, do_alter, do_shift_x=do_shift_x))

    @parameterized.named_parameters(
        test_utils.cases_from_list({
            'testcase_name': '_{}'.format(shape),
            'shape': shape
        } for shape in TAYLOR_MATRIX_SHAPES))
    def testTaylorExpansion(self, shape):
        def f_2_exact(x0, x, params, do_alter, do_shift_x=True):
            w1, w2, b = params
            f_lin = EmpiricalTest.f_lin_exact(x0, x, params, do_alter,
                                              do_shift_x)
            if do_shift_x:
                x0 = x0 * 2 + 1.
                x = x * 2 + 1.
            if do_alter:
                b *= 2.
                w1 += 5.
                w2 /= 0.9
            dx = x - x0
            return tree_map(operator.add, f_lin,
                            [0.5 * np.dot(np.dot(dx.T, w1), dx), (0., 0.)])

        key, params, x0 = self._get_init_data(shape)

        f_lin = nt.taylor_expand(EmpiricalTest.f, x0, 1)
        f_2 = nt.taylor_expand(EmpiricalTest.f, x0, 2)

        for _ in range(TAYLOR_RANDOM_SAMPLES):
            for do_alter in [True, False]:
                for do_shift_x in [True, False]:
                    key, split = random.split(key)
                    x = random.normal(split, (shape[-1], 1))
                    self.assertAllClose(
                        EmpiricalTest.f_lin_exact(x0,
                                                  x,
                                                  params,
                                                  do_alter,
                                                  do_shift_x=do_shift_x),
                        f_lin(x, params, do_alter, do_shift_x=do_shift_x))
                    self.assertAllClose(
                        f_2_exact(x0,
                                  x,
                                  params,
                                  do_alter,
                                  do_shift_x=do_shift_x),
                        f_2(x, params, do_alter, do_shift_x=do_shift_x))

    @parameterized.named_parameters(
        test_utils.cases_from_list({
            'testcase_name':
            '_train_shape={}_test_shape={}_network={}_{}'.format(
                train, test, network, name),
            'train_shape':
            train,
            'test_shape':
            test,
            'network':
            network,
            'name':
            name,
            'kernel_fn':
            kernel_fn
        } for train, test, network in zip(TRAIN_SHAPES, TEST_SHAPES, NETWORK)
                                   for name, kernel_fn in KERNELS.items()))
    def testNTKAgainstDirect(self, train_shape, test_shape, network, name,
                             kernel_fn):
        key = random.PRNGKey(0)
        key, self_split, other_split = random.split(key, 3)
        data_self = random.normal(self_split, train_shape)
        data_other = random.normal(other_split, test_shape)

        implicit, direct, _ = kernel_fn(key,
                                        train_shape[1:],
                                        network,
                                        diagonal_axes=(),
                                        trace_axes=())

        implicit_batched, direct_batched, _ = kernel_fn(key,
                                                        train_shape[1:],
                                                        network,
                                                        diagonal_axes=(),
                                                        trace_axes=(),
                                                        vmap_axes=0)

        g = implicit(data_self, None)
        g_direct = direct(data_self, None)
        g_batched = implicit_batched(data_self, None)
        g_direct_batched = direct_batched(data_self, None)
        self.assertAllClose(g, g_direct)
        self.assertAllClose(g, g_batched)
        self.assertAllClose(g, g_direct_batched)

        g = implicit(data_other, data_self)
        g_direct = direct(data_other, data_self)
        g_batched = implicit_batched(data_other, data_self)
        g_direct_batched = direct_batched(data_other, data_self)
        self.assertAllClose(g, g_direct)
        self.assertAllClose(g, g_batched)
        self.assertAllClose(g, g_direct_batched)

    @parameterized.named_parameters(
        test_utils.cases_from_list(
            {
                'testcase_name':
                '_diagonal_axes={}_trace_axes={}'.format(
                    diagonal_axes, trace_axes),
                'diagonal_axes':
                diagonal_axes,
                'trace_axes':
                trace_axes,
            } for diagonal_axes in [(), (0, ), (0, 1), (0, 1, 2), (
                0, 1, 2, 3), (-1, ), (-2, ), (0, -1), (1, -2), (2, 3), (3, 0,
                                                                        2)]
            for trace_axes in [(), (0, ), (0, 1), (-1, ), (1, ), (
                0, -1), (-1, -2), (0, 1, 2,
                                   3), (3, 1, 2,
                                        0), (1, 2,
                                             3), (-3,
                                                  -2), (-3,
                                                        -1), (-2,
                                                              -4), (2, 0, -1)])
    )
    def testAxes(self, diagonal_axes, trace_axes):
        key = random.PRNGKey(0)
        key, self_split, other_split = random.split(key, 3)
        data_self = random.normal(self_split, (4, 5, 6, 3))
        data_other = random.normal(other_split, (2, 5, 6, 3))

        _diagonal_axes = tuple(d % data_self.ndim for d in diagonal_axes)
        _trace_axes = tuple(t % data_self.ndim for t in trace_axes)

        if any(d == c for d in _diagonal_axes for c in _trace_axes):
            raise absltest.SkipTest(
                'diagonal axes must be different from channel axes.')

        get_kernel = KERNELS['empirical_logits_3']
        kwargs = dict(key=key,
                      input_shape=(5, 6, 3),
                      network=CONV,
                      diagonal_axes=diagonal_axes,
                      trace_axes=trace_axes)

        implicit, direct, nngp = get_kernel(**kwargs)
        implicit_batched, direct_batched, _ = get_kernel(**kwargs, vmap_axes=0)

        n_marg = len(_diagonal_axes)
        n_chan = len(_trace_axes)

        g_nngp = nngp(data_self, None)
        self.assertEqual(2 * (data_self.ndim - n_chan) - n_marg, g_nngp.ndim)

        g_direct = direct(data_self, None)
        self.assertEqual(g_nngp.shape, g_direct.shape)

        g_direct_batched = direct_batched(data_self, None)
        g = implicit(data_self, None)
        g_batched = implicit_batched(data_self, None)

        self.assertAllClose(g_direct, g)
        self.assertAllClose(g_direct, g_direct_batched)
        self.assertAllClose(g_direct, g_batched)

        if 0 not in _trace_axes and 0 not in _diagonal_axes:
            g_nngp = nngp(data_other, data_self)
            self.assertEqual(2 * (data_self.ndim - n_chan) - n_marg,
                             g_nngp.ndim)

            g_direct = direct(data_other, data_self)
            self.assertEqual(g_nngp.shape, g_direct.shape)

            g_direct_batched = direct_batched(data_other, data_self)
            g = implicit(data_other, data_self)
            g_batched = implicit_batched(data_other, data_self)

            self.assertAllClose(g_direct, g)
            self.assertAllClose(g_direct, g_direct_batched)
            self.assertAllClose(g_direct, g_batched)

    @parameterized.named_parameters(
        test_utils.cases_from_list(
            {
                'testcase_name': '_same_inputs={}'.format(same_inputs),
                'same_inputs': same_inputs
            } for same_inputs in [True, False]))
    def test_parallel_in_out(self, same_inputs):
        rng = random.PRNGKey(0)
        input_key1, input_key2, net_key = random.split(rng, 3)

        x1_1, x1_2 = np.split(random.normal(input_key1, (3, 21)), (10, ),
                              axis=1)
        x2_1, x2_2 = np.split(random.normal(input_key2, (4, 21)), (10, ),
                              axis=1)

        x1 = (x1_1, x1_2)
        x2 = (x2_1, x2_2) if not same_inputs else None

        def layer(N_out):
            return stax.parallel(stax.Dense(N_out), stax.Dense(N_out + 1))

        init_fn, apply_fn, _ = stax.serial(layer(1024), layer(1))

        _, params = init_fn(net_key, (x1_1.shape, x1_2.shape))

        implicit_kernel_fn = jit(
            nt.empirical_ntk_fn(apply_fn, implementation=2))
        direct_kernel_fn = jit(nt.empirical_ntk_fn(apply_fn, implementation=1))
        implicit_batched_kernel_fn = jit(
            nt.empirical_ntk_fn(apply_fn, vmap_axes=(0, 0), implementation=2))
        direct_batched_kernel_fn = jit(
            nt.empirical_ntk_fn(apply_fn, vmap_axes=(0, 0), implementation=1))

        k_direct = direct_kernel_fn(x1, x2, params)

        self.assertAllClose(k_direct, implicit_kernel_fn(x1, x2, params))
        self.assertAllClose(k_direct, direct_batched_kernel_fn(x1, x2, params))
        self.assertAllClose(k_direct,
                            implicit_batched_kernel_fn(x1, x2, params))

        nngp_kernel_fn = jit(nt.empirical_nngp_fn(apply_fn))
        nngp = nngp_kernel_fn(x1, x2, params)
        self.assertEqual(len(nngp), 2)
        self.assertEqual(nngp[0].shape, (3, 3 if same_inputs else 4))
        self.assertEqual(nngp[1].shape, (3, 3 if same_inputs else 4))

    @parameterized.named_parameters(
        test_utils.cases_from_list(
            {
                'testcase_name': '_same_inputs={}'.format(same_inputs),
                'same_inputs': same_inputs
            } for same_inputs in [True, False]))
    def test_parallel_nested(self, same_inputs):
        rng = random.PRNGKey(0)
        input_key1, input_key2, net_key = random.split(rng, 3)

        x1_1, x1_2, x1_3 = np.split(random.normal(input_key1, (3, 33)),
                                    (10, 21),
                                    axis=1)
        x2_1, x2_2, x2_3 = np.split(random.normal(input_key2, (4, 33)),
                                    (10, 21),
                                    axis=1)

        x1 = ([x1_1, x1_2], x1_3)
        x2 = ([x2_1, x2_2], x2_3) if not same_inputs else None

        def layer(N_out):
            return stax.parallel(
                stax.parallel(stax.Dense(N_out), stax.Dense(N_out + 1)),
                stax.Dense(N_out + 2))

        init_fn, apply_fn, _ = stax.serial(layer(1024), layer(1))

        _, params = init_fn(net_key, tree_map(np.shape, x1))
        implicit_kernel_fn = jit(
            nt.empirical_ntk_fn(apply_fn, implementation=2))
        direct_kernel_fn = jit(nt.empirical_ntk_fn(apply_fn, implementation=1))

        implicit_batched_kernel_fn = jit(
            nt.empirical_ntk_fn(apply_fn,
                                vmap_axes=([0, 0], 0),
                                implementation=2))
        direct_batched_kernel_fn = jit(
            nt.empirical_ntk_fn(apply_fn,
                                vmap_axes=([0, 0], 0),
                                implementation=1))

        k_direct = direct_kernel_fn(x1, x2, params)

        self.assertAllClose(k_direct, implicit_kernel_fn(x1, x2, params))
        self.assertAllClose(k_direct, direct_batched_kernel_fn(x1, x2, params))
        self.assertAllClose(k_direct,
                            implicit_batched_kernel_fn(x1, x2, params))

        nngp_kernel_fn = jit(nt.empirical_nngp_fn(apply_fn))
        nngp = nngp_kernel_fn(x1, x2, params)

        self.assertEqual(len(nngp), 2)
        nngp_shape = (3, 3 if same_inputs else 4)
        self.assertEqual(nngp[0][0].shape, nngp_shape)
        self.assertEqual(nngp[0][1].shape, nngp_shape)
        self.assertEqual(nngp[1].shape, nngp_shape)

    @parameterized.named_parameters(
        test_utils.cases_from_list(
            {
                'testcase_name': '_same_inputs={}'.format(same_inputs),
                'same_inputs': same_inputs
            } for same_inputs in [True, False]))
    def test_vmap_axes(self, same_inputs):
        n1, n2 = 3, 4
        c1, c2, c3 = 9, 5, 7
        h2, h3, w3 = 6, 8, 2

        def get_x(n, k):
            k1, k2, k3 = random.split(k, 3)
            x1 = random.normal(k1, (n, c1))
            x2 = random.normal(k2, (h2, n, c2))
            x3 = random.normal(k3, (c3, w3, n, h3))
            x = [(x1, x2), x3]
            return x

        x1 = get_x(n1, random.PRNGKey(1))
        x2 = get_x(n2, random.PRNGKey(2)) if not same_inputs else None

        p1 = random.normal(random.PRNGKey(5), (n1, h2, h2))
        p2 = None if same_inputs else random.normal(random.PRNGKey(6),
                                                    (n2, h2, h2))

        init_fn, apply_fn, _ = stax.serial(
            stax.parallel(
                stax.parallel(
                    stax.serial(stax.Dense(4, 2., 0.1), stax.Relu(),
                                stax.Dense(3, 1., 0.15)),  # 1
                    stax.serial(
                        stax.Conv(7, (2, ),
                                  padding='SAME',
                                  dimension_numbers=('HNC', 'OIH', 'NHC')),
                        stax.Erf(), stax.Aggregate(1, 0, -1),
                        stax.GlobalAvgPool(), stax.Dense(3, 0.5, 0.2)),  # 2
                ),
                stax.serial(
                    stax.Conv(5, (2, 3),
                              padding='SAME',
                              dimension_numbers=('CWNH', 'IOHW', 'HWCN')),
                    stax.Sin(),
                )  # 3
            ),
            stax.parallel(
                stax.FanInSum(),
                stax.Conv(2, (2, 1),
                          dimension_numbers=('HWCN', 'OIHW', 'HNWC'))))

        _, params = init_fn(random.PRNGKey(3), tree_map(np.shape, x1))
        implicit = jit(nt.empirical_ntk_fn(apply_fn, implementation=2))
        direct = jit(nt.empirical_ntk_fn(apply_fn, implementation=1))

        implicit_batched = jit(
            nt.empirical_ntk_fn(apply_fn,
                                vmap_axes=([(0, 1), 2], [-2,
                                                         -3], dict(pattern=0)),
                                implementation=2))
        direct_batched = jit(
            nt.empirical_ntk_fn(apply_fn,
                                vmap_axes=([(-2, -2),
                                            -2], [0, 1], dict(pattern=-3)),
                                implementation=1))

        k = direct(x1, x2, params, pattern=(p1, p2))

        self.assertAllClose(k, implicit(x1, x2, params, pattern=(p1, p2)))
        self.assertAllClose(k, direct_batched(x1, x2, params,
                                              pattern=(p1, p2)))
        self.assertAllClose(k,
                            implicit_batched(x1, x2, params, pattern=(p1, p2)))
示例#13
0
class JacobianRulesTest(test_utils.NeuralTangentsTestCase):
    def _assert_is_diagonal(self, j, axis1, axis2, constant_diagonal: bool):
        c = j.shape[axis1]
        self.assertEqual(c, j.shape[axis2])
        mask_shape = [c if i in (axis1, axis2) else 1 for i in range(j.ndim)]
        mask = np.eye(c, dtype=np.bool_).reshape(mask_shape)

        # Check that removing the diagonal makes the array all 0.
        j_masked = np.where(mask, np.zeros((), j.dtype), j)
        self.assertAllClose(np.zeros_like(j, j.dtype), j_masked)

        if constant_diagonal:
            # Check that diagonal is constant.
            if j.size != 0:
                j_diagonals = np.diagonal(j, axis1=axis1, axis2=axis2)
                self.assertAllClose(np.min(j_diagonals, -1),
                                    np.max(j_diagonals, -1))

    def _assert_constant(self, j, axis):
        if axis is not None:
            j = np.moveaxis(j, axis, 0)
            j = list(j)
            for ji in j:
                self.assertAllClose(j[0], ji)

    def _compare_jacobians(self, j_fwd, j_rev, j_rule, primitive):
        if primitive == lax.convert_element_type_p:
            # Check that only one of fwd/red Jacobians matches the rule.
            e_fwd, e_rev = None, None
            try:
                self.assertAllClose(j_fwd, j_rule)
            except Exception as e:
                logging.exception(
                    'Forward-mode Jacobian does not match the rule.')
                e_fwd = e

            try:
                self.assertAllClose(j_rev, j_rule)
            except Exception as e:
                logging.exception(
                    'Reverse-mode Jacobian does not match the rule.')
                e_rev = e

            if e_fwd is not None and e_rev is not None:
                raise ValueError(e_fwd, e_rev)

        else:
            if primitive == lax.reshape_p:
                # Reshape Jacobian is special-case defined as identity.
                j_rule = j_rule.reshape(j_fwd.shape)

            self.assertAllClose(j_fwd, j_rev)
            if j_rule is not None:
                self.assertAllClose(j_fwd, j_rule)
                self.assertAllClose(j_rev, j_rule)

    def _test_primitive(self, primitive: Optional[Primitive], shapes, dtype,
                        params):
        xs = _get_inputs(shapes, dtype)
        n = len(xs)
        eqn, f = _get_f_and_eqn(params, primitive, *xs)

        out = f(*xs)
        cts_in = ShapedArray(out.shape, out.dtype)

        argnums = tuple(range(n))
        js_fwd = jax.jacfwd(f, argnums)(*xs)
        js_rev = jax.jacrev(f, argnums)(*xs)

        for idx in range(n):
            if primitive == lax.conv_general_dilated_p and idx == 0:
                raise absltest.SkipTest(
                    'Jacobian of CNN wrt inputs not implemented.')

            if primitive == lax.div_p and idx == 1:
                raise absltest.SkipTest(
                    'Division is linear only in the first arg.')

            invals = _get_invals(idx, *xs)
            j_fwd, j_rev = js_fwd[idx], js_rev[idx]

            if primitive in rules.JACOBIAN_RULES:
                j_rule = rules.JACOBIAN_RULES[primitive](eqn, idx, invals,
                                                         cts_in)
            else:
                warnings.warn(
                    f'Jacobian rule for {primitive} at position {idx} not '
                    f'found.')
                j_rule = None

            with self.subTest(f'Jacobian ({idx})'):
                self._compare_jacobians(j_fwd, j_rev, j_rule, primitive)

            structure = rules.STRUCTURE_RULES[primitive](eqn, idx, invals,
                                                         cts_in)

            j = j_fwd if j_rule is None else j_rule

            if primitive == lax.reshape_p:
                out_ndim = xs[0].ndim
                j = j.transpose(
                    tuple(xs[0].ndim + i
                          for i in onp.argsort(structure.in_trace)) +
                    tuple(i for i in onp.argsort(structure.in_trace)))
                j = j.reshape(xs[0].shape +
                              tuple(xs[0].shape[i]
                                    for i in onp.argsort(structure.in_trace)))

            else:
                out_ndim = out.ndim

            with self.subTest(f'Diagonal axes ({idx})'):
                for i, o in zip(structure.in_diagonal, structure.out_diagonal):
                    self._assert_is_diagonal(j=j,
                                             axis1=out_ndim + i[idx],
                                             axis2=o,
                                             constant_diagonal=False)

            with self.subTest(f'Constant diagonal axes ({idx})'):
                for i, o in zip(structure.in_trace, structure.out_trace):
                    self._assert_is_diagonal(j=j,
                                             axis1=out_ndim + i,
                                             axis2=o,
                                             constant_diagonal=True)

            with self.subTest(f'Input broadcast axes ({idx})'):
                for i in structure.in_broadcast:
                    self._assert_constant(j=j, axis=i)

            with self.subTest(f'Output broadcast axes ({idx})'):
                for i in structure.out_broadcast:
                    self._assert_constant(j=j, axis=i)

    @parameterized.parameters(
        test_utils.cases_from_list(
            dict(
                primitive=primitive,
                shape=shape,
                dtype=dtype,
                params=params,
            ) for shape in _SHAPES for dtype in _DTYPES
            for primitive in _UNARY_PRIMITIVES.keys()
            for params in _UNARY_PRIMITIVES[primitive](shape, dtype)))
    def test_unary(self, primitive: Optional[Primitive], shape, dtype, params):
        if primitive == jax._src.dispatch.device_put_p:
            # Can't instantiate devices at test generation time; using subtests.
            for device in [None] + jax.devices() + jax.devices('cpu'):
                with self.subTest(device=device):
                    params = {'device': device}
                    self._test_primitive(primitive, [shape], dtype, params)

        else:
            self._test_primitive(primitive, [shape], dtype, params)

    @parameterized.parameters(
        test_utils.cases_from_list(
            dict(primitive=primitive,
                 shape1=shape1,
                 shape2=shape2,
                 dtype=dtype,
                 params=params) for shape1 in _SHAPES for shape2 in _SHAPES
            for dtype in _DTYPES for primitive in _BINARY_PRIMITIVES.keys()
            for params in _BINARY_PRIMITIVES[primitive](shape1, shape2)))
    def test_binary(self, primitive: Optional[Primitive], shape1, shape2,
                    dtype, params):
        # TODO(romann): revisit when bugs below are fixed.
        if primitive == lax.conv_general_dilated_p:
            if jax.default_backend() == 'tpu':
                raise absltest.SkipTest('http://b/235167364')

            elif jax.default_backend(
            ) == 'gpu' and params['batch_group_count'] != 1:
                raise absltest.SkipTest('http://b/235485533')

        if len(shape1) > 3 or len(shape2) > 3:
            test_utils.skip_test(self)

        self._test_primitive(primitive, [shape1, shape2], dtype, params)

    @parameterized.parameters(
        test_utils.cases_from_list(
            dict(
                primitive=primitive, shapes=shapes, dtype=dtype, params=params)
            for shapes in _concat_shapes(4, *_SHAPES) for dtype in _DTYPES
            for primitive in _N_ARY_PRIMITIVES.keys()
            for params in _N_ARY_PRIMITIVES[primitive](*shapes)))
    def test_n_ary(self, primitive: Optional[Primitive], shapes, dtype,
                   params):
        self._test_primitive(primitive, shapes, dtype, params)
示例#14
0
class FanInTest(test_utils.NeuralTangentsTestCase):

  @classmethod
  def _get_phi(cls, i):
    return {
        0: stax.Relu(),
        1: stax.Erf(),
        2: stax.Abs()
    }[i % 3]

  @parameterized.named_parameters(
      test_utils.cases_from_list(
          {
              'testcase_name':
                  ' [{}_axis={}_n_branches={}_{}_{}_{}]'.format(
                      'same_inputs' if same_inputs else 'different_inputs',
                      axis,
                      n_branches,
                      get,
                      branch_in,
                      fan_in_mode),
              'same_inputs':
                  same_inputs,
              'axis':
                  axis,
              'n_branches':
                  n_branches,
              'get':
                  get,
              'branch_in':
                  branch_in,
              'fan_in_mode':
                  fan_in_mode,
          }
          for same_inputs in [False]
          for axis in [0, 1]
          for n_branches in [3] for get in ['ntk']
          for branch_in in ['dense_before_branch_in',
                            'dense_after_branch_in']
          for fan_in_mode in ['FanInSum', 'FanInConcat', 'FanInProd']))
  def test_fan_in_fc(self, same_inputs, axis, n_branches, get, branch_in,
                     fan_in_mode):
    if fan_in_mode in ['FanInSum', 'FanInProd']:
      if axis != 0:
        raise absltest.SkipTest('`FanInSum` and `FanInProd` are skipped when '
                                'axis != 0.')
      axis = None
    if (fan_in_mode == 'FanInSum' or
        axis == 0) and branch_in == 'dense_after_branch_in':
      raise absltest.SkipTest('`FanInSum` and `FanInConcat(0)` '
                              'require `is_gaussian`.')

    if ((axis == 1 or fan_in_mode == 'FanInProd') and
        branch_in == 'dense_before_branch_in'):
      raise absltest.SkipTest(
          '`FanInConcat` or `FanInProd` on feature axis requires a dense layer '
          'after concatenation or Hadamard product.')
    if fan_in_mode == 'FanInSum':
      fan_in_layer = stax.FanInSum()
    elif fan_in_mode == 'FanInProd':
      fan_in_layer = stax.FanInProd()
    else:
      fan_in_layer = stax.FanInConcat(axis)

    if n_branches != 2:
      test_utils.skip_test(self)

    key = random.PRNGKey(1)
    X0_1 = np.cos(random.normal(key, (4, 3)))
    X0_2 = None if same_inputs else random.normal(key, (8, 3))

    width = 1024
    n_samples = 256 * 2

    if default_backend() == 'tpu':
      tol = 0.07
    else:
      tol = 0.02

    dense = stax.Dense(width, 1.25, 0.1)
    input_layers = [dense,
                    stax.FanOut(n_branches)]

    branches = []
    for b in range(n_branches):
      branch_layers = [FanInTest._get_phi(b)]
      for i in range(b):
        multiplier = 1 if axis not in (1, -1) else (1 + 0.25 * i)
        branch_layers += [
            stax.Dense(int(width * multiplier), 1. + 2 * i, 0.5 + i),
            FanInTest._get_phi(i)]

      if branch_in == 'dense_before_branch_in':
        branch_layers += [dense]
      branches += [stax.serial(*branch_layers)]

    output_layers = [
        fan_in_layer,
        stax.Relu()
    ]
    if branch_in == 'dense_after_branch_in':
      output_layers.insert(1, dense)

    nn = stax.serial(*(input_layers + [stax.parallel(*branches)] +
                       output_layers))

    if get == 'nngp':
      init_fn, apply_fn, kernel_fn = nn
    elif get == 'ntk':
      init_fn, apply_fn, kernel_fn = stax.serial(nn, stax.Dense(1, 1.25, 0.5))
    else:
      raise ValueError(get)

    kernel_fn_mc = nt.monte_carlo_kernel_fn(
        init_fn, apply_fn, key, n_samples,
        device_count=0 if axis in (0, -2) else -1,
        implementation=2,
        vmap_axes=None if axis in (0, -2) else 0,
    )

    exact = kernel_fn(X0_1, X0_2, get=get)
    empirical = kernel_fn_mc(X0_1, X0_2, get=get)
    test_utils.assert_close_matrices(self, empirical, exact, tol)

  @parameterized.named_parameters(
      test_utils.cases_from_list(
          {
              'testcase_name':
                  ' [{}_axis={}_n_branches={}_{}_{}_{}_{}]'.format(
                      'same_inputs' if same_inputs else 'different_inputs',
                      axis,
                      n_branches,
                      get,
                      branch_in,
                      readout,
                      fan_in_mode),
              'same_inputs':
                  same_inputs,
              'axis':
                  axis,
              'n_branches':
                  n_branches,
              'get':
                  get,
              'branch_in':
                  branch_in,
              'readout':
                  readout,
              'fan_in_mode':
                  fan_in_mode,
          }
          for same_inputs in [False]
          for axis in [0, 1, 2, 3]
          for n_branches in [2] for get in ['ntk']
          for branch_in in ['dense_before_branch_in', 'dense_after_branch_in']
          for readout in ['pool', 'flatten']
          for fan_in_mode in ['FanInSum', 'FanInConcat', 'FanInProd']))
  def test_fan_in_conv(self,
                       same_inputs,
                       axis,
                       n_branches,
                       get,
                       branch_in,
                       readout,
                       fan_in_mode):
    test_utils.skip_test(self)
    if fan_in_mode in ['FanInSum', 'FanInProd']:
      if axis != 0:
        raise absltest.SkipTest('`FanInSum` and `FanInProd()` are skipped when '
                                'axis != 0.')
      axis = None
    if (fan_in_mode == 'FanInSum' or
        axis in [0, 1, 2]) and branch_in == 'dense_after_branch_in':
      raise absltest.SkipTest('`FanInSum` and `FanInConcat(0/1/2)` '
                              'require `is_gaussian`.')

    if ((axis == 3 or fan_in_mode == 'FanInProd') and
        branch_in == 'dense_before_branch_in'):
      raise absltest.SkipTest('`FanInConcat` or `FanInProd` on feature axis '
                              'requires a dense layer after concatenation '
                              'or Hadamard product.')

    if fan_in_mode == 'FanInSum':
      fan_in_layer = stax.FanInSum()
    elif fan_in_mode == 'FanInProd':
      fan_in_layer = stax.FanInProd()
    else:
      fan_in_layer = stax.FanInConcat(axis)

    key = random.PRNGKey(1)
    X0_1 = random.normal(key, (2, 5, 6, 3))
    X0_2 = None if same_inputs else random.normal(key, (3, 5, 6, 3))

    if default_backend() == 'tpu':
      width = 2048
      n_samples = 1024
      tol = 0.02
    else:
      width = 1024
      n_samples = 512
      tol = 0.01

    conv = stax.Conv(out_chan=width,
                     filter_shape=(3, 3),
                     padding='SAME',
                     W_std=1.25,
                     b_std=0.1)

    input_layers = [conv,
                    stax.FanOut(n_branches)]

    branches = []
    for b in range(n_branches):
      branch_layers = [FanInTest._get_phi(b)]
      for i in range(b):
        multiplier = 1 if axis not in (3, -1) else (1 + 0.25 * i)
        branch_layers += [
            stax.Conv(
                out_chan=int(width * multiplier),
                filter_shape=(i + 1, 4 - i),
                padding='SAME',
                W_std=1.25 + i,
                b_std=0.1 + i),
            FanInTest._get_phi(i)]

      if branch_in == 'dense_before_branch_in':
        branch_layers += [conv]
      branches += [stax.serial(*branch_layers)]

    output_layers = [
        fan_in_layer,
        stax.Relu(),
        stax.GlobalAvgPool() if readout == 'pool' else stax.Flatten()
    ]
    if branch_in == 'dense_after_branch_in':
      output_layers.insert(1, conv)

    nn = stax.serial(*(input_layers + [stax.parallel(*branches)] +
                       output_layers))

    init_fn, apply_fn, kernel_fn = stax.serial(
        nn, stax.Dense(1 if get == 'ntk' else width, 1.25, 0.5))

    kernel_fn_mc = nt.monte_carlo_kernel_fn(
        init_fn,
        apply_fn,
        key,
        n_samples,
        device_count=0 if axis in (0, -4) else -1,
        implementation=2,
        vmap_axes=None if axis in (0, -4) else 0,
    )

    exact = kernel_fn(X0_1, X0_2, get=get)
    empirical = kernel_fn_mc(X0_1, X0_2, get=get)
    test_utils.assert_close_matrices(self, empirical, exact, tol)