Esempio n. 1
0
    def test_composition_conv(self, avg_pool):
        rng = random.PRNGKey(0)
        x1 = random.normal(rng, (5, 10, 10, 3))
        x2 = random.normal(rng, (5, 10, 10, 3))

        Block = stax.serial(stax.Conv(256, (3, 3)), stax.Relu())
        if avg_pool:
            Readout = stax.serial(stax.GlobalAvgPool(), stax.Dense(10))
            marginalization = 'none'
        else:
            Readout = stax.serial(stax.Flatten(), stax.Dense(10))
            marginalization = 'auto'

        block_ker_fn, readout_ker_fn = Block[2], Readout[2]
        _, _, composed_ker_fn = stax.serial(Block, Readout)

        ker_out = readout_ker_fn(
            block_ker_fn(x1, marginalization=marginalization))
        composed_ker_out = composed_ker_fn(x1)
        self.assertAllClose(ker_out, composed_ker_out, True)

        if avg_pool:
            with self.assertRaises(ValueError):
                ker_out = readout_ker_fn(block_ker_fn(x1))

        ker_out = readout_ker_fn(
            block_ker_fn(x1, x2, marginalization=marginalization))
        composed_ker_out = composed_ker_fn(x1, x2)
        self.assertAllClose(ker_out, composed_ker_out, True)
Esempio n. 2
0
  def test_composition_conv(self, avg_pool, same_inputs):
    rng = random.PRNGKey(0)
    x1 = random.normal(rng, (3, 5, 5, 3))
    x2 = None if same_inputs else random.normal(rng, (4, 5, 5, 3))

    Block = stax.serial(stax.Conv(256, (3, 3)), stax.Relu())
    if avg_pool:
      Readout = stax.serial(stax.Conv(256, (3, 3)),
                            stax.GlobalAvgPool(),
                            stax.Dense(10))
    else:
      Readout = stax.serial(stax.Flatten(), stax.Dense(10))

    block_ker_fn, readout_ker_fn = Block[2], Readout[2]
    _, _, composed_ker_fn = stax.serial(Block, Readout)

    composed_ker_out = composed_ker_fn(x1, x2)
    ker_out_no_marg = readout_ker_fn(block_ker_fn(x1, x2,
                                                  diagonal_spatial=False))
    ker_out_default = readout_ker_fn(block_ker_fn(x1, x2))
    self.assertAllClose(composed_ker_out, ker_out_no_marg)
    self.assertAllClose(composed_ker_out, ker_out_default)

    if avg_pool:
      with self.assertRaises(ValueError):
        ker_out = readout_ker_fn(block_ker_fn(x1, x2, diagonal_spatial=True))
    else:
      ker_out_marg = readout_ker_fn(block_ker_fn(x1, x2,
                                                 diagonal_spatial=True))
      self.assertAllClose(composed_ker_out, ker_out_marg)
Esempio n. 3
0
def _get_net(W_std, b_std, filter_shape, is_conv, use_pooling, is_res,
             padding, phi, strides, width, is_ntk):
  fc = partial(stax.Dense, W_std=W_std, b_std=b_std)
  conv = partial(
      stax.Conv,
      filter_shape=filter_shape,
      strides=strides,
      padding=padding,
      W_std=W_std,
      b_std=b_std)
  affine = conv(width) if is_conv else fc(width)

  res_unit = stax.serial((stax.AvgPool(
      (2, 3), None, 'SAME' if padding == 'SAME' else 'CIRCULAR')
                          if use_pooling else stax.Identity()), phi, affine)

  if is_res:
    block = stax.serial(affine, stax.FanOut(2),
                        stax.parallel(stax.Identity(), res_unit),
                        stax.FanInSum())
  else:
    block = stax.serial(affine, res_unit)

  readout = stax.serial(stax.GlobalAvgPool() if use_pooling else stax.Flatten(),
                        fc(1 if is_ntk else width))

  net = stax.serial(block, readout)
  return net
Esempio n. 4
0
  def test_nested_parallel(self, same_inputs, kernel_type):
    platform = default_backend()
    rtol = RTOL if platform != 'tpu' else 0.05

    rng = random.PRNGKey(0)
    (input_key1,
     input_key2,
     input_key3,
     input_key4,
     mask_key,
     mc_key) = random.split(rng, 6)

    x1_1, x2_1 = _get_inputs(input_key1, same_inputs, (BATCH_SIZE, 5))
    x1_2, x2_2 = _get_inputs(input_key2, same_inputs, (BATCH_SIZE, 2, 2, 2))
    x1_3, x2_3 = _get_inputs(input_key3, same_inputs, (BATCH_SIZE, 2, 2, 3))
    x1_4, x2_4 = _get_inputs(input_key4, same_inputs, (BATCH_SIZE, 3, 4))

    m1_key, m2_key, m3_key, m4_key = random.split(mask_key, 4)

    x1_1 = test_utils.mask(
        x1_1, mask_constant=-1, mask_axis=(1,), key=m1_key, p=0.5)
    x1_2 = test_utils.mask(
        x1_2, mask_constant=-1, mask_axis=(2, 3,), key=m2_key, p=0.5)
    if not same_inputs:
      x2_3 = test_utils.mask(
          x2_3, mask_constant=-1, mask_axis=(1, 3,), key=m3_key, p=0.5)
      x2_4 = test_utils.mask(
          x2_4, mask_constant=-1, mask_axis=(2,), key=m4_key, p=0.5)

    x1 = (((x1_1, x1_2), x1_3), x1_4)
    x2 = (((x2_1, x2_2), x2_3), x2_4) if not same_inputs else None

    N_in = 2 ** 7

    # We only include dropout on non-TPU backends, because it takes large N to
    # converge on TPU.
    dropout_or_id = stax.Dropout(0.9) if platform != 'tpu' else stax.Identity()

    init_fn, apply_fn, kernel_fn = stax.parallel(
        stax.parallel(
            stax.parallel(stax.Dense(N_in),
                          stax.serial(stax.Conv(N_in + 1, (2, 2)),
                                      stax.Flatten())),
            stax.serial(stax.Conv(N_in + 2, (2, 2)),
                        dropout_or_id,
                        stax.GlobalAvgPool())),
        stax.Conv(N_in + 3, (2,)))

    kernel_fn_empirical = nt.monte_carlo_kernel_fn(
        init_fn, apply_fn, mc_key, N_SAMPLES,
        implementation=_DEFAULT_TESTING_NTK_IMPLEMENTATION,
        vmap_axes=(((((0, 0), 0), 0), (((0, 0), 0), 0), {})
                   if platform == 'tpu' else None)
    )

    test_utils.assert_close_matrices(
        self,
        kernel_fn(x1, x2, get=kernel_type, mask_constant=-1),
        kernel_fn_empirical(x1, x2, get=kernel_type, mask_constant=-1),
        rtol)
