예제 #1
0
def main(unused_argv):
  key1, key2, key3 = random.split(random.PRNGKey(1), 3)
  x1 = random.normal(key1, (2, 8, 8, 3))
  x2 = random.normal(key2, (3, 8, 8, 3))

  # A vanilla CNN.
  init_fn, f, _ = stax.serial(
      stax.Conv(8, (3, 3)),
      stax.Relu(),
      stax.Conv(8, (3, 3)),
      stax.Relu(),
      stax.Conv(8, (3, 3)),
      stax.Flatten(),
      stax.Dense(10)
  )

  _, params = init_fn(key3, x1.shape)
  kwargs = dict(
      f=f,
      trace_axes=(),
      vmap_axes=0,
  )

  # Default, baseline Jacobian contraction.
  jacobian_contraction = nt.empirical_ntk_fn(
      **kwargs,
      implementation=nt.NtkImplementation.JACOBIAN_CONTRACTION)

  # (6, 3, 10, 10) full `np.ndarray` test-train NTK
  ntk_jc = jacobian_contraction(x2, x1, params)

  # NTK-vector products-based implementation.
  ntk_vector_products = nt.empirical_ntk_fn(
      **kwargs,
      implementation=nt.NtkImplementation.NTK_VECTOR_PRODUCTS)

  ntk_vp = ntk_vector_products(x2, x1, params)

  # Structured derivatives-based implementation.
  structured_derivatives = nt.empirical_ntk_fn(
      **kwargs,
      implementation=nt.NtkImplementation.STRUCTURED_DERIVATIVES)

  ntk_sd = structured_derivatives(x2, x1, params)

  # Auto-FLOPs-selecting implementation. Doesn't work correctly on CPU/GPU.
  auto = nt.empirical_ntk_fn(
      **kwargs,
      implementation=nt.NtkImplementation.AUTO)

  ntk_auto = auto(x2, x1, params)

  # Check that implementations match
  for ntk1 in [ntk_jc, ntk_vp, ntk_sd, ntk_auto]:
    for ntk2 in [ntk_jc, ntk_vp, ntk_sd, ntk_auto]:
      diff = np.max(np.abs(ntk1 - ntk2))
      print(f'NTK implementation diff {diff}.')
      assert diff < (1e-4 if jax.default_backend() != 'tpu' else 0.1), diff

  print('All NTK implementations match.')
예제 #2
0
def _build_network(input_shape, network, out_logits):
    if len(input_shape) == 1:
        assert network == FLAT
        return stax.Dense(out_logits, W_std=2.0, b_std=0.5)
    elif len(input_shape) == 3:
        if network == POOLING:
            return stax.serial(
                stax.Conv(CONVOLUTION_CHANNELS, (3, 3), W_std=2.0, b_std=0.05),
                stax.GlobalAvgPool(),
                stax.Dense(out_logits, W_std=2.0, b_std=0.5))
        elif network == CONV:
            return stax.serial(
                stax.Conv(CONVOLUTION_CHANNELS, (1, 2), W_std=1.5, b_std=0.1),
                stax.Relu(),
                stax.Conv(CONVOLUTION_CHANNELS, (3, 2), W_std=2.0, b_std=0.05),
            )
        elif network == FLAT:
            return stax.serial(
                stax.Conv(CONVOLUTION_CHANNELS, (3, 3), W_std=2.0, b_std=0.05),
                stax.Flatten(), stax.Dense(out_logits, W_std=2.0, b_std=0.5))
        else:
            raise ValueError(
                'Unexpected network type found: {}'.format(network))
    else:
        raise ValueError('Expected flat or image test input.')
예제 #3
0
  def test_exp_normalized(self):
    key = random.PRNGKey(0)
    x1 = random.normal(key, (2, 6, 7, 1))
    x2 = random.normal(key, (4, 6, 7, 1))

    for do_clip in [True, False]:
      for gamma in [1., 2., 0.5]:
        for get in ['nngp', 'ntk']:
          with self.subTest(do_clip=do_clip, gamma=gamma, get=get):
            _, _, kernel_fn = stax.serial(
                stax.Conv(1, (3, 3)),
                stax.ExpNormalized(gamma, do_clip),
                stax.Conv(1, (3, 3)),
                stax.ExpNormalized(gamma, do_clip),
                stax.GlobalAvgPool(),
                stax.Dense(1)
            )
            k_12 = kernel_fn(x1, x2, get=get)
            self.assertEqual(k_12.shape, (x1.shape[0], x2.shape[0]))

            k_11 = kernel_fn(x1, None, get=get)
            self.assertEqual(k_11.shape, (x1.shape[0],) * 2)
            self.assertGreater(np.min(np.linalg.eigvalsh(k_11)), 0)

            k_22 = kernel_fn(x2, None, get=get)
            self.assertEqual(k_22.shape, (x2.shape[0],) * 2)
            self.assertGreater(np.min(np.linalg.eigvalsh(k_22)), 0)
