예제 #1
0
def _index_and_contract(ntk: np.ndarray,
                        trace_axes: Axes,
                        diagonal_axes: Axes) -> np.ndarray:
  if ntk.ndim % 2 == 1:
    raise ValueError('Expected an even-dimensional kernel. Please file a bug at'
                     'https://github.com/google/neural-tangents/issues/new')

  output_ndim = ntk.ndim // 2
  trace_axes = utils.canonicalize_axis(trace_axes, output_ndim)
  diagonal_axes = utils.canonicalize_axis(diagonal_axes, output_ndim)
  n_marg = len(diagonal_axes)
  contract_size = utils.size_at(ntk.shape[:output_ndim], trace_axes)

  shrink = 0
  for c in reversed(trace_axes):
    ntk = np.trace(ntk, axis1=c, axis2=output_ndim + c - shrink)
    shrink += 1

  for i, d in enumerate(diagonal_axes):
    ntk = np.diagonal(ntk, axis1=d - i, axis2=output_ndim + d - shrink - 2 * i)

  ntk = utils.zip_axes(ntk, 0, ntk.ndim - n_marg)
  res_diagonal_axes = utils.get_res_batch_dims(trace_axes, diagonal_axes)
  ntk = np.moveaxis(ntk, range(-n_marg, 0), res_diagonal_axes)
  return ntk / contract_size
예제 #2
0
  def testAxes(self, diagonal_axes, trace_axes):
    key = random.PRNGKey(0)
    key, self_split, other_split = random.split(key, 3)
    data_self = random.normal(self_split, (4, 5, 6, 3))
    data_other = random.normal(other_split, (2, 5, 6, 3))

    _diagonal_axes = utils.canonicalize_axis(diagonal_axes, data_self)
    _trace_axes = utils.canonicalize_axis(trace_axes, data_self)

    if any(d == c for d in _diagonal_axes for c in _trace_axes):
      raise absltest.SkipTest(
          'diagonal axes must be different from channel axes.')

    get_kernel = KERNELS['empirical_logits_3']
    kwargs = dict(
        key=key,
        input_shape=(5, 6, 3),
        network=CONV,
        diagonal_axes=diagonal_axes,
        trace_axes=trace_axes
    )

    implicit, direct, nngp = get_kernel(**kwargs)
    implicit_batched, direct_batched, _ = get_kernel(**kwargs, vmap_axes=0)

    n_marg = len(_diagonal_axes)
    n_chan = len(_trace_axes)

    g_nngp = nngp(data_self, None)
    self.assertEqual(2 * (data_self.ndim - n_chan) - n_marg, g_nngp.ndim)

    g_direct = direct(data_self, None)
    self.assertEqual(g_nngp.shape, g_direct.shape)

    g_direct_batched = direct_batched(data_self, None)
    g = implicit(data_self, None)
    g_batched = implicit_batched(data_self, None)

    self.assertAllClose(g_direct, g)
    self.assertAllClose(g_direct, g_direct_batched)
    self.assertAllClose(g_direct, g_batched)

    if 0 not in _trace_axes and 0 not in _diagonal_axes:
      g_nngp = nngp(data_other, data_self)
      self.assertEqual(2 * (data_self.ndim - n_chan) - n_marg, g_nngp.ndim)

      g_direct = direct(data_other, data_self)
      self.assertEqual(g_nngp.shape, g_direct.shape)

      g_direct_batched = direct_batched(data_other, data_self)
      g = implicit(data_other, data_self)
      g_batched = implicit_batched(data_other, data_self)

      self.assertAllClose(g_direct, g)
      self.assertAllClose(g_direct, g_direct_batched)
      self.assertAllClose(g_direct, g_batched)
예제 #3
0
  def sum_and_contract(j1, j2, output_ndim):
    _diagonal_axes = utils.canonicalize_axis(diagonal_axes, output_ndim)
    _trace_axes = utils.canonicalize_axis(trace_axes, output_ndim)

    def contract(x, y):
      param_axes = list(range(x.ndim))[output_ndim:]
      contract_axes = _trace_axes + param_axes
      return utils.dot_general(x, y, contract_axes, _diagonal_axes)

    return tree_reduce(operator.add, tree_multimap(contract, j1, j2))
예제 #4
0
    def sum_and_contract(fx, j1, j2):
        ndim = fx.ndim
        size = utils.size_at(fx, trace_axes)

        _diagonal_axes = utils.canonicalize_axis(diagonal_axes, ndim)
        _trace_axes = utils.canonicalize_axis(trace_axes, ndim)

        def contract(x, y):
            param_axes = list(range(x.ndim))[ndim:]
            contract_axes = _trace_axes + param_axes
            return utils.dot_general(x, y, contract_axes,
                                     _diagonal_axes) / size

        return tree_reduce(operator.add, tree_multimap(contract, j1, j2))
