Пример #1
0
    def ntk_fn(x1, x2, params, keys=None):
        """Computes the empirical ntk.

    Args:
      x1: A first `np.ndarray` of inputs, of shape [n1, ...], over which we
        would like to compute the NTK.
      x2: A second `np.ndarray` of inputs, of shape [n2, ...], over which we
        would like to compute the NTK.
      params: A PyTree of parameters about which we would like to compute the
        neural tangent kernel.
      keys: None or a PRNG key or a tuple of PRNG keys or a (2, 2) array and
        dtype uint32. If `key == None`, then the function `f` is deterministic
        and requires no PRNG key; else if `keys` is a single PRNG key, then x1
        and x2 share the same PRNG key; else x1 and x2 use two different PRNG
        keys.

    Returns:
      A `np.ndarray` of shape [n1, n2] + output_shape + output_shape.
    """
        key1, key2 = _read_keys(keys)
        f1 = partial(f, rng=key1)
        jac_fn1 = jacobian(f1)
        j1 = jac_fn1(params, x1)
        if x2 is None:
            j2 = j1
        else:
            f2 = partial(f, rng=key2)
            jac_fn2 = jacobian(f2)
            j2 = jac_fn2(params, x2)

        ntk = sum_and_contract(j1, j2)
        # TODO(schsam): If we care, this will not work if the output is not of
        # shape [n, output_dim].
        return np.transpose(ntk, (0, 2, 1, 3))
Пример #2
0
  def ntk_fn(x1: np.ndarray,
             x2: Optional[np.ndarray],
             params: PyTree,
             keys: Union[PRNGKey,
                         Tuple[PRNGKey, PRNGKey],
                         np.ndarray] = None,
             **apply_fn_kwargs) -> np.ndarray:
    """Computes a single sample of the empirical NTK (jacobian outer product).

    Args:
      x1:
        first batch of inputs.
      x2:
        second batch of inputs. `x2=None` means `x2=x1`. `f(x2)` must have a
        matching shape with `f(x1)` on `trace_axes` and `diagonal_axes`.
      params:
        A `PyTree` of parameters about which we would like to compute the
        neural tangent kernel.
      keys:
        `None` or a PRNG key or a tuple of PRNG keys or a (2, 2) array of
        dtype `uint32`. If `key=None`, then the function `f` is deterministic
        and requires no PRNG key; else if `keys` is a single PRNG key, then `x1`
        and `x2` must be the same and share the same PRNG key; else `x1` and
        `x2` use two different PRNG keys.
      **apply_fn_kwargs:
        keyword arguments passed to `apply_fn`.

    Returns:
      A single sample of the empirical NTK. The shape of the kernel is "almost"
      `zip(f(x1).shape, f(x2).shape)` except for:
      1) `trace_axes` are absent as they are contracted over.
      2) `diagonal_axes` are present only once.
      All other axes are present twice.
    """
    key1, key2 = _read_keys(keys)

    f1 = _get_f_params(f, x1, key1, **apply_fn_kwargs)
    jac_fn1 = jacobian(f1)
    j1 = jac_fn1(params)
    if x2 is None:
      j2 = j1
    else:
      f2 = _get_f_params(f, x2, key2, **apply_fn_kwargs)
      jac_fn2 = jacobian(f2)
      j2 = jac_fn2(params)

    fx1 = eval_shape(f1, params)
    ntk = sum_and_contract(j1, j2, fx1.ndim)
    return ntk / utils.size_at(fx1, trace_axes)
Пример #3
0
    def ntk_fun(x1, x2, params):
        """Computes the empirical ntk.

    Args:
      x1: A first `np.ndarray` of inputs, of shape [n1, ...], over which we
        would like to compute the NTK.
      x2: A second `np.ndarray` of inputs, of shape [n2, ...], over which we
        would like to compute the NTK.
      params: A PyTree of parameters about which we would like to compute the
        neural tangent kernel.

    Returns:
      A `np.ndarray` of shape [n1, n2] + output_shape + output_shape.
    """
        if x2 is None:
            x2 = x1
        fx2_struct = eval_shape(f, params, x2)
        fx_dummy = np.ones(fx2_struct.shape, fx2_struct.dtype)

        def delta_vjp_jvp(delta):
            def delta_vjp(delta):
                return vjp(lambda p: f(p, x2), params)[1](delta)

            return jvp(lambda p: f(p, x1), (params, ), delta_vjp(delta))[1]

        ntk = jacobian(delta_vjp_jvp)(fx_dummy)
        ndim = len(fx2_struct.shape)
        ordering = (0, ndim) + tuple(range(1, ndim)) + \
            tuple(x + ndim for x in range(1, ndim))
        return np.transpose(ntk, ordering)
