Esempio n. 1
0
    def vectorize(self, vectorized):
        vec_args = [a for a in self.args if a in vectorized]
        excluded = [i for i, a in enumerate(self.args) if a not in vectorized]

        if not vec_args:
            return

        arg_sig = ",".join(f'({vectorized[a]})' for a in vec_args)
        sig = f"{arg_sig}->({self.core_shape})"

        self.fun = jnp.vectorize(self.fun, excluded=excluded, signature=sig)
        for wrt, d in self.derivatives.items():
            wrtsig = (vectorized[var] for var in reversed(wrt))
            if self.core_shape:
                outsig = ','.join((self.core_shape, *wrtsig))
            else:
                outsig = ','.join(wrtsig)
            dsig = f"{arg_sig}->({outsig})"
            vecd = jnp.vectorize(d, excluded=excluded, signature=dsig)
            self.derivatives[wrt] = vecd

        core_shape = self.core_shape
        out_core_ndim = len(core_shape.split(',')) if core_shape else 0
        self.core_ndim = {a: len(vectorized[a].split(',')) for a in vec_args}
        self.core_ndim[None] = out_core_ndim
        self.isvectorized = True
  def test_exclude_errors(self):
    with self.assertRaisesRegex(
        TypeError, "jax.numpy.vectorize can only exclude"):
      jnp.vectorize(lambda x: x, excluded={'foo'})

    with self.assertRaisesRegex(
        ValueError, r"excluded=\{-1\} contains negative numbers"):
      jnp.vectorize(lambda x: x, excluded={-1})

    f = jnp.vectorize(lambda x: x, excluded={1})
    with self.assertRaisesRegex(
        ValueError, r"excluded=\{1\} is invalid for 1 argument\(s\)"):
      f(1.0)
Esempio n. 3
0
def cosine_similarity(
    predictions: chex.Array,
    targets: chex.Array,
    epsilon: float = 0.,
) -> chex.Array:
    r"""Computes the cosine similarity between targets and predictions.

  The cosine **similarity** is a measure of similarity between vectors defined
  as the cosine of the angle between them, which is also the inner product of
  those vectors normalized to have unit norm.

  References:
    [Wikipedia, 2021](https://en.wikipedia.org/wiki/Cosine_similarity)

  Args:
    predictions: The predicted vector.
    targets: Ground truth target vector.
    epsilon: minimum norm for terms in the denominator of the cosine similarity.

  Returns:
    cosine similarity values.
  """
    chex.assert_equal_shape([targets, predictions])
    chex.assert_type([targets, predictions], float)
    # vectorize norm fn, to treat all dimensions except the last as batch dims.
    batched_norm_fn = jnp.vectorize(utils.safe_norm,
                                    signature='(k)->()',
                                    excluded={1})
    # normalise the last dimension of targets and predictions.
    unit_targets = targets / jnp.expand_dims(batched_norm_fn(targets, epsilon),
                                             axis=-1)
    unit_predictions = predictions / jnp.expand_dims(
        batched_norm_fn(predictions, epsilon), axis=-1)
    # return cosine similarity.
    return jnp.sum(unit_targets * unit_predictions, axis=-1)
Esempio n. 4
0
def jax_funcify_Composite(op, vectorize=True, **kwargs):
    jax_impl = jax_funcify(op.fgraph)

    def composite(*args):
        return jax_impl(*args)[0]

    return jnp.vectorize(composite)
Esempio n. 5
0
def logpdf(x, mean, cov, allow_singular=None):
    if allow_singular is not None:
        raise NotImplementedError(
            "allow_singular argument of multivariate_normal.logpdf")
    x, mean, cov = _promote_dtypes_inexact(x, mean, cov)
    if not mean.shape:
        return (-1 / 2 * jnp.square(x - mean) / cov - 1 / 2 *
                (np.log(2 * np.pi) + jnp.log(cov)))
    else:
        n = mean.shape[-1]
        if not np.shape(cov):
            y = x - mean
            return (-1 / 2 * jnp.einsum('...i,...i->...', y, y) / cov - n / 2 *
                    (np.log(2 * np.pi) + jnp.log(cov)))
        else:
            if cov.ndim < 2 or cov.shape[-2:] != (n, n):
                raise ValueError(
                    "multivariate_normal.logpdf got incompatible shapes")
            L = lax.linalg.cholesky(cov)
            y = jnp.vectorize(partial(lax.linalg.triangular_solve,
                                      lower=True,
                                      transpose_a=True),
                              signature="(n,n),(n)->(n)")(L, x - mean)
            return (-1 / 2 * jnp.einsum('...i,...i->...', y, y) -
                    n / 2 * np.log(2 * np.pi) -
                    jnp.log(L.diagonal(axis1=-1, axis2=-2)).sum(-1))