def _build_network(input_shape, network, out_logits):
    if len(input_shape) == 1:
        assert network == FLAT
        return stax.Dense(out_logits, W_std=2.0, b_std=0.5)
    elif len(input_shape) == 3:
        if network == POOLING:
            return stax.serial(
                stax.Conv(CONVOLUTION_CHANNELS, (3, 3), W_std=2.0, b_std=0.05),
                stax.GlobalAvgPool(),
                stax.Dense(out_logits, W_std=2.0, b_std=0.5))
        elif network == CONV:
            return stax.serial(
                stax.Conv(CONVOLUTION_CHANNELS, (1, 2), W_std=1.5, b_std=0.1),
                stax.Relu(),
                stax.Conv(CONVOLUTION_CHANNELS, (3, 2), W_std=2.0, b_std=0.05),
            )
        elif network == FLAT:
            return stax.serial(
                stax.Conv(CONVOLUTION_CHANNELS, (3, 3), W_std=2.0, b_std=0.05),
                stax.Flatten(), stax.Dense(out_logits, W_std=2.0, b_std=0.5))
        else:
            raise ValueError(
                'Unexpected network type found: {}'.format(network))
    else:
        raise ValueError('Expected flat or image test input.')
Esempio n. 6
0
  def test_ab_relu_id(self, same_inputs, do_stabilize):
    key = random.PRNGKey(1)
    X0_1 = random.normal(key, (3, 2))
    fc = stax.Dense(5, 1, 0)

    X0_2 = None if same_inputs else random.normal(key, (4, 2))

    # Test that ABRelu(a, a) == a * Identity
    init_fn, apply_id, kernel_fn_id = stax.serial(fc, stax.Identity())
    _, params = init_fn(key, input_shape=X0_1.shape)

    for a in [-5, -1, -0.5, 0, 0.5, 1, 5]:
      with self.subTest(a=a):
        _, apply_ab_relu, kernel_fn_ab_relu = stax.serial(
            fc, stax.ABRelu(a, a, do_stabilize=do_stabilize))

        X1_1_id = a * apply_id(params, X0_1)
        X1_1_ab_relu = apply_ab_relu(params, X0_1)
        self.assertAllClose(X1_1_id, X1_1_ab_relu)

        kernels_id = kernel_fn_id(X0_1 * a, None if X0_2 is None else a * X0_2)
        kernels_ab_relu = kernel_fn_ab_relu(X0_1, X0_2)
        # Manually correct the value of `is_gaussian` because
        # `ab_relu` (incorrectly) sets `is_gaussian=False` when `a==b`.
        kernels_ab_relu = kernels_ab_relu.replace(is_gaussian=True)
        self.assertAllClose(kernels_id, kernels_ab_relu)
Esempio n. 7
0
def _build_network(input_shape, network, out_logits, use_dropout):
    dropout = stax.Dropout(0.9,
                           mode='train') if use_dropout else stax.Identity()
    if len(input_shape) == 1:
        assert network == 'FLAT'
        return stax.serial(stax.Dense(WIDTH, W_std=2.0, b_std=0.5), dropout,
                           stax.Dense(out_logits, W_std=2.0, b_std=0.5))
    elif len(input_shape) == 3:
        if network == POOLING:
            return stax.serial(
                stax.Conv(CONVOLUTION_CHANNELS, (2, 2), W_std=2.0, b_std=0.05),
                stax.GlobalAvgPool(), dropout,
                stax.Dense(out_logits, W_std=2.0, b_std=0.5))
        elif network == FLAT:
            return stax.serial(
                stax.Conv(CONVOLUTION_CHANNELS, (2, 2), W_std=2.0, b_std=0.05),
                stax.Flatten(), dropout,
                stax.Dense(out_logits, W_std=2.0, b_std=0.5))
        elif network == INTERMEDIATE_CONV:
            return stax.Conv(CONVOLUTION_CHANNELS, (2, 2),
                             W_std=2.0,
                             b_std=0.05)
        else:
            raise ValueError(
                'Unexpected network type found: {}'.format(network))
    else:
        raise ValueError('Expected flat or image test input.')