예제 #5
0
def _trace_and_diagonal(ntk: np.ndarray, trace_axes: Axes,
                        diagonal_axes: Axes) -> np.ndarray:
    """Extract traces and diagonals along respective pairs of axes from the `ntk`.

  Args:
    ntk:
      input empirical NTK of shape `(N1, X, Y, Z, ..., N2, X, Y, Z, ...)`.
    trace_axes:
      axes (among `X, Y, Z, ...`) to trace over, i.e. compute the trace along
      and remove the  respective pairs of axes from the `ntk`.
    diagonal_axes:
      axes (among `X, Y, Z, ...`) to take the diagonal along, i.e. extract the
      diagonal along the respective pairs of axes from the `ntk` (and hence
      reduce the resulting `ntk` axes count by 2).
  Returns:
    An array of shape, for example, `(N1, N2, Y, Z, Z, ...)` if
    `trace_axes=(1,)` (`X` axes removed), and `diagonal_axes=(2,)` (`Y` axes
    replaced with a single `Y` axis).
  """

    if ntk.ndim % 2 == 1:
        raise ValueError(
            'Expected an even-dimensional kernel. Please file a bug at'
            'https://github.com/google/neural-tangents/issues/new')

    output_ndim = ntk.ndim // 2

    trace_axes = utils.canonicalize_axis(trace_axes, output_ndim)
    diagonal_axes = utils.canonicalize_axis(diagonal_axes, output_ndim)

    n_diag, n_trace = len(diagonal_axes), len(trace_axes)
    contract_size = utils.size_at(ntk.shape[:output_ndim], trace_axes)

    for i, c in enumerate(reversed(trace_axes)):
        ntk = np.trace(ntk, axis1=c, axis2=output_ndim + c - i)

    for i, d in enumerate(diagonal_axes):
        axis1 = d - i
        axis2 = output_ndim + d - 2 * i - n_trace
        for c in trace_axes:
            if c < d:
                axis1 -= 1
                axis2 -= 1
        ntk = np.diagonal(ntk, axis1=axis1, axis2=axis2)

    ntk = utils.zip_axes(ntk, 0, ntk.ndim - n_diag)
    res_diagonal_axes = utils.get_res_batch_dims(trace_axes, diagonal_axes)
    ntk = np.moveaxis(ntk, range(-n_diag, 0), res_diagonal_axes)
    return ntk / contract_size
    def testAxes(self, diagonal_axes, trace_axes):
        key = stateless_uniform(shape=[2],
                                seed=[0, 0],
                                minval=None,
                                maxval=None,
                                dtype=tf.int32)
        splits = tf_random_split(seed=tf.convert_to_tensor(key,
                                                           dtype=tf.int32),
                                 num=3)
        key = splits[0]
        self_split = splits[1]
        other_split = splits[2]
        data_self = np.asarray(normal((4, 5, 6, 3), seed=self_split))
        data_other = np.asarray(normal((2, 5, 6, 3), seed=other_split))

        _diagonal_axes = utils.canonicalize_axis(diagonal_axes, data_self)
        _trace_axes = utils.canonicalize_axis(trace_axes, data_self)

        if any(d == c for d in _diagonal_axes for c in _trace_axes):
            raise absltest.SkipTest(
                'diagonal axes must be different from channel axes.')

        implicit, direct, nngp = KERNELS['empirical_logits_3'](
            key, (5, 6, 3),
            CONV,
            diagonal_axes=diagonal_axes,
            trace_axes=trace_axes)

        n_marg = len(_diagonal_axes)
        n_chan = len(_trace_axes)

        g = implicit(data_self, None)
        g_direct = direct(data_self, None)
        g_nngp = nngp(data_self, None)

        self.assertAllClose(g, g_direct)
        self.assertEqual(g_nngp.shape, g.shape)
        self.assertEqual(2 * (data_self.ndim - n_chan) - n_marg, g_nngp.ndim)

        if 0 not in _trace_axes and 0 not in _diagonal_axes:
            g = implicit(data_other, data_self)
            g_direct = direct(data_other, data_self)
            g_nngp = nngp(data_other, data_self)

            self.assertAllClose(g, g_direct)
            self.assertEqual(g_nngp.shape, g.shape)
            self.assertEqual(2 * (data_self.ndim - n_chan) - n_marg,
                             g_nngp.ndim)
예제 #7
0
def _get_fx_test_shape(y_train: np.ndarray, k_test_train: np.ndarray,
                       y_axes: Axes) -> Tuple[int, ...]:
    if k_test_train is None:
        return y_train.shape

    shape = list(k_test_train.shape[::2])
    y_axes = utils.canonicalize_axis(y_axes, y_train)
    for i, c in enumerate(y_train.shape):
        if i in y_axes:
            shape.insert(i, c)
    return tuple(shape)
예제 #8
0
    def cho_solve(b: np.ndarray, b_axes: Axes) -> np.ndarray:
        b_axes = utils.canonicalize_axis(b_axes, b)
        last_b_axes = range(-len(b_axes), 0)
        x_shape = x_non_channel_shape + tuple(b.shape[a] for a in b_axes)

        b = np.moveaxis(b, b_axes, last_b_axes)
        b = b.reshape((A.shape[1], -1))

        x = sp.linalg.cho_solve(C, b)
        x = x.reshape(x_shape)
        return x
