Ejemplo n.º 1
0
def _get_net(W_std, b_std, filter_shape, is_conv, use_pooling, is_res, padding,
             phi, strides, width, is_ntk, proj_into_2d, layer_norm,
             parameterization, use_dropout):
    fc = partial(stax.Dense,
                 W_std=W_std,
                 b_std=b_std,
                 parameterization=parameterization)
    conv = partial(stax.Conv,
                   filter_shape=filter_shape,
                   strides=strides,
                   padding=padding,
                   W_std=W_std,
                   b_std=b_std,
                   parameterization=parameterization)
    affine = conv(width) if is_conv else fc(width)
    rate = np.onp.random.uniform(0.5, 0.9)
    dropout = stax.Dropout(rate, mode='train')
    ave_pool = stax.AvgPool((2, 3), None,
                            'SAME' if padding == 'SAME' else 'CIRCULAR')
    ave_pool_or_identity = ave_pool if use_pooling else stax.Identity()
    dropout_or_identity = dropout if use_dropout else stax.Identity()
    layer_norm_or_identity = (stax.Identity() if layer_norm is None else
                              stax.LayerNorm(axis=layer_norm))
    res_unit = stax.serial(ave_pool_or_identity, phi, dropout_or_identity,
                           affine)
    if is_res:
        block = stax.serial(affine, stax.FanOut(2),
                            stax.parallel(stax.Identity(), res_unit),
                            stax.FanInSum(), layer_norm_or_identity)
    else:
        block = stax.serial(affine, res_unit, layer_norm_or_identity)

    if proj_into_2d == 'FLAT':
        proj_layer = stax.Flatten()
    elif proj_into_2d == 'POOL':
        proj_layer = stax.GlobalAvgPool()
    elif proj_into_2d.startswith('ATTN'):
        n_heads = int(np.sqrt(width))
        n_chan_val = int(np.round(float(width) / n_heads))
        fixed = proj_into_2d == 'ATTN_FIXED'
        proj_layer = stax.serial(
            stax.GlobalSelfAttention(width,
                                     n_chan_key=width,
                                     n_chan_val=n_chan_val,
                                     n_heads=n_heads,
                                     fixed=fixed,
                                     W_key_std=W_std,
                                     W_value_std=W_std,
                                     W_query_std=W_std,
                                     W_out_std=1.0,
                                     b_std=b_std), stax.Flatten())
    else:
        raise ValueError(proj_into_2d)
    readout = stax.serial(proj_layer, fc(1 if is_ntk else width))

    return stax.serial(block, readout)
Ejemplo n.º 2
0
def WideResnetnt(
        block_size,
        k,
        num_classes,
        batchnorm='std'):  #, batch_norm=None,layer_norm=None,freezelast=None):
    """Based off of WideResnet from paper, with or without BatchNorm. 
  (Set config.wrn_block_size=3, config.wrn_widening_f=10 in that case).
  Uses default weight and bias init."""
    parameterization = 'standard'
    layers_lst = [
        stax_nt.Conv(16, (3, 3),
                     padding='SAME',
                     parameterization=parameterization),
        WideResnetGroupnt(block_size,
                          16 * k,
                          parameterization=parameterization,
                          batchnorm=batchnorm),
        WideResnetGroupnt(block_size,
                          32 * k, (2, 2),
                          parameterization=parameterization,
                          batchnorm=batchnorm),
        WideResnetGroupnt(block_size,
                          64 * k, (2, 2),
                          parameterization=parameterization,
                          batchnorm=batchnorm)
    ]
    layers_lst += [_batch_norm_internal(batchnorm), stax_nt.Relu()]
    layers_lst += [
        stax_nt.AvgPool((8, 8)),
        stax_nt.Flatten(),
        stax_nt.Dense(num_classes, parameterization=parameterization)
    ]
    return stax_nt.serial(*layers_lst)
Ejemplo n.º 3
0
    def testPredictOnCPU(self):
        x_train = random.normal(random.PRNGKey(1), (10, 4, 5, 3))
        x_test = random.normal(random.PRNGKey(1), (8, 4, 5, 3))

        y_train = random.uniform(random.PRNGKey(1), (10, 7))

        _, _, kernel_fn = stax.serial(stax.Conv(1, (3, 3)), stax.Relu(),
                                      stax.Flatten(), stax.Dense(1))

        for store_on_device in [False, True]:
            for device_count in [0, 1]:
                for get in ['ntk', 'nngp', ('nngp', 'ntk'), ('ntk', 'nngp')]:
                    with self.subTest(store_on_device=store_on_device,
                                      device_count=device_count,
                                      get=get):
                        kernel_fn_batched = batch.batch(
                            kernel_fn, 2, device_count, store_on_device)
                        predictor = predict.gradient_descent_mse_gp(
                            kernel_fn_batched, x_train, y_train, x_test, get,
                            0., True)
                        gp_inference = predict.gp_inference(
                            kernel_fn_batched, x_train, y_train, x_test, get,
                            0., True)

                        self.assertAllClose(predictor(None), predictor(np.inf),
                                            True)
                        self.assertAllClose(predictor(None), gp_inference,
                                            True)
