예제 #1
0
    def get_state_0(fx_train_or_state_0, fx_test_0, fx_test_shape):
        if isinstance(fx_train_or_state_0, ODEState):
            fx_train_0 = fx_train_or_state_0.fx_train
            fx_test_0 = fx_train_or_state_0.fx_test
            qx_train_0 = fx_train_or_state_0.qx_train
            qx_test_0 = fx_train_or_state_0.qx_test
        else:
            fx_train_0 = fx_train_or_state_0
            qx_train_0 = qx_test_0 = None

        if fx_train_0 is None:
            fx_train_0 = np.zeros_like(y_train, dtype)
        else:
            fx_train_0 = np.broadcast_to(fx_train_0, y_train.shape)

        if fx_test_0 is not None:
            fx_test_0 = np.broadcast_to(fx_test_0, fx_test_shape)

        if momentum is None:
            if qx_train_0 is not None or qx_test_0 is not None:
                raise ValueError('Got passed momentum state variables, while '
                                 '`momentum is None`.')
        else:
            qx_train_0 = (np.zeros_like(y_train, dtype) if qx_train_0 is None
                          else np.broadcast_to(qx_train_0, y_train.shape))
            qx_test_0 = (None if fx_test_0 is None else
                         (np.zeros(fx_test_shape, dtype) if qx_test_0 is None
                          else np.broadcast_to(qx_test_0, fx_test_shape)))

        return ODEState(fx_train_0, fx_test_0, qx_train_0, qx_test_0)  # pytype: disable=wrong-arg-count
예제 #2
0
    def testZeroTimeAgreement(self, train_shape, test_shape, network,
                              out_logits):
        """Test that the NTK and NNGP agree at t=0."""
        _, x_test, x_train, y_train = self._get_inputs(out_logits, test_shape,
                                                       train_shape)
        _, _, ker_fun = _build_network(train_shape[1:], network, out_logits)

        reg = 1e-7
        predictor = predict.gradient_descent_mse_ensemble(ker_fun,
                                                          x_train,
                                                          y_train,
                                                          diag_reg=reg)

        for x in (None, 'x_test'):
            with self.subTest(x=x):
                x = x if x is None else x_test
                zero = predictor(t=0.0,
                                 x_test=x,
                                 get=('NTK', 'NNGP'),
                                 compute_cov=True)
                if x is None:
                    k = ker_fun(x_train, None, get='nngp')
                    ref = (np.zeros_like(y_train, k.dtype), k)
                else:
                    ref = (np.zeros((test_shape[0], out_logits)),
                           ker_fun(x_test, None, get='nngp'))

                self.assertAllClose((ref, ) * 2, zero, check_dtypes=False)
                if x is None:
                    zero_x = predictor(t=0.0,
                                       x_test=x_train,
                                       get=('NTK', 'NNGP'),
                                       compute_cov=True)
                    self.assertAllClose((ref, ) * 2, zero_x)
예제 #3
0
 def init(x0):
   vs = [np.zeros(sz, dtype=x0.dtype) for sz in x0.shape]
   return x0, np.zeros_like(x0), vs
예제 #4
0
 def init(x0):
   m0 = np.zeros_like(x0)
   u0 = np.zeros_like(x0)
   return x0, m0, u0
예제 #5
0
 def init(x0):
   avg_sq_grad = np.zeros_like(x0)
   mom = np.zeros_like(x0)
   return x0, avg_sq_grad, mom
예제 #6
0
 def init(x0):
   g_sq = np.zeros_like(x0)
   m = np.zeros_like(x0)
   return x0, g_sq, m
예제 #7
0
 def init(x0):
   v0 = np.zeros_like(x0)
   return x0, v0
예제 #8
0
    def predict_fn(
        get: Get,
        k_test_train=None,
        nngp_test_test: np.ndarray = None
    ) -> Dict[str, Union[np.ndarray, Gaussian]]:
        """`test`-set posterior given respective covariance matrices.

    Args:
      get:
        string, the mode of the Gaussian process, either "nngp" or "ntk", or a
        tuple, or `None`. If `None` then both `nngp` and `ntk` predictions are
        returned.
      k_test_train:
        test-train kernel. Can be (a) `np.ndarray`, (b) `Kernel` namedtuple, (c)
        `Kernel` object. Must contain the necessary `nngp` and/or `ntk` kernels
        for arguments provided to the returned `predict_fn` function. For
        example, if you request to compute posterior test [only] NTK covariance,
        `k_test_train` must contain both `ntk` and `nngp` kernels. If `None`,
        returns predictions on the training set. Note that train-set outputs are
        always `N(y_train, 0)` and mostly returned for API consistency.
      nngp_test_test:
        A test-test NNGP array. Provide if you want to compute test-test
        posterior covariance. `nngp_test_tes=None`, means to not compute it. If
        `k_test_train is None`, pass any non-`None` value (e.g. `True`) if you
        want to get non-regularized (`diag_reg=0`) train-train posterior
        covariance. Note that non-regularized train-set outputs will always be
        the zero-variance Gaussian `N(y_train, 0)` and mostly returned for API
        consistency. For regularized train-set posterior outputs according to a
        positive `diag_reg`, pass `k_test_train=k_train_train`, and, optionally,
        `nngp_test_test=nngp_train_train`.

    Returns:
      Either a `Gaussian('mean', 'variance')` namedtuple or `mean` of the GP
      posterior on the  `test` set.
    """
        if get is None:
            get = ('nngp', 'ntk')

        out = {}

        for g in get:
            k_dd = _get_attr(k_train_train, g)
            k_td = None if k_test_train is None else _get_attr(k_test_train, g)

            if k_td is None:
                # Train set predictions.
                y = y_train.astype(k_dd.dtype)
            else:
                # Test set predictions.
                y = np.tensordot(k_td, k_inv_y(g), (odd, first))
                y = np.moveaxis(y, range(-len(trace_axes), 0), trace_axes)

            if nngp_test_test is not None:
                if k_td is None:
                    out[g] = Gaussian(y, np.zeros_like(k_dd, k_dd.dtype))
                else:
                    if (g == 'ntk' and (not hasattr(k_train_train, 'nngp')
                                        or not hasattr(k_test_train, 'nngp'))):
                        raise ValueError(
                            'If `"ntk" in get`, and `nngp_test_test is not None`, '
                            'and `k_test_train is not None`, i.e. you request the '
                            'NTK posterior covariance on the test set, you need '
                            'both NTK and NNGP train-train and test-train matrices'
                            'contained in `k_test_train` and `k_train_train`. '
                            'Hence they must be `namedtuple`s with `nngp` and '
                            '`ntk` attributes.')

                    k_td_nngp_inv_y = solve(g)(_get_attr(k_test_train, 'nngp'),
                                               even)

                    if g == 'nngp':
                        cov = np.tensordot(k_td, k_td_nngp_inv_y, (odd, first))
                        cov = nngp_test_test - utils.zip_axes(cov)
                        out[g] = Gaussian(y, cov)

                    elif g == 'ntk':
                        term_1 = solve(g)(k_td, even)
                        cov = np.tensordot(_get_attr(k_train_train, 'nngp'),
                                           term_1, (odd, first))
                        cov = np.tensordot(term_1, cov, (first, first))

                        term_2 = np.tensordot(k_td, k_td_nngp_inv_y,
                                              (odd, first))
                        term_2 += np.moveaxis(term_2, first, last)
                        cov = utils.zip_axes(cov - term_2) + nngp_test_test
                        out[g] = Gaussian(y, cov)

                    else:
                        raise ValueError(g)

            else:
                out[g] = y

        return out