예제 #9
0
def gradient_descent_mse_ensemble(kernel_fn: KernelFn,
                                  x_train: np.ndarray,
                                  y_train: np.ndarray,
                                  learning_rate: float = 1.,
                                  diag_reg: float = 0.0,
                                  diag_reg_absolute_scale: bool = False,
                                  trace_axes: Axes = (-1, ),
                                  **kernel_fn_kwargs):
    r"""Predicts the gaussian embedding induced by gradient descent on MSE loss.

  This is equivalent to an infinite ensemble of infinite-width networks after
  marginalizing out the initialization, if `kernel_fn` is the kernel function of
  the infinite-width network. Note that `kernel_fn` can in principle also be an
  empirical / Monte Carlo finite-width kernel function, but in this case the
  returned output will not have a simple interpretation (unless these functions
  are used to approximate the infinite-width kernel).

  Note that first invocation of the returned `predict_fn` will be slow and
  allocate a lot of memory for its whole lifetime, as the kernel computation,
  and either eigendecomposition (`t` is a scalar or an array) or Cholesky
  factorization (`t=None`) of `kernel_fn(x_train, None, get)` is performed and
  cached for future invocations (or both, if the function is called on both
  finite and infinite (`t=None`) times).

  Args:
    kernel_fn:
      A kernel function that computes NNGP and/or NTK. Must have a signature
      `kernel_fn(x1, x2, get, **kernel_fn_kwargs)` and return a `Kernel` object
      or a `namedtuple` with `nngp` and/or `ntk` attributes. Therefore, it can
      be an `AnalyticKernelFn`, but also a `MonteCarloKernelFn`, or an
      `EmpiricalKernelFn` (but only `nt.empirical_kernel_fn` and not
      `nt.empirical_ntk_fn` or `ntk.empirical_nngp_fn`, since the latter two do
      not accept a `get` argument). Note that for meaningful outputs, the kernel
      function must represent or at least approximate the infinite-width kernel.
    x_train:
      training inputs.
    y_train:
      training targets.
    learning_rate:
      learning rate, step size.
    diag_reg:
      a scalar representing the strength of the diagonal regularization for
      `kernel_fn(x_train, None, get)`, i.e. computing
      `kernel_fn(x_train, None, get) + diag_reg * I` during Cholesky
      factorization or eigendecomposition.
    diag_reg_absolute_scale:
      `True` for `diag_reg` to represent regularization in absolute units,
      `False` to be
      `diag_reg * np.mean(np.trace(kernel_fn(x_train, None, get)))`.
    trace_axes:
      `f(x_train)` axes such that `kernel_fn(x_train, None, get)`,
      `kernel_fn(x_test, x_train, get)`[, and `kernel_fn(x_test, None, get)`]
      lack these pairs of dimensions and are to be interpreted as
      :math:`\Theta \otimes I`, i.e. block-diagonal along `trace_axes`. These
      can can be specified either to save space and compute, or to even improve
      approximation accuracy of the infinite-width or infinite-samples limit,
      since in in these limits the covariance along channel / feature / logit
      axes indeed converges to a constant-diagonal matrix. However, if you
      target linearized dynamics of a specific finite-width network,
      `trace_axes=()` will yield most accurate result.
    **kernel_fn_kwargs:
      optional keyword arguments passed to `kernel_fn`.

  Returns:
    A function with signature `predict_fn(t, x_test, get, compute_cov)`
    returning either mean or mean and covariance of the infinite ensemble of
    infinite-width networks outputs on `x_test` at time[s] `t`, in the `get`
    regime (`"nngp"`, `"ntk"`, or `("nngp", "ntk")`).
  """
    expm1 = _make_expm1_fn(y_train.size)
    inv_expm1 = _make_inv_expm1_fn(y_train.size)

    trace_axes = utils.canonicalize_axis(trace_axes, y_train)
    trace_axes = tuple(-y_train.ndim + a for a in trace_axes)
    n_trace_axes = len(trace_axes)
    last_t_axes = range(-n_trace_axes, 0)
    trace_shape = tuple(y_train.shape[a] for a in trace_axes)

    y_train_flat = np.moveaxis(y_train, trace_axes,
                               last_t_axes).reshape((-1, ) + trace_shape)

    k_dd_cache = {}

    def get_k_train_train(get: Sequence[str]) -> _Kernel:
        if len(get) == 1:
            get = get[0]
            if get not in k_dd_cache:
                k_dd_cache[get] = kernel_fn(x_train, None, get,
                                            **kernel_fn_kwargs)

        elif len(get) == 2:
            if not any(g in k_dd_cache for g in get):
                k_dd_cache.update(
                    kernel_fn(x_train, None, get,
                              **kernel_fn_kwargs)._asdict())
            else:
                for g in get:
                    if g not in k_dd_cache:
                        k_dd_cache[g] = kernel_fn(x_train, None, g,
                                                  **kernel_fn_kwargs)

        else:
            raise ValueError(get)
        return _Kernel(**k_dd_cache)

    @lru_cache(2)
    def eigenspace(get: str):
        k_dd = getattr(get_k_train_train((get, )), get)
        k_dd = _add_diagonal_regularizer(utils.make_2d(k_dd), diag_reg,
                                         diag_reg_absolute_scale)
        evals, evecs = np.linalg.eigh(k_dd)
        evals = np.expand_dims(evals, 0)
        return evals, evecs

    @lru_cache(4)
    def predict_inf(get: Get):
        _, get = utils.canonicalize_get(get)
        k_dd = get_k_train_train(get)
        return gp_inference(k_dd, y_train, diag_reg, diag_reg_absolute_scale,
                            trace_axes)

    def get_matrices(get: Get, x_test: Optional[np.ndarray],
                     compute_cov: bool):
        get = _get_dependency(get, compute_cov)
        k_dd = get_k_train_train(get)
        if x_test is None:
            k_td = None
            nngp_tt = compute_cov or None
        else:
            k_td = kernel_fn(x_test, x_train, get, **kernel_fn_kwargs)
            if compute_cov:
                nngp_tt = kernel_fn(x_test, None, 'nngp', **kernel_fn_kwargs)
            else:
                nngp_tt = None
        return k_dd, k_td, nngp_tt

    @utils.get_namedtuple('Gaussians')
    def predict_fn(t: ArrayOrScalar = None,
                   x_test: np.ndarray = None,
                   get: Get = None,
                   compute_cov: bool = False) -> Dict[str, Gaussian]:
        """Return output mean and covariance on the test set at time[s] `t`.

    Args:
      t:
        a scalar of array of scalars of any shape. `t=None` is treated as
        infinity and returns the same result as `t=np.inf`, but is computed
        using linear solve for test predictions instead of eigendecomposition,
        saving time and precision.
      x_test:
        test inputs. `None` means to return non-regularized (`diag_reg=0`)
        predictions on the train-set inputs. For regularized predictions, pass
        `x_test=x_train`.
      get:
        string, the mode of the Gaussian process, either "nngp" or "ntk", or a
        tuple. `get=None` is equivalent to `get=("nngp", "ntk")`.
      compute_cov:
        if `True` computing both `mean` and `variance` and only `mean`
        otherwise.

    Returns:
      `fx_test_mean_t` or `(fx_test_mean_t, fx_test_cov_t)` if
      `compute_cov == True` with potentially additional leading time dimensions.
    """
        if get is None:
            get = ('nngp', 'ntk')

        # train-train, test-train, test-test.
        k_dd, k_td, nngp_tt = get_matrices(get, x_test, compute_cov)

        # Infinite time.
        if t is None:
            return predict_inf(get)(get=get,
                                    k_test_train=k_td,
                                    nngp_test_test=nngp_tt)

        # Finite time.
        t = np.array(t) * learning_rate
        t_shape = t.shape
        t = t.reshape((-1, 1))

        def reshape_mean(mean):
            k = _get_first(k_dd if k_td is None else k_td)
            mean = mean.reshape(t_shape + k.shape[::2] + trace_shape)
            mean = np.moveaxis(mean, last_t_axes, trace_axes)
            return mean

        def reshape_cov(cov):
            k = _get_first(k_dd if k_td is None else k_td)
            cov_shape_t = t_shape + k.shape[::2] * 2
            return utils.zip_axes(cov.reshape(cov_shape_t), len(t_shape))

        out = {}

        for g in get:
            evals, evecs = eigenspace(g)

            # Training set.
            if k_td is None:
                mean = np.einsum('ji,ti,ki,k...->tj...',
                                 evecs,
                                 -expm1(evals, t),
                                 evecs,
                                 y_train_flat,
                                 optimize=True)

            # Test set.
            else:
                neg_inv_expm1 = -inv_expm1(evals, t)
                ktd_g = utils.make_2d(getattr(k_td, g))
                mean = np.einsum('lj,ji,ti,ki,k...->tl...',
                                 ktd_g,
                                 evecs,
                                 neg_inv_expm1,
                                 evecs,
                                 y_train_flat,
                                 optimize=True)

            mean = reshape_mean(mean)

            if nngp_tt is not None:
                nngp_dd = utils.make_2d(k_dd.nngp)

                # Training set.
                if k_td is None:
                    if g == 'nngp':
                        cov = np.einsum('ji,ti,ki->tjk',
                                        evecs,
                                        (np.maximum(evals, 0.) *
                                         np.exp(-2 * np.maximum(evals, 0.) *
                                                t / y_train.size)),
                                        evecs,
                                        optimize=True)

                    elif g == 'ntk':
                        exp = np.einsum('mi,ti,ki->tmk',
                                        evecs,
                                        np.exp(-np.maximum(evals, 0.) * t /
                                               y_train.size),
                                        evecs,
                                        optimize=True)
                        cov = np.einsum('tmk,kl,tnl->tmn',
                                        exp,
                                        nngp_dd,
                                        exp,
                                        optimize=True)

                    else:
                        raise ValueError(g)

                # Test set.
                else:
                    _nngp_tt = np.expand_dims(utils.make_2d(nngp_tt), 0)

                    if g == 'nngp':
                        cov = _nngp_tt - np.einsum('mj,ji,ti,ki,lk->tml',
                                                   ktd_g,
                                                   evecs,
                                                   -inv_expm1(evals, 2 * t),
                                                   evecs,
                                                   ktd_g,
                                                   optimize=True)

                    elif g == 'ntk':
                        term_1 = np.einsum('mi,ti,ki,lk->tml',
                                           evecs,
                                           neg_inv_expm1,
                                           evecs,
                                           ktd_g,
                                           optimize=True)
                        term_2 = np.einsum(
                            'mj,ji,ti,ki,lk->tml',
                            ktd_g,
                            evecs,
                            neg_inv_expm1,
                            evecs,
                            utils.make_2d(k_td.nngp),  # pytype:disable=attribute-error
                            optimize=True)
                        term_2 += np.moveaxis(term_2, 1, 2)
                        cov = np.einsum('tji,jk,tkl->til',
                                        term_1,
                                        nngp_dd,
                                        term_1,
                                        optimize=True)
                        cov += -term_2 + _nngp_tt

                    else:
                        raise ValueError(g)

                out[g] = Gaussian(mean, reshape_cov(cov))

            else:
                out[g] = mean

        return out

    return predict_fn
