Esempio n. 1
0
    def ntk_fn(x1: NTTree[np.ndarray], x2: Optional[NTTree[np.ndarray]],
               params: PyTree, **apply_fn_kwargs) -> np.ndarray:
        """Computes a single sample of the empirical NTK (implicit differentiation).

    Args:
      x1:
        first batch of inputs.
      x2:
        second batch of inputs. `x2=None` means `x2=x1`. `f(x2)` must have a
        matching shape with `f(x1)` on `trace_axes` and `diagonal_axes`.
      params:
        A `PyTree` of parameters about which we would like to compute the
        neural tangent kernel.
      **apply_fn_kwargs:
        keyword arguments passed to `apply_fn`. `apply_fn_kwargs` will be split
        into `apply_fn_kwargs1` and `apply_fn_kwargs2` by the `split_kwargs`
        function which will be passed to `apply_fn`. In particular, the rng key
        in `apply_fn_kwargs`, will be split into two different (if `x1 != x2`)
        or same (if `x1 == x2`) rng keys. See the `_read_key` function for more
        details.

    Returns:
      A single sample of the empirical NTK. The shape of the kernel is "almost"
      `zip(f(x1).shape, f(x2).shape)` except for:
      1) `trace_axes` are absent as they are contracted over.
      2) `diagonal_axes` are present only once.
      All other axes are present twice.
    """
        kwargs1, kwargs2 = utils.split_kwargs(apply_fn_kwargs, x1, x2)
        f1 = _flatten(_get_f_params(f, x1, **kwargs1))
        f2 = (f1 if utils.all_none(x2) else _flatten(
            _get_f_params(f, x2, **kwargs2)))

        def delta_vjp_jvp(delta):
            def delta_vjp(delta):
                return vjp(f2, params)[1](delta)

            return jvp(f1, (params, ), delta_vjp(delta))[1]

        # Since we are taking the Jacobian of a linear function (which does not
        # depend on its coefficients), it is more efficient to substitute fx_dummy
        # for the outputs of the network. fx_dummy has the same shape as the output
        # of the network on a single piece of input data.
        fx2_struct = eval_shape(f2, params)

        @utils.nt_tree_fn()
        def dummy_output(fx_struct):
            return np.ones(fx_struct.shape, fx_struct.dtype)

        fx_dummy = dummy_output(fx2_struct)

        ntk = jacobian(delta_vjp_jvp)(fx_dummy)
        if utils.is_list_or_tuple(fx_dummy):
            fx_treedef = tree_structure(
                eval_shape(_get_f_params(f, x1, **kwargs1), params))
            ntk = [ntk[i][i] for i in range(len(fx_dummy))]
            ntk = tree_unflatten(fx_treedef, ntk)

        return _trace_and_diagonal(ntk, trace_axes, diagonal_axes)
Esempio n. 2
0
 def get_batch_size(x):
     if utils.is_list_or_tuple(x):
         return get_batch_size(x[0])
     return x.shape[0]
Esempio n. 3
0
 def get_n1_n2(k):
     if utils.is_list_or_tuple(k):
         # TODO(schsam): We might want to check for consistency here, but I can't
         # imagine a case where we could get inconsistent kernels.
         return get_n1_n2(k[0])
     return k.nngp.shape[:2]