Exemplo n.º 1
0
def _index_and_contract(ntk: np.ndarray,
                        trace_axes: Axes,
                        diagonal_axes: Axes) -> np.ndarray:
  if ntk.ndim % 2 == 1:
    raise ValueError('Expected an even-dimensional kernel. Please file a bug at'
                     'https://github.com/google/neural-tangents/issues/new')

  output_ndim = ntk.ndim // 2
  trace_axes = utils.canonicalize_axis(trace_axes, output_ndim)
  diagonal_axes = utils.canonicalize_axis(diagonal_axes, output_ndim)
  n_marg = len(diagonal_axes)
  contract_size = utils.size_at(ntk.shape[:output_ndim], trace_axes)

  shrink = 0
  for c in reversed(trace_axes):
    ntk = np.trace(ntk, axis1=c, axis2=output_ndim + c - shrink)
    shrink += 1

  for i, d in enumerate(diagonal_axes):
    ntk = np.diagonal(ntk, axis1=d - i, axis2=output_ndim + d - shrink - 2 * i)

  ntk = utils.zip_axes(ntk, 0, ntk.ndim - n_marg)
  res_diagonal_axes = utils.get_res_batch_dims(trace_axes, diagonal_axes)
  ntk = np.moveaxis(ntk, range(-n_marg, 0), res_diagonal_axes)
  return ntk / contract_size
Exemplo n.º 2
0
def _trace_and_diagonal(ntk: np.ndarray, trace_axes: Axes,
                        diagonal_axes: Axes) -> np.ndarray:
    """Extract traces and diagonals along respective pairs of axes from the `ntk`.

  Args:
    ntk:
      input empirical NTK of shape `(N1, X, Y, Z, ..., N2, X, Y, Z, ...)`.
    trace_axes:
      axes (among `X, Y, Z, ...`) to trace over, i.e. compute the trace along
      and remove the  respective pairs of axes from the `ntk`.
    diagonal_axes:
      axes (among `X, Y, Z, ...`) to take the diagonal along, i.e. extract the
      diagonal along the respective pairs of axes from the `ntk` (and hence
      reduce the resulting `ntk` axes count by 2).
  Returns:
    An array of shape, for example, `(N1, N2, Y, Z, Z, ...)` if
    `trace_axes=(1,)` (`X` axes removed), and `diagonal_axes=(2,)` (`Y` axes
    replaced with a single `Y` axis).
  """

    if ntk.ndim % 2 == 1:
        raise ValueError(
            'Expected an even-dimensional kernel. Please file a bug at'
            'https://github.com/google/neural-tangents/issues/new')

    output_ndim = ntk.ndim // 2

    trace_axes = utils.canonicalize_axis(trace_axes, output_ndim)
    diagonal_axes = utils.canonicalize_axis(diagonal_axes, output_ndim)

    n_diag, n_trace = len(diagonal_axes), len(trace_axes)
    contract_size = utils.size_at(ntk.shape[:output_ndim], trace_axes)

    for i, c in enumerate(reversed(trace_axes)):
        ntk = np.trace(ntk, axis1=c, axis2=output_ndim + c - i)

    for i, d in enumerate(diagonal_axes):
        axis1 = d - i
        axis2 = output_ndim + d - 2 * i - n_trace
        for c in trace_axes:
            if c < d:
                axis1 -= 1
                axis2 -= 1
        ntk = np.diagonal(ntk, axis1=axis1, axis2=axis2)

    ntk = utils.zip_axes(ntk, 0, ntk.ndim - n_diag)
    res_diagonal_axes = utils.get_res_batch_dims(trace_axes, diagonal_axes)
    ntk = np.moveaxis(ntk, range(-n_diag, 0), res_diagonal_axes)
    return ntk / contract_size
Exemplo n.º 3
0
 def reshape_cov(cov):
     k = _get_first(k_dd if k_td is None else k_td)
     cov_shape_t = t_shape + k.shape[::2] * 2
     return utils.zip_axes(cov.reshape(cov_shape_t), len(t_shape))
