示例#1
0
        def predict_fn_finite(t, fx_train_0, fx_test_0, k_test_train):
            t = np.array(t) * learning_rate
            t_shape, t_ndim = t.shape, t.ndim
            t = t.reshape((-1, 1))

            rhs = -y_train if fx_train_0 is None else fx_train_0 - y_train
            rhs = np.moveaxis(rhs, trace_axes,
                              last_t_axes).reshape((-1, ) + rhs_shape)
            shape = t_shape + k_train_train.shape[1::2] + rhs_shape

            if fx_train_0 is not None:
                dfx_train = expm1_fn(rhs, t).reshape(shape)
                dfx_train = np.moveaxis(dfx_train, last_t_axes, trace_axes)
                fx_train_t = fx_train_0 + dfx_train

            if fx_test_0 is not None:
                dfx_test = inv_expm1_fn(rhs, t).reshape(shape)
                dfx_test = np.tensordot(k_test_train, dfx_test,
                                        (odd, non_t_axes))
                dfx_test = np.moveaxis(
                    dfx_test,
                    tuple(range(n_non_t_axes, n_non_t_axes + t_ndim)) +
                    last_t_axes,
                    tuple(range(t_ndim)) + trace_axes)
                fx_test_t = fx_test_0 + dfx_test

            if fx_train_0 is not None and fx_test_0 is not None:
                return fx_train_t, fx_test_t
            if fx_test_0 is None:
                return fx_train_t
            return fx_test_t
示例#2
0
def conv_general_dilated(lhs, rhs, window_strides, padding, output_shape, lhs_dilation=None,
                         rhs_dilation=None, dimension_numbers=None,
                         feature_group_count=1, batch_group_count=1, precision=None):
  """ A general conv API that integrates normal conv, deconvolution,
  dilated convolution, etc."""
  dim = None
  lhs_spec, rhs_spec, out_spec = dimension_numbers
  if lhs_spec != out_spec:
    raise TypeError('Current implementation requires the `data_format` of the '
                    'inputs and outputs to be the same.')
  if len(lhs_spec) >= 6:
    raise TypeError('Current implmentation does not support 4 or higher'
                    'dimensional convolution, but got: ', len(lhs_spec) - 2)
  dim = len(lhs_spec) - 2
  if lhs_dilation and rhs_dilation:
    if lhs_dilation == (1,) * dim and rhs_dilation == (1,) * dim:
      lhs_dilation, rhs_dilation = None, None
    else:
      raise TypeError('Current implementation does not support that '
                      'deconvolution and dilation to be performed at the same '
                      'time, but got lhs_dilation: {}, rhs_dilation: {}'.format(
                          lhs_dilation, rhs_dilation))
  if padding not in ['SAME', 'VALID']:
    raise TypeError('Current implementation requires the padding parameter'
                    'to be either `VALID` or `SAME`, but got: ', padding)
  # Convert params from int/Sequence[int] to list of ints.
  strides, lhs_dilation, rhs_dilation = _conv_general_param_type_converter(
    window_strides, lhs_dilation, rhs_dilation
  )
  # Preprocess the shapes
  dim_maps = {}
  if isinstance(lhs_spec, str):
    dim_maps['I'] = list(rhs_spec).index('I')
    dim_maps['O'] = list(rhs_spec).index('O')
    dim_maps['N'] = list(lhs_spec).index('N')
    dim_maps['C'] = list(lhs_spec).index('C')
  else:
    dim_maps['I'] = rhs_spec[1]
    dim_maps['O'] = rhs_spec[0]
    dim_maps['N'] = lhs_spec[0]
    dim_maps['C'] = lhs_spec[1]

  lhs = np.moveaxis(lhs, (dim_maps['N'], dim_maps['C']), (0, dim + 1))
  # Adjust the filters, put the dimension 'I' and 'O' at last.
  rhs = np.moveaxis(rhs, (dim_maps['O'], dim_maps['I']), (dim + 1, dim))
  spatial_dim_maps = {1: 'W', 2: 'HW', 3: 'DHW'}
  data_format = 'N' + spatial_dim_maps[dim] + 'C'
  tf_nn_APIs = {1: [nn.conv1d, nn.conv1d_transpose],
                2: [nn.conv2d, nn.conv2d_transpose],
                3: [nn.conv3d, nn.conv3d_transpose]}

  output = None
  if rhs_dilation or (lhs_dilation is None and rhs_dilation is None):
    output = tf_nn_APIs[dim][0](lhs, rhs, strides, padding, data_format, rhs_dilation)
  else:
    output = tf_nn_APIs[dim][1](lhs, rhs, tf.constant(output_shape), strides, padding, data_format, lhs_dilation)
  output = np.moveaxis(output, (0, dim + 1), (dim_maps['N'], dim_maps['C']))
  return np.asarray(output)