Esempio n. 6
0
  def forward(self, x: Array) -> Array:
    """Computes y = f(x)."""
    self._check_forward_input_shape(x)

    def unbatched(single_x, matrix, bias):
      return matrix @ single_x + bias

    batched = jnp.vectorize(unbatched, signature="(m),(m,m),(m)->(m)")
    return batched(x, self._matrix, self._bias)
Esempio n. 7
0
  def inverse(self, y: Array) -> Array:
    """Computes x = f^{-1}(y)."""
    self._check_inverse_input_shape(y)

    def unbatched(single_y, matrix, bias):
      return jnp.linalg.solve(matrix, single_y - bias)

    batched = jnp.vectorize(unbatched, signature="(m),(m,m),(m)->(m)")
    return batched(y, self._matrix, self._bias)
Esempio n. 8
0
 def covariance(self) -> Array:
     """Calculates the covariance."""
     probs = self.probs
     cov_matrix = -self._total_count[..., None, None] * (
         probs[..., None, :] * probs[..., :, None])
     chex.assert_shape(cov_matrix, probs.shape + self.event_shape)
     # Missing diagonal term in the covariance matrix.
     cov_matrix += jnp.vectorize(jnp.diag, signature='(k)->(k,k)')(
         self._total_count[..., None] * probs)
     return cov_matrix
Esempio n. 9
0
 def entropy(self) -> Array:
     """Calculates the Shannon entropy (in nats)."""
     # The method `_entropy_scalar` does not work when `self.total_count` is an
     # array (instead of a scalar) or when we jit the function, so we default to
     # computing the entropy using an alternative method that uses a lax while
     # loop and does not create intermediate arrays whose shape depends on
     # `self.total_count`.
     entropy_fn = jnp.vectorize(self._entropy_scalar_with_lax,
                                signature='(),(k),(k)->()')
     return entropy_fn(self.total_count, self.probs, self.log_of_probs)
Esempio n. 10
0
def _vmap_2d(fn: Callable[[float, float, float], float], cov12: np.ndarray,
             var1: np.ndarray, var2: Optional[np.ndarray],
             diagonal_batch: bool, diagonal_spatial: bool) -> np.ndarray:
    """Effectively a "2D vmap" of `fn(cov12, var1, var2)`.

  Applicable for all possible kernel layouts.

  Args:
    fn:
      scalar-valued, elementwise `fn(cov12, var1, var2)` function to apply.

    cov12:
      covariance tensor (`q12`), `nngp`/`ntk`/`cov1`/`cov2`, of shape
      `(N1[, N2])`, `(N1[, N2], X, Y, ...)`, `(N1[, N2], X, X, Y, Y, ...)`
      depending on `diagonal_batch`, `diagonal_spatial`, and the number of
      spatial dimensions.

    var1:
      variance tensor (`q11`), has shape `(N1[, X, Y, ...])`.

    var2:
      variance tensor (`q22`), has shape `(N1[, X, Y, ...])`.

    diagonal_batch:
      `True` if `cov12` has only one batch dimension.

    diagonal_spatial:
      `True` if `cov12` has spatial dimensions appearing once (vs twice).

  Returns:
    Resulting array `[fn(cov12[i, j], var1[i], var2[j])]_{i j}`. Has the same
    shape as `cov12`.
  """
    batch_ndim = 1 if diagonal_batch else 2
    start = 2 - batch_ndim
    cov_end = batch_ndim if diagonal_spatial else cov12.ndim
    _cov12 = utils.make_2d(cov12, start, cov_end)

    var_end = 1 if diagonal_spatial else var1.ndim
    var1 = var1.reshape(var1.shape[:start] + (-1, ) + var1.shape[var_end:])
    var2 = var1 if var2 is None else var2.reshape(var2.shape[:start] + (-1, ) +
                                                  var2.shape[var_end:])

    fn = vmap(vmap(np.vectorize(fn),
                   in_axes=(start, None, start),
                   out_axes=start),
              in_axes=(start, start, None),
              out_axes=start)
    out = fn(_cov12, var1, var2)  # type: np.ndarray
    out_shape = (cov12.shape[:start] + cov12.shape[start:cov_end:2] +
                 cov12.shape[start + 1:cov_end:2] + cov12.shape[cov_end:])
    out = out.reshape(out_shape)
    out = utils.zip_axes(out, start, cov_end)
    return out
