Exemplo n.º 1
def _index_and_contract(ntk: np.ndarray,
                        trace_axes: Axes,
                        diagonal_axes: Axes) -> np.ndarray:
  if ntk.ndim % 2 == 1:
    raise ValueError('Expected an even-dimensional kernel. Please file a bug at'

  output_ndim = ntk.ndim // 2
  trace_axes = utils.canonicalize_axis(trace_axes, output_ndim)
  diagonal_axes = utils.canonicalize_axis(diagonal_axes, output_ndim)
  n_marg = len(diagonal_axes)
  contract_size = utils.size_at(ntk.shape[:output_ndim], trace_axes)

  shrink = 0
  for c in reversed(trace_axes):
    ntk = np.trace(ntk, axis1=c, axis2=output_ndim + c - shrink)
    shrink += 1

  for i, d in enumerate(diagonal_axes):
    ntk = np.diagonal(ntk, axis1=d - i, axis2=output_ndim + d - shrink - 2 * i)

  ntk = utils.zip_axes(ntk, 0, ntk.ndim - n_marg)
  res_diagonal_axes = utils.get_res_batch_dims(trace_axes, diagonal_axes)
  ntk = np.moveaxis(ntk, range(-n_marg, 0), res_diagonal_axes)
  return ntk / contract_size
Exemplo n.º 2
def _trace_and_diagonal(ntk: np.ndarray, trace_axes: Axes,
                        diagonal_axes: Axes) -> np.ndarray:
    """Extract traces and diagonals along respective pairs of axes from the `ntk`.

      input empirical NTK of shape `(N1, X, Y, Z, ..., N2, X, Y, Z, ...)`.
      axes (among `X, Y, Z, ...`) to trace over, i.e. compute the trace along
      and remove the  respective pairs of axes from the `ntk`.
      axes (among `X, Y, Z, ...`) to take the diagonal along, i.e. extract the
      diagonal along the respective pairs of axes from the `ntk` (and hence
      reduce the resulting `ntk` axes count by 2).
    An array of shape, for example, `(N1, N2, Y, Z, Z, ...)` if
    `trace_axes=(1,)` (`X` axes removed), and `diagonal_axes=(2,)` (`Y` axes
    replaced with a single `Y` axis).

    if ntk.ndim % 2 == 1:
        raise ValueError(
            'Expected an even-dimensional kernel. Please file a bug at'

    output_ndim = ntk.ndim // 2

    trace_axes = utils.canonicalize_axis(trace_axes, output_ndim)
    diagonal_axes = utils.canonicalize_axis(diagonal_axes, output_ndim)

    n_diag, n_trace = len(diagonal_axes), len(trace_axes)
    contract_size = utils.size_at(ntk.shape[:output_ndim], trace_axes)

    for i, c in enumerate(reversed(trace_axes)):
        ntk = np.trace(ntk, axis1=c, axis2=output_ndim + c - i)

    for i, d in enumerate(diagonal_axes):
        axis1 = d - i
        axis2 = output_ndim + d - 2 * i - n_trace
        for c in trace_axes:
            if c < d:
                axis1 -= 1
                axis2 -= 1
        ntk = np.diagonal(ntk, axis1=axis1, axis2=axis2)

    ntk = utils.zip_axes(ntk, 0, ntk.ndim - n_diag)
    res_diagonal_axes = utils.get_res_batch_dims(trace_axes, diagonal_axes)
    ntk = np.moveaxis(ntk, range(-n_diag, 0), res_diagonal_axes)
    return ntk / contract_size