Ejemplo n.º 4
0
def wide_resnet(block_size, k, num_classes):
    return stax.serial(stax.Conv(16, (3, 3), padding='SAME'),
                       wide_resnet_group(block_size, int(16 * k)),
                       wide_resnet_group(block_size, int(32 * k), (2, 2)),
                       wide_resnet_group(block_size, int(64 * k), (2, 2)),
                       stax.AvgPool((8, 8)), stax.Flatten(),
                       stax.Dense(num_classes, 1., 0.))
Ejemplo n.º 5
0
  def test_composition_conv(self, avg_pool, same_inputs):
    rng = random.PRNGKey(0)
    x1 = random.normal(rng, (3, 5, 5, 3))
    x2 = None if same_inputs else random.normal(rng, (4, 5, 5, 3))

    Block = stax.serial(stax.Conv(256, (3, 3)), stax.Relu())
    if avg_pool:
      Readout = stax.serial(stax.Conv(256, (3, 3)),
                            stax.GlobalAvgPool(),
                            stax.Dense(10))
    else:
      Readout = stax.serial(stax.Flatten(), stax.Dense(10))

    block_ker_fn, readout_ker_fn = Block[2], Readout[2]
    _, _, composed_ker_fn = stax.serial(Block, Readout)

    composed_ker_out = composed_ker_fn(x1, x2)
    ker_out_no_marg = readout_ker_fn(block_ker_fn(x1, x2,
                                                  diagonal_spatial=False))
    ker_out_default = readout_ker_fn(block_ker_fn(x1, x2))
    self.assertAllClose(composed_ker_out, ker_out_no_marg)
    self.assertAllClose(composed_ker_out, ker_out_default)

    if avg_pool:
      with self.assertRaises(ValueError):
        ker_out = readout_ker_fn(block_ker_fn(x1, x2, diagonal_spatial=True))
    else:
      ker_out_marg = readout_ker_fn(block_ker_fn(x1, x2,
                                                 diagonal_spatial=True))
      self.assertAllClose(composed_ker_out, ker_out_marg)
Ejemplo n.º 6
0
def main(unused_argv):
  key1, key2, key3 = random.split(random.PRNGKey(1), 3)
  x1 = random.normal(key1, (2, 8, 8, 3))
  x2 = random.normal(key2, (3, 8, 8, 3))

  # A vanilla CNN.
  init_fn, f, _ = stax.serial(
      stax.Conv(8, (3, 3)),
      stax.Relu(),
      stax.Conv(8, (3, 3)),
      stax.Relu(),
      stax.Conv(8, (3, 3)),
      stax.Flatten(),
      stax.Dense(10)
  )

  _, params = init_fn(key3, x1.shape)
  kwargs = dict(
      f=f,
      trace_axes=(),
      vmap_axes=0,
  )

  # Default, baseline Jacobian contraction.
  jacobian_contraction = nt.empirical_ntk_fn(
      **kwargs,
      implementation=nt.NtkImplementation.JACOBIAN_CONTRACTION)

  # (6, 3, 10, 10) full `np.ndarray` test-train NTK
  ntk_jc = jacobian_contraction(x2, x1, params)

  # NTK-vector products-based implementation.
  ntk_vector_products = nt.empirical_ntk_fn(
      **kwargs,
      implementation=nt.NtkImplementation.NTK_VECTOR_PRODUCTS)

  ntk_vp = ntk_vector_products(x2, x1, params)

  # Structured derivatives-based implementation.
  structured_derivatives = nt.empirical_ntk_fn(
      **kwargs,
      implementation=nt.NtkImplementation.STRUCTURED_DERIVATIVES)

  ntk_sd = structured_derivatives(x2, x1, params)

  # Auto-FLOPs-selecting implementation. Doesn't work correctly on CPU/GPU.
  auto = nt.empirical_ntk_fn(
      **kwargs,
      implementation=nt.NtkImplementation.AUTO)

  ntk_auto = auto(x2, x1, params)

  # Check that implementations match
  for ntk1 in [ntk_jc, ntk_vp, ntk_sd, ntk_auto]:
    for ntk2 in [ntk_jc, ntk_vp, ntk_sd, ntk_auto]:
      diff = np.max(np.abs(ntk1 - ntk2))
      print(f'NTK implementation diff {diff}.')
      assert diff < (1e-4 if jax.default_backend() != 'tpu' else 0.1), diff

  print('All NTK implementations match.')
Ejemplo n.º 7
0
    def test_composition_conv(self, avg_pool):
        rng = random.PRNGKey(0)
        x1 = random.normal(rng, (5, 10, 10, 3))
        x2 = random.normal(rng, (5, 10, 10, 3))

        Block = stax.serial(stax.Conv(256, (3, 3)), stax.Relu())
        if avg_pool:
            Readout = stax.serial(stax.GlobalAvgPool(), stax.Dense(10))
            marginalization = 'none'
        else:
            Readout = stax.serial(stax.Flatten(), stax.Dense(10))
            marginalization = 'auto'

        block_ker_fn, readout_ker_fn = Block[2], Readout[2]
        _, _, composed_ker_fn = stax.serial(Block, Readout)

        ker_out = readout_ker_fn(
            block_ker_fn(x1, marginalization=marginalization))
        composed_ker_out = composed_ker_fn(x1)
        self.assertAllClose(ker_out, composed_ker_out, True)

        if avg_pool:
            with self.assertRaises(ValueError):
                ker_out = readout_ker_fn(block_ker_fn(x1))

        ker_out = readout_ker_fn(
            block_ker_fn(x1, x2, marginalization=marginalization))
        composed_ker_out = composed_ker_fn(x1, x2)
        self.assertAllClose(ker_out, composed_ker_out, True)