예제 #4
0
  def test_nested_parallel(self, same_inputs, kernel_type):
    platform = default_backend()
    rtol = RTOL if platform != 'tpu' else 0.05

    rng = random.PRNGKey(0)
    (input_key1,
     input_key2,
     input_key3,
     input_key4,
     mask_key,
     mc_key) = random.split(rng, 6)

    x1_1, x2_1 = _get_inputs(input_key1, same_inputs, (BATCH_SIZE, 5))
    x1_2, x2_2 = _get_inputs(input_key2, same_inputs, (BATCH_SIZE, 2, 2, 2))
    x1_3, x2_3 = _get_inputs(input_key3, same_inputs, (BATCH_SIZE, 2, 2, 3))
    x1_4, x2_4 = _get_inputs(input_key4, same_inputs, (BATCH_SIZE, 3, 4))

    m1_key, m2_key, m3_key, m4_key = random.split(mask_key, 4)

    x1_1 = test_utils.mask(
        x1_1, mask_constant=-1, mask_axis=(1,), key=m1_key, p=0.5)
    x1_2 = test_utils.mask(
        x1_2, mask_constant=-1, mask_axis=(2, 3,), key=m2_key, p=0.5)
    if not same_inputs:
      x2_3 = test_utils.mask(
          x2_3, mask_constant=-1, mask_axis=(1, 3,), key=m3_key, p=0.5)
      x2_4 = test_utils.mask(
          x2_4, mask_constant=-1, mask_axis=(2,), key=m4_key, p=0.5)

    x1 = (((x1_1, x1_2), x1_3), x1_4)
    x2 = (((x2_1, x2_2), x2_3), x2_4) if not same_inputs else None

    N_in = 2 ** 7

    # We only include dropout on non-TPU backends, because it takes large N to
    # converge on TPU.
    dropout_or_id = stax.Dropout(0.9) if platform != 'tpu' else stax.Identity()

    init_fn, apply_fn, kernel_fn = stax.parallel(
        stax.parallel(
            stax.parallel(stax.Dense(N_in),
                          stax.serial(stax.Conv(N_in + 1, (2, 2)),
                                      stax.Flatten())),
            stax.serial(stax.Conv(N_in + 2, (2, 2)),
                        dropout_or_id,
                        stax.GlobalAvgPool())),
        stax.Conv(N_in + 3, (2,)))

    kernel_fn_empirical = nt.monte_carlo_kernel_fn(
        init_fn, apply_fn, mc_key, N_SAMPLES,
        implementation=_DEFAULT_TESTING_NTK_IMPLEMENTATION,
        vmap_axes=(((((0, 0), 0), 0), (((0, 0), 0), 0), {})
                   if platform == 'tpu' else None)
    )

    test_utils.assert_close_matrices(
        self,
        kernel_fn(x1, x2, get=kernel_type, mask_constant=-1),
        kernel_fn_empirical(x1, x2, get=kernel_type, mask_constant=-1),
        rtol)
예제 #5
0
def _build_network(input_shape, network, out_logits, use_dropout):
    dropout = stax.Dropout(0.9,
                           mode='train') if use_dropout else stax.Identity()
    if len(input_shape) == 1:
        assert network == 'FLAT'
        return stax.serial(stax.Dense(WIDTH, W_std=2.0, b_std=0.5), dropout,
                           stax.Dense(out_logits, W_std=2.0, b_std=0.5))
    elif len(input_shape) == 3:
        if network == POOLING:
            return stax.serial(
                stax.Conv(CONVOLUTION_CHANNELS, (2, 2), W_std=2.0, b_std=0.05),
                stax.GlobalAvgPool(), dropout,
                stax.Dense(out_logits, W_std=2.0, b_std=0.5))
        elif network == FLAT:
            return stax.serial(
                stax.Conv(CONVOLUTION_CHANNELS, (2, 2), W_std=2.0, b_std=0.05),
                stax.Flatten(), dropout,
                stax.Dense(out_logits, W_std=2.0, b_std=0.5))
        elif network == INTERMEDIATE_CONV:
            return stax.Conv(CONVOLUTION_CHANNELS, (2, 2),
                             W_std=2.0,
                             b_std=0.05)
        else:
            raise ValueError(
                'Unexpected network type found: {}'.format(network))
    else:
        raise ValueError('Expected flat or image test input.')
예제 #6
0
def WideResnetBlocknt(channels,
                      strides=(1, 1),
                      channel_mismatch=False,
                      batchnorm='std',
                      parameterization='ntk'):
    """A WideResnet block, with or without BatchNorm."""

    Main = stax_nt.serial(
        _batch_norm_internal(batchnorm), stax_nt.Relu(),
        stax_nt.Conv(channels, (3, 3),
                     strides,
                     padding='SAME',
                     parameterization=parameterization),
        _batch_norm_internal(batchnorm), stax_nt.Relu(),
        stax_nt.Conv(channels, (3, 3),
                     padding='SAME',
                     parameterization=parameterization))

    Shortcut = stax_nt.Identity() if not channel_mismatch else stax_nt.Conv(
        channels, (3, 3),
        strides,
        padding='SAME',
        parameterization=parameterization)
    return stax_nt.serial(stax_nt.FanOut(2), stax_nt.parallel(Main, Shortcut),
                          stax_nt.FanInSum())
