예제 #1
0
파일: jet_test.py 프로젝트: wayfeng/jax
 def test_add_any(self):
   # https://github.com/google/jax/issues/5217
   f = lambda x, eps: x * eps + eps + x
   def g(eps):
     x = jnp.array(1.)
     return jax.grad(f)(x, eps)
   jet(g, (1.,), ([1.],))  # doesn't crash
예제 #2
0
def sol_recursive(f, z, t):
    """
  Recursively compute higher order derivatives of dynamics of ODE.
  """
    if reg == "none":
        return f(z, t), jnp.zeros_like(z)

    z_shape = z.shape
    z_t = jnp.concatenate((jnp.ravel(z), jnp.array([t])))

    def g(z_t):
        """
    Closure to expand z.
    """
        z, t = jnp.reshape(z_t[:-1], z_shape), z_t[-1]
        dz = jnp.ravel(f(z, t))
        dt = jnp.array([1.])
        dz_t = jnp.concatenate((dz, dt))
        return dz_t

    reg_ind = REGS.index(reg)

    (y0, [*yns]) = jet(g, (z_t, ), ((jnp.ones_like(z_t), ), ))
    for _ in range(reg_ind + 1):
        (y0, [*yns]) = jet(g, (z_t, ), ((y0, *yns), ))

    return (jnp.reshape(y0[:-1], z_shape), jnp.reshape(yns[-2][:-1], z_shape))
예제 #3
0
파일: jet_test.py 프로젝트: wayfeng/jax
  def test_inst_zero(self):
    def f(x):
      return 2.
    def g(x):
      return 2. + 0 * x
    x = jnp.ones(1)
    order = 3
    f_out_primals, f_out_series = jet(f, (x, ), ([jnp.ones_like(x) for _ in range(order)], ))
    assert f_out_series is not zero_series

    g_out_primals, g_out_series = jet(g, (x, ), ([jnp.ones_like(x) for _ in range(order)], ))

    assert g_out_primals == f_out_primals
    assert g_out_series == f_out_series
예제 #4
0
파일: jet_test.py 프로젝트: yotarok/jax
    def check_jet(self,
                  fun,
                  primals,
                  series,
                  atol=1e-5,
                  rtol=1e-5,
                  check_dtypes=True):
        y, terms = jet(fun, primals, series)
        expected_y, expected_terms = jvp_taylor(fun, primals, series)

        self.assertAllClose(y,
                            expected_y,
                            atol=atol,
                            rtol=rtol,
                            check_dtypes=check_dtypes)

        # TODO(duvenaud): Lower zero_series to actual zeros automatically.
        if terms == zero_series:
            terms = tree_map(np.zeros_like, expected_terms)

        self.assertAllClose(terms,
                            expected_terms,
                            atol=atol,
                            rtol=rtol,
                            check_dtypes=check_dtypes)
예제 #5
0
파일: jet_test.py 프로젝트: yotarok/jax
    def check_jet_finite(self,
                         fun,
                         primals,
                         series,
                         atol=1e-5,
                         rtol=1e-5,
                         check_dtypes=True):

        y, terms = jet(fun, primals, series)
        expected_y, expected_terms = jvp_taylor(fun, primals, series)

        def _convert(x):
            return np.where(np.isfinite(x), x, np.nan)

        y = _convert(y)
        expected_y = _convert(expected_y)

        terms = _convert(np.asarray(terms))
        expected_terms = _convert(np.asarray(expected_terms))

        self.assertAllClose(y,
                            expected_y,
                            atol=atol,
                            rtol=rtol,
                            check_dtypes=check_dtypes)

        # TODO(duvenaud): Lower zero_series to actual zeros automatically.
        if terms == zero_series:
            terms = tree_map(np.zeros_like, expected_terms)

        self.assertAllClose(terms,
                            expected_terms,
                            atol=atol,
                            rtol=rtol,
                            check_dtypes=check_dtypes)
