Exemplo n.º 1
0
    def testPredictOnCPU(self):
        x_train = random.normal(random.PRNGKey(1), (10, 4, 5, 3))
        x_test = random.normal(random.PRNGKey(1), (8, 4, 5, 3))

        y_train = random.uniform(random.PRNGKey(1), (10, 7))

        _, _, kernel_fn = stax.serial(stax.Conv(1, (3, 3)), stax.Relu(),
                                      stax.Flatten(), stax.Dense(1))

        for store_on_device in [False, True]:
            for device_count in [0, 1]:
                for get in ['ntk', 'nngp', ('nngp', 'ntk'), ('ntk', 'nngp')]:
                    with self.subTest(store_on_device=store_on_device,
                                      device_count=device_count,
                                      get=get):
                        kernel_fn_batched = batch.batch(
                            kernel_fn, 2, device_count, store_on_device)
                        predictor = predict.gradient_descent_mse_gp(
                            kernel_fn_batched, x_train, y_train, x_test, get,
                            0., True)
                        gp_inference = predict.gp_inference(
                            kernel_fn_batched, x_train, y_train, x_test, get,
                            0., True)

                        self.assertAllClose(predictor(None), predictor(np.inf),
                                            True)
                        self.assertAllClose(predictor(None), gp_inference,
                                            True)
Exemplo n.º 2
0
    def testZeroTimeAgreement(self, train_shape, test_shape, network,
                              out_logits):
        """Test that the NTK and NNGP agree at t=0."""

        key = random.PRNGKey(0)

        key, split = random.split(key)
        x_train = np.cos(random.normal(split, train_shape))

        key, split = random.split(key)
        y_train = np.array(
            random.bernoulli(split, shape=(train_shape[0], out_logits)),
            np.float32)

        key, split = random.split(key)
        x_test = np.cos(random.normal(split, test_shape))
        _, _, ker_fun = _build_network(train_shape[1:], network, out_logits)

        reg = 1e-7
        prediction = predict.gradient_descent_mse_gp(ker_fun,
                                                     x_train,
                                                     y_train,
                                                     x_test,
                                                     diag_reg=reg,
                                                     get=('NTK', 'NNGP'),
                                                     compute_cov=True)

        zero_prediction = prediction(0.0)

        self.assertAllClose(zero_prediction.ntk, zero_prediction.nngp, True)
        reference = (np.zeros(
            (test_shape[0], out_logits)), ker_fun(x_test, x_test, get='nngp'))
        self.assertAllClose((reference, ) * 2, zero_prediction, True)
Exemplo n.º 3
0
    def testInfiniteTimeAgreement(self, train_shape, test_shape, network,
                                  out_logits, get):

        key = random.PRNGKey(0)

        key, split = random.split(key)
        x_train = np.cos(random.normal(split, train_shape))

        key, split = random.split(key)
        y_train = np.array(
            random.bernoulli(split, shape=(train_shape[0], out_logits)),
            np.float32)

        key, split = random.split(key)
        x_test = np.cos(random.normal(split, test_shape))
        _, _, kernel_fn = _build_network(train_shape[1:], network, out_logits)

        reg = 1e-7
        prediction = predict.gradient_descent_mse_gp(kernel_fn,
                                                     x_train,
                                                     y_train,
                                                     x_test,
                                                     get,
                                                     diag_reg=reg,
                                                     compute_cov=True)

        finite_prediction = prediction(np.inf)
        finite_prediction_none = prediction(None)
        gp_inference = predict.gp_inference(kernel_fn, x_train, y_train,
                                            x_test, get, reg, True)

        self.assertAllClose(finite_prediction_none, finite_prediction, True)
        self.assertAllClose(finite_prediction_none, gp_inference, True)
Exemplo n.º 4
0
    def testNTKPredCovPosDef(self, train_shape, test_shape, network,
                             out_logits):
        key = random.PRNGKey(0)

        key, split = random.split(key)
        x_train = np.cos(random.normal(split, train_shape))

        key, split = random.split(key)
        y_train = np.array(
            random.bernoulli(split, shape=(train_shape[0], out_logits)),
            np.float32)

        key, split = random.split(key)
        x_test = np.cos(random.normal(split, test_shape))
        _, _, ker_fun = _build_network(train_shape[1:], network, out_logits)

        reg = 1e-7
        ntk_predictions = predict.gradient_descent_mse_gp(ker_fun,
                                                          x_train,
                                                          y_train,
                                                          x_test,
                                                          diag_reg=reg,
                                                          get='ntk',
                                                          compute_cov=True)

        ts = np.logspace(-2, 8, 10)

        ntk_cov_predictions = [ntk_predictions(t).covariance for t in ts]

        if xla_bridge.get_backend().platform == 'tpu':
            eigh = np.onp.linalg.eigh
        else:
            eigh = np.linalg.eigh

        check_symmetric = np.array(
            [np.max(np.abs(cov - cov.T)) for cov in ntk_cov_predictions])
        check_pos_evals = np.min(
            np.array([eigh(cov)[0] + 1e-10 for cov in ntk_cov_predictions]))

        self.assertAllClose(check_symmetric, np.zeros_like(check_symmetric),
                            True)
        self.assertGreater(check_pos_evals, 0., True)
