def WideResnet(block_size, k, num_classes):
    return stax.serial(stax.Conv(16, (3, 3), padding='SAME'),
                       WideResnetGroup(block_size, int(16 * k)),
                       WideResnetGroup(block_size, int(32 * k), (2, 2)),
                       WideResnetGroup(block_size, int(64 * k), (2, 2)),
                       stax.GlobalAvgPool(), stax.Flatten(),
                       stax.Dense(num_classes, 1., 0.))
Ejemplo n.º 2
0
def _build_network(input_shape, network, out_logits):
    if len(input_shape) == 1:
        assert network == FLAT
        return stax.Dense(out_logits, W_std=2.0, b_std=0.5)
    elif len(input_shape) == 3:
        if network == POOLING:
            return stax.serial(
                stax.Conv(CONVOLUTION_CHANNELS, (3, 3), W_std=2.0, b_std=0.05),
                stax.GlobalAvgPool(),
                stax.Dense(out_logits, W_std=2.0, b_std=0.5))
        elif network == CONV:
            return stax.serial(
                stax.Conv(CONVOLUTION_CHANNELS, (1, 2), W_std=1.5, b_std=0.1),
                stax.Relu(),
                stax.Conv(CONVOLUTION_CHANNELS, (3, 2), W_std=2.0, b_std=0.05),
            )
        elif network == FLAT:
            return stax.serial(
                stax.Conv(CONVOLUTION_CHANNELS, (3, 3), W_std=2.0, b_std=0.05),
                stax.Flatten(), stax.Dense(out_logits, W_std=2.0, b_std=0.5))
        else:
            raise ValueError(
                'Unexpected network type found: {}'.format(network))
    else:
        raise ValueError('Expected flat or image test input.')
Ejemplo n.º 3
0
    def test_composition_conv(self, avg_pool):
        rng = random.PRNGKey(0)
        x1 = random.normal(rng, (5, 10, 10, 3))
        x2 = random.normal(rng, (5, 10, 10, 3))

        Block = stax.serial(stax.Conv(256, (3, 3)), stax.Relu())
        if avg_pool:
            Readout = stax.serial(stax.GlobalAvgPool(), stax.Dense(10))
            marginalization = 'none'
        else:
            Readout = stax.serial(stax.Flatten(), stax.Dense(10))
            marginalization = 'auto'

        block_ker_fn, readout_ker_fn = Block[2], Readout[2]
        _, _, composed_ker_fn = stax.serial(Block, Readout)

        ker_out = readout_ker_fn(
            block_ker_fn(x1, marginalization=marginalization))
        composed_ker_out = composed_ker_fn(x1)
        self.assertAllClose(ker_out, composed_ker_out, True)

        if avg_pool:
            with self.assertRaises(ValueError):
                ker_out = readout_ker_fn(block_ker_fn(x1))

        ker_out = readout_ker_fn(
            block_ker_fn(x1, x2, marginalization=marginalization))
        composed_ker_out = composed_ker_fn(x1, x2)
        self.assertAllClose(ker_out, composed_ker_out, True)
Ejemplo n.º 4
0
def _build_network(input_shape, network, out_logits, use_dropout):
    dropout = stax.Dropout(0.9,
                           mode='train') if use_dropout else stax.Identity()
    if len(input_shape) == 1:
        assert network == 'FLAT'
        return stax.serial(stax.Dense(WIDTH, W_std=2.0, b_std=0.5), dropout,
                           stax.Dense(out_logits, W_std=2.0, b_std=0.5))
    elif len(input_shape) == 3:
        if network == POOLING:
            return stax.serial(
                stax.Conv(CONVOLUTION_CHANNELS, (2, 2), W_std=2.0, b_std=0.05),
                stax.GlobalAvgPool(), dropout,
                stax.Dense(out_logits, W_std=2.0, b_std=0.5))
        elif network == FLAT:
            return stax.serial(
                stax.Conv(CONVOLUTION_CHANNELS, (2, 2), W_std=2.0, b_std=0.05),
                stax.Flatten(), dropout,
                stax.Dense(out_logits, W_std=2.0, b_std=0.5))
        elif network == INTERMEDIATE_CONV:
            return stax.Conv(CONVOLUTION_CHANNELS, (2, 2),
                             W_std=2.0,
                             b_std=0.05)
        else:
            raise ValueError(
                'Unexpected network type found: {}'.format(network))
    else:
        raise ValueError('Expected flat or image test input.')
Ejemplo n.º 5
0
def _get_net(W_std, b_std, filter_shape, is_conv, use_pooling, is_res,
             padding, phi, strides, width, is_ntk):
  fc = partial(stax.Dense, W_std=W_std, b_std=b_std)
  conv = partial(
      stax.Conv,
      filter_shape=filter_shape,
      strides=strides,
      padding=padding,
      W_std=W_std,
      b_std=b_std)
  affine = conv(width) if is_conv else fc(width)

  res_unit = stax.serial((stax.AvgPool(
      (2, 3), None, 'SAME' if padding == 'SAME' else 'CIRCULAR')
                          if use_pooling else stax.Identity()), phi, affine)

  if is_res:
    block = stax.serial(affine, stax.FanOut(2),
                        stax.parallel(stax.Identity(), res_unit),
                        stax.FanInSum())
  else:
    block = stax.serial(affine, res_unit)

  readout = stax.serial(stax.GlobalAvgPool() if use_pooling else stax.Flatten(),
                        fc(1 if is_ntk else width))

  net = stax.serial(block, readout)
  return net
Ejemplo n.º 6
0
  def test_exp_normalized(self):
    key = random.PRNGKey(0)
    x1 = random.normal(key, (2, 6, 7, 1))
    x2 = random.normal(key, (4, 6, 7, 1))

    for do_clip in [True, False]:
      for gamma in [1., 2., 0.5]:
        for get in ['nngp', 'ntk']:
          with self.subTest(do_clip=do_clip, gamma=gamma, get=get):
            _, _, kernel_fn = stax.serial(
                stax.Conv(1, (3, 3)),
                stax.ExpNormalized(gamma, do_clip),
                stax.Conv(1, (3, 3)),
                stax.ExpNormalized(gamma, do_clip),
                stax.GlobalAvgPool(),
                stax.Dense(1)
            )
            k_12 = kernel_fn(x1, x2, get=get)
            self.assertEqual(k_12.shape, (x1.shape[0], x2.shape[0]))

            k_11 = kernel_fn(x1, None, get=get)
            self.assertEqual(k_11.shape, (x1.shape[0],) * 2)
            self.assertGreater(np.min(np.linalg.eigvalsh(k_11)), 0)

            k_22 = kernel_fn(x2, None, get=get)
            self.assertEqual(k_22.shape, (x2.shape[0],) * 2)
            self.assertGreater(np.min(np.linalg.eigvalsh(k_22)), 0)
Ejemplo n.º 7
0
  def test_hermite(self, same_inputs, degree, get, readout):
    key = random.PRNGKey(1)
    key1, key2, key = random.split(key, 3)

    if degree > 2:
      width = 10000
      n_samples = 5000
      test_utils.skip_test(self)
    else:
      width = 10000
      n_samples = 100

    x1 = np.cos(random.normal(key1, [2, 6, 6, 3]))
    x2 = x1 if same_inputs else np.cos(random.normal(key2, [3, 6, 6, 3]))

    conv_layers = [
        stax.Conv(width, (3, 3), W_std=2., b_std=0.5),
        stax.LayerNorm(),
        stax.Hermite(degree),
        stax.GlobalAvgPool() if readout == 'pool' else stax.Flatten(),
        stax.Dense(1) if get == 'ntk' else stax.Identity()]

    init_fn, apply_fn, kernel_fn = stax.serial(*conv_layers)
    analytic_kernel = kernel_fn(x1, x2, get)
    mc_kernel_fn = nt.monte_carlo_kernel_fn(init_fn, apply_fn, key, n_samples)
    mc_kernel = mc_kernel_fn(x1, x2, get)
    rot = degree / 2. * 1e-2
    test_utils.assert_close_matrices(self, mc_kernel, analytic_kernel, rot)