예제 #6
0
파일: jet_test.py 프로젝트: romanngg/jax
    def check_jet_finite(self,
                         fun,
                         primals,
                         series,
                         atol=1e-5,
                         rtol=1e-5,
                         check_dtypes=True):
        # Convert to jax arrays to ensure dtype canonicalization.
        primals = jax.tree_map(jnp.asarray, primals)
        series = jax.tree_map(jnp.asarray, series)

        y, terms = jet(fun, primals, series)
        expected_y, expected_terms = jvp_taylor(fun, primals, series)

        def _convert(x):
            return jnp.where(jnp.isfinite(x), x, jnp.nan)

        y = _convert(y)
        expected_y = _convert(expected_y)

        terms = _convert(jnp.asarray(terms))
        expected_terms = _convert(jnp.asarray(expected_terms))

        self.assertAllClose(y,
                            expected_y,
                            atol=atol,
                            rtol=rtol,
                            check_dtypes=check_dtypes)

        self.assertAllClose(terms,
                            expected_terms,
                            atol=atol,
                            rtol=rtol,
                            check_dtypes=check_dtypes)
예제 #7
0
    def check_jet_finite(self,
                         fun,
                         primals,
                         series,
                         atol=1e-5,
                         rtol=1e-5,
                         check_dtypes=True):

        y, terms = jet(fun, primals, series)
        expected_y, expected_terms = jvp_taylor(fun, primals, series)

        def _convert(x):
            return jnp.where(jnp.isfinite(x), x, jnp.nan)

        y = _convert(y)
        expected_y = _convert(expected_y)

        terms = _convert(jnp.asarray(terms))
        expected_terms = _convert(jnp.asarray(expected_terms))

        self.assertAllClose(y,
                            expected_y,
                            atol=atol,
                            rtol=rtol,
                            check_dtypes=check_dtypes)

        self.assertAllClose(terms,
                            expected_terms,
                            atol=atol,
                            rtol=rtol,
                            check_dtypes=check_dtypes)
예제 #8
0
파일: jet_test.py 프로젝트: romanngg/jax
    def check_jet(self,
                  fun,
                  primals,
                  series,
                  atol=1e-5,
                  rtol=1e-5,
                  check_dtypes=True):
        # Convert to jax arrays to ensure dtype canonicalization.
        primals = jax.tree_map(jnp.asarray, primals)
        series = jax.tree_map(jnp.asarray, series)

        y, terms = jet(fun, primals, series)
        expected_y, expected_terms = jvp_taylor(fun, primals, series)

        self.assertAllClose(y,
                            expected_y,
                            atol=atol,
                            rtol=rtol,
                            check_dtypes=check_dtypes)

        self.assertAllClose(terms,
                            expected_terms,
                            atol=atol,
                            rtol=rtol,
                            check_dtypes=check_dtypes)
예제 #9
0
    def _taylor_coefficient_generator(*, f, y0):
        """Generate Taylor coefficients.

        Generate Taylor-series-coefficients of the ODE solution `x(t)` via generating
        Taylor-series-coefficients of `g(t)=f(x(t))` via ``jax.experimental.jet()``.
        """

        # This is the 0th Taylor coefficient of x(t) at t=t0.
        x_primals = y0
        yield (x_primals, )

        # This contains the higher-order, unnormalised
        # Taylor coefficients of x(t) at t=t0.
        # We know them because of the ODE.
        x_series = (f(y0), )

        while True:
            yield (x_primals, ) + x_series

            # jet() computes a Taylor approximation of g(t) := f(x(t))
            # The output is the zeroth Taylor approximation g(t_0) ('primals')
            # as well its higher-order Taylor coefficients ('series')
            g_primals, g_series = jet(fun=f,
                                      primals=(x_primals, ),
                                      series=(x_series, ))

            # For ODEs \dot y(t) = f(y(t)),
            # The nth Taylor coefficient of y is the
            # (n-1)th Taylor coefficient of g(t) = f(y(t)).
            # This way, by augmenting x0 with the Taylor series
            # approximating g(t) = f(y(t)), we increase the order
            # of the approximation by 1.
            x_series = (g_primals, *g_series)