Exemplo n.º 5
0
    def testInfiniteTimeAgreement(self, train_shape, test_shape, network,
                                  out_logits, mode):
        # TODO(alemi): Add some finite time tests.

        key = random.PRNGKey(0)

        key, split = random.split(key)
        data_train = np.cos(random.normal(split, train_shape))

        key, split = random.split(key)
        data_labels = np.array(
            random.bernoulli(split, shape=(train_shape[0], out_logits)),
            np.float32)

        key, split = random.split(key)
        data_test = np.cos(random.normal(split, test_shape))
        _, _, ker_fun = _build_network(train_shape[1:], network, out_logits)

        reg = 1e-7
        mean_pred, var = predict.gp_inference(ker_fun,
                                              data_train,
                                              data_labels,
                                              data_test,
                                              diag_reg=reg,
                                              mode=mode,
                                              compute_var=True)
        prediction = predict.gradient_descent_mse_gp(ker_fun,
                                                     data_train,
                                                     data_labels,
                                                     data_test,
                                                     diag_reg=reg,
                                                     mode=mode,
                                                     compute_var=True)

        inf_mean_pred, inf_var = prediction(np.inf)

        self.assertAllClose(mean_pred, inf_mean_pred, True)
        self.assertAllClose(var, inf_var, True)
Exemplo n.º 6
0
    def testNTK_NTKNNGPAgreement(self, train_shape, test_shape, network,
                                 out_logits):
        key = random.PRNGKey(0)

        key, split = random.split(key)
        x_train = np.cos(random.normal(split, train_shape))

        key, split = random.split(key)
        y_train = np.array(
            random.bernoulli(split, shape=(train_shape[0], out_logits)),
            np.float32)

        key, split = random.split(key)
        x_test = np.cos(random.normal(split, test_shape))
        _, _, ker_fun = _build_network(train_shape[1:], network, out_logits)

        reg = 1e-7
        prediction = predict.gradient_descent_mse_gp(ker_fun,
                                                     x_train,
                                                     y_train,
                                                     x_test,
                                                     diag_reg=reg,
                                                     get='NTK',
                                                     compute_cov=True)

        ts = np.logspace(-2, 8, 10)
        ntk_predictions = [prediction(t).mean for t in ts]

        # Create a hacked kernel function that always returns the ntk kernel
        def always_ntk(x1, x2, get=('nngp', 'ntk')):
            out = ker_fun(x1, x2, get=('nngp', 'ntk'))
            if get == 'nngp' or get == 'ntk':
                return out.ntk
            else:
                return out._replace(nngp=out.ntk)

        ntk_nngp_prediction = predict.gradient_descent_mse_gp(always_ntk,
                                                              x_train,
                                                              y_train,
                                                              x_test,
                                                              diag_reg=reg,
                                                              get='NNGP',
                                                              compute_cov=True)

        ntk_nngp_predictions = [ntk_nngp_prediction(t).mean for t in ts]

        # Test if you use the nngp equations with the ntk, you get the same mean
        self.assertAllClose(ntk_predictions, ntk_nngp_predictions, True)

        # Next test that if you go through the NTK code path, but with only
        # the NNGP kernel, we recreate the NNGP dynamics.
        reg = 1e-7
        nngp_prediction = predict.gradient_descent_mse_gp(ker_fun,
                                                          x_train,
                                                          y_train,
                                                          x_test,
                                                          diag_reg=reg,
                                                          get='NNGP',
                                                          compute_cov=True)

        # Create a hacked kernel function that always returns the nngp kernel
        def always_nngp(x1, x2, get=('nngp', 'ntk')):
            out = ker_fun(x1, x2, get=('nngp', 'ntk'))
            if get == 'nngp' or get == 'ntk':
                return out.nngp
            else:
                return out._replace(ntk=out.nngp)

        nngp_ntk_prediction = predict.gradient_descent_mse_gp(always_nngp,
                                                              x_train,
                                                              y_train,
                                                              x_test,
                                                              diag_reg=reg,
                                                              get='NTK',
                                                              compute_cov=True)

        nngp_cov_predictions = [nngp_prediction(t).covariance for t in ts]
        nngp_ntk_cov_predictions = [
            nngp_ntk_prediction(t).covariance for t in ts
        ]

        # Test if you use the ntk equations with the nngp, you get the same cov
        # Although, due to accumulation of numerical errors, only roughly.
        self.assertAllClose(nngp_cov_predictions, nngp_ntk_cov_predictions,
                            True)