Ejemplo n.º 8
0
 def WideResnet(block_size, k, num_classes):
     return stax.serial(
         stax.Conv(16, (3, 3), padding='SAME'),
         ntk_generator.ResnetGroup(block_size, int(16 * k)),
         ntk_generator.ResnetGroup(block_size, int(32 * k), (2, 2)),
         ntk_generator.ResnetGroup(block_size, int(64 * k), (2, 2)),
         stax.Flatten(), stax.Dense(num_classes, 1., 0.))
Ejemplo n.º 9
0
def _get_net(W_std, b_std, filter_shape, is_conv, use_pooling, is_res,
             padding, phi, strides, width, is_ntk):
  fc = partial(stax.Dense, W_std=W_std, b_std=b_std)
  conv = partial(
      stax.Conv,
      filter_shape=filter_shape,
      strides=strides,
      padding=padding,
      W_std=W_std,
      b_std=b_std)
  affine = conv(width) if is_conv else fc(width)

  res_unit = stax.serial((stax.AvgPool(
      (2, 3), None, 'SAME' if padding == 'SAME' else 'CIRCULAR')
                          if use_pooling else stax.Identity()), phi, affine)

  if is_res:
    block = stax.serial(affine, stax.FanOut(2),
                        stax.parallel(stax.Identity(), res_unit),
                        stax.FanInSum())
  else:
    block = stax.serial(affine, res_unit)

  readout = stax.serial(stax.GlobalAvgPool() if use_pooling else stax.Flatten(),
                        fc(1 if is_ntk else width))

  net = stax.serial(block, readout)
  return net
Ejemplo n.º 10
0
  def test_nested_parallel(self, same_inputs, kernel_type):
    platform = default_backend()
    rtol = RTOL if platform != 'tpu' else 0.05

    rng = random.PRNGKey(0)
    (input_key1,
     input_key2,
     input_key3,
     input_key4,
     mask_key,
     mc_key) = random.split(rng, 6)

    x1_1, x2_1 = _get_inputs(input_key1, same_inputs, (BATCH_SIZE, 5))
    x1_2, x2_2 = _get_inputs(input_key2, same_inputs, (BATCH_SIZE, 2, 2, 2))
    x1_3, x2_3 = _get_inputs(input_key3, same_inputs, (BATCH_SIZE, 2, 2, 3))
    x1_4, x2_4 = _get_inputs(input_key4, same_inputs, (BATCH_SIZE, 3, 4))

    m1_key, m2_key, m3_key, m4_key = random.split(mask_key, 4)

    x1_1 = test_utils.mask(
        x1_1, mask_constant=-1, mask_axis=(1,), key=m1_key, p=0.5)
    x1_2 = test_utils.mask(
        x1_2, mask_constant=-1, mask_axis=(2, 3,), key=m2_key, p=0.5)
    if not same_inputs:
      x2_3 = test_utils.mask(
          x2_3, mask_constant=-1, mask_axis=(1, 3,), key=m3_key, p=0.5)
      x2_4 = test_utils.mask(
          x2_4, mask_constant=-1, mask_axis=(2,), key=m4_key, p=0.5)

    x1 = (((x1_1, x1_2), x1_3), x1_4)
    x2 = (((x2_1, x2_2), x2_3), x2_4) if not same_inputs else None

    N_in = 2 ** 7

    # We only include dropout on non-TPU backends, because it takes large N to
    # converge on TPU.
    dropout_or_id = stax.Dropout(0.9) if platform != 'tpu' else stax.Identity()

    init_fn, apply_fn, kernel_fn = stax.parallel(
        stax.parallel(
            stax.parallel(stax.Dense(N_in),
                          stax.serial(stax.Conv(N_in + 1, (2, 2)),
                                      stax.Flatten())),
            stax.serial(stax.Conv(N_in + 2, (2, 2)),
                        dropout_or_id,
                        stax.GlobalAvgPool())),
        stax.Conv(N_in + 3, (2,)))

    kernel_fn_empirical = nt.monte_carlo_kernel_fn(
        init_fn, apply_fn, mc_key, N_SAMPLES,
        implementation=_DEFAULT_TESTING_NTK_IMPLEMENTATION,
        vmap_axes=(((((0, 0), 0), 0), (((0, 0), 0), 0), {})
                   if platform == 'tpu' else None)
    )

    test_utils.assert_close_matrices(
        self,
        kernel_fn(x1, x2, get=kernel_type, mask_constant=-1),
        kernel_fn_empirical(x1, x2, get=kernel_type, mask_constant=-1),
        rtol)