예제 #7
0
  def test_composition_conv(self, avg_pool, same_inputs):
    rng = random.PRNGKey(0)
    x1 = random.normal(rng, (3, 5, 5, 3))
    x2 = None if same_inputs else random.normal(rng, (4, 5, 5, 3))

    Block = stax.serial(stax.Conv(256, (3, 3)), stax.Relu())
    if avg_pool:
      Readout = stax.serial(stax.Conv(256, (3, 3)),
                            stax.GlobalAvgPool(),
                            stax.Dense(10))
    else:
      Readout = stax.serial(stax.Flatten(), stax.Dense(10))

    block_ker_fn, readout_ker_fn = Block[2], Readout[2]
    _, _, composed_ker_fn = stax.serial(Block, Readout)

    composed_ker_out = composed_ker_fn(x1, x2)
    ker_out_no_marg = readout_ker_fn(block_ker_fn(x1, x2,
                                                  diagonal_spatial=False))
    ker_out_default = readout_ker_fn(block_ker_fn(x1, x2))
    self.assertAllClose(composed_ker_out, ker_out_no_marg)
    self.assertAllClose(composed_ker_out, ker_out_default)

    if avg_pool:
      with self.assertRaises(ValueError):
        ker_out = readout_ker_fn(block_ker_fn(x1, x2, diagonal_spatial=True))
    else:
      ker_out_marg = readout_ker_fn(block_ker_fn(x1, x2,
                                                 diagonal_spatial=True))
      self.assertAllClose(composed_ker_out, ker_out_marg)
예제 #8
0
  def test_vmap_axes(self, same_inputs):
    n1, n2 = 3, 4
    c1, c2, c3 = 9, 5, 7
    h2, h3, w3 = 6, 8, 2

    def get_x(n, k):
      k1, k2, k3 = random.split(k, 3)
      x1 = random.normal(k1, (n, c1))
      x2 = random.normal(k2, (h2, n, c2))
      x3 = random.normal(k3, (c3, w3, n, h3))
      x = [(x1, x2), x3]
      return x

    x1 = get_x(n1, random.PRNGKey(1))
    x2 = get_x(n2, random.PRNGKey(2)) if not same_inputs else None

    p1 = random.normal(random.PRNGKey(5), (n1, h2, h2))
    p2 = None if same_inputs else random.normal(random.PRNGKey(6), (n2, h2, h2))

    init_fn, apply_fn, _ = stax.serial(
        stax.parallel(
            stax.parallel(
                stax.serial(stax.Dense(4, 2., 0.1),
                            stax.Relu(),
                            stax.Dense(3, 1., 0.15)),  # 1
                stax.serial(stax.Conv(7, (2,), padding='SAME',
                                      dimension_numbers=('HNC', 'OIH', 'NHC')),
                            stax.Erf(),
                            stax.Aggregate(1, 0, -1),
                            stax.GlobalAvgPool(),
                            stax.Dense(3, 0.5, 0.2)),  # 2
            ),
            stax.serial(
                stax.Conv(5, (2, 3), padding='SAME',
                          dimension_numbers=('CWNH', 'IOHW', 'HWCN')),
                stax.Sin(),
            )  # 3
        ),
        stax.parallel(
            stax.FanInSum(),
            stax.Conv(2, (2, 1), dimension_numbers=('HWCN', 'OIHW', 'HNWC'))
        )
    )

    _, params = init_fn(random.PRNGKey(3), tree_map(np.shape, x1))
    implicit = jit(empirical._empirical_implicit_ntk_fn(apply_fn))
    direct = jit(empirical._empirical_direct_ntk_fn(apply_fn))

    implicit_batched = jit(empirical._empirical_implicit_ntk_fn(
        apply_fn, vmap_axes=([(0, 1), 2], [-2, -3], dict(pattern=0))))
    direct_batched = jit(empirical._empirical_direct_ntk_fn(
        apply_fn, vmap_axes=([(-2, -2), -2], [0, 1], dict(pattern=-3))))

    k = direct(x1, x2, params, pattern=(p1, p2))

    self.assertAllClose(k, implicit(x1, x2, params, pattern=(p1, p2)))
    self.assertAllClose(k, direct_batched(x1, x2, params, pattern=(p1, p2)))
    self.assertAllClose(k, implicit_batched(x1, x2, params, pattern=(p1, p2)))
