Beispiel #1
0
def _index_and_contract(ntk: np.ndarray,
                        trace_axes: Axes,
                        diagonal_axes: Axes) -> np.ndarray:
  if ntk.ndim % 2 == 1:
    raise ValueError('Expected an even-dimensional kernel. Please file a bug at'
                     'https://github.com/google/neural-tangents/issues/new')

  output_ndim = ntk.ndim // 2
  trace_axes = utils.canonicalize_axis(trace_axes, output_ndim)
  diagonal_axes = utils.canonicalize_axis(diagonal_axes, output_ndim)
  n_marg = len(diagonal_axes)
  contract_size = utils.size_at(ntk.shape[:output_ndim], trace_axes)

  shrink = 0
  for c in reversed(trace_axes):
    ntk = np.trace(ntk, axis1=c, axis2=output_ndim + c - shrink)
    shrink += 1

  for i, d in enumerate(diagonal_axes):
    ntk = np.diagonal(ntk, axis1=d - i, axis2=output_ndim + d - shrink - 2 * i)

  ntk = utils.zip_axes(ntk, 0, ntk.ndim - n_marg)
  res_diagonal_axes = utils.get_res_batch_dims(trace_axes, diagonal_axes)
  ntk = np.moveaxis(ntk, range(-n_marg, 0), res_diagonal_axes)
  return ntk / contract_size
Beispiel #2
0
  def ntk_fn(x1: np.ndarray,
             x2: Optional[np.ndarray],
             params: PyTree,
             keys: Union[PRNGKey,
                         Tuple[PRNGKey, PRNGKey],
                         np.ndarray] = None,
             **apply_fn_kwargs) -> np.ndarray:
    """Computes a single sample of the empirical NTK (jacobian outer product).

    Args:
      x1:
        first batch of inputs.
      x2:
        second batch of inputs. `x2=None` means `x2=x1`. `f(x2)` must have a
        matching shape with `f(x1)` on `trace_axes` and `diagonal_axes`.
      params:
        A `PyTree` of parameters about which we would like to compute the
        neural tangent kernel.
      keys:
        `None` or a PRNG key or a tuple of PRNG keys or a (2, 2) array of
        dtype `uint32`. If `key=None`, then the function `f` is deterministic
        and requires no PRNG key; else if `keys` is a single PRNG key, then `x1`
        and `x2` must be the same and share the same PRNG key; else `x1` and
        `x2` use two different PRNG keys.
      **apply_fn_kwargs:
        keyword arguments passed to `apply_fn`.

    Returns:
      A single sample of the empirical NTK. The shape of the kernel is "almost"
      `zip(f(x1).shape, f(x2).shape)` except for:
      1) `trace_axes` are absent as they are contracted over.
      2) `diagonal_axes` are present only once.
      All other axes are present twice.
    """
    key1, key2 = _read_keys(keys)

    f1 = _get_f_params(f, x1, key1, **apply_fn_kwargs)
    with tf.GradientTape() as tape:
      tape.watch(params)
      y = f1(params)
    j1 = np.asarray(tape.jacobian(y, params))
    # jac_fn1 = jacobian(f1)
    # j1 = jac_fn1(params)
    if x2 is None:
      j2 = j1
    else:
      f2 = _get_f_params(f, x2, key2, **apply_fn_kwargs)
      with tf.GradientTape() as tape:
        tape.watch(params)
        y = f2(params)
      j2 = np.asarray(tape.jacobian(y, params))
      # jac_fn2 = jacobian(f2)
      # j2 = jac_fn2(params)

    fx1 = eval_on_shapes(f1)(params)
    ntk = sum_and_contract(j1, j2, fx1.ndim)
    return ntk / utils.size_at(fx1, trace_axes)
Beispiel #3
0
    def sum_and_contract(fx, j1, j2):
        ndim = fx.ndim
        size = utils.size_at(fx, trace_axes)

        _diagonal_axes = utils.canonicalize_axis(diagonal_axes, ndim)
        _trace_axes = utils.canonicalize_axis(trace_axes, ndim)

        def contract(x, y):
            param_axes = list(range(x.ndim))[ndim:]
            contract_axes = _trace_axes + param_axes
            return utils.dot_general(x, y, contract_axes,
                                     _diagonal_axes) / size

        return tree_reduce(operator.add, tree_multimap(contract, j1, j2))