Esempio n. 8
0
def WideResnetBlocknt(channels,
                      strides=(1, 1),
                      channel_mismatch=False,
                      batchnorm='std',
                      parameterization='ntk'):
    """A WideResnet block, with or without BatchNorm."""

    Main = stax_nt.serial(
        _batch_norm_internal(batchnorm), stax_nt.Relu(),
        stax_nt.Conv(channels, (3, 3),
                     strides,
                     padding='SAME',
                     parameterization=parameterization),
        _batch_norm_internal(batchnorm), stax_nt.Relu(),
        stax_nt.Conv(channels, (3, 3),
                     padding='SAME',
                     parameterization=parameterization))

    Shortcut = stax_nt.Identity() if not channel_mismatch else stax_nt.Conv(
        channels, (3, 3),
        strides,
        padding='SAME',
        parameterization=parameterization)
    return stax_nt.serial(stax_nt.FanOut(2), stax_nt.parallel(Main, Shortcut),
                          stax_nt.FanInSum())
Esempio n. 9
0
    def test_ab_relu_id(self, same_inputs):
        key = random.PRNGKey(1)
        X0_1 = random.normal(key, (5, 7))
        fc = stax.Dense(10, 1, 0)

        X0_2 = None if same_inputs else random.normal(key, (9, 7))

        # Test that ABRelu(a, a) == a * Identity
        init_fn, apply_id, kernel_fn_id = stax.serial(fc, stax.Identity())
        params = init_fn(key, input_shape=(-1, 7))

        for a in [-5, -1, -0.5, 0, 0.5, 1, 5]:
            with self.subTest(a=a):
                _, apply_ab_relu, kernel_fn_ab_relu = stax.serial(
                    fc, stax.ABRelu(a, a))

                X1_1_id = a * apply_id(params, X0_1)
                X1_1_ab_relu = apply_ab_relu(params, X0_1)
                self.assertAllClose(X1_1_id, X1_1_ab_relu, True)

                kernels_id = kernel_fn_id(X0_1 * a,
                                          None if X0_2 is None else a * X0_2)
                kernels_ab_relu = kernel_fn_ab_relu(X0_1, X0_2,
                                                    ('nngp', 'ntk'))
                self.assertAllClose(kernels_id, kernels_ab_relu, True)
Esempio n. 10
0
def GP(x_train, y_train, x_test, y_test, w_std, b_std, l, C):
    net0 = stax.Dense(1, w_std, b_std)
    nets = [net0]

    k_layer = []
    K = net0[2](x_train, None)
    k_layer.append(K.nngp)

    for l in range(1, l + 1):
        net_l = stax.serial(stax.Relu(), stax.Dense(1, w_std, b_std))
        K = net_l[2](K)
        k_layer.append(K.nngp)
        nets += [stax.serial(nets[-1], net_l)]

    kernel_fn = nets[-1][2]

    start = time.time()
    # Bayesian and infinite-time gradient descent inference with infinite network.
    fx_test_nngp, fx_test_ntk = nt.predict.gp_inference(kernel_fn,
                                                        x_train,
                                                        y_train,
                                                        x_test,
                                                        get=('nngp', 'ntk'),
                                                        diag_reg=C)

    fx_test_nngp.block_until_ready()

    duration = time.time() - start
    #print('Kernel construction and inference done in %s seconds.' % duration)
    return accuracy(y_test, fx_test_nngp)
Esempio n. 11
0
  def _test_activation(self, activation_fn, same_inputs, model, get,
                       rbf_gamma=None):
    if 'conv' in model:
      test_utils.skip_test(self)

    key = random.PRNGKey(1)
    key, split = random.split(key)
    output_dim = 1024 if get == 'nngp' else 1
    b_std = 0.5
    W_std = 2.0
    if activation_fn[2].__name__ == 'Sin':
      W_std = 0.9
    if activation_fn[2].__name__ == 'Rbf':
      W_std = 1.0
      b_std = 0.0

    if model == 'fc':
      rtol = 0.04
      X0_1 = random.normal(key, (4, 2))
      X0_2 = None if same_inputs else random.normal(split, (2, 2))
      affine = stax.Dense(1024, W_std, b_std)
      readout = stax.Dense(output_dim)
      depth = 1

    else:
      rtol = 0.05
      X0_1 = random.normal(key, (2, 4, 4, 3))
      X0_2 = None if same_inputs else random.normal(split, (4, 4, 4, 3))
      affine = stax.Conv(512, (3, 2), W_std=W_std, b_std=b_std, padding='SAME')
      readout = stax.serial(stax.GlobalAvgPool() if 'pool' in model else
                            stax.Flatten(),
                            stax.Dense(output_dim))
      depth = 2

    if default_backend() == 'cpu':
      num_samplings = 200
      rtol *= 2
    else:
      num_samplings = (500 if activation_fn[2].__name__ in ('Sin', 'Rbf')
                       else 300)

    init_fn, apply_fn, kernel_fn = stax.serial(
        *[affine, activation_fn]*depth, readout)
    analytic_kernel = kernel_fn(X0_1, X0_2, get)
    mc_kernel_fn = nt.monte_carlo_kernel_fn(
        init_fn, apply_fn, split, num_samplings, implementation=2,
        vmap_axes=0
    )
    empirical_kernel = mc_kernel_fn(X0_1, X0_2, get)
    test_utils.assert_close_matrices(self, analytic_kernel,
                                     empirical_kernel, rtol)

    # Check match with explicit RBF
    if rbf_gamma is not None and get == 'nngp' and model == 'fc':
      input_dim = X0_1.shape[1]
      _, _, kernel_fn = self._RBF(rbf_gamma / input_dim)
      direct_rbf_kernel = kernel_fn(X0_1, X0_2, get)
      test_utils.assert_close_matrices(self, analytic_kernel,
                                       direct_rbf_kernel, rtol)