Пример #4
0
    def ntk_fn(x1, x2, params, keys=None):
        """Computes the empirical ntk.
        Args:
          x1: A first `np.ndarray` of inputs, of shape [n1, ...], over which we
            would like to compute the NTK.
          x2: A second `np.ndarray` of inputs, of shape [n2, ...], over which we
            would like to compute the NTK.
          params: A PyTree of parameters about which we would like to compute the
            neural tangent kernel.
          keys: None or a PRNG key or a tuple of PRNG keys or a (2, 2) array and
            dtype uint32. If `key == None`, then the function `f` is deterministic
            and requires no PRNG key; else if `keys` is a single PRNG key, then x1
            and x2 must be the same and share the same PRNG key; else x1 and x2 use
            two different PRNG keys.
        Returns:
          A `np.ndarray` of shape [n1, n2] + output_shape + output_shape.
        """
        key1, key2 = _read_keys(keys)
        # TODO(xlc): find a good way to check utils.x1_is_x2(x1, x2) == (key1==key2)
        if x2 is None:
            x2 = x1

        f_dummy = partial(f, rng=random.PRNGKey(1))
        fx2_struct = eval_shape(f_dummy, params, x2)
        fx_dummy = np.ones(fx2_struct.shape, fx2_struct.dtype)
        def delta_vjp_jvp(delta):
            def delta_vjp(delta):
                return vjp(lambda p: f(p, x2, rng=key2), params)[1](delta)
            return jvp(lambda p: f(p, x1, rng=key1), (params,), delta_vjp(delta))[1]

        ntk = jacobian(delta_vjp_jvp)(fx_dummy)
        ndim = len(fx2_struct.shape)
        ordering = (0, ndim) + tuple(range(1, ndim)) + \
            tuple(x + ndim for x in range(1, ndim))
        return np.transpose(ntk, ordering)
Пример #5
0
    def ntk_fn(x1: NTTree[np.ndarray], x2: Optional[NTTree[np.ndarray]],
               params: PyTree, **apply_fn_kwargs) -> np.ndarray:
        """Computes a single sample of the empirical NTK (implicit differentiation).

    Args:
      x1:
        first batch of inputs.
      x2:
        second batch of inputs. `x2=None` means `x2=x1`. `f(x2)` must have a
        matching shape with `f(x1)` on `trace_axes` and `diagonal_axes`.
      params:
        A `PyTree` of parameters about which we would like to compute the
        neural tangent kernel.
      **apply_fn_kwargs:
        keyword arguments passed to `apply_fn`. `apply_fn_kwargs` will be split
        into `apply_fn_kwargs1` and `apply_fn_kwargs2` by the `split_kwargs`
        function which will be passed to `apply_fn`. In particular, the rng key
        in `apply_fn_kwargs`, will be split into two different (if `x1 != x2`)
        or same (if `x1 == x2`) rng keys. See the `_read_key` function for more
        details.

    Returns:
      A single sample of the empirical NTK. The shape of the kernel is "almost"
      `zip(f(x1).shape, f(x2).shape)` except for:
      1) `trace_axes` are absent as they are contracted over.
      2) `diagonal_axes` are present only once.
      All other axes are present twice.
    """
        kwargs1, kwargs2 = utils.split_kwargs(apply_fn_kwargs, x1, x2)
        f1 = _flatten(_get_f_params(f, x1, **kwargs1))
        f2 = (f1 if utils.all_none(x2) else _flatten(
            _get_f_params(f, x2, **kwargs2)))

        def delta_vjp_jvp(delta):
            def delta_vjp(delta):
                return vjp(f2, params)[1](delta)

            return jvp(f1, (params, ), delta_vjp(delta))[1]

        # Since we are taking the Jacobian of a linear function (which does not
        # depend on its coefficients), it is more efficient to substitute fx_dummy
        # for the outputs of the network. fx_dummy has the same shape as the output
        # of the network on a single piece of input data.
        fx2_struct = eval_shape(f2, params)

        @utils.nt_tree_fn()
        def dummy_output(fx_struct):
            return np.ones(fx_struct.shape, fx_struct.dtype)

        fx_dummy = dummy_output(fx2_struct)

        ntk = jacobian(delta_vjp_jvp)(fx_dummy)
        if utils.is_list_or_tuple(fx_dummy):
            fx_treedef = tree_structure(
                eval_shape(_get_f_params(f, x1, **kwargs1), params))
            ntk = [ntk[i][i] for i in range(len(fx_dummy))]
            ntk = tree_unflatten(fx_treedef, ntk)

        return _trace_and_diagonal(ntk, trace_axes, diagonal_axes)