示例#3
0
def _index_and_contract(ntk: np.ndarray,
                        trace_axes: Axes,
                        diagonal_axes: Axes) -> np.ndarray:
  if ntk.ndim % 2 == 1:
    raise ValueError('Expected an even-dimensional kernel. Please file a bug at'
                     'https://github.com/google/neural-tangents/issues/new')

  output_ndim = ntk.ndim // 2
  trace_axes = utils.canonicalize_axis(trace_axes, output_ndim)
  diagonal_axes = utils.canonicalize_axis(diagonal_axes, output_ndim)
  n_marg = len(diagonal_axes)
  contract_size = utils.size_at(ntk.shape[:output_ndim], trace_axes)

  shrink = 0
  for c in reversed(trace_axes):
    ntk = np.trace(ntk, axis1=c, axis2=output_ndim + c - shrink)
    shrink += 1

  for i, d in enumerate(diagonal_axes):
    ntk = np.diagonal(ntk, axis1=d - i, axis2=output_ndim + d - shrink - 2 * i)

  ntk = utils.zip_axes(ntk, 0, ntk.ndim - n_marg)
  res_diagonal_axes = utils.get_res_batch_dims(trace_axes, diagonal_axes)
  ntk = np.moveaxis(ntk, range(-n_marg, 0), res_diagonal_axes)
  return ntk / contract_size
示例#4
0
def diagonal_between(x: np.ndarray,
                     start_axis: int = 0,
                     end_axis: int = -1) -> np.ndarray:
    """Returns the diagonal along all dimensions between start and end axes."""
    if end_axis == -1:
        end_axis = x.ndim
    half_ndim, ragged = divmod(end_axis - start_axis, 2)
    if ragged:
        raise ValueError(
            f'Need even number of axes to flatten, got {end_axis - start_axis}.'
        )
    if half_ndim == 0:
        return x

    side_shape = x.shape[start_axis:start_axis + half_ndim]
    side_size = size_at(side_shape)

    shape_2d = x.shape[:start_axis] + (side_size,
                                       side_size) + x.shape[end_axis:]
    shape_result = x.shape[:start_axis] + side_shape + x.shape[end_axis:]

    x = np.diagonal(x.reshape(shape_2d),
                    axis1=start_axis,
                    axis2=start_axis + 1)
    x = np.moveaxis(x, -1, start_axis)
    return x.reshape(shape_result)
示例#5
0
            def reshape(m):
                if m is not None:
                    if m.shape[self.channel_axis] != 1:
                        raise NotImplementedError(
                            f'Different channel-wise masks are not supported for '
                            f'infinite-width layers now (got `mask.shape == {m.shape}). '
                            f'Please describe your use case at '
                            f'https://github.com/google/neural-tangents/issues/new'
                        )

                    m = np.squeeze(
                        np.moveaxis(m, (self.batch_axis, self.channel_axis),
                                    (0, -1)), -1)
                    if self.is_reversed:
                        m = np.moveaxis(m, range(1, m.ndim),
                                        range(m.ndim - 1, 0, -1))
                return m
示例#6
0
def dot_general(lhs: np.ndarray,
                rhs: np.ndarray,
                contracting_dims: Axes,
                batch_dims: Axes,
                precision=None) -> np.ndarray:
  """`jax.lax.dot_general` with preserved dims order and shared lhs / rhs dims.

  Precisely, returns `jax.lax.dot_general(lhs, rhs, dimension_numbers)` where
  `dimension_numbers == ((contracting_dims, contracting_dims),
                         (batch_dims, batch_dims))`,
  but preserves the dimension order in the output. See XLA's
   `DotGeneral<https://www.tensorflow.org/xla/operation_semantics#dotgeneral>`.

  Args:
    lhs: array.
    rhs: array, must have the same dimensionality as `lhs`.
    contracting_dims: contracting dimensions.
    batch_dims: batch dimensions.
    precision: Optional. Either `None`, which means the default precision for
      the backend, or a `Precision` enum value.

  Returns:
    Dot product result with preserved dimension order.
  """
  contracting_dims = canonicalize_axis(contracting_dims, lhs)
  batch_dims = canonicalize_axis(batch_dims, lhs)

  n_batch_dims = len(batch_dims)
  leading_batch_dims = range(n_batch_dims)

  dimension_numbers = ((contracting_dims, contracting_dims),
                       (leading_batch_dims, leading_batch_dims))

  lhs = np.moveaxis(lhs, batch_dims, leading_batch_dims)
  if rhs is None:
    rhs = lhs
  else:
    rhs = np.moveaxis(rhs, batch_dims, leading_batch_dims)

  prod = tf_dot_general(lhs, rhs, dimension_numbers)
  prod = zip_axes(prod, n_batch_dims)

  res_batch_dims = get_res_batch_dims(contracting_dims, batch_dims)
  prod = np.moveaxis(prod, leading_batch_dims, res_batch_dims)
  return prod
