Ejemplo n.º 1
0
def WideResnetBlocknt(channels,
                      strides=(1, 1),
                      channel_mismatch=False,
                      batchnorm='std',
                      parameterization='ntk'):
    """A WideResnet block, with or without BatchNorm."""

    Main = stax_nt.serial(
        _batch_norm_internal(batchnorm), stax_nt.Relu(),
        stax_nt.Conv(channels, (3, 3),
                     strides,
                     padding='SAME',
                     parameterization=parameterization),
        _batch_norm_internal(batchnorm), stax_nt.Relu(),
        stax_nt.Conv(channels, (3, 3),
                     padding='SAME',
                     parameterization=parameterization))

    Shortcut = stax_nt.Identity() if not channel_mismatch else stax_nt.Conv(
        channels, (3, 3),
        strides,
        padding='SAME',
        parameterization=parameterization)
    return stax_nt.serial(stax_nt.FanOut(2), stax_nt.parallel(Main, Shortcut),
                          stax_nt.FanInSum())
Ejemplo n.º 2
0
def main(unused_argv):
  key1, key2, key3 = random.split(random.PRNGKey(1), 3)
  x1 = random.normal(key1, (2, 8, 8, 3))
  x2 = random.normal(key2, (3, 8, 8, 3))

  # A vanilla CNN.
  init_fn, f, _ = stax.serial(
      stax.Conv(8, (3, 3)),
      stax.Relu(),
      stax.Conv(8, (3, 3)),
      stax.Relu(),
      stax.Conv(8, (3, 3)),
      stax.Flatten(),
      stax.Dense(10)
  )

  _, params = init_fn(key3, x1.shape)
  kwargs = dict(
      f=f,
      trace_axes=(),
      vmap_axes=0,
  )

  # Default, baseline Jacobian contraction.
  jacobian_contraction = nt.empirical_ntk_fn(
      **kwargs,
      implementation=nt.NtkImplementation.JACOBIAN_CONTRACTION)

  # (6, 3, 10, 10) full `np.ndarray` test-train NTK
  ntk_jc = jacobian_contraction(x2, x1, params)

  # NTK-vector products-based implementation.
  ntk_vector_products = nt.empirical_ntk_fn(
      **kwargs,
      implementation=nt.NtkImplementation.NTK_VECTOR_PRODUCTS)

  ntk_vp = ntk_vector_products(x2, x1, params)

  # Structured derivatives-based implementation.
  structured_derivatives = nt.empirical_ntk_fn(
      **kwargs,
      implementation=nt.NtkImplementation.STRUCTURED_DERIVATIVES)

  ntk_sd = structured_derivatives(x2, x1, params)

  # Auto-FLOPs-selecting implementation. Doesn't work correctly on CPU/GPU.
  auto = nt.empirical_ntk_fn(
      **kwargs,
      implementation=nt.NtkImplementation.AUTO)

  ntk_auto = auto(x2, x1, params)

  # Check that implementations match
  for ntk1 in [ntk_jc, ntk_vp, ntk_sd, ntk_auto]:
    for ntk2 in [ntk_jc, ntk_vp, ntk_sd, ntk_auto]:
      diff = np.max(np.abs(ntk1 - ntk2))
      print(f'NTK implementation diff {diff}.')
      assert diff < (1e-4 if jax.default_backend() != 'tpu' else 0.1), diff

  print('All NTK implementations match.')
Ejemplo n.º 3
0
    def _test_analytic_kernel_composition(self, batching_fn):
        # Check Fully-Connected.
        rng = stateless_uniform(shape=[2],
                                seed=[0, 0],
                                minval=None,
                                maxval=None,
                                dtype=tf.int32)
        keys = tf_random_split(rng)
        rng_self = keys[0]
        rng_other = keys[1]
        x_self = np.asarray(normal((8, 10), seed=rng_self))
        x_other = np.asarray(normal((2, 10), seed=rng_other))
        Block = stax.serial(stax.Dense(256), stax.Relu())

        _, _, ker_fn = Block
        ker_fn = batching_fn(ker_fn)

        _, _, composed_ker_fn = stax.serial(Block, Block)

        ker_out = ker_fn(ker_fn(x_self))
        composed_ker_out = composed_ker_fn(x_self)
        if batching_fn == batch._parallel:
            # In the parallel setting, `x1_is_x2` is not computed correctly
            # when x1==x2.
            composed_ker_out = composed_ker_out.replace(
                x1_is_x2=ker_out.x1_is_x2)
        self.assertAllClose(ker_out, composed_ker_out)

        ker_out = ker_fn(ker_fn(x_self, x_other))
        composed_ker_out = composed_ker_fn(x_self, x_other)
        if batching_fn == batch._parallel:
            composed_ker_out = composed_ker_out.replace(
                x1_is_x2=ker_out.x1_is_x2)
        self.assertAllClose(ker_out, composed_ker_out)

        # Check convolutional + pooling.
        x_self = np.asarray(normal((8, 10, 10, 3), seed=rng))
        x_other = np.asarray(normal((2, 10, 10, 3), seed=rng))

        Block = stax.serial(stax.Conv(256, (2, 2)), stax.Relu())
        Readout = stax.serial(stax.GlobalAvgPool(), stax.Dense(10))

        block_ker_fn, readout_ker_fn = Block[2], Readout[2]
        _, _, composed_ker_fn = stax.serial(Block, Readout)
        block_ker_fn = batching_fn(block_ker_fn)
        readout_ker_fn = batching_fn(readout_ker_fn)

        ker_out = readout_ker_fn(block_ker_fn(x_self))
        composed_ker_out = composed_ker_fn(x_self)
        if batching_fn == batch._parallel:
            composed_ker_out = composed_ker_out.replace(
                x1_is_x2=ker_out.x1_is_x2)
        self.assertAllClose(ker_out, composed_ker_out)
        ker_out = readout_ker_fn(block_ker_fn(x_self, x_other))
        composed_ker_out = composed_ker_fn(x_self, x_other)
        if batching_fn == batch._parallel:
            composed_ker_out = composed_ker_out.replace(
                x1_is_x2=ker_out.x1_is_x2)
        self.assertAllClose(ker_out, composed_ker_out)