Ejemplo n.º 11
0
def _build_network(input_shape, network, out_logits, use_dropout):
    dropout = stax.Dropout(0.9,
                           mode='train') if use_dropout else stax.Identity()
    if len(input_shape) == 1:
        assert network == 'FLAT'
        return stax.serial(stax.Dense(WIDTH, W_std=2.0, b_std=0.5), dropout,
                           stax.Dense(out_logits, W_std=2.0, b_std=0.5))
    elif len(input_shape) == 3:
        if network == POOLING:
            return stax.serial(
                stax.Conv(CONVOLUTION_CHANNELS, (2, 2), W_std=2.0, b_std=0.05),
                stax.GlobalAvgPool(), dropout,
                stax.Dense(out_logits, W_std=2.0, b_std=0.5))
        elif network == FLAT:
            return stax.serial(
                stax.Conv(CONVOLUTION_CHANNELS, (2, 2), W_std=2.0, b_std=0.05),
                stax.Flatten(), dropout,
                stax.Dense(out_logits, W_std=2.0, b_std=0.5))
        elif network == INTERMEDIATE_CONV:
            return stax.Conv(CONVOLUTION_CHANNELS, (2, 2),
                             W_std=2.0,
                             b_std=0.05)
        else:
            raise ValueError(
                'Unexpected network type found: {}'.format(network))
    else:
        raise ValueError('Expected flat or image test input.')
Ejemplo n.º 12
0
  def test_hermite(self, same_inputs, degree, get, readout):
    key = random.PRNGKey(1)
    key1, key2, key = random.split(key, 3)

    if degree > 2:
      width = 10000
      n_samples = 5000
      test_utils.skip_test(self)
    else:
      width = 10000
      n_samples = 100

    x1 = np.cos(random.normal(key1, [2, 6, 6, 3]))
    x2 = x1 if same_inputs else np.cos(random.normal(key2, [3, 6, 6, 3]))

    conv_layers = [
        stax.Conv(width, (3, 3), W_std=2., b_std=0.5),
        stax.LayerNorm(),
        stax.Hermite(degree),
        stax.GlobalAvgPool() if readout == 'pool' else stax.Flatten(),
        stax.Dense(1) if get == 'ntk' else stax.Identity()]

    init_fn, apply_fn, kernel_fn = stax.serial(*conv_layers)
    analytic_kernel = kernel_fn(x1, x2, get)
    mc_kernel_fn = nt.monte_carlo_kernel_fn(init_fn, apply_fn, key, n_samples)
    mc_kernel = mc_kernel_fn(x1, x2, get)
    rot = degree / 2. * 1e-2
    test_utils.assert_close_matrices(self, mc_kernel, analytic_kernel, rot)
Ejemplo n.º 13
0
def _build_network(input_shape, network, out_logits):
    if len(input_shape) == 1:
        assert network == FLAT
        return stax.Dense(out_logits, W_std=2.0, b_std=0.5)
    elif len(input_shape) == 3:
        if network == POOLING:
            return stax.serial(
                stax.Conv(CONVOLUTION_CHANNELS, (3, 3), W_std=2.0, b_std=0.05),
                stax.GlobalAvgPool(),
                stax.Dense(out_logits, W_std=2.0, b_std=0.5))
        elif network == CONV:
            return stax.serial(
                stax.Conv(CONVOLUTION_CHANNELS, (1, 2), W_std=1.5, b_std=0.1),
                stax.Relu(),
                stax.Conv(CONVOLUTION_CHANNELS, (3, 2), W_std=2.0, b_std=0.05),
            )
        elif network == FLAT:
            return stax.serial(
                stax.Conv(CONVOLUTION_CHANNELS, (3, 3), W_std=2.0, b_std=0.05),
                stax.Flatten(), stax.Dense(out_logits, W_std=2.0, b_std=0.5))
        else:
            raise ValueError(
                'Unexpected network type found: {}'.format(network))
    else:
        raise ValueError('Expected flat or image test input.')
Ejemplo n.º 14
0
  def testPredictOnCPU(self):
    x_train = random.normal(random.PRNGKey(1), (4, 4, 4, 2))
    x_test = random.normal(random.PRNGKey(1), (8, 4, 4, 2))

    y_train = random.uniform(random.PRNGKey(1), (4, 2))

    _, _, kernel_fn = stax.serial(
        stax.Conv(1, (3, 3)), stax.Relu(), stax.Flatten(), stax.Dense(1))

    for store_on_device in [False, True]:
      for device_count in [0, 1]:
        for get in ['ntk', 'nngp', ('nngp', 'ntk'), ('ntk', 'nngp')]:
          for x in [None, 'x_test']:
            with self.subTest(
                store_on_device=store_on_device,
                device_count=device_count,
                get=get,
                x=x):
              kernel_fn_batched = batch.batch(kernel_fn, 2, device_count,
                                              store_on_device)
              predictor = predict.gradient_descent_mse_ensemble(
                  kernel_fn_batched, x_train, y_train)

              x = x if x is None else x_test
              predict_none = predictor(None, x, get, compute_cov=True)
              predict_inf = predictor(np.inf, x, get, compute_cov=True)
              self.assertAllClose(predict_none, predict_inf)

              if x is not None:
                on_cpu = (not store_on_device or
                          xla_bridge.get_backend().platform == 'cpu')
                self.assertEqual(on_cpu, utils.is_on_cpu(predict_inf))
                self.assertEqual(on_cpu, utils.is_on_cpu(predict_none))
