def testGPInferenceGet(self, train_shape, test_shape, network, out_logits):

        key = random.PRNGKey(0)

        key, split = random.split(key)
        x_train = np.cos(random.normal(split, train_shape))

        key, split = random.split(key)
        y_train = np.array(
            random.bernoulli(split, shape=(train_shape[0], out_logits)),
            np.float32)

        key, split = random.split(key)
        x_test = np.cos(random.normal(split, test_shape))
        _, _, kernel_fn = _build_network(train_shape[1:], network, out_logits)

        out = predict.gp_inference(kernel_fn,
                                   x_train,
                                   y_train,
                                   x_test,
                                   'ntk',
                                   diag_reg=0.,
                                   compute_cov=True)
        assert isinstance(out, predict.Gaussian)

        out = predict.gp_inference(kernel_fn,
                                   x_train,
                                   y_train,
                                   x_test,
                                   'nngp',
                                   diag_reg=0.,
                                   compute_cov=True)
        assert isinstance(out, predict.Gaussian)

        out = predict.gp_inference(kernel_fn,
                                   x_train,
                                   y_train,
                                   x_test, ('ntk', ),
                                   diag_reg=0.,
                                   compute_cov=True)
        assert len(out) == 1 and isinstance(out[0], predict.Gaussian)

        out = predict.gp_inference(kernel_fn,
                                   x_train,
                                   y_train,
                                   x_test, ('ntk', 'nngp'),
                                   diag_reg=0.,
                                   compute_cov=True)
        assert (len(out) == 2 and isinstance(out[0], predict.Gaussian)
                and isinstance(out[1], predict.Gaussian))

        out2 = predict.gp_inference(kernel_fn,
                                    x_train,
                                    y_train,
                                    x_test, ('nngp', 'ntk'),
                                    diag_reg=0.,
                                    compute_cov=True)
        self.assertAllClose(out[0], out2[1], True)
        self.assertAllClose(out[1], out2[0], True)
    def testPredictOnCPU(self):
        x_train = random.normal(random.PRNGKey(1), (10, 4, 5, 3))
        x_test = random.normal(random.PRNGKey(1), (8, 4, 5, 3))

        y_train = random.uniform(random.PRNGKey(1), (10, 7))

        _, _, kernel_fn = stax.serial(stax.Conv(1, (3, 3)), stax.Relu(),
                                      stax.Flatten(), stax.Dense(1))

        for store_on_device in [False, True]:
            for device_count in [0, 1]:
                for get in ['ntk', 'nngp', ('nngp', 'ntk'), ('ntk', 'nngp')]:
                    with self.subTest(store_on_device=store_on_device,
                                      device_count=device_count,
                                      get=get):
                        kernel_fn_batched = batch.batch(
                            kernel_fn, 2, device_count, store_on_device)
                        predictor = predict.gradient_descent_mse_gp(
                            kernel_fn_batched, x_train, y_train, x_test, get,
                            0., True)
                        gp_inference = predict.gp_inference(
                            kernel_fn_batched, x_train, y_train, x_test, get,
                            0., True)

                        self.assertAllClose(predictor(None), predictor(np.inf),
                                            True)
                        self.assertAllClose(predictor(None), gp_inference,
                                            True)
    def testInfiniteTimeAgreement(self, train_shape, test_shape, network,
                                  out_logits, get):

        key = random.PRNGKey(0)

        key, split = random.split(key)
        x_train = np.cos(random.normal(split, train_shape))

        key, split = random.split(key)
        y_train = np.array(
            random.bernoulli(split, shape=(train_shape[0], out_logits)),
            np.float32)

        key, split = random.split(key)
        x_test = np.cos(random.normal(split, test_shape))
        _, _, kernel_fn = _build_network(train_shape[1:], network, out_logits)

        reg = 1e-7
        prediction = predict.gradient_descent_mse_gp(kernel_fn,
                                                     x_train,
                                                     y_train,
                                                     x_test,
                                                     get,
                                                     diag_reg=reg,
                                                     compute_cov=True)

        finite_prediction = prediction(np.inf)
        finite_prediction_none = prediction(None)
        gp_inference = predict.gp_inference(kernel_fn, x_train, y_train,
                                            x_test, get, reg, True)

        self.assertAllClose(finite_prediction_none, finite_prediction, True)
        self.assertAllClose(finite_prediction_none, gp_inference, True)
  def testGpInference(self):
    reg = 1e-5
    key = random.PRNGKey(1)
    x_train = random.normal(key, (4, 2))
    init_fn, apply_fn, kernel_fn_analytic = stax.serial(
        stax.Dense(32, 2., 0.5),
        stax.Relu(),
        stax.Dense(10, 2., 0.5))
    y_train = random.normal(key, (4, 10))
    for kernel_fn_is_analytic in [True, False]:
      if kernel_fn_is_analytic:
        kernel_fn = kernel_fn_analytic
      else:
        _, params = init_fn(key, x_train.shape)
        kernel_fn_empirical = empirical.empirical_kernel_fn(apply_fn)
        def kernel_fn(x1, x2, get):
          return kernel_fn_empirical(x1, x2, get, params)

      for get in [None,
                  'nngp', 'ntk',
                  ('nngp',), ('ntk',),
                  ('nngp', 'ntk'), ('ntk', 'nngp')]:
        k_dd = kernel_fn(x_train, None, get)

        gp_inference = predict.gp_inference(k_dd, y_train, diag_reg=reg)
        gd_ensemble = predict.gradient_descent_mse_ensemble(kernel_fn,
                                                            x_train,
                                                            y_train,
                                                            diag_reg=reg)
        for x_test in [None, 'x_test']:
          x_test = None if x_test is None else random.normal(key, (8, 2))
          k_td = None if x_test is None else kernel_fn(x_test, x_train, get)

          for compute_cov in [True, False]:
            with self.subTest(kernel_fn_is_analytic=kernel_fn_is_analytic,
                              get=get,
                              x_test=x_test if x_test is None else 'x_test',
                              compute_cov=compute_cov):
              if compute_cov:
                nngp_tt = (True if x_test is None else
                           kernel_fn(x_test, None, 'nngp'))
              else:
                nngp_tt = None

              out_ens = gd_ensemble(None, x_test, get, compute_cov)
              out_ens_inf = gd_ensemble(np.inf, x_test, get, compute_cov)
              self._assertAllClose(out_ens_inf, out_ens, 0.08)

              if (get is not None and
                  'nngp' not in get and
                  compute_cov and
                  k_td is not None):
                with self.assertRaises(ValueError):
                  out_gp_inf = gp_inference(get=get, k_test_train=k_td,
                                            nngp_test_test=nngp_tt)
              else:
                out_gp_inf = gp_inference(get=get, k_test_train=k_td,
                                          nngp_test_test=nngp_tt)
                self.assertAllClose(out_ens, out_gp_inf)
    def testNTKMeanPrediction(self, train_shape, test_shape, network,
                              out_logits):

        key = random.PRNGKey(0)

        key, split = random.split(key)
        data_train = np.cos(random.normal(split, train_shape))

        key, split = random.split(key)
        data_labels = np.array(
            random.bernoulli(split, shape=(train_shape[0], out_logits)),
            np.float32)

        key, split = random.split(key)
        data_test = np.cos(random.normal(split, test_shape))
        _, _, ker_fun = _build_network(train_shape[1:], network, out_logits)
        mean_pred, var = predict.gp_inference(ker_fun,
                                              data_train,
                                              data_labels,
                                              data_test,
                                              diag_reg=0.,
                                              mode='NTK',
                                              compute_var=True)

        if xla_bridge.get_backend().platform == 'tpu':
            eigh = np.onp.linalg.eigh
        else:
            eigh = np.linalg.eigh

        self.assertEqual(var.shape[0], data_test.shape[0])
        min_eigh = np.min(eigh(var)[0])
        self.assertGreater(min_eigh + 1e-10, 0.)

        def mc_sampling(count=10):
            empirical_mean = 0.
            key = random.PRNGKey(100)
            for _ in range(count):
                key, split = random.split(key)
                params, f, theta = _empirical_kernel(split, train_shape[1:],
                                                     network, out_logits)
                g_dd = theta(data_train, None)
                g_td = theta(data_test, data_train)
                predictor = predict.gradient_descent_mse(
                    g_dd, data_labels, g_td)

                fx_initial_train = f(params, data_train)
                fx_initial_test = f(params, data_test)

                _, fx_pred_test = predictor(1.0e8, fx_initial_train,
                                            fx_initial_test)
                empirical_mean += fx_pred_test
            return empirical_mean / count

        atol = ATOL
        rtol = RTOL
        mean_emp = mc_sampling(100)

        self.assertAllClose(mean_pred, mean_emp, True, rtol, atol)
    def testInfiniteTimeAgreement(self, train_shape, test_shape, network,
                                  out_logits, mode):
        # TODO(alemi): Add some finite time tests.

        key = random.PRNGKey(0)

        key, split = random.split(key)
        data_train = np.cos(random.normal(split, train_shape))

        key, split = random.split(key)
        data_labels = np.array(
            random.bernoulli(split, shape=(train_shape[0], out_logits)),
            np.float32)

        key, split = random.split(key)
        data_test = np.cos(random.normal(split, test_shape))
        _, _, ker_fun = _build_network(train_shape[1:], network, out_logits)

        reg = 1e-7
        mean_pred, var = predict.gp_inference(ker_fun,
                                              data_train,
                                              data_labels,
                                              data_test,
                                              diag_reg=reg,
                                              mode=mode,
                                              compute_var=True)
        prediction = predict.gradient_descent_mse_gp(ker_fun,
                                                     data_train,
                                                     data_labels,
                                                     data_test,
                                                     diag_reg=reg,
                                                     mode=mode,
                                                     compute_var=True)

        inf_mean_pred, inf_var = prediction(np.inf)

        self.assertAllClose(mean_pred, inf_mean_pred, True)
        self.assertAllClose(var, inf_var, True)
    def testTrainedEnsemblePredCov(self, train_shape, test_shape, network,
                                   out_logits):
        if xla_bridge.get_backend().platform == 'gpu' and config.read(
                'jax_enable_x64'):
            raise jtu.SkipTest('Not running GPU x64 to save time.')
        training_steps = 5000
        learning_rate = 1.0
        ensemble_size = 50

        init_fn, apply_fn, ker_fn = stax.serial(
            stax.Dense(1024, W_std=1.2, b_std=0.05), stax.Erf(),
            stax.Dense(out_logits, W_std=1.2, b_std=0.05))

        opt_init, opt_update, get_params = optimizers.sgd(learning_rate)
        opt_update = jit(opt_update)

        key = random.PRNGKey(0)
        key, = random.split(key, 1)

        key, split = random.split(key)
        x_train = np.cos(random.normal(split, train_shape))

        key, split = random.split(key)
        y_train = np.array(
            random.bernoulli(split, shape=(train_shape[0], out_logits)),
            np.float32)
        train = (x_train, y_train)
        key, split = random.split(key)
        x_test = np.cos(random.normal(split, test_shape))

        ensemble_key = random.split(key, ensemble_size)

        loss = jit(lambda params, x, y: 0.5 * np.mean(
            (apply_fn(params, x) - y)**2))
        grad_loss = jit(lambda state, x, y: grad(loss)
                        (get_params(state), x, y))

        def train_network(key):
            _, params = init_fn(key, (-1, ) + train_shape[1:])
            opt_state = opt_init(params)
            for i in range(training_steps):
                opt_state = opt_update(i, grad_loss(opt_state, *train),
                                       opt_state)

            return get_params(opt_state)

        params = vmap(train_network)(ensemble_key)

        ensemble_fx = vmap(apply_fn, (0, None))(params, x_test)
        ensemble_loss = vmap(loss, (0, None, None))(params, x_train, y_train)
        ensemble_loss = np.mean(ensemble_loss)
        self.assertLess(ensemble_loss, 1e-5, True)

        mean_emp = np.mean(ensemble_fx, axis=0)
        mean_subtracted = ensemble_fx - mean_emp
        cov_emp = np.einsum(
            'ijk,ilk->jl', mean_subtracted, mean_subtracted, optimize=True) / (
                mean_subtracted.shape[0] * mean_subtracted.shape[-1])

        reg = 1e-7
        ntk_predictions = predict.gp_inference(ker_fn,
                                               x_train,
                                               y_train,
                                               x_test,
                                               'ntk',
                                               reg,
                                               compute_cov=True)

        self.assertAllClose(mean_emp, ntk_predictions.mean, True, RTOL, ATOL)
        self.assertAllClose(cov_emp, ntk_predictions.covariance, True, RTOL,
                            ATOL)
    def testNTKMeanCovPrediction(self, train_shape, test_shape, network,
                                 out_logits):

        key = random.PRNGKey(0)

        key, split = random.split(key)
        x_train = np.cos(random.normal(split, train_shape))

        key, split = random.split(key)
        y_train = np.array(
            random.bernoulli(split, shape=(train_shape[0], out_logits)),
            np.float32)

        key, split = random.split(key)
        x_test = np.cos(random.normal(split, test_shape))
        _, _, kernel_fn = _build_network(train_shape[1:], network, out_logits)
        mean_pred, cov_pred = predict.gp_inference(kernel_fn,
                                                   x_train,
                                                   y_train,
                                                   x_test,
                                                   'ntk',
                                                   diag_reg=0.,
                                                   compute_cov=True)

        if xla_bridge.get_backend().platform == 'tpu':
            eigh = np.onp.linalg.eigh
        else:
            eigh = np.linalg.eigh

        self.assertEqual(cov_pred.shape[0], x_test.shape[0])
        min_eigh = np.min(eigh(cov_pred)[0])
        self.assertGreater(min_eigh + 1e-10, 0.)

        def mc_sampling(count=10):
            key = random.PRNGKey(100)
            init_fn, f, _ = _build_network(train_shape[1:], network,
                                           out_logits)
            _kernel_fn = empirical.empirical_kernel_fn(f)
            kernel_fn = jit(
                lambda x1, x2, params: _kernel_fn(x1, x2, params, 'ntk'))
            collect_test_predict = []
            for _ in range(count):
                key, split = random.split(key)
                _, params = init_fn(split, train_shape)

                g_dd = kernel_fn(x_train, None, params)
                g_td = kernel_fn(x_test, x_train, params)
                predictor = predict.gradient_descent_mse(g_dd, y_train, g_td)

                fx_initial_train = f(params, x_train)
                fx_initial_test = f(params, x_test)

                _, fx_pred_test = predictor(1.0e8, fx_initial_train,
                                            fx_initial_test)
                collect_test_predict.append(fx_pred_test)
            collect_test_predict = np.array(collect_test_predict)
            mean_emp = np.mean(collect_test_predict, axis=0)
            mean_subtracted = collect_test_predict - mean_emp
            cov_emp = np.einsum(
                'ijk,ilk->jl', mean_subtracted, mean_subtracted,
                optimize=True) / (mean_subtracted.shape[0] *
                                  mean_subtracted.shape[-1])
            return mean_emp, cov_emp

        atol = ATOL
        rtol = RTOL
        mean_emp, cov_emp = mc_sampling(100)

        self.assertAllClose(mean_pred, mean_emp, True, rtol, atol)
        self.assertAllClose(cov_pred, cov_emp, True, rtol, atol)