def ResnetBlock(channels, strides=(1, 1), channel_mismatch=False):
    Main = stax.serial(stax.Relu(),
                       stax.Conv(channels, (3, 3), strides, padding='SAME'),
                       stax.Relu(), stax.Conv(channels, (3, 3),
                                              padding='SAME'))
    Shortcut = stax.Identity() if not channel_mismatch else stax.Conv(
        channels, (3, 3), strides, padding='SAME')
    return stax.serial(stax.FanOut(2), stax.parallel(Main, Shortcut),
                       stax.FanInSum())
Ejemplo n.º 5
0
def build_dense_network(
    hidden_layers: Sequence[int],
    activations: Union[Sequence, str] = "erf",
    w_std: float = 2.5,
    b_std=1,
) -> NTModel:
    """Utility function to build a simple feedforward network with the
    neural tangents library.

    Args:
        hidden_layers (Sequence[int]): Iterable with the number of neurons.
            For example, [512, 512]
        activations (Union[Sequence, str], optional):
            Iterable with neural_tangents.stax axtivations or "relu" or "erf".
            Defaults to "erf".
        w_std (float): Standard deviation of the weight distribution.
        b_std (float): Standard deviation of the bias distribution.

    Returns:
        NTModel: jiited init, apply and
            kernel functions, predict_function (None)
    """
    from jax.config import config  # pylint:disable=import-outside-toplevel

    config.update("jax_enable_x64", True)
    from jax import jit  # pylint:disable=import-outside-toplevel
    from neural_tangents import stax  # pylint:disable=import-outside-toplevel

    assert len(hidden_layers) >= 1, "You must provide at least one hidden layer"
    if activations is None:
        activations = [stax.Relu() for _ in hidden_layers]
    elif isinstance(activations, str):
        if activations.lower() == "relu":
            activations = [stax.Relu() for _ in hidden_layers]
        elif activations.lower() == "erf":
            activations = [stax.Erf() for _ in hidden_layers]
    else:
        for activation in activations:
            assert callable(activation), "You need to provide `neural_tangents.stax` activations"

    assert len(activations) == len(
        hidden_layers
    ), "The number of hidden layers should match the number of nonlinearities"
    stack = []

    for hidden_layer, activation in zip(hidden_layers, activations):
        stack.append(stax.Dense(hidden_layer, W_std=w_std, b_std=b_std))
        stack.append(activation)

    stack.append(stax.Dense(1, W_std=w_std, b_std=b_std))

    init_fn, apply_fn, kernel_fn = stax.serial(*stack)

    return NTModel(init_fn, jit(apply_fn), jit(kernel_fn, static_argnums=(2,)), None)
Ejemplo n.º 6
0
    def _test_analytic_kernel_composition(self, batching_fn):
        # Check Fully-Connected.
        rng = random.PRNGKey(0)
        rng_self, rng_other = random.split(rng)
        x_self = random.normal(rng_self, (8, 10))
        x_other = random.normal(rng_other, (2, 10))
        Block = stax.serial(stax.Dense(256), stax.Relu())

        _, _, ker_fn = Block
        ker_fn = batching_fn(ker_fn)

        _, _, composed_ker_fn = stax.serial(Block, Block)

        ker_out = ker_fn(ker_fn(x_self))
        composed_ker_out = composed_ker_fn(x_self)
        if batching_fn == batch._parallel:
            # In the parallel setting, `x1_is_x2` is not computed correctly
            # when x1==x2.
            composed_ker_out = composed_ker_out._replace(
                x1_is_x2=ker_out.x1_is_x2)
        self.assertAllClose(ker_out, composed_ker_out, True)

        ker_out = ker_fn(ker_fn(x_self, x_other))
        composed_ker_out = composed_ker_fn(x_self, x_other)
        if batching_fn == batch._parallel:
            composed_ker_out = composed_ker_out._replace(
                x1_is_x2=ker_out.x1_is_x2)
        self.assertAllClose(ker_out, composed_ker_out, True)

        # Check convolutional + pooling.
        x_self = random.normal(rng, (8, 10, 10, 3))
        x_other = random.normal(rng, (2, 10, 10, 3))

        Block = stax.serial(stax.Conv(256, (2, 2)), stax.Relu())
        Readout = stax.serial(stax.GlobalAvgPool(), stax.Dense(10))

        block_ker_fn, readout_ker_fn = Block[2], Readout[2]
        _, _, composed_ker_fn = stax.serial(Block, Readout)
        block_ker_fn = batching_fn(block_ker_fn)
        readout_ker_fn = batching_fn(readout_ker_fn)

        ker_out = readout_ker_fn(block_ker_fn(x_self, marginalization='none'))
        composed_ker_out = composed_ker_fn(x_self)
        if batching_fn == batch._parallel:
            composed_ker_out = composed_ker_out._replace(
                x1_is_x2=ker_out.x1_is_x2)
        self.assertAllClose(ker_out, composed_ker_out, True)
        ker_out = readout_ker_fn(
            block_ker_fn(x_self, x_other, marginalization='none'))
        composed_ker_out = composed_ker_fn(x_self, x_other)
        if batching_fn == batch._parallel:
            composed_ker_out = composed_ker_out._replace(
                x1_is_x2=ker_out.x1_is_x2)
        self.assertAllClose(ker_out, composed_ker_out, True)