예제 #10
0
def _taylormode(f, z0, t0, order):
    """Taylor-mode automatic differentiation for initialisation.

    Inspired by the implementation in
    https://github.com/jacobjinkelly/easy-neural-ode/blob/master/latent_ode.py
    """

    try:
        import jax.numpy as jnp
        from jax.config import config
        from jax.experimental.jet import jet

        config.update("jax_enable_x64", True)
    except ImportError as err:
        raise ImportError(
            "Cannot perform Taylor-mode initialisation without optional "
            "dependencies jax and jaxlib. Try installing them via `pip install jax jaxlib`."
        ) from err

    def total_derivative(z_t):
        """Total derivative."""
        z, t = jnp.reshape(z_t[:-1], z_shape), z_t[-1]
        dz = jnp.ravel(f(t, z))
        dt = jnp.array([1.0])
        dz_t = jnp.concatenate((dz, dt))
        return dz_t

    z_shape = z0.shape
    z_t = jnp.concatenate((jnp.ravel(z0), jnp.array([t0])))

    derivs = []

    derivs.extend(z0)
    if order == 0:
        return jnp.array(derivs)
    (y0, [*yns]) = jet(total_derivative, (z_t,), ((jnp.ones_like(z_t),),))
    derivs.extend(y0[:-1])
    if order == 1:
        return jnp.array(derivs)

    order = order - 2
    for _ in range(order + 1):
        (y0, [*yns]) = jet(total_derivative, (z_t,), ((y0, *yns),))
        derivs.extend(yns[-2][:-1])

    return jnp.array(derivs)
예제 #11
0
파일: jet_test.py 프로젝트: wayfeng/jax
  def check_jet(self, fun, primals, series, atol=1e-5, rtol=1e-5,
                check_dtypes=True):
    y, terms = jet(fun, primals, series)
    expected_y, expected_terms = jvp_taylor(fun, primals, series)

    self.assertAllClose(y, expected_y, atol=atol, rtol=rtol,
                        check_dtypes=check_dtypes)

    self.assertAllClose(terms, expected_terms, atol=atol, rtol=rtol,
                        check_dtypes=check_dtypes)
예제 #12
0
def sol_recursive(f, z, t):
  """
  Recursively compute higher order derivatives of dynamics of ODE.
  """
  z_shape = z.shape
  z_t = jnp.concatenate((jnp.ravel(z), jnp.array([t])))

  def g(z_t):
    """
    Closure to expand z.
    """
    z, t = jnp.reshape(z_t[:-1], z_shape), z_t[-1]
    dz = jnp.ravel(f(z, t))
    dt = jnp.array([1.])
    dz_t = jnp.concatenate((dz, dt))
    return dz_t

  (y0, [y1h]) = jet(g, (z_t, ), ((jnp.ones_like(z_t), ), ))
  (y0, [y1, y2h]) = jet(g, (z_t, ), ((y0, y1h,), ))

  return (jnp.reshape(y0[:-1], z_shape), [jnp.reshape(y1[:-1], z_shape)])