示例#7
0
def reverse_zipped(mat: np.ndarray, start_axis: int = 0) -> np.ndarray:
    if mat is not None:
        source_axes = tuple(j for i in range(mat.ndim - 2, start_axis - 1, -2)
                            for j in (i, i + 1))

        target_axes = range(start_axis, mat.ndim)
        mat = np.moveaxis(mat, source_axes, target_axes)

    return mat
示例#8
0
    def cho_solve(b: np.ndarray, b_axes: Axes) -> np.ndarray:
        b_axes = utils.canonicalize_axis(b_axes, b)
        last_b_axes = range(-len(b_axes), 0)
        x_shape = x_non_channel_shape + tuple(b.shape[a] for a in b_axes)

        b = np.moveaxis(b, b_axes, last_b_axes)
        b = b.reshape((A.shape[1], -1))

        x = np.asarray(tf.linalg.cholesky_solve(C, b))
        x = x.reshape(x_shape)
        return x
示例#9
0
        def predict_fn_inf(fx_train_0, fx_test_0, k_test_train):
            fx_train_t = y_train.astype(k_train_train.dtype)
            if fx_test_0 is None:
                return fx_train_t

            rhs = y_train if fx_train_0 is None else y_train - fx_train_0
            dfx_test = np.tensordot(k_test_train, solve(rhs, trace_axes),
                                    (odd, first))
            dfx_test = np.moveaxis(dfx_test, last_t_axes, trace_axes)
            fx_test_t = fx_test_0 + dfx_test

            if fx_train_0 is None:
                return fx_test_t
            return fx_train_t, fx_test_t
示例#10
0
def _zip_axes(x: np.ndarray,
              start_axis: int = 0,
              end_axis: int = -1,
              unzip: bool = False) -> np.ndarray:
  """Zip/unzip (interleave/de-interleave) axes starting from `start_axis`.

  Changes the shape as follows:
    If `unzip == True`:
    `[..., X, X, ..., Y, Y, ..., Z, Z, ...] -> [..., X, Y, Z, ..., X, Y, Z, ..]`
    If `unzip == False`:
    `[..., X, Y, Z, ..., X, Y, Z, ...] -> [..., X, X, ..., Y, Y, ..., Z, Z, ..]`

  Args:
    x: `np.ndarray` with an even number of dimensions following `start_axis`.
    start_axis: `int`, number of axis from which to zip/unzip.
    end_axis: `int`, number of axis until which to zip/unzip.
    unzip: `bool`, set to `True` to unzip instead of zip.

  Returns:
    A `np.ndarray` with a new shape.
  """
  if end_axis == -1:
    end_axis = len(x.shape)
  half_ndim, ragged = divmod(end_axis - start_axis, 2)
  if ragged:
    raise ValueError(
        f'Need even number of axes to zip, got {end_axis - start_axis}.')

  odd_axes = range(start_axis + 1, end_axis, 2)
  last_axes = range(end_axis - half_ndim, end_axis)

  if unzip:
    x = np.moveaxis(x, odd_axes, last_axes)
  else:
    x = np.moveaxis(x, last_axes, odd_axes)
  return x
示例#11
0
        def dstate_dt(state_t: ODEState, unused_t) -> ODEState:
            fx_train_t, fx_test_t, qx_train_t, qx_test_t = (state_t.fx_train,
                                                            state_t.fx_test,
                                                            state_t.qx_train,
                                                            state_t.qx_test)

            dy_df_t = grad_loss(fx_train_t)

            fx_train_t = -np.moveaxis(
                np.tensordot(k_train_train, dy_df_t,
                             (odd, non_t_axes)), last_t_axes, trace_axes)
            if fx_test_t is not None:
                fx_test_t = -np.moveaxis(
                    np.tensordot(k_test_train, dy_df_t,
                                 (odd, non_t_axes)), last_t_axes, trace_axes)

            if momentum is None:
                return ODEState(fx_train_t, fx_test_t)  # pytype: disable=wrong-arg-count

            fx_train_t += momentum * qx_train_t
            if qx_test_t is not None:
                fx_test_t += momentum * qx_test_t

            return ODEState(qx_train_t, qx_test_t, fx_train_t, fx_test_t)  # pytype: disable=wrong-arg-count
