def predict_fn_finite(t, fx_train_0, fx_test_0, k_test_train): t = np.array(t) * learning_rate t_shape, t_ndim = t.shape, t.ndim t = t.reshape((-1, 1)) rhs = -y_train if fx_train_0 is None else fx_train_0 - y_train rhs = np.moveaxis(rhs, trace_axes, last_t_axes).reshape((-1, ) + rhs_shape) shape = t_shape + k_train_train.shape[1::2] + rhs_shape if fx_train_0 is not None: dfx_train = expm1_fn(rhs, t).reshape(shape) dfx_train = np.moveaxis(dfx_train, last_t_axes, trace_axes) fx_train_t = fx_train_0 + dfx_train if fx_test_0 is not None: dfx_test = inv_expm1_fn(rhs, t).reshape(shape) dfx_test = np.tensordot(k_test_train, dfx_test, (odd, non_t_axes)) dfx_test = np.moveaxis( dfx_test, tuple(range(n_non_t_axes, n_non_t_axes + t_ndim)) + last_t_axes, tuple(range(t_ndim)) + trace_axes) fx_test_t = fx_test_0 + dfx_test if fx_train_0 is not None and fx_test_0 is not None: return fx_train_t, fx_test_t if fx_test_0 is None: return fx_train_t return fx_test_t
def conv_general_dilated(lhs, rhs, window_strides, padding, output_shape, lhs_dilation=None, rhs_dilation=None, dimension_numbers=None, feature_group_count=1, batch_group_count=1, precision=None): """ A general conv API that integrates normal conv, deconvolution, dilated convolution, etc.""" dim = None lhs_spec, rhs_spec, out_spec = dimension_numbers if lhs_spec != out_spec: raise TypeError('Current implementation requires the `data_format` of the ' 'inputs and outputs to be the same.') if len(lhs_spec) >= 6: raise TypeError('Current implmentation does not support 4 or higher' 'dimensional convolution, but got: ', len(lhs_spec) - 2) dim = len(lhs_spec) - 2 if lhs_dilation and rhs_dilation: if lhs_dilation == (1,) * dim and rhs_dilation == (1,) * dim: lhs_dilation, rhs_dilation = None, None else: raise TypeError('Current implementation does not support that ' 'deconvolution and dilation to be performed at the same ' 'time, but got lhs_dilation: {}, rhs_dilation: {}'.format( lhs_dilation, rhs_dilation)) if padding not in ['SAME', 'VALID']: raise TypeError('Current implementation requires the padding parameter' 'to be either `VALID` or `SAME`, but got: ', padding) # Convert params from int/Sequence[int] to list of ints. strides, lhs_dilation, rhs_dilation = _conv_general_param_type_converter( window_strides, lhs_dilation, rhs_dilation ) # Preprocess the shapes dim_maps = {} if isinstance(lhs_spec, str): dim_maps['I'] = list(rhs_spec).index('I') dim_maps['O'] = list(rhs_spec).index('O') dim_maps['N'] = list(lhs_spec).index('N') dim_maps['C'] = list(lhs_spec).index('C') else: dim_maps['I'] = rhs_spec[1] dim_maps['O'] = rhs_spec[0] dim_maps['N'] = lhs_spec[0] dim_maps['C'] = lhs_spec[1] lhs = np.moveaxis(lhs, (dim_maps['N'], dim_maps['C']), (0, dim + 1)) # Adjust the filters, put the dimension 'I' and 'O' at last. rhs = np.moveaxis(rhs, (dim_maps['O'], dim_maps['I']), (dim + 1, dim)) spatial_dim_maps = {1: 'W', 2: 'HW', 3: 'DHW'} data_format = 'N' + spatial_dim_maps[dim] + 'C' tf_nn_APIs = {1: [nn.conv1d, nn.conv1d_transpose], 2: [nn.conv2d, nn.conv2d_transpose], 3: [nn.conv3d, nn.conv3d_transpose]} output = None if rhs_dilation or (lhs_dilation is None and rhs_dilation is None): output = tf_nn_APIs[dim][0](lhs, rhs, strides, padding, data_format, rhs_dilation) else: output = tf_nn_APIs[dim][1](lhs, rhs, tf.constant(output_shape), strides, padding, data_format, lhs_dilation) output = np.moveaxis(output, (0, dim + 1), (dim_maps['N'], dim_maps['C'])) return np.asarray(output)
def _index_and_contract(ntk: np.ndarray, trace_axes: Axes, diagonal_axes: Axes) -> np.ndarray: if ntk.ndim % 2 == 1: raise ValueError('Expected an even-dimensional kernel. Please file a bug at' 'https://github.com/google/neural-tangents/issues/new') output_ndim = ntk.ndim // 2 trace_axes = utils.canonicalize_axis(trace_axes, output_ndim) diagonal_axes = utils.canonicalize_axis(diagonal_axes, output_ndim) n_marg = len(diagonal_axes) contract_size = utils.size_at(ntk.shape[:output_ndim], trace_axes) shrink = 0 for c in reversed(trace_axes): ntk = np.trace(ntk, axis1=c, axis2=output_ndim + c - shrink) shrink += 1 for i, d in enumerate(diagonal_axes): ntk = np.diagonal(ntk, axis1=d - i, axis2=output_ndim + d - shrink - 2 * i) ntk = utils.zip_axes(ntk, 0, ntk.ndim - n_marg) res_diagonal_axes = utils.get_res_batch_dims(trace_axes, diagonal_axes) ntk = np.moveaxis(ntk, range(-n_marg, 0), res_diagonal_axes) return ntk / contract_size
def diagonal_between(x: np.ndarray, start_axis: int = 0, end_axis: int = -1) -> np.ndarray: """Returns the diagonal along all dimensions between start and end axes.""" if end_axis == -1: end_axis = x.ndim half_ndim, ragged = divmod(end_axis - start_axis, 2) if ragged: raise ValueError( f'Need even number of axes to flatten, got {end_axis - start_axis}.' ) if half_ndim == 0: return x side_shape = x.shape[start_axis:start_axis + half_ndim] side_size = size_at(side_shape) shape_2d = x.shape[:start_axis] + (side_size, side_size) + x.shape[end_axis:] shape_result = x.shape[:start_axis] + side_shape + x.shape[end_axis:] x = np.diagonal(x.reshape(shape_2d), axis1=start_axis, axis2=start_axis + 1) x = np.moveaxis(x, -1, start_axis) return x.reshape(shape_result)
def reshape(m): if m is not None: if m.shape[self.channel_axis] != 1: raise NotImplementedError( f'Different channel-wise masks are not supported for ' f'infinite-width layers now (got `mask.shape == {m.shape}). ' f'Please describe your use case at ' f'https://github.com/google/neural-tangents/issues/new' ) m = np.squeeze( np.moveaxis(m, (self.batch_axis, self.channel_axis), (0, -1)), -1) if self.is_reversed: m = np.moveaxis(m, range(1, m.ndim), range(m.ndim - 1, 0, -1)) return m
def dot_general(lhs: np.ndarray, rhs: np.ndarray, contracting_dims: Axes, batch_dims: Axes, precision=None) -> np.ndarray: """`jax.lax.dot_general` with preserved dims order and shared lhs / rhs dims. Precisely, returns `jax.lax.dot_general(lhs, rhs, dimension_numbers)` where `dimension_numbers == ((contracting_dims, contracting_dims), (batch_dims, batch_dims))`, but preserves the dimension order in the output. See XLA's `DotGeneral<https://www.tensorflow.org/xla/operation_semantics#dotgeneral>`. Args: lhs: array. rhs: array, must have the same dimensionality as `lhs`. contracting_dims: contracting dimensions. batch_dims: batch dimensions. precision: Optional. Either `None`, which means the default precision for the backend, or a `Precision` enum value. Returns: Dot product result with preserved dimension order. """ contracting_dims = canonicalize_axis(contracting_dims, lhs) batch_dims = canonicalize_axis(batch_dims, lhs) n_batch_dims = len(batch_dims) leading_batch_dims = range(n_batch_dims) dimension_numbers = ((contracting_dims, contracting_dims), (leading_batch_dims, leading_batch_dims)) lhs = np.moveaxis(lhs, batch_dims, leading_batch_dims) if rhs is None: rhs = lhs else: rhs = np.moveaxis(rhs, batch_dims, leading_batch_dims) prod = tf_dot_general(lhs, rhs, dimension_numbers) prod = zip_axes(prod, n_batch_dims) res_batch_dims = get_res_batch_dims(contracting_dims, batch_dims) prod = np.moveaxis(prod, leading_batch_dims, res_batch_dims) return prod
def reverse_zipped(mat: np.ndarray, start_axis: int = 0) -> np.ndarray: if mat is not None: source_axes = tuple(j for i in range(mat.ndim - 2, start_axis - 1, -2) for j in (i, i + 1)) target_axes = range(start_axis, mat.ndim) mat = np.moveaxis(mat, source_axes, target_axes) return mat
def cho_solve(b: np.ndarray, b_axes: Axes) -> np.ndarray: b_axes = utils.canonicalize_axis(b_axes, b) last_b_axes = range(-len(b_axes), 0) x_shape = x_non_channel_shape + tuple(b.shape[a] for a in b_axes) b = np.moveaxis(b, b_axes, last_b_axes) b = b.reshape((A.shape[1], -1)) x = np.asarray(tf.linalg.cholesky_solve(C, b)) x = x.reshape(x_shape) return x
def predict_fn_inf(fx_train_0, fx_test_0, k_test_train): fx_train_t = y_train.astype(k_train_train.dtype) if fx_test_0 is None: return fx_train_t rhs = y_train if fx_train_0 is None else y_train - fx_train_0 dfx_test = np.tensordot(k_test_train, solve(rhs, trace_axes), (odd, first)) dfx_test = np.moveaxis(dfx_test, last_t_axes, trace_axes) fx_test_t = fx_test_0 + dfx_test if fx_train_0 is None: return fx_test_t return fx_train_t, fx_test_t
def _zip_axes(x: np.ndarray, start_axis: int = 0, end_axis: int = -1, unzip: bool = False) -> np.ndarray: """Zip/unzip (interleave/de-interleave) axes starting from `start_axis`. Changes the shape as follows: If `unzip == True`: `[..., X, X, ..., Y, Y, ..., Z, Z, ...] -> [..., X, Y, Z, ..., X, Y, Z, ..]` If `unzip == False`: `[..., X, Y, Z, ..., X, Y, Z, ...] -> [..., X, X, ..., Y, Y, ..., Z, Z, ..]` Args: x: `np.ndarray` with an even number of dimensions following `start_axis`. start_axis: `int`, number of axis from which to zip/unzip. end_axis: `int`, number of axis until which to zip/unzip. unzip: `bool`, set to `True` to unzip instead of zip. Returns: A `np.ndarray` with a new shape. """ if end_axis == -1: end_axis = len(x.shape) half_ndim, ragged = divmod(end_axis - start_axis, 2) if ragged: raise ValueError( f'Need even number of axes to zip, got {end_axis - start_axis}.') odd_axes = range(start_axis + 1, end_axis, 2) last_axes = range(end_axis - half_ndim, end_axis) if unzip: x = np.moveaxis(x, odd_axes, last_axes) else: x = np.moveaxis(x, last_axes, odd_axes) return x
def dstate_dt(state_t: ODEState, unused_t) -> ODEState: fx_train_t, fx_test_t, qx_train_t, qx_test_t = (state_t.fx_train, state_t.fx_test, state_t.qx_train, state_t.qx_test) dy_df_t = grad_loss(fx_train_t) fx_train_t = -np.moveaxis( np.tensordot(k_train_train, dy_df_t, (odd, non_t_axes)), last_t_axes, trace_axes) if fx_test_t is not None: fx_test_t = -np.moveaxis( np.tensordot(k_test_train, dy_df_t, (odd, non_t_axes)), last_t_axes, trace_axes) if momentum is None: return ODEState(fx_train_t, fx_test_t) # pytype: disable=wrong-arg-count fx_train_t += momentum * qx_train_t if qx_test_t is not None: fx_test_t += momentum * qx_test_t return ODEState(qx_train_t, qx_test_t, fx_train_t, fx_test_t) # pytype: disable=wrong-arg-count
def _trace_and_diagonal(ntk: np.ndarray, trace_axes: Axes, diagonal_axes: Axes) -> np.ndarray: """Extract traces and diagonals along respective pairs of axes from the `ntk`. Args: ntk: input empirical NTK of shape `(N1, X, Y, Z, ..., N2, X, Y, Z, ...)`. trace_axes: axes (among `X, Y, Z, ...`) to trace over, i.e. compute the trace along and remove the respective pairs of axes from the `ntk`. diagonal_axes: axes (among `X, Y, Z, ...`) to take the diagonal along, i.e. extract the diagonal along the respective pairs of axes from the `ntk` (and hence reduce the resulting `ntk` axes count by 2). Returns: An array of shape, for example, `(N1, N2, Y, Z, Z, ...)` if `trace_axes=(1,)` (`X` axes removed), and `diagonal_axes=(2,)` (`Y` axes replaced with a single `Y` axis). """ if ntk.ndim % 2 == 1: raise ValueError('Expected an even-dimensional kernel. Please file a bug at' 'https://github.com/google/neural-tangents/issues/new') output_ndim = ntk.ndim // 2 trace_axes = utils.canonicalize_axis(trace_axes, output_ndim) diagonal_axes = utils.canonicalize_axis(diagonal_axes, output_ndim) n_diag, n_trace = len(diagonal_axes), len(trace_axes) contract_size = utils.size_at(ntk.shape[:output_ndim], trace_axes) for i, c in enumerate(reversed(trace_axes)): ntk = np.trace(ntk, axis1=c, axis2=output_ndim + c - i) for i, d in enumerate(diagonal_axes): axis1 = d - i axis2 = output_ndim + d - 2 * i - n_trace for c in trace_axes: if c < d: axis1 -= 1 axis2 -= 1 ntk = np.diagonal(ntk, axis1=axis1, axis2=axis2) ntk = utils.zip_axes(ntk, 0, ntk.ndim - n_diag) res_diagonal_axes = utils.get_res_batch_dims(trace_axes, diagonal_axes) ntk = np.moveaxis(ntk, range(-n_diag, 0), res_diagonal_axes) return ntk / contract_size
def testPredCovPosDef(self, train_shape, test_shape, network, out_logits): _, x_test, x_train, y_train = self._get_inputs(out_logits, test_shape, train_shape) _, _, ker_fun = _build_network(train_shape[1:], network, out_logits) ts = np.logspace(-3, 3, 10) predict_fn_mse_ens = predict.gradient_descent_mse_ensemble( ker_fun, x_train, y_train) for get in ('nngp', 'ntk'): for x in (None, 'x_test'): for t in (None, 'ts'): with self.subTest(get=get, x=x, t=t): cov = predict_fn_mse_ens( t=t if t is None else ts, get=get, x_test=x if x is None else x_test, compute_cov=True).covariance self.assertAllClose(cov, np.moveaxis(cov, -1, -2)) self.assertGreater(np.min(np.linalg.eigh(cov)[0]), -1e-4)
def transpose_zipped(x: np.ndarray) -> np.ndarray: return np.moveaxis(x, range(1, x.ndim, 2), range(0, x.ndim, 2))
def predict_fn( get: Get, k_test_train=None, nngp_test_test: np.ndarray = None ) -> Dict[str, Union[np.ndarray, Gaussian]]: """`test`-set posterior given respective covariance matrices. Args: get: string, the mode of the Gaussian process, either "nngp" or "ntk", or a tuple, or `None`. If `None` then both `nngp` and `ntk` predictions are returned. k_test_train: test-train kernel. Can be (a) `np.ndarray`, (b) `Kernel` namedtuple, (c) `Kernel` object. Must contain the necessary `nngp` and/or `ntk` kernels for arguments provided to the returned `predict_fn` function. For example, if you request to compute posterior test [only] NTK covariance, `k_test_train` must contain both `ntk` and `nngp` kernels. If `None`, returns predictions on the training set. Note that train-set outputs are always `N(y_train, 0)` and mostly returned for API consistency. nngp_test_test: A test-test NNGP array. Provide if you want to compute test-test posterior covariance. `nngp_test_tes=None`, means to not compute it. If `k_test_train is None`, pass any non-`None` value (e.g. `True`) if you want to get non-regularized (`diag_reg=0`) train-train posterior covariance. Note that non-regularized train-set outputs will always be the zero-variance Gaussian `N(y_train, 0)` and mostly returned for API consistency. For regularized train-set posterior outputs according to a positive `diag_reg`, pass `k_test_train=k_train_train`, and, optionally, `nngp_test_test=nngp_train_train`. Returns: Either a `Gaussian('mean', 'variance')` namedtuple or `mean` of the GP posterior on the `test` set. """ if get is None: get = ('nngp', 'ntk') out = {} for g in get: k_dd = _get_attr(k_train_train, g) k_td = None if k_test_train is None else _get_attr(k_test_train, g) if k_td is None: # Train set predictions. y = y_train.astype(k_dd.dtype) else: # Test set predictions. y = np.tensordot(k_td, k_inv_y(g), (odd, first)) y = np.moveaxis(y, range(-len(trace_axes), 0), trace_axes) if nngp_test_test is not None: if k_td is None: out[g] = Gaussian(y, np.zeros_like(k_dd, k_dd.dtype)) else: if (g == 'ntk' and (not hasattr(k_train_train, 'nngp') or not hasattr(k_test_train, 'nngp'))): raise ValueError( 'If `"ntk" in get`, and `nngp_test_test is not None`, ' 'and `k_test_train is not None`, i.e. you request the ' 'NTK posterior covariance on the test set, you need ' 'both NTK and NNGP train-train and test-train matrices' 'contained in `k_test_train` and `k_train_train`. ' 'Hence they must be `namedtuple`s with `nngp` and ' '`ntk` attributes.') k_td_nngp_inv_y = solve(g)(_get_attr(k_test_train, 'nngp'), even) if g == 'nngp': cov = np.tensordot(k_td, k_td_nngp_inv_y, (odd, first)) cov = nngp_test_test - utils.zip_axes(cov) out[g] = Gaussian(y, cov) elif g == 'ntk': term_1 = solve(g)(k_td, even) cov = np.tensordot(_get_attr(k_train_train, 'nngp'), term_1, (odd, first)) cov = np.tensordot(term_1, cov, (first, first)) term_2 = np.tensordot(k_td, k_td_nngp_inv_y, (odd, first)) term_2 += np.moveaxis(term_2, first, last) cov = utils.zip_axes(cov - term_2) + nngp_test_test out[g] = Gaussian(y, cov) else: raise ValueError(g) else: out[g] = y return out
def gradient_descent_mse_ensemble(kernel_fn: KernelFn, x_train: np.ndarray, y_train: np.ndarray, learning_rate: float = 1., diag_reg: float = 0.0, diag_reg_absolute_scale: bool = False, trace_axes: Axes = (-1, ), **kernel_fn_kwargs): r"""Predicts the gaussian embedding induced by gradient descent on MSE loss. This is equivalent to an infinite ensemble of infinite-width networks after marginalizing out the initialization, if `kernel_fn` is the kernel function of the infinite-width network. Note that `kernel_fn` can in principle also be an empirical / Monte Carlo finite-width kernel function, but in this case the returned output will not have a simple interpretation (unless these functions are used to approximate the infinite-width kernel). Note that first invocation of the returned `predict_fn` will be slow and allocate a lot of memory for its whole lifetime, as the kernel computation, and either eigendecomposition (`t` is a scalar or an array) or Cholesky factorization (`t=None`) of `kernel_fn(x_train, None, get)` is performed and cached for future invocations (or both, if the function is called on both finite and infinite (`t=None`) times). Args: kernel_fn: A kernel function that computes NNGP and/or NTK. Must have a signature `kernel_fn(x1, x2, get, **kernel_fn_kwargs)` and return a `Kernel` object or a `namedtuple` with `nngp` and/or `ntk` attributes. Therefore, it can be an `AnalyticKernelFn`, but also a `MonteCarloKernelFn`, or an `EmpiricalKernelFn` (but only `nt.empirical_kernel_fn` and not `nt.empirical_ntk_fn` or `ntk.empirical_nngp_fn`, since the latter two do not accept a `get` argument). Note that for meaningful outputs, the kernel function must represent or at least approximate the infinite-width kernel. x_train: training inputs. y_train: training targets. learning_rate: learning rate, step size. diag_reg: a scalar representing the strength of the diagonal regularization for `kernel_fn(x_train, None, get)`, i.e. computing `kernel_fn(x_train, None, get) + diag_reg * I` during Cholesky factorization or eigendecomposition. diag_reg_absolute_scale: `True` for `diag_reg` to represent regularization in absolute units, `False` to be `diag_reg * np.mean(np.trace(kernel_fn(x_train, None, get)))`. trace_axes: `f(x_train)` axes such that `kernel_fn(x_train, None, get)`, `kernel_fn(x_test, x_train, get)`[, and `kernel_fn(x_test, None, get)`] lack these pairs of dimensions and are to be interpreted as :math:`\Theta \otimes I`, i.e. block-diagonal along `trace_axes`. These can can be specified either to save space and compute, or to even improve approximation accuracy of the infinite-width or infinite-samples limit, since in in these limits the covariance along channel / feature / logit axes indeed converges to a constant-diagonal matrix. However, if you target linearized dynamics of a specific finite-width network, `trace_axes=()` will yield most accurate result. **kernel_fn_kwargs: optional keyword arguments passed to `kernel_fn`. Returns: A function with signature `predict_fn(t, x_test, get, compute_cov)` returning either mean or mean and covariance of the infinite ensemble of infinite-width networks outputs on `x_test` at time[s] `t`, in the `get` regime (`"nngp"`, `"ntk"`, or `("nngp", "ntk")`). """ expm1 = _make_expm1_fn(y_train.size) inv_expm1 = _make_inv_expm1_fn(y_train.size) trace_axes = utils.canonicalize_axis(trace_axes, y_train) trace_axes = tuple(-y_train.ndim + a for a in trace_axes) n_trace_axes = len(trace_axes) last_t_axes = range(-n_trace_axes, 0) trace_shape = tuple(y_train.shape[a] for a in trace_axes) y_train_flat = np.moveaxis(y_train, trace_axes, last_t_axes).reshape((-1, ) + trace_shape) k_dd_cache = {} def get_k_train_train(get: Tuple[str, ...]) -> _Kernel: if len(get) == 1: get = get[0] if get not in k_dd_cache: k_dd_cache[get] = kernel_fn(x_train, None, get, **kernel_fn_kwargs) elif len(get) == 2: if not any(g in k_dd_cache for g in get): k_dd_cache.update( kernel_fn(x_train, None, get, **kernel_fn_kwargs)._asdict()) else: for g in get: if g not in k_dd_cache: k_dd_cache[g] = kernel_fn(x_train, None, g, **kernel_fn_kwargs) else: raise ValueError(get) return _Kernel(**k_dd_cache) @lru_cache(2) def eigenspace(get: str): k_dd = getattr(get_k_train_train((get, )), get) k_dd = _add_diagonal_regularizer(utils.make_2d(k_dd), diag_reg, diag_reg_absolute_scale) return tf.linalg.eigh(k_dd) @lru_cache(4) def predict_inf(get: Get): _, get = utils.canonicalize_get(get) k_dd = get_k_train_train(get) return gp_inference(k_dd, y_train, diag_reg, diag_reg_absolute_scale, trace_axes) def get_matrices(get: Get, x_test: Optional[np.ndarray], compute_cov: bool): get = _get_dependency(get, compute_cov) k_dd = get_k_train_train(get) if x_test is None: k_td = None nngp_tt = compute_cov or None else: k_td = kernel_fn(x_test, x_train, get, **kernel_fn_kwargs) if compute_cov: nngp_tt = kernel_fn(x_test, None, 'nngp', **kernel_fn_kwargs) else: nngp_tt = None return k_dd, k_td, nngp_tt @utils.get_namedtuple('Gaussians') def predict_fn(t: ArrayOrScalar = None, x_test: np.ndarray = None, get: Get = None, compute_cov: bool = False) -> Dict[str, Gaussian]: """Return output mean and covariance on the test set at time[s] `t`. Args: t: a scalar of array of scalars of any shape. `t=None` is treated as infinity and returns the same result as `t=np.inf`, but is computed using linear solve for test predictions instead of eigendecomposition, saving time and precision. x_test: test inputs. `None` means to return non-regularized (`diag_reg=0`) predictions on the train-set inputs. For regularized predictions, pass `x_test=x_train`. get: string, the mode of the Gaussian process, either "nngp" or "ntk", or a tuple. `get=None` is equivalent to `get=("nngp", "ntk")`. compute_cov: if `True` computing both `mean` and `variance` and only `mean` otherwise. Returns: `fx_test_mean_t` or `(fx_test_mean_t, fx_test_cov_t)` if `compute_cov == True` with potentially additional leading time dimensions. """ if get is None: get = ('nngp', 'ntk') # train-train, test-train, test-test. k_dd, k_td, nngp_tt = get_matrices(get, x_test, compute_cov) # Infinite time. if t is None: return predict_inf(get)(get=get, k_test_train=k_td, nngp_test_test=nngp_tt) # Finite time. t = np.array(t) * learning_rate t_shape = t.shape t = t.reshape((-1, 1)) def reshape_mean(mean): k = _get_first(k_dd if k_td is None else k_td) mean = mean.reshape(t_shape + k.shape[::2] + trace_shape) mean = np.moveaxis(mean, last_t_axes, trace_axes) return mean def reshape_cov(cov): k = _get_first(k_dd if k_td is None else k_td) cov_shape_t = t_shape + k.shape[::2] * 2 return utils.zip_axes(cov.reshape(cov_shape_t), len(t_shape)) out = {} for g in get: evals, evecs = eigenspace(g) # Training set. if k_td is None: mean = tf.einsum('ji,ti,ki,k...->tj...', evecs, -expm1(evals, t), evecs, y_train_flat, optimize=True) # Test set. else: neg_inv_expm1 = -inv_expm1(evals, t) ktd_g = utils.make_2d(getattr(k_td, g)) mean = tf.einsum('lj,ji,ti,ki,k...->tl...', ktd_g, evecs, neg_inv_expm1, evecs, y_train_flat, optimize=True) mean = reshape_mean(mean) if nngp_tt is not None: nngp_dd = utils.make_2d(k_dd.nngp) # Training set. if k_td is None: if g == 'nngp': cov = np.einsum('ji,ti,ki->tjk', evecs, (np.maximum(evals, 0.) * np.exp(-2 * np.maximum(evals, 0.) * t / y_train.size)), evecs, optimize=True) elif g == 'ntk': exp = np.einsum('mi,ti,ki->tmk', evecs, np.exp(-np.maximum(evals, 0.) * t / y_train.size), evecs, optimize=True) cov = np.einsum('tmk,kl,tnl->tmn', exp, nngp_dd, exp, optimize=True) else: raise ValueError(g) # Test set. else: _nngp_tt = utils.make_2d(nngp_tt) if g == 'nngp': cov = _nngp_tt - np.einsum('mj,ji,ti,ki,lk->tml', ktd_g, evecs, -inv_expm1(evals, 2 * t), evecs, ktd_g, optimize=True) elif g == 'ntk': term_1 = np.einsum('mi,ti,ki,lk->tml', evecs, neg_inv_expm1, evecs, ktd_g, optimize=True) term_2 = np.einsum( 'mj,ji,ti,ki,lk->tml', ktd_g, evecs, neg_inv_expm1, evecs, utils.make_2d(k_td.nngp), # pytype:disable=attribute-error optimize=True) term_2 += np.moveaxis(term_2, 1, 2) cov = np.einsum('tji,jk,tkl->til', term_1, nngp_dd, term_1, optimize=True) cov += -term_2 + _nngp_tt else: raise ValueError(g) out[g] = Gaussian(mean, reshape_cov(cov)) else: out[g] = mean return out return predict_fn
def reshape_mean(mean): k = _get_first(k_dd if k_td is None else k_td) mean = mean.reshape(t_shape + k.shape[::2] + trace_shape) mean = np.moveaxis(mean, last_t_axes, trace_axes) return mean
def predict_fn(t: ArrayOrScalar = None, x_test: np.ndarray = None, get: Get = None, compute_cov: bool = False) -> Dict[str, Gaussian]: """Return output mean and covariance on the test set at time[s] `t`. Args: t: a scalar of array of scalars of any shape. `t=None` is treated as infinity and returns the same result as `t=np.inf`, but is computed using linear solve for test predictions instead of eigendecomposition, saving time and precision. x_test: test inputs. `None` means to return non-regularized (`diag_reg=0`) predictions on the train-set inputs. For regularized predictions, pass `x_test=x_train`. get: string, the mode of the Gaussian process, either "nngp" or "ntk", or a tuple. `get=None` is equivalent to `get=("nngp", "ntk")`. compute_cov: if `True` computing both `mean` and `variance` and only `mean` otherwise. Returns: `fx_test_mean_t` or `(fx_test_mean_t, fx_test_cov_t)` if `compute_cov == True` with potentially additional leading time dimensions. """ if get is None: get = ('nngp', 'ntk') # train-train, test-train, test-test. k_dd, k_td, nngp_tt = get_matrices(get, x_test, compute_cov) # Infinite time. if t is None: return predict_inf(get)(get=get, k_test_train=k_td, nngp_test_test=nngp_tt) # Finite time. t = np.array(t) * learning_rate t_shape = t.shape t = t.reshape((-1, 1)) def reshape_mean(mean): k = _get_first(k_dd if k_td is None else k_td) mean = mean.reshape(t_shape + k.shape[::2] + trace_shape) mean = np.moveaxis(mean, last_t_axes, trace_axes) return mean def reshape_cov(cov): k = _get_first(k_dd if k_td is None else k_td) cov_shape_t = t_shape + k.shape[::2] * 2 return utils.zip_axes(cov.reshape(cov_shape_t), len(t_shape)) out = {} for g in get: evals, evecs = eigenspace(g) # Training set. if k_td is None: mean = tf.einsum('ji,ti,ki,k...->tj...', evecs, -expm1(evals, t), evecs, y_train_flat, optimize=True) # Test set. else: neg_inv_expm1 = -inv_expm1(evals, t) ktd_g = utils.make_2d(getattr(k_td, g)) mean = tf.einsum('lj,ji,ti,ki,k...->tl...', ktd_g, evecs, neg_inv_expm1, evecs, y_train_flat, optimize=True) mean = reshape_mean(mean) if nngp_tt is not None: nngp_dd = utils.make_2d(k_dd.nngp) # Training set. if k_td is None: if g == 'nngp': cov = np.einsum('ji,ti,ki->tjk', evecs, (np.maximum(evals, 0.) * np.exp(-2 * np.maximum(evals, 0.) * t / y_train.size)), evecs, optimize=True) elif g == 'ntk': exp = np.einsum('mi,ti,ki->tmk', evecs, np.exp(-np.maximum(evals, 0.) * t / y_train.size), evecs, optimize=True) cov = np.einsum('tmk,kl,tnl->tmn', exp, nngp_dd, exp, optimize=True) else: raise ValueError(g) # Test set. else: _nngp_tt = utils.make_2d(nngp_tt) if g == 'nngp': cov = _nngp_tt - np.einsum('mj,ji,ti,ki,lk->tml', ktd_g, evecs, -inv_expm1(evals, 2 * t), evecs, ktd_g, optimize=True) elif g == 'ntk': term_1 = np.einsum('mi,ti,ki,lk->tml', evecs, neg_inv_expm1, evecs, ktd_g, optimize=True) term_2 = np.einsum( 'mj,ji,ti,ki,lk->tml', ktd_g, evecs, neg_inv_expm1, evecs, utils.make_2d(k_td.nngp), # pytype:disable=attribute-error optimize=True) term_2 += np.moveaxis(term_2, 1, 2) cov = np.einsum('tji,jk,tkl->til', term_1, nngp_dd, term_1, optimize=True) cov += -term_2 + _nngp_tt else: raise ValueError(g) out[g] = Gaussian(mean, reshape_cov(cov)) else: out[g] = mean return out