Beispiel #4
0
def _trace_and_diagonal(ntk: np.ndarray, trace_axes: Axes,
                        diagonal_axes: Axes) -> np.ndarray:
    """Extract traces and diagonals along respective pairs of axes from the `ntk`.

  Args:
    ntk:
      input empirical NTK of shape `(N1, X, Y, Z, ..., N2, X, Y, Z, ...)`.
    trace_axes:
      axes (among `X, Y, Z, ...`) to trace over, i.e. compute the trace along
      and remove the  respective pairs of axes from the `ntk`.
    diagonal_axes:
      axes (among `X, Y, Z, ...`) to take the diagonal along, i.e. extract the
      diagonal along the respective pairs of axes from the `ntk` (and hence
      reduce the resulting `ntk` axes count by 2).
  Returns:
    An array of shape, for example, `(N1, N2, Y, Z, Z, ...)` if
    `trace_axes=(1,)` (`X` axes removed), and `diagonal_axes=(2,)` (`Y` axes
    replaced with a single `Y` axis).
  """

    if ntk.ndim % 2 == 1:
        raise ValueError(
            'Expected an even-dimensional kernel. Please file a bug at'
            'https://github.com/google/neural-tangents/issues/new')

    output_ndim = ntk.ndim // 2

    trace_axes = utils.canonicalize_axis(trace_axes, output_ndim)
    diagonal_axes = utils.canonicalize_axis(diagonal_axes, output_ndim)

    n_diag, n_trace = len(diagonal_axes), len(trace_axes)
    contract_size = utils.size_at(ntk.shape[:output_ndim], trace_axes)

    for i, c in enumerate(reversed(trace_axes)):
        ntk = np.trace(ntk, axis1=c, axis2=output_ndim + c - i)

    for i, d in enumerate(diagonal_axes):
        axis1 = d - i
        axis2 = output_ndim + d - 2 * i - n_trace
        for c in trace_axes:
            if c < d:
                axis1 -= 1
                axis2 -= 1
        ntk = np.diagonal(ntk, axis1=axis1, axis2=axis2)

    ntk = utils.zip_axes(ntk, 0, ntk.ndim - n_diag)
    res_diagonal_axes = utils.get_res_batch_dims(trace_axes, diagonal_axes)
    ntk = np.moveaxis(ntk, range(-n_diag, 0), res_diagonal_axes)
    return ntk / contract_size
Beispiel #5
0
  def ntk_fn(x1: np.ndarray,
             x2: Optional[np.ndarray],
             params: PyTree,
             **apply_fn_kwargs) -> np.ndarray:
    """Computes a single sample of the empirical NTK (jacobian outer product).

    Args:
      x1:
        first batch of inputs.
      x2:
        second batch of inputs. `x2=None` means `x2=x1`. `f(x2)` must have a
        matching shape with `f(x1)` on `trace_axes` and `diagonal_axes`.
      params:
        A `PyTree` of parameters about which we would like to compute the
        neural tangent kernel.
      **apply_fn_kwargs:
        keyword arguments passed to `apply_fn`. `apply_fn_kwargs` will be split
        into `apply_fn_kwargs1` and `apply_fn_kwargs2` by the `_split_kwargs`
        function which will be passed to `apply_fn`. In particular, the rng key
        in `apply_fn_kwargs`, will be split into two different (if `x1!=x2`) or
        same (if `x1==x2`) rng keys. See the `_read_key` function for more
        details.

    Returns:
      A single sample of the empirical NTK. The shape of the kernel is "almost"
      `zip(f(x1).shape, f(x2).shape)` except for:
      1) `trace_axes` are absent as they are contracted over.
      2) `diagonal_axes` are present only once.
      All other axes are present twice.
    """

    apply_fn_kwargs1, apply_fn_kwargs2 = _split_kwargs(apply_fn_kwargs, x1, x2)
    f1 = _get_f_params(f, x1, **apply_fn_kwargs1)
    with tf.GradientTape() as tape:
      tape.watch(params)
      y = f1(params)
    j1 = np.asarray(tape.jacobian(y, params))
    if x2 is None:
      j2 = j1
    else:
      f2 = _get_f_params(f, x2, **apply_fn_kwargs2)
      with tf.GradientTape() as tape:
        tape.watch(params)
        y = f2(params)
      j2 = np.asarray(tape.jacobian(y, params))

    fx1 = eval_on_shapes(f1)(params)
    ntk = sum_and_contract(j1, j2, fx1.ndim)
    return ntk / utils.size_at(fx1, trace_axes)