def Resnet(block_size, num_classes):
    return stax.serial(stax.Conv(64, (3, 3), padding='SAME'),
                       ResnetGroup(block_size, 64),
                       ResnetGroup(block_size, 128, (2, 2)),
                       ResnetGroup(block_size, 256, (2, 2)),
                       ResnetGroup(block_size, 512, (2, 2)), stax.Flatten(),
                       stax.Dense(num_classes, 1., 0.05))
Ejemplo n.º 16
0
  def _test_activation(self, activation_fn, same_inputs, model, get,
                       rbf_gamma=None):
    if 'conv' in model:
      test_utils.skip_test(self)

    key = random.PRNGKey(1)
    key, split = random.split(key)
    output_dim = 1024 if get == 'nngp' else 1
    b_std = 0.5
    W_std = 2.0
    if activation_fn[2].__name__ == 'Sin':
      W_std = 0.9
    if activation_fn[2].__name__ == 'Rbf':
      W_std = 1.0
      b_std = 0.0

    if model == 'fc':
      rtol = 0.04
      X0_1 = random.normal(key, (4, 2))
      X0_2 = None if same_inputs else random.normal(split, (2, 2))
      affine = stax.Dense(1024, W_std, b_std)
      readout = stax.Dense(output_dim)
      depth = 1

    else:
      rtol = 0.05
      X0_1 = random.normal(key, (2, 4, 4, 3))
      X0_2 = None if same_inputs else random.normal(split, (4, 4, 4, 3))
      affine = stax.Conv(512, (3, 2), W_std=W_std, b_std=b_std, padding='SAME')
      readout = stax.serial(stax.GlobalAvgPool() if 'pool' in model else
                            stax.Flatten(),
                            stax.Dense(output_dim))
      depth = 2

    if default_backend() == 'cpu':
      num_samplings = 200
      rtol *= 2
    else:
      num_samplings = (500 if activation_fn[2].__name__ in ('Sin', 'Rbf')
                       else 300)

    init_fn, apply_fn, kernel_fn = stax.serial(
        *[affine, activation_fn]*depth, readout)
    analytic_kernel = kernel_fn(X0_1, X0_2, get)
    mc_kernel_fn = nt.monte_carlo_kernel_fn(
        init_fn, apply_fn, split, num_samplings, implementation=2,
        vmap_axes=0
    )
    empirical_kernel = mc_kernel_fn(X0_1, X0_2, get)
    test_utils.assert_close_matrices(self, analytic_kernel,
                                     empirical_kernel, rtol)

    # Check match with explicit RBF
    if rbf_gamma is not None and get == 'nngp' and model == 'fc':
      input_dim = X0_1.shape[1]
      _, _, kernel_fn = self._RBF(rbf_gamma / input_dim)
      direct_rbf_kernel = kernel_fn(X0_1, X0_2, get)
      test_utils.assert_close_matrices(self, analytic_kernel,
                                       direct_rbf_kernel, rtol)
Ejemplo n.º 17
0
def _get_inputs_and_model(width=1, n_classes=2):
    key = random.PRNGKey(1)
    key, split = random.split(key)
    x1 = random.normal(key, (8, 4, 3, 2))
    x2 = random.normal(split, (4, 4, 3, 2))
    init_fun, apply_fun, ker_fun = stax.serial(stax.Conv(width, (3, 3)),
                                               stax.Relu(), stax.Flatten(),
                                               stax.Dense(n_classes, 2., 0.5))
    return x1, x2, init_fun, apply_fun, ker_fun, key
Ejemplo n.º 18
0
    def testPredictOnCPU(self):
        key1 = stateless_uniform(shape=[2],
                                 seed=[1, 1],
                                 minval=None,
                                 maxval=None,
                                 dtype=tf.int32)
        key2 = stateless_uniform(shape=[2],
                                 seed=[1, 1],
                                 minval=None,
                                 maxval=None,
                                 dtype=tf.int32)
        key3 = stateless_uniform(shape=[2],
                                 seed=[1, 1],
                                 minval=None,
                                 maxval=None,
                                 dtype=tf.int32)
        x_train = np.asarray(normal((4, 4, 4, 2), seed=key1))
        x_test = np.asarray(normal((8, 4, 4, 2), seed=key2))

        y_train = np.asarray(stateless_uniform(shape=(4, 2), seed=key3))

        _, _, kernel_fn = stax.serial(stax.Conv(1, (3, 3)), stax.Relu(),
                                      stax.Flatten(), stax.Dense(1))

        for store_on_device in [False, True]:
            for device_count in [0, 1]:
                for get in ['ntk', 'nngp', ('nngp', 'ntk'), ('ntk', 'nngp')]:
                    for x in [None, 'x_test']:
                        with self.subTest(store_on_device=store_on_device,
                                          device_count=device_count,
                                          get=get,
                                          x=x):
                            kernel_fn_batched = batch.batch(
                                kernel_fn, 2, device_count, store_on_device)
                            predictor = predict.gradient_descent_mse_ensemble(
                                kernel_fn_batched, x_train, y_train)

                            x = x if x is None else x_test
                            predict_none = predictor(None,
                                                     x,
                                                     get,
                                                     compute_cov=True)
                            predict_inf = predictor(np.inf,
                                                    x,
                                                    get,
                                                    compute_cov=True)
                            self.assertAllClose(predict_none, predict_inf)

                            if x is not None:
                                on_cpu = (not store_on_device
                                          or xla_bridge.get_backend().platform
                                          == 'cpu')
                                self.assertEqual(on_cpu,
                                                 utils.is_on_cpu(predict_inf))
                                self.assertEqual(on_cpu,
                                                 utils.is_on_cpu(predict_none))