Пример #6
0
  def ntk_fn(x1: np.ndarray,
             x2: Optional[np.ndarray],
             params: PyTree,
             **apply_fn_kwargs) -> np.ndarray:
    """Computes a single sample of the empirical NTK (jacobian outer product).

    Args:
      x1:
        first batch of inputs.
      x2:
        second batch of inputs. `x2=None` means `x2=x1`. `f(x2)` must have a
        matching shape with `f(x1)` on `trace_axes` and `diagonal_axes`.
      params:
        A `PyTree` of parameters about which we would like to compute the
        neural tangent kernel.
      **apply_fn_kwargs:
        keyword arguments passed to `apply_fn`. `apply_fn_kwargs` will be split
        into `apply_fn_kwargs1` and `apply_fn_kwargs2` by the `_split_kwargs`
        function which will be passed to `apply_fn`. In particular, the rng key
        in `apply_fn_kwargs`, will be split into two different (if `x1!=x2`) or
        same (if `x1==x2`) rng keys. See the `_read_key` function for more
        details.

    Returns:
      A single sample of the empirical NTK. The shape of the kernel is "almost"
      `zip(f(x1).shape, f(x2).shape)` except for:
      1) `trace_axes` are absent as they are contracted over.
      2) `diagonal_axes` are present only once.
      All other axes are present twice.
    """

    apply_fn_kwargs1, apply_fn_kwargs2 = _split_kwargs(apply_fn_kwargs, x1, x2)
    f1 = _get_f_params(f, x1, **apply_fn_kwargs1)
    jac_fn1 = jacobian(f1)
    j1 = jac_fn1(params)
    if x2 is None:
      j2 = j1
    else:
      f2 = _get_f_params(f, x2, **apply_fn_kwargs2)
      jac_fn2 = jacobian(f2)
      j2 = jac_fn2(params)

    fx1 = eval_shape(f1, params)
    ntk = sum_and_contract(j1, j2, fx1.ndim)
    return ntk / utils.size_at(fx1, trace_axes)
Пример #7
0
    def ntk_direct(f, params, x1, x2):
      jac_fn = jacobian(f)
      j1 = jac_fn(params, x1)

      if x2 is None:
        j2 = j1
      else:
        j2 = jac_fn(params, x2)

      return sum_and_contract(j1, j2)
