Exemplo n.º 1
0
        def mc_sampling(count=10):
            empirical_mean = 0.
            key = random.PRNGKey(100)
            init_fn, f, _ = _build_network(train_shape[1:], network,
                                           out_logits)
            _kernel_fn = empirical.empirical_kernel_fn(f)
            kernel_fn = jit(
                lambda x1, x2, params: _kernel_fn(x1, x2, params, 'ntk'))

            for _ in range(count):
                key, split = random.split(key)
                _, params = init_fn(split, train_shape)

                g_dd = kernel_fn(data_train, None, params)
                g_td = kernel_fn(data_test, data_train, params)
                predictor = predict.gradient_descent_mse(
                    g_dd, data_labels, g_td)

                fx_initial_train = f(params, data_train)
                fx_initial_test = f(params, data_test)

                _, fx_pred_test = predictor(1.0e8, fx_initial_train,
                                            fx_initial_test)
                empirical_mean += fx_pred_test
            return empirical_mean / count
Exemplo n.º 2
0
        def mc_sampling(count=10):
            key = random.PRNGKey(100)
            init_fn, f, _ = _build_network(train_shape[1:], network,
                                           out_logits)
            _kernel_fn = empirical.empirical_kernel_fn(f)
            kernel_fn = jit(
                lambda x1, x2, params: _kernel_fn(x1, x2, params, 'ntk'))
            collect_test_predict = []
            for _ in range(count):
                key, split = random.split(key)
                _, params = init_fn(split, train_shape)

                g_dd = kernel_fn(x_train, None, params)
                g_td = kernel_fn(x_test, x_train, params)
                predictor = predict.gradient_descent_mse(g_dd, y_train, g_td)

                fx_initial_train = f(params, x_train)
                fx_initial_test = f(params, x_test)

                _, fx_pred_test = predictor(1.0e8, fx_initial_train,
                                            fx_initial_test)
                collect_test_predict.append(fx_pred_test)
            collect_test_predict = np.array(collect_test_predict)
            mean_emp = np.mean(collect_test_predict, axis=0)
            mean_subtracted = collect_test_predict - mean_emp
            cov_emp = np.einsum(
                'ijk,ilk->jl', mean_subtracted, mean_subtracted,
                optimize=True) / (mean_subtracted.shape[0] *
                                  mean_subtracted.shape[-1])
            return mean_emp, cov_emp
 def predict_empirical(key):
   _, params = init_fn(key, train_shape)
   g_dd = kernel_fn(x_train, None, params)
   g_td = kernel_fn(x_test, x_train, params)
   predict_fn = predict.gradient_descent_mse(g_dd, y_train, diag_reg=reg)
   fx_train_0 = f(params, x_train)
   fx_test_0 = f(params, x_test)
   return predict_fn(ts, fx_train_0, fx_test_0, g_td)
  def testNTKGDPrediction(self, train_shape, test_shape, network, out_logits,
                          fn_and_kernel, momentum, learning_rate, t, loss):
    key, x_test, x_train, y_train = self._get_inputs(out_logits, test_shape,
                                                     train_shape)

    params, f, ntk = fn_and_kernel(key, train_shape[1:], network, out_logits)

    g_dd = ntk(x_train, None, 'ntk')
    g_td = ntk(x_test, x_train, 'ntk')

    # Regress to an MSE loss.
    loss_fn = lambda y, y_hat: 0.5 * np.mean((y - y_hat)**2)
    grad_loss = jit(grad(lambda params, x: loss_fn(f(params, x), y_train)))

    trace_axes = () if g_dd.ndim == 4 else (-1,)
    if loss == 'mse_analytic':
      if momentum is not None:
        raise absltest.SkipTest(momentum)
      predictor = predict.gradient_descent_mse(g_dd, y_train,
                                               learning_rate=learning_rate,
                                               trace_axes=trace_axes)
    elif loss == 'mse':
      predictor = predict.gradient_descent(loss_fn, g_dd, y_train,
                                           learning_rate=learning_rate,
                                           momentum=momentum,
                                           trace_axes=trace_axes)
    else:
      raise NotImplementedError(loss)

    predictor = jit(predictor)

    fx_train_0 = f(params, x_train)
    fx_test_0 = f(params, x_test)

    self._test_zero_time(predictor, fx_train_0, fx_test_0, g_td, momentum)
    self._test_multi_step(predictor, fx_train_0, fx_test_0, g_td, momentum)
    if loss == 'mse_analytic':
      self._test_inf_time(predictor, fx_train_0, fx_test_0, g_td, y_train)

    if momentum is None:
      opt_init, opt_update, get_params = optimizers.sgd(learning_rate)
    else:
      opt_init, opt_update, get_params = optimizers.momentum(learning_rate,
                                                             momentum)

    opt_state = opt_init(params)
    for i in range(t):
      params = get_params(opt_state)
      opt_state = opt_update(i, grad_loss(params, x_train), opt_state)

    params = get_params(opt_state)

    fx_train_nn, fx_test_nn = f(params, x_train), f(params, x_test)
    fx_train_t, fx_test_t = predictor(t, fx_train_0, fx_test_0, g_td)

    self.assertAllClose(fx_train_nn, fx_train_t, rtol=RTOL, atol=ATOL)
    self.assertAllClose(fx_test_nn, fx_test_t, rtol=RTOL, atol=ATOL)