예제 #10
0
def gp_inference(k_train_train,
                 y_train: np.ndarray,
                 diag_reg: float = 0.,
                 diag_reg_absolute_scale: bool = False,
                 trace_axes: Axes = (-1, )):
    r"""Compute the mean and variance of the `posterior` of NNGP and NTK.

  Note that first invocation of the returned `predict_fn` will be slow and
  allocate a lot of memory for its whole lifetime, as a Cholesky factorization
  of `k_train_train.nngp` or `k_train_train.ntk` (or both) is performed and
  cached for future invocations.

  Args:
    k_train_train:
      train-train kernel. Can be (a) `np.ndarray`, (b) `Kernel` namedtuple, (c)
      `Kernel` object. Must contain the necessary `nngp` and/or `ntk` kernels
      for arguments provided to the returned `predict_fn` function. For
      example, if you request to compute posterior test [only] NTK covariance in
      future `predict_fn` invocations, `k_train_train` must contain both `ntk`
      and `nngp` kernels.
    y_train:
      train targets.
    diag_reg:
      a scalar representing the strength of the diagonal regularization for
      `k_train_train`, i.e. computing `k_train_train + diag_reg * I` during
      Cholesky factorization.
    diag_reg_absolute_scale:
      `True` for `diag_reg` to represent regularization in absolute units,
      `False` to be `diag_reg * np.mean(np.trace(k_train_train))`.
    trace_axes:
      `f(x_train)` axes such that `k_train_train`,
      `k_test_train`[, and `nngp_test_test`] lack these pairs of dimensions and
      are to be interpreted as :math:`\Theta \otimes I`, i.e. block-diagonal
      along `trace_axes`. These can can be specified either to save space and
      compute, or to even improve approximation accuracy of the infinite-width
      or infinite-samples limit, since in in these limits the covariance along
      channel / feature / logit axes indeed converges to a  constant-diagonal
      matrix. However, if you target linearized dynamics of a specific
      finite-width network, `trace_axes=()` will yield most accurate result.

  Returns:
    A function of signature `predict_fn(get, k_test_train, nngp_test_test)`
    computing posterior Gaussian distribution (mean or mean and covariance)
    on a given test set.
  """
    even, odd, first, last = _get_axes(_get_first(k_train_train))
    trace_axes = utils.canonicalize_axis(trace_axes, y_train)

    @lru_cache(2)
    def solve(g: str):
        k_dd = _get_attr(k_train_train, g)
        return _get_cho_solve(k_dd, diag_reg, diag_reg_absolute_scale)

    @lru_cache(2)
    def k_inv_y(g: str):
        return solve(g)(y_train, trace_axes)

    @utils.get_namedtuple('Gaussians')
    def predict_fn(
        get: Get,
        k_test_train=None,
        nngp_test_test: np.ndarray = None
    ) -> Dict[str, Union[np.ndarray, Gaussian]]:
        """`test`-set posterior given respective covariance matrices.

    Args:
      get:
        string, the mode of the Gaussian process, either "nngp" or "ntk", or a
        tuple, or `None`. If `None` then both `nngp` and `ntk` predictions are
        returned.
      k_test_train:
        test-train kernel. Can be (a) `np.ndarray`, (b) `Kernel` namedtuple, (c)
        `Kernel` object. Must contain the necessary `nngp` and/or `ntk` kernels
        for arguments provided to the returned `predict_fn` function. For
        example, if you request to compute posterior test [only] NTK covariance,
        `k_test_train` must contain both `ntk` and `nngp` kernels. If `None`,
        returns predictions on the training set. Note that train-set outputs are
        always `N(y_train, 0)` and mostly returned for API consistency.
      nngp_test_test:
        A test-test NNGP array. Provide if you want to compute test-test
        posterior covariance. `nngp_test_tes=None`, means to not compute it. If
        `k_test_train is None`, pass any non-`None` value (e.g. `True`) if you
        want to get non-regularized (`diag_reg=0`) train-train posterior
        covariance. Note that non-regularized train-set outputs will always be
        the zero-variance Gaussian `N(y_train, 0)` and mostly returned for API
        consistency. For regularized train-set posterior outputs according to a
        positive `diag_reg`, pass `k_test_train=k_train_train`, and, optionally,
        `nngp_test_test=nngp_train_train`.

    Returns:
      Either a `Gaussian('mean', 'variance')` namedtuple or `mean` of the GP
      posterior on the  `test` set.
    """
        if get is None:
            get = ('nngp', 'ntk')

        out = {}

        for g in get:
            k_dd = _get_attr(k_train_train, g)
            k_td = None if k_test_train is None else _get_attr(k_test_train, g)

            if k_td is None:
                # Train set predictions.
                y = y_train.astype(k_dd.dtype)
            else:
                # Test set predictions.
                y = np.tensordot(k_td, k_inv_y(g), (odd, first))
                y = np.moveaxis(y, range(-len(trace_axes), 0), trace_axes)

            if nngp_test_test is not None:
                if k_td is None:
                    out[g] = Gaussian(y, np.zeros_like(k_dd, k_dd.dtype))
                else:
                    if (g == 'ntk' and (not hasattr(k_train_train, 'nngp')
                                        or not hasattr(k_test_train, 'nngp'))):
                        raise ValueError(
                            'If `"ntk" in get`, and `nngp_test_test is not None`, '
                            'and `k_test_train is not None`, i.e. you request the '
                            'NTK posterior covariance on the test set, you need '
                            'both NTK and NNGP train-train and test-train matrices '
                            'contained in `k_test_train` and `k_train_train`. '
                            'Hence they must be `namedtuple`s with `nngp` and '
                            '`ntk` attributes.')

                    k_td_nngp_inv_y = solve(g)(_get_attr(k_test_train, 'nngp'),
                                               even)

                    if g == 'nngp':
                        cov = np.tensordot(k_td, k_td_nngp_inv_y, (odd, first))
                        cov = nngp_test_test - utils.zip_axes(cov)
                        out[g] = Gaussian(y, cov)

                    elif g == 'ntk':
                        term_1 = solve(g)(k_td, even)
                        cov = np.tensordot(_get_attr(k_train_train, 'nngp'),
                                           term_1, (odd, first))
                        cov = np.tensordot(term_1, cov, (first, first))

                        term_2 = np.tensordot(k_td, k_td_nngp_inv_y,
                                              (odd, first))
                        term_2 += np.moveaxis(term_2, first, last)
                        cov = utils.zip_axes(cov - term_2) + nngp_test_test
                        out[g] = Gaussian(y, cov)

                    else:
                        raise ValueError(g)

            else:
                out[g] = y

        return out

    return predict_fn