예제 #13
0
def initialize_odefilter_with_taylormode(f, y0, t0, prior, initrv):
    """Initialize an ODE filter with Taylor-mode automatic differentiation.

    This requires JAX. For an explanation of what happens ``under the hood``, see [1]_.

    References
    ----------
    .. [1] Krämer, N. and Hennig, P., Stable implementation of probabilistic ODE solvers,
       *arXiv:2012.10106*, 2020.


    The implementation is inspired by the implementation in
    https://github.com/jacobjinkelly/easy-neural-ode/blob/master/latent_ode.py

    Parameters
    ----------
    f
        ODE vector field.
    y0
        Initial value.
    t0
        Initial time point.
    prior
        Prior distribution used for the ODE solver. For instance an integrated Brownian motion prior (``IBM``).
    initrv
        Initial random variable.

    Returns
    -------
    Normal
        Estimated initial random variable. Compatible with the specified prior.


    Examples
    --------

    >>> import sys, pytest
    >>> if not sys.platform.startswith('linux'):
    ...     pytest.skip()

    >>> from dataclasses import astuple
    >>> from probnum.randvars import Normal
    >>> from probnum.problems.zoo.diffeq import threebody_jax, vanderpol_jax
    >>> from probnum.statespace import IBM

    Compute the initial values of the restricted three-body problem as follows

    >>> f, t0, tmax, y0, df, *_ = astuple(threebody_jax())
    >>> print(y0)
    [ 0.994       0.          0.         -2.00158511]

    >>> prior = IBM(ordint=3, spatialdim=4)
    >>> initrv = Normal(mean=np.zeros(prior.dimension), cov=np.eye(prior.dimension))
    >>> improved_initrv = initialize_odefilter_with_taylormode(f, y0, t0, prior, initrv)
    >>> print(prior.proj2coord(0) @ improved_initrv.mean)
    [ 0.994       0.          0.         -2.00158511]
    >>> print(improved_initrv.mean)
    [ 9.94000000e-01  0.00000000e+00 -3.15543023e+02  0.00000000e+00
      0.00000000e+00 -2.00158511e+00  0.00000000e+00  9.99720945e+04
      0.00000000e+00 -3.15543023e+02  0.00000000e+00  6.39028111e+07
     -2.00158511e+00  0.00000000e+00  9.99720945e+04  0.00000000e+00]

    Compute the initial values of the van-der-Pol oscillator as follows

    >>> f, t0, tmax, y0, df, *_ = astuple(vanderpol_jax())
    >>> print(y0)
    [2. 0.]
    >>> prior = IBM(ordint=3, spatialdim=2)
    >>> initrv = Normal(mean=np.zeros(prior.dimension), cov=np.eye(prior.dimension))
    >>> improved_initrv = initialize_odefilter_with_taylormode(f, y0, t0, prior, initrv)
    >>> print(prior.proj2coord(0) @ improved_initrv.mean)
    [2. 0.]
    >>> print(improved_initrv.mean)
    [    2.     0.    -2.    60.     0.    -2.    60. -1798.]
    >>> print(improved_initrv.std)
    [0. 0. 0. 0. 0. 0. 0. 0.]
    """

    try:
        import jax.numpy as jnp
        from jax.config import config
        from jax.experimental.jet import jet

        config.update("jax_enable_x64", True)
    except ImportError as err:
        raise ImportError(
            "Cannot perform Taylor-mode initialisation without optional "
            "dependencies jax and jaxlib. Try installing them via `pip install jax jaxlib`."
        ) from err

    order = prior.ordint

    def total_derivative(z_t):
        """Total derivative."""
        z, t = jnp.reshape(z_t[:-1], z_shape), z_t[-1]
        dz = jnp.ravel(f(t, z))
        dt = jnp.array([1.0])
        dz_t = jnp.concatenate((dz, dt))
        return dz_t

    z_shape = y0.shape
    z_t = jnp.concatenate((jnp.ravel(y0), jnp.array([t0])))

    derivs = []

    derivs.extend(y0)
    if order == 0:
        all_derivs = statespace.Integrator._convert_derivwise_to_coordwise(
            np.asarray(jnp.array(derivs)), ordint=0, spatialdim=len(y0))

        return randvars.Normal(
            np.asarray(all_derivs),
            cov=np.asarray(jnp.diag(jnp.zeros(len(derivs)))),
            cov_cholesky=np.asarray(jnp.diag(jnp.zeros(len(derivs)))),
        )

    (dy0, [*yns]) = jet(total_derivative, (z_t, ), ((jnp.ones_like(z_t), ), ))
    derivs.extend(dy0[:-1])
    if order == 1:
        all_derivs = statespace.Integrator._convert_derivwise_to_coordwise(
            np.asarray(jnp.array(derivs)), ordint=1, spatialdim=len(y0))

        return randvars.Normal(
            np.asarray(all_derivs),
            cov=np.asarray(jnp.diag(jnp.zeros(len(derivs)))),
            cov_cholesky=np.asarray(jnp.diag(jnp.zeros(len(derivs)))),
        )

    for _ in range(1, order):
        (dy0, [*yns]) = jet(total_derivative, (z_t, ), ((dy0, *yns), ))
        derivs.extend(yns[-2][:-1])

    all_derivs = statespace.Integrator._convert_derivwise_to_coordwise(
        jnp.array(derivs), ordint=order, spatialdim=len(y0))

    return randvars.Normal(
        np.asarray(all_derivs),
        cov=np.asarray(jnp.diag(jnp.zeros(len(derivs)))),
        cov_cholesky=np.asarray(jnp.diag(jnp.zeros(len(derivs)))),
    )