Ejemplo n.º 8
0
  def test_composition_conv(self, avg_pool, same_inputs):
    rng = random.PRNGKey(0)
    x1 = random.normal(rng, (3, 5, 5, 3))
    x2 = None if same_inputs else random.normal(rng, (4, 5, 5, 3))

    Block = stax.serial(stax.Conv(256, (3, 3)), stax.Relu())
    if avg_pool:
      Readout = stax.serial(stax.Conv(256, (3, 3)),
                            stax.GlobalAvgPool(),
                            stax.Dense(10))
    else:
      Readout = stax.serial(stax.Flatten(), stax.Dense(10))

    block_ker_fn, readout_ker_fn = Block[2], Readout[2]
    _, _, composed_ker_fn = stax.serial(Block, Readout)

    composed_ker_out = composed_ker_fn(x1, x2)
    ker_out_no_marg = readout_ker_fn(block_ker_fn(x1, x2,
                                                  diagonal_spatial=False))
    ker_out_default = readout_ker_fn(block_ker_fn(x1, x2))
    self.assertAllClose(composed_ker_out, ker_out_no_marg)
    self.assertAllClose(composed_ker_out, ker_out_default)

    if avg_pool:
      with self.assertRaises(ValueError):
        ker_out = readout_ker_fn(block_ker_fn(x1, x2, diagonal_spatial=True))
    else:
      ker_out_marg = readout_ker_fn(block_ker_fn(x1, x2,
                                                 diagonal_spatial=True))
      self.assertAllClose(composed_ker_out, ker_out_marg)
Ejemplo n.º 9
0
  def test_nested_parallel(self, same_inputs, kernel_type):
    platform = default_backend()
    rtol = RTOL if platform != 'tpu' else 0.05

    rng = random.PRNGKey(0)
    (input_key1,
     input_key2,
     input_key3,
     input_key4,
     mask_key,
     mc_key) = random.split(rng, 6)

    x1_1, x2_1 = _get_inputs(input_key1, same_inputs, (BATCH_SIZE, 5))
    x1_2, x2_2 = _get_inputs(input_key2, same_inputs, (BATCH_SIZE, 2, 2, 2))
    x1_3, x2_3 = _get_inputs(input_key3, same_inputs, (BATCH_SIZE, 2, 2, 3))
    x1_4, x2_4 = _get_inputs(input_key4, same_inputs, (BATCH_SIZE, 3, 4))

    m1_key, m2_key, m3_key, m4_key = random.split(mask_key, 4)

    x1_1 = test_utils.mask(
        x1_1, mask_constant=-1, mask_axis=(1,), key=m1_key, p=0.5)
    x1_2 = test_utils.mask(
        x1_2, mask_constant=-1, mask_axis=(2, 3,), key=m2_key, p=0.5)
    if not same_inputs:
      x2_3 = test_utils.mask(
          x2_3, mask_constant=-1, mask_axis=(1, 3,), key=m3_key, p=0.5)
      x2_4 = test_utils.mask(
          x2_4, mask_constant=-1, mask_axis=(2,), key=m4_key, p=0.5)

    x1 = (((x1_1, x1_2), x1_3), x1_4)
    x2 = (((x2_1, x2_2), x2_3), x2_4) if not same_inputs else None

    N_in = 2 ** 7

    # We only include dropout on non-TPU backends, because it takes large N to
    # converge on TPU.
    dropout_or_id = stax.Dropout(0.9) if platform != 'tpu' else stax.Identity()

    init_fn, apply_fn, kernel_fn = stax.parallel(
        stax.parallel(
            stax.parallel(stax.Dense(N_in),
                          stax.serial(stax.Conv(N_in + 1, (2, 2)),
                                      stax.Flatten())),
            stax.serial(stax.Conv(N_in + 2, (2, 2)),
                        dropout_or_id,
                        stax.GlobalAvgPool())),
        stax.Conv(N_in + 3, (2,)))

    kernel_fn_empirical = nt.monte_carlo_kernel_fn(
        init_fn, apply_fn, mc_key, N_SAMPLES,
        implementation=_DEFAULT_TESTING_NTK_IMPLEMENTATION,
        vmap_axes=(((((0, 0), 0), 0), (((0, 0), 0), 0), {})
                   if platform == 'tpu' else None)
    )

    test_utils.assert_close_matrices(
        self,
        kernel_fn(x1, x2, get=kernel_type, mask_constant=-1),
        kernel_fn_empirical(x1, x2, get=kernel_type, mask_constant=-1),
        rtol)
Ejemplo n.º 10
0
  def _test_activation(self, activation_fn, same_inputs, model, get,
                       rbf_gamma=None):
    if 'conv' in model:
      test_utils.skip_test(self)

    key = random.PRNGKey(1)
    key, split = random.split(key)
    output_dim = 1024 if get == 'nngp' else 1
    b_std = 0.5
    W_std = 2.0
    if activation_fn[2].__name__ == 'Sin':
      W_std = 0.9
    if activation_fn[2].__name__ == 'Rbf':
      W_std = 1.0
      b_std = 0.0

    if model == 'fc':
      rtol = 0.04
      X0_1 = random.normal(key, (4, 2))
      X0_2 = None if same_inputs else random.normal(split, (2, 2))
      affine = stax.Dense(1024, W_std, b_std)
      readout = stax.Dense(output_dim)
      depth = 1

    else:
      rtol = 0.05
      X0_1 = random.normal(key, (2, 4, 4, 3))
      X0_2 = None if same_inputs else random.normal(split, (4, 4, 4, 3))
      affine = stax.Conv(512, (3, 2), W_std=W_std, b_std=b_std, padding='SAME')
      readout = stax.serial(stax.GlobalAvgPool() if 'pool' in model else
                            stax.Flatten(),
                            stax.Dense(output_dim))
      depth = 2

    if default_backend() == 'cpu':
      num_samplings = 200
      rtol *= 2
    else:
      num_samplings = (500 if activation_fn[2].__name__ in ('Sin', 'Rbf')
                       else 300)

    init_fn, apply_fn, kernel_fn = stax.serial(
        *[affine, activation_fn]*depth, readout)
    analytic_kernel = kernel_fn(X0_1, X0_2, get)
    mc_kernel_fn = nt.monte_carlo_kernel_fn(
        init_fn, apply_fn, split, num_samplings, implementation=2,
        vmap_axes=0
    )
    empirical_kernel = mc_kernel_fn(X0_1, X0_2, get)
    test_utils.assert_close_matrices(self, analytic_kernel,
                                     empirical_kernel, rtol)

    # Check match with explicit RBF
    if rbf_gamma is not None and get == 'nngp' and model == 'fc':
      input_dim = X0_1.shape[1]
      _, _, kernel_fn = self._RBF(rbf_gamma / input_dim)
      direct_rbf_kernel = kernel_fn(X0_1, X0_2, get)
      test_utils.assert_close_matrices(self, analytic_kernel,
                                       direct_rbf_kernel, rtol)