示例#12
0
def _trace_and_diagonal(ntk: np.ndarray,
                        trace_axes: Axes,
                        diagonal_axes: Axes) -> np.ndarray:
  """Extract traces and diagonals along respective pairs of axes from the `ntk`.

  Args:
    ntk:
      input empirical NTK of shape `(N1, X, Y, Z, ..., N2, X, Y, Z, ...)`.
    trace_axes:
      axes (among `X, Y, Z, ...`) to trace over, i.e. compute the trace along
      and remove the  respective pairs of axes from the `ntk`.
    diagonal_axes:
      axes (among `X, Y, Z, ...`) to take the diagonal along, i.e. extract the
      diagonal along the respective pairs of axes from the `ntk` (and hence
      reduce the resulting `ntk` axes count by 2).
  Returns:
    An array of shape, for example, `(N1, N2, Y, Z, Z, ...)` if
    `trace_axes=(1,)` (`X` axes removed), and `diagonal_axes=(2,)` (`Y` axes
    replaced with a single `Y` axis).
  """
  if ntk.ndim % 2 == 1:
    raise ValueError('Expected an even-dimensional kernel. Please file a bug at'
                     'https://github.com/google/neural-tangents/issues/new')

  output_ndim = ntk.ndim // 2

  trace_axes = utils.canonicalize_axis(trace_axes, output_ndim)
  diagonal_axes = utils.canonicalize_axis(diagonal_axes, output_ndim)

  n_diag, n_trace = len(diagonal_axes), len(trace_axes)
  contract_size = utils.size_at(ntk.shape[:output_ndim], trace_axes)

  for i, c in enumerate(reversed(trace_axes)):
    ntk = np.trace(ntk, axis1=c, axis2=output_ndim + c - i)

  for i, d in enumerate(diagonal_axes):
    axis1 = d - i
    axis2 = output_ndim + d - 2 * i - n_trace
    for c in trace_axes:
      if c < d:
        axis1 -= 1
        axis2 -= 1
    ntk = np.diagonal(ntk, axis1=axis1, axis2=axis2)

  ntk = utils.zip_axes(ntk, 0, ntk.ndim - n_diag)
  res_diagonal_axes = utils.get_res_batch_dims(trace_axes, diagonal_axes)
  ntk = np.moveaxis(ntk, range(-n_diag, 0), res_diagonal_axes)
  return ntk / contract_size
示例#13
0
    def testPredCovPosDef(self, train_shape, test_shape, network, out_logits):
        _, x_test, x_train, y_train = self._get_inputs(out_logits, test_shape,
                                                       train_shape)
        _, _, ker_fun = _build_network(train_shape[1:], network, out_logits)

        ts = np.logspace(-3, 3, 10)
        predict_fn_mse_ens = predict.gradient_descent_mse_ensemble(
            ker_fun, x_train, y_train)

        for get in ('nngp', 'ntk'):
            for x in (None, 'x_test'):
                for t in (None, 'ts'):
                    with self.subTest(get=get, x=x, t=t):
                        cov = predict_fn_mse_ens(
                            t=t if t is None else ts,
                            get=get,
                            x_test=x if x is None else x_test,
                            compute_cov=True).covariance

                        self.assertAllClose(cov, np.moveaxis(cov, -1, -2))
                        self.assertGreater(np.min(np.linalg.eigh(cov)[0]),
                                           -1e-4)
示例#14
0
def transpose_zipped(x: np.ndarray) -> np.ndarray:
  return np.moveaxis(x, range(1, x.ndim, 2), range(0, x.ndim, 2))