Ejemplo n.º 19
0
 def WideResnet(block_size, k, num_classes):
   return stax.serial(
       stax.Conv(16, (3, 3), padding='SAME', parameterization='standard'),
       WideResnetGroup(block_size, int(16 * k)),
       WideResnetGroup(block_size, int(32 * k), (2, 2)),
       WideResnetGroup(block_size, int(64 * k), (2, 2)),
       stax.AvgPool((8, 8), padding='SAME'),
       stax.Flatten(),
       stax.Dense(num_classes, 1.0, 0.0, parameterization='standard'),
   )
Ejemplo n.º 20
0
def _get_net(W_std, b_std, filter_shape, is_conv, use_pooling, is_res,
             padding, phi, strides, width, is_ntk, proj_into_2d):
  fc = partial(stax.Dense, W_std=W_std, b_std=b_std)
  conv = partial(
      stax.Conv,
      filter_shape=filter_shape,
      strides=strides,
      padding=padding,
      W_std=W_std,
      b_std=b_std)
  affine = conv(width) if is_conv else fc(width)

  res_unit = stax.serial((stax.AvgPool(
      (2, 3), None, 'SAME' if padding == 'SAME' else 'CIRCULAR')
                          if use_pooling else stax.Identity()), phi, affine)

  if is_res:
    block = stax.serial(affine, stax.FanOut(2),
                        stax.parallel(stax.Identity(), res_unit),
                        stax.FanInSum())
  else:
    block = stax.serial(affine, res_unit)

  if proj_into_2d == 'FLAT':
    proj_layer = stax.Flatten()
  elif proj_into_2d == 'POOL':
    proj_layer = stax.GlobalAvgPool()
  elif proj_into_2d.startswith('ATTN'):
    n_heads = int(np.sqrt(width))
    n_chan_val = int(np.round(float(width) / n_heads))
    fixed = proj_into_2d == 'ATTN_FIXED'
    proj_layer = stax.serial(
        stax.GlobalSelfAttention(
            width, n_chan_key=width, n_chan_val=n_chan_val, n_heads=n_heads,
            fixed=fixed, W_key_std=W_std, W_value_std=W_std, W_query_std=W_std,
            W_out_std=1.0, b_std=b_std),
        stax.Flatten())
  else:
    raise ValueError(proj_into_2d)
  readout = stax.serial(proj_layer, fc(1 if is_ntk else width))

  return stax.serial(block, readout)
Ejemplo n.º 21
0
def WideResnet(block_size, k, num_classes, W_std=1., b_std=0.):
    return stax.serial(
        stax.Conv(16, (3, 3), W_std=W_std, b_std=b_std, padding='SAME'),
        WideResnetGroup(block_size, int(16 * k), W_std=W_std, b_std=b_std),
        WideResnetGroup(block_size,
                        int(32 * k), (2, 2),
                        W_std=W_std,
                        b_std=b_std),
        WideResnetGroup(block_size,
                        int(64 * k), (2, 2),
                        W_std=W_std,
                        b_std=b_std), stax.AvgPool((7, 7)), stax.Flatten(),
        stax.Dense(num_classes, W_std=W_std, b_std=b_std))
Ejemplo n.º 22
0
def _get_inputs_and_model(width=1, n_classes=2, use_conv=True):
    key = random.PRNGKey(1)
    key, split = random.split(key)
    x1 = random.normal(key, (8, 4, 3, 2))
    x2 = random.normal(split, (4, 4, 3, 2))

    if not use_conv:
        x1 = np.reshape(x1, (x1.shape[0], -1))
        x2 = np.reshape(x2, (x2.shape[0], -1))

    init_fn, apply_fn, kernel_fn = stax.serial(
        stax.Conv(width, (3, 3)) if use_conv else stax.Dense(width),
        stax.Relu(), stax.Flatten(), stax.Dense(n_classes, 2., 0.5))
    return x1, x2, init_fn, apply_fn, kernel_fn, key
Ejemplo n.º 23
0
def build_le_net(network_width):
    """ Construct the LeNet of width network_width with average pooling using neural tangent's stax."""
    return stax.serial(
        stax.Conv(out_chan=6 * network_width,
                  filter_shape=(3, 3),
                  strides=(1, 1),
                  padding='VALID'), stax.Relu(),
        stax.AvgPool(window_shape=(2, 2), strides=(2, 2)),
        stax.Conv(out_chan=16 * network_width,
                  filter_shape=(3, 3),
                  strides=(1, 1),
                  padding='VALID'), stax.Relu(),
        stax.AvgPool(window_shape=(2, 2), strides=(2, 2)), stax.Flatten(),
        stax.Dense(120 * network_width), stax.Relu(),
        stax.Dense(84 * network_width), stax.Relu(), stax.Dense(10))