Beispiel #9
0
  def test_kwargs(self, do_batch, mode):
    rng = random.PRNGKey(1)

    x_train = random.normal(rng, (8, 7, 10))
    x_test = random.normal(rng, (4, 7, 10))
    y_train = random.normal(rng, (8, 1))

    rng_train, rng_test = random.split(rng, 2)

    pattern_train = random.normal(rng, (8, 7, 7))
    pattern_test = random.normal(rng, (4, 7, 7))

    init_fn, apply_fn, kernel_fn = stax.serial(
        stax.Dense(8),
        stax.Relu(),
        stax.Dropout(rate=0.4),
        stax.Aggregate(),
        stax.GlobalAvgPool(),
        stax.Dense(1)
    )

    kw_dd = dict(pattern=(pattern_train, pattern_train))
    kw_td = dict(pattern=(pattern_test, pattern_train))
    kw_tt = dict(pattern=(pattern_test, pattern_test))

    if mode == 'mc':
      kernel_fn = monte_carlo_kernel_fn(init_fn, apply_fn, rng, 2,
                                        batch_size=2 if do_batch else 0)

    elif mode == 'empirical':
      kernel_fn = empirical_kernel_fn(apply_fn)
      if do_batch:
        raise absltest.SkipTest('Batching of empirical kernel is not '
                                'implemented with keyword arguments.')

      for kw in (kw_dd, kw_td, kw_tt):
        kw.update(dict(params=init_fn(rng, x_train.shape)[1],
                       get=('nngp', 'ntk')))

      kw_dd.update(dict(rng=(rng_train, None)))
      kw_td.update(dict(rng=(rng_test, rng_train)))
      kw_tt.update(dict(rng=(rng_test, None)))

    elif mode == 'analytic':
      if do_batch:
        kernel_fn = batch.batch(kernel_fn, batch_size=2)

    else:
      raise ValueError(mode)

    k_dd = kernel_fn(x_train, None, **kw_dd)
    k_td = kernel_fn(x_test, x_train, **kw_td)
    k_tt = kernel_fn(x_test, None, **kw_tt)

    # Infinite time NNGP/NTK.
    predict_fn_gp = predict.gp_inference(k_dd, y_train)
    out_gp = predict_fn_gp(k_test_train=k_td, nngp_test_test=k_tt.nngp)

    if mode == 'empirical':
      for kw in (kw_dd, kw_td, kw_tt):
        kw.pop('get')

    predict_fn_ensemble = predict.gradient_descent_mse_ensemble(kernel_fn,
                                                                x_train,
                                                                y_train,
                                                                **kw_dd)
    out_ensemble = predict_fn_ensemble(x_test=x_test, compute_cov=True, **kw_tt)
    self.assertAllClose(out_gp, out_ensemble)

    # Finite time NTK test.
    predict_fn_mse = predict.gradient_descent_mse(k_dd.ntk, y_train)
    out_mse = predict_fn_mse(t=1.,
                             fx_train_0=None,
                             fx_test_0=0.,
                             k_test_train=k_td.ntk)
    out_ensemble = predict_fn_ensemble(t=1.,
                                       get='ntk',
                                       x_test=x_test,
                                       compute_cov=False,
                                       **kw_tt)
    self.assertAllClose(out_mse, out_ensemble)

    # Finite time NNGP train.
    predict_fn_mse = predict.gradient_descent_mse(k_dd.nngp, y_train)
    out_mse = predict_fn_mse(t=2.,
                             fx_train_0=0.,
                             fx_test_0=None,
                             k_test_train=k_td.nngp)
    out_ensemble = predict_fn_ensemble(t=2.,
                                       get='nngp',
                                       x_test=None,
                                       compute_cov=False,
                                       **kw_dd)
    self.assertAllClose(out_mse, out_ensemble)