Esempio n. 11
0
 def test_bad_inputs(self):
   matmat = jnp.vectorize(jnp.dot, signature='(n,m),(m,k)->(n,k)')
   with self.assertRaisesRegex(
       TypeError, "wrong number of positional arguments"):
     matmat(jnp.zeros((3, 2)))
   with self.assertRaisesRegex(
       ValueError,
       r"input with shape \(2,\) does not have enough dimensions"):
     matmat(jnp.zeros((2,)), jnp.zeros((2, 2)))
   with self.assertRaisesRegex(
       ValueError, r"inconsistent size for core dimension 'm'"):
     matmat(jnp.zeros((2, 3)), jnp.zeros((4, 5)))
Esempio n. 12
0
  def covariance(self) -> Array:
    """Calculates the covariance.

    Constructs a diagonal matrix with the variance vector as diagonal. Note that
    TFP would drop leading dimensions in the covariance if
    `self._scale_diag.ndims < self._loc.ndims`. To keep things simple and
    predictable, and for consistency with other distributions, in Distrax the
    `covariance` has shape `batch_shape + (num_dims, num_dims)`.

    Returns:
      Diagonal covariance matrix.
    """
    return jnp.vectorize(jnp.diag, signature='(k)->(k,k)')(self.variance())
Esempio n. 13
0
def get_spin_spherical(transformer, shape, spins):
    """Returns set of spin-weighted spherical functions.

  Args:
    transformer: SpinSphericalFourierTransformer instance.
    shape: Desired shape (batch, latitude, longitude, spins, channels).
    spins: Desired spins.

  Returns:
    Array of spherical functions and array of their spectral coefficients.
  """
    # Make some arbitrary reproducible complex inputs.
    batch_size, resolution, _, num_spins, num_channels = shape
    if len(spins) != num_spins:
        raise ValueError('len(spins) must match desired shape.')
    ell_max = sphere_utils.ell_max_from_resolution(resolution)
    num_coefficients = np_spin_spherical_harmonics.n_coeffs_from_ell_max(
        ell_max)
    shape_coefficients = (batch_size, num_spins, num_channels,
                          num_coefficients)
    # These numbers are chosen arbitrarily, but not randomly, since random
    # coefficients make for hard to visually interpret functions. Something
    # simpler like linspace(-1-1j, 1+1j) would have the same phase for all complex
    # numbers, which is also undesirable.
    coefficients = (jnp.linspace(
        -0.5, 0.7 + 0.5j,
        np.prod(shape_coefficients)).reshape(shape_coefficients))

    # Broadcast
    to_matrix = jnp.vectorize(spin_spherical_harmonics.coefficients_to_matrix,
                              signature='(i)->(j,k)')
    coefficients = to_matrix(coefficients)
    # Transpose back to (batch, ell, m, spin, channel) format.
    coefficients = jnp.transpose(coefficients, (0, 3, 4, 1, 2))

    # Coefficients for ell < |spin| are always zero.
    for i, spin in enumerate(spins):
        coefficients = coefficients.at[:, :abs(spin), :, i].set(0.0)

    # Convert to spatial domain.
    batched_backward_transform = jax.vmap(
        transformer.swsft_backward_spins_channels, in_axes=(0, None))
    sphere = batched_backward_transform(coefficients, spins)

    return sphere, coefficients