def ResnetBlock(channels, strides=(1, 1), channel_mismatch=False):
    Main = stax.serial(stax.Relu(),
                       stax.Conv(channels, (3, 3), strides, padding='SAME'),
                       stax.Relu(), stax.Conv(channels, (3, 3),
                                              padding='SAME'))
    Shortcut = stax.Identity() if not channel_mismatch else stax.Conv(
        channels, (3, 3), strides, padding='SAME')
    return stax.serial(stax.FanOut(2), stax.parallel(Main, Shortcut),
                       stax.FanInSum())
예제 #10
0
    def testGradientDescentMseEnsembleTrain(self):
        key = random.PRNGKey(1)
        x = random.normal(key, (8, 4, 6, 3))
        _, _, kernel_fn = stax.serial(stax.Conv(1, (2, 2)), stax.Relu(),
                                      stax.Conv(1, (2, 1)))
        y = random.normal(key, (8, 2, 5, 1))
        predictor = predict.gradient_descent_mse_ensemble(kernel_fn, x, y)

        for t in [None, np.array([0., 1., 10.])]:
            with self.subTest(t=t):
                y_none = predictor(t, None, None, compute_cov=True)
                y_x = predictor(t, x, None, compute_cov=True)
                self._assertAllClose(y_none, y_x, 0.04)
예제 #11
0
 def _get_kernel_fn(self, same_inputs, readin, readout):
     key = random.PRNGKey(1)
     x1 = random.normal(key, (2, 5, 6, 3))
     x2 = None if same_inputs else random.normal(key, (3, 5, 6, 3))
     layers = [readin]
     filter_shape = (2, 3) if readin[0].__name__ == 'Identity' else ()
     layers += [
         stax.Conv(1, filter_shape, padding='SAME'),
         stax.Relu(),
         stax.Conv(1, filter_shape, padding='SAME'),
         stax.Erf(), readout
     ]
     _, _, kernel_fn = stax.serial(*layers)
     return kernel_fn, x1, x2
예제 #12
0
def build_le_net(network_width):
    """ Construct the LeNet of width network_width with average pooling using neural tangent's stax."""
    return stax.serial(
        stax.Conv(out_chan=6 * network_width,
                  filter_shape=(3, 3),
                  strides=(1, 1),
                  padding='VALID'), stax.Relu(),
        stax.AvgPool(window_shape=(2, 2), strides=(2, 2)),
        stax.Conv(out_chan=16 * network_width,
                  filter_shape=(3, 3),
                  strides=(1, 1),
                  padding='VALID'), stax.Relu(),
        stax.AvgPool(window_shape=(2, 2), strides=(2, 2)), stax.Flatten(),
        stax.Dense(120 * network_width), stax.Relu(),
        stax.Dense(84 * network_width), stax.Relu(), stax.Dense(10))
예제 #13
0
    def testPredictOnCPU(self):
        x_train = random.normal(random.PRNGKey(1), (10, 4, 5, 3))
        x_test = random.normal(random.PRNGKey(1), (8, 4, 5, 3))

        y_train = random.uniform(random.PRNGKey(1), (10, 7))

        _, _, kernel_fn = stax.serial(stax.Conv(1, (3, 3)), stax.Relu(),
                                      stax.Flatten(), stax.Dense(1))

        for store_on_device in [False, True]:
            for device_count in [0, 1]:
                for get in ['ntk', 'nngp', ('nngp', 'ntk'), ('ntk', 'nngp')]:
                    with self.subTest(store_on_device=store_on_device,
                                      device_count=device_count,
                                      get=get):
                        kernel_fn_batched = batch.batch(
                            kernel_fn, 2, device_count, store_on_device)
                        predictor = predict.gradient_descent_mse_gp(
                            kernel_fn_batched, x_train, y_train, x_test, get,
                            0., True)
                        gp_inference = predict.gp_inference(
                            kernel_fn_batched, x_train, y_train, x_test, get,
                            0., True)

                        self.assertAllClose(predictor(None), predictor(np.inf),
                                            True)
                        self.assertAllClose(predictor(None), gp_inference,
                                            True)
예제 #14
0
def WideResnetnt(
        block_size,
        k,
        num_classes,
        batchnorm='std'):  #, batch_norm=None,layer_norm=None,freezelast=None):
    """Based off of WideResnet from paper, with or without BatchNorm. 
  (Set config.wrn_block_size=3, config.wrn_widening_f=10 in that case).
  Uses default weight and bias init."""
    parameterization = 'standard'
    layers_lst = [
        stax_nt.Conv(16, (3, 3),
                     padding='SAME',
                     parameterization=parameterization),
        WideResnetGroupnt(block_size,
                          16 * k,
                          parameterization=parameterization,
                          batchnorm=batchnorm),
        WideResnetGroupnt(block_size,
                          32 * k, (2, 2),
                          parameterization=parameterization,
                          batchnorm=batchnorm),
        WideResnetGroupnt(block_size,
                          64 * k, (2, 2),
                          parameterization=parameterization,
                          batchnorm=batchnorm)
    ]
    layers_lst += [_batch_norm_internal(batchnorm), stax_nt.Relu()]
    layers_lst += [
        stax_nt.AvgPool((8, 8)),
        stax_nt.Flatten(),
        stax_nt.Dense(num_classes, parameterization=parameterization)
    ]
    return stax_nt.serial(*layers_lst)