예제 #11
0
def gradient_descent_mse(
    k_train_train: np.ndarray,
    y_train: np.ndarray,
    learning_rate: float = 1.,
    diag_reg: float = 0.,
    diag_reg_absolute_scale: bool = False,
    trace_axes: Axes = (-1, )
) -> Callable[
    [ArrayOrScalar, ArrayOrScalar, ArrayOrScalar, Optional[np.ndarray]], Union[
        np.ndarray, Tuple[np.ndarray, np.ndarray]]]:
    r"""Predicts the outcome of function space gradient descent training on MSE.

  Solves in closed form for the continuous-time version of gradient descent.

  Uses the closed-form solution for gradient descent on an MSE loss in function
  space detailed in [*,**] given a Neural Tangent or Neural Network Gaussian
  Process Kernel over the dataset. Given NNGP or NTK, this function will return
  a function that predicts the time evolution for function space points at
  arbitrary time[s] (training step[s]) `t`. Note that these time[s] (step[s])
  are continuous and are interpreted in units of the `learning_rate` so
  `absolute_time = learning_rate * t`, and the scales of `learning_rate` and `t`
  are interchangeable.

  Note that first invocation of the returned `predict_fn` will be slow and
  allocate a lot of memory for its whole lifetime, as either eigendecomposition
  (`t` is a scalar or an array) or Cholesky factorization (`t=None`) of
  `k_train_train` is performed and cached for future invocations (or both, if
  the function is called on both finite and infinite (`t=None`) times).

  [*] https://arxiv.org/abs/1806.07572
  [**] https://arxiv.org/abs/1902.06720

  Example:
    >>> from neural_tangents import empirical_ntk_fn
    >>> from neural_tangents import predict
    >>>
    >>> t = 1e-7
    >>> kernel_fn = empirical_ntk_fn(f)
    >>> k_train_train = kernel_fn(x_train, None, params)
    >>> k_test_train = kernel_fn(x_test, x_train, params)
    >>>
    >>> predict_fn = predict.gradient_descent_mse(k_train_train, y_train)
    >>>
    >>> fx_train_0 = f(params, x_train)
    >>> fx_test_0 = f(params, x_test)
    >>>
    >>> fx_train_t, fx_test_t = predict_fn(t, fx_train_0, fx_test_0,
    >>>                                    k_test_train)

  Args:
    k_train_train:
      kernel on the training data. Must have the shape of
      `zip(y_train.shape, y_train.shape)` with `trace_axes` absent.
    y_train:
      targets for the training data.
    learning_rate:
      learning rate, step size.
    diag_reg:
      a scalar representing the strength of the diagonal regularization for
      `k_train_train`, i.e. computing `k_train_train + diag_reg * I` during
      Cholesky factorization or eigendecomposition.
    diag_reg_absolute_scale:
      `True` for `diag_reg` to represent regularization in absolute units,
      `False` to be `diag_reg * np.mean(np.trace(k_train_train))`.
    trace_axes:
      `f(x_train)` axes such that `k_train_train` lacks these pairs of
      dimensions and is to be interpreted as :math:`\Theta \otimes I`, i.e.
      block-diagonal along `trace_axes`. These can can be specified either to
      save space and compute, or to even improve approximation accuracy of the
      infinite-width or infinite-samples limit, since in in these limits the
      covariance along channel / feature / logit axes indeed converges to a
      constant-diagonal matrix. However, if you target linearized dynamics of a
      specific finite-width network, `trace_axes=()` will yield most accurate
      result.

  Returns:
    A function of signature
    `predict_fn(t, fx_train_0, fx_test_0, k_test_train)` that
    returns output train [and test] set[s] predictions at time[s] `t`.
  """
    _, odd, first, _ = _get_axes(k_train_train)
    trace_axes = utils.canonicalize_axis(trace_axes, y_train)
    trace_axes = tuple(-y_train.ndim + a for a in trace_axes)
    n_t_axes, n_non_t_axes = len(trace_axes), y_train.ndim - len(trace_axes)
    last_t_axes = tuple(range(-n_t_axes, 0))
    non_t_axes = tuple(range(-y_train.ndim, -n_t_axes))

    @lru_cache(1)
    def get_predict_fn_inf():
        with jax.core.eval_context():
            solve = _get_cho_solve(k_train_train, diag_reg,
                                   diag_reg_absolute_scale)

        def predict_fn_inf(fx_train_0, fx_test_0, k_test_train):
            fx_train_t = y_train.astype(k_train_train.dtype)
            if fx_test_0 is None:
                return fx_train_t

            rhs = y_train if fx_train_0 is None else y_train - fx_train_0
            dfx_test = np.tensordot(k_test_train, solve(rhs, trace_axes),
                                    (odd, first))
            dfx_test = np.moveaxis(dfx_test, last_t_axes, trace_axes)
            fx_test_t = fx_test_0 + dfx_test

            if fx_train_0 is None:
                return fx_test_t
            return fx_train_t, fx_test_t

        return predict_fn_inf

    @lru_cache(1)
    def get_predict_fn_finite():
        with jax.core.eval_context():
            expm1_fn, inv_expm1_fn = _get_fns_in_eigenbasis(
                k_train_train, diag_reg, diag_reg_absolute_scale,
                (_make_expm1_fn(y_train.size), _make_inv_expm1_fn(
                    y_train.size)))

        rhs_shape = tuple(y_train.shape[a] for a in trace_axes)

        def predict_fn_finite(t, fx_train_0, fx_test_0, k_test_train):
            t = np.array(t) * learning_rate
            t_shape, t_ndim = t.shape, t.ndim
            first_t_axes = tuple(range(t_ndim))
            t = t.reshape((-1, 1))

            rhs = -y_train if fx_train_0 is None else fx_train_0 - y_train
            rhs = np.moveaxis(rhs, trace_axes,
                              last_t_axes).reshape((-1, ) + rhs_shape)
            shape = t_shape + k_train_train.shape[1::2] + rhs_shape

            if fx_train_0 is not None:
                dfx_train = expm1_fn(rhs, t).reshape(shape)
                dfx_train = np.moveaxis(dfx_train, last_t_axes, trace_axes)
                fx_train_t = np.expand_dims(fx_train_0,
                                            first_t_axes) + dfx_train

            if fx_test_0 is not None:
                dfx_test = inv_expm1_fn(rhs, t).reshape(shape)
                dfx_test = np.tensordot(k_test_train, dfx_test,
                                        (odd, non_t_axes))
                dfx_test = np.moveaxis(
                    dfx_test,
                    tuple(range(n_non_t_axes, n_non_t_axes + t_ndim)) +
                    last_t_axes,
                    tuple(range(t_ndim)) + trace_axes)
                fx_test_t = np.expand_dims(fx_test_0, first_t_axes) + dfx_test

            if fx_train_0 is not None and fx_test_0 is not None:
                return fx_train_t, fx_test_t
            if fx_test_0 is None:
                return fx_train_t
            return fx_test_t

        return predict_fn_finite

    def predict_fn(
        t: ArrayOrScalar = None,
        fx_train_0: ArrayOrScalar = 0.,
        fx_test_0: ArrayOrScalar = None,
        k_test_train: np.ndarray = None
    ) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]:
        """Return output predictions on train [and test] set[s] at time[s] `t`.

    Args:
      t:
        a scalar of array of scalars of any shape. `t=None` is treated as
        infinity and returns the same result as `t=np.inf`, but is computed
        using identity or linear solve for train and test predictions
        respectively instead of eigendecomposition, saving time and precision.
        Equivalent of training steps (but can be fractional).
      fx_train_0:
        output of the network at `t == 0` on the training set. `fx_train_0=None`
        means to not compute predictions on the training set.
      fx_test_0:
        output of the network at `t == 0` on the test set. `fx_test_0=None`
        means to not compute predictions on the test set.
      k_test_train:
        kernel relating test data with training data. Must have the shape of
        `zip(y_test.shape, y_train.shape)` with `trace_axes` absent. Pass
        `k_test_train=None` if you only need non-regularized (`diag_reg=0`)
        predictions on the training set. For regularized train-set predictions,
        pass `k_test_train=k_train_train`.

    Returns:
      `fx_train_t` or `(fx_train_t, fx_test_t)` if `fx_test_0 != None` with
      potentially additional leading time dimensions matching `t.shape`.

    Raises:
      ValueError: if `fx_test_0` is not `None`, but `k_test_train` is `None`.
    """
        _check_inputs(fx_train_0, fx_test_0, k_test_train)

        # Infinite time
        if t is None:
            return get_predict_fn_inf()(fx_train_0, fx_test_0, k_test_train)

        # Finite time
        return get_predict_fn_finite()(t, fx_train_0, fx_test_0, k_test_train)

    return predict_fn