Ejemplo n.º 7
0
def build_le_net(network_width):
    """ Construct the LeNet of width network_width with average pooling using neural tangent's stax."""
    return stax.serial(
        stax.Conv(out_chan=6 * network_width,
                  filter_shape=(3, 3),
                  strides=(1, 1),
                  padding='VALID'), stax.Relu(),
        stax.AvgPool(window_shape=(2, 2), strides=(2, 2)),
        stax.Conv(out_chan=16 * network_width,
                  filter_shape=(3, 3),
                  strides=(1, 1),
                  padding='VALID'), stax.Relu(),
        stax.AvgPool(window_shape=(2, 2), strides=(2, 2)), stax.Flatten(),
        stax.Dense(120 * network_width), stax.Relu(),
        stax.Dense(84 * network_width), stax.Relu(), stax.Dense(10))
Ejemplo n.º 8
0
  def test_parameterizations(self, model, width, same_inputs, is_ntk,
                             filter_shape, proj_into_2d, parameterization):
    is_conv = 'conv' in model

    W_std, b_std = 2.**0.5, 0.5**0.5
    padding = PADDINGS[0]
    strides = STRIDES[0]
    phi = stax.Relu()
    use_pooling, is_res = False, False
    layer_norm = None
    pool_type = 'AVG'
    use_dropout = False

    # Check for duplicate / incorrectly-shaped NN configs / wrong backend.
    if is_conv:
      if xla_bridge.get_backend().platform == 'cpu':
        raise jtu.SkipTest('Not running CNN models on CPU to save time.')
    elif proj_into_2d != PROJECTIONS[0]:
      raise jtu.SkipTest('FC models do not have these parameters.')

    net = _get_net(W_std, b_std, filter_shape, is_conv, use_pooling, is_res,
                   padding, phi, strides, width, is_ntk, proj_into_2d,
                   pool_type, layer_norm, parameterization, use_dropout)
    self._check_agreement_with_empirical(net, same_inputs, is_conv, use_dropout,
                                         is_ntk, proj_into_2d)
Ejemplo n.º 9
0
def _get_net_pool(width, is_ntk, pool_type, padding,
                  filter_shape, strides, normalize_edges):
  W_std, b_std = 2.**0.5, 0.5**0.5
  phi = stax.Relu()
  parameterization = 'ntk'

  fc = partial(
      stax.Dense, W_std=W_std, b_std=b_std, parameterization=parameterization)
  conv = partial(
      stax.Conv,
      filter_shape=(3, 2),
      strides=None,
      padding='SAME',
      W_std=W_std,
      b_std=b_std,
      parameterization=parameterization)

  if pool_type == 'AVG':
    pool_fn = partial(stax.AvgPool, normalize_edges=normalize_edges)
    global_pool_fn = stax.GlobalAvgPool
  elif pool_type == 'SUM':
    pool_fn = stax.SumPool
    global_pool_fn = stax.GlobalSumPool

  pool = pool_fn(filter_shape, strides, padding)

  return stax.serial(
      conv(width), phi, pool, conv(width), phi, global_pool_fn(),
      fc(1 if is_ntk else width)), INPUT_SHAPE
Ejemplo n.º 10
0
  def testPredictOnCPU(self):
    x_train = random.normal(random.PRNGKey(1), (4, 4, 4, 2))
    x_test = random.normal(random.PRNGKey(1), (8, 4, 4, 2))

    y_train = random.uniform(random.PRNGKey(1), (4, 2))

    _, _, kernel_fn = stax.serial(
        stax.Conv(1, (3, 3)), stax.Relu(), stax.Flatten(), stax.Dense(1))

    for store_on_device in [False, True]:
      for device_count in [0, 1]:
        for get in ['ntk', 'nngp', ('nngp', 'ntk'), ('ntk', 'nngp')]:
          for x in [None, 'x_test']:
            with self.subTest(
                store_on_device=store_on_device,
                device_count=device_count,
                get=get,
                x=x):
              kernel_fn_batched = batch.batch(kernel_fn, 2, device_count,
                                              store_on_device)
              predictor = predict.gradient_descent_mse_ensemble(
                  kernel_fn_batched, x_train, y_train)

              x = x if x is None else x_test
              predict_none = predictor(None, x, get, compute_cov=True)
              predict_inf = predictor(np.inf, x, get, compute_cov=True)
              self.assertAllClose(predict_none, predict_inf)

              if x is not None:
                on_cpu = (not store_on_device or
                          xla_bridge.get_backend().platform == 'cpu')
                self.assertEqual(on_cpu, utils.is_on_cpu(predict_inf))
                self.assertEqual(on_cpu, utils.is_on_cpu(predict_none))