예제 #15
0
 def WideResnet(block_size, k, num_classes):
     return stax.serial(
         stax.Conv(16, (3, 3), padding='SAME'),
         ntk_generator.ResnetGroup(block_size, int(16 * k)),
         ntk_generator.ResnetGroup(block_size, int(32 * k), (2, 2)),
         ntk_generator.ResnetGroup(block_size, int(64 * k), (2, 2)),
         stax.Flatten(), stax.Dense(num_classes, 1., 0.))
예제 #16
0
  def testPredictOnCPU(self):
    x_train = random.normal(random.PRNGKey(1), (4, 4, 4, 2))
    x_test = random.normal(random.PRNGKey(1), (8, 4, 4, 2))

    y_train = random.uniform(random.PRNGKey(1), (4, 2))

    _, _, kernel_fn = stax.serial(
        stax.Conv(1, (3, 3)), stax.Relu(), stax.Flatten(), stax.Dense(1))

    for store_on_device in [False, True]:
      for device_count in [0, 1]:
        for get in ['ntk', 'nngp', ('nngp', 'ntk'), ('ntk', 'nngp')]:
          for x in [None, 'x_test']:
            with self.subTest(
                store_on_device=store_on_device,
                device_count=device_count,
                get=get,
                x=x):
              kernel_fn_batched = batch.batch(kernel_fn, 2, device_count,
                                              store_on_device)
              predictor = predict.gradient_descent_mse_ensemble(
                  kernel_fn_batched, x_train, y_train)

              x = x if x is None else x_test
              predict_none = predictor(None, x, get, compute_cov=True)
              predict_inf = predictor(np.inf, x, get, compute_cov=True)
              self.assertAllClose(predict_none, predict_inf)

              if x is not None:
                on_cpu = (not store_on_device or
                          xla_bridge.get_backend().platform == 'cpu')
                self.assertEqual(on_cpu, utils.is_on_cpu(predict_inf))
                self.assertEqual(on_cpu, utils.is_on_cpu(predict_none))
예제 #17
0
    def test_composition_conv(self, avg_pool):
        rng = random.PRNGKey(0)
        x1 = random.normal(rng, (5, 10, 10, 3))
        x2 = random.normal(rng, (5, 10, 10, 3))

        Block = stax.serial(stax.Conv(256, (3, 3)), stax.Relu())
        if avg_pool:
            Readout = stax.serial(stax.GlobalAvgPool(), stax.Dense(10))
            marginalization = 'none'
        else:
            Readout = stax.serial(stax.Flatten(), stax.Dense(10))
            marginalization = 'auto'

        block_ker_fn, readout_ker_fn = Block[2], Readout[2]
        _, _, composed_ker_fn = stax.serial(Block, Readout)

        ker_out = readout_ker_fn(
            block_ker_fn(x1, marginalization=marginalization))
        composed_ker_out = composed_ker_fn(x1)
        self.assertAllClose(ker_out, composed_ker_out, True)

        if avg_pool:
            with self.assertRaises(ValueError):
                ker_out = readout_ker_fn(block_ker_fn(x1))

        ker_out = readout_ker_fn(
            block_ker_fn(x1, x2, marginalization=marginalization))
        composed_ker_out = composed_ker_fn(x1, x2)
        self.assertAllClose(ker_out, composed_ker_out, True)
예제 #18
0
def wide_resnet(block_size, k, num_classes):
    return stax.serial(stax.Conv(16, (3, 3), padding='SAME'),
                       wide_resnet_group(block_size, int(16 * k)),
                       wide_resnet_group(block_size, int(32 * k), (2, 2)),
                       wide_resnet_group(block_size, int(64 * k), (2, 2)),
                       stax.AvgPool((8, 8)), stax.Flatten(),
                       stax.Dense(num_classes, 1., 0.))
def Resnet(block_size, num_classes):
    return stax.serial(stax.Conv(64, (3, 3), padding='SAME'),
                       ResnetGroup(block_size, 64),
                       ResnetGroup(block_size, 128, (2, 2)),
                       ResnetGroup(block_size, 256, (2, 2)),
                       ResnetGroup(block_size, 512, (2, 2)), stax.Flatten(),
                       stax.Dense(num_classes, 1., 0.05))