Esempio n. 14
0
 def _get_localized_kernel(self, ell_max, num_channels_in):
   # We interpolate along ell to obtain all weights from the learnable weights,
   # hence it doesn't make sense to have more parameters than num_ell.
   if self.num_filter_params > ell_max + 1:
     raise ValueError("num_filter_params must be <= ell_max + 1")
   ell_in = jnp.linspace(0, 1, self.num_filter_params)
   ell_out = jnp.linspace(0, 1, ell_max + 1)
   # `vectorize` is over leading dimensions, so we put ell as the last
   # dimension and transpose it to the first later.
   learnable_shape = (len(self.spins_in), len(self.spins_out),
                      num_channels_in, self.features,
                      self.num_filter_params)
   learnable_weights = self.param("kernel", self.initializer, learnable_shape)
   # `jnp.interp` works on 1D inputs; we vectorize it to interpolate over a
   # single dimension of n-D inputs.
   vectorized_interp = jnp.vectorize(jnp.interp, signature="(m),(n),(n)->(m)")
   weights = vectorized_interp(ell_out, ell_in, learnable_weights)
   # Make ell the first dimension.
   return weights.transpose((4, 0, 1, 2, 3))
Esempio n. 15
0
 def __init__(self, connections, input_current=2,
              V_rest=-0.07, V_reset=-0.07, V_th=-0.054, V_is=-0.08, V_es=0, tm=0.02, ts=0.01, P_max=1, rm_gs=0.5,
              Rm=0.010):
     self.V_es = V_es
     self.V_is = V_is
     self.V_th = V_th
     self.V_rest = V_rest
     self.V_reset = V_reset
     self.Rm = Rm
     self.input_current = input_current
     self.P_max = P_max
     self.rm_gs = rm_gs
     self.ts = ts
     self.tm = tm
     self.number_of_neurons = len(connections)
     self.connectivity_matrix = np.fromfunction(np.vectorize(lambda i, j: connections.get(i, {}).get(j, 0)),
                                                (self.number_of_neurons, self.number_of_neurons)).T
     self.neurons = None
     self.P_synapses = None
     self.dt = None
     self.Ie = None
Esempio n. 16
0
def make_graph(data, save_name):
    prior = beta.pdf(x, a=data["prior"]["a"], b=data["prior"]["b"])
    n_0 = data["likelihood"]["n_0"]
    n_1 = data["likelihood"]["n_1"]
    samples = jnp.concatenate([jnp.zeros(n_0), jnp.ones(n_1)])
    likelihood_function = jnp.vectorize(
        lambda p: jnp.exp(bernoulli.logpmf(samples, p).sum()))
    likelihood = likelihood_function(x)
    posterior = beta.pdf(x, a=data["posterior"]["a"], b=data["posterior"]["b"])

    fig, ax = plt.subplots()
    axt = ax.twinx()
    fig1 = ax.plot(
        x,
        prior,
        "k",
        label=f"prior Beta({data['prior']['a']}, {data['prior']['b']})",
        linewidth=2.0,
    )
    fig2 = axt.plot(x,
                    likelihood,
                    "r:",
                    label=f"likelihood Bernoulli",
                    linewidth=2.0)
    fig3 = ax.plot(
        x,
        posterior,
        "b-.",
        label=
        f"posterior Beta({data['posterior']['a']}, {data['posterior']['b']})",
        linewidth=2.0,
    )
    fig_list = fig1 + fig2 + fig3
    labels = [fig.get_label() for fig in fig_list]
    ax.legend(fig_list, labels, loc="upper left", shadow=True)
    axt.set_ylabel("Likelihood")
    ax.set_ylabel("Prior/Posterior")
    ax.set_title(f"$N_0$:{n_0}, $N_1$:{n_1}")
    pml.savefig(save_name)
Esempio n. 17
0
 def test_wrong_output_type(self):
   f = jnp.vectorize(jnp.dot, signature='(n,m),(m,k)->(n,k),()')
   with self.assertRaisesRegex(
       TypeError, "output must be a tuple"):
     f(jnp.zeros((2, 2)), jnp.zeros((2, 2)))
