def _index_and_contract(ntk: np.ndarray, trace_axes: Axes, diagonal_axes: Axes) -> np.ndarray: if ntk.ndim % 2 == 1: raise ValueError('Expected an even-dimensional kernel. Please file a bug at' 'https://github.com/google/neural-tangents/issues/new') output_ndim = ntk.ndim // 2 trace_axes = utils.canonicalize_axis(trace_axes, output_ndim) diagonal_axes = utils.canonicalize_axis(diagonal_axes, output_ndim) n_marg = len(diagonal_axes) contract_size = utils.size_at(ntk.shape[:output_ndim], trace_axes) shrink = 0 for c in reversed(trace_axes): ntk = np.trace(ntk, axis1=c, axis2=output_ndim + c - shrink) shrink += 1 for i, d in enumerate(diagonal_axes): ntk = np.diagonal(ntk, axis1=d - i, axis2=output_ndim + d - shrink - 2 * i) ntk = utils.zip_axes(ntk, 0, ntk.ndim - n_marg) res_diagonal_axes = utils.get_res_batch_dims(trace_axes, diagonal_axes) ntk = np.moveaxis(ntk, range(-n_marg, 0), res_diagonal_axes) return ntk / contract_size
def ntk_fn(x1: np.ndarray, x2: Optional[np.ndarray], params: PyTree, keys: Union[PRNGKey, Tuple[PRNGKey, PRNGKey], np.ndarray] = None, **apply_fn_kwargs) -> np.ndarray: """Computes a single sample of the empirical NTK (jacobian outer product). Args: x1: first batch of inputs. x2: second batch of inputs. `x2=None` means `x2=x1`. `f(x2)` must have a matching shape with `f(x1)` on `trace_axes` and `diagonal_axes`. params: A `PyTree` of parameters about which we would like to compute the neural tangent kernel. keys: `None` or a PRNG key or a tuple of PRNG keys or a (2, 2) array of dtype `uint32`. If `key=None`, then the function `f` is deterministic and requires no PRNG key; else if `keys` is a single PRNG key, then `x1` and `x2` must be the same and share the same PRNG key; else `x1` and `x2` use two different PRNG keys. **apply_fn_kwargs: keyword arguments passed to `apply_fn`. Returns: A single sample of the empirical NTK. The shape of the kernel is "almost" `zip(f(x1).shape, f(x2).shape)` except for: 1) `trace_axes` are absent as they are contracted over. 2) `diagonal_axes` are present only once. All other axes are present twice. """ key1, key2 = _read_keys(keys) f1 = _get_f_params(f, x1, key1, **apply_fn_kwargs) with tf.GradientTape() as tape: tape.watch(params) y = f1(params) j1 = np.asarray(tape.jacobian(y, params)) # jac_fn1 = jacobian(f1) # j1 = jac_fn1(params) if x2 is None: j2 = j1 else: f2 = _get_f_params(f, x2, key2, **apply_fn_kwargs) with tf.GradientTape() as tape: tape.watch(params) y = f2(params) j2 = np.asarray(tape.jacobian(y, params)) # jac_fn2 = jacobian(f2) # j2 = jac_fn2(params) fx1 = eval_on_shapes(f1)(params) ntk = sum_and_contract(j1, j2, fx1.ndim) return ntk / utils.size_at(fx1, trace_axes)
def sum_and_contract(fx, j1, j2): ndim = fx.ndim size = utils.size_at(fx, trace_axes) _diagonal_axes = utils.canonicalize_axis(diagonal_axes, ndim) _trace_axes = utils.canonicalize_axis(trace_axes, ndim) def contract(x, y): param_axes = list(range(x.ndim))[ndim:] contract_axes = _trace_axes + param_axes return utils.dot_general(x, y, contract_axes, _diagonal_axes) / size return tree_reduce(operator.add, tree_multimap(contract, j1, j2))
def _trace_and_diagonal(ntk: np.ndarray, trace_axes: Axes, diagonal_axes: Axes) -> np.ndarray: """Extract traces and diagonals along respective pairs of axes from the `ntk`. Args: ntk: input empirical NTK of shape `(N1, X, Y, Z, ..., N2, X, Y, Z, ...)`. trace_axes: axes (among `X, Y, Z, ...`) to trace over, i.e. compute the trace along and remove the respective pairs of axes from the `ntk`. diagonal_axes: axes (among `X, Y, Z, ...`) to take the diagonal along, i.e. extract the diagonal along the respective pairs of axes from the `ntk` (and hence reduce the resulting `ntk` axes count by 2). Returns: An array of shape, for example, `(N1, N2, Y, Z, Z, ...)` if `trace_axes=(1,)` (`X` axes removed), and `diagonal_axes=(2,)` (`Y` axes replaced with a single `Y` axis). """ if ntk.ndim % 2 == 1: raise ValueError( 'Expected an even-dimensional kernel. Please file a bug at' 'https://github.com/google/neural-tangents/issues/new') output_ndim = ntk.ndim // 2 trace_axes = utils.canonicalize_axis(trace_axes, output_ndim) diagonal_axes = utils.canonicalize_axis(diagonal_axes, output_ndim) n_diag, n_trace = len(diagonal_axes), len(trace_axes) contract_size = utils.size_at(ntk.shape[:output_ndim], trace_axes) for i, c in enumerate(reversed(trace_axes)): ntk = np.trace(ntk, axis1=c, axis2=output_ndim + c - i) for i, d in enumerate(diagonal_axes): axis1 = d - i axis2 = output_ndim + d - 2 * i - n_trace for c in trace_axes: if c < d: axis1 -= 1 axis2 -= 1 ntk = np.diagonal(ntk, axis1=axis1, axis2=axis2) ntk = utils.zip_axes(ntk, 0, ntk.ndim - n_diag) res_diagonal_axes = utils.get_res_batch_dims(trace_axes, diagonal_axes) ntk = np.moveaxis(ntk, range(-n_diag, 0), res_diagonal_axes) return ntk / contract_size
def ntk_fn(x1: np.ndarray, x2: Optional[np.ndarray], params: PyTree, **apply_fn_kwargs) -> np.ndarray: """Computes a single sample of the empirical NTK (jacobian outer product). Args: x1: first batch of inputs. x2: second batch of inputs. `x2=None` means `x2=x1`. `f(x2)` must have a matching shape with `f(x1)` on `trace_axes` and `diagonal_axes`. params: A `PyTree` of parameters about which we would like to compute the neural tangent kernel. **apply_fn_kwargs: keyword arguments passed to `apply_fn`. `apply_fn_kwargs` will be split into `apply_fn_kwargs1` and `apply_fn_kwargs2` by the `_split_kwargs` function which will be passed to `apply_fn`. In particular, the rng key in `apply_fn_kwargs`, will be split into two different (if `x1!=x2`) or same (if `x1==x2`) rng keys. See the `_read_key` function for more details. Returns: A single sample of the empirical NTK. The shape of the kernel is "almost" `zip(f(x1).shape, f(x2).shape)` except for: 1) `trace_axes` are absent as they are contracted over. 2) `diagonal_axes` are present only once. All other axes are present twice. """ apply_fn_kwargs1, apply_fn_kwargs2 = _split_kwargs(apply_fn_kwargs, x1, x2) f1 = _get_f_params(f, x1, **apply_fn_kwargs1) with tf.GradientTape() as tape: tape.watch(params) y = f1(params) j1 = np.asarray(tape.jacobian(y, params)) if x2 is None: j2 = j1 else: f2 = _get_f_params(f, x2, **apply_fn_kwargs2) with tf.GradientTape() as tape: tape.watch(params) y = f2(params) j2 = np.asarray(tape.jacobian(y, params)) fx1 = eval_on_shapes(f1)(params) ntk = sum_and_contract(j1, j2, fx1.ndim) return ntk / utils.size_at(fx1, trace_axes)
def nngp_fn(x1: np.ndarray, x2: Optional[np.ndarray], params: PyTree, keys: Union[PRNGKey, Tuple[PRNGKey, PRNGKey], np.ndarray] = None, **apply_fn_kwargs) -> np.ndarray: """Computes a single sample of the empirical NNGP. Args: x1: first batch of inputs. x2: second batch of inputs. `x2=None` means `x2=x1`. `f(x2)` must have a matching shape with `f(x1)` on `trace_axes` and `diagonal_axes`. params: A `PyTree` of parameters about which we would like to compute the neural tangent kernel. keys: `None` or a PRNG key or a tuple of PRNG keys or a (2, 2) array of dtype `uint32`. If `key=None`, then the function `f` is deterministic and requires no PRNG key; else if `keys` is a single PRNG key, then `x1` and `x2` must be the same and share the same PRNG key; else `x1` and `x2` use two different PRNG keys. **apply_fn_kwargs: keyword arguments passed to `apply_fn`. Returns: A single sample of the empirical NNGP. The shape of the kernel is "almost" `zip(f(x1).shape, f(x2).shape)` except for: 1) `trace_axes` are absent as they are contracted over. 2) `diagonal_axes` are present only once. All other axes are present twice. """ key1, key2 = _read_keys(keys) def output(x, rng): out = f(params, x, rng=rng, **apply_fn_kwargs) masked_output = utils.get_masked_array(out) return masked_output.masked_value out1 = output(x1, key1) if x2 is None: out2 = out1 else: out2 = output(x2, key2) dot = utils.dot_general(out1, out2, trace_axes, diagonal_axes) return dot / utils.size_at(out1, trace_axes)
def nngp_fn(x1: np.ndarray, x2: Optional[np.ndarray], params: PyTree, **apply_fn_kwargs) -> np.ndarray: """Computes a single sample of the empirical NNGP. Args: x1: first batch of inputs. x2: second batch of inputs. `x2=None` means `x2=x1`. `f(x2)` must have a matching shape with `f(x1)` on `trace_axes` and `diagonal_axes`. params: A `PyTree` of parameters about which we would like to compute the neural tangent kernel. **apply_fn_kwargs: keyword arguments passed to `apply_fn`. `apply_fn_kwargs` will be split into `apply_fn_kwargs1` and `apply_fn_kwargs2` by the `_split_kwargs` function which will be passed to `apply_fn`. In particular, the rng key in `apply_fn_kwargs`, will be split into two different (if `x1!=x2`) or same (if `x1==x2`) rng keys. See the `_read_key` function for more details. Returns: A single sample of the empirical NNGP. The shape of the kernel is "almost" `zip(f(x1).shape, f(x2).shape)` except for: 1) `trace_axes` are absent as they are contracted over. 2) `diagonal_axes` are present only once. All other axes are present twice. """ def output(x, **kwargs): out = f(params, x, **kwargs) masked_output = utils.get_masked_array(out) return masked_output.masked_value apply_fn_kwargs1, apply_fn_kwargs2 = _split_kwargs(apply_fn_kwargs, x1, x2) out1 = output(x1, **apply_fn_kwargs1) if x2 is None: out2 = out1 else: out2 = output(x2, **apply_fn_kwargs2) dot = utils.dot_general(out1, out2, trace_axes, diagonal_axes) return dot / utils.size_at(out1, trace_axes)
def contract(out1, out2): dot = utils.dot_general(out1, out2, trace_axes, diagonal_axes) return dot / utils.size_at(out1, trace_axes)