예제 #20
0
  def test_hermite(self, same_inputs, degree, get, readout):
    key = random.PRNGKey(1)
    key1, key2, key = random.split(key, 3)

    if degree > 2:
      width = 10000
      n_samples = 5000
      test_utils.skip_test(self)
    else:
      width = 10000
      n_samples = 100

    x1 = np.cos(random.normal(key1, [2, 6, 6, 3]))
    x2 = x1 if same_inputs else np.cos(random.normal(key2, [3, 6, 6, 3]))

    conv_layers = [
        stax.Conv(width, (3, 3), W_std=2., b_std=0.5),
        stax.LayerNorm(),
        stax.Hermite(degree),
        stax.GlobalAvgPool() if readout == 'pool' else stax.Flatten(),
        stax.Dense(1) if get == 'ntk' else stax.Identity()]

    init_fn, apply_fn, kernel_fn = stax.serial(*conv_layers)
    analytic_kernel = kernel_fn(x1, x2, get)
    mc_kernel_fn = nt.monte_carlo_kernel_fn(init_fn, apply_fn, key, n_samples)
    mc_kernel = mc_kernel_fn(x1, x2, get)
    rot = degree / 2. * 1e-2
    test_utils.assert_close_matrices(self, mc_kernel, analytic_kernel, rot)
예제 #21
0
    def _test_analytic_kernel_composition(self, batching_fn):
        # Check Fully-Connected.
        rng = stateless_uniform(shape=[2],
                                seed=[0, 0],
                                minval=None,
                                maxval=None,
                                dtype=tf.int32)
        keys = tf_random_split(rng)
        rng_self = keys[0]
        rng_other = keys[1]
        x_self = np.asarray(normal((8, 10), seed=rng_self))
        x_other = np.asarray(normal((2, 10), seed=rng_other))
        Block = stax.serial(stax.Dense(256), stax.Relu())

        _, _, ker_fn = Block
        ker_fn = batching_fn(ker_fn)

        _, _, composed_ker_fn = stax.serial(Block, Block)

        ker_out = ker_fn(ker_fn(x_self))
        composed_ker_out = composed_ker_fn(x_self)
        if batching_fn == batch._parallel:
            # In the parallel setting, `x1_is_x2` is not computed correctly
            # when x1==x2.
            composed_ker_out = composed_ker_out.replace(
                x1_is_x2=ker_out.x1_is_x2)
        self.assertAllClose(ker_out, composed_ker_out)

        ker_out = ker_fn(ker_fn(x_self, x_other))
        composed_ker_out = composed_ker_fn(x_self, x_other)
        if batching_fn == batch._parallel:
            composed_ker_out = composed_ker_out.replace(
                x1_is_x2=ker_out.x1_is_x2)
        self.assertAllClose(ker_out, composed_ker_out)

        # Check convolutional + pooling.
        x_self = np.asarray(normal((8, 10, 10, 3), seed=rng))
        x_other = np.asarray(normal((2, 10, 10, 3), seed=rng))

        Block = stax.serial(stax.Conv(256, (2, 2)), stax.Relu())
        Readout = stax.serial(stax.GlobalAvgPool(), stax.Dense(10))

        block_ker_fn, readout_ker_fn = Block[2], Readout[2]
        _, _, composed_ker_fn = stax.serial(Block, Readout)
        block_ker_fn = batching_fn(block_ker_fn)
        readout_ker_fn = batching_fn(readout_ker_fn)

        ker_out = readout_ker_fn(block_ker_fn(x_self))
        composed_ker_out = composed_ker_fn(x_self)
        if batching_fn == batch._parallel:
            composed_ker_out = composed_ker_out.replace(
                x1_is_x2=ker_out.x1_is_x2)
        self.assertAllClose(ker_out, composed_ker_out)
        ker_out = readout_ker_fn(block_ker_fn(x_self, x_other))
        composed_ker_out = composed_ker_fn(x_self, x_other)
        if batching_fn == batch._parallel:
            composed_ker_out = composed_ker_out.replace(
                x1_is_x2=ker_out.x1_is_x2)
        self.assertAllClose(ker_out, composed_ker_out)
