Example #1
0
    def testGpInference(self):
        reg = 1e-5
        key = stateless_uniform(shape=[2],
                                seed=[1, 1],
                                minval=None,
                                maxval=None,
                                dtype=tf.int32)
        x_train = np.asarray(normal((4, 2), seed=key))
        init_fn, apply_fn, kernel_fn_analytic = stax.serial(
            stax.Dense(32, 2., 0.5), stax.Relu(), stax.Dense(10, 2., 0.5))
        y_train = np.asarray(normal((4, 10), seed=key))
        for kernel_fn_is_analytic in [True, False]:
            if kernel_fn_is_analytic:
                kernel_fn = kernel_fn_analytic
            else:
                _, params = init_fn(key, x_train.shape)
                kernel_fn_empirical = empirical.empirical_kernel_fn(apply_fn)

                def kernel_fn(x1, x2, get):
                    return kernel_fn_empirical(x1, x2, get, params)

            for get in [
                    None, 'nngp', 'ntk', ('nngp', ), ('ntk', ),
                ('nngp', 'ntk'), ('ntk', 'nngp')
            ]:
                k_dd = kernel_fn(x_train, None, get)

                gp_inference = predict.gp_inference(k_dd,
                                                    y_train,
                                                    diag_reg=reg)
                gd_ensemble = predict.gradient_descent_mse_ensemble(
                    kernel_fn, x_train, y_train, diag_reg=reg)
                for x_test in [None, 'x_test']:
                    x_test = None if x_test is None else np.asarray(
                        normal((8, 2), seed=key))
                    k_td = None if x_test is None else kernel_fn(
                        x_test, x_train, get)

                    for compute_cov in [True, False]:
                        with self.subTest(
                                kernel_fn_is_analytic=kernel_fn_is_analytic,
                                get=get,
                                x_test=x_test if x_test is None else 'x_test',
                                compute_cov=compute_cov):
                            if compute_cov:
                                nngp_tt = (True if x_test is None else
                                           kernel_fn(x_test, None, 'nngp'))
                            else:
                                nngp_tt = None

                            out_ens = gd_ensemble(None, x_test, get,
                                                  compute_cov)
                            out_ens_inf = gd_ensemble(np.inf, x_test, get,
                                                      compute_cov)
                            self._assertAllClose(out_ens_inf, out_ens, 0.08)

                            if (get is not None and 'nngp' not in get
                                    and compute_cov and k_td is not None):
                                with self.assertRaises(ValueError):
                                    out_gp_inf = gp_inference(
                                        get=get,
                                        k_test_train=k_td,
                                        nngp_test_test=nngp_tt)
                            else:
                                out_gp_inf = gp_inference(
                                    get=get,
                                    k_test_train=k_td,
                                    nngp_test_test=nngp_tt)
                                self.assertAllClose(out_ens, out_gp_inf)