Ejemplo n.º 11
0
def WideResnetnt(
        block_size,
        k,
        num_classes,
        batchnorm='std'):  #, batch_norm=None,layer_norm=None,freezelast=None):
    """Based off of WideResnet from paper, with or without BatchNorm. 
  (Set config.wrn_block_size=3, config.wrn_widening_f=10 in that case).
  Uses default weight and bias init."""
    parameterization = 'standard'
    layers_lst = [
        stax_nt.Conv(16, (3, 3),
                     padding='SAME',
                     parameterization=parameterization),
        WideResnetGroupnt(block_size,
                          16 * k,
                          parameterization=parameterization,
                          batchnorm=batchnorm),
        WideResnetGroupnt(block_size,
                          32 * k, (2, 2),
                          parameterization=parameterization,
                          batchnorm=batchnorm),
        WideResnetGroupnt(block_size,
                          64 * k, (2, 2),
                          parameterization=parameterization,
                          batchnorm=batchnorm)
    ]
    layers_lst += [_batch_norm_internal(batchnorm), stax_nt.Relu()]
    layers_lst += [
        stax_nt.AvgPool((8, 8)),
        stax_nt.Flatten(),
        stax_nt.Dense(num_classes, parameterization=parameterization)
    ]
    return stax_nt.serial(*layers_lst)
Ejemplo n.º 12
0
def GP(x_train, y_train, x_test, y_test, w_std, b_std, l, C):
    net0 = stax.Dense(1, w_std, b_std)
    nets = [net0]

    k_layer = []
    K = net0[2](x_train, None)
    k_layer.append(K.nngp)

    for l in range(1, l + 1):
        net_l = stax.serial(stax.Relu(), stax.Dense(1, w_std, b_std))
        K = net_l[2](K)
        k_layer.append(K.nngp)
        nets += [stax.serial(nets[-1], net_l)]

    kernel_fn = nets[-1][2]

    start = time.time()
    # Bayesian and infinite-time gradient descent inference with infinite network.
    fx_test_nngp, fx_test_ntk = nt.predict.gp_inference(kernel_fn,
                                                        x_train,
                                                        y_train,
                                                        x_test,
                                                        get=('nngp', 'ntk'),
                                                        diag_reg=C)

    fx_test_nngp.block_until_ready()

    duration = time.time() - start
    #print('Kernel construction and inference done in %s seconds.' % duration)
    return accuracy(y_test, fx_test_nngp)
Ejemplo n.º 13
0
def main(unused_argv):

    train_size = FLAGS.train_size
    x_train, y_train, x_test, y_test = pickle.load(
        open("data_" + str(train_size) + ".p", "rb"))
    print("Got data")
    sys.stdout.flush()

    # Build the network
    init_fn, apply_fn, _ = stax.serial(
        stax.Dense(2048, 1., 0.05),
        # stax.Erf(),
        stax.Relu(),
        stax.Dense(1, 1., 0.05))

    # initialize the network first time, to compute NTK
    randnnn = numpy.random.random_integers(np.iinfo(np.int32).min,
                                           high=np.iinfo(np.int32).max,
                                           size=2)[0]
    key = random.PRNGKey(randnnn)
    _, params = init_fn(key, (-1, 784))

    # Create an MSE predictor to solve the NTK equation in function space.
    # we assume that the NTK is approximately the same for any sample of parameters (true in the limit of infinite width)

    print("Making NTK")
    sys.stdout.flush()
    ntk = nt.batch(nt.empirical_ntk_fn(apply_fn), batch_size=4, device_count=1)
    g_dd = ntk(x_train, None, params)
    pickle.dump(g_dd, open("ntk_train_" + str(FLAGS.train_size) + ".p", "wb"))
    g_td = ntk(x_test, x_train, params)
    pickle.dump(g_td,
                open("ntk_train_test_" + str(FLAGS.train_size) + ".p", "wb"))
    predictor = nt.predict.gradient_descent_mse(g_dd, y_train, g_td)
Ejemplo n.º 14
0
def _build_network(input_shape, network, out_logits):
    if len(input_shape) == 1:
        assert network == FLAT
        return stax.Dense(out_logits, W_std=2.0, b_std=0.5)
    elif len(input_shape) == 3:
        if network == POOLING:
            return stax.serial(
                stax.Conv(CONVOLUTION_CHANNELS, (3, 3), W_std=2.0, b_std=0.05),
                stax.GlobalAvgPool(),
                stax.Dense(out_logits, W_std=2.0, b_std=0.5))
        elif network == CONV:
            return stax.serial(
                stax.Conv(CONVOLUTION_CHANNELS, (1, 2), W_std=1.5, b_std=0.1),
                stax.Relu(),
                stax.Conv(CONVOLUTION_CHANNELS, (3, 2), W_std=2.0, b_std=0.05),
            )
        elif network == FLAT:
            return stax.serial(
                stax.Conv(CONVOLUTION_CHANNELS, (3, 3), W_std=2.0, b_std=0.05),
                stax.Flatten(), stax.Dense(out_logits, W_std=2.0, b_std=0.5))
        else:
            raise ValueError(
                'Unexpected network type found: {}'.format(network))
    else:
        raise ValueError('Expected flat or image test input.')