예제 #22
0
  def _test_activation(self, activation_fn, same_inputs, model, get,
                       rbf_gamma=None):
    if 'conv' in model:
      test_utils.skip_test(self)

    key = random.PRNGKey(1)
    key, split = random.split(key)
    output_dim = 1024 if get == 'nngp' else 1
    b_std = 0.5
    W_std = 2.0
    if activation_fn[2].__name__ == 'Sin':
      W_std = 0.9
    if activation_fn[2].__name__ == 'Rbf':
      W_std = 1.0
      b_std = 0.0

    if model == 'fc':
      rtol = 0.04
      X0_1 = random.normal(key, (4, 2))
      X0_2 = None if same_inputs else random.normal(split, (2, 2))
      affine = stax.Dense(1024, W_std, b_std)
      readout = stax.Dense(output_dim)
      depth = 1

    else:
      rtol = 0.05
      X0_1 = random.normal(key, (2, 4, 4, 3))
      X0_2 = None if same_inputs else random.normal(split, (4, 4, 4, 3))
      affine = stax.Conv(512, (3, 2), W_std=W_std, b_std=b_std, padding='SAME')
      readout = stax.serial(stax.GlobalAvgPool() if 'pool' in model else
                            stax.Flatten(),
                            stax.Dense(output_dim))
      depth = 2

    if default_backend() == 'cpu':
      num_samplings = 200
      rtol *= 2
    else:
      num_samplings = (500 if activation_fn[2].__name__ in ('Sin', 'Rbf')
                       else 300)

    init_fn, apply_fn, kernel_fn = stax.serial(
        *[affine, activation_fn]*depth, readout)
    analytic_kernel = kernel_fn(X0_1, X0_2, get)
    mc_kernel_fn = nt.monte_carlo_kernel_fn(
        init_fn, apply_fn, split, num_samplings, implementation=2,
        vmap_axes=0
    )
    empirical_kernel = mc_kernel_fn(X0_1, X0_2, get)
    test_utils.assert_close_matrices(self, analytic_kernel,
                                     empirical_kernel, rtol)

    # Check match with explicit RBF
    if rbf_gamma is not None and get == 'nngp' and model == 'fc':
      input_dim = X0_1.shape[1]
      _, _, kernel_fn = self._RBF(rbf_gamma / input_dim)
      direct_rbf_kernel = kernel_fn(X0_1, X0_2, get)
      test_utils.assert_close_matrices(self, analytic_kernel,
                                       direct_rbf_kernel, rtol)
예제 #23
0
    def testGradientDescentMseEnsembleTrain(self):
        key = stateless_uniform(shape=[2],
                                seed=[1, 1],
                                minval=None,
                                maxval=None,
                                dtype=tf.int32)
        x = np.asarray(normal((8, 4, 6, 3), seed=key))
        _, _, kernel_fn = stax.serial(stax.Conv(1, (2, 2)), stax.Relu(),
                                      stax.Conv(1, (2, 1)))
        y = np.asarray(normal((8, 2, 5, 1), seed=key))
        predictor = predict.gradient_descent_mse_ensemble(kernel_fn, x, y)

        for t in [None, np.array([0., 1., 10.])]:
            with self.subTest(t=t):
                y_none = predictor(t, None, None, compute_cov=True)
                y_x = predictor(t, x, None, compute_cov=True)
                self._assertAllClose(y_none, y_x, 0.04)
예제 #24
0
def MyConv(*args, parameterization='standard', order=None, **kwargs):
    """Wrapper for convolutional layer with different parameterizations."""
    if parameterization == 'standard':
        return jax_stax.Conv(*args, **kwargs)
    elif parameterization == 'ntk':
        return stax.Conv(*args, b_std=1.0, **kwargs)[:2]
    elif parameterization == 'taylor':
        return TaylorConv(*args, b_std=1.0, order=order, **kwargs)
예제 #25
0
def _get_inputs_and_model(width=1, n_classes=2):
    key = random.PRNGKey(1)
    key, split = random.split(key)
    x1 = random.normal(key, (8, 4, 3, 2))
    x2 = random.normal(split, (4, 4, 3, 2))
    init_fun, apply_fun, ker_fun = stax.serial(stax.Conv(width, (3, 3)),
                                               stax.Relu(), stax.Flatten(),
                                               stax.Dense(n_classes, 2., 0.5))
    return x1, x2, init_fun, apply_fun, ker_fun, key
예제 #26
0
    def testPredictOnCPU(self):
        key1 = stateless_uniform(shape=[2],
                                 seed=[1, 1],
                                 minval=None,
                                 maxval=None,
                                 dtype=tf.int32)
        key2 = stateless_uniform(shape=[2],
                                 seed=[1, 1],
                                 minval=None,
                                 maxval=None,
                                 dtype=tf.int32)
        key3 = stateless_uniform(shape=[2],
                                 seed=[1, 1],
                                 minval=None,
                                 maxval=None,
                                 dtype=tf.int32)
        x_train = np.asarray(normal((4, 4, 4, 2), seed=key1))
        x_test = np.asarray(normal((8, 4, 4, 2), seed=key2))

        y_train = np.asarray(stateless_uniform(shape=(4, 2), seed=key3))

        _, _, kernel_fn = stax.serial(stax.Conv(1, (3, 3)), stax.Relu(),
                                      stax.Flatten(), stax.Dense(1))

        for store_on_device in [False, True]:
            for device_count in [0, 1]:
                for get in ['ntk', 'nngp', ('nngp', 'ntk'), ('ntk', 'nngp')]:
                    for x in [None, 'x_test']:
                        with self.subTest(store_on_device=store_on_device,
                                          device_count=device_count,
                                          get=get,
                                          x=x):
                            kernel_fn_batched = batch.batch(
                                kernel_fn, 2, device_count, store_on_device)
                            predictor = predict.gradient_descent_mse_ensemble(
                                kernel_fn_batched, x_train, y_train)

                            x = x if x is None else x_test
                            predict_none = predictor(None,
                                                     x,
                                                     get,
                                                     compute_cov=True)
                            predict_inf = predictor(np.inf,
                                                    x,
                                                    get,
                                                    compute_cov=True)
                            self.assertAllClose(predict_none, predict_inf)

                            if x is not None:
                                on_cpu = (not store_on_device
                                          or xla_bridge.get_backend().platform
                                          == 'cpu')
                                self.assertEqual(on_cpu,
                                                 utils.is_on_cpu(predict_inf))
                                self.assertEqual(on_cpu,
                                                 utils.is_on_cpu(predict_none))