Ejemplo n.º 24
0
    def test_flatten_first(self, same_inputs):
        key = random.PRNGKey(1)
        X0_1 = random.normal(key, (5, 4, 3, 2))
        X0_2 = None if same_inputs else random.normal(key, (3, 4, 3, 2))

        X0_1_flat = np.reshape(X0_1, (X0_1.shape[0], -1))
        X0_2_flat = None if same_inputs else np.reshape(
            X0_2, (X0_2.shape[0], -1))

        _, _, fc_flat = stax.serial(stax.Dense(10, 2., 0.5), stax.Erf())
        _, _, fc = stax.serial(stax.Flatten(), stax.Dense(10, 2., 0.5),
                               stax.Erf())

        K_flat = fc_flat(X0_1_flat, X0_2_flat)
        K = fc(X0_1, X0_2)
        self.assertAllClose(K_flat, K, True)
Ejemplo n.º 25
0
def _MyrtleNetwork(width, depth, W_std=jnp.sqrt(2.0), b_std=0.0):
    layer_factor = {5: [2, 1, 1], 7: [2, 2, 2], 10: [3, 3, 3]}
    activation_fn = stax.Relu()
    layers = []
    conv = functools.partial(stax.Conv,
                             W_std=W_std,
                             b_std=b_std,
                             padding="SAME")

    layers += [conv(width, (3, 3)), activation_fn] * layer_factor[depth][0]
    layers += [stax.AvgPool((2, 2), strides=(2, 2))]
    layers += [conv(width, (3, 3)), activation_fn] * layer_factor[depth][1]
    layers += [stax.AvgPool((2, 2), strides=(2, 2))]
    layers += [conv(width, (3, 3)), activation_fn] * layer_factor[depth][2]
    layers += [stax.AvgPool((2, 2), strides=(2, 2))] * 3

    layers += [stax.Flatten(), stax.Dense(10, W_std, b_std)]

    return stax.serial(*layers)
Ejemplo n.º 26
0
def _get_inputs_and_model(width=1, n_classes=2, use_conv=True):
    key = stateless_uniform(shape=[2],
                            seed=[1, 1],
                            minval=None,
                            maxval=None,
                            dtype=tf.int32)
    keys = tf_random_split(key)
    key = keys[0]
    split = keys[1]
    x1 = np.asarray(normal((8, 4, 3, 2), seed=key))
    x2 = np.asarray(normal((4, 4, 3, 2), seed=split))

    if not use_conv:
        x1 = np.reshape(x1, (x1.shape[0], -1))
        x2 = np.reshape(x2, (x2.shape[0], -1))

    init_fn, apply_fn, kernel_fn = stax.serial(
        stax.Conv(width, (3, 3)) if use_conv else stax.Dense(width),
        stax.Relu(), stax.Flatten(), stax.Dense(n_classes, 2., 0.5))
    return x1, x2, init_fn, apply_fn, kernel_fn, key
Ejemplo n.º 27
0
  def test_elementwise_numerical(self, same_inputs, model, phi, get):
    if 'conv' in model:
      test_utils.skip_test(self)

    key, split = random.split(random.PRNGKey(1))

    output_dim = 1
    b_std = 0.01
    W_std = 1.0
    rtol = 2e-3
    deg = 25
    if get == 'ntk':
      rtol *= 2
    if default_backend() == 'tpu':
      rtol *= 2

    if model == 'fc':
      X0_1 = random.normal(key, (3, 7))
      X0_2 = None if same_inputs else random.normal(split, (5, 7))
      affine = stax.Dense(1024, W_std, b_std)
      readout = stax.Dense(output_dim)
      depth = 1
    else:
      X0_1 = random.normal(key, (2, 8, 8, 3))
      X0_2 = None if same_inputs else random.normal(split, (3, 8, 8, 3))
      affine = stax.Conv(1024, (3, 2), W_std=W_std, b_std=b_std, padding='SAME')
      readout = stax.serial(stax.GlobalAvgPool() if 'pool' in model else
                            stax.Flatten(),
                            stax.Dense(output_dim))
      depth = 2

    _, _, kernel_fn = stax.serial(*[affine, phi] * depth, readout)
    analytic_kernel = kernel_fn(X0_1, X0_2, get)

    fn = lambda x: phi[1]((), x)
    _, _, kernel_fn = stax.serial(
        *[affine, stax.ElementwiseNumerical(fn, deg=deg)] * depth, readout)
    numerical_activation_kernel = kernel_fn(X0_1, X0_2, get)

    test_utils.assert_close_matrices(self, analytic_kernel,
                                     numerical_activation_kernel, rtol)