예제 #12
0
def gradient_descent(
    loss: Callable[[np.ndarray, np.ndarray], float],
    k_train_train: np.ndarray,
    y_train: np.ndarray,
    learning_rate: float = 1.,
    momentum: float = None,
    trace_axes: Axes = (-1, )
) -> Callable[[
        ArrayOrScalar, Union[ArrayOrScalar,
                             ODEState], ArrayOrScalar, Optional[np.ndarray]
], Union[np.ndarray, Tuple[np.ndarray, np.ndarray], ODEState]]:
    r"""Predicts the outcome of function space training using gradient descent.

  Uses an ODE solver. If `momentum != None`, solves a continuous-time version of
  gradient descent with momentum (note: this case uses standard momentum as
  opposed to Nesterov momentum).

  Solves the function space ODE for [momentum] gradient descent with a given
  `loss` (detailed in [*]) given a Neural Tangent Kernel[s] over the dataset[s]
  at arbitrary time[s] (step[s]) `t`. Note that for gradient descent
  `absolute_time = learning_rate * t` and the scales of the learning rate and
  query step[s] `t` are interchangeable. However, the momentum gradient descent
  ODE is solved in the units of `learning_rate**0.5`, and therefore
  `absolute_time = learning_rate**0.5 * t`, hence the `learning_rate` and
  training time[s] (step[s]) `t` scales are not interchangeable.

  [*] https://arxiv.org/abs/1902.06720

  Example:
    >>> from neural_tangents import empirical_ntk_fn
    >>> from neural_tangents import predict
    >>>
    >>> t = 1e-7
    >>> learning_rate = 1e-2
    >>> momentum = 0.9
    >>>
    >>> kernel_fn = empirical_ntk_fn(f)
    >>> k_test_train = kernel_fn(x_test, x_train, params)
    >>>
    >>> from jax.experimental import stax
    >>> cross_entropy = lambda fx, y_hat: -np.mean(stax.logsoftmax(fx) * y_hat)
    >>> predict_fn = predict.gradient_descent(cross_entropy, k_train_train,
    >>>                                       y_train, learning_rate, momentum)
    >>>
    >>> fx_train_0 = f(params, x_train)
    >>> fx_test_0 = f(params, x_test)
    >>>
    >>> fx_train_t, fx_test_t = predict_fn(t, fx_train_0, fx_test_0,
    >>>                                    k_test_train)

  Args:
    loss:
      a loss function whose signature is `loss(f(x_train), y_train)`. Note:
      the loss function should treat the batch and output dimensions
      symmetrically.
    k_train_train:
      kernel on the training data. Must have the shape of
      `zip(y_train.shape, y_train.shape)` with `trace_axes` absent.
    y_train:
      targets for the training data.
    learning_rate:
      learning rate, step size.
    momentum:
      momentum scalar.
    trace_axes:
      `f(x_train)` axes such that `k_train_train` lacks these pairs of
      dimensions and is to be interpreted as :math:`\Theta \otimes I`, i.e.
      block-diagonal along `trace_axes`. These can can be specified either to
      save space and compute, or to even improve approximation accuracy of the
      infinite-width or infinite-samples limit, since in in these limits the
      covariance along channel / feature / logit axes indeed converges to a
      constant-diagonal matrix. However, if you target linearized dynamics of a
      specific finite-width network, `trace_axes=()` will yield most accurate
      result.

  Returns:
    A function that returns output train [and test] set[s] predictions at
    time[s] `t`.
  """
    _, odd, _, _ = _get_axes(k_train_train)
    trace_axes = utils.canonicalize_axis(trace_axes, y_train)
    non_t_axes = tuple(a for a in range(y_train.ndim) if a not in trace_axes)
    last_t_axes = range(-len(trace_axes), 0)

    dtype = k_train_train.dtype
    grad_loss = grad(lambda fx: loss(fx, y_train))

    if momentum is not None:
        learning_rate **= 0.5
        momentum = (momentum - 1.0) / learning_rate

    def get_state_0(fx_train_or_state_0, fx_test_0, fx_test_shape):
        if isinstance(fx_train_or_state_0, ODEState):
            fx_train_0 = fx_train_or_state_0.fx_train
            fx_test_0 = fx_train_or_state_0.fx_test
            qx_train_0 = fx_train_or_state_0.qx_train
            qx_test_0 = fx_train_or_state_0.qx_test
        else:
            fx_train_0 = fx_train_or_state_0
            qx_train_0 = qx_test_0 = None

        if fx_train_0 is None:
            fx_train_0 = np.zeros_like(y_train, dtype)
        else:
            fx_train_0 = np.broadcast_to(fx_train_0, y_train.shape)

        if fx_test_0 is not None:
            fx_test_0 = np.broadcast_to(fx_test_0, fx_test_shape)

        if momentum is None:
            if qx_train_0 is not None or qx_test_0 is not None:
                raise ValueError('Got passed momentum state variables, while '
                                 '`momentum is None`.')
        else:
            qx_train_0 = (np.zeros_like(y_train, dtype) if qx_train_0 is None
                          else np.broadcast_to(qx_train_0, y_train.shape))
            qx_test_0 = (None if fx_test_0 is None else
                         (np.zeros(fx_test_shape, dtype) if qx_test_0 is None
                          else np.broadcast_to(qx_test_0, fx_test_shape)))

        return ODEState(fx_train_0, fx_test_0, qx_train_0, qx_test_0)  # pytype: disable=wrong-arg-count

    def get_dstate_dt(k_test_train):
        def dstate_dt(state_t: ODEState, unused_t) -> ODEState:
            fx_train_t, fx_test_t, qx_train_t, qx_test_t = (state_t.fx_train,
                                                            state_t.fx_test,
                                                            state_t.qx_train,
                                                            state_t.qx_test)

            dy_df_t = grad_loss(fx_train_t)

            fx_train_t = -np.moveaxis(
                np.tensordot(k_train_train, dy_df_t,
                             (odd, non_t_axes)), last_t_axes, trace_axes)
            if fx_test_t is not None:
                fx_test_t = -np.moveaxis(
                    np.tensordot(k_test_train, dy_df_t,
                                 (odd, non_t_axes)), last_t_axes, trace_axes)

            if momentum is None:
                return ODEState(fx_train_t, fx_test_t)  # pytype: disable=wrong-arg-count

            fx_train_t += momentum * qx_train_t
            if qx_test_t is not None:
                fx_test_t += momentum * qx_test_t

            return ODEState(qx_train_t, qx_test_t, fx_train_t, fx_test_t)  # pytype: disable=wrong-arg-count

        return dstate_dt

    def predict_fn(
        t: ArrayOrScalar = None,
        fx_train_or_state_0: Union[ArrayOrScalar, ODEState] = 0.,
        fx_test_0: ArrayOrScalar = None,
        k_test_train: np.ndarray = None
    ) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray], ODEState]:
        """Return output predictions on train [and test] set[s] at time[s] `t`.

    Args:
      t:
        a scalar or array of scalars of any shape in strictly increasing order.
        `t=None` is equivalent to `t=np.inf` and may not converge. Equivalent of
        training steps (but can be fractional).
      fx_train_or_state_0:
        either (a) output of the network at `t == 0` on the training set or (b)
        complete ODE state (`predict.ODEState`). Pass an ODE state if you want
        to operate on the full ODE state instead of output variables only
        (useful for inspecting auxiliary variables or resuming an optimizer with
        auxiliary variables from a specific state. Note that only
        `momentum != None` optimizer currently has auxiliary variables. To
        initialize an ODE state from scratch, call
        `predict.ODEState(fx_train_0, fx_test_0)`. If an ODE state is passed, an
        ODE state is returned. `fx_train_0=None` means to not compute
        predictions on the training set.
      fx_test_0:
        output of the network at `t == 0` on the test set. `fx_test_0=None`
        means to not compute predictions on the test set.
      k_test_train:
        kernel relating test data with training data. Must have the shape of
        `zip(y_test.shape, y_train.shape)` with `trace_axes` absent. Pass
        `k_test_train=None` if you only need predictions on the training set.

    Returns:
      `fx_train_t` or `(fx_train_t, fx_test_t)` if `fx_test_0 != None` with
      potentially additional leading time dimensions matching `t.shape`.
      Alternatively can return an `ODEState` at time[s] `t`.

    Raises:
      ValueError: if `fx_test_0` is not `None`, but `k_test_train` is `None`.
    """
        _check_inputs(fx_train_or_state_0, fx_test_0, k_test_train)

        t = np.array(t if t is not None else np.inf, dtype) * learning_rate
        t_shape = t.shape
        t = t.reshape((-1, ))

        # ODE solver requires `t[0]` to be the time where `fx_train_0` [and
        # `fx_test_0`] are evaluated, but also a strictly increasing sequence of
        # timesteps, so we always temporarily append an [almost] `0` at the start.
        t0 = np.where(t[0] == 0, np.full((1, ), -1e-24, t.dtype),
                      np.zeros((1, ), t.dtype))
        t = np.concatenate([t0, t])

        # Solve the ODE.
        fx_test_shape = _get_fx_test_shape(y_train, k_test_train, trace_axes)
        state_0 = get_state_0(fx_train_or_state_0, fx_test_0, fx_test_shape)
        state_t = ode.odeint(get_dstate_dt(k_test_train), state_0, t)

        # Remove the added `t0`.
        trim = lambda x: x[1:].reshape(t_shape + x.shape[1:])
        trim_tree = lambda tree: tree_map(trim, tree)
        state_t = trim_tree(state_t)

        # `ODEState` -> `ODEState`
        if isinstance(fx_train_or_state_0, ODEState):
            return state_t

        # `np.ndarray` -> `np.ndarray`
        fx_train_t, fx_test_t = state_t.fx_train, state_t.fx_test

        if fx_train_or_state_0 is not None and fx_test_0 is None:
            return fx_train_t
        if fx_test_0 is not None and fx_train_or_state_0 is None:
            return fx_test_t
        return fx_train_t, fx_test_t

    return predict_fn