Esempio n. 12
0
  def test_elementwise(
      self,
      same_inputs,
      phi,
      n,
      diagonal_batch,
      diagonal_spatial
  ):
    fn = lambda x: phi[1]((), x)

    name = phi[0].__name__

    def nngp_fn(cov12, var1, var2):
      if 'Identity' in name:
        res = cov12

      elif 'Erf' in name:
        prod = (1 + 2 * var1) * (1 + 2 * var2)
        res = np.arcsin(2 * cov12 / np.sqrt(prod)) * 2 / np.pi

      elif 'Sin' in name:
        sum_ = (var1 + var2)
        s1 = np.exp((-0.5 * sum_ + cov12))
        s2 = np.exp((-0.5 * sum_ - cov12))
        res = (s1 - s2) / 2

      elif 'Relu' in name:
        prod = var1 * var2
        sqrt = np.sqrt(np.maximum(prod - cov12 ** 2, 1e-30))
        angles = np.arctan2(sqrt, cov12)
        dot_sigma = (1 - angles / np.pi) / 2
        res = sqrt / (2 * np.pi) + dot_sigma * cov12

      else:
        raise NotImplementedError(name)

      return res

    _, _, kernel_fn = stax.serial(stax.Dense(1), stax.Elementwise(fn, nngp_fn),
                                  stax.Dense(1), stax.Elementwise(fn, nngp_fn))
    _, _, kernel_fn_manual = stax.serial(stax.Dense(1), phi,
                                         stax.Dense(1), phi)

    key = random.PRNGKey(1)
    shape = (4, 3, 2)[:n] + (1,)
    x1 = random.normal(key, (5,) + shape)
    if same_inputs is None:
      x2 = None
    elif same_inputs is True:
      x2 = x1
    else:
      x2 = random.normal(key, (6,) + shape)

    kwargs = dict(diagonal_batch=diagonal_batch,
                  diagonal_spatial=diagonal_spatial)

    k = kernel_fn(x1, x2, **kwargs)
    k_manual = kernel_fn_manual(x1, x2, **kwargs).replace(is_gaussian=False)
    self.assertAllClose(k_manual, k)
Esempio n. 13
0
    def _test_analytic_kernel_composition(self, batching_fn):
        # Check Fully-Connected.
        rng = stateless_uniform(shape=[2],
                                seed=[0, 0],
                                minval=None,
                                maxval=None,
                                dtype=tf.int32)
        keys = tf_random_split(rng)
        rng_self = keys[0]
        rng_other = keys[1]
        x_self = np.asarray(normal((8, 10), seed=rng_self))
        x_other = np.asarray(normal((2, 10), seed=rng_other))
        Block = stax.serial(stax.Dense(256), stax.Relu())

        _, _, ker_fn = Block
        ker_fn = batching_fn(ker_fn)

        _, _, composed_ker_fn = stax.serial(Block, Block)

        ker_out = ker_fn(ker_fn(x_self))
        composed_ker_out = composed_ker_fn(x_self)
        if batching_fn == batch._parallel:
            # In the parallel setting, `x1_is_x2` is not computed correctly
            # when x1==x2.
            composed_ker_out = composed_ker_out.replace(
                x1_is_x2=ker_out.x1_is_x2)
        self.assertAllClose(ker_out, composed_ker_out)

        ker_out = ker_fn(ker_fn(x_self, x_other))
        composed_ker_out = composed_ker_fn(x_self, x_other)
        if batching_fn == batch._parallel:
            composed_ker_out = composed_ker_out.replace(
                x1_is_x2=ker_out.x1_is_x2)
        self.assertAllClose(ker_out, composed_ker_out)

        # Check convolutional + pooling.
        x_self = np.asarray(normal((8, 10, 10, 3), seed=rng))
        x_other = np.asarray(normal((2, 10, 10, 3), seed=rng))

        Block = stax.serial(stax.Conv(256, (2, 2)), stax.Relu())
        Readout = stax.serial(stax.GlobalAvgPool(), stax.Dense(10))

        block_ker_fn, readout_ker_fn = Block[2], Readout[2]
        _, _, composed_ker_fn = stax.serial(Block, Readout)
        block_ker_fn = batching_fn(block_ker_fn)
        readout_ker_fn = batching_fn(readout_ker_fn)

        ker_out = readout_ker_fn(block_ker_fn(x_self))
        composed_ker_out = composed_ker_fn(x_self)
        if batching_fn == batch._parallel:
            composed_ker_out = composed_ker_out.replace(
                x1_is_x2=ker_out.x1_is_x2)
        self.assertAllClose(ker_out, composed_ker_out)
        ker_out = readout_ker_fn(block_ker_fn(x_self, x_other))
        composed_ker_out = composed_ker_fn(x_self, x_other)
        if batching_fn == batch._parallel:
            composed_ker_out = composed_ker_out.replace(
                x1_is_x2=ker_out.x1_is_x2)
        self.assertAllClose(ker_out, composed_ker_out)