Ejemplo n.º 11
0
    def _test_analytic_kernel_composition(self, batching_fn):
        # Check Fully-Connected.
        rng = stateless_uniform(shape=[2],
                                seed=[0, 0],
                                minval=None,
                                maxval=None,
                                dtype=tf.int32)
        keys = tf_random_split(rng)
        rng_self = keys[0]
        rng_other = keys[1]
        x_self = np.asarray(normal((8, 10), seed=rng_self))
        x_other = np.asarray(normal((2, 10), seed=rng_other))
        Block = stax.serial(stax.Dense(256), stax.Relu())

        _, _, ker_fn = Block
        ker_fn = batching_fn(ker_fn)

        _, _, composed_ker_fn = stax.serial(Block, Block)

        ker_out = ker_fn(ker_fn(x_self))
        composed_ker_out = composed_ker_fn(x_self)
        if batching_fn == batch._parallel:
            # In the parallel setting, `x1_is_x2` is not computed correctly
            # when x1==x2.
            composed_ker_out = composed_ker_out.replace(
                x1_is_x2=ker_out.x1_is_x2)
        self.assertAllClose(ker_out, composed_ker_out)

        ker_out = ker_fn(ker_fn(x_self, x_other))
        composed_ker_out = composed_ker_fn(x_self, x_other)
        if batching_fn == batch._parallel:
            composed_ker_out = composed_ker_out.replace(
                x1_is_x2=ker_out.x1_is_x2)
        self.assertAllClose(ker_out, composed_ker_out)

        # Check convolutional + pooling.
        x_self = np.asarray(normal((8, 10, 10, 3), seed=rng))
        x_other = np.asarray(normal((2, 10, 10, 3), seed=rng))

        Block = stax.serial(stax.Conv(256, (2, 2)), stax.Relu())
        Readout = stax.serial(stax.GlobalAvgPool(), stax.Dense(10))

        block_ker_fn, readout_ker_fn = Block[2], Readout[2]
        _, _, composed_ker_fn = stax.serial(Block, Readout)
        block_ker_fn = batching_fn(block_ker_fn)
        readout_ker_fn = batching_fn(readout_ker_fn)

        ker_out = readout_ker_fn(block_ker_fn(x_self))
        composed_ker_out = composed_ker_fn(x_self)
        if batching_fn == batch._parallel:
            composed_ker_out = composed_ker_out.replace(
                x1_is_x2=ker_out.x1_is_x2)
        self.assertAllClose(ker_out, composed_ker_out)
        ker_out = readout_ker_fn(block_ker_fn(x_self, x_other))
        composed_ker_out = composed_ker_fn(x_self, x_other)
        if batching_fn == batch._parallel:
            composed_ker_out = composed_ker_out.replace(
                x1_is_x2=ker_out.x1_is_x2)
        self.assertAllClose(ker_out, composed_ker_out)
Ejemplo n.º 12
0
  def test_vmap_axes(self, same_inputs):
    n1, n2 = 3, 4
    c1, c2, c3 = 9, 5, 7
    h2, h3, w3 = 6, 8, 2

    def get_x(n, k):
      k1, k2, k3 = random.split(k, 3)
      x1 = random.normal(k1, (n, c1))
      x2 = random.normal(k2, (h2, n, c2))
      x3 = random.normal(k3, (c3, w3, n, h3))
      x = [(x1, x2), x3]
      return x

    x1 = get_x(n1, random.PRNGKey(1))
    x2 = get_x(n2, random.PRNGKey(2)) if not same_inputs else None

    p1 = random.normal(random.PRNGKey(5), (n1, h2, h2))
    p2 = None if same_inputs else random.normal(random.PRNGKey(6), (n2, h2, h2))

    init_fn, apply_fn, _ = stax.serial(
        stax.parallel(
            stax.parallel(
                stax.serial(stax.Dense(4, 2., 0.1),
                            stax.Relu(),
                            stax.Dense(3, 1., 0.15)),  # 1
                stax.serial(stax.Conv(7, (2,), padding='SAME',
                                      dimension_numbers=('HNC', 'OIH', 'NHC')),
                            stax.Erf(),
                            stax.Aggregate(1, 0, -1),
                            stax.GlobalAvgPool(),
                            stax.Dense(3, 0.5, 0.2)),  # 2
            ),
            stax.serial(
                stax.Conv(5, (2, 3), padding='SAME',
                          dimension_numbers=('CWNH', 'IOHW', 'HWCN')),
                stax.Sin(),
            )  # 3
        ),
        stax.parallel(
            stax.FanInSum(),
            stax.Conv(2, (2, 1), dimension_numbers=('HWCN', 'OIHW', 'HNWC'))
        )
    )

    _, params = init_fn(random.PRNGKey(3), tree_map(np.shape, x1))
    implicit = jit(empirical._empirical_implicit_ntk_fn(apply_fn))
    direct = jit(empirical._empirical_direct_ntk_fn(apply_fn))

    implicit_batched = jit(empirical._empirical_implicit_ntk_fn(
        apply_fn, vmap_axes=([(0, 1), 2], [-2, -3], dict(pattern=0))))
    direct_batched = jit(empirical._empirical_direct_ntk_fn(
        apply_fn, vmap_axes=([(-2, -2), -2], [0, 1], dict(pattern=-3))))

    k = direct(x1, x2, params, pattern=(p1, p2))

    self.assertAllClose(k, implicit(x1, x2, params, pattern=(p1, p2)))
    self.assertAllClose(k, direct_batched(x1, x2, params, pattern=(p1, p2)))
    self.assertAllClose(k, implicit_batched(x1, x2, params, pattern=(p1, p2)))
Ejemplo n.º 13
0
def _get_net(W_std, b_std, filter_shape, is_conv, use_pooling, is_res, padding,
             phi, strides, width, is_ntk, proj_into_2d, layer_norm,
             parameterization, use_dropout):
    fc = partial(stax.Dense,
                 W_std=W_std,
                 b_std=b_std,
                 parameterization=parameterization)
    conv = partial(stax.Conv,
                   filter_shape=filter_shape,
                   strides=strides,
                   padding=padding,
                   W_std=W_std,
                   b_std=b_std,
                   parameterization=parameterization)
    affine = conv(width) if is_conv else fc(width)
    rate = np.onp.random.uniform(0.5, 0.9)
    dropout = stax.Dropout(rate, mode='train')
    ave_pool = stax.AvgPool((2, 3), None,
                            'SAME' if padding == 'SAME' else 'CIRCULAR')
    ave_pool_or_identity = ave_pool if use_pooling else stax.Identity()
    dropout_or_identity = dropout if use_dropout else stax.Identity()
    layer_norm_or_identity = (stax.Identity() if layer_norm is None else
                              stax.LayerNorm(axis=layer_norm))
    res_unit = stax.serial(ave_pool_or_identity, phi, dropout_or_identity,
                           affine)
    if is_res:
        block = stax.serial(affine, stax.FanOut(2),
                            stax.parallel(stax.Identity(), res_unit),
                            stax.FanInSum(), layer_norm_or_identity)
    else:
        block = stax.serial(affine, res_unit, layer_norm_or_identity)

    if proj_into_2d == 'FLAT':
        proj_layer = stax.Flatten()
    elif proj_into_2d == 'POOL':
        proj_layer = stax.GlobalAvgPool()
    elif proj_into_2d.startswith('ATTN'):
        n_heads = int(np.sqrt(width))
        n_chan_val = int(np.round(float(width) / n_heads))
        fixed = proj_into_2d == 'ATTN_FIXED'
        proj_layer = stax.serial(
            stax.GlobalSelfAttention(width,
                                     n_chan_key=width,
                                     n_chan_val=n_chan_val,
                                     n_heads=n_heads,
                                     fixed=fixed,
                                     W_key_std=W_std,
                                     W_value_std=W_std,
                                     W_query_std=W_std,
                                     W_out_std=1.0,
                                     b_std=b_std), stax.Flatten())
    else:
        raise ValueError(proj_into_2d)
    readout = stax.serial(proj_layer, fc(1 if is_ntk else width))

    return stax.serial(block, readout)
