Пример #1
0
  def nngp_fn(x1: np.ndarray,
              x2: Optional[np.ndarray],
              params: PyTree,
              keys: Union[PRNGKey,
                          Tuple[PRNGKey, PRNGKey],
                          np.ndarray] = None,
              **apply_fn_kwargs) -> np.ndarray:
    """Computes a single sample of the empirical NNGP.

    Args:
      x1:
        first batch of inputs.
      x2:
        second batch of inputs. `x2=None` means `x2=x1`. `f(x2)` must have a
        matching shape with `f(x1)` on `trace_axes` and `diagonal_axes`.
      params:
        A `PyTree` of parameters about which we would like to compute the
        neural tangent kernel.
      keys: `None` or a PRNG key or a tuple of PRNG keys or a (2, 2) array of
        dtype `uint32`. If `key=None`, then the function `f` is deterministic
        and requires no PRNG key; else if `keys` is a single PRNG key, then `x1`
        and `x2` must be the same and share the same PRNG key; else `x1` and
        `x2` use two different PRNG keys.
      **apply_fn_kwargs:
        keyword arguments passed to `apply_fn`.

    Returns:
      A single sample of the empirical NNGP. The shape of the kernel is "almost"
      `zip(f(x1).shape, f(x2).shape)` except for:
      1) `trace_axes` are absent as they are contracted over.
      2) `diagonal_axes` are present only once.
      All other axes are present twice.
    """
    key1, key2 = _read_keys(keys)

    def output(x, rng):
      out = f(params, x, rng=rng, **apply_fn_kwargs)
      masked_output = utils.get_masked_array(out)
      return masked_output.masked_value

    out1 = output(x1, key1)
    if x2 is None:
      out2 = out1
    else:
      out2 = output(x2, key2)

    dot = utils.dot_general(out1, out2, trace_axes, diagonal_axes)
    return dot / utils.size_at(out1, trace_axes)
Пример #2
0
  def nngp_fn(x1: np.ndarray,
              x2: Optional[np.ndarray],
              params: PyTree,
              **apply_fn_kwargs) -> np.ndarray:
    """Computes a single sample of the empirical NNGP.

    Args:
      x1:
        first batch of inputs.
      x2:
        second batch of inputs. `x2=None` means `x2=x1`. `f(x2)` must have a
        matching shape with `f(x1)` on `trace_axes` and `diagonal_axes`.
      params:
        A `PyTree` of parameters about which we would like to compute the
        neural tangent kernel.
      **apply_fn_kwargs:
        keyword arguments passed to `apply_fn`. `apply_fn_kwargs` will be split
        into `apply_fn_kwargs1` and `apply_fn_kwargs2` by the `_split_kwargs`
        function which will be passed to `apply_fn`. In particular, the rng key
        in `apply_fn_kwargs`, will be split into two different (if `x1!=x2`) or
        same (if `x1==x2`) rng keys. See the `_read_key` function for more
        details.

    Returns:
      A single sample of the empirical NNGP. The shape of the kernel is "almost"
      `zip(f(x1).shape, f(x2).shape)` except for:
      1) `trace_axes` are absent as they are contracted over.
      2) `diagonal_axes` are present only once.
      All other axes are present twice.
    """

    def output(x, **kwargs):
      out = f(params, x, **kwargs)
      masked_output = utils.get_masked_array(out)
      return masked_output.masked_value

    apply_fn_kwargs1, apply_fn_kwargs2 = _split_kwargs(apply_fn_kwargs, x1, x2)

    out1 = output(x1, **apply_fn_kwargs1)
    if x2 is None:
      out2 = out1
    else:
      out2 = output(x2, **apply_fn_kwargs2)

    dot = utils.dot_general(out1, out2, trace_axes, diagonal_axes)
    return dot / utils.size_at(out1, trace_axes)
Пример #3
0
 def contract(x, y):
     param_axes = list(range(x.ndim))[ndim:]
     contract_axes = _trace_axes + param_axes
     return utils.dot_general(x, y, contract_axes,
                              _diagonal_axes) / size
Пример #4
0
 def contract(out1, out2):
     dot = utils.dot_general(out1, out2, trace_axes, diagonal_axes)
     return dot / utils.size_at(out1, trace_axes)