Esempio n. 14
0
  def test_vmap_axes(self, same_inputs):
    n1, n2 = 3, 4
    c1, c2, c3 = 9, 5, 7
    h2, h3, w3 = 6, 8, 2

    def get_x(n, k):
      k1, k2, k3 = random.split(k, 3)
      x1 = random.normal(k1, (n, c1))
      x2 = random.normal(k2, (h2, n, c2))
      x3 = random.normal(k3, (c3, w3, n, h3))
      x = [(x1, x2), x3]
      return x

    x1 = get_x(n1, random.PRNGKey(1))
    x2 = get_x(n2, random.PRNGKey(2)) if not same_inputs else None

    p1 = random.normal(random.PRNGKey(5), (n1, h2, h2))
    p2 = None if same_inputs else random.normal(random.PRNGKey(6), (n2, h2, h2))

    init_fn, apply_fn, _ = stax.serial(
        stax.parallel(
            stax.parallel(
                stax.serial(stax.Dense(4, 2., 0.1),
                            stax.Relu(),
                            stax.Dense(3, 1., 0.15)),  # 1
                stax.serial(stax.Conv(7, (2,), padding='SAME',
                                      dimension_numbers=('HNC', 'OIH', 'NHC')),
                            stax.Erf(),
                            stax.Aggregate(1, 0, -1),
                            stax.GlobalAvgPool(),
                            stax.Dense(3, 0.5, 0.2)),  # 2
            ),
            stax.serial(
                stax.Conv(5, (2, 3), padding='SAME',
                          dimension_numbers=('CWNH', 'IOHW', 'HWCN')),
                stax.Sin(),
            )  # 3
        ),
        stax.parallel(
            stax.FanInSum(),
            stax.Conv(2, (2, 1), dimension_numbers=('HWCN', 'OIHW', 'HNWC'))
        )
    )

    _, params = init_fn(random.PRNGKey(3), tree_map(np.shape, x1))
    implicit = jit(empirical._empirical_implicit_ntk_fn(apply_fn))
    direct = jit(empirical._empirical_direct_ntk_fn(apply_fn))

    implicit_batched = jit(empirical._empirical_implicit_ntk_fn(
        apply_fn, vmap_axes=([(0, 1), 2], [-2, -3], dict(pattern=0))))
    direct_batched = jit(empirical._empirical_direct_ntk_fn(
        apply_fn, vmap_axes=([(-2, -2), -2], [0, 1], dict(pattern=-3))))

    k = direct(x1, x2, params, pattern=(p1, p2))

    self.assertAllClose(k, implicit(x1, x2, params, pattern=(p1, p2)))
    self.assertAllClose(k, direct_batched(x1, x2, params, pattern=(p1, p2)))
    self.assertAllClose(k, implicit_batched(x1, x2, params, pattern=(p1, p2)))
def ResnetBlock(channels, strides=(1, 1), channel_mismatch=False):
    Main = stax.serial(stax.Relu(),
                       stax.Conv(channels, (3, 3), strides, padding='SAME'),
                       stax.Relu(), stax.Conv(channels, (3, 3),
                                              padding='SAME'))
    Shortcut = stax.Identity() if not channel_mismatch else stax.Conv(
        channels, (3, 3), strides, padding='SAME')
    return stax.serial(stax.FanOut(2), stax.parallel(Main, Shortcut),
                       stax.FanInSum())
Esempio n. 16
0
    def _test_analytic_kernel_composition(self, batching_fn):
        # Check Fully-Connected.
        rng = random.PRNGKey(0)
        rng_self, rng_other = random.split(rng)
        x_self = random.normal(rng_self, (8, 10))
        x_other = random.normal(rng_other, (2, 10))
        Block = stax.serial(stax.Dense(256), stax.Relu())

        _, _, ker_fn = Block
        ker_fn = batching_fn(ker_fn)

        _, _, composed_ker_fn = stax.serial(Block, Block)

        ker_out = ker_fn(ker_fn(x_self))
        composed_ker_out = composed_ker_fn(x_self)
        if batching_fn == batch._parallel:
            # In the parallel setting, `x1_is_x2` is not computed correctly
            # when x1==x2.
            composed_ker_out = composed_ker_out._replace(
                x1_is_x2=ker_out.x1_is_x2)
        self.assertAllClose(ker_out, composed_ker_out, True)

        ker_out = ker_fn(ker_fn(x_self, x_other))
        composed_ker_out = composed_ker_fn(x_self, x_other)
        if batching_fn == batch._parallel:
            composed_ker_out = composed_ker_out._replace(
                x1_is_x2=ker_out.x1_is_x2)
        self.assertAllClose(ker_out, composed_ker_out, True)

        # Check convolutional + pooling.
        x_self = random.normal(rng, (8, 10, 10, 3))
        x_other = random.normal(rng, (2, 10, 10, 3))

        Block = stax.serial(stax.Conv(256, (2, 2)), stax.Relu())
        Readout = stax.serial(stax.GlobalAvgPool(), stax.Dense(10))

        block_ker_fn, readout_ker_fn = Block[2], Readout[2]
        _, _, composed_ker_fn = stax.serial(Block, Readout)
        block_ker_fn = batching_fn(block_ker_fn)
        readout_ker_fn = batching_fn(readout_ker_fn)

        ker_out = readout_ker_fn(block_ker_fn(x_self, marginalization='none'))
        composed_ker_out = composed_ker_fn(x_self)
        if batching_fn == batch._parallel:
            composed_ker_out = composed_ker_out._replace(
                x1_is_x2=ker_out.x1_is_x2)
        self.assertAllClose(ker_out, composed_ker_out, True)
        ker_out = readout_ker_fn(
            block_ker_fn(x_self, x_other, marginalization='none'))
        composed_ker_out = composed_ker_fn(x_self, x_other)
        if batching_fn == batch._parallel:
            composed_ker_out = composed_ker_out._replace(
                x1_is_x2=ker_out.x1_is_x2)
        self.assertAllClose(ker_out, composed_ker_out, True)
Esempio n. 17
0
  def test_composition(self):
    rng = random.PRNGKey(0)
    xs = random.normal(rng, (10, 10))
    Block = stax.serial(stax.Dense(256), stax.Relu())

    _, _, ker_fn = Block
    _, _, composed_ker_fn = stax.serial(Block, Block)

    ker_out = ker_fn(ker_fn(xs))
    composed_ker_out = composed_ker_fn(xs)

    self.assertAllClose(ker_out, composed_ker_out, True)