Пример #8
0
  def ntk_fn(x1: np.ndarray,
             x2: Optional[np.ndarray],
             params: PyTree,
             keys: Union[PRNGKey,
                         Tuple[PRNGKey, PRNGKey],
                         np.ndarray] = None,
             **apply_fn_kwargs) -> np.ndarray:
    """Computes a single sample of the empirical NTK (implicit differentiation).

    Args:
      x1:
        first batch of inputs.
      x2:
        second batch of inputs. `x2=None` means `x2=x1`. `f(x2)` must have a
        matching shape with `f(x1)` on `trace_axes` and `diagonal_axes`.
      params:
        A `PyTree` of parameters about which we would like to compute the
        neural tangent kernel.
      keys:
        `None` or a PRNG key or a tuple of PRNG keys or a (2, 2) array of
        dtype `uint32`. If `key=None`, then the function `f` is deterministic
        and requires no PRNG key; else if `keys` is a single PRNG key, then `x1`
        and `x2` must be the same and share the same PRNG key; else `x1` and
        `x2` use two different PRNG keys.
      **apply_fn_kwargs:
        keyword arguments passed to `apply_fn`.

    Returns:
      A single sample of the empirical NTK. The shape of the kernel is "almost"
      `zip(f(x1).shape, f(x2).shape)` except for:
      1) `trace_axes` are absent as they are contracted over.
      2) `diagonal_axes` are present only once.
      All other axes are present twice.
    """
    key1, key2 = _read_keys(keys)
    # TODO(xlc): find a good way to check utils.x1_is_x2(x1, x2) == (key1==key2)

    f1 = _get_f_params(f, x1, key1, **apply_fn_kwargs)
    f2 = f1 if x2 is None else _get_f_params(f, x2, key2, **apply_fn_kwargs)

    def delta_vjp_jvp(delta):
      def delta_vjp(delta):
        return vjp(f2, params)[1](delta)
      return jvp(f1, (params,), delta_vjp(delta))[1]

    # Since we are taking the Jacobian of a linear function (which does not
    # depend on its coefficients), it is more efficient to substitute fx_dummy
    # for the outputs of the network. fx_dummy has the same shape as the output
    # of the network on a single piece of input data.
    fx2_struct = eval_shape(f2, params)
    fx_dummy = np.ones(fx2_struct.shape, fx2_struct.dtype)

    ntk = jacobian(delta_vjp_jvp)(fx_dummy)
    return _trace_and_diagonal(ntk, trace_axes, diagonal_axes)
Пример #9
0
def _compute_ntk(f, fx_dummy, params, x1, x2):
    """Computes the ntk without batching for inputs x1 and x2.

  The Neural Tangent Kernel is defined as J(X_1)^T J(X_2) where J is the
  jacobian df/dparams. Computing the NTK directly involves directly
  instantiating the jacobian which takes
  O(dataset_size * output_dim * parameters) memory. It turns out it is
  substantially more efficient (especially as the number of parameters grows)
  to compute the NTK implicitly.

  This involves using JAX's autograd to compute derivatives of linear functions
  (which do not depend on the inputs). Thus, we find it more efficient to refer
  to fx_dummy for the outputs of the network. fx_dummy has the same shape as
  the output of the network on a single piece of input data.

  TODO(schsam): Write up a better description of the implicit method.

  Args:
    f: The function whose NTK we are computing. f should have the signature
       f(params, inputs) and should return an ndarray of outputs with shape
       [|inputs|, output_dim].
    fx_dummy: A dummy evaluation of f on a single input that we use to
       instantiate an ndarray with the correct shape
       (aka [|inputs|, output_dim]).
       It should be possible at some point to use JAX's tracing mechanism to do
       this more efficiently.
    params: A set of parameters about which we would like to compute the neural
       tangent kernel. This should be any structure that can be mapped over by
       JAX's tree utilities.
    x1: A first ndarray of inputs, of shape [n1, ...], over which we would like
       to compute the NTK.
    x2: A second ndarray of inputs, of shape [n2, ...], over which we would like
       to compute the NTK.

  Returns:
    An ndarray containing the NTK with shape [n * output_dim, m * output_dim].
  """
    fx_dummy = np.concatenate([fx_dummy] * len(x2))
    output_dim = fx_dummy.shape[1]

    def dzdt(delta):
        _, dfdw = vjp(lambda p: f(p, x2), params)
        dfdw, = dfdw(delta)

        def z(t):
            p = tree_multimap(np.add, params, tree_map(lambda x: t * x, dfdw))
            return f(p, x1)

        _, dzdot = jvp(z, (0.0, ), (1.0, ))
        return dzdot

    theta = jacobian(dzdt)(fx_dummy)
    return np.reshape(theta, (len(x1) * output_dim, len(x2) * output_dim))