Ejemplo n.º 15
0
def main(unused_argv):
    # Build data pipelines.
    print('Loading data.')
    x_train, y_train, x_test, y_test = \
      datasets.get_dataset('cifar10', FLAGS.train_size, FLAGS.test_size)

    # Build the infinite network.
    _, _, kernel_fn = stax.serial(stax.Dense(1, 2., 0.05), stax.Relu(),
                                  stax.Dense(1, 2., 0.05))

    # Optionally, compute the kernel in batches, in parallel.
    kernel_fn = nt.batch(kernel_fn,
                         device_count=0,
                         batch_size=FLAGS.batch_size)

    start = time.time()
    # Bayesian and infinite-time gradient descent inference with infinite network.
    predict_fn = nt.predict.gradient_descent_mse_ensemble(kernel_fn,
                                                          x_train,
                                                          y_train,
                                                          diag_reg=1e-3)
    fx_test_nngp, fx_test_ntk = predict_fn(x_test=x_test)
    fx_test_nngp.block_until_ready()
    fx_test_ntk.block_until_ready()

    duration = time.time() - start
    print('Kernel construction and inference done in %s seconds.' % duration)

    # Print out accuracy and loss for infinite network predictions.
    loss = lambda fx, y_hat: 0.5 * np.mean((fx - y_hat)**2)
    util.print_summary('NNGP test', y_test, fx_test_nngp, None, loss)
    util.print_summary('NTK test', y_test, fx_test_ntk, None, loss)
Ejemplo n.º 16
0
  def test_layernorm(self,
                     model,
                     width,
                     same_inputs,
                     is_ntk,
                     proj_into_2d,
                     layer_norm):
    is_conv = 'conv' in model
    # Check for duplicate / incorrectly-shaped NN configs / wrong backend.
    if is_conv:
      test_utils.skip_test(self)
    elif proj_into_2d != PROJECTIONS[0] or layer_norm not in ('C', 'NC'):
      raise absltest.SkipTest('FC models do not have these parameters.')

    W_std, b_std = 2.**0.5, 0.5**0.5
    filter_shape = FILTER_SHAPES[0]
    padding = PADDINGS[0]
    strides = STRIDES[0]
    phi = stax.Relu()
    use_pooling, is_res = False, False
    parameterization = 'ntk'
    pool_type = 'AVG'
    use_dropout = False

    net = _get_net(W_std, b_std, filter_shape, is_conv, use_pooling, is_res,
                   padding, phi, strides, width, is_ntk, proj_into_2d,
                   pool_type, layer_norm, parameterization, 1, use_dropout)
    _check_agreement_with_empirical(self, net, same_inputs, use_dropout, is_ntk,
                                    0.07)
Ejemplo n.º 17
0
    def testPredictOnCPU(self):
        x_train = random.normal(random.PRNGKey(1), (10, 4, 5, 3))
        x_test = random.normal(random.PRNGKey(1), (8, 4, 5, 3))

        y_train = random.uniform(random.PRNGKey(1), (10, 7))

        _, _, kernel_fn = stax.serial(stax.Conv(1, (3, 3)), stax.Relu(),
                                      stax.Flatten(), stax.Dense(1))

        for store_on_device in [False, True]:
            for device_count in [0, 1]:
                for get in ['ntk', 'nngp', ('nngp', 'ntk'), ('ntk', 'nngp')]:
                    with self.subTest(store_on_device=store_on_device,
                                      device_count=device_count,
                                      get=get):
                        kernel_fn_batched = batch.batch(
                            kernel_fn, 2, device_count, store_on_device)
                        predictor = predict.gradient_descent_mse_gp(
                            kernel_fn_batched, x_train, y_train, x_test, get,
                            0., True)
                        gp_inference = predict.gp_inference(
                            kernel_fn_batched, x_train, y_train, x_test, get,
                            0., True)

                        self.assertAllClose(predictor(None), predictor(np.inf),
                                            True)
                        self.assertAllClose(predictor(None), gp_inference,
                                            True)
Ejemplo n.º 18
0
  def test_composition_conv(self, avg_pool, same_inputs):
    rng = random.PRNGKey(0)
    x1 = random.normal(rng, (3, 5, 5, 3))
    x2 = None if same_inputs else random.normal(rng, (4, 5, 5, 3))

    Block = stax.serial(stax.Conv(256, (3, 3)), stax.Relu())
    if avg_pool:
      Readout = stax.serial(stax.Conv(256, (3, 3)),
                            stax.GlobalAvgPool(),
                            stax.Dense(10))
    else:
      Readout = stax.serial(stax.Flatten(), stax.Dense(10))

    block_ker_fn, readout_ker_fn = Block[2], Readout[2]
    _, _, composed_ker_fn = stax.serial(Block, Readout)

    composed_ker_out = composed_ker_fn(x1, x2)
    ker_out_no_marg = readout_ker_fn(block_ker_fn(x1, x2,
                                                  diagonal_spatial=False))
    ker_out_default = readout_ker_fn(block_ker_fn(x1, x2))
    self.assertAllClose(composed_ker_out, ker_out_no_marg)
    self.assertAllClose(composed_ker_out, ker_out_default)

    if avg_pool:
      with self.assertRaises(ValueError):
        ker_out = readout_ker_fn(block_ker_fn(x1, x2, diagonal_spatial=True))
    else:
      ker_out_marg = readout_ker_fn(block_ker_fn(x1, x2,
                                                 diagonal_spatial=True))
      self.assertAllClose(composed_ker_out, ker_out_marg)
