Пример #1
0
 def test_eval_shape_big_random_array(self):
   if not config.omnistaging_enabled:
     raise SkipTest("after deleting lazy constants, requires omnistaging")
   def f(x):
     return random.normal(random.PRNGKey(x), (int(1e12),))
   with core.skipping_checks():  # check_jaxpr will materialize array
     api.eval_shape(f, 0)  # doesn't error
Пример #2
0
    def ntk_fn(x1: NTTree[np.ndarray], x2: Optional[NTTree[np.ndarray]],
               params: PyTree, **apply_fn_kwargs) -> np.ndarray:
        """Computes a single sample of the empirical NTK (implicit differentiation).

    Args:
      x1:
        first batch of inputs.
      x2:
        second batch of inputs. `x2=None` means `x2=x1`. `f(x2)` must have a
        matching shape with `f(x1)` on `trace_axes` and `diagonal_axes`.
      params:
        A `PyTree` of parameters about which we would like to compute the
        neural tangent kernel.
      **apply_fn_kwargs:
        keyword arguments passed to `apply_fn`. `apply_fn_kwargs` will be split
        into `apply_fn_kwargs1` and `apply_fn_kwargs2` by the `split_kwargs`
        function which will be passed to `apply_fn`. In particular, the rng key
        in `apply_fn_kwargs`, will be split into two different (if `x1 != x2`)
        or same (if `x1 == x2`) rng keys. See the `_read_key` function for more
        details.

    Returns:
      A single sample of the empirical NTK. The shape of the kernel is "almost"
      `zip(f(x1).shape, f(x2).shape)` except for:
      1) `trace_axes` are absent as they are contracted over.
      2) `diagonal_axes` are present only once.
      All other axes are present twice.
    """
        kwargs1, kwargs2 = utils.split_kwargs(apply_fn_kwargs, x1, x2)
        f1 = _flatten(_get_f_params(f, x1, **kwargs1))
        f2 = (f1 if utils.all_none(x2) else _flatten(
            _get_f_params(f, x2, **kwargs2)))

        def delta_vjp_jvp(delta):
            def delta_vjp(delta):
                return vjp(f2, params)[1](delta)

            return jvp(f1, (params, ), delta_vjp(delta))[1]

        # Since we are taking the Jacobian of a linear function (which does not
        # depend on its coefficients), it is more efficient to substitute fx_dummy
        # for the outputs of the network. fx_dummy has the same shape as the output
        # of the network on a single piece of input data.
        fx2_struct = eval_shape(f2, params)

        @utils.nt_tree_fn()
        def dummy_output(fx_struct):
            return np.ones(fx_struct.shape, fx_struct.dtype)

        fx_dummy = dummy_output(fx2_struct)

        ntk = jacobian(delta_vjp_jvp)(fx_dummy)
        if utils.is_list_or_tuple(fx_dummy):
            fx_treedef = tree_structure(
                eval_shape(_get_f_params(f, x1, **kwargs1), params))
            ntk = [ntk[i][i] for i in range(len(fx_dummy))]
            ntk = tree_unflatten(fx_treedef, ntk)

        return _trace_and_diagonal(ntk, trace_axes, diagonal_axes)
Пример #3
0
    def ntk_fun(x1, x2, params):
        """Computes the empirical ntk.

    Args:
      x1: A first `np.ndarray` of inputs, of shape [n1, ...], over which we
        would like to compute the NTK.
      x2: A second `np.ndarray` of inputs, of shape [n2, ...], over which we
        would like to compute the NTK.
      params: A PyTree of parameters about which we would like to compute the
        neural tangent kernel.

    Returns:
      A `np.ndarray` of shape [n1, n2] + output_shape + output_shape.
    """
        if x2 is None:
            x2 = x1
        fx2_struct = eval_shape(f, params, x2)
        fx_dummy = np.ones(fx2_struct.shape, fx2_struct.dtype)

        def delta_vjp_jvp(delta):
            def delta_vjp(delta):
                return vjp(lambda p: f(p, x2), params)[1](delta)

            return jvp(lambda p: f(p, x1), (params, ), delta_vjp(delta))[1]

        ntk = jacobian(delta_vjp_jvp)(fx_dummy)
        ndim = len(fx2_struct.shape)
        ordering = (0, ndim) + tuple(range(1, ndim)) + \
            tuple(x + ndim for x in range(1, ndim))
        return np.transpose(ntk, ordering)
