def main(unused_argv):
    # Build data pipelines.
    print('Loading data.')
    x_train, y_train, x_test, y_test = \
        datasets.mnist(FLAGS.train_size, FLAGS.test_size)

    # Build the network
    init_fn, f, _ = stax.serial(stax.Dense(2048, 1., 0.05), stax.Erf(),
                                stax.Dense(10, 1., 0.05))

    key = random.PRNGKey(0)
    _, params = init_fn(key, (-1, 784))

    # Create and initialize an optimizer.
    opt_init, opt_apply, get_params = optimizers.sgd(FLAGS.learning_rate)
    state = opt_init(params)

    # Create an mse loss function and a gradient function.
    loss = lambda fx, y_hat: 0.5 * np.mean((fx - y_hat)**2)
    grad_loss = jit(grad(lambda params, x, y: loss(f(params, x), y)))

    # Create an MSE predictor to solve the NTK equation in function space.
    ntk = batch(get_ntk_fun_empirical(f), batch_size=4, device_count=0)
    g_dd = ntk(x_train, None, params)
    g_td = ntk(x_test, x_train, params)
    predictor = predict.gradient_descent_mse(g_dd, y_train, g_td)

    # Get initial values of the network in function space.
    fx_train = f(params, x_train)
    fx_test = f(params, x_test)

    # Train the network.
    train_steps = int(FLAGS.train_time // FLAGS.learning_rate)
    print('Training for {} steps'.format(train_steps))

    for i in range(train_steps):
        params = get_params(state)
        state = opt_apply(i, grad_loss(params, x_train, y_train), state)

    # Get predictions from analytic computation.
    print('Computing analytic prediction.')
    fx_train, fx_test = predictor(FLAGS.train_time, fx_train, fx_test)

    # Print out summary data comparing the linear / nonlinear model.
    util.print_summary('train', y_train, f(params, x_train), fx_train, loss)
    util.print_summary('test', y_test, f(params, x_test), fx_test, loss)
def _build_network(input_shape, network, out_logits):
  if len(input_shape) == 1:
    assert network == FLAT
    return stax.serial(
        stax.Dense(4096, W_std=1.2, b_std=0.05), stax.Erf(),
        stax.Dense(out_logits, W_std=1.2, b_std=0.05))
  elif len(input_shape) == 3:
    if network == POOLING:
      return stax.serial(
          stax.Conv(CONVOLUTION_CHANNELS, (3, 3), W_std=2.0, b_std=0.05),
          stax.GlobalAvgPool(), stax.Dense(out_logits, W_std=2.0, b_std=0.05))
    elif network == FLAT:
      return stax.serial(
          stax.Conv(CONVOLUTION_CHANNELS, (3, 3), W_std=2.0, b_std=0.05),
          stax.Flatten(), stax.Dense(out_logits, W_std=2.0, b_std=0.05))
    else:
      raise ValueError('Unexpected network type found: {}'.format(network))
  else:
    raise ValueError('Expected flat or image test input.')
Пример #3
0
  def test_sparse_inputs(self, act, kernel):
    key = random.PRNGKey(1)

    input_count = 4
    sparse_count = 2
    input_size = 128
    width = 4096

    # NOTE(schsam): It seems that convergence is slower when inputs are sparse.
    samples = N_SAMPLES

    if xla_bridge.get_backend().platform == 'gpu':
      jtu._default_tolerance[np.onp.dtype(np.onp.float64)] = 5e-4
      samples = 100 * N_SAMPLES
    else:
      jtu._default_tolerance[np.onp.dtype(np.onp.float32)] = 5e-2
      jtu._default_tolerance[np.onp.dtype(np.onp.float64)] = 5e-3

    # a batch of dense inputs
    x_dense = random.normal(key, (input_count, input_size))
    x_sparse = ops.index_update(x_dense, ops.index[:sparse_count, :], 0.)

    activation = stax.Relu() if act == 'relu' else stax.Erf()

    init_fn, apply_fn, kernel_fn = stax.serial(
        stax.Dense(width),
        activation,
        stax.Dense(1 if kernel == 'ntk' else width))
    exact = kernel_fn(x_sparse, None, kernel)
    mc = monte_carlo.monte_carlo_kernel_fn(init_fn, apply_fn,
                                           random.split(key, 2)[0],
                                           samples)(x_sparse, None, kernel)
    mc = np.reshape(mc, exact.shape)

    assert not np.any(np.isnan(exact))
    self.assertAllClose(exact[sparse_count:, sparse_count:],
                        mc[sparse_count:, sparse_count:], True)
Пример #4
0
PADDINGS = [
    'SAME',
    'VALID',
    'CIRCULAR'
]

STRIDES = [
    None,
    (1, 2),
    (2, 1),
]

ACTIVATIONS = {
    # TODO: investigate poor erf convergence.
    stax.Erf(): 'erf',
    stax.Relu(): 'Relu',
    stax.ABRelu(-0.5, 0.7): 'ABRelu(-0.5, 0.7)'
}

PROJECTIONS = [
    'FLAT',
    'POOL',
    'ATTN_FIXED',
    'ATTN_PARAM'
]

LAYER_NORM = [
    (-1,),
    (1, 3),
    (1, 2, 3)
Пример #5
0
    def testTrainedEnsemblePredCov(self, train_shape, test_shape, network,
                                   out_logits):
        if xla_bridge.get_backend().platform == 'gpu' and config.read(
                'jax_enable_x64'):
            raise jtu.SkipTest('Not running GPU x64 to save time.')
        training_steps = 5000
        learning_rate = 1.0
        ensemble_size = 50

        init_fn, apply_fn, ker_fn = stax.serial(
            stax.Dense(1024, W_std=1.2, b_std=0.05), stax.Erf(),
            stax.Dense(out_logits, W_std=1.2, b_std=0.05))

        opt_init, opt_update, get_params = optimizers.sgd(learning_rate)
        opt_update = jit(opt_update)

        key = random.PRNGKey(0)
        key, = random.split(key, 1)

        key, split = random.split(key)
        x_train = np.cos(random.normal(split, train_shape))

        key, split = random.split(key)
        y_train = np.array(
            random.bernoulli(split, shape=(train_shape[0], out_logits)),
            np.float32)
        train = (x_train, y_train)
        key, split = random.split(key)
        x_test = np.cos(random.normal(split, test_shape))

        ensemble_key = random.split(key, ensemble_size)

        loss = jit(lambda params, x, y: 0.5 * np.mean(
            (apply_fn(params, x) - y)**2))
        grad_loss = jit(lambda state, x, y: grad(loss)
                        (get_params(state), x, y))

        def train_network(key):
            _, params = init_fn(key, (-1, ) + train_shape[1:])
            opt_state = opt_init(params)
            for i in range(training_steps):
                opt_state = opt_update(i, grad_loss(opt_state, *train),
                                       opt_state)

            return get_params(opt_state)

        params = vmap(train_network)(ensemble_key)

        ensemble_fx = vmap(apply_fn, (0, None))(params, x_test)
        ensemble_loss = vmap(loss, (0, None, None))(params, x_train, y_train)
        ensemble_loss = np.mean(ensemble_loss)
        self.assertLess(ensemble_loss, 1e-5, True)

        mean_emp = np.mean(ensemble_fx, axis=0)
        mean_subtracted = ensemble_fx - mean_emp
        cov_emp = np.einsum(
            'ijk,ilk->jl', mean_subtracted, mean_subtracted, optimize=True) / (
                mean_subtracted.shape[0] * mean_subtracted.shape[-1])

        reg = 1e-7
        ntk_predictions = predict.gp_inference(ker_fn,
                                               x_train,
                                               y_train,
                                               x_test,
                                               'ntk',
                                               reg,
                                               compute_cov=True)

        self.assertAllClose(mean_emp, ntk_predictions.mean, True, RTOL, ATOL)
        self.assertAllClose(cov_emp, ntk_predictions.covariance, True, RTOL,
                            ATOL)
Пример #6
0
 def _get_phi(cls, i):
     return {0: stax.Relu(), 1: stax.Erf(), 2: stax.Abs()}[i % 3]
Пример #7
0
    def test_mask_conv(self, same_inputs, get, mask_axis, mask_constant,
                       concat, proj, p, n, transpose):
        if isinstance(concat, int) and concat > n:
            raise absltest.SkipTest('Concatenation axis out of bounds.')

        test_utils.skip_test(self)
        if default_backend() == 'gpu' and n > 3:
            raise absltest.SkipTest('>=4D-CNN is not supported on GPUs.')

        width = 256
        n_samples = 256
        tol = 0.03
        key = random.PRNGKey(1)

        spatial_shape = ((1, 2, 3, 2, 1) if transpose else (15, 8, 9))[:n]
        filter_shape = ((2, 3, 1, 2, 1) if transpose else (7, 2, 3))[:n]
        strides = (2, 1, 3, 2, 3)[:n]
        spatial_spec = 'HWDZX'[:n]
        dimension_numbers = ('N' + spatial_spec + 'C', 'OI' + spatial_spec,
                             'N' + spatial_spec + 'C')

        x1 = np.cos(random.normal(key, (2, ) + spatial_shape + (2, )))
        x1 = test_utils.mask(x1, mask_constant, mask_axis, key, p)

        if same_inputs:
            x2 = None
        else:
            x2 = np.cos(random.normal(key, (4, ) + spatial_shape + (2, )))
            x2 = test_utils.mask(x2, mask_constant, mask_axis, key, p)

        def get_attn():
            return stax.GlobalSelfAttention(
                n_chan_out=width,
                n_chan_key=width,
                n_chan_val=int(np.round(float(width) / int(np.sqrt(width)))),
                n_heads=int(np.sqrt(width)),
            ) if proj == 'avg' else stax.Identity()

        conv = stax.ConvTranspose if transpose else stax.Conv

        nn = stax.serial(
            stax.FanOut(3),
            stax.parallel(
                stax.serial(
                    conv(dimension_numbers=dimension_numbers,
                         out_chan=width,
                         strides=strides,
                         filter_shape=filter_shape,
                         padding='CIRCULAR',
                         W_std=1.5,
                         b_std=0.2),
                    stax.LayerNorm(axis=(1, -1)),
                    stax.Abs(),
                    stax.DotGeneral(rhs=0.9),
                    conv(dimension_numbers=dimension_numbers,
                         out_chan=width,
                         strides=strides,
                         filter_shape=filter_shape,
                         padding='VALID',
                         W_std=1.2,
                         b_std=0.1),
                ),
                stax.serial(
                    conv(dimension_numbers=dimension_numbers,
                         out_chan=width,
                         strides=strides,
                         filter_shape=filter_shape,
                         padding='SAME',
                         W_std=0.1,
                         b_std=0.3),
                    stax.Relu(),
                    stax.Dropout(0.7),
                    conv(dimension_numbers=dimension_numbers,
                         out_chan=width,
                         strides=strides,
                         filter_shape=filter_shape,
                         padding='VALID',
                         W_std=0.9,
                         b_std=1.),
                ),
                stax.serial(
                    get_attn(),
                    conv(dimension_numbers=dimension_numbers,
                         out_chan=width,
                         strides=strides,
                         filter_shape=filter_shape,
                         padding='CIRCULAR',
                         W_std=1.,
                         b_std=0.1),
                    stax.Erf(),
                    stax.Dropout(0.2),
                    stax.DotGeneral(rhs=0.7),
                    conv(dimension_numbers=dimension_numbers,
                         out_chan=width,
                         strides=strides,
                         filter_shape=filter_shape,
                         padding='VALID',
                         W_std=1.,
                         b_std=0.1),
                )),
            (stax.FanInSum() if concat is None else stax.FanInConcat(concat)),
            get_attn(),
            {
                'avg': stax.GlobalAvgPool(),
                'sum': stax.GlobalSumPool(),
                'flatten': stax.Flatten(),
            }[proj],
        )

        if get == 'nngp':
            init_fn, apply_fn, kernel_fn = stax.serial(
                nn, stax.Dense(width, 1., 0.))
        elif get == 'ntk':
            init_fn, apply_fn, kernel_fn = stax.serial(nn,
                                                       stax.Dense(1, 1., 0.))
        else:
            raise ValueError(get)

        kernel_fn_mc = nt.monte_carlo_kernel_fn(
            init_fn,
            apply_fn,
            key,
            n_samples,
            device_count=0 if concat in (0, -n) else -1,
            implementation=_DEFAULT_TESTING_NTK_IMPLEMENTATION,
            vmap_axes=None if concat in (0, -n) else 0,
        )

        kernel_fn = jit(kernel_fn, static_argnames='get')
        exact = kernel_fn(x1, x2, get, mask_constant=mask_constant)
        empirical = kernel_fn_mc(x1, x2, get=get, mask_constant=mask_constant)
        test_utils.assert_close_matrices(self, empirical, exact, tol)
Пример #8
0
    def test_mask_fc(self, same_inputs, get, concat, p, mask_axis,
                     mask_constant):
        width = 512
        n_samples = 128
        tol = 0.04
        key = random.PRNGKey(1)

        x1 = random.normal(key, (4, 6, 5, 7))
        x1 = test_utils.mask(x1, mask_constant, mask_axis, key, p)

        if same_inputs:
            x2 = None
        else:
            x2 = random.normal(key, (2, 6, 5, 7))
            x2 = test_utils.mask(x2, mask_constant, mask_axis, key, p)

        nn = stax.serial(
            stax.Flatten(), stax.FanOut(3),
            stax.parallel(
                stax.serial(
                    stax.Dense(width, 1., 0.1),
                    stax.Abs(),
                    stax.DotGeneral(lhs=-0.2),
                    stax.Dense(width, 1.5, 0.01),
                ),
                stax.serial(
                    stax.Dense(width, 1.1, 0.1),
                    stax.DotGeneral(rhs=0.7),
                    stax.Erf(),
                    stax.Dense(width if concat != 1 else 512, 1.5, 0.1),
                ),
                stax.serial(
                    stax.DotGeneral(rhs=0.5),
                    stax.Dense(width, 1.2),
                    stax.ABRelu(-0.2, 0.4),
                    stax.Dense(width if concat != 1 else 1024, 1.3, 0.2),
                )),
            (stax.FanInSum() if concat is None else stax.FanInConcat(concat)),
            stax.Dense(width, 2., 0.01), stax.Relu())

        if get == 'nngp':
            init_fn, apply_fn, kernel_fn = stax.serial(
                nn, stax.Dense(width, 1., 0.1))
        elif get == 'ntk':
            init_fn, apply_fn, kernel_fn = stax.serial(nn,
                                                       stax.Dense(1, 1., 0.1))
        else:
            raise ValueError(get)

        kernel_fn_mc = nt.monte_carlo_kernel_fn(
            init_fn,
            apply_fn,
            key,
            n_samples,
            device_count=0 if concat in (0, -2) else -1,
            implementation=_DEFAULT_TESTING_NTK_IMPLEMENTATION,
            vmap_axes=None if concat in (0, -2) else 0,
        )

        kernel_fn = jit(kernel_fn, static_argnames='get')
        exact = kernel_fn(x1, x2, get, mask_constant=mask_constant)
        empirical = kernel_fn_mc(x1, x2, get=get, mask_constant=mask_constant)
        test_utils.assert_close_matrices(self, empirical, exact, tol)
Пример #9
0
    def test_input_req(self, same_inputs):
        test_utils.skip_test(self)

        key = random.PRNGKey(1)
        x1 = random.normal(key, (2, 7, 8, 4, 3))
        x2 = None if same_inputs else random.normal(key, (4, 7, 8, 4, 3))

        _, _, wrong_conv_fn = stax.serial(
            stax.Conv(out_chan=1,
                      filter_shape=(1, 2, 3),
                      dimension_numbers=('NDHWC', 'HDWIO', 'NCDWH')),
            stax.Relu(),
            stax.Conv(out_chan=1,
                      filter_shape=(1, 2, 3),
                      dimension_numbers=('NHDWC', 'HWDIO', 'NCWHD')))
        with self.assertRaises(ValueError):
            wrong_conv_fn(x1, x2)

        init_fn, apply_fn, correct_conv_fn = stax.serial(
            stax.Conv(out_chan=1024,
                      filter_shape=(1, 2, 3),
                      dimension_numbers=('NHWDC', 'DHWIO', 'NCWDH')),
            stax.Relu(),
            stax.Conv(out_chan=1024,
                      filter_shape=(1, 2, 3),
                      dimension_numbers=('NCHDW', 'WHDIO', 'NCDWH')),
            stax.Flatten(), stax.Dense(1024))

        correct_conv_fn_mc = nt.monte_carlo_kernel_fn(
            init_fn=init_fn,
            apply_fn=apply_fn,
            key=key,
            n_samples=400,
            implementation=_DEFAULT_TESTING_NTK_IMPLEMENTATION,
            vmap_axes=0)
        K = correct_conv_fn(x1, x2, get='nngp')
        K_mc = correct_conv_fn_mc(x1, x2, get='nngp')
        self.assertAllClose(K, K_mc, atol=0.01, rtol=0.05)

        _, _, wrong_conv_fn = stax.serial(
            stax.Conv(out_chan=1,
                      filter_shape=(1, 2, 3),
                      dimension_numbers=('NDHWC', 'HDWIO', 'NCDWH')),
            stax.GlobalAvgPool(channel_axis=2))
        with self.assertRaises(ValueError):
            wrong_conv_fn(x1, x2)

        init_fn, apply_fn, correct_conv_fn = stax.serial(
            stax.Conv(out_chan=1024,
                      filter_shape=(1, 2, 3),
                      dimension_numbers=('NHDWC', 'DHWIO', 'NDWCH')),
            stax.Relu(), stax.AvgPool((2, 1, 3), batch_axis=0,
                                      channel_axis=-2),
            stax.Conv(out_chan=1024,
                      filter_shape=(1, 2, 3),
                      dimension_numbers=('NDHCW', 'IHWDO', 'NDCHW')),
            stax.Relu(), stax.GlobalAvgPool(channel_axis=2), stax.Dense(1024))

        correct_conv_fn_mc = nt.monte_carlo_kernel_fn(
            init_fn=init_fn,
            apply_fn=apply_fn,
            key=key,
            n_samples=300,
            implementation=_DEFAULT_TESTING_NTK_IMPLEMENTATION,
            vmap_axes=0)
        K = correct_conv_fn(x1, x2, get='nngp')
        K_mc = correct_conv_fn_mc(x1, x2, get='nngp')
        self.assertAllClose(K, K_mc, atol=0.01, rtol=0.05)

        _, _, wrong_conv_fn = stax.serial(
            stax.Flatten(),
            stax.Dense(1),
            stax.Erf(),
            stax.Conv(out_chan=1,
                      filter_shape=(1, 2),
                      dimension_numbers=('CN', 'IO', 'NC')),
        )
        with self.assertRaises(ValueError):
            wrong_conv_fn(x1, x2)

        init_fn, apply_fn, correct_conv_fn = stax.serial(
            stax.Flatten(), stax.Conv(out_chan=1024, filter_shape=()),
            stax.Relu(), stax.Dense(1))

        correct_conv_fn_mc = nt.monte_carlo_kernel_fn(
            init_fn=init_fn,
            apply_fn=apply_fn,
            key=key,
            n_samples=200,
            implementation=_DEFAULT_TESTING_NTK_IMPLEMENTATION,
            vmap_axes=0)
        K = correct_conv_fn(x1, x2, get='ntk')
        K_mc = correct_conv_fn_mc(x1, x2, get='ntk')
        self.assertAllClose(K, K_mc, atol=0.01, rtol=0.05)
Пример #10
0
def main(unused_argv):
    # Build data pipelines.
    print('Loading data.')
    x_train, y_train, x_test, y_test = \
      datasets.mnist(FLAGS.train_size, FLAGS.test_size)

    # x_train
    import numpy
    # numpy.argmax(y_train,1)%2
    # y_train_tmp = numpy.zeros((y_train.shape[0],2))
    # y_train_tmp[np.arange(y_train.shape[0]),numpy.argmax(y_train,1)%2] = 1
    # y_train = y_train_tmp
    # y_test_tmp = numpy.zeros((y_test.shape[0],2))
    # y_test_tmp[np.arange(y_train.shape[0]),numpy.argmax(y_test,1)%2] = 1
    # y_test = y_test_tmp

    y_train_tmp = numpy.argmax(y_train, 1) % 2
    y_train = np.expand_dims(y_train_tmp, 1)
    y_test_tmp = numpy.argmax(y_test, 1) % 2
    y_test = np.expand_dims(y_test_tmp, 1)
    # print(y_train)
    # Build the network
    # init_fn, apply_fn, _ = stax.serial(
    #   stax.Dense(2048, 1., 0.05),
    #   # stax.Erf(),
    #   stax.Relu(),
    #   stax.Dense(2048, 1., 0.05),
    #   # stax.Erf(),
    #   stax.Relu(),
    #   stax.Dense(10, 1., 0.05))
    init_fn, apply_fn, _ = stax.serial(stax.Dense(2048, 1., 0.05), stax.Erf(),
                                       stax.Dense(1, 1., 0.05))

    # key = random.PRNGKey(0)
    randnnn = numpy.random.random_integers(np.iinfo(np.int32).min,
                                           high=np.iinfo(np.int32).max,
                                           size=2)[0]
    key = random.PRNGKey(randnnn)
    _, params = init_fn(key, (-1, 784))

    # params

    # Create and initialize an optimizer.
    opt_init, opt_apply, get_params = optimizers.sgd(FLAGS.learning_rate)
    state = opt_init(params)
    # state

    # Create an mse loss function and a gradient function.
    loss = lambda fx, y_hat: 0.5 * np.mean((fx - y_hat)**2)
    grad_loss = jit(grad(lambda params, x, y: loss(apply_fn(params, x), y)))

    # Create an MSE predictor to solve the NTK equation in function space.
    ntk = nt.batch(nt.empirical_ntk_fn(apply_fn), batch_size=4, device_count=0)
    g_dd = ntk(x_train, None, params)
    g_td = ntk(x_test, x_train, params)
    predictor = nt.predict.gradient_descent_mse(g_dd, y_train, g_td)
    # g_dd.shape

    # Get initial values of the network in function space.
    fx_train = apply_fn(params, x_train)
    fx_test = apply_fn(params, x_test)

    # Train the network.
    train_steps = int(FLAGS.train_time // FLAGS.learning_rate)
    print('Training for {} steps'.format(train_steps))

    for i in range(train_steps):
        params = get_params(state)
        state = opt_apply(i, grad_loss(params, x_train, y_train), state)

    # Get predictions from analytic computation.
    print('Computing analytic prediction.')
    # fx_train, fx_test = predictor(FLAGS.train_time, fx_train, fx_test)
    fx_train, fx_test = predictor(FLAGS.train_time, fx_train, fx_test)

    # Print out summary data comparing the linear / nonlinear model.
    util.print_summary('train', y_train, apply_fn(params, x_train), fx_train,
                       loss)
    util.print_summary('test', y_test, apply_fn(params, x_test), fx_test, loss)
Пример #11
0
class ElementwiseNumericalTest(test_utils.NeuralTangentsTestCase):

  @parameterized.named_parameters(
      test_utils.cases_from_list({
          'testcase_name':
              '_{}_{}_{}_{}'.format(
                  model,
                  phi[0].__name__,
                  'Same_inputs' if same_inputs else 'Different_inputs',
                  get),
          'model': model,
          'phi': phi,
          'same_inputs': same_inputs,
          'get': get,
      }
                          for model in ['fc', 'conv-pool', 'conv-flatten']
                          for phi in [
                              stax.Erf(),
                              stax.Gelu(),
                              stax.Sin(),
                          ]
                          for same_inputs in [False, True]
                          for get in ['nngp', 'ntk']))
  def test_elementwise_numerical(self, same_inputs, model, phi, get):
    if 'conv' in model:
      test_utils.skip_test(self)

    key, split = random.split(random.PRNGKey(1))

    output_dim = 1
    b_std = 0.01
    W_std = 1.0
    rtol = 2e-3
    deg = 25
    if get == 'ntk':
      rtol *= 2
    if default_backend() == 'tpu':
      rtol *= 2

    if model == 'fc':
      X0_1 = random.normal(key, (3, 7))
      X0_2 = None if same_inputs else random.normal(split, (5, 7))
      affine = stax.Dense(1024, W_std, b_std)
      readout = stax.Dense(output_dim)
      depth = 1
    else:
      X0_1 = random.normal(key, (2, 8, 8, 3))
      X0_2 = None if same_inputs else random.normal(split, (3, 8, 8, 3))
      affine = stax.Conv(1024, (3, 2), W_std=W_std, b_std=b_std, padding='SAME')
      readout = stax.serial(stax.GlobalAvgPool() if 'pool' in model else
                            stax.Flatten(),
                            stax.Dense(output_dim))
      depth = 2

    _, _, kernel_fn = stax.serial(*[affine, phi] * depth, readout)
    analytic_kernel = kernel_fn(X0_1, X0_2, get)

    fn = lambda x: phi[1]((), x)
    _, _, kernel_fn = stax.serial(
        *[affine, stax.ElementwiseNumerical(fn, deg=deg)] * depth, readout)
    numerical_activation_kernel = kernel_fn(X0_1, X0_2, get)

    test_utils.assert_close_matrices(self, analytic_kernel,
                                     numerical_activation_kernel, rtol)
Пример #12
0
    def testTrainedEnsemblePredCov(self, train_shape, test_shape, network,
                                   out_logits):
        training_steps = 1000
        learning_rate = 0.1
        ensemble_size = 1024

        init_fn, apply_fn, kernel_fn = stax.serial(
            stax.Dense(128, W_std=1.2, b_std=0.05), stax.Erf(),
            stax.Dense(out_logits, W_std=1.2, b_std=0.05))

        opt_init, opt_update, get_params = optimizers.sgd(learning_rate)
        opt_update = jit(opt_update)

        key, x_test, x_train, y_train = self._get_inputs(
            out_logits, test_shape, train_shape)
        predict_fn_mse_ens = predict.gradient_descent_mse_ensemble(
            kernel_fn,
            x_train,
            y_train,
            learning_rate=learning_rate,
            diag_reg=0.)

        train = (x_train, y_train)
        ensemble_key = random.split(key, ensemble_size)

        loss = jit(lambda params, x, y: 0.5 * np.mean(
            (apply_fn(params, x) - y)**2))
        grad_loss = jit(lambda state, x, y: grad(loss)
                        (get_params(state), x, y))

        def train_network(key):
            _, params = init_fn(key, (-1, ) + train_shape[1:])
            opt_state = opt_init(params)
            for i in range(training_steps):
                opt_state = opt_update(i, grad_loss(opt_state, *train),
                                       opt_state)

            return get_params(opt_state)

        params = vmap(train_network)(ensemble_key)
        rtol = 0.08

        for x in [None, 'x_test']:
            with self.subTest(x=x):
                x = x if x is None else x_test
                x_fin = x_train if x is None else x_test
                ensemble_fx = vmap(apply_fn, (0, None))(params, x_fin)

                mean_emp = np.mean(ensemble_fx, axis=0, keepdims=True)
                mean_subtracted = ensemble_fx - mean_emp
                cov_emp = np.einsum(
                    'ijk,ilk->jl',
                    mean_subtracted,
                    mean_subtracted,
                    optimize=True) / (mean_subtracted.shape[0] *
                                      mean_subtracted.shape[-1])

                ntk = predict_fn_mse_ens(training_steps,
                                         x,
                                         'ntk',
                                         compute_cov=True)
                self._assertAllClose(mean_emp, ntk.mean, rtol)
                self._assertAllClose(cov_emp, ntk.covariance, rtol)
Пример #13
0
    def test_vmap_axes(self, same_inputs):
        n1, n2 = 3, 4
        c1, c2, c3 = 9, 5, 7
        h2, h3, w3 = 6, 8, 2

        def get_x(n, k):
            k1, k2, k3 = random.split(k, 3)
            x1 = random.normal(k1, (n, c1))
            x2 = random.normal(k2, (h2, n, c2))
            x3 = random.normal(k3, (c3, w3, n, h3))
            x = [(x1, x2), x3]
            return x

        x1 = get_x(n1, random.PRNGKey(1))
        x2 = get_x(n2, random.PRNGKey(2)) if not same_inputs else None

        p1 = random.normal(random.PRNGKey(5), (n1, h2, h2))
        p2 = None if same_inputs else random.normal(random.PRNGKey(6),
                                                    (n2, h2, h2))

        init_fn, apply_fn, _ = stax.serial(
            stax.parallel(
                stax.parallel(
                    stax.serial(stax.Dense(4, 2., 0.1), stax.Relu(),
                                stax.Dense(3, 1., 0.15)),  # 1
                    stax.serial(
                        stax.Conv(7, (2, ),
                                  padding='SAME',
                                  dimension_numbers=('HNC', 'OIH', 'NHC')),
                        stax.Erf(), stax.Aggregate(1, 0, -1),
                        stax.GlobalAvgPool(), stax.Dense(3, 0.5, 0.2)),  # 2
                ),
                stax.serial(
                    stax.Conv(5, (2, 3),
                              padding='SAME',
                              dimension_numbers=('CWNH', 'IOHW', 'HWCN')),
                    stax.Sin(),
                )  # 3
            ),
            stax.parallel(
                stax.FanInSum(),
                stax.Conv(2, (2, 1),
                          dimension_numbers=('HWCN', 'OIHW', 'HNWC'))))

        _, params = init_fn(random.PRNGKey(3), tree_map(np.shape, x1))
        implicit = jit(nt.empirical_ntk_fn(apply_fn, implementation=2))
        direct = jit(nt.empirical_ntk_fn(apply_fn, implementation=1))

        implicit_batched = jit(
            nt.empirical_ntk_fn(apply_fn,
                                vmap_axes=([(0, 1), 2], [-2,
                                                         -3], dict(pattern=0)),
                                implementation=2))
        direct_batched = jit(
            nt.empirical_ntk_fn(apply_fn,
                                vmap_axes=([(-2, -2),
                                            -2], [0, 1], dict(pattern=-3)),
                                implementation=1))

        k = direct(x1, x2, params, pattern=(p1, p2))

        self.assertAllClose(k, implicit(x1, x2, params, pattern=(p1, p2)))
        self.assertAllClose(k, direct_batched(x1, x2, params,
                                              pattern=(p1, p2)))
        self.assertAllClose(k,
                            implicit_batched(x1, x2, params, pattern=(p1, p2)))
Пример #14
0
FLAGS["train_size"] = 128
FLAGS["test_size"] = 128
FLAGS["train_time"] = 10000.0
class Struct:
    def __init__(self, **entries):
        self.__dict__.update(entries)
FLAGS=Struct(**FLAGS)

print('Loading data.')
x_train, y_train, x_test, y_test = \
  datasets.mnist(FLAGS.train_size, FLAGS.test_size)

# Build the network
init_fn, apply_fn, _ = stax.serial(
  stax.Dense(2048, 1., 0.05),
  stax.Erf(),
  stax.Dense(10, 1., 0.05))

key = random.PRNGKey(0)
_, params = init_fn(key, (-1, 784))

# params

# Create and initialize an optimizer.
opt_init, opt_apply, get_params = optimizers.sgd(FLAGS.learning_rate)
state = opt_init(params)
# state
#%%


# Create an mse loss function and a gradient function.
Пример #15
0
st.header("Define your (Finite) Network")

# st.sidebar.markdown("## Network")
n_hidden = st.slider("Hidden width", 16, 2048, 512, step=16)
depth = st.slider("Network depth (excludes input and output)",
                  1,
                  10,
                  2,
                  step=1)
sigma_w = st.slider("Sigma w", 0.1, 3.0, 1.5, step=0.1)
sigma_b = st.slider("Sigma b", 0.01, 0.1, 0.05, step=0.01)

activation_fn = st.selectbox("Activation Function", ("Erf", "ReLU", "None"))

activation_fn_dict = {"Erf": stax.Erf(), "ReLU": stax.Relu(), "None": None}
activation_fn = activation_fn_dict[activation_fn]

sequence = ((stax.Dense(n_hidden, W_std=sigma_w, b_std=sigma_b),
             activation_fn) if activation_fn else
            (stax.Dense(n_hidden, W_std=sigma_w, b_std=sigma_b), ))
init_fn, apply_fn, kernel_fn = stax.serial(
    *(sequence * depth), stax.Dense(1, W_std=sigma_w, b_std=sigma_b))

apply_fn = jit(apply_fn)
kernel_fn = jit(kernel_fn, static_argnums=(2, ))

st.markdown("""
We define our network using the **Neural Tanget Stax** module.
It allows us to define the architecture, initialisation and 
returns the network function plus (infinite width) kernel function.
Пример #16
0
class ElementwiseTest(test_utils.NeuralTangentsTestCase):

  @parameterized.product(
      phi=[
          stax.Identity(),
          stax.Erf(),
          stax.Sin(),
          stax.Relu(),
      ],
      same_inputs=[False, True, None],
      n=[0, 1, 2],
      diagonal_batch=[True, False],
      diagonal_spatial=[True, False]
  )
  def test_elementwise(
      self,
      same_inputs,
      phi,
      n,
      diagonal_batch,
      diagonal_spatial
  ):
    fn = lambda x: phi[1]((), x)

    name = phi[0].__name__

    def nngp_fn(cov12, var1, var2):
      if 'Identity' in name:
        res = cov12

      elif 'Erf' in name:
        prod = (1 + 2 * var1) * (1 + 2 * var2)
        res = np.arcsin(2 * cov12 / np.sqrt(prod)) * 2 / np.pi

      elif 'Sin' in name:
        sum_ = (var1 + var2)
        s1 = np.exp((-0.5 * sum_ + cov12))
        s2 = np.exp((-0.5 * sum_ - cov12))
        res = (s1 - s2) / 2

      elif 'Relu' in name:
        prod = var1 * var2
        sqrt = np.sqrt(np.maximum(prod - cov12 ** 2, 1e-30))
        angles = np.arctan2(sqrt, cov12)
        dot_sigma = (1 - angles / np.pi) / 2
        res = sqrt / (2 * np.pi) + dot_sigma * cov12

      else:
        raise NotImplementedError(name)

      return res

    _, _, kernel_fn = stax.serial(stax.Dense(1), stax.Elementwise(fn, nngp_fn),
                                  stax.Dense(1), stax.Elementwise(fn, nngp_fn))
    _, _, kernel_fn_manual = stax.serial(stax.Dense(1), phi,
                                         stax.Dense(1), phi)

    key = random.PRNGKey(1)
    shape = (4, 3, 2)[:n] + (1,)
    x1 = random.normal(key, (5,) + shape)
    if same_inputs is None:
      x2 = None
    elif same_inputs is True:
      x2 = x1
    else:
      x2 = random.normal(key, (6,) + shape)

    kwargs = dict(diagonal_batch=diagonal_batch,
                  diagonal_spatial=diagonal_spatial)

    k = kernel_fn(x1, x2, **kwargs)
    k_manual = kernel_fn_manual(x1, x2, **kwargs).replace(is_gaussian=False)
    self.assertAllClose(k_manual, k)
Пример #17
0
def main(unused_argv):
    # Build data and .
    print('Loading data.')
    x_train, y_train, x_test, y_test = datasets.get_dataset('mnist',
                                                            permute_train=True)

    # Build the network
    init_fn, f, _ = stax.serial(stax.Dense(512, 1., 0.05), stax.Erf(),
                                stax.Dense(10, 1., 0.05))

    key = random.stateless_random_uniform(shape=[2],
                                          seed=[0, 0],
                                          minval=None,
                                          maxval=None,
                                          dtype=np.int32)
    _, params = init_fn(key, (1, 784))

    # Linearize the network about its initial parameters.
    f_lin = nt.linearize(f, params)

    # Create and initialize an optimizer for both f and f_lin.
    opt_init, opt_apply, get_params = optimizers.momentum(
        FLAGS.learning_rate, 0.9)
    opt_apply = jit(opt_apply)

    state = opt_init(params)
    state_lin = opt_init(params)

    # Create a cross-entropy loss function.
    loss = lambda fx, y_hat: -np.mean(logsoftmax(fx) * y_hat)

    # Specialize the loss function to compute gradients for both linearized and
    # full networks.
    grad_loss = jit(grad(lambda params, x, y: loss(f(params, x), y)))
    grad_loss_lin = jit(grad(lambda params, x, y: loss(f_lin(params, x), y)))

    # Train the network.
    print('Training.')
    print('Epoch\tLoss\tLinearized Loss')
    print('------------------------------------------')

    epoch = 0
    steps_per_epoch = 50000 // FLAGS.batch_size

    for i, (x, y) in enumerate(
            datasets.minibatch(x_train, y_train, FLAGS.batch_size,
                               FLAGS.train_epochs)):

        params = get_params(state)
        state = opt_apply(i, grad_loss(params, x, y), state)

        params_lin = get_params(state_lin)
        state_lin = opt_apply(i, grad_loss_lin(params_lin, x, y), state_lin)

        if i % steps_per_epoch == 0:
            print('{}\t{}\t{}'.format(epoch, loss(f(params, x), y),
                                      loss(f_lin(params_lin, x), y)))
            epoch += 1

    # Print out summary data comparing the linear / nonlinear model.
    x, y = x_train[:10000], y_train[:10000]
    util.print_summary('train', y, f(params, x), f_lin(params_lin, x), loss)
    util.print_summary('test', y_test, f(params, x_test),
                       f_lin(params_lin, x_test), loss)
Пример #18
0
def weight_space(train_embedding, test_embedding, data_set):
    init_fn, f, _ = stax.serial(
        stax.Dense(512, 1., 0.05),
        stax.Erf(),
        # 2 denotes 2 type of classes
        stax.Dense(2, 1., 0.05))

    key = random.PRNGKey(0)
    # (-1, 135),  135 denotes the feature length, here is 9 * 15 = 135
    _, params = init_fn(key, (-1, 135))

    # Linearize the network about its initial parameters.
    f_lin = nt.linearize(f, params)

    # Create and initialize an optimizer for both f and f_lin.
    opt_init, opt_apply, get_params = optimizers.momentum(1.0, 0.9)
    opt_apply = jit(opt_apply)

    state = opt_init(params)
    state_lin = opt_init(params)

    # Create a cross-entropy loss function.
    loss = lambda fx, y_hat: -np.mean(logsoftmax(fx) * y_hat)

    # Specialize the loss function to compute gradients for both linearized and
    # full networks.
    grad_loss = jit(grad(lambda params, x, y: loss(f(params, x), y)))
    grad_loss_lin = jit(grad(lambda params, x, y: loss(f_lin(params, x), y)))

    # Train the network.
    print('Training.')
    print('Epoch\tLoss\tLinearized Loss')
    print('------------------------------------------')

    epoch = 0
    # Use whole batch
    batch_size = 64
    train_epochs = 10
    steps_per_epoch = 100

    for i, (x, y) in enumerate(
            datasets.mini_batch(train_embedding, data_set['Y_train'],
                                batch_size, train_epochs)):
        params = get_params(state)
        state = opt_apply(i, grad_loss(params, x, y), state)

        params_lin = get_params(state_lin)
        state_lin = opt_apply(i, grad_loss_lin(params_lin, x, y), state_lin)

        if i % steps_per_epoch == 0:
            print('{}\t{:.4f}\t{:.4f}'.format(epoch, loss(f(params, x), y),
                                              loss(f_lin(params_lin, x), y)))
            epoch += 1
        if i / steps_per_epoch == train_epochs:
            break

    # Print out summary data comparing the linear / nonlinear model.
    x, y = train_embedding[:10000], data_set['Y_train'][:10000]
    util.print_summary('train', y, f(params, x), f_lin(params_lin, x), loss)
    util.print_summary('test', data_set['Y_test'], f(params, test_embedding),
                       f_lin(params_lin, test_embedding), loss)