Example #2
0
    def testTrainedEnsemblePredCov(self, train_shape, test_shape, network,
                                   out_logits):
        training_steps = 1000
        learning_rate = 0.1
        ensemble_size = 1024

        init_fn, apply_fn, kernel_fn = stax.serial(
            stax.Dense(128, W_std=1.2, b_std=0.05), stax.Erf(),
            stax.Dense(out_logits, W_std=1.2, b_std=0.05))

        opt_init, opt_update, get_params = optimizers.sgd(learning_rate)
        opt_update = jit(opt_update)

        key, x_test, x_train, y_train = self._get_inputs(
            out_logits, test_shape, train_shape)
        predict_fn_mse_ens = predict.gradient_descent_mse_ensemble(
            kernel_fn,
            x_train,
            y_train,
            learning_rate=learning_rate,
            diag_reg=0.)

        train = (x_train, y_train)
        ensemble_key = tf_random_split(key, ensemble_size)

        loss = jit(lambda params, x, y: 0.5 * np.mean(
            (apply_fn(params, x) - y)**2))
        grad_loss = jit(lambda state, x, y: grad(loss)
                        (get_params(state), x, y))

        def train_network(key):
            _, params = init_fn(key, (-1, ) + train_shape[1:])
            opt_state = opt_init(params)
            for i in range(training_steps):
                opt_state = opt_update(i, grad_loss(opt_state, *train),
                                       opt_state)

            return get_params(opt_state)

        params = vmap(train_network)(ensemble_key)
        rtol = 0.08

        for x in [None, 'x_test']:
            with self.subTest(x=x):
                x = x if x is None else x_test
                x_fin = x_train if x is None else x_test
                ensemble_fx = vmap(apply_fn, (0, None))(params, x_fin)

                mean_emp = np.mean(ensemble_fx, axis=0)
                mean_subtracted = ensemble_fx - mean_emp
                cov_emp = np.einsum(
                    'ijk,ilk->jl',
                    mean_subtracted,
                    mean_subtracted,
                    optimize=True) / (mean_subtracted.shape[0] *
                                      mean_subtracted.shape[-1])

                ntk = predict_fn_mse_ens(training_steps,
                                         x,
                                         'ntk',
                                         compute_cov=True)
                self._assertAllClose(mean_emp, ntk.mean, rtol)
                self._assertAllClose(cov_emp, ntk.covariance, rtol)
    def testPredictND(self):
        n_chan = 6
        key = random.PRNGKey(1)
        im_shape = (5, 4, 3)
        n_train = 2
        n_test = 2
        x_train = random.normal(key, (n_train, ) + im_shape)
        y_train = random.uniform(key, (n_train, 3, 2, n_chan))
        init_fn, apply_fn, _ = stax.Conv(n_chan, (3, 2), (1, 2))
        _, params = init_fn(key, x_train.shape)
        fx_train_0 = apply_fn(params, x_train)

        for trace_axes in [(), (-1, ), (-2, ), (-3, ), (0, 1), (2, 3), (2, ),
                           (1, 3), (0, -1), (0, 0, -3), (0, 1, 2, 3),
                           (0, 1, -1, 2)]:
            for ts in [None, np.arange(6).reshape((2, 3))]:
                for x in [None, 'x_test']:
                    with self.subTest(trace_axes=trace_axes, ts=ts, x=x):
                        t_shape = ts.shape if ts is not None else ()
                        y_test_shape = t_shape + (n_test, ) + y_train.shape[1:]
                        y_train_shape = t_shape + y_train.shape
                        x = x if x is None else random.normal(
                            key, (n_test, ) + im_shape)
                        fx_test_0 = None if x is None else apply_fn(params, x)

                        kernel_fn = empirical.empirical_kernel_fn(
                            apply_fn, trace_axes=trace_axes)

                        # TODO(romann): investigate the SIGTERM error on CPU.
                        # kernel_fn = jit(kernel_fn, static_argnums=(2,))
                        ntk_train_train = kernel_fn(x_train, None, 'ntk',
                                                    params)
                        if x is not None:
                            ntk_test_train = kernel_fn(x, x_train, 'ntk',
                                                       params)

                        loss = lambda x, y: 0.5 * np.mean(x - y)**2
                        predict_fn_mse = predict.gradient_descent_mse(
                            ntk_train_train, y_train, trace_axes=trace_axes)

                        predict_fn_mse_ensemble = predict.gradient_descent_mse_ensemble(
                            kernel_fn,
                            x_train,
                            y_train,
                            trace_axes=trace_axes,
                            params=params)

                        if x is None:
                            p_train_mse = predict_fn_mse(ts, fx_train_0)
                        else:
                            p_train_mse, p_test_mse = predict_fn_mse(
                                ts, fx_train_0, fx_test_0, ntk_test_train)
                            self.assertAllClose(y_test_shape, p_test_mse.shape)
                        self.assertAllClose(y_train_shape, p_train_mse.shape)

                        p_nngp_mse_ens, p_ntk_mse_ens = predict_fn_mse_ensemble(
                            ts, x, ('nngp', 'ntk'), compute_cov=True)
                        ref_shape = y_train_shape if x is None else y_test_shape
                        self.assertAllClose(ref_shape,
                                            p_ntk_mse_ens.mean.shape)
                        self.assertAllClose(ref_shape,
                                            p_nngp_mse_ens.mean.shape)

                        if ts is not None:
                            predict_fn = predict.gradient_descent(
                                loss,
                                ntk_train_train,
                                y_train,
                                trace_axes=trace_axes)

                            if x is None:
                                p_train = predict_fn(ts, fx_train_0)
                            else:
                                p_train, p_test = predict_fn(
                                    ts, fx_train_0, fx_test_0, ntk_test_train)
                                self.assertAllClose(y_test_shape, p_test.shape)
                            self.assertAllClose(y_train_shape, p_train.shape)