Ejemplo n.º 14
0
    def _test_analytic_kernel_composition(self, batching_fn):
        # Check Fully-Connected.
        rng = random.PRNGKey(0)
        rng_self, rng_other = random.split(rng)
        x_self = random.normal(rng_self, (8, 10))
        x_other = random.normal(rng_other, (2, 10))
        Block = stax.serial(stax.Dense(256), stax.Relu())

        _, _, ker_fn = Block
        ker_fn = batching_fn(ker_fn)

        _, _, composed_ker_fn = stax.serial(Block, Block)

        ker_out = ker_fn(ker_fn(x_self))
        composed_ker_out = composed_ker_fn(x_self)
        if batching_fn == batch._parallel:
            # In the parallel setting, `x1_is_x2` is not computed correctly
            # when x1==x2.
            composed_ker_out = composed_ker_out._replace(
                x1_is_x2=ker_out.x1_is_x2)
        self.assertAllClose(ker_out, composed_ker_out, True)

        ker_out = ker_fn(ker_fn(x_self, x_other))
        composed_ker_out = composed_ker_fn(x_self, x_other)
        if batching_fn == batch._parallel:
            composed_ker_out = composed_ker_out._replace(
                x1_is_x2=ker_out.x1_is_x2)
        self.assertAllClose(ker_out, composed_ker_out, True)

        # Check convolutional + pooling.
        x_self = random.normal(rng, (8, 10, 10, 3))
        x_other = random.normal(rng, (2, 10, 10, 3))

        Block = stax.serial(stax.Conv(256, (2, 2)), stax.Relu())
        Readout = stax.serial(stax.GlobalAvgPool(), stax.Dense(10))

        block_ker_fn, readout_ker_fn = Block[2], Readout[2]
        _, _, composed_ker_fn = stax.serial(Block, Readout)
        block_ker_fn = batching_fn(block_ker_fn)
        readout_ker_fn = batching_fn(readout_ker_fn)

        ker_out = readout_ker_fn(block_ker_fn(x_self, marginalization='none'))
        composed_ker_out = composed_ker_fn(x_self)
        if batching_fn == batch._parallel:
            composed_ker_out = composed_ker_out._replace(
                x1_is_x2=ker_out.x1_is_x2)
        self.assertAllClose(ker_out, composed_ker_out, True)
        ker_out = readout_ker_fn(
            block_ker_fn(x_self, x_other, marginalization='none'))
        composed_ker_out = composed_ker_fn(x_self, x_other)
        if batching_fn == batch._parallel:
            composed_ker_out = composed_ker_out._replace(
                x1_is_x2=ker_out.x1_is_x2)
        self.assertAllClose(ker_out, composed_ker_out, True)
Ejemplo n.º 15
0
    def _test_analytic_kernel_composition(self, batching_fn):
        # Check Fully-Connected.
        rng = random.PRNGKey(0)
        rng_self, rng_other = random.split(rng)
        x_self = random.normal(rng_self, (8, 10))
        x_other = random.normal(rng_other, (20, 10))
        Block = stax.serial(stax.Dense(256), stax.Relu())

        _, _, ker_fn = Block
        ker_fn = batching_fn(ker_fn)

        _, _, composed_ker_fn = stax.serial(Block, Block)

        ker_out = ker_fn(ker_fn(x_self))
        composed_ker_out = composed_ker_fn(x_self)
        self.assertAllClose(ker_out, composed_ker_out, True)

        ker_out = ker_fn(ker_fn(x_self, x_other))
        composed_ker_out = composed_ker_fn(x_self, x_other)
        self.assertAllClose(ker_out, composed_ker_out, True)

        # Check convolutional + pooling.
        x_self = random.normal(rng, (8, 10, 10, 3))
        x_other = random.normal(rng, (10, 10, 10, 3))

        Block = stax.serial(stax.Conv(256, (3, 3)), stax.Relu())
        Readout = stax.serial(stax.GlobalAvgPool(), stax.Dense(10))

        block_ker_fn, readout_ker_fn = Block[2], Readout[2]
        _, _, composed_ker_fn = stax.serial(Block, Readout)

        block_ker_fn = batching_fn(block_ker_fn)
        readout_ker_fn = batching_fn(readout_ker_fn)

        ker_out = readout_ker_fn(block_ker_fn(x_self, marginalization='none'))
        composed_ker_out = composed_ker_fn(x_self)
        self.assertAllClose(ker_out, composed_ker_out, True)

        ker_out = readout_ker_fn(
            block_ker_fn(x_self, x_other, marginalization='none'))
        composed_ker_out = composed_ker_fn(x_self, x_other)
        self.assertAllClose(ker_out, composed_ker_out, True)
Ejemplo n.º 16
0
def _get_net(W_std, b_std, filter_shape, is_conv, use_pooling, is_res,
             padding, phi, strides, width, is_ntk, proj_into_2d):
  fc = partial(stax.Dense, W_std=W_std, b_std=b_std)
  conv = partial(
      stax.Conv,
      filter_shape=filter_shape,
      strides=strides,
      padding=padding,
      W_std=W_std,
      b_std=b_std)
  affine = conv(width) if is_conv else fc(width)

  res_unit = stax.serial((stax.AvgPool(
      (2, 3), None, 'SAME' if padding == 'SAME' else 'CIRCULAR')
                          if use_pooling else stax.Identity()), phi, affine)

  if is_res:
    block = stax.serial(affine, stax.FanOut(2),
                        stax.parallel(stax.Identity(), res_unit),
                        stax.FanInSum())
  else:
    block = stax.serial(affine, res_unit)

  if proj_into_2d == 'FLAT':
    proj_layer = stax.Flatten()
  elif proj_into_2d == 'POOL':
    proj_layer = stax.GlobalAvgPool()
  elif proj_into_2d.startswith('ATTN'):
    n_heads = int(np.sqrt(width))
    n_chan_val = int(np.round(float(width) / n_heads))
    fixed = proj_into_2d == 'ATTN_FIXED'
    proj_layer = stax.serial(
        stax.GlobalSelfAttention(
            width, n_chan_key=width, n_chan_val=n_chan_val, n_heads=n_heads,
            fixed=fixed, W_key_std=W_std, W_value_std=W_std, W_query_std=W_std,
            W_out_std=1.0, b_std=b_std),
        stax.Flatten())
  else:
    raise ValueError(proj_into_2d)
  readout = stax.serial(proj_layer, fc(1 if is_ntk else width))

  return stax.serial(block, readout)