Пример #4
0
    def ntk_fn(x1, x2, params, keys=None):
        """Computes the empirical ntk.
        Args:
          x1: A first `np.ndarray` of inputs, of shape [n1, ...], over which we
            would like to compute the NTK.
          x2: A second `np.ndarray` of inputs, of shape [n2, ...], over which we
            would like to compute the NTK.
          params: A PyTree of parameters about which we would like to compute the
            neural tangent kernel.
          keys: None or a PRNG key or a tuple of PRNG keys or a (2, 2) array and
            dtype uint32. If `key == None`, then the function `f` is deterministic
            and requires no PRNG key; else if `keys` is a single PRNG key, then x1
            and x2 must be the same and share the same PRNG key; else x1 and x2 use
            two different PRNG keys.
        Returns:
          A `np.ndarray` of shape [n1, n2] + output_shape + output_shape.
        """
        key1, key2 = _read_keys(keys)
        # TODO(xlc): find a good way to check utils.x1_is_x2(x1, x2) == (key1==key2)
        if x2 is None:
            x2 = x1

        f_dummy = partial(f, rng=random.PRNGKey(1))
        fx2_struct = eval_shape(f_dummy, params, x2)
        fx_dummy = np.ones(fx2_struct.shape, fx2_struct.dtype)
        def delta_vjp_jvp(delta):
            def delta_vjp(delta):
                return vjp(lambda p: f(p, x2, rng=key2), params)[1](delta)
            return jvp(lambda p: f(p, x1, rng=key1), (params,), delta_vjp(delta))[1]

        ntk = jacobian(delta_vjp_jvp)(fx_dummy)
        ndim = len(fx2_struct.shape)
        ordering = (0, ndim) + tuple(range(1, ndim)) + \
            tuple(x + ndim for x in range(1, ndim))
        return np.transpose(ntk, ordering)
Пример #5
0
    def test_eval_shape_shape_error(self):
        def fun(x, y):
            return np.tanh(np.dot(x, y) + 3.)

        x = np.ones((3, 3))
        y = np.ones((4, 4))

        self.assertRaises(TypeError, lambda: api.eval_shape(fun, x, y))
Пример #6
0
    def test_eval_shape_constants(self):
        def fun():
            x = np.ones((2, 3))
            y = np.ones((3, 4))
            return np.tanh(np.dot(x, y) + 3.)

        out_shape = api.eval_shape(fun)

        self.assertEqual(out_shape, (2, 4))
Пример #7
0
    def test_eval_shape_tuple_itemgetting(self):
        def fun(x, y):
            return x[0] + x[1] + y

        x = (np.ones(2), np.ones(2))
        y = 3.
        out_shape = api.eval_shape(fun, x, y)

        self.assertEqual(out_shape, (2, ))
Пример #8
0
    def test_eval_shape_output_dict(self):
        def fun(x, y):
            return {'hi': x[0] + x[1] + y}

        x = (np.ones(2), np.ones(2))
        y = 3.
        out_shape = api.eval_shape(fun, x, y)

        self.assertEqual(out_shape, {'hi': (2, )})
Пример #9
0
 def force_fn(R, **kwargs):
     nonlocal _force_fn
     if _force_fn is None:
         out_shape = eval_shape(energy_or_force_fn, R, **kwargs).shape
         if out_shape == ():
             _force_fn = force(energy_or_force_fn)
         else:
             _force_fn = energy_or_force_fn
     return _force_fn(R, **kwargs)
Пример #10
0
    def test_eval_shape(self):
        def fun(x, y):
            return np.tanh(np.dot(x, y) + 3.)

        x = np.ones((2, 3))
        y = np.ones((3, 4))
        out_shape = api.eval_shape(fun, x, y)

        self.assertEqual(out_shape, (2, 4))
Пример #11
0
    def test_eval_shape_tuple_unpacking(self):
        def fun(x, y):
            a, b = x
            return a + b + y

        x = (np.ones(2), np.ones(2))
        y = 3.
        out_shape = api.eval_shape(fun, x, y)

        self.assertEqual(out_shape, (2, ))