示例#15
0
    def predict_fn(
        get: Get,
        k_test_train=None,
        nngp_test_test: np.ndarray = None
    ) -> Dict[str, Union[np.ndarray, Gaussian]]:
        """`test`-set posterior given respective covariance matrices.

    Args:
      get:
        string, the mode of the Gaussian process, either "nngp" or "ntk", or a
        tuple, or `None`. If `None` then both `nngp` and `ntk` predictions are
        returned.
      k_test_train:
        test-train kernel. Can be (a) `np.ndarray`, (b) `Kernel` namedtuple, (c)
        `Kernel` object. Must contain the necessary `nngp` and/or `ntk` kernels
        for arguments provided to the returned `predict_fn` function. For
        example, if you request to compute posterior test [only] NTK covariance,
        `k_test_train` must contain both `ntk` and `nngp` kernels. If `None`,
        returns predictions on the training set. Note that train-set outputs are
        always `N(y_train, 0)` and mostly returned for API consistency.
      nngp_test_test:
        A test-test NNGP array. Provide if you want to compute test-test
        posterior covariance. `nngp_test_tes=None`, means to not compute it. If
        `k_test_train is None`, pass any non-`None` value (e.g. `True`) if you
        want to get non-regularized (`diag_reg=0`) train-train posterior
        covariance. Note that non-regularized train-set outputs will always be
        the zero-variance Gaussian `N(y_train, 0)` and mostly returned for API
        consistency. For regularized train-set posterior outputs according to a
        positive `diag_reg`, pass `k_test_train=k_train_train`, and, optionally,
        `nngp_test_test=nngp_train_train`.

    Returns:
      Either a `Gaussian('mean', 'variance')` namedtuple or `mean` of the GP
      posterior on the  `test` set.
    """
        if get is None:
            get = ('nngp', 'ntk')

        out = {}

        for g in get:
            k_dd = _get_attr(k_train_train, g)
            k_td = None if k_test_train is None else _get_attr(k_test_train, g)

            if k_td is None:
                # Train set predictions.
                y = y_train.astype(k_dd.dtype)
            else:
                # Test set predictions.
                y = np.tensordot(k_td, k_inv_y(g), (odd, first))
                y = np.moveaxis(y, range(-len(trace_axes), 0), trace_axes)

            if nngp_test_test is not None:
                if k_td is None:
                    out[g] = Gaussian(y, np.zeros_like(k_dd, k_dd.dtype))
                else:
                    if (g == 'ntk' and (not hasattr(k_train_train, 'nngp')
                                        or not hasattr(k_test_train, 'nngp'))):
                        raise ValueError(
                            'If `"ntk" in get`, and `nngp_test_test is not None`, '
                            'and `k_test_train is not None`, i.e. you request the '
                            'NTK posterior covariance on the test set, you need '
                            'both NTK and NNGP train-train and test-train matrices'
                            'contained in `k_test_train` and `k_train_train`. '
                            'Hence they must be `namedtuple`s with `nngp` and '
                            '`ntk` attributes.')

                    k_td_nngp_inv_y = solve(g)(_get_attr(k_test_train, 'nngp'),
                                               even)

                    if g == 'nngp':
                        cov = np.tensordot(k_td, k_td_nngp_inv_y, (odd, first))
                        cov = nngp_test_test - utils.zip_axes(cov)
                        out[g] = Gaussian(y, cov)

                    elif g == 'ntk':
                        term_1 = solve(g)(k_td, even)
                        cov = np.tensordot(_get_attr(k_train_train, 'nngp'),
                                           term_1, (odd, first))
                        cov = np.tensordot(term_1, cov, (first, first))

                        term_2 = np.tensordot(k_td, k_td_nngp_inv_y,
                                              (odd, first))
                        term_2 += np.moveaxis(term_2, first, last)
                        cov = utils.zip_axes(cov - term_2) + nngp_test_test
                        out[g] = Gaussian(y, cov)

                    else:
                        raise ValueError(g)

            else:
                out[g] = y

        return out