Ejemplo n.º 17
0
  def test_elementwise_numerical(self, same_inputs, model, phi, get):
    if 'conv' in model:
      test_utils.skip_test(self)

    key, split = random.split(random.PRNGKey(1))

    output_dim = 1
    b_std = 0.01
    W_std = 1.0
    rtol = 2e-3
    deg = 25
    if get == 'ntk':
      rtol *= 2
    if default_backend() == 'tpu':
      rtol *= 2

    if model == 'fc':
      X0_1 = random.normal(key, (3, 7))
      X0_2 = None if same_inputs else random.normal(split, (5, 7))
      affine = stax.Dense(1024, W_std, b_std)
      readout = stax.Dense(output_dim)
      depth = 1
    else:
      X0_1 = random.normal(key, (2, 8, 8, 3))
      X0_2 = None if same_inputs else random.normal(split, (3, 8, 8, 3))
      affine = stax.Conv(1024, (3, 2), W_std=W_std, b_std=b_std, padding='SAME')
      readout = stax.serial(stax.GlobalAvgPool() if 'pool' in model else
                            stax.Flatten(),
                            stax.Dense(output_dim))
      depth = 2

    _, _, kernel_fn = stax.serial(*[affine, phi] * depth, readout)
    analytic_kernel = kernel_fn(X0_1, X0_2, get)

    fn = lambda x: phi[1]((), x)
    _, _, kernel_fn = stax.serial(
        *[affine, stax.ElementwiseNumerical(fn, deg=deg)] * depth, readout)
    numerical_activation_kernel = kernel_fn(X0_1, X0_2, get)

    test_utils.assert_close_matrices(self, analytic_kernel,
                                     numerical_activation_kernel, rtol)
Ejemplo n.º 18
0
def CNNStandard(n_channels,
                L,
                filter=(3, 3),
                data='cifar10',
                gap=True,
                nonlinearity='relu',
                parameterization='standard',
                order=None):
    if data == 'cifar10':
        num_classes = 10
    if data == 'cifar100':
        num_classes = 100
    if nonlinearity == 'relu':
        nonlin = Relu
    elif nonlinearity == 'swish':
        nonlin = Swish
    init_fn, f = jax_stax.serial(*[
        jax_stax.serial(
            MyConv(n_channels,
                   filter,
                   parameterization=parameterization,
                   order=order),
            nonlin,
        ) for _ in range(L)
    ])
    if gap:
        init_fn, f = jax_stax.serial((init_fn, f),
                                     stax.GlobalAvgPool()[:2],
                                     MyDense(num_classes,
                                             parameterization=parameterization,
                                             order=order))
    else:
        init_fn, f = jax_stax.serial((init_fn, f),
                                     stax.Flatten()[:2],
                                     MyDense(num_classes,
                                             parameterization=parameterization,
                                             order=order))
    return init_fn, f
Ejemplo n.º 19
0
    def test_fan_in_conv(self, same_inputs, axis, n_branches, get, branch_in,
                         readout, fan_in_mode):
        test_utils.skip_test(self)
        if fan_in_mode in ['FanInSum', 'FanInProd']:
            if axis != 0:
                raise absltest.SkipTest(
                    '`FanInSum` and `FanInProd()` are skipped when '
                    'axis != 0.')
            axis = None
        if (fan_in_mode == 'FanInSum'
                or axis in [0, 1, 2]) and branch_in == 'dense_after_branch_in':
            raise absltest.SkipTest('`FanInSum` and `FanInConcat(0/1/2)` '
                                    'require `is_gaussian`.')

        if ((axis == 3 or fan_in_mode == 'FanInProd')
                and branch_in == 'dense_before_branch_in'):
            raise absltest.SkipTest(
                '`FanInConcat` or `FanInProd` on feature axis '
                'requires a dense layer after concatenation '
                'or Hadamard product.')

        if fan_in_mode == 'FanInSum':
            fan_in_layer = stax.FanInSum()
        elif fan_in_mode == 'FanInProd':
            fan_in_layer = stax.FanInProd()
        else:
            fan_in_layer = stax.FanInConcat(axis)

        key = random.PRNGKey(1)
        X0_1 = random.normal(key, (2, 5, 6, 3))
        X0_2 = None if same_inputs else random.normal(key, (3, 5, 6, 3))

        if default_backend() == 'tpu':
            width = 2048
            n_samples = 1024
            tol = 0.02
        else:
            width = 1024
            n_samples = 512
            tol = 0.01

        conv = stax.Conv(out_chan=width,
                         filter_shape=(3, 3),
                         padding='SAME',
                         W_std=1.25,
                         b_std=0.1)

        input_layers = [conv, stax.FanOut(n_branches)]

        branches = []
        for b in range(n_branches):
            branch_layers = [FanInTest._get_phi(b)]
            for i in range(b):
                multiplier = 1 if axis not in (3, -1) else (1 + 0.25 * i)
                branch_layers += [
                    stax.Conv(out_chan=int(width * multiplier),
                              filter_shape=(i + 1, 4 - i),
                              padding='SAME',
                              W_std=1.25 + i,
                              b_std=0.1 + i),
                    FanInTest._get_phi(i)
                ]

            if branch_in == 'dense_before_branch_in':
                branch_layers += [conv]
            branches += [stax.serial(*branch_layers)]

        output_layers = [
            fan_in_layer,
            stax.Relu(),
            stax.GlobalAvgPool() if readout == 'pool' else stax.Flatten()
        ]
        if branch_in == 'dense_after_branch_in':
            output_layers.insert(1, conv)

        nn = stax.serial(*(input_layers + [stax.parallel(*branches)] +
                           output_layers))

        init_fn, apply_fn, kernel_fn = stax.serial(
            nn, stax.Dense(1 if get == 'ntk' else width, 1.25, 0.5))

        kernel_fn_mc = nt.monte_carlo_kernel_fn(
            init_fn,
            apply_fn,
            key,
            n_samples,
            device_count=0 if axis in (0, -4) else -1,
            implementation=_DEFAULT_TESTING_NTK_IMPLEMENTATION,
            vmap_axes=None if axis in (0, -4) else 0,
        )

        exact = kernel_fn(X0_1, X0_2, get=get)
        empirical = kernel_fn_mc(X0_1, X0_2, get=get)
        test_utils.assert_close_matrices(self, empirical, exact, tol)
Ejemplo n.º 20
0
  def test_kwargs(self, do_batch, mode):
    rng = random.PRNGKey(1)

    x_train = random.normal(rng, (8, 7, 10))
    x_test = random.normal(rng, (4, 7, 10))
    y_train = random.normal(rng, (8, 1))

    rng_train, rng_test = random.split(rng, 2)

    pattern_train = random.normal(rng, (8, 7, 7))
    pattern_test = random.normal(rng, (4, 7, 7))

    init_fn, apply_fn, kernel_fn = stax.serial(
        stax.Dense(8),
        stax.Relu(),
        stax.Dropout(rate=0.4),
        stax.Aggregate(),
        stax.GlobalAvgPool(),
        stax.Dense(1)
    )

    kw_dd = dict(pattern=(pattern_train, pattern_train))
    kw_td = dict(pattern=(pattern_test, pattern_train))
    kw_tt = dict(pattern=(pattern_test, pattern_test))

    if mode == 'mc':
      kernel_fn = monte_carlo_kernel_fn(init_fn, apply_fn, rng, 2,
                                        batch_size=2 if do_batch else 0)

    elif mode == 'empirical':
      kernel_fn = empirical_kernel_fn(apply_fn)
      if do_batch:
        raise absltest.SkipTest('Batching of empirical kernel is not '
                                'implemented with keyword arguments.')

      for kw in (kw_dd, kw_td, kw_tt):
        kw.update(dict(params=init_fn(rng, x_train.shape)[1],
                       get=('nngp', 'ntk')))

      kw_dd.update(dict(rng=(rng_train, None)))
      kw_td.update(dict(rng=(rng_test, rng_train)))
      kw_tt.update(dict(rng=(rng_test, None)))

    elif mode == 'analytic':
      if do_batch:
        kernel_fn = batch.batch(kernel_fn, batch_size=2)

    else:
      raise ValueError(mode)

    k_dd = kernel_fn(x_train, None, **kw_dd)
    k_td = kernel_fn(x_test, x_train, **kw_td)
    k_tt = kernel_fn(x_test, None, **kw_tt)

    # Infinite time NNGP/NTK.
    predict_fn_gp = predict.gp_inference(k_dd, y_train)
    out_gp = predict_fn_gp(k_test_train=k_td, nngp_test_test=k_tt.nngp)

    if mode == 'empirical':
      for kw in (kw_dd, kw_td, kw_tt):
        kw.pop('get')

    predict_fn_ensemble = predict.gradient_descent_mse_ensemble(kernel_fn,
                                                                x_train,
                                                                y_train,
                                                                **kw_dd)
    out_ensemble = predict_fn_ensemble(x_test=x_test, compute_cov=True, **kw_tt)
    self.assertAllClose(out_gp, out_ensemble)

    # Finite time NTK test.
    predict_fn_mse = predict.gradient_descent_mse(k_dd.ntk, y_train)
    out_mse = predict_fn_mse(t=1.,
                             fx_train_0=None,
                             fx_test_0=0.,
                             k_test_train=k_td.ntk)
    out_ensemble = predict_fn_ensemble(t=1.,
                                       get='ntk',
                                       x_test=x_test,
                                       compute_cov=False,
                                       **kw_tt)
    self.assertAllClose(out_mse, out_ensemble)

    # Finite time NNGP train.
    predict_fn_mse = predict.gradient_descent_mse(k_dd.nngp, y_train)
    out_mse = predict_fn_mse(t=2.,
                             fx_train_0=0.,
                             fx_test_0=None,
                             k_test_train=k_td.nngp)
    out_ensemble = predict_fn_ensemble(t=2.,
                                       get='nngp',
                                       x_test=None,
                                       compute_cov=False,
                                       **kw_dd)
    self.assertAllClose(out_mse, out_ensemble)