Ejemplo n.º 19
0
    def test_composition_conv(self, avg_pool):
        rng = random.PRNGKey(0)
        x1 = random.normal(rng, (5, 10, 10, 3))
        x2 = random.normal(rng, (5, 10, 10, 3))

        Block = stax.serial(stax.Conv(256, (3, 3)), stax.Relu())
        if avg_pool:
            Readout = stax.serial(stax.GlobalAvgPool(), stax.Dense(10))
            marginalization = 'none'
        else:
            Readout = stax.serial(stax.Flatten(), stax.Dense(10))
            marginalization = 'auto'

        block_ker_fn, readout_ker_fn = Block[2], Readout[2]
        _, _, composed_ker_fn = stax.serial(Block, Readout)

        ker_out = readout_ker_fn(
            block_ker_fn(x1, marginalization=marginalization))
        composed_ker_out = composed_ker_fn(x1)
        self.assertAllClose(ker_out, composed_ker_out, True)

        if avg_pool:
            with self.assertRaises(ValueError):
                ker_out = readout_ker_fn(block_ker_fn(x1))

        ker_out = readout_ker_fn(
            block_ker_fn(x1, x2, marginalization=marginalization))
        composed_ker_out = composed_ker_fn(x1, x2)
        self.assertAllClose(ker_out, composed_ker_out, True)
Ejemplo n.º 20
0
    def test_layernorm(self, model, width, same_inputs, is_ntk, proj_into_2d,
                       layer_norm):
        is_conv = 'conv' in model
        # Check for duplicate / incorrectly-shaped NN configs / wrong backend.
        if is_conv:
            if xla_bridge.get_backend().platform == 'cpu':
                raise jtu.SkipTest(
                    'Not running CNN models on CPU to save time.')
        elif proj_into_2d != PROJECTIONS[0] or layer_norm != LAYER_NORM[0]:
            raise jtu.SkipTest('FC models do not have these parameters.')

        W_std, b_std = 2.**0.5, 0.5**0.5
        filter_size = FILTER_SIZES[0]
        padding = PADDINGS[0]
        strides = STRIDES[0]
        phi = stax.Relu()
        use_pooling, is_res = False, False
        parameterization = 'ntk'
        use_dropout = False

        self._check_agreement_with_empirical(W_std, b_std, filter_size,
                                             is_conv, is_ntk, is_res,
                                             layer_norm, padding, phi,
                                             proj_into_2d, same_inputs,
                                             strides, use_pooling, width,
                                             parameterization, use_dropout)
Ejemplo n.º 21
0
    def test_empirical_ntk_diagonal_outputs(self, same_inputs, device_count,
                                            trace_axes, diagonal_axes):
        test_utils.stub_out_pmap(batching, 2)
        rng = random.PRNGKey(0)

        input_key1, input_key2, net_key = random.split(rng, 3)

        init_fn, apply_fn, _ = stax.serial(stax.Dense(5), stax.Relu(),
                                           stax.Dense(3))

        test_x1 = random.normal(input_key1, (12, 4, 4))
        test_x2 = None
        if same_inputs:
            test_x2 = random.normal(input_key2, (9, 4, 4))

        kernel_fn = nt.empirical_ntk_fn(apply_fn,
                                        trace_axes=trace_axes,
                                        diagonal_axes=diagonal_axes,
                                        vmap_axes=0,
                                        implementation=2)

        _, params = init_fn(net_key, test_x1.shape)

        true_kernel = kernel_fn(test_x1, test_x2, params)
        batched_fn = batching.batch(kernel_fn,
                                    device_count=device_count,
                                    batch_size=3)
        batch_kernel = batched_fn(test_x1, test_x2, params)
        self.assertAllClose(true_kernel, batch_kernel)
Ejemplo n.º 22
0
  def test_nonlinear(
      self,
      model,
      width,
      same_inputs,
      is_ntk,
      filter_shape,
      proj_into_2d,
      b_std,
      W_std,
      parameterization,
      s
  ):
    is_conv = 'conv' in model

    if parameterization == 'standard':
      width //= s

    padding = PADDINGS[0]
    strides = STRIDES[0]
    phi = stax.Relu()
    use_pooling, is_res = False, False
    layer_norm = None
    pool_type = 'AVG'
    use_dropout = False

    # Check for duplicate / incorrectly-shaped NN configs / wrong backend.
    if is_conv:
      test_utils.skip_test(self)
    elif proj_into_2d != PROJECTIONS[0] or filter_shape != FILTER_SHAPES[0]:
      raise absltest.SkipTest('FC models do not have these parameters.')

    net = _get_net(W_std=W_std,
                   b_std=b_std,
                   filter_shape=filter_shape,
                   is_conv=is_conv,
                   use_pooling=use_pooling,
                   is_res=is_res,
                   padding=padding,
                   phi=phi,
                   strides=strides,
                   width=width,
                   is_ntk=is_ntk,
                   proj_into_2d=proj_into_2d,
                   pool_type=pool_type,
                   layer_norm=layer_norm,
                   parameterization=parameterization,
                   s=s,
                   use_dropout=use_dropout)

    _check_agreement_with_empirical(
        self,
        net=net,
        same_inputs=same_inputs,
        use_dropout=use_dropout,
        is_ntk=is_ntk,
        rtol=0.015,
        atol=1000
    )
