def _index_and_contract(ntk: np.ndarray, trace_axes: Axes, diagonal_axes: Axes) -> np.ndarray: if ntk.ndim % 2 == 1: raise ValueError('Expected an even-dimensional kernel. Please file a bug at' 'https://github.com/google/neural-tangents/issues/new') output_ndim = ntk.ndim // 2 trace_axes = utils.canonicalize_axis(trace_axes, output_ndim) diagonal_axes = utils.canonicalize_axis(diagonal_axes, output_ndim) n_marg = len(diagonal_axes) contract_size = utils.size_at(ntk.shape[:output_ndim], trace_axes) shrink = 0 for c in reversed(trace_axes): ntk = np.trace(ntk, axis1=c, axis2=output_ndim + c - shrink) shrink += 1 for i, d in enumerate(diagonal_axes): ntk = np.diagonal(ntk, axis1=d - i, axis2=output_ndim + d - shrink - 2 * i) ntk = utils.zip_axes(ntk, 0, ntk.ndim - n_marg) res_diagonal_axes = utils.get_res_batch_dims(trace_axes, diagonal_axes) ntk = np.moveaxis(ntk, range(-n_marg, 0), res_diagonal_axes) return ntk / contract_size
def _trace_and_diagonal(ntk: np.ndarray, trace_axes: Axes, diagonal_axes: Axes) -> np.ndarray: """Extract traces and diagonals along respective pairs of axes from the `ntk`. Args: ntk: input empirical NTK of shape `(N1, X, Y, Z, ..., N2, X, Y, Z, ...)`. trace_axes: axes (among `X, Y, Z, ...`) to trace over, i.e. compute the trace along and remove the respective pairs of axes from the `ntk`. diagonal_axes: axes (among `X, Y, Z, ...`) to take the diagonal along, i.e. extract the diagonal along the respective pairs of axes from the `ntk` (and hence reduce the resulting `ntk` axes count by 2). Returns: An array of shape, for example, `(N1, N2, Y, Z, Z, ...)` if `trace_axes=(1,)` (`X` axes removed), and `diagonal_axes=(2,)` (`Y` axes replaced with a single `Y` axis). """ if ntk.ndim % 2 == 1: raise ValueError( 'Expected an even-dimensional kernel. Please file a bug at' 'https://github.com/google/neural-tangents/issues/new') output_ndim = ntk.ndim // 2 trace_axes = utils.canonicalize_axis(trace_axes, output_ndim) diagonal_axes = utils.canonicalize_axis(diagonal_axes, output_ndim) n_diag, n_trace = len(diagonal_axes), len(trace_axes) contract_size = utils.size_at(ntk.shape[:output_ndim], trace_axes) for i, c in enumerate(reversed(trace_axes)): ntk = np.trace(ntk, axis1=c, axis2=output_ndim + c - i) for i, d in enumerate(diagonal_axes): axis1 = d - i axis2 = output_ndim + d - 2 * i - n_trace for c in trace_axes: if c < d: axis1 -= 1 axis2 -= 1 ntk = np.diagonal(ntk, axis1=axis1, axis2=axis2) ntk = utils.zip_axes(ntk, 0, ntk.ndim - n_diag) res_diagonal_axes = utils.get_res_batch_dims(trace_axes, diagonal_axes) ntk = np.moveaxis(ntk, range(-n_diag, 0), res_diagonal_axes) return ntk / contract_size
def reshape_cov(cov): k = _get_first(k_dd if k_td is None else k_td) cov_shape_t = t_shape + k.shape[::2] * 2 return utils.zip_axes(cov.reshape(cov_shape_t), len(t_shape))
def predict_fn( get: Get, k_test_train=None, nngp_test_test: np.ndarray = None ) -> Dict[str, Union[np.ndarray, Gaussian]]: """`test`-set posterior given respective covariance matrices. Args: get: string, the mode of the Gaussian process, either "nngp" or "ntk", or a tuple, or `None`. If `None` then both `nngp` and `ntk` predictions are returned. k_test_train: test-train kernel. Can be (a) `np.ndarray`, (b) `Kernel` namedtuple, (c) `Kernel` object. Must contain the necessary `nngp` and/or `ntk` kernels for arguments provided to the returned `predict_fn` function. For example, if you request to compute posterior test [only] NTK covariance, `k_test_train` must contain both `ntk` and `nngp` kernels. If `None`, returns predictions on the training set. Note that train-set outputs are always `N(y_train, 0)` and mostly returned for API consistency. nngp_test_test: A test-test NNGP array. Provide if you want to compute test-test posterior covariance. `nngp_test_tes=None`, means to not compute it. If `k_test_train is None`, pass any non-`None` value (e.g. `True`) if you want to get non-regularized (`diag_reg=0`) train-train posterior covariance. Note that non-regularized train-set outputs will always be the zero-variance Gaussian `N(y_train, 0)` and mostly returned for API consistency. For regularized train-set posterior outputs according to a positive `diag_reg`, pass `k_test_train=k_train_train`, and, optionally, `nngp_test_test=nngp_train_train`. Returns: Either a `Gaussian('mean', 'variance')` namedtuple or `mean` of the GP posterior on the `test` set. """ if get is None: get = ('nngp', 'ntk') out = {} for g in get: k_dd = _get_attr(k_train_train, g) k_td = None if k_test_train is None else _get_attr(k_test_train, g) if k_td is None: # Train set predictions. y = y_train.astype(k_dd.dtype) else: # Test set predictions. y = np.tensordot(k_td, k_inv_y(g), (odd, first)) y = np.moveaxis(y, range(-len(trace_axes), 0), trace_axes) if nngp_test_test is not None: if k_td is None: out[g] = Gaussian(y, np.zeros_like(k_dd, k_dd.dtype)) else: if (g == 'ntk' and (not hasattr(k_train_train, 'nngp') or not hasattr(k_test_train, 'nngp'))): raise ValueError( 'If `"ntk" in get`, and `nngp_test_test is not None`, ' 'and `k_test_train is not None`, i.e. you request the ' 'NTK posterior covariance on the test set, you need ' 'both NTK and NNGP train-train and test-train matrices ' 'contained in `k_test_train` and `k_train_train`. ' 'Hence they must be `namedtuple`s with `nngp` and ' '`ntk` attributes.') k_td_nngp_inv_y = solve(g)(_get_attr(k_test_train, 'nngp'), even) if g == 'nngp': cov = np.tensordot(k_td, k_td_nngp_inv_y, (odd, first)) cov = nngp_test_test - utils.zip_axes(cov) out[g] = Gaussian(y, cov) elif g == 'ntk': term_1 = solve(g)(k_td, even) cov = np.tensordot(_get_attr(k_train_train, 'nngp'), term_1, (odd, first)) cov = np.tensordot(term_1, cov, (first, first)) term_2 = np.tensordot(k_td, k_td_nngp_inv_y, (odd, first)) term_2 += np.moveaxis(term_2, first, last) cov = utils.zip_axes(cov - term_2) + nngp_test_test out[g] = Gaussian(y, cov) else: raise ValueError(g) else: out[g] = y return out