Ejemplo n.º 21
0
    def test_input_req(self, same_inputs):
        test_utils.skip_test(self)

        key = random.PRNGKey(1)
        x1 = random.normal(key, (2, 7, 8, 4, 3))
        x2 = None if same_inputs else random.normal(key, (4, 7, 8, 4, 3))

        _, _, wrong_conv_fn = stax.serial(
            stax.Conv(out_chan=1,
                      filter_shape=(1, 2, 3),
                      dimension_numbers=('NDHWC', 'HDWIO', 'NCDWH')),
            stax.Relu(),
            stax.Conv(out_chan=1,
                      filter_shape=(1, 2, 3),
                      dimension_numbers=('NHDWC', 'HWDIO', 'NCWHD')))
        with self.assertRaises(ValueError):
            wrong_conv_fn(x1, x2)

        init_fn, apply_fn, correct_conv_fn = stax.serial(
            stax.Conv(out_chan=1024,
                      filter_shape=(1, 2, 3),
                      dimension_numbers=('NHWDC', 'DHWIO', 'NCWDH')),
            stax.Relu(),
            stax.Conv(out_chan=1024,
                      filter_shape=(1, 2, 3),
                      dimension_numbers=('NCHDW', 'WHDIO', 'NCDWH')),
            stax.Flatten(), stax.Dense(1024))

        correct_conv_fn_mc = nt.monte_carlo_kernel_fn(
            init_fn=init_fn,
            apply_fn=apply_fn,
            key=key,
            n_samples=400,
            implementation=_DEFAULT_TESTING_NTK_IMPLEMENTATION,
            vmap_axes=0)
        K = correct_conv_fn(x1, x2, get='nngp')
        K_mc = correct_conv_fn_mc(x1, x2, get='nngp')
        self.assertAllClose(K, K_mc, atol=0.01, rtol=0.05)

        _, _, wrong_conv_fn = stax.serial(
            stax.Conv(out_chan=1,
                      filter_shape=(1, 2, 3),
                      dimension_numbers=('NDHWC', 'HDWIO', 'NCDWH')),
            stax.GlobalAvgPool(channel_axis=2))
        with self.assertRaises(ValueError):
            wrong_conv_fn(x1, x2)

        init_fn, apply_fn, correct_conv_fn = stax.serial(
            stax.Conv(out_chan=1024,
                      filter_shape=(1, 2, 3),
                      dimension_numbers=('NHDWC', 'DHWIO', 'NDWCH')),
            stax.Relu(), stax.AvgPool((2, 1, 3), batch_axis=0,
                                      channel_axis=-2),
            stax.Conv(out_chan=1024,
                      filter_shape=(1, 2, 3),
                      dimension_numbers=('NDHCW', 'IHWDO', 'NDCHW')),
            stax.Relu(), stax.GlobalAvgPool(channel_axis=2), stax.Dense(1024))

        correct_conv_fn_mc = nt.monte_carlo_kernel_fn(
            init_fn=init_fn,
            apply_fn=apply_fn,
            key=key,
            n_samples=300,
            implementation=_DEFAULT_TESTING_NTK_IMPLEMENTATION,
            vmap_axes=0)
        K = correct_conv_fn(x1, x2, get='nngp')
        K_mc = correct_conv_fn_mc(x1, x2, get='nngp')
        self.assertAllClose(K, K_mc, atol=0.01, rtol=0.05)

        _, _, wrong_conv_fn = stax.serial(
            stax.Flatten(),
            stax.Dense(1),
            stax.Erf(),
            stax.Conv(out_chan=1,
                      filter_shape=(1, 2),
                      dimension_numbers=('CN', 'IO', 'NC')),
        )
        with self.assertRaises(ValueError):
            wrong_conv_fn(x1, x2)

        init_fn, apply_fn, correct_conv_fn = stax.serial(
            stax.Flatten(), stax.Conv(out_chan=1024, filter_shape=()),
            stax.Relu(), stax.Dense(1))

        correct_conv_fn_mc = nt.monte_carlo_kernel_fn(
            init_fn=init_fn,
            apply_fn=apply_fn,
            key=key,
            n_samples=200,
            implementation=_DEFAULT_TESTING_NTK_IMPLEMENTATION,
            vmap_axes=0)
        K = correct_conv_fn(x1, x2, get='ntk')
        K_mc = correct_conv_fn_mc(x1, x2, get='ntk')
        self.assertAllClose(K, K_mc, atol=0.01, rtol=0.05)