Esempio n. 18
0
def main(unused_argv):
    # Build data pipelines.
    print('Loading data.')
    x_train, y_train, x_test, y_test = \
      datasets.get_dataset('cifar10', FLAGS.train_size, FLAGS.test_size)

    # Build the infinite network.
    _, _, kernel_fn = stax.serial(stax.Dense(1, 2., 0.05), stax.Relu(),
                                  stax.Dense(1, 2., 0.05))

    # Optionally, compute the kernel in batches, in parallel.
    kernel_fn = nt.batch(kernel_fn,
                         device_count=0,
                         batch_size=FLAGS.batch_size)

    start = time.time()
    # Bayesian and infinite-time gradient descent inference with infinite network.
    predict_fn = nt.predict.gradient_descent_mse_ensemble(kernel_fn,
                                                          x_train,
                                                          y_train,
                                                          diag_reg=1e-3)
    fx_test_nngp, fx_test_ntk = predict_fn(x_test=x_test)
    fx_test_nngp.block_until_ready()
    fx_test_ntk.block_until_ready()

    duration = time.time() - start
    print('Kernel construction and inference done in %s seconds.' % duration)

    # Print out accuracy and loss for infinite network predictions.
    loss = lambda fx, y_hat: 0.5 * np.mean((fx - y_hat)**2)
    util.print_summary('NNGP test', y_test, fx_test_nngp, None, loss)
    util.print_summary('NTK test', y_test, fx_test_ntk, None, loss)
Esempio n. 19
0
    def testPredictOnCPU(self):
        x_train = random.normal(random.PRNGKey(1), (10, 4, 5, 3))
        x_test = random.normal(random.PRNGKey(1), (8, 4, 5, 3))

        y_train = random.uniform(random.PRNGKey(1), (10, 7))

        _, _, kernel_fn = stax.serial(stax.Conv(1, (3, 3)), stax.Relu(),
                                      stax.Flatten(), stax.Dense(1))

        for store_on_device in [False, True]:
            for device_count in [0, 1]:
                for get in ['ntk', 'nngp', ('nngp', 'ntk'), ('ntk', 'nngp')]:
                    with self.subTest(store_on_device=store_on_device,
                                      device_count=device_count,
                                      get=get):
                        kernel_fn_batched = batch.batch(
                            kernel_fn, 2, device_count, store_on_device)
                        predictor = predict.gradient_descent_mse_gp(
                            kernel_fn_batched, x_train, y_train, x_test, get,
                            0., True)
                        gp_inference = predict.gp_inference(
                            kernel_fn_batched, x_train, y_train, x_test, get,
                            0., True)

                        self.assertAllClose(predictor(None), predictor(np.inf),
                                            True)
                        self.assertAllClose(predictor(None), gp_inference,
                                            True)
Esempio n. 20
0
    def test_parallel_in_out_empirical(self, same_inputs):
        test_utils.stub_out_pmap(batch, 2)
        rng = random.PRNGKey(0)
        input_key1, input_key2, net_key = random.split(rng, 3)

        x1_1, x1_2, x1_3 = random.normal(input_key1, (3, 4, 10))
        x2_1, x2_2, x2_3 = random.normal(input_key2, (3, 8, 10))

        x1 = (x1_1, (x1_2, x1_3))
        x2 = (x2_1, (x2_2, x2_3))

        def net(N_out):
            return stax.parallel(
                stax.Dense(N_out),
                stax.parallel(stax.Dense(N_out + 1), stax.Dense(N_out + 2)))

        # Check NNGP.
        init_fn, apply_fn, _ = net(WIDTH)
        _, params = init_fn(net_key, ((-1, 10), ((-1, 10), (-1, 10))))

        kernel_fn = jit(empirical.empirical_nngp_fn(apply_fn))
        batch_kernel_fn = jit(batch.batch(kernel_fn, 2))

        test_utils.assert_close_matrices(self, kernel_fn(x1, x2, params),
                                         batch_kernel_fn(x1, x2, params), RTOL)

        # Check NTK.
        init_fn, apply_fn, _ = stax.serial(net(WIDTH), net(1))
        _, params = init_fn(net_key, ((-1, 10), ((-1, 10), (-1, 10))))

        kernel_fn = jit(empirical.empirical_ntk_fn(apply_fn))
        batch_kernel_fn = jit(batch.batch(kernel_fn, 2))

        test_utils.assert_close_matrices(self, kernel_fn(x1, x2, params),
                                         batch_kernel_fn(x1, x2, params), RTOL)
Esempio n. 21
0
 def WideResnet(block_size, k, num_classes):
     return stax.serial(
         stax.Conv(16, (3, 3), padding='SAME'),
         ntk_generator.ResnetGroup(block_size, int(16 * k)),
         ntk_generator.ResnetGroup(block_size, int(32 * k), (2, 2)),
         ntk_generator.ResnetGroup(block_size, int(64 * k), (2, 2)),
         stax.Flatten(), stax.Dense(num_classes, 1., 0.))
Esempio n. 22
0
  def test_hermite(self, same_inputs, degree, get, readout):
    key = random.PRNGKey(1)
    key1, key2, key = random.split(key, 3)

    if degree > 2:
      width = 10000
      n_samples = 5000
      test_utils.skip_test(self)
    else:
      width = 10000
      n_samples = 100

    x1 = np.cos(random.normal(key1, [2, 6, 6, 3]))
    x2 = x1 if same_inputs else np.cos(random.normal(key2, [3, 6, 6, 3]))

    conv_layers = [
        stax.Conv(width, (3, 3), W_std=2., b_std=0.5),
        stax.LayerNorm(),
        stax.Hermite(degree),
        stax.GlobalAvgPool() if readout == 'pool' else stax.Flatten(),
        stax.Dense(1) if get == 'ntk' else stax.Identity()]

    init_fn, apply_fn, kernel_fn = stax.serial(*conv_layers)
    analytic_kernel = kernel_fn(x1, x2, get)
    mc_kernel_fn = nt.monte_carlo_kernel_fn(init_fn, apply_fn, key, n_samples)
    mc_kernel = mc_kernel_fn(x1, x2, get)
    rot = degree / 2. * 1e-2
    test_utils.assert_close_matrices(self, mc_kernel, analytic_kernel, rot)