Beispiel #6
0
  def nngp_fn(x1: np.ndarray,
              x2: Optional[np.ndarray],
              params: PyTree,
              keys: Union[PRNGKey,
                          Tuple[PRNGKey, PRNGKey],
                          np.ndarray] = None,
              **apply_fn_kwargs) -> np.ndarray:
    """Computes a single sample of the empirical NNGP.

    Args:
      x1:
        first batch of inputs.
      x2:
        second batch of inputs. `x2=None` means `x2=x1`. `f(x2)` must have a
        matching shape with `f(x1)` on `trace_axes` and `diagonal_axes`.
      params:
        A `PyTree` of parameters about which we would like to compute the
        neural tangent kernel.
      keys: `None` or a PRNG key or a tuple of PRNG keys or a (2, 2) array of
        dtype `uint32`. If `key=None`, then the function `f` is deterministic
        and requires no PRNG key; else if `keys` is a single PRNG key, then `x1`
        and `x2` must be the same and share the same PRNG key; else `x1` and
        `x2` use two different PRNG keys.
      **apply_fn_kwargs:
        keyword arguments passed to `apply_fn`.

    Returns:
      A single sample of the empirical NNGP. The shape of the kernel is "almost"
      `zip(f(x1).shape, f(x2).shape)` except for:
      1) `trace_axes` are absent as they are contracted over.
      2) `diagonal_axes` are present only once.
      All other axes are present twice.
    """
    key1, key2 = _read_keys(keys)

    def output(x, rng):
      out = f(params, x, rng=rng, **apply_fn_kwargs)
      masked_output = utils.get_masked_array(out)
      return masked_output.masked_value

    out1 = output(x1, key1)
    if x2 is None:
      out2 = out1
    else:
      out2 = output(x2, key2)

    dot = utils.dot_general(out1, out2, trace_axes, diagonal_axes)
    return dot / utils.size_at(out1, trace_axes)
Beispiel #7
0
  def nngp_fn(x1: np.ndarray,
              x2: Optional[np.ndarray],
              params: PyTree,
              **apply_fn_kwargs) -> np.ndarray:
    """Computes a single sample of the empirical NNGP.

    Args:
      x1:
        first batch of inputs.
      x2:
        second batch of inputs. `x2=None` means `x2=x1`. `f(x2)` must have a
        matching shape with `f(x1)` on `trace_axes` and `diagonal_axes`.
      params:
        A `PyTree` of parameters about which we would like to compute the
        neural tangent kernel.
      **apply_fn_kwargs:
        keyword arguments passed to `apply_fn`. `apply_fn_kwargs` will be split
        into `apply_fn_kwargs1` and `apply_fn_kwargs2` by the `_split_kwargs`
        function which will be passed to `apply_fn`. In particular, the rng key
        in `apply_fn_kwargs`, will be split into two different (if `x1!=x2`) or
        same (if `x1==x2`) rng keys. See the `_read_key` function for more
        details.

    Returns:
      A single sample of the empirical NNGP. The shape of the kernel is "almost"
      `zip(f(x1).shape, f(x2).shape)` except for:
      1) `trace_axes` are absent as they are contracted over.
      2) `diagonal_axes` are present only once.
      All other axes are present twice.
    """

    def output(x, **kwargs):
      out = f(params, x, **kwargs)
      masked_output = utils.get_masked_array(out)
      return masked_output.masked_value

    apply_fn_kwargs1, apply_fn_kwargs2 = _split_kwargs(apply_fn_kwargs, x1, x2)

    out1 = output(x1, **apply_fn_kwargs1)
    if x2 is None:
      out2 = out1
    else:
      out2 = output(x2, **apply_fn_kwargs2)

    dot = utils.dot_general(out1, out2, trace_axes, diagonal_axes)
    return dot / utils.size_at(out1, trace_axes)
Beispiel #8
0
 def contract(out1, out2):
     dot = utils.dot_general(out1, out2, trace_axes, diagonal_axes)
     return dot / utils.size_at(out1, trace_axes)