def main(unused_argv):
    # Build data pipelines.
    print('Loading data.')
    x_train, y_train, x_test, y_test = \
        datasets.mnist(FLAGS.train_size, FLAGS.test_size)

    # Build the network
    init_fn, f, _ = stax.serial(stax.Dense(2048, 1., 0.05), stax.Erf(),
                                stax.Dense(10, 1., 0.05))

    key = random.PRNGKey(0)
    _, params = init_fn(key, (-1, 784))

    # Create and initialize an optimizer.
    opt_init, opt_apply, get_params = optimizers.sgd(FLAGS.learning_rate)
    state = opt_init(params)

    # Create an mse loss function and a gradient function.
    loss = lambda fx, y_hat: 0.5 * np.mean((fx - y_hat)**2)
    grad_loss = jit(grad(lambda params, x, y: loss(f(params, x), y)))

    # Create an MSE predictor to solve the NTK equation in function space.
    ntk = batch(get_ntk_fun_empirical(f), batch_size=4, device_count=0)
    g_dd = ntk(x_train, None, params)
    g_td = ntk(x_test, x_train, params)
    predictor = predict.gradient_descent_mse(g_dd, y_train, g_td)

    # Get initial values of the network in function space.
    fx_train = f(params, x_train)
    fx_test = f(params, x_test)

    # Train the network.
    train_steps = int(FLAGS.train_time // FLAGS.learning_rate)
    print('Training for {} steps'.format(train_steps))

    for i in range(train_steps):
        params = get_params(state)
        state = opt_apply(i, grad_loss(params, x_train, y_train), state)

    # Get predictions from analytic computation.
    print('Computing analytic prediction.')
    fx_train, fx_test = predictor(FLAGS.train_time, fx_train, fx_test)

    # Print out summary data comparing the linear / nonlinear model.
    util.print_summary('train', y_train, f(params, x_train), fx_train, loss)
    util.print_summary('test', y_test, f(params, x_test), fx_test, loss)
Exemplo n.º 6
0
        def mc_sampling(count=10):
            empirical_mean = 0.
            key = random.PRNGKey(100)
            for _ in range(count):
                key, split = random.split(key)
                params, f, theta = _empirical_kernel(split, train_shape[1:],
                                                     network, out_logits)
                g_dd = theta(data_train, None)
                g_td = theta(data_test, data_train)
                predictor = predict.gradient_descent_mse(
                    g_dd, data_labels, g_td)

                fx_initial_train = f(params, data_train)
                fx_initial_test = f(params, data_test)

                _, fx_pred_test = predictor(1.0e8, fx_initial_train,
                                            fx_initial_test)
                empirical_mean += fx_pred_test
            return empirical_mean / count
Exemplo n.º 7
0
    def testNTKMSEPrediction(self, train_shape, test_shape, network,
                             out_logits, fn_and_kernel):

        key = random.PRNGKey(0)

        key, split = random.split(key)
        x_train = random.normal(split, train_shape)

        key, split = random.split(key)
        y_train = np.array(
            random.bernoulli(split, shape=(train_shape[0], out_logits)),
            np.float32)

        key, split = random.split(key)
        x_test = random.normal(split, test_shape)

        params, f, ntk = fn_and_kernel(key, train_shape[1:], network,
                                       out_logits)

        # Regress to an MSE loss.
        loss = lambda params, x: \
            0.5 * np.mean((f(params, x) - y_train) ** 2)
        grad_loss = jit(grad(loss))

        g_dd = ntk(x_train, None, 'ntk')
        g_td = ntk(x_test, x_train, 'ntk')

        predictor = predict.gradient_descent_mse(g_dd, y_train, g_td)

        atol = ATOL
        rtol = RTOL
        step_size = 0.1

        if len(train_shape) > 2:
            # Hacky way to up the tolerance just for convolutions.
            atol = ATOL * 2
            rtol = RTOL * 2
            step_size = 0.1

        train_time = 100.0
        steps = int(train_time / step_size)

        opt_init, opt_update, get_params = optimizers.sgd(step_size)
        opt_state = opt_init(params)

        fx_initial_train = f(params, x_train)
        fx_initial_test = f(params, x_test)

        fx_pred_train, fx_pred_test = predictor(0.0, fx_initial_train,
                                                fx_initial_test)

        self.assertAllClose(fx_initial_train, fx_pred_train, True)
        self.assertAllClose(fx_initial_test, fx_pred_test, True)

        for i in range(steps):
            params = get_params(opt_state)
            opt_state = opt_update(i, grad_loss(params, x_train), opt_state)

        params = get_params(opt_state)
        fx_train = f(params, x_train)
        fx_test = f(params, x_test)

        fx_pred_train, fx_pred_test = predictor(train_time, fx_initial_train,
                                                fx_initial_test)

        fx_disp_train = np.sqrt(np.mean((fx_train - fx_initial_train)**2))
        fx_disp_test = np.sqrt(np.mean((fx_test - fx_initial_test)**2))

        fx_error_train = (fx_train - fx_pred_train) / fx_disp_train
        fx_error_test = (fx_test - fx_pred_test) / fx_disp_test

        self.assertAllClose(fx_error_train, np.zeros_like(fx_error_train),
                            True, rtol, atol)
        self.assertAllClose(fx_error_test, np.zeros_like(fx_error_test), True,
                            rtol, atol)
Exemplo n.º 8
0
    def testPredictND(self):
        n_chan = 6
        key = random.PRNGKey(1)
        im_shape = (5, 4, 3)
        n_train = 2
        n_test = 2
        x_train = random.normal(key, (n_train, ) + im_shape)
        y_train = random.uniform(key, (n_train, 3, 2, n_chan))
        init_fn, apply_fn, _ = stax.Conv(n_chan, (3, 2), (1, 2))
        _, params = init_fn(key, x_train.shape)
        fx_train_0 = apply_fn(params, x_train)

        for trace_axes in [(), (-1, ), (-2, ), (-3, ), (0, 1), (2, 3), (2, ),
                           (1, 3), (0, -1), (0, 0, -3), (0, 1, 2, 3),
                           (0, 1, -1, 2)]:
            for ts in [None, np.arange(6).reshape((2, 3))]:
                for x in [None, 'x_test']:
                    with self.subTest(trace_axes=trace_axes, ts=ts, x=x):
                        t_shape = ts.shape if ts is not None else ()
                        y_test_shape = t_shape + (n_test, ) + y_train.shape[1:]
                        y_train_shape = t_shape + y_train.shape
                        x = x if x is None else random.normal(
                            key, (n_test, ) + im_shape)
                        fx_test_0 = None if x is None else apply_fn(params, x)

                        kernel_fn = empirical.empirical_kernel_fn(
                            apply_fn, trace_axes=trace_axes)

                        # TODO(romann): investigate the SIGTERM error on CPU.
                        # kernel_fn = jit(kernel_fn, static_argnums=(2,))
                        ntk_train_train = kernel_fn(x_train, None, 'ntk',
                                                    params)
                        if x is not None:
                            ntk_test_train = kernel_fn(x, x_train, 'ntk',
                                                       params)

                        loss = lambda x, y: 0.5 * np.mean(x - y)**2
                        predict_fn_mse = predict.gradient_descent_mse(
                            ntk_train_train, y_train, trace_axes=trace_axes)

                        predict_fn_mse_ensemble = predict.gradient_descent_mse_ensemble(
                            kernel_fn,
                            x_train,
                            y_train,
                            trace_axes=trace_axes,
                            params=params)

                        if x is None:
                            p_train_mse = predict_fn_mse(ts, fx_train_0)
                        else:
                            p_train_mse, p_test_mse = predict_fn_mse(
                                ts, fx_train_0, fx_test_0, ntk_test_train)
                            self.assertAllClose(y_test_shape, p_test_mse.shape)
                        self.assertAllClose(y_train_shape, p_train_mse.shape)

                        p_nngp_mse_ens, p_ntk_mse_ens = predict_fn_mse_ensemble(
                            ts, x, ('nngp', 'ntk'), compute_cov=True)
                        ref_shape = y_train_shape if x is None else y_test_shape
                        self.assertAllClose(ref_shape,
                                            p_ntk_mse_ens.mean.shape)
                        self.assertAllClose(ref_shape,
                                            p_nngp_mse_ens.mean.shape)

                        if ts is not None:
                            predict_fn = predict.gradient_descent(
                                loss,
                                ntk_train_train,
                                y_train,
                                trace_axes=trace_axes)

                            if x is None:
                                p_train = predict_fn(ts, fx_train_0)
                            else:
                                p_train, p_test = predict_fn(
                                    ts, fx_train_0, fx_test_0, ntk_test_train)
                                self.assertAllClose(y_test_shape, p_test.shape)
                            self.assertAllClose(y_train_shape, p_train.shape)
Exemplo n.º 9
0
  def test_kwargs(self, do_batch, mode):
    rng = random.PRNGKey(1)

    x_train = random.normal(rng, (8, 7, 10))
    x_test = random.normal(rng, (4, 7, 10))
    y_train = random.normal(rng, (8, 1))

    rng_train, rng_test = random.split(rng, 2)

    pattern_train = random.normal(rng, (8, 7, 7))
    pattern_test = random.normal(rng, (4, 7, 7))

    init_fn, apply_fn, kernel_fn = stax.serial(
        stax.Dense(8),
        stax.Relu(),
        stax.Dropout(rate=0.4),
        stax.Aggregate(),
        stax.GlobalAvgPool(),
        stax.Dense(1)
    )

    kw_dd = dict(pattern=(pattern_train, pattern_train))
    kw_td = dict(pattern=(pattern_test, pattern_train))
    kw_tt = dict(pattern=(pattern_test, pattern_test))

    if mode == 'mc':
      kernel_fn = monte_carlo_kernel_fn(init_fn, apply_fn, rng, 2,
                                        batch_size=2 if do_batch else 0)

    elif mode == 'empirical':
      kernel_fn = empirical_kernel_fn(apply_fn)
      if do_batch:
        raise absltest.SkipTest('Batching of empirical kernel is not '
                                'implemented with keyword arguments.')

      for kw in (kw_dd, kw_td, kw_tt):
        kw.update(dict(params=init_fn(rng, x_train.shape)[1],
                       get=('nngp', 'ntk')))

      kw_dd.update(dict(rng=(rng_train, None)))
      kw_td.update(dict(rng=(rng_test, rng_train)))
      kw_tt.update(dict(rng=(rng_test, None)))

    elif mode == 'analytic':
      if do_batch:
        kernel_fn = batch.batch(kernel_fn, batch_size=2)

    else:
      raise ValueError(mode)

    k_dd = kernel_fn(x_train, None, **kw_dd)
    k_td = kernel_fn(x_test, x_train, **kw_td)
    k_tt = kernel_fn(x_test, None, **kw_tt)

    # Infinite time NNGP/NTK.
    predict_fn_gp = predict.gp_inference(k_dd, y_train)
    out_gp = predict_fn_gp(k_test_train=k_td, nngp_test_test=k_tt.nngp)

    if mode == 'empirical':
      for kw in (kw_dd, kw_td, kw_tt):
        kw.pop('get')

    predict_fn_ensemble = predict.gradient_descent_mse_ensemble(kernel_fn,
                                                                x_train,
                                                                y_train,
                                                                **kw_dd)
    out_ensemble = predict_fn_ensemble(x_test=x_test, compute_cov=True, **kw_tt)
    self.assertAllClose(out_gp, out_ensemble)

    # Finite time NTK test.
    predict_fn_mse = predict.gradient_descent_mse(k_dd.ntk, y_train)
    out_mse = predict_fn_mse(t=1.,
                             fx_train_0=None,
                             fx_test_0=0.,
                             k_test_train=k_td.ntk)
    out_ensemble = predict_fn_ensemble(t=1.,
                                       get='ntk',
                                       x_test=x_test,
                                       compute_cov=False,
                                       **kw_tt)
    self.assertAllClose(out_mse, out_ensemble)

    # Finite time NNGP train.
    predict_fn_mse = predict.gradient_descent_mse(k_dd.nngp, y_train)
    out_mse = predict_fn_mse(t=2.,
                             fx_train_0=0.,
                             fx_test_0=None,
                             k_test_train=k_td.nngp)
    out_ensemble = predict_fn_ensemble(t=2.,
                                       get='nngp',
                                       x_test=None,
                                       compute_cov=False,
                                       **kw_dd)
    self.assertAllClose(out_mse, out_ensemble)