Ejemplo n.º 23
0
  def testGpInference(self):
    reg = 1e-5
    key = random.PRNGKey(1)
    x_train = random.normal(key, (4, 2))
    init_fn, apply_fn, kernel_fn_analytic = stax.serial(
        stax.Dense(32, 2., 0.5),
        stax.Relu(),
        stax.Dense(10, 2., 0.5))
    y_train = random.normal(key, (4, 10))
    for kernel_fn_is_analytic in [True, False]:
      if kernel_fn_is_analytic:
        kernel_fn = kernel_fn_analytic
      else:
        _, params = init_fn(key, x_train.shape)
        kernel_fn_empirical = empirical.empirical_kernel_fn(apply_fn)
        def kernel_fn(x1, x2, get):
          return kernel_fn_empirical(x1, x2, get, params)

      for get in [None,
                  'nngp', 'ntk',
                  ('nngp',), ('ntk',),
                  ('nngp', 'ntk'), ('ntk', 'nngp')]:
        k_dd = kernel_fn(x_train, None, get)

        gp_inference = predict.gp_inference(k_dd, y_train, diag_reg=reg)
        gd_ensemble = predict.gradient_descent_mse_ensemble(kernel_fn,
                                                            x_train,
                                                            y_train,
                                                            diag_reg=reg)
        for x_test in [None, 'x_test']:
          x_test = None if x_test is None else random.normal(key, (8, 2))
          k_td = None if x_test is None else kernel_fn(x_test, x_train, get)

          for compute_cov in [True, False]:
            with self.subTest(kernel_fn_is_analytic=kernel_fn_is_analytic,
                              get=get,
                              x_test=x_test if x_test is None else 'x_test',
                              compute_cov=compute_cov):
              if compute_cov:
                nngp_tt = (True if x_test is None else
                           kernel_fn(x_test, None, 'nngp'))
              else:
                nngp_tt = None

              out_ens = gd_ensemble(None, x_test, get, compute_cov)
              out_ens_inf = gd_ensemble(np.inf, x_test, get, compute_cov)
              self._assertAllClose(out_ens_inf, out_ens, 0.08)

              if (get is not None and
                  'nngp' not in get and
                  compute_cov and
                  k_td is not None):
                with self.assertRaises(ValueError):
                  out_gp_inf = gp_inference(get=get, k_test_train=k_td,
                                            nngp_test_test=nngp_tt)
              else:
                out_gp_inf = gp_inference(get=get, k_test_train=k_td,
                                          nngp_test_test=nngp_tt)
                self.assertAllClose(out_ens, out_gp_inf)
Ejemplo n.º 24
0
  def test_vmap_axes(self, same_inputs):
    n1, n2 = 3, 4
    c1, c2, c3 = 9, 5, 7
    h2, h3, w3 = 6, 8, 2

    def get_x(n, k):
      k1, k2, k3 = random.split(k, 3)
      x1 = random.normal(k1, (n, c1))
      x2 = random.normal(k2, (h2, n, c2))
      x3 = random.normal(k3, (c3, w3, n, h3))
      x = [(x1, x2), x3]
      return x

    x1 = get_x(n1, random.PRNGKey(1))
    x2 = get_x(n2, random.PRNGKey(2)) if not same_inputs else None

    p1 = random.normal(random.PRNGKey(5), (n1, h2, h2))
    p2 = None if same_inputs else random.normal(random.PRNGKey(6), (n2, h2, h2))

    init_fn, apply_fn, _ = stax.serial(
        stax.parallel(
            stax.parallel(
                stax.serial(stax.Dense(4, 2., 0.1),
                            stax.Relu(),
                            stax.Dense(3, 1., 0.15)),  # 1
                stax.serial(stax.Conv(7, (2,), padding='SAME',
                                      dimension_numbers=('HNC', 'OIH', 'NHC')),
                            stax.Erf(),
                            stax.Aggregate(1, 0, -1),
                            stax.GlobalAvgPool(),
                            stax.Dense(3, 0.5, 0.2)),  # 2
            ),
            stax.serial(
                stax.Conv(5, (2, 3), padding='SAME',
                          dimension_numbers=('CWNH', 'IOHW', 'HWCN')),
                stax.Sin(),
            )  # 3
        ),
        stax.parallel(
            stax.FanInSum(),
            stax.Conv(2, (2, 1), dimension_numbers=('HWCN', 'OIHW', 'HNWC'))
        )
    )

    _, params = init_fn(random.PRNGKey(3), tree_map(np.shape, x1))
    implicit = jit(empirical._empirical_implicit_ntk_fn(apply_fn))
    direct = jit(empirical._empirical_direct_ntk_fn(apply_fn))

    implicit_batched = jit(empirical._empirical_implicit_ntk_fn(
        apply_fn, vmap_axes=([(0, 1), 2], [-2, -3], dict(pattern=0))))
    direct_batched = jit(empirical._empirical_direct_ntk_fn(
        apply_fn, vmap_axes=([(-2, -2), -2], [0, 1], dict(pattern=-3))))

    k = direct(x1, x2, params, pattern=(p1, p2))

    self.assertAllClose(k, implicit(x1, x2, params, pattern=(p1, p2)))
    self.assertAllClose(k, direct_batched(x1, x2, params, pattern=(p1, p2)))
    self.assertAllClose(k, implicit_batched(x1, x2, params, pattern=(p1, p2)))
Ejemplo n.º 25
0
def _get_inputs_and_model(width=1, n_classes=2):
    key = random.PRNGKey(1)
    key, split = random.split(key)
    x1 = random.normal(key, (8, 4, 3, 2))
    x2 = random.normal(split, (4, 4, 3, 2))
    init_fun, apply_fun, ker_fun = stax.serial(stax.Conv(width, (3, 3)),
                                               stax.Relu(), stax.Flatten(),
                                               stax.Dense(n_classes, 2., 0.5))
    return x1, x2, init_fun, apply_fun, ker_fun, key
def create_network(depth, width):
    layers = []
    for l in range(depth):
        layers += [
            stax.Dense(M, W_std=1.5, b_std=0.0, parameterization='ntk'),
            stax.Relu()
        ]
    layers += [stax.Dense(1, W_std=1.5, b_std=0, parameterization='ntk')]
    return stax.serial(*layers)