Пример #10
0
def get_ntk_fun_empirical_direct(f):
    """Computes the ntk without batching for inputs x1 and x2.

  The Neural Tangent Kernel is defined as J(X_1)^T J(X_2) where J is the
  jacobian df/dparams.

  Args:
    f: The function whose NTK we are computing. f should have the signature
       f(params, inputs) and should return an `np.ndarray` of outputs with shape
       [|inputs|, output_dim].

  Returns:
    A function `ntk_fun` that computes the empirical ntk.
  """
    jac_fn = jacobian(f)

    def sum_and_contract(j1, j2):
        def contract(x, y):
            param_count = int(np.prod(x.shape[2:]))
            x = np.reshape(x, x.shape[:2] + (param_count, ))
            y = np.reshape(y, y.shape[:2] + (param_count, ))
            return np.dot(x, np.transpose(y, (0, 2, 1)))

        return tree_reduce(operator.add, tree_multimap(contract, j1, j2))

    def ntk_fun(x1, x2, params):
        """Computes the empirical ntk.

    Args:
      x1: A first `np.ndarray` of inputs, of shape [n1, ...], over which we
        would like to compute the NTK.
      x2: A second `np.ndarray` of inputs, of shape [n2, ...], over which we
        would like to compute the NTK.
      params: A PyTree of parameters about which we would like to compute the
        neural tangent kernel.
    Returns:
      A `np.ndarray` of shape [n1, n2] + output_shape + output_shape.
    """
        j1 = jac_fn(params, x1)

        if x2 is None:
            j2 = j1
        else:
            j2 = jac_fn(params, x2)

        ntk = sum_and_contract(j1, j2)
        # TODO(schsam): If we care, this will not work if the output is not of
        # shape [n, output_dim].
        return np.transpose(ntk, (0, 2, 1, 3))

    return ntk_fun
Пример #11
0
    def ntk_fn(x1, x2, params, keys=None, **apply_fn_kwargs):
        """Computes the empirical ntk.

    Args:
      x1: A first `np.ndarray` of inputs, of shape [n1, ...], over which we
        would like to compute the NTK.
      x2: A second `np.ndarray` of inputs, of shape [n2, ...], over which we
        would like to compute the NTK.
      params: A PyTree of parameters about which we would like to compute the
        neural tangent kernel.
      keys: None or a PRNG key or a tuple of PRNG keys or a (2, 2) array and
        dtype uint32. If `key == None`, then the function `f` is deterministic
        and requires no PRNG key; else if `keys` is a single PRNG key, then x1
        and x2 must be the same and share the same PRNG key; else x1 and x2 use
        two different PRNG keys.
      **apply_fn_kwargs: keyword arguments passed to `apply_fn`.

    Returns:
      A `np.ndarray` of shape [n1, n2] + output_shape + output_shape.
    """
        key1, key2 = _read_keys(keys)
        # TODO: find a good way to check utils.x1_is_x2(x1, x2) == (key1==key2)
        if x2 is None:
            x2 = x1

        f1 = _get_f_params(f, x1, key1, **apply_fn_kwargs)
        f2 = _get_f_params(f, x2, key2, **apply_fn_kwargs)

        def delta_vjp_jvp(delta):
            def delta_vjp(delta):
                return vjp(f2, params)[1](delta)

            return jvp(f1, (params, ), delta_vjp(delta))[1]

        # Since we are taking the Jacobian of a linear function (which does not
        # depend on its coefficients), it is more efficient to substitute fx_dummy
        # for the outputs of the network. fx_dummy has the same shape as the output
        # of the network on a single piece of input data.
        fx2_struct = eval_shape(f2, params)
        fx_dummy = np.ones(fx2_struct.shape, fx2_struct.dtype)

        ntk = jacobian(delta_vjp_jvp)(fx_dummy)
        ndim = len(fx2_struct.shape)
        ordering = (0, ndim) + tuple(range(1, ndim)) + \
           tuple(x + ndim for x in range(1, ndim))
        return np.transpose(ntk, ordering)
Пример #12
0
 def j_fn(x, *args):
     _kwargs = {k: v for k, v in zip(keys, args)}
     fx = _get_f_params(f, x, x_axis, fx_axis, kw_axes, **_kwargs)
     jx = jacobian(fx)(params)
     return jx
Пример #13
0
 def scalar_solve2(f, y):
     y_1d = y[np.newaxis]
     return np.linalg.solve(api.jacobian(f)(y_1d), y_1d).squeeze()
Пример #14
0
 def explicit_jacobian_solve(matvec, b):
   return lax.stop_gradient(np.linalg.solve(api.jacobian(matvec)(b), b))
Пример #15
0
 def vector_solve(f, y):
   return np.linalg.solve(api.jacobian(f)(y), y)