Пример #12
0
  def ntk_fn(x1: np.ndarray,
             x2: Optional[np.ndarray],
             params: PyTree,
             keys: Union[PRNGKey,
                         Tuple[PRNGKey, PRNGKey],
                         np.ndarray] = None,
             **apply_fn_kwargs) -> np.ndarray:
    """Computes a single sample of the empirical NTK (implicit differentiation).

    Args:
      x1:
        first batch of inputs.
      x2:
        second batch of inputs. `x2=None` means `x2=x1`. `f(x2)` must have a
        matching shape with `f(x1)` on `trace_axes` and `diagonal_axes`.
      params:
        A `PyTree` of parameters about which we would like to compute the
        neural tangent kernel.
      keys:
        `None` or a PRNG key or a tuple of PRNG keys or a (2, 2) array of
        dtype `uint32`. If `key=None`, then the function `f` is deterministic
        and requires no PRNG key; else if `keys` is a single PRNG key, then `x1`
        and `x2` must be the same and share the same PRNG key; else `x1` and
        `x2` use two different PRNG keys.
      **apply_fn_kwargs:
        keyword arguments passed to `apply_fn`.

    Returns:
      A single sample of the empirical NTK. The shape of the kernel is "almost"
      `zip(f(x1).shape, f(x2).shape)` except for:
      1) `trace_axes` are absent as they are contracted over.
      2) `diagonal_axes` are present only once.
      All other axes are present twice.
    """
    key1, key2 = _read_keys(keys)
    # TODO(xlc): find a good way to check utils.x1_is_x2(x1, x2) == (key1==key2)

    f1 = _get_f_params(f, x1, key1, **apply_fn_kwargs)
    f2 = f1 if x2 is None else _get_f_params(f, x2, key2, **apply_fn_kwargs)

    def delta_vjp_jvp(delta):
      def delta_vjp(delta):
        return vjp(f2, params)[1](delta)
      return jvp(f1, (params,), delta_vjp(delta))[1]

    # Since we are taking the Jacobian of a linear function (which does not
    # depend on its coefficients), it is more efficient to substitute fx_dummy
    # for the outputs of the network. fx_dummy has the same shape as the output
    # of the network on a single piece of input data.
    fx2_struct = eval_shape(f2, params)
    fx_dummy = np.ones(fx2_struct.shape, fx2_struct.dtype)

    ntk = jacobian(delta_vjp_jvp)(fx_dummy)
    return _trace_and_diagonal(ntk, trace_axes, diagonal_axes)
Пример #13
0
    def ntk_fn(x1: NTTree[np.ndarray], x2: Optional[NTTree[np.ndarray]],
               params: PyTree, **apply_fn_kwargs) -> np.ndarray:
        """Computes a single sample of the empirical NTK (jacobian outer product).

    Args:
      x1:
        first batch of inputs.
      x2:
        second batch of inputs. `x2=None` means `x2=x1`. `f(x2)` must have a
        matching shape with `f(x1)` on `trace_axes` and `diagonal_axes`.
      params:
        A `PyTree` of parameters about which we would like to compute the
        neural tangent kernel.
      **apply_fn_kwargs:
        keyword arguments passed to `apply_fn`. `apply_fn_kwargs` will be split
        into `apply_fn_kwargs1` and `apply_fn_kwargs2` by the `split_kwargs`
        function which will be passed to `apply_fn`. In particular, the rng key
        in `apply_fn_kwargs`, will be split into two different (if `x1!=x2`) or
        same (if `x1==x2`) rng keys. See the `_read_key` function for more
        details.

    Returns:
      A single sample of the empirical NTK. The shape of the kernel is "almost"
      `zip(f(x1).shape, f(x2).shape)` except for:
      1) `trace_axes` are absent as they are contracted over.
      2) `diagonal_axes` are present only once.
      All other axes are present twice.
    """
        kwargs1, kwargs2 = utils.split_kwargs(apply_fn_kwargs, x1, x2)
        fx1 = eval_shape(f, params, x1, **kwargs1)
        x_axis, fx_axis, kw_axes = _canonicalize_axes(vmap_axes, x1, fx1,
                                                      **kwargs1)

        keys = apply_fn_kwargs.keys()
        args1, args2 = (kwargs1[k] for k in keys), (kwargs2[k] for k in keys)

        def j_fn(x, *args):
            _kwargs = {k: v for k, v in zip(keys, args)}
            fx = _get_f_params(f, x, x_axis, fx_axis, kw_axes, **_kwargs)
            jx = jacobian(fx)(params)
            return jx

        if x_axis is not None or kw_axes:
            in_axes = [x_axis] + [
                kw_axes[k] if k in kw_axes else None for k in keys
            ]
            j_fn = vmap(j_fn, in_axes=in_axes, out_axes=fx_axis)

        j1 = j_fn(x1, *args1)
        j2 = j_fn(x2, *args2) if not utils.all_none(x2) else j1
        ntk = sum_and_contract(fx1, j1, j2)
        return ntk