Ejemplo n.º 27
0
    def testPredictOnCPU(self):
        key1 = stateless_uniform(shape=[2],
                                 seed=[1, 1],
                                 minval=None,
                                 maxval=None,
                                 dtype=tf.int32)
        key2 = stateless_uniform(shape=[2],
                                 seed=[1, 1],
                                 minval=None,
                                 maxval=None,
                                 dtype=tf.int32)
        key3 = stateless_uniform(shape=[2],
                                 seed=[1, 1],
                                 minval=None,
                                 maxval=None,
                                 dtype=tf.int32)
        x_train = np.asarray(normal((4, 4, 4, 2), seed=key1))
        x_test = np.asarray(normal((8, 4, 4, 2), seed=key2))

        y_train = np.asarray(stateless_uniform(shape=(4, 2), seed=key3))

        _, _, kernel_fn = stax.serial(stax.Conv(1, (3, 3)), stax.Relu(),
                                      stax.Flatten(), stax.Dense(1))

        for store_on_device in [False, True]:
            for device_count in [0, 1]:
                for get in ['ntk', 'nngp', ('nngp', 'ntk'), ('ntk', 'nngp')]:
                    for x in [None, 'x_test']:
                        with self.subTest(store_on_device=store_on_device,
                                          device_count=device_count,
                                          get=get,
                                          x=x):
                            kernel_fn_batched = batch.batch(
                                kernel_fn, 2, device_count, store_on_device)
                            predictor = predict.gradient_descent_mse_ensemble(
                                kernel_fn_batched, x_train, y_train)

                            x = x if x is None else x_test
                            predict_none = predictor(None,
                                                     x,
                                                     get,
                                                     compute_cov=True)
                            predict_inf = predictor(np.inf,
                                                    x,
                                                    get,
                                                    compute_cov=True)
                            self.assertAllClose(predict_none, predict_inf)

                            if x is not None:
                                on_cpu = (not store_on_device
                                          or xla_bridge.get_backend().platform
                                          == 'cpu')
                                self.assertEqual(on_cpu,
                                                 utils.is_on_cpu(predict_inf))
                                self.assertEqual(on_cpu,
                                                 utils.is_on_cpu(predict_none))
Ejemplo n.º 28
0
    def _test_analytic_kernel_composition(self, batching_fn):
        # Check Fully-Connected.
        rng = random.PRNGKey(0)
        rng_self, rng_other = random.split(rng)
        x_self = random.normal(rng_self, (8, 10))
        x_other = random.normal(rng_other, (20, 10))
        Block = stax.serial(stax.Dense(256), stax.Relu())

        _, _, ker_fn = Block
        ker_fn = batching_fn(ker_fn)

        _, _, composed_ker_fn = stax.serial(Block, Block)

        ker_out = ker_fn(ker_fn(x_self))
        composed_ker_out = composed_ker_fn(x_self)
        self.assertAllClose(ker_out, composed_ker_out, True)

        ker_out = ker_fn(ker_fn(x_self, x_other))
        composed_ker_out = composed_ker_fn(x_self, x_other)
        self.assertAllClose(ker_out, composed_ker_out, True)

        # Check convolutional + pooling.
        x_self = random.normal(rng, (8, 10, 10, 3))
        x_other = random.normal(rng, (10, 10, 10, 3))

        Block = stax.serial(stax.Conv(256, (3, 3)), stax.Relu())
        Readout = stax.serial(stax.GlobalAvgPool(), stax.Dense(10))

        block_ker_fn, readout_ker_fn = Block[2], Readout[2]
        _, _, composed_ker_fn = stax.serial(Block, Readout)

        block_ker_fn = batching_fn(block_ker_fn)
        readout_ker_fn = batching_fn(readout_ker_fn)

        ker_out = readout_ker_fn(block_ker_fn(x_self, marginalization='none'))
        composed_ker_out = composed_ker_fn(x_self)
        self.assertAllClose(ker_out, composed_ker_out, True)

        ker_out = readout_ker_fn(
            block_ker_fn(x_self, x_other, marginalization='none'))
        composed_ker_out = composed_ker_fn(x_self, x_other)
        self.assertAllClose(ker_out, composed_ker_out, True)
Ejemplo n.º 29
0
  def test_composition(self):
    rng = random.PRNGKey(0)
    xs = random.normal(rng, (10, 10))
    Block = stax.serial(stax.Dense(256), stax.Relu())

    _, _, ker_fn = Block
    _, _, composed_ker_fn = stax.serial(Block, Block)

    ker_out = ker_fn(ker_fn(xs))
    composed_ker_out = composed_ker_fn(xs)

    self.assertAllClose(ker_out, composed_ker_out, True)
Ejemplo n.º 30
0
 def WideResnetBlock(channels, strides=(1, 1), channel_mismatch=False):
   main = stax.serial(
       stax.Relu(),
       stax.Conv(
           channels, (3, 3), strides, padding='SAME',
           parameterization='standard'
       ),
       stax.Relu(),
       stax.Conv(channels, (3, 3), padding='SAME',
                 parameterization='standard'),
   )
   shortcut = (
       stax.Identity()
       if not channel_mismatch
       else stax.Conv(
           channels, (3, 3), strides, padding='SAME',
           parameterization='standard'
       )
   )
   return stax.serial(stax.FanOut(2), stax.parallel(main, shortcut),
                      stax.FanInSum())