예제 #27
0
 def WideResnet(block_size, k, num_classes):
   return stax.serial(
       stax.Conv(16, (3, 3), padding='SAME', parameterization='standard'),
       WideResnetGroup(block_size, int(16 * k)),
       WideResnetGroup(block_size, int(32 * k), (2, 2)),
       WideResnetGroup(block_size, int(64 * k), (2, 2)),
       stax.AvgPool((8, 8), padding='SAME'),
       stax.Flatten(),
       stax.Dense(num_classes, 1.0, 0.0, parameterization='standard'),
   )
예제 #28
0
    def _test_analytic_kernel_composition(self, batching_fn):
        # Check Fully-Connected.
        rng = random.PRNGKey(0)
        rng_self, rng_other = random.split(rng)
        x_self = random.normal(rng_self, (8, 10))
        x_other = random.normal(rng_other, (2, 10))
        Block = stax.serial(stax.Dense(256), stax.Relu())

        _, _, ker_fn = Block
        ker_fn = batching_fn(ker_fn)

        _, _, composed_ker_fn = stax.serial(Block, Block)

        ker_out = ker_fn(ker_fn(x_self))
        composed_ker_out = composed_ker_fn(x_self)
        if batching_fn == batch._parallel:
            # In the parallel setting, `x1_is_x2` is not computed correctly
            # when x1==x2.
            composed_ker_out = composed_ker_out._replace(
                x1_is_x2=ker_out.x1_is_x2)
        self.assertAllClose(ker_out, composed_ker_out, True)

        ker_out = ker_fn(ker_fn(x_self, x_other))
        composed_ker_out = composed_ker_fn(x_self, x_other)
        if batching_fn == batch._parallel:
            composed_ker_out = composed_ker_out._replace(
                x1_is_x2=ker_out.x1_is_x2)
        self.assertAllClose(ker_out, composed_ker_out, True)

        # Check convolutional + pooling.
        x_self = random.normal(rng, (8, 10, 10, 3))
        x_other = random.normal(rng, (2, 10, 10, 3))

        Block = stax.serial(stax.Conv(256, (2, 2)), stax.Relu())
        Readout = stax.serial(stax.GlobalAvgPool(), stax.Dense(10))

        block_ker_fn, readout_ker_fn = Block[2], Readout[2]
        _, _, composed_ker_fn = stax.serial(Block, Readout)
        block_ker_fn = batching_fn(block_ker_fn)
        readout_ker_fn = batching_fn(readout_ker_fn)

        ker_out = readout_ker_fn(block_ker_fn(x_self, marginalization='none'))
        composed_ker_out = composed_ker_fn(x_self)
        if batching_fn == batch._parallel:
            composed_ker_out = composed_ker_out._replace(
                x1_is_x2=ker_out.x1_is_x2)
        self.assertAllClose(ker_out, composed_ker_out, True)
        ker_out = readout_ker_fn(
            block_ker_fn(x_self, x_other, marginalization='none'))
        composed_ker_out = composed_ker_fn(x_self, x_other)
        if batching_fn == batch._parallel:
            composed_ker_out = composed_ker_out._replace(
                x1_is_x2=ker_out.x1_is_x2)
        self.assertAllClose(ker_out, composed_ker_out, True)
예제 #29
0
 def WideResnetBlock(channels, strides=(1, 1), channel_mismatch=False):
   main = stax.serial(
       stax.Relu(),
       stax.Conv(
           channels, (3, 3), strides, padding='SAME',
           parameterization='standard'
       ),
       stax.Relu(),
       stax.Conv(channels, (3, 3), padding='SAME',
                 parameterization='standard'),
   )
   shortcut = (
       stax.Identity()
       if not channel_mismatch
       else stax.Conv(
           channels, (3, 3), strides, padding='SAME',
           parameterization='standard'
       )
   )
   return stax.serial(stax.FanOut(2), stax.parallel(main, shortcut),
                      stax.FanInSum())
예제 #30
0
 def conv(out_chan, s):
   return stax.Conv(
       out_chan=out_chan,
       filter_shape=filter_shape,
       strides=strides,
       padding=padding,
       W_std=W_std,
       b_std=b_std,
       dimension_numbers=dimension_numbers,
       parameterization=parameterization,
       s=s
   )