Esempio n. 18
0
 def test_matmat(self, left_shape, right_shape, result_shape):
     matmat = np.vectorize(np.dot, signature='(n,m),(m,k)->(n,k)')
     self.assertEqual(
         matmat(np.zeros(left_shape), np.zeros(right_shape)).shape,
         result_shape)
Esempio n. 19
0
 def inverse_and_log_det(self, y: Array) -> Tuple[Array, Array]:
     """Computes x = f^{-1}(y) and log|det J(f^{-1})(y)|."""
     fn = jnp.vectorize(_rational_quadratic_spline_inv,
                        signature='(),(n),(n),(n)->(),()')
     x, logdet = fn(y, self._x_pos, self._y_pos, self._knot_slopes)
     return x, logdet
Esempio n. 20
0
 def test_vecmat(self, left_shape, right_shape, result_shape):
   vecvec = jnp.vectorize(jnp.dot, signature='(m),(m)->()')
   self.assertEqual(vecvec(jnp.zeros(left_shape),
                           jnp.zeros(right_shape)).shape, result_shape)
Esempio n. 21
0
 def test_inconsistent_output_size(self):
   f = jnp.vectorize(jnp.dot, signature='(n,m),(m,k)->(n,n)')
   with self.assertRaisesRegex(
       ValueError, r"inconsistent size for core dimension 'n'"):
     f(jnp.zeros((2, 3)), jnp.zeros((3, 4)))
Esempio n. 22
0
 def inverse_and_log_det(self, inputs: Array) -> Tuple[Array, Array]:
     fn = jnp.vectorize(_rational_quadratic_spline_inv,
                        signature="(),(n),(n),(n)->(),()")
     outputs, log_det = fn(inputs, self.x_pos, self.y_pos, self.knot_slopes)
     return outputs, log_det
Esempio n. 23
0
 def test_wrong_num_outputs(self):
   f = jnp.vectorize(lambda *args: args, signature='(),()->(),(),()')
   with self.assertRaisesRegex(
       TypeError, "wrong number of output arguments"):
     f(1, 2)
