def _index_and_contract(ntk: np.ndarray, trace_axes: Axes, diagonal_axes: Axes) -> np.ndarray: if ntk.ndim % 2 == 1: raise ValueError('Expected an even-dimensional kernel. Please file a bug at' 'https://github.com/google/neural-tangents/issues/new') output_ndim = ntk.ndim // 2 trace_axes = utils.canonicalize_axis(trace_axes, output_ndim) diagonal_axes = utils.canonicalize_axis(diagonal_axes, output_ndim) n_marg = len(diagonal_axes) contract_size = utils.size_at(ntk.shape[:output_ndim], trace_axes) shrink = 0 for c in reversed(trace_axes): ntk = np.trace(ntk, axis1=c, axis2=output_ndim + c - shrink) shrink += 1 for i, d in enumerate(diagonal_axes): ntk = np.diagonal(ntk, axis1=d - i, axis2=output_ndim + d - shrink - 2 * i) ntk = utils.zip_axes(ntk, 0, ntk.ndim - n_marg) res_diagonal_axes = utils.get_res_batch_dims(trace_axes, diagonal_axes) ntk = np.moveaxis(ntk, range(-n_marg, 0), res_diagonal_axes) return ntk / contract_size
def _trace_and_diagonal(ntk: np.ndarray, trace_axes: Axes, diagonal_axes: Axes) -> np.ndarray: """Extract traces and diagonals along respective pairs of axes from the `ntk`. Args: ntk: input empirical NTK of shape `(N1, X, Y, Z, ..., N2, X, Y, Z, ...)`. trace_axes: axes (among `X, Y, Z, ...`) to trace over, i.e. compute the trace along and remove the respective pairs of axes from the `ntk`. diagonal_axes: axes (among `X, Y, Z, ...`) to take the diagonal along, i.e. extract the diagonal along the respective pairs of axes from the `ntk` (and hence reduce the resulting `ntk` axes count by 2). Returns: An array of shape, for example, `(N1, N2, Y, Z, Z, ...)` if `trace_axes=(1,)` (`X` axes removed), and `diagonal_axes=(2,)` (`Y` axes replaced with a single `Y` axis). """ if ntk.ndim % 2 == 1: raise ValueError( 'Expected an even-dimensional kernel. Please file a bug at' 'https://github.com/google/neural-tangents/issues/new') output_ndim = ntk.ndim // 2 trace_axes = utils.canonicalize_axis(trace_axes, output_ndim) diagonal_axes = utils.canonicalize_axis(diagonal_axes, output_ndim) n_diag, n_trace = len(diagonal_axes), len(trace_axes) contract_size = utils.size_at(ntk.shape[:output_ndim], trace_axes) for i, c in enumerate(reversed(trace_axes)): ntk = np.trace(ntk, axis1=c, axis2=output_ndim + c - i) for i, d in enumerate(diagonal_axes): axis1 = d - i axis2 = output_ndim + d - 2 * i - n_trace for c in trace_axes: if c < d: axis1 -= 1 axis2 -= 1 ntk = np.diagonal(ntk, axis1=axis1, axis2=axis2) ntk = utils.zip_axes(ntk, 0, ntk.ndim - n_diag) res_diagonal_axes = utils.get_res_batch_dims(trace_axes, diagonal_axes) ntk = np.moveaxis(ntk, range(-n_diag, 0), res_diagonal_axes) return ntk / contract_size