示例#16
0
def gradient_descent_mse_ensemble(kernel_fn: KernelFn,
                                  x_train: np.ndarray,
                                  y_train: np.ndarray,
                                  learning_rate: float = 1.,
                                  diag_reg: float = 0.0,
                                  diag_reg_absolute_scale: bool = False,
                                  trace_axes: Axes = (-1, ),
                                  **kernel_fn_kwargs):
    r"""Predicts the gaussian embedding induced by gradient descent on MSE loss.

  This is equivalent to an infinite ensemble of infinite-width networks after
  marginalizing out the initialization, if `kernel_fn` is the kernel function of
  the infinite-width network. Note that `kernel_fn` can in principle also be an
  empirical / Monte Carlo finite-width kernel function, but in this case the
  returned output will not have a simple interpretation (unless these functions
  are used to approximate the infinite-width kernel).

  Note that first invocation of the returned `predict_fn` will be slow and
  allocate a lot of memory for its whole lifetime, as the kernel computation,
  and either eigendecomposition (`t` is a scalar or an array) or Cholesky
  factorization (`t=None`) of `kernel_fn(x_train, None, get)` is performed and
  cached for future invocations (or both, if the function is called on both
  finite and infinite (`t=None`) times).

  Args:
    kernel_fn:
      A kernel function that computes NNGP and/or NTK. Must have a signature
      `kernel_fn(x1, x2, get, **kernel_fn_kwargs)` and return a `Kernel` object
      or a `namedtuple` with `nngp` and/or `ntk` attributes. Therefore, it can
      be an `AnalyticKernelFn`, but also a `MonteCarloKernelFn`, or an
      `EmpiricalKernelFn` (but only `nt.empirical_kernel_fn` and not
      `nt.empirical_ntk_fn` or `ntk.empirical_nngp_fn`, since the latter two do
      not accept a `get` argument). Note that for meaningful outputs, the kernel
      function must represent or at least approximate the infinite-width kernel.
    x_train:
      training inputs.
    y_train:
      training targets.
    learning_rate:
      learning rate, step size.
    diag_reg:
      a scalar representing the strength of the diagonal regularization for
      `kernel_fn(x_train, None, get)`, i.e. computing
      `kernel_fn(x_train, None, get) + diag_reg * I` during Cholesky
      factorization or eigendecomposition.
    diag_reg_absolute_scale:
      `True` for `diag_reg` to represent regularization in absolute units,
      `False` to be
      `diag_reg * np.mean(np.trace(kernel_fn(x_train, None, get)))`.
    trace_axes:
      `f(x_train)` axes such that `kernel_fn(x_train, None, get)`,
      `kernel_fn(x_test, x_train, get)`[, and `kernel_fn(x_test, None, get)`]
      lack these pairs of dimensions and are to be interpreted as
      :math:`\Theta \otimes I`, i.e. block-diagonal along `trace_axes`. These
      can can be specified either to save space and compute, or to even improve
      approximation accuracy of the infinite-width or infinite-samples limit,
      since in in these limits the covariance along channel / feature / logit
      axes indeed converges to a constant-diagonal matrix. However, if you
      target linearized dynamics of a specific finite-width network,
      `trace_axes=()` will yield most accurate result.
    **kernel_fn_kwargs:
      optional keyword arguments passed to `kernel_fn`.

  Returns:
    A function with signature `predict_fn(t, x_test, get, compute_cov)`
    returning either mean or mean and covariance of the infinite ensemble of
    infinite-width networks outputs on `x_test` at time[s] `t`, in the `get`
    regime (`"nngp"`, `"ntk"`, or `("nngp", "ntk")`).
  """
    expm1 = _make_expm1_fn(y_train.size)
    inv_expm1 = _make_inv_expm1_fn(y_train.size)

    trace_axes = utils.canonicalize_axis(trace_axes, y_train)
    trace_axes = tuple(-y_train.ndim + a for a in trace_axes)
    n_trace_axes = len(trace_axes)
    last_t_axes = range(-n_trace_axes, 0)
    trace_shape = tuple(y_train.shape[a] for a in trace_axes)

    y_train_flat = np.moveaxis(y_train, trace_axes,
                               last_t_axes).reshape((-1, ) + trace_shape)

    k_dd_cache = {}

    def get_k_train_train(get: Tuple[str, ...]) -> _Kernel:
        if len(get) == 1:
            get = get[0]
            if get not in k_dd_cache:
                k_dd_cache[get] = kernel_fn(x_train, None, get,
                                            **kernel_fn_kwargs)

        elif len(get) == 2:
            if not any(g in k_dd_cache for g in get):
                k_dd_cache.update(
                    kernel_fn(x_train, None, get,
                              **kernel_fn_kwargs)._asdict())
            else:
                for g in get:
                    if g not in k_dd_cache:
                        k_dd_cache[g] = kernel_fn(x_train, None, g,
                                                  **kernel_fn_kwargs)

        else:
            raise ValueError(get)
        return _Kernel(**k_dd_cache)

    @lru_cache(2)
    def eigenspace(get: str):
        k_dd = getattr(get_k_train_train((get, )), get)
        k_dd = _add_diagonal_regularizer(utils.make_2d(k_dd), diag_reg,
                                         diag_reg_absolute_scale)
        return tf.linalg.eigh(k_dd)

    @lru_cache(4)
    def predict_inf(get: Get):
        _, get = utils.canonicalize_get(get)
        k_dd = get_k_train_train(get)
        return gp_inference(k_dd, y_train, diag_reg, diag_reg_absolute_scale,
                            trace_axes)

    def get_matrices(get: Get, x_test: Optional[np.ndarray],
                     compute_cov: bool):
        get = _get_dependency(get, compute_cov)
        k_dd = get_k_train_train(get)
        if x_test is None:
            k_td = None
            nngp_tt = compute_cov or None
        else:
            k_td = kernel_fn(x_test, x_train, get, **kernel_fn_kwargs)
            if compute_cov:
                nngp_tt = kernel_fn(x_test, None, 'nngp', **kernel_fn_kwargs)
            else:
                nngp_tt = None
        return k_dd, k_td, nngp_tt

    @utils.get_namedtuple('Gaussians')
    def predict_fn(t: ArrayOrScalar = None,
                   x_test: np.ndarray = None,
                   get: Get = None,
                   compute_cov: bool = False) -> Dict[str, Gaussian]:
        """Return output mean and covariance on the test set at time[s] `t`.

    Args:
      t:
        a scalar of array of scalars of any shape. `t=None` is treated as
        infinity and returns the same result as `t=np.inf`, but is computed
        using linear solve for test predictions instead of eigendecomposition,
        saving time and precision.
      x_test:
        test inputs. `None` means to return non-regularized (`diag_reg=0`)
        predictions on the train-set inputs. For regularized predictions, pass
        `x_test=x_train`.
      get:
        string, the mode of the Gaussian process, either "nngp" or "ntk", or a
        tuple. `get=None` is equivalent to `get=("nngp", "ntk")`.
      compute_cov:
        if `True` computing both `mean` and `variance` and only `mean`
        otherwise.

    Returns:
      `fx_test_mean_t` or `(fx_test_mean_t, fx_test_cov_t)` if
      `compute_cov == True` with potentially additional leading time dimensions.
    """
        if get is None:
            get = ('nngp', 'ntk')

        # train-train, test-train, test-test.
        k_dd, k_td, nngp_tt = get_matrices(get, x_test, compute_cov)

        # Infinite time.
        if t is None:
            return predict_inf(get)(get=get,
                                    k_test_train=k_td,
                                    nngp_test_test=nngp_tt)

        # Finite time.
        t = np.array(t) * learning_rate
        t_shape = t.shape
        t = t.reshape((-1, 1))

        def reshape_mean(mean):
            k = _get_first(k_dd if k_td is None else k_td)
            mean = mean.reshape(t_shape + k.shape[::2] + trace_shape)
            mean = np.moveaxis(mean, last_t_axes, trace_axes)
            return mean

        def reshape_cov(cov):
            k = _get_first(k_dd if k_td is None else k_td)
            cov_shape_t = t_shape + k.shape[::2] * 2
            return utils.zip_axes(cov.reshape(cov_shape_t), len(t_shape))

        out = {}

        for g in get:
            evals, evecs = eigenspace(g)

            # Training set.
            if k_td is None:
                mean = tf.einsum('ji,ti,ki,k...->tj...',
                                 evecs,
                                 -expm1(evals, t),
                                 evecs,
                                 y_train_flat,
                                 optimize=True)

            # Test set.
            else:
                neg_inv_expm1 = -inv_expm1(evals, t)
                ktd_g = utils.make_2d(getattr(k_td, g))
                mean = tf.einsum('lj,ji,ti,ki,k...->tl...',
                                 ktd_g,
                                 evecs,
                                 neg_inv_expm1,
                                 evecs,
                                 y_train_flat,
                                 optimize=True)

            mean = reshape_mean(mean)

            if nngp_tt is not None:
                nngp_dd = utils.make_2d(k_dd.nngp)

                # Training set.
                if k_td is None:
                    if g == 'nngp':
                        cov = np.einsum('ji,ti,ki->tjk',
                                        evecs,
                                        (np.maximum(evals, 0.) *
                                         np.exp(-2 * np.maximum(evals, 0.) *
                                                t / y_train.size)),
                                        evecs,
                                        optimize=True)

                    elif g == 'ntk':
                        exp = np.einsum('mi,ti,ki->tmk',
                                        evecs,
                                        np.exp(-np.maximum(evals, 0.) * t /
                                               y_train.size),
                                        evecs,
                                        optimize=True)
                        cov = np.einsum('tmk,kl,tnl->tmn',
                                        exp,
                                        nngp_dd,
                                        exp,
                                        optimize=True)

                    else:
                        raise ValueError(g)

                # Test set.
                else:
                    _nngp_tt = utils.make_2d(nngp_tt)

                    if g == 'nngp':
                        cov = _nngp_tt - np.einsum('mj,ji,ti,ki,lk->tml',
                                                   ktd_g,
                                                   evecs,
                                                   -inv_expm1(evals, 2 * t),
                                                   evecs,
                                                   ktd_g,
                                                   optimize=True)

                    elif g == 'ntk':
                        term_1 = np.einsum('mi,ti,ki,lk->tml',
                                           evecs,
                                           neg_inv_expm1,
                                           evecs,
                                           ktd_g,
                                           optimize=True)
                        term_2 = np.einsum(
                            'mj,ji,ti,ki,lk->tml',
                            ktd_g,
                            evecs,
                            neg_inv_expm1,
                            evecs,
                            utils.make_2d(k_td.nngp),  # pytype:disable=attribute-error
                            optimize=True)
                        term_2 += np.moveaxis(term_2, 1, 2)
                        cov = np.einsum('tji,jk,tkl->til',
                                        term_1,
                                        nngp_dd,
                                        term_1,
                                        optimize=True)
                        cov += -term_2 + _nngp_tt

                    else:
                        raise ValueError(g)

                out[g] = Gaussian(mean, reshape_cov(cov))

            else:
                out[g] = mean

        return out

    return predict_fn