Esempio n. 24
0
def ElementwiseNumerical(
        fn: Callable[[float], float],
        deg: int,
        df: Optional[Callable[[float], float]] = None) -> InternalLayer:
    """Activation function using numerical integration.

  Supports general activation functions using Gauss-Hermite quadrature.

  Args:
    fn: activation function.
    deg: number of sample points and weights for quadrature. It must be >= 1.
      We observe for smooth activations deg=25 is a good place to start.
      For non-smooth activation functions (e.g. ReLU, Abs) quadrature is not
      recommended (for now use `nt.monte_carlo_kernel_fn`). Due to bivariate
      integration, compute time and memory scale as O(deg**2) for more
      precision. See eq (13) in
      https://mathworld.wolfram.com/Hermite-GaussQuadrature.html
      for error estimates in the case of 1d Gauss-Hermite quadrature.
    df: optional, derivative of the activation function (`fn`). If not provided,
      it is computed by `jax.grad`. Providing analytic derivative can speed up
      the NTK computations.

  Returns:
    `(init_fn, apply_fn, kernel_fn)`.
  """
    warnings.warn(
        f'Numerical Activation Layer with fn={fn}, deg={deg} used!'
        'Note that numerical error is controlled by `deg` and for a given'
        'tolerance level, required `deg` will highly be dependent on the choice'
        'of `fn`.')

    quad_points = osp.special.roots_hermite(deg)

    if df is None:
        warnings.warn(
            'Using JAX autodiff to compute the `fn` derivative for NTK. Beware of '
            'https://jax.readthedocs.io/en/latest/faq.html#gradients-contain-nan-where-using-where.'
        )
        df = np.vectorize(grad(fn))

    def kernel_fn(k: Kernel) -> Kernel:
        """Kernel transformation of activation function using quadrature."""
        cov1, nngp, cov2, ntk = k.cov1, k.nngp, k.cov2, k.ntk

        d1 = get_diagonal(cov1, k.diagonal_batch, k.diagonal_spatial)
        d2 = get_diagonal(cov2, k.diagonal_batch, k.diagonal_spatial)

        end_axis = 1 if k.diagonal_spatial else cov1.ndim
        q11 = utils.interleave_ones(d1, 0, end_axis, True)
        q22 = utils.interleave_ones(d1 if d2 is None else d2, 0, end_axis,
                                    False)

        def nngp_ntk_fn(nngp, q11, q22, ntk=None):
            """Simple Gauss-Hermite quadrature routine."""
            xs, ws = quad_points
            grid = np.outer(ws, ws)
            x = xs.reshape((xs.shape[0], ) + (1, ) * (nngp.ndim + 1))
            y = xs.reshape((1, xs.shape[0]) + (1, ) * nngp.ndim)
            xy_axes = (0, 1)

            nngp = np.expand_dims(nngp, xy_axes)
            q11, q22 = np.expand_dims(q11,
                                      xy_axes), np.expand_dims(q22, xy_axes)

            def integrate(f):
                fvals = f(_sqrt(2 * q11) *
                          x) * f(nngp / _sqrt(q11 / 2, 1e-30) * x +
                                 _sqrt(2 * (q22 - nngp**2 / q11)) * y)
                return np.tensordot(grid, fvals, (xy_axes, xy_axes)) / np.pi

            if ntk is not None:
                ntk *= integrate(df)
            nngp = integrate(fn)
            return nngp, ntk

        def nngp_fn_diag(nngp):
            xs, ws = quad_points
            x = xs.reshape((xs.shape[0], ) + (1, ) * nngp.ndim)
            x_axes = (0, )
            nngp = np.expand_dims(nngp, x_axes)
            fval = fn(_sqrt(2 * nngp) * x)**2
            return np.tensordot(ws, fval, (x_axes, x_axes)) / np.sqrt(np.pi)

        nngp, ntk = nngp_ntk_fn(nngp, q11, q22, ntk)

        if k.diagonal_batch and k.diagonal_spatial:
            cov1 = nngp_fn_diag(cov1)
            if cov2 is not None:
                cov2 = nngp_fn_diag(cov2)

        else:
            start_axis = 1 if k.diagonal_batch else 0
            q11 = utils.interleave_ones(d1, start_axis, end_axis, True)
            q22 = utils.interleave_ones(d1, start_axis, end_axis, False)
            cov1, _ = nngp_ntk_fn(cov1, q11, q22)

            if cov2 is not None:
                q11 = utils.interleave_ones(d2, start_axis, end_axis, True)
                q22 = utils.interleave_ones(d2, start_axis, end_axis, False)
                cov2, _ = nngp_ntk_fn(cov2, q11, q22)

        return k.replace(cov1=cov1, nngp=nngp, cov2=cov2, ntk=ntk)

    return _elementwise(fn, f'ElementwiseNumerical({fn},deg={deg})', kernel_fn)
Esempio n. 25
0
 def test_wrong_output_shape(self):
   f = jnp.vectorize(jnp.dot, signature='(n,m),(m,k)->(n)')
   with self.assertRaisesRegex(
       ValueError, r"output shape \(2, 2\) does not match"):
     f(jnp.zeros((2, 2)), jnp.zeros((2, 2)))