Esempio n. 23
0
def main(unused_argv):
    # Build data and .
    print('Loading data.')
    x_train, y_train, x_test, y_test = datasets.get_dataset('mnist',
                                                            permute_train=True)

    # Build the network
    init_fn, f, _ = stax.serial(stax.Dense(2048, 1., 0.05), stax.Erf(),
                                stax.Dense(10, 1., 0.05))

    key = random.PRNGKey(0)
    _, params = init_fn(key, (-1, 784))

    # Linearize the network about its initial parameters.
    f_lin = nt.linearize(f, params)

    # Create and initialize an optimizer for both f and f_lin.
    opt_init, opt_apply, get_params = optimizers.momentum(
        FLAGS.learning_rate, 0.9)
    opt_apply = jit(opt_apply)

    state = opt_init(params)
    state_lin = opt_init(params)

    # Create a cross-entropy loss function.
    loss = lambda fx, y_hat: -np.mean(logsoftmax(fx) * y_hat)

    # Specialize the loss function to compute gradients for both linearized and
    # full networks.
    grad_loss = jit(grad(lambda params, x, y: loss(f(params, x), y)))
    grad_loss_lin = jit(grad(lambda params, x, y: loss(f_lin(params, x), y)))

    # Train the network.
    print('Training.')
    print('Epoch\tLoss\tLinearized Loss')
    print('------------------------------------------')

    epoch = 0
    steps_per_epoch = 50000 // FLAGS.batch_size

    for i, (x, y) in enumerate(
            datasets.minibatch(x_train, y_train, FLAGS.batch_size,
                               FLAGS.train_epochs)):

        params = get_params(state)
        state = opt_apply(i, grad_loss(params, x, y), state)

        params_lin = get_params(state_lin)
        state_lin = opt_apply(i, grad_loss_lin(params_lin, x, y), state_lin)

        if i % steps_per_epoch == 0:
            print('{}\t{:.4f}\t{:.4f}'.format(epoch, loss(f(params, x), y),
                                              loss(f_lin(params_lin, x), y)))
            epoch += 1

    # Print out summary data comparing the linear / nonlinear model.
    x, y = x_train[:10000], y_train[:10000]
    util.print_summary('train', y, f(params, x), f_lin(params_lin, x), loss)
    util.print_summary('test', y_test, f(params, x_test),
                       f_lin(params_lin, x_test), loss)
Esempio n. 24
0
  def test_nonlineariy(self, phi, same_inputs, a, b, n):
    width = 2**10
    n_samples = 2**9
    init_fn, apply_fn, kernel_fn = stax.serial(
        stax.Dense(width),
        phi(a=a, b=b),
        stax.Dense(width),
        phi(a=a, b=b),
        stax.Dense(1))

    key1, key2, key_mc = random.split(random.PRNGKey(1), 3)
    shape = (4, 3, 2)[:n] + (1,)
    x1 = np.cos(random.normal(key1, (2,) + shape))
    if same_inputs is None:
      x2 = None
    elif same_inputs is True:
      x2 = x1
    else:
      x2 = np.cos(random.normal(key2, (3,) + shape))

    k = kernel_fn(x1, x2)
    mc_kernel_fn = nt.monte_carlo_kernel_fn(init_fn, apply_fn, key_mc,
                                            n_samples)
    k_mc = mc_kernel_fn(x1, x2, ('nngp', 'ntk'))
    test_utils.assert_close_matrices(self, k_mc.nngp, k.nngp, 6e-2)
    test_utils.assert_close_matrices(self, k_mc.ntk, k.ntk, 6e-2)
Esempio n. 25
0
  def test_exp_normalized(self):
    key = random.PRNGKey(0)
    x1 = random.normal(key, (2, 6, 7, 1))
    x2 = random.normal(key, (4, 6, 7, 1))

    for do_clip in [True, False]:
      for gamma in [1., 2., 0.5]:
        for get in ['nngp', 'ntk']:
          with self.subTest(do_clip=do_clip, gamma=gamma, get=get):
            _, _, kernel_fn = stax.serial(
                stax.Conv(1, (3, 3)),
                stax.ExpNormalized(gamma, do_clip),
                stax.Conv(1, (3, 3)),
                stax.ExpNormalized(gamma, do_clip),
                stax.GlobalAvgPool(),
                stax.Dense(1)
            )
            k_12 = kernel_fn(x1, x2, get=get)
            self.assertEqual(k_12.shape, (x1.shape[0], x2.shape[0]))

            k_11 = kernel_fn(x1, None, get=get)
            self.assertEqual(k_11.shape, (x1.shape[0],) * 2)
            self.assertGreater(np.min(np.linalg.eigvalsh(k_11)), 0)

            k_22 = kernel_fn(x2, None, get=get)
            self.assertEqual(k_22.shape, (x2.shape[0],) * 2)
            self.assertGreater(np.min(np.linalg.eigvalsh(k_22)), 0)