Ejemplo n.º 22
0
    def test_mask_conv(self, same_inputs, get, mask_axis, mask_constant,
                       concat, proj, p, n, transpose):
        if isinstance(concat, int) and concat > n:
            raise absltest.SkipTest('Concatenation axis out of bounds.')

        test_utils.skip_test(self)
        if default_backend() == 'gpu' and n > 3:
            raise absltest.SkipTest('>=4D-CNN is not supported on GPUs.')

        width = 256
        n_samples = 256
        tol = 0.03
        key = random.PRNGKey(1)

        spatial_shape = ((1, 2, 3, 2, 1) if transpose else (15, 8, 9))[:n]
        filter_shape = ((2, 3, 1, 2, 1) if transpose else (7, 2, 3))[:n]
        strides = (2, 1, 3, 2, 3)[:n]
        spatial_spec = 'HWDZX'[:n]
        dimension_numbers = ('N' + spatial_spec + 'C', 'OI' + spatial_spec,
                             'N' + spatial_spec + 'C')

        x1 = np.cos(random.normal(key, (2, ) + spatial_shape + (2, )))
        x1 = test_utils.mask(x1, mask_constant, mask_axis, key, p)

        if same_inputs:
            x2 = None
        else:
            x2 = np.cos(random.normal(key, (4, ) + spatial_shape + (2, )))
            x2 = test_utils.mask(x2, mask_constant, mask_axis, key, p)

        def get_attn():
            return stax.GlobalSelfAttention(
                n_chan_out=width,
                n_chan_key=width,
                n_chan_val=int(np.round(float(width) / int(np.sqrt(width)))),
                n_heads=int(np.sqrt(width)),
            ) if proj == 'avg' else stax.Identity()

        conv = stax.ConvTranspose if transpose else stax.Conv

        nn = stax.serial(
            stax.FanOut(3),
            stax.parallel(
                stax.serial(
                    conv(dimension_numbers=dimension_numbers,
                         out_chan=width,
                         strides=strides,
                         filter_shape=filter_shape,
                         padding='CIRCULAR',
                         W_std=1.5,
                         b_std=0.2),
                    stax.LayerNorm(axis=(1, -1)),
                    stax.Abs(),
                    stax.DotGeneral(rhs=0.9),
                    conv(dimension_numbers=dimension_numbers,
                         out_chan=width,
                         strides=strides,
                         filter_shape=filter_shape,
                         padding='VALID',
                         W_std=1.2,
                         b_std=0.1),
                ),
                stax.serial(
                    conv(dimension_numbers=dimension_numbers,
                         out_chan=width,
                         strides=strides,
                         filter_shape=filter_shape,
                         padding='SAME',
                         W_std=0.1,
                         b_std=0.3),
                    stax.Relu(),
                    stax.Dropout(0.7),
                    conv(dimension_numbers=dimension_numbers,
                         out_chan=width,
                         strides=strides,
                         filter_shape=filter_shape,
                         padding='VALID',
                         W_std=0.9,
                         b_std=1.),
                ),
                stax.serial(
                    get_attn(),
                    conv(dimension_numbers=dimension_numbers,
                         out_chan=width,
                         strides=strides,
                         filter_shape=filter_shape,
                         padding='CIRCULAR',
                         W_std=1.,
                         b_std=0.1),
                    stax.Erf(),
                    stax.Dropout(0.2),
                    stax.DotGeneral(rhs=0.7),
                    conv(dimension_numbers=dimension_numbers,
                         out_chan=width,
                         strides=strides,
                         filter_shape=filter_shape,
                         padding='VALID',
                         W_std=1.,
                         b_std=0.1),
                )),
            (stax.FanInSum() if concat is None else stax.FanInConcat(concat)),
            get_attn(),
            {
                'avg': stax.GlobalAvgPool(),
                'sum': stax.GlobalSumPool(),
                'flatten': stax.Flatten(),
            }[proj],
        )

        if get == 'nngp':
            init_fn, apply_fn, kernel_fn = stax.serial(
                nn, stax.Dense(width, 1., 0.))
        elif get == 'ntk':
            init_fn, apply_fn, kernel_fn = stax.serial(nn,
                                                       stax.Dense(1, 1., 0.))
        else:
            raise ValueError(get)

        kernel_fn_mc = nt.monte_carlo_kernel_fn(
            init_fn,
            apply_fn,
            key,
            n_samples,
            device_count=0 if concat in (0, -n) else -1,
            implementation=_DEFAULT_TESTING_NTK_IMPLEMENTATION,
            vmap_axes=None if concat in (0, -n) else 0,
        )

        kernel_fn = jit(kernel_fn, static_argnames='get')
        exact = kernel_fn(x1, x2, get, mask_constant=mask_constant)
        empirical = kernel_fn_mc(x1, x2, get=get, mask_constant=mask_constant)
        test_utils.assert_close_matrices(self, empirical, exact, tol)
Ejemplo n.º 23
0
from neural_tangents import stax
from neural_tangents._src.empirical import _DEFAULT_TESTING_NTK_IMPLEMENTATION
from tests import test_utils

config.parse_flags_with_absl()
config.update('jax_numpy_rank_promotion', 'raise')

test_utils.update_test_tolerance()

prandom.seed(1)


@parameterized.product(
    same_inputs=[False, True],
    readout=[stax.Flatten(),
             stax.GlobalAvgPool(),
             stax.Identity()],
    readin=[stax.Flatten(),
            stax.GlobalAvgPool(),
            stax.Identity()])