示例#17
0
 def reshape_mean(mean):
     k = _get_first(k_dd if k_td is None else k_td)
     mean = mean.reshape(t_shape + k.shape[::2] + trace_shape)
     mean = np.moveaxis(mean, last_t_axes, trace_axes)
     return mean
示例#18
0
    def predict_fn(t: ArrayOrScalar = None,
                   x_test: np.ndarray = None,
                   get: Get = None,
                   compute_cov: bool = False) -> Dict[str, Gaussian]:
        """Return output mean and covariance on the test set at time[s] `t`.

    Args:
      t:
        a scalar of array of scalars of any shape. `t=None` is treated as
        infinity and returns the same result as `t=np.inf`, but is computed
        using linear solve for test predictions instead of eigendecomposition,
        saving time and precision.
      x_test:
        test inputs. `None` means to return non-regularized (`diag_reg=0`)
        predictions on the train-set inputs. For regularized predictions, pass
        `x_test=x_train`.
      get:
        string, the mode of the Gaussian process, either "nngp" or "ntk", or a
        tuple. `get=None` is equivalent to `get=("nngp", "ntk")`.
      compute_cov:
        if `True` computing both `mean` and `variance` and only `mean`
        otherwise.

    Returns:
      `fx_test_mean_t` or `(fx_test_mean_t, fx_test_cov_t)` if
      `compute_cov == True` with potentially additional leading time dimensions.
    """
        if get is None:
            get = ('nngp', 'ntk')

        # train-train, test-train, test-test.
        k_dd, k_td, nngp_tt = get_matrices(get, x_test, compute_cov)

        # Infinite time.
        if t is None:
            return predict_inf(get)(get=get,
                                    k_test_train=k_td,
                                    nngp_test_test=nngp_tt)

        # Finite time.
        t = np.array(t) * learning_rate
        t_shape = t.shape
        t = t.reshape((-1, 1))

        def reshape_mean(mean):
            k = _get_first(k_dd if k_td is None else k_td)
            mean = mean.reshape(t_shape + k.shape[::2] + trace_shape)
            mean = np.moveaxis(mean, last_t_axes, trace_axes)
            return mean

        def reshape_cov(cov):
            k = _get_first(k_dd if k_td is None else k_td)
            cov_shape_t = t_shape + k.shape[::2] * 2
            return utils.zip_axes(cov.reshape(cov_shape_t), len(t_shape))

        out = {}

        for g in get:
            evals, evecs = eigenspace(g)

            # Training set.
            if k_td is None:
                mean = tf.einsum('ji,ti,ki,k...->tj...',
                                 evecs,
                                 -expm1(evals, t),
                                 evecs,
                                 y_train_flat,
                                 optimize=True)

            # Test set.
            else:
                neg_inv_expm1 = -inv_expm1(evals, t)
                ktd_g = utils.make_2d(getattr(k_td, g))
                mean = tf.einsum('lj,ji,ti,ki,k...->tl...',
                                 ktd_g,
                                 evecs,
                                 neg_inv_expm1,
                                 evecs,
                                 y_train_flat,
                                 optimize=True)

            mean = reshape_mean(mean)

            if nngp_tt is not None:
                nngp_dd = utils.make_2d(k_dd.nngp)

                # Training set.
                if k_td is None:
                    if g == 'nngp':
                        cov = np.einsum('ji,ti,ki->tjk',
                                        evecs,
                                        (np.maximum(evals, 0.) *
                                         np.exp(-2 * np.maximum(evals, 0.) *
                                                t / y_train.size)),
                                        evecs,
                                        optimize=True)

                    elif g == 'ntk':
                        exp = np.einsum('mi,ti,ki->tmk',
                                        evecs,
                                        np.exp(-np.maximum(evals, 0.) * t /
                                               y_train.size),
                                        evecs,
                                        optimize=True)
                        cov = np.einsum('tmk,kl,tnl->tmn',
                                        exp,
                                        nngp_dd,
                                        exp,
                                        optimize=True)

                    else:
                        raise ValueError(g)

                # Test set.
                else:
                    _nngp_tt = utils.make_2d(nngp_tt)

                    if g == 'nngp':
                        cov = _nngp_tt - np.einsum('mj,ji,ti,ki,lk->tml',
                                                   ktd_g,
                                                   evecs,
                                                   -inv_expm1(evals, 2 * t),
                                                   evecs,
                                                   ktd_g,
                                                   optimize=True)

                    elif g == 'ntk':
                        term_1 = np.einsum('mi,ti,ki,lk->tml',
                                           evecs,
                                           neg_inv_expm1,
                                           evecs,
                                           ktd_g,
                                           optimize=True)
                        term_2 = np.einsum(
                            'mj,ji,ti,ki,lk->tml',
                            ktd_g,
                            evecs,
                            neg_inv_expm1,
                            evecs,
                            utils.make_2d(k_td.nngp),  # pytype:disable=attribute-error
                            optimize=True)
                        term_2 += np.moveaxis(term_2, 1, 2)
                        cov = np.einsum('tji,jk,tkl->til',
                                        term_1,
                                        nngp_dd,
                                        term_1,
                                        optimize=True)
                        cov += -term_2 + _nngp_tt

                    else:
                        raise ValueError(g)

                out[g] = Gaussian(mean, reshape_cov(cov))

            else:
                out[g] = mean

        return out