Пример #14
0
        def get_ntk(x1, x2, *args):
            args1, args2 = args[:len(args) // 2], args[len(args) // 2:]
            _kwargs1 = {k: v for k, v in zip(keys, args1)}
            _kwargs2 = {k: v for k, v in zip(keys, args2)}

            f1 = _get_f_params(f, x1, x_axis, fx_axis, kw_axes, **_kwargs1)
            f2 = f1 if utils.all_none(x2) else _get_f_params(
                f, x2, x_axis, fx_axis, kw_axes, **_kwargs2)

            def delta_vjp_jvp(delta):
                def delta_vjp(delta):
                    return vjp(f2, params)[1](delta)

                return jvp(f1, (params, ), delta_vjp(delta))[1]

            fx1, fx2 = eval_shape(f1, params), eval_shape(f2, params)
            eye = _std_basis(fx1)
            ntk = vmap(linear_transpose(delta_vjp_jvp, fx2))(eye)
            ntk = tree_map(
                lambda fx12: _unravel_array_into_pytree(fx1, 0, fx12), ntk)
            ntk = _diagonal(ntk, fx1)
            return ntk
Пример #15
0
  def ntk_fn(x1: np.ndarray,
             x2: Optional[np.ndarray],
             params: PyTree,
             keys: Union[PRNGKey,
                         Tuple[PRNGKey, PRNGKey],
                         np.ndarray] = None,
             **apply_fn_kwargs) -> np.ndarray:
    """Computes a single sample of the empirical NTK (jacobian outer product).

    Args:
      x1:
        first batch of inputs.
      x2:
        second batch of inputs. `x2=None` means `x2=x1`. `f(x2)` must have a
        matching shape with `f(x1)` on `trace_axes` and `diagonal_axes`.
      params:
        A `PyTree` of parameters about which we would like to compute the
        neural tangent kernel.
      keys:
        `None` or a PRNG key or a tuple of PRNG keys or a (2, 2) array of
        dtype `uint32`. If `key=None`, then the function `f` is deterministic
        and requires no PRNG key; else if `keys` is a single PRNG key, then `x1`
        and `x2` must be the same and share the same PRNG key; else `x1` and
        `x2` use two different PRNG keys.
      **apply_fn_kwargs:
        keyword arguments passed to `apply_fn`.

    Returns:
      A single sample of the empirical NTK. The shape of the kernel is "almost"
      `zip(f(x1).shape, f(x2).shape)` except for:
      1) `trace_axes` are absent as they are contracted over.
      2) `diagonal_axes` are present only once.
      All other axes are present twice.
    """
    key1, key2 = _read_keys(keys)

    f1 = _get_f_params(f, x1, key1, **apply_fn_kwargs)
    jac_fn1 = jacobian(f1)
    j1 = jac_fn1(params)
    if x2 is None:
      j2 = j1
    else:
      f2 = _get_f_params(f, x2, key2, **apply_fn_kwargs)
      jac_fn2 = jacobian(f2)
      j2 = jac_fn2(params)

    fx1 = eval_shape(f1, params)
    ntk = sum_and_contract(j1, j2, fx1.ndim)
    return ntk / utils.size_at(fx1, trace_axes)
Пример #16
0
    def test_eval_shape_duck_typing(self):
        def fun(A, b, x):
            return np.dot(A, x) + b

        class MyArgArray(object):
            def __init__(self, shape, dtype):
                self.shape = shape
                self.dtype = dtype

        A = MyArgArray((3, 4), np.float32)
        b = MyArgArray((5, ), np.float32)
        x = MyArgArray((4, 5), np.float32)
        out_shape = api.eval_shape(fun, A, b, x)

        self.assertEqual(out_shape, (3, 5))
Пример #17
0
    def ntk_fn(x1, x2, params, keys=None, **apply_fn_kwargs):
        """Computes the empirical ntk.

    Args:
      x1: A first `np.ndarray` of inputs, of shape [n1, ...], over which we
        would like to compute the NTK.
      x2: A second `np.ndarray` of inputs, of shape [n2, ...], over which we
        would like to compute the NTK.
      params: A PyTree of parameters about which we would like to compute the
        neural tangent kernel.
      keys: None or a PRNG key or a tuple of PRNG keys or a (2, 2) array and
        dtype uint32. If `key == None`, then the function `f` is deterministic
        and requires no PRNG key; else if `keys` is a single PRNG key, then x1
        and x2 must be the same and share the same PRNG key; else x1 and x2 use
        two different PRNG keys.
      **apply_fn_kwargs: keyword arguments passed to `apply_fn`.

    Returns:
      A `np.ndarray` of shape [n1, n2] + output_shape + output_shape.
    """
        key1, key2 = _read_keys(keys)
        # TODO: find a good way to check utils.x1_is_x2(x1, x2) == (key1==key2)
        if x2 is None:
            x2 = x1

        f1 = _get_f_params(f, x1, key1, **apply_fn_kwargs)
        f2 = _get_f_params(f, x2, key2, **apply_fn_kwargs)

        def delta_vjp_jvp(delta):
            def delta_vjp(delta):
                return vjp(f2, params)[1](delta)

            return jvp(f1, (params, ), delta_vjp(delta))[1]

        # Since we are taking the Jacobian of a linear function (which does not
        # depend on its coefficients), it is more efficient to substitute fx_dummy
        # for the outputs of the network. fx_dummy has the same shape as the output
        # of the network on a single piece of input data.
        fx2_struct = eval_shape(f2, params)
        fx_dummy = np.ones(fx2_struct.shape, fx2_struct.dtype)

        ntk = jacobian(delta_vjp_jvp)(fx_dummy)
        ndim = len(fx2_struct.shape)
        ordering = (0, ndim) + tuple(range(1, ndim)) + \
           tuple(x + ndim for x in range(1, ndim))
        return np.transpose(ntk, ordering)
Пример #18
0
  def testSelectAndScatterAdd(self, dtype, padding, shape, dims, strides):
    rng = jtu.rand_small(self.rng())

    pads = lax.padtype_to_pads(shape, dims, strides, padding)

    def fun(operand, cotangents):
      return lax._select_and_scatter_add(operand, cotangents, lax.ge_p, dims,
                                         strides, pads)
    ones = (1,) * len(shape)
    cotangent_shape = api.eval_shape(
      lambda x: lax._select_and_gather_add(x, x, lax.ge_p, dims, strides,
                                           pads, ones, ones),
      np.ones(shape, dtype)).shape

    for bdims in all_bdims(cotangent_shape, shape):
      self._CheckBatching(fun, 3, bdims, (cotangent_shape, shape),
                          (dtype, dtype), rng)
Пример #19
0
  def ntk_fn(x1: np.ndarray,
             x2: Optional[np.ndarray],
             params: PyTree,
             **apply_fn_kwargs) -> np.ndarray:
    """Computes a single sample of the empirical NTK (jacobian outer product).

    Args:
      x1:
        first batch of inputs.
      x2:
        second batch of inputs. `x2=None` means `x2=x1`. `f(x2)` must have a
        matching shape with `f(x1)` on `trace_axes` and `diagonal_axes`.
      params:
        A `PyTree` of parameters about which we would like to compute the
        neural tangent kernel.
      **apply_fn_kwargs:
        keyword arguments passed to `apply_fn`. `apply_fn_kwargs` will be split
        into `apply_fn_kwargs1` and `apply_fn_kwargs2` by the `_split_kwargs`
        function which will be passed to `apply_fn`. In particular, the rng key
        in `apply_fn_kwargs`, will be split into two different (if `x1!=x2`) or
        same (if `x1==x2`) rng keys. See the `_read_key` function for more
        details.

    Returns:
      A single sample of the empirical NTK. The shape of the kernel is "almost"
      `zip(f(x1).shape, f(x2).shape)` except for:
      1) `trace_axes` are absent as they are contracted over.
      2) `diagonal_axes` are present only once.
      All other axes are present twice.
    """

    apply_fn_kwargs1, apply_fn_kwargs2 = _split_kwargs(apply_fn_kwargs, x1, x2)
    f1 = _get_f_params(f, x1, **apply_fn_kwargs1)
    jac_fn1 = jacobian(f1)
    j1 = jac_fn1(params)
    if x2 is None:
      j2 = j1
    else:
      f2 = _get_f_params(f, x2, **apply_fn_kwargs2)
      jac_fn2 = jacobian(f2)
      j2 = jac_fn2(params)

    fx1 = eval_shape(f1, params)
    ntk = sum_and_contract(j1, j2, fx1.ndim)
    return ntk / utils.size_at(fx1, trace_axes)
Пример #20
0
def _displacement_or_metric_to_metric_sq(
        displacement_or_metric: DisplacementOrMetricFn) -> MetricFn:
    """Checks whether or not a displacement or metric was provided."""
    for dim in range(1, 4):
        try:
            R = ShapedArray((dim, ), f32)
            dR_or_dr = eval_shape(displacement_or_metric, R, R, t=0)
            if len(dR_or_dr.shape) == 0:
                return lambda Ra, Rb, **kwargs: \
                  displacement_or_metric(Ra, Rb, **kwargs) ** 2
            else:
                return lambda Ra, Rb, **kwargs: space.square_distance(
                    displacement_or_metric(Ra, Rb, **kwargs))
        except TypeError:
            continue
        except ValueError:
            continue
    raise ValueError(
        'Canonicalize displacement not implemented for spatial dimension larger'
        'than 4.')
Пример #21
0
    def ntk_fn(x1: NTTree[np.ndarray], x2: Optional[NTTree[np.ndarray]],
               params: PyTree, **apply_fn_kwargs) -> np.ndarray:
        """Computes a single sample of the empirical NTK (implicit differentiation).

    Args:
      x1:
        first batch of inputs.
      x2:
        second batch of inputs. `x2=None` means `x2=x1`. `f(x2)` must have a
        matching shape with `f(x1)` on `trace_axes` and `diagonal_axes`.
      params:
        A `PyTree` of parameters about which we would like to compute the
        neural tangent kernel.
      **apply_fn_kwargs:
        keyword arguments passed to `apply_fn`. `apply_fn_kwargs` will be split
        into `apply_fn_kwargs1` and `apply_fn_kwargs2` by the `split_kwargs`
        function which will be passed to `apply_fn`. In particular, the rng key
        in `apply_fn_kwargs`, will be split into two different (if `x1 != x2`)
        or same (if `x1 == x2`) rng keys. See the `_read_key` function for more
        details.

    Returns:
      A single sample of the empirical NTK. The shape of the kernel is "almost"
      `zip(f(x1).shape, f(x2).shape)` except for:
      1) `trace_axes` are absent as they are contracted over.
      2) `diagonal_axes` are present only once.
      All other axes are present twice.
    """
        kwargs1, kwargs2 = utils.split_kwargs(apply_fn_kwargs, x1, x2)
        fx1 = eval_shape(f, params, x1, **kwargs1)
        x_axis, fx_axis, kw_axes = _canonicalize_axes(vmap_axes, x1, fx1,
                                                      **kwargs1)

        keys = apply_fn_kwargs.keys()
        args1 = (kwargs1[k] for k in keys)
        args2 = (kwargs1[k]
                 if k in kw_axes and kwargs2[k] is None else kwargs2[k]
                 for k in keys)

        def get_ntk(x1, x2, *args):
            args1, args2 = args[:len(args) // 2], args[len(args) // 2:]
            _kwargs1 = {k: v for k, v in zip(keys, args1)}
            _kwargs2 = {k: v for k, v in zip(keys, args2)}

            f1 = _get_f_params(f, x1, x_axis, fx_axis, kw_axes, **_kwargs1)
            f2 = f1 if utils.all_none(x2) else _get_f_params(
                f, x2, x_axis, fx_axis, kw_axes, **_kwargs2)

            def delta_vjp_jvp(delta):
                def delta_vjp(delta):
                    return vjp(f2, params)[1](delta)

                return jvp(f1, (params, ), delta_vjp(delta))[1]

            fx1, fx2 = eval_shape(f1, params), eval_shape(f2, params)
            eye = _std_basis(fx1)
            ntk = vmap(linear_transpose(delta_vjp_jvp, fx2))(eye)
            ntk = tree_map(
                lambda fx12: _unravel_array_into_pytree(fx1, 0, fx12), ntk)
            ntk = _diagonal(ntk, fx1)
            return ntk

        if x_axis is not None or kw_axes:
            x2 = x1 if utils.all_none(x2) else x2

            kw_in_axes = [kw_axes[k] if k in kw_axes else None for k in keys]
            in_axes1 = [x_axis, None] + kw_in_axes + [None] * len(kw_in_axes)
            in_axes2 = [None, x_axis] + [None] * len(kw_in_axes) + kw_in_axes

            get_ntk = vmap(vmap(get_ntk, in_axes1, fx_axis), in_axes2,
                           _add(fx_axis, _ndim(fx1)))

        return _trace_and_diagonal(get_ntk(x1, x2, *args1, *args2), trace_axes,
                                   diagonal_axes)
Пример #22
0
 def test_eval_shape_big_random_array(self):
   def f(x):
     return random.normal(random.PRNGKey(x), (int(1e12),))
   with core.skipping_checks():  # check_jaxpr will materialize array
     api.eval_shape(f, 0)  # doesn't error