Exemplo n.º 4
0
    def predict_fn(
        get: Get,
        k_test_train=None,
        nngp_test_test: np.ndarray = None
    ) -> Dict[str, Union[np.ndarray, Gaussian]]:
        """`test`-set posterior given respective covariance matrices.

    Args:
      get:
        string, the mode of the Gaussian process, either "nngp" or "ntk", or a
        tuple, or `None`. If `None` then both `nngp` and `ntk` predictions are
        returned.
      k_test_train:
        test-train kernel. Can be (a) `np.ndarray`, (b) `Kernel` namedtuple, (c)
        `Kernel` object. Must contain the necessary `nngp` and/or `ntk` kernels
        for arguments provided to the returned `predict_fn` function. For
        example, if you request to compute posterior test [only] NTK covariance,
        `k_test_train` must contain both `ntk` and `nngp` kernels. If `None`,
        returns predictions on the training set. Note that train-set outputs are
        always `N(y_train, 0)` and mostly returned for API consistency.
      nngp_test_test:
        A test-test NNGP array. Provide if you want to compute test-test
        posterior covariance. `nngp_test_tes=None`, means to not compute it. If
        `k_test_train is None`, pass any non-`None` value (e.g. `True`) if you
        want to get non-regularized (`diag_reg=0`) train-train posterior
        covariance. Note that non-regularized train-set outputs will always be
        the zero-variance Gaussian `N(y_train, 0)` and mostly returned for API
        consistency. For regularized train-set posterior outputs according to a
        positive `diag_reg`, pass `k_test_train=k_train_train`, and, optionally,
        `nngp_test_test=nngp_train_train`.

    Returns:
      Either a `Gaussian('mean', 'variance')` namedtuple or `mean` of the GP
      posterior on the  `test` set.
    """
        if get is None:
            get = ('nngp', 'ntk')

        out = {}

        for g in get:
            k_dd = _get_attr(k_train_train, g)
            k_td = None if k_test_train is None else _get_attr(k_test_train, g)

            if k_td is None:
                # Train set predictions.
                y = y_train.astype(k_dd.dtype)
            else:
                # Test set predictions.
                y = np.tensordot(k_td, k_inv_y(g), (odd, first))
                y = np.moveaxis(y, range(-len(trace_axes), 0), trace_axes)

            if nngp_test_test is not None:
                if k_td is None:
                    out[g] = Gaussian(y, np.zeros_like(k_dd, k_dd.dtype))
                else:
                    if (g == 'ntk' and (not hasattr(k_train_train, 'nngp')
                                        or not hasattr(k_test_train, 'nngp'))):
                        raise ValueError(
                            'If `"ntk" in get`, and `nngp_test_test is not None`, '
                            'and `k_test_train is not None`, i.e. you request the '
                            'NTK posterior covariance on the test set, you need '
                            'both NTK and NNGP train-train and test-train matrices '
                            'contained in `k_test_train` and `k_train_train`. '
                            'Hence they must be `namedtuple`s with `nngp` and '
                            '`ntk` attributes.')

                    k_td_nngp_inv_y = solve(g)(_get_attr(k_test_train, 'nngp'),
                                               even)

                    if g == 'nngp':
                        cov = np.tensordot(k_td, k_td_nngp_inv_y, (odd, first))
                        cov = nngp_test_test - utils.zip_axes(cov)
                        out[g] = Gaussian(y, cov)

                    elif g == 'ntk':
                        term_1 = solve(g)(k_td, even)
                        cov = np.tensordot(_get_attr(k_train_train, 'nngp'),
                                           term_1, (odd, first))
                        cov = np.tensordot(term_1, cov, (first, first))

                        term_2 = np.tensordot(k_td, k_td_nngp_inv_y,
                                              (odd, first))
                        term_2 += np.moveaxis(term_2, first, last)
                        cov = utils.zip_axes(cov - term_2) + nngp_test_test
                        out[g] = Gaussian(y, cov)

                    else:
                        raise ValueError(g)

            else:
                out[g] = y

        return out