Example #4
0
    def testNTK_NTKNNGPAgreement(self, train_shape, test_shape, network,
                                 out_logits):
        _, x_test, x_train, y_train = self._get_inputs(out_logits, test_shape,
                                                       train_shape)
        _, _, ker_fun = _build_network(train_shape[1:], network, out_logits)

        reg = 1e-7
        predictor = predict.gradient_descent_mse_ensemble(ker_fun,
                                                          x_train,
                                                          y_train,
                                                          diag_reg=reg)

        ts = np.logspace(-2, 8, 10).reshape((5, 2))

        for t in (None, 'ts'):
            for x in (None, 'x_test'):
                with self.subTest(t=t, x=x):
                    x = x if x is None else x_test
                    t = t if t is None else ts

                    ntk = predictor(t=t, get='ntk', x_test=x)

                    # Test time broadcasting
                    if t is not None:
                        ntk_ind = np.array([
                            predictor(t=t, get='ntk', x_test=x)
                            for t in t.ravel()
                        ]).reshape(t.shape + ntk.shape[2:])
                        self.assertAllClose(ntk_ind, ntk)

                    # Create a hacked kernel function that always returns the ntk kernel
                    def always_ntk(x1, x2, get=('nngp', 'ntk')):
                        out = ker_fun(x1, x2, get=('nngp', 'ntk'))
                        if get == 'nngp' or get == 'ntk':
                            return out.ntk
                        else:
                            return out._replace(nngp=out.ntk)

                    predictor_ntk = predict.gradient_descent_mse_ensemble(
                        always_ntk, x_train, y_train, diag_reg=reg)

                    ntk_nngp = predictor_ntk(t=t, get='nngp', x_test=x)

                    # Test if you use nngp equations with ntk, you get the same mean
                    self.assertAllClose(ntk, ntk_nngp)

                    # Next test that if you go through the NTK code path, but with only
                    # the NNGP kernel, we recreate the NNGP dynamics.
                    # Create a hacked kernel function that always returns the nngp kernel
                    def always_nngp(x1, x2, get=('nngp', 'ntk')):
                        out = ker_fun(x1, x2, get=('nngp', 'ntk'))
                        if get == 'nngp' or get == 'ntk':
                            return out.nngp
                        else:
                            return out._replace(ntk=out.nngp)

                    predictor_nngp = predict.gradient_descent_mse_ensemble(
                        always_nngp, x_train, y_train, diag_reg=reg)

                    nngp_cov = predictor(t=t,
                                         get='nngp',
                                         x_test=x,
                                         compute_cov=True).covariance

                    # test time broadcasting for covariance
                    nngp_ntk_cov = predictor_nngp(t=t,
                                                  get='ntk',
                                                  x_test=x,
                                                  compute_cov=True).covariance
                    if t is not None:
                        nngp_ntk_cov_ind = np.array([
                            predictor_nngp(t=t,
                                           get='ntk',
                                           x_test=x,
                                           compute_cov=True).covariance
                            for t in t.ravel()
                        ]).reshape(t.shape + nngp_cov.shape[2:])
                        self.assertAllClose(nngp_ntk_cov_ind, nngp_ntk_cov)

                    # Test if you use ntk equations with nngp, you get the same cov
                    # Although, due to accumulation of numerical errors, only roughly.
                    self.assertAllClose(nngp_cov, nngp_ntk_cov)