Ejemplo n.º 28
0
def CNNStandard(n_channels,
                L,
                filter=(3, 3),
                data='cifar10',
                gap=True,
                nonlinearity='relu',
                parameterization='standard',
                order=None):
    if data == 'cifar10':
        num_classes = 10
    if data == 'cifar100':
        num_classes = 100
    if nonlinearity == 'relu':
        nonlin = Relu
    elif nonlinearity == 'swish':
        nonlin = Swish
    init_fn, f = jax_stax.serial(*[
        jax_stax.serial(
            MyConv(n_channels,
                   filter,
                   parameterization=parameterization,
                   order=order),
            nonlin,
        ) for _ in range(L)
    ])
    if gap:
        init_fn, f = jax_stax.serial((init_fn, f),
                                     stax.GlobalAvgPool()[:2],
                                     MyDense(num_classes,
                                             parameterization=parameterization,
                                             order=order))
    else:
        init_fn, f = jax_stax.serial((init_fn, f),
                                     stax.Flatten()[:2],
                                     MyDense(num_classes,
                                             parameterization=parameterization,
                                             order=order))
    return init_fn, f
Ejemplo n.º 29
0
def cnn1d(hidden_widths, nonlin, args):
    layers = []
    layers_ker = []

    window = (3, )
    stride = (1, )
    if args != None:
        window = args

    for i, ch in enumerate(hidden_widths):
        layers += [
            stax.Conv(ch,
                      window,
                      stride,
                      'CIRCULAR',
                      b_std=0,
                      parameterization='ntk')
        ]
        if nonlin != 'relu':
            layers_ker += [i]
        else:
            layers += [stax.Relu()]
            layers_ker += [2 * i + 1]
    layers += [stax.Flatten()]
    if nonlin != 'relu':
        layers_ker += [i + 1]
    else:
        layers += [stax.Relu()]
        layers_ker += [2 * i + 3]
    layers += [stax.Dense(10, parameterization='ntk')]
    if nonlin != 'relu':
        layers_ker += [i + 2]
    else:
        layers_ker += [2 * i + 4]

    return layers, layers_ker
Ejemplo n.º 30
0
    def test_fan_in_conv(self, same_inputs, axis, n_branches, get, branch_in,
                         readout, fan_in_mode):
        test_utils.skip_test(self)
        if fan_in_mode in ['FanInSum', 'FanInProd']:
            if axis != 0:
                raise absltest.SkipTest(
                    '`FanInSum` and `FanInProd()` are skipped when '
                    'axis != 0.')
            axis = None
        if (fan_in_mode == 'FanInSum'
                or axis in [0, 1, 2]) and branch_in == 'dense_after_branch_in':
            raise absltest.SkipTest('`FanInSum` and `FanInConcat(0/1/2)` '
                                    'require `is_gaussian`.')

        if ((axis == 3 or fan_in_mode == 'FanInProd')
                and branch_in == 'dense_before_branch_in'):
            raise absltest.SkipTest(
                '`FanInConcat` or `FanInProd` on feature axis '
                'requires a dense layer after concatenation '
                'or Hadamard product.')

        if fan_in_mode == 'FanInSum':
            fan_in_layer = stax.FanInSum()
        elif fan_in_mode == 'FanInProd':
            fan_in_layer = stax.FanInProd()
        else:
            fan_in_layer = stax.FanInConcat(axis)

        key = random.PRNGKey(1)
        X0_1 = random.normal(key, (2, 5, 6, 3))
        X0_2 = None if same_inputs else random.normal(key, (3, 5, 6, 3))

        if default_backend() == 'tpu':
            width = 2048
            n_samples = 1024
            tol = 0.02
        else:
            width = 1024
            n_samples = 512
            tol = 0.01

        conv = stax.Conv(out_chan=width,
                         filter_shape=(3, 3),
                         padding='SAME',
                         W_std=1.25,
                         b_std=0.1)

        input_layers = [conv, stax.FanOut(n_branches)]

        branches = []
        for b in range(n_branches):
            branch_layers = [FanInTest._get_phi(b)]
            for i in range(b):
                multiplier = 1 if axis not in (3, -1) else (1 + 0.25 * i)
                branch_layers += [
                    stax.Conv(out_chan=int(width * multiplier),
                              filter_shape=(i + 1, 4 - i),
                              padding='SAME',
                              W_std=1.25 + i,
                              b_std=0.1 + i),
                    FanInTest._get_phi(i)
                ]

            if branch_in == 'dense_before_branch_in':
                branch_layers += [conv]
            branches += [stax.serial(*branch_layers)]

        output_layers = [
            fan_in_layer,
            stax.Relu(),
            stax.GlobalAvgPool() if readout == 'pool' else stax.Flatten()
        ]
        if branch_in == 'dense_after_branch_in':
            output_layers.insert(1, conv)

        nn = stax.serial(*(input_layers + [stax.parallel(*branches)] +
                           output_layers))

        init_fn, apply_fn, kernel_fn = stax.serial(
            nn, stax.Dense(1 if get == 'ntk' else width, 1.25, 0.5))

        kernel_fn_mc = nt.monte_carlo_kernel_fn(
            init_fn,
            apply_fn,
            key,
            n_samples,
            device_count=0 if axis in (0, -4) else -1,
            implementation=_DEFAULT_TESTING_NTK_IMPLEMENTATION,
            vmap_axes=None if axis in (0, -4) else 0,
        )

        exact = kernel_fn(X0_1, X0_2, get=get)
        empirical = kernel_fn_mc(X0_1, X0_2, get=get)
        test_utils.assert_close_matrices(self, empirical, exact, tol)