예제 #14
0
    def __call__(
        self,
        ivp: problems.InitialValueProblem,
        prior_process: randprocs.markov.MarkovProcess,
    ) -> randvars.RandomVariable:

        try:
            import jax.numpy as jnp
            from jax.config import config
            from jax.experimental.jet import jet

            config.update("jax_enable_x64", True)
        except ImportError as err:
            raise ImportError(
                "Cannot perform Taylor-mode initialisation without optional "
                "dependencies jax and jaxlib. Try installing them via `pip install jax jaxlib`."
            ) from err

        num_derivatives = prior_process.transition.num_derivatives

        dt = jnp.array([1.0])

        def evaluate_ode_for_extended_state(extended_state, ivp=ivp, dt=dt):
            r"""Evaluate the ODE for an extended state (x(t), t).

            More precisely, compute the derivative of the stacked state (x(t), t) according to the ODE.
            This function implements a rewriting of non-autonomous as autonomous ODEs.
            This means that

            .. math:: \dot x(t) = f(t, x(t))

            becomes

            .. math:: \dot z(t) = \dot (x(t), t) = (f(x(t), t), 1).

            Only considering autonomous ODEs makes the jet-implementation
            (and automatic differentiation in general) easier.
            """
            x, t = jnp.reshape(extended_state[:-1],
                               ivp.y0.shape), extended_state[-1]
            dx = ivp.f(t, x)
            dx_ravelled = jnp.ravel(dx)
            stacked_ode_eval = jnp.concatenate((dx_ravelled, dt))
            return stacked_ode_eval

        def derivs_to_normal_randvar(derivs, num_derivatives_in_prior):
            """Finalize the output in terms of creating a suitably sized random
            variable."""
            all_derivs = (randprocs.markov.integrator.convert.
                          convert_derivwise_to_coordwise(
                              np.asarray(derivs),
                              num_derivatives=num_derivatives_in_prior,
                              wiener_process_dimension=ivp.y0.shape[0],
                          ))

            # Wrap all inputs through np.asarray, because 'Normal's
            # do not like JAX 'DeviceArray's
            return randvars.Normal(
                mean=np.asarray(all_derivs),
                cov=np.asarray(jnp.diag(jnp.zeros(len(derivs)))),
                cov_cholesky=np.asarray(jnp.diag(jnp.zeros(len(derivs)))),
            )

        extended_state = jnp.concatenate(
            (jnp.ravel(ivp.y0), jnp.array([ivp.t0])))
        derivs = []

        # Corner case 1: num_derivatives == 0
        derivs.extend(ivp.y0)
        if num_derivatives == 0:
            return derivs_to_normal_randvar(
                derivs=derivs, num_derivatives_in_prior=num_derivatives)

        # Corner case 2: num_derivatives == 1
        initial_series = (jnp.ones_like(extended_state), )
        (initial_taylor_coefficient, [*remaining_taylor_coefficents]) = jet(
            fun=evaluate_ode_for_extended_state,
            primals=(extended_state, ),
            series=(initial_series, ),
        )
        derivs.extend(initial_taylor_coefficient[:-1])
        if num_derivatives == 1:
            return derivs_to_normal_randvar(
                derivs=derivs, num_derivatives_in_prior=num_derivatives)

        # Order > 1
        for _ in range(1, num_derivatives):
            taylor_coefficients = (
                initial_taylor_coefficient,
                *remaining_taylor_coefficents,
            )
            (_, [*remaining_taylor_coefficents]) = jet(
                fun=evaluate_ode_for_extended_state,
                primals=(extended_state, ),
                series=(taylor_coefficients, ),
            )
            derivs.extend(remaining_taylor_coefficents[-2][:-1])
        return derivs_to_normal_randvar(
            derivs=derivs, num_derivatives_in_prior=num_derivatives)