Esempio n. 26
0
def Elementwise(
    fn: Optional[Callable[[float], float]] = None,
    nngp_fn: Optional[Callable[[float, float, float], float]] = None,
    d_nngp_fn: Optional[Callable[[float, float, float], float]] = None
) -> InternalLayer:
    """Elementwise application of `fn` using provided `nngp_fn`.

  Constructs a layer given only scalar-valued nonlinearity / activation
  `fn` and the 2D integral `nngp_fn`. NTK function is derived automatically in
  closed form from `nngp_fn`.

  If you cannot provide the `nngp_fn`, see `nt.stax.ElementwiseNumerical` to use
  numerical integration or `nt.monte_carlo.monte_carlo_kernel_fn` to use Monte
  Carlo sampling.

  If your function is implemented separately (e.g. `nt.stax.Relu` etc) it's best
  to use the custom implementation, since it uses symbolically simplified
  expressions that are more precise and numerically stable.

  Example:
    >>> fn = jax.scipy.special.erf  # type: Callable[[float], float]
    
    >>> def nngp_fn(cov12: float, var1: float, var2: float) -> float:
    >>>  prod = (1 + 2 * var1) * (1 + 2 * var2)
    >>>  return np.arcsin(2 * cov12 / np.sqrt(prod)) * 2 / np.pi
    
    >>> # Use autodiff and vectorization to construct the layer:
    >>> _, _, kernel_fn_auto = stax.Elementwise(fn, nngp_fn)
    
    >>> # Use custom pre-derived expressions
    >>> # (should be faster and more numerically stable):
    >>> _, _, kernel_fn_stax = stax.Erf()
    
    >>> kernel_fn_auto(x1, x2) == kernel_fn_stax(x1, x2)  # usually `True`.

  Args:
    fn:
      a scalar-input/valued function `fn : R -> R`, the activation /
      nonlinearity. If `None`, invoking the finite width `apply_fn` will raise
      an exception.

    nngp_fn:
      a scalar-valued function
      `nngp_fn : (cov12, var1, var2) |-> E[fn(x_1) * fn(x_2)]`, where the
      expectation is over bivariate normal `x1, x2` with variances `var1`,
      `var2` and covarianve `cov12`. Needed for both NNGP and NTK calculation.
      If `None`, invoking infinite width `kernel_fn` will raise an exception.

    d_nngp_fn:
      an optional scalar-valued function
      `d_nngp_fn : (cov12, var1, var2) |-> E[fn'(x_1) * fn'(x_2)]` with the same
      `x1, x2` distribution as in `nngp_fn`. If `None`, will be computed using
      automatic differentiation as `d_nngp_fn = d(nngp_fn)/d(cov12)`, which may
      lead to worse precision or numerical stability. `nngp_fn` and `d_nngp_fn`
      are used to derive the closed-form expression for the NTK.

  Returns:
    `(init_fn, apply_fn, kernel_fn)`.

  Raises:
    NotImplementedError: if a `fn`/`nngp_fn` is not provided, but `apply_fn`/
      `kernel_fn` is called respectively.
  """
    if fn is not None:
        name = fn.__name__
    elif nngp_fn is not None:
        name = nngp_fn.__name__
    else:
        raise ValueError('No finite (`fn`) or infinite (`nngp_fn`) functions '
                         'provided, the layer will not do anything.')

    if nngp_fn is None:
        kernel_fn = None

    else:
        if d_nngp_fn is None:
            warnings.warn(
                'Using JAX autodiff to compute the `fn` derivative for NTK. Beware of '
                'https://jax.readthedocs.io/en/latest/faq.html#gradients-contain-nan-where-using-where.'
            )
            d_nngp_fn = np.vectorize(grad(nngp_fn))

        def kernel_fn(k: Kernel) -> Kernel:
            cov1, nngp, cov2, ntk = k.cov1, k.nngp, k.cov2, k.ntk

            var1 = get_diagonal(cov1, k.diagonal_batch, k.diagonal_spatial)
            var2 = get_diagonal(cov2, k.diagonal_batch, k.diagonal_spatial)

            if ntk is not None:
                ntk *= _vmap_2d(d_nngp_fn, nngp, var1, var2, False,
                                k.diagonal_spatial)

            nngp = _vmap_2d(nngp_fn, nngp, var1, var2, False,
                            k.diagonal_spatial)
            cov1 = _vmap_2d(nngp_fn, cov1, var1, None, k.diagonal_batch,
                            k.diagonal_spatial)
            if cov2 is not None:
                cov2 = _vmap_2d(nngp_fn, cov2, var2, None, k.diagonal_batch,
                                k.diagonal_spatial)
            return k.replace(cov1=cov1, nngp=nngp, cov2=cov2, ntk=ntk)

    return _elementwise(fn, name, kernel_fn)