class DiagonalTest(test_utils.NeuralTangentsTestCase):
    def _get_kernel_fn(self, same_inputs, readin, readout):
        key = random.PRNGKey(1)
        x1 = random.normal(key, (2, 5, 6, 3))
        x2 = None if same_inputs else random.normal(key, (3, 5, 6, 3))
        layers = [readin]
        filter_shape = (2, 3) if readin[0].__name__ == 'Identity' else ()
        layers += [
            stax.Conv(1, filter_shape, padding='SAME'),
            stax.Relu(),
            stax.Conv(1, filter_shape, padding='SAME'),
Ejemplo n.º 24
0
  def test_fan_in_conv(self,
                       same_inputs,
                       axis,
                       n_branches,
                       get,
                       branch_in,
                       readout):
    if xla_bridge.get_backend().platform == 'cpu':
      raise jtu.SkipTest('Not running CNNs on CPU to save time.')

    if axis in (None, 0, 1, 2) and branch_in == 'dense_after_branch_in':
      raise jtu.SkipTest('`FanInSum` and `FanInConcat(0/1/2)` '
                         'require `is_gaussian`.')

    if axis == 3 and branch_in == 'dense_before_branch_in':
      raise jtu.SkipTest('`FanInConcat` on feature axis requires a dense layer '
                         'after concatenation.')

    key = random.PRNGKey(1)
    X0_1 = random.normal(key, (2, 5, 6, 3))
    X0_2 = None if same_inputs else random.normal(key, (3, 5, 6, 3))

    if xla_bridge.get_backend().platform == 'tpu':
      width = 2048
      n_samples = 1024
      tol = 0.02
    else:
      width = 1024
      n_samples = 512
      tol = 0.01

    conv = stax.Conv(out_chan=width,
                     filter_shape=(3, 3),
                     padding='SAME',
                     W_std=1.25,
                     b_std=0.1)

    input_layers = [conv,
                    stax.FanOut(n_branches)]

    branches = []
    for b in range(n_branches):
      branch_layers = [FanInTest._get_phi(b)]
      for i in range(b):
        branch_layers += [
            stax.Conv(
                out_chan=width,
                filter_shape=(i + 1, 4 - i),
                padding='SAME',
                W_std=1.25 + i,
                b_std=0.1 + i),
            FanInTest._get_phi(i)]

      if branch_in == 'dense_before_branch_in':
        branch_layers += [conv]
      branches += [stax.serial(*branch_layers)]

    output_layers = [
        stax.FanInSum() if axis is None else stax.FanInConcat(axis),
        stax.Relu(),
        stax.GlobalAvgPool() if readout == 'pool' else stax.Flatten()
    ]
    if branch_in == 'dense_after_branch_in':
      output_layers.insert(1, conv)

    nn = stax.serial(*(input_layers + [stax.parallel(*branches)] +
                       output_layers))

    init_fn, apply_fn, kernel_fn = stax.serial(
        nn, stax.Dense(1 if get == 'ntk' else width, 1.25, 0.5))

    kernel_fn_mc = monte_carlo.monte_carlo_kernel_fn(
        init_fn,
        apply_fn,
        key,
        n_samples,
        device_count=0 if axis in (0, -4) else -1)

    exact = kernel_fn(X0_1, X0_2, get=get)
    empirical = kernel_fn_mc(X0_1, X0_2, get=get)
    empirical = empirical.reshape(exact.shape)
    utils.assert_close_matrices(self, empirical, exact, tol)
Ejemplo n.º 25
0
@parameterized.named_parameters(
    test_utils.cases_from_list(
        {
            'testcase_name':
            ' [{}_out={}_in={}]'.format(
                'same_inputs' if same_inputs else 'different_inputs',
                readout[0].__name__, readin[0].__name__),
            'same_inputs':
            same_inputs,
            'readout':
            readout,
            'readin':
            readin
        } for same_inputs in [False, True] for readout in
        [stax.Flatten(), stax.GlobalAvgPool(),
         stax.Identity()]
        for readin in [stax.Flatten(),
                       stax.GlobalAvgPool(),
                       stax.Identity()]))
class DiagonalTest(test_utils.NeuralTangentsTestCase):
    def _get_kernel_fn(self, same_inputs, readin, readout):
        key = random.PRNGKey(1)
        x1 = random.normal(key, (2, 5, 6, 3))
        x2 = None if same_inputs else random.normal(key, (3, 5, 6, 3))
        layers = [readin]
        filter_shape = (2, 3) if readin[0].__name__ == 'Identity' else ()
        layers += [
            stax.Conv(1, filter_shape, padding='SAME'),
            stax.Relu(),
            stax.Conv(1, filter_shape, padding='SAME'),
Ejemplo n.º 26
0
def ResNet50(num_classes,
             batchnorm=True,
             parameterization='standard',
             nonlinearity='relu'):
    # Define layer constructors
    if parameterization == 'standard':

        def MyGeneralConv(*args, **kwargs):
            return GeneralConv(*args, **kwargs)

        def MyDense(*args, **kwargs):
            return Dense(*args, **kwargs)
    elif parameterization == 'ntk':

        def MyGeneralConv(*args, **kwargs):
            return stax._GeneralConv(*args, **kwargs)[:2]

        def MyDense(*args, **kwargs):
            return stax.Dense(*args, **kwargs)[:2]

    # Define nonlinearity
    if nonlinearity == 'relu':
        nonlin = Relu
    elif nonlinearity == 'swish':
        nonlin = Swish
    elif nonlinearity == 'swishten':
        nonlin = Swishten
    elif nonlinearity == 'softplus':
        nonlin = Softplus
    return jax_stax.serial(
        MyGeneralConv(('NHWC', 'HWIO', 'NHWC'),
                      64, (7, 7),
                      strides=(2, 2),
                      padding='SAME'),
        BatchNorm() if batchnorm else Identity, nonlin,
        MaxPool((3, 3), strides=(2, 2)),
        ConvBlock(3, [64, 64, 256],
                  strides=(1, 1),
                  batchnorm=batchnorm,
                  parameterization=parameterization,
                  nonlin=nonlin),
        IdentityBlock(3, [64, 64],
                      batchnorm=batchnorm,
                      parameterization=parameterization,
                      nonlin=nonlin),
        IdentityBlock(3, [64, 64],
                      batchnorm=batchnorm,
                      parameterization=parameterization,
                      nonlin=nonlin),
        ConvBlock(3, [128, 128, 512],
                  batchnorm=batchnorm,
                  parameterization=parameterization,
                  nonlin=nonlin),
        IdentityBlock(3, [128, 128],
                      batchnorm=batchnorm,
                      parameterization=parameterization,
                      nonlin=nonlin),
        IdentityBlock(3, [128, 128],
                      batchnorm=batchnorm,
                      parameterization=parameterization,
                      nonlin=nonlin),
        IdentityBlock(3, [128, 128],
                      batchnorm=batchnorm,
                      parameterization=parameterization,
                      nonlin=nonlin),
        ConvBlock(3, [256, 256, 1024],
                  batchnorm=batchnorm,
                  parameterization=parameterization,
                  nonlin=nonlin),
        IdentityBlock(3, [256, 256],
                      batchnorm=batchnorm,
                      parameterization=parameterization,
                      nonlin=nonlin),
        IdentityBlock(3, [256, 256],
                      batchnorm=batchnorm,
                      parameterization=parameterization,
                      nonlin=nonlin),
        IdentityBlock(3, [256, 256],
                      batchnorm=batchnorm,
                      parameterization=parameterization,
                      nonlin=nonlin),
        IdentityBlock(3, [256, 256],
                      batchnorm=batchnorm,
                      parameterization=parameterization,
                      nonlin=nonlin),
        IdentityBlock(3, [256, 256],
                      batchnorm=batchnorm,
                      parameterization=parameterization,
                      nonlin=nonlin),
        ConvBlock(3, [512, 512, 2048],
                  batchnorm=batchnorm,
                  parameterization=parameterization,
                  nonlin=nonlin),
        IdentityBlock(3, [512, 512],
                      batchnorm=batchnorm,
                      parameterization=parameterization,
                      nonlin=nonlin),
        IdentityBlock(3, [512, 512],
                      batchnorm=batchnorm,
                      parameterization=parameterization,
                      nonlin=nonlin),
        stax.GlobalAvgPool()[:-1], MyDense(num_classes))
Ejemplo n.º 27
0
def main(*args, use_dummy_data: bool = False, **kwargs) -> None:
    # Mask all padding with this value.
    mask_constant = 100.

    if use_dummy_data:
        x_train, y_train, x_test, y_test = _get_dummy_data(mask_constant)
    else:
        # Build data pipelines.
        print('Loading IMDb data.')
        x_train, y_train, x_test, y_test = datasets.get_dataset(
            name='imdb_reviews',
            n_train=FLAGS.n_train,
            n_test=FLAGS.n_test,
            do_flatten_and_normalize=False,
            data_dir=FLAGS.imdb_path,
            input_key='text')

        # Embed words and pad / truncate sentences to a fixed size.
        x_train, x_test = datasets.embed_glove(
            xs=[x_train, x_test],
            glove_path=FLAGS.glove_path,
            max_sentence_length=FLAGS.max_sentence_length,
            mask_constant=mask_constant)

    # Build the infinite network.
    # Not using the finite model, hence width is set to 1 everywhere.
    _, _, kernel_fn = stax.serial(
        stax.Conv(out_chan=1,
                  filter_shape=(9, ),
                  strides=(1, ),
                  padding='VALID'), stax.Relu(),
        stax.GlobalSelfAttention(n_chan_out=1,
                                 n_chan_key=1,
                                 n_chan_val=1,
                                 pos_emb_type='SUM',
                                 W_pos_emb_std=1.,
                                 pos_emb_decay_fn=lambda d: 1 / (1 + d**2),
                                 n_heads=1), stax.Relu(), stax.GlobalAvgPool(),
        stax.Dense(out_dim=1))

    # Optionally, compute the kernel in batches, in parallel.
    kernel_fn = nt.batch(kernel_fn,
                         device_count=-1,
                         batch_size=FLAGS.batch_size)

    start = time.time()
    # Bayesian and infinite-time gradient descent inference with infinite network.
    predict = nt.predict.gradient_descent_mse_ensemble(
        kernel_fn=kernel_fn,
        x_train=x_train,
        y_train=y_train,
        diag_reg=1e-6,
        mask_constant=mask_constant)

    fx_test_nngp, fx_test_ntk = predict(x_test=x_test, get=('nngp', 'ntk'))

    fx_test_nngp.block_until_ready()
    fx_test_ntk.block_until_ready()

    duration = time.time() - start
    print(f'Kernel construction and inference done in {duration} seconds.')

    # Print out accuracy and loss for infinite network predictions.
    loss = lambda fx, y_hat: 0.5 * np.mean((fx - y_hat)**2)
    util.print_summary('NNGP test', y_test, fx_test_nngp, None, loss)
    util.print_summary('NTK test', y_test, fx_test_ntk, None, loss)