Example #5
0
  def test_kwargs(self, do_batch, mode):
    rng = random.PRNGKey(1)

    x_train = random.normal(rng, (8, 7, 10))
    x_test = random.normal(rng, (4, 7, 10))
    y_train = random.normal(rng, (8, 1))

    rng_train, rng_test = random.split(rng, 2)

    pattern_train = random.normal(rng, (8, 7, 7))
    pattern_test = random.normal(rng, (4, 7, 7))

    init_fn, apply_fn, kernel_fn = stax.serial(
        stax.Dense(8),
        stax.Relu(),
        stax.Dropout(rate=0.4),
        stax.Aggregate(),
        stax.GlobalAvgPool(),
        stax.Dense(1)
    )

    kw_dd = dict(pattern=(pattern_train, pattern_train))
    kw_td = dict(pattern=(pattern_test, pattern_train))
    kw_tt = dict(pattern=(pattern_test, pattern_test))

    if mode == 'mc':
      kernel_fn = monte_carlo_kernel_fn(init_fn, apply_fn, rng, 2,
                                        batch_size=2 if do_batch else 0)

    elif mode == 'empirical':
      kernel_fn = empirical_kernel_fn(apply_fn)
      if do_batch:
        raise absltest.SkipTest('Batching of empirical kernel is not '
                                'implemented with keyword arguments.')

      for kw in (kw_dd, kw_td, kw_tt):
        kw.update(dict(params=init_fn(rng, x_train.shape)[1],
                       get=('nngp', 'ntk')))

      kw_dd.update(dict(rng=(rng_train, None)))
      kw_td.update(dict(rng=(rng_test, rng_train)))
      kw_tt.update(dict(rng=(rng_test, None)))

    elif mode == 'analytic':
      if do_batch:
        kernel_fn = batch.batch(kernel_fn, batch_size=2)

    else:
      raise ValueError(mode)

    k_dd = kernel_fn(x_train, None, **kw_dd)
    k_td = kernel_fn(x_test, x_train, **kw_td)
    k_tt = kernel_fn(x_test, None, **kw_tt)

    # Infinite time NNGP/NTK.
    predict_fn_gp = predict.gp_inference(k_dd, y_train)
    out_gp = predict_fn_gp(k_test_train=k_td, nngp_test_test=k_tt.nngp)

    if mode == 'empirical':
      for kw in (kw_dd, kw_td, kw_tt):
        kw.pop('get')

    predict_fn_ensemble = predict.gradient_descent_mse_ensemble(kernel_fn,
                                                                x_train,
                                                                y_train,
                                                                **kw_dd)
    out_ensemble = predict_fn_ensemble(x_test=x_test, compute_cov=True, **kw_tt)
    self.assertAllClose(out_gp, out_ensemble)

    # Finite time NTK test.
    predict_fn_mse = predict.gradient_descent_mse(k_dd.ntk, y_train)
    out_mse = predict_fn_mse(t=1.,
                             fx_train_0=None,
                             fx_test_0=0.,
                             k_test_train=k_td.ntk)
    out_ensemble = predict_fn_ensemble(t=1.,
                                       get='ntk',
                                       x_test=x_test,
                                       compute_cov=False,
                                       **kw_tt)
    self.assertAllClose(out_mse, out_ensemble)

    # Finite time NNGP train.
    predict_fn_mse = predict.gradient_descent_mse(k_dd.nngp, y_train)
    out_mse = predict_fn_mse(t=2.,
                             fx_train_0=0.,
                             fx_test_0=None,
                             k_test_train=k_td.nngp)
    out_ensemble = predict_fn_ensemble(t=2.,
                                       get='nngp',
                                       x_test=None,
                                       compute_cov=False,
                                       **kw_dd)
    self.assertAllClose(out_mse, out_ensemble)