Esempio n. 27
0
 def test_matvec(self, left_shape, right_shape, result_shape):
   matvec = jnp.vectorize(jnp.dot, signature='(n,m),(m)->(n)')
   self.assertEqual(matvec(jnp.zeros(left_shape),
                           jnp.zeros(right_shape)).shape, result_shape)
Esempio n. 28
0
 def forward_and_log_det(self, x: Array) -> Tuple[Array, Array]:
     """Computes y = f(x) and log|det J(f)(x)|."""
     fn = jnp.vectorize(_rational_quadratic_spline_fwd,
                        signature='(),(n),(n),(n)->(),()')
     y, logdet = fn(x, self._x_pos, self._y_pos, self._knot_slopes)
     return y, logdet
Esempio n. 29
0
 def test_mean(self, shape, result_shape):
   mean = jnp.vectorize(jnp.mean, signature='(n)->()')
   self.assertEqual(mean(jnp.zeros(shape)).shape, result_shape)
Esempio n. 30
0
def _conv_general_dilated_j(eqn: JaxprEqn, idx: int, invals: List[ShapedArray],
                            cts_in: ShapedArray) -> np.ndarray:
    if idx != 1:
        raise NotImplementedError(eqn, idx)

    lhs = invals[1 if idx == 0 else 0]
    rhs = invals[idx]
    ndim = cts_in.ndim

    lhs_spec, rhs_spec, out_spec = eqn.params['dimension_numbers']
    precision = eqn.params['precision']

    n_groups_f = eqn.params['feature_group_count']
    n_groups_b = eqn.params['batch_group_count']

    n_channels_in = lhs.shape[lhs_spec[1]]
    n_batch_in = lhs.shape[lhs_spec[0]]
    group_size_out = rhs.shape[rhs_spec[0]] // (n_groups_f * n_groups_b)
    group_size_in = n_channels_in // n_groups_f
    batch_size_in = n_batch_in // n_groups_b

    if isinstance(precision, tuple):
        if precision[0] == precision[1]:
            precision = precision[0]
        else:
            raise NotImplementedError(precision)

    filter_shape = tuple(rhs.shape[i] for i in range(ndim)
                         if i in rhs_spec[2:])

    j = lax.conv_general_dilated_patches(
        lhs=lhs,
        filter_shape=filter_shape,
        window_strides=eqn.params['window_strides'],
        padding=eqn.params['padding'],
        lhs_dilation=eqn.params['lhs_dilation'],
        rhs_dilation=eqn.params['rhs_dilation'],
        dimension_numbers=eqn.params['dimension_numbers'],
        precision=precision,
        preferred_element_type=eqn.params['preferred_element_type'])

    if n_groups_b > 1:
        j = np.moveaxis(j, (out_spec[0], out_spec[1]), (-1, -2))
        j = j.reshape(j.shape[:-2] + (n_channels_in, *filter_shape, n_groups_b,
                                      batch_size_in))
        j = np.moveaxis(j, (-1, -2), (-2, -1))

    else:
        j = np.moveaxis(j, out_spec[1], -1)
        rhs_shape = (n_groups_f, group_size_in) + filter_shape

        j = j.reshape(j.shape[:ndim - 1] + rhs_shape)
        j = np.moveaxis(j, (ndim - 1, ndim), (-1, -2))

    j = np.vectorize(np.diag, signature='(k)->(k,k)')(j)

    if n_groups_b > 1:
        j = np.moveaxis(j, tuple(range(ndim - 2,
                                       j.ndim)), [ndim + rhs_spec[1]] +
                        [ndim + i for i in sorted(rhs_spec[2:])] +
                        [out_spec[0], out_spec[1], ndim + rhs_spec[0]])

    else:
        j = np.moveaxis(j, tuple(range(ndim - 1, j.ndim)),
                        [ndim + i for i in sorted(rhs_spec[2:])] +
                        [ndim + rhs_spec[1], out_spec[1], ndim + rhs_spec[0]])

    eye = np.eye(group_size_out, dtype=lhs.dtype)
    eye = np.expand_dims(eye, [
        i for i in range(j.ndim) if i not in (out_spec[1], ndim + rhs_spec[0])
    ])
    j = np.kron(j, eye)
    return j