Esempio n. 26
0
def main(unused_argv):
  key1, key2, key3 = random.split(random.PRNGKey(1), 3)
  x1 = random.normal(key1, (2, 8, 8, 3))
  x2 = random.normal(key2, (3, 8, 8, 3))

  # A vanilla CNN.
  init_fn, f, _ = stax.serial(
      stax.Conv(8, (3, 3)),
      stax.Relu(),
      stax.Conv(8, (3, 3)),
      stax.Relu(),
      stax.Conv(8, (3, 3)),
      stax.Flatten(),
      stax.Dense(10)
  )

  _, params = init_fn(key3, x1.shape)
  kwargs = dict(
      f=f,
      trace_axes=(),
      vmap_axes=0,
  )

  # Default, baseline Jacobian contraction.
  jacobian_contraction = nt.empirical_ntk_fn(
      **kwargs,
      implementation=nt.NtkImplementation.JACOBIAN_CONTRACTION)

  # (6, 3, 10, 10) full `np.ndarray` test-train NTK
  ntk_jc = jacobian_contraction(x2, x1, params)

  # NTK-vector products-based implementation.
  ntk_vector_products = nt.empirical_ntk_fn(
      **kwargs,
      implementation=nt.NtkImplementation.NTK_VECTOR_PRODUCTS)

  ntk_vp = ntk_vector_products(x2, x1, params)

  # Structured derivatives-based implementation.
  structured_derivatives = nt.empirical_ntk_fn(
      **kwargs,
      implementation=nt.NtkImplementation.STRUCTURED_DERIVATIVES)

  ntk_sd = structured_derivatives(x2, x1, params)

  # Auto-FLOPs-selecting implementation. Doesn't work correctly on CPU/GPU.
  auto = nt.empirical_ntk_fn(
      **kwargs,
      implementation=nt.NtkImplementation.AUTO)

  ntk_auto = auto(x2, x1, params)

  # Check that implementations match
  for ntk1 in [ntk_jc, ntk_vp, ntk_sd, ntk_auto]:
    for ntk2 in [ntk_jc, ntk_vp, ntk_sd, ntk_auto]:
      diff = np.max(np.abs(ntk1 - ntk2))
      print(f'NTK implementation diff {diff}.')
      assert diff < (1e-4 if jax.default_backend() != 'tpu' else 0.1), diff

  print('All NTK implementations match.')
Esempio n. 27
0
  def test_linear(
      self,
      get,
      s,
      depth,
      same_inputs,
      b_std,
      W_std,
      parameterization,
  ):
    if parameterization == 'standard':
      width = 2**9 // s
    elif parameterization == 'ntk':
      if s != 2**9:
        raise absltest.SkipTest(
            '"ntk" parameterization does not depend on "s".')
      width = 2**10
    else:
      raise ValueError(parameterization)

    layers = []
    for i in range(depth + 1):
      s_in = 1 if i == 0 else s
      s_out = 1 if (i == depth and get == 'ntk') else s
      out_dim = 1 if (i == depth and get == 'ntk') else width * (i + 1)
      layers += [stax.Dense(out_dim,
                            W_std=W_std / (i + 1),
                            b_std=b_std if b_std is None else b_std / (i + 1),
                            parameterization=parameterization,
                            s=(s_in, s_out))]

    net = stax.serial(*layers)
    net = net, (BATCH_SIZE, 3), -1, 1
    _check_agreement_with_empirical(self, net, same_inputs, False, get == 'ntk',
                                    rtol=0.02, atol=10)
Esempio n. 28
0
    def test_empirical_ntk_diagonal_outputs(self, same_inputs, device_count,
                                            trace_axes, diagonal_axes):
        test_utils.stub_out_pmap(batching, 2)
        rng = random.PRNGKey(0)

        input_key1, input_key2, net_key = random.split(rng, 3)

        init_fn, apply_fn, _ = stax.serial(stax.Dense(5), stax.Relu(),
                                           stax.Dense(3))

        test_x1 = random.normal(input_key1, (12, 4, 4))
        test_x2 = None
        if same_inputs:
            test_x2 = random.normal(input_key2, (9, 4, 4))

        kernel_fn = nt.empirical_ntk_fn(apply_fn,
                                        trace_axes=trace_axes,
                                        diagonal_axes=diagonal_axes,
                                        vmap_axes=0,
                                        implementation=2)

        _, params = init_fn(net_key, test_x1.shape)

        true_kernel = kernel_fn(test_x1, test_x2, params)
        batched_fn = batching.batch(kernel_fn,
                                    device_count=device_count,
                                    batch_size=3)
        batch_kernel = batched_fn(test_x1, test_x2, params)
        self.assertAllClose(true_kernel, batch_kernel)
def Resnet(block_size, num_classes):
    return stax.serial(stax.Conv(64, (3, 3), padding='SAME'),
                       ResnetGroup(block_size, 64),
                       ResnetGroup(block_size, 128, (2, 2)),
                       ResnetGroup(block_size, 256, (2, 2)),
                       ResnetGroup(block_size, 512, (2, 2)), stax.Flatten(),
                       stax.Dense(num_classes, 1., 0.05))
Esempio n. 30
0
    def test_flatten_first(self, same_inputs):
        key = random.PRNGKey(1)
        X0_1 = random.normal(key, (5, 4, 3, 2))
        X0_2 = None if same_inputs else random.normal(key, (3, 4, 3, 2))

        X0_1_flat = np.reshape(X0_1, (X0_1.shape[0], -1))
        X0_2_flat = None if same_inputs else np.reshape(
            X0_2, (X0_2.shape[0], -1))

        _, _, fc_flat = stax.serial(stax.Dense(10, 2., 0.5), stax.Erf())
        _, _, fc = stax.serial(stax.Flatten(), stax.Dense(10, 2., 0.5),
                               stax.Erf())

        K_flat = fc_flat(X0_1_flat, X0_2_flat)
        K = fc(X0_1, X0_2)
        self.assertAllClose(K_flat, K, True)