Exemple #1
0
 def one_rng(rng):
     _, params = dynamics_init(rng, (x_dim, ))
     fwd_x_T = ode.odeint(lambda x, _: dynamics(params, x),
                          jnp.zeros((x_dim, )),
                          times,
                          mxstep=1e9)
     bwd_x_0 = vmap(
         lambda T, x_T: ode.odeint(lambda x, _: -dynamics(params, x),
                                   x_T,
                                   jnp.array([0.0, T]),
                                   mxstep=1e9)[1],
         in_axes=(0, 0))(times, fwd_x_T)
     return jnp.sum(bwd_x_0**2, axis=-1)
        def eval_from_x0(policy_params, x0, total_time):
            # Zero is necessary for some reason...
            t = jp.array([0.0, total_time])
            y0 = jp.concatenate((jp.zeros((1, )), x0, jp.zeros((z_dim, ))))
            # odeint_kwargs = {"rtol": 1e-3, "mxstep": 1e6}
            odeint_kwargs = {"mxstep": 1e6}
            y_fwd = ode.odeint(ofunc, y0, t, policy_params, **odeint_kwargs)

            # This is similar but not exactly the same as the place that the rev-mode
            # solution since the step sizes can vary when using all the other
            # parameters.
            y_bwd = ode.odeint(lambda y, t, *args: -ofunc(y, -t, *args),
                               y_fwd[1], -t[::-1], policy_params,
                               **odeint_kwargs)

            return y_fwd, y_bwd[::-1]
Exemple #3
0
 def update(r):
     v1_normalized = r * saep * v1
     y = np.asarray(
         (r_axis + ct * v1_normalized[0] + st * v1_normalized[1],
          z_axis + v1_normalized[2]))
     sol = odeint(step_partial, y, t_eval)
     return sol[:, 0], sol[:, 1]
Exemple #4
0
 def test_disable_jit_odeint_with_vmap(self):
     # https://github.com/google/jax/issues/2598
     with jax.disable_jit():
         t = jnp.array([0.0, 1.0])
         x0_eval = jnp.zeros((5, 2))
         f = lambda x0: odeint(lambda x, _t: x, x0, t)
         jax.vmap(f)(x0_eval)  # doesn't crash
def manifold_ode_log_prob(params: List, rng: random.PRNGKey, num_samples: int) -> Tuple[jnp.ndarray]:
    """Forward model of the neural manifold ODE. The base distribution is uniform
    on the sphere. Computes both the samples from the forward model and the
    log-probability of the generated samples.

    Args:
        params: Parameters of the neural manifold ODE.
        rng: Pseudo-random number generator key.
        num_samples: Number of samples use to estimate the KL divergence.

    Returns:
        sph: samples generated from the neural manifold ODE under the forward
            model.
        log_prob: The log-probability of the generated samples.

    """
    rng, rng_x, rng_b = random.split(rng, 3)
    x = sample_uniform(rng_x, [num_samples, 4])
    b = project_to_sphere(x + 0.01 * random.normal(rng_b, x.shape))
    v = log(b, x)
    fldj = log_det_jac_exp(b, v)
    vector_field = ambient_to_spherical_vector_field(lambda x, t: net(params, stacked(x, t)))
    cfunc, divfunc = spherical_to_chart_vector_field(b, vector_field)
    init = (v, jnp.zeros(len(v)))
    time = jnp.array([0.0, 1.0])
    tang, trace = tuple(_[-1] for _ in odeint(divfunc, init, time))
    sph = exp(b, tang)
    ildj = -log_det_jac_exp(b, tang)
    log_prob = uniform_log_density(x) + fldj + ildj + trace
    return sph, log_prob
def manifold_reverse_ode_log_prob(params: List, rng: random.PRNGKey, revx: jnp.ndarray) -> jnp.ndarray:
    """Given observations, compute their log-likelihood under the neural manifold
    ODE by integrating the dynamics backwards and applying the
    change-of-variables formula (computed continuously).

    Args:
        params: Parameters of the neural manifold ODE.
        rng: Pseudo-random number generator key.
        revx: Observations whose log-likelihood under the neural ODE model
            should be computed.

    Returns:
        rev_log_prob: The log-probability of the observations.

    """
    b = project_to_sphere(revx + 0.01 * random.normal(rng, revx.shape))
    vrev = log(b, revx)
    revfldj = log_det_jac_exp(b, vrev)
    revinit = (vrev, jnp.zeros(len(vrev)))
    vector_field = ambient_to_spherical_vector_field(lambda x, t: net(params, stacked(x, t)))
    cfunc, divfunc = spherical_to_chart_vector_field(b, vector_field)
    revfunc = lambda x, t: tuple(-_ for _ in divfunc(x, 1.0 - t))
    time = jnp.array([0.0, 1.0])
    revtang, revtrace = tuple(_[-1] for _ in odeint(revfunc, revinit, time))
    revsph = exp(b, revtang)
    revildj = -log_det_jac_exp(b, revtang)
    rev_log_prob = uniform_log_density(revsph) - revfldj - revildj - revtrace
    return rev_log_prob
Exemple #7
0
 def _run_static(cls, T, x0, theta, rtol=1e-5, atol=1e-3, mxstep=500):
     '''
     x0 is shape (d,)
     theta is shape (nargs,)
     '''
     t = np.arange(T, dtype='float32')
     return odeint(cls.dx_dt, x0, t, *theta)
 def integrate(self, x0, n_steps):
   """
   Integrates the equations of motion of the dynamical system.
   """
   t = jnp.asarray([n*self.dt for n in range(n_steps)])
   traj = odeint(self.equation_of_motion, x0, t)
   return traj
Exemple #9
0
def model(N, y=None):
    """
    :param int N: number of measurement times
    :param numpy.ndarray y: measured populations with shape (N, 2)
    """
    # initial population
    z_init = numpyro.sample("z_init",
                            dist.LogNormal(jnp.log(10), 1).expand([2]))
    # measurement times
    ts = jnp.arange(float(N))
    # parameters alpha, beta, gamma, delta of dz_dt
    theta = numpyro.sample(
        "theta",
        dist.TruncatedNormal(
            low=0.0,
            loc=jnp.array([1.0, 0.05, 1.0, 0.05]),
            scale=jnp.array([0.5, 0.05, 0.5, 0.05]),
        ),
    )
    # integrate dz/dt, the result will have shape N x 2
    z = odeint(dz_dt, z_init, ts, theta, rtol=1e-6, atol=1e-5, mxstep=1000)
    # measurement errors
    sigma = numpyro.sample("sigma", dist.LogNormal(-1, 1).expand([2]))
    # measured populations
    numpyro.sample("y", dist.LogNormal(jnp.log(z), sigma), obs=y)
 def solve_ode(init_condition, times):
     return odeint(f,
                   init_condition,
                   t=times,
                   rtol=1e-10,
                   atol=1e-10,
                   **kwargs)
  def eval_from_x0(policy_params, x0, total_time):
    # Zero is necessary for some reason...
    ts = jnp.array([0.0, total_time])
    y0 = (jnp.zeros(()), jnp.zeros(()), x0)
    odeint_kwargs = {"mxstep": 1e6}
    y_fwd = ode.odeint(ofunc, y0, ts, policy_params, **odeint_kwargs)
    yT = tree_map(itemgetter(-1), y_fwd)

    # This is similar but not exactly the same as the place that the rev-mode
    # solution since the step sizes can vary when using all the other
    # parameters.
    y_bwd = ode.odeint(lambda y, t, *args: tree_map(jnp.negative, ofunc(y, -t, *args)), yT,
                       -ts[::-1], policy_params, **odeint_kwargs)
    y0_bwd = tree_map(itemgetter(-1), y_bwd)

    return y0, yT, y0_bwd
Exemple #12
0
def loss(t1, flat_p, omega, U_T):
    '''
    define the loss function, which is a pure function
    '''
    t_set = jnp.linspace(0., t1, 5)

    D, _, = jnp.shape(U_T)
    U_0 = jnp.eye(D, dtype=jnp.complex128)

    def func(y, t, *args):
        t1, omega, flat_p, = args

        return -1.0j * (omega * sz + A(t, flat_p, t1) * sx) @ y
        # return -1.0j*( omega* sz)@y

    res = odeint(func,
                 U_0,
                 t_set,
                 t1,
                 omega,
                 flat_p,
                 rtol=1.4e-10,
                 atol=1.4e-10)

    U_F = res[-1, :, :]
    return (1 - jnp.abs(jnp.trace(U_T.conj().T @ U_F) / D)**2)
Exemple #13
0
    def kernel(a):
        def integrand(y, t, a):
            return dndz(1. / t - t) * cosmo.g(
                a, t)  # probably not enough, careful about change ofvar

        y0 = 0.
        y = odeint(integrand, y0, np.array([cosmo._amin, a]), a)
        return y[1]
Exemple #14
0
        def g(x):
            # Two initial values for the ODE
            y0_arr = jnp.array([[x, 0.1], [x, 0.2]])

            # Run ODE twice
            t = jnp.array([0., 5.])
            y = jax.vmap(lambda y0: odeint(dx_dt, y0, t))(y0_arr)
            return y[:, -1].sum()
Exemple #15
0
 def advance(x0, theta):
     x1 = odeint(cls.dx_dt,
                 x0,
                 t_one_step,
                 *theta,
                 rtol=rtol,
                 atol=atol,
                 mxstep=mxstep)[1]
     return x1, x1
Exemple #16
0
 def solve_model_rk45(inputs):
     y = odeint(rhs_ode,
                y0,
                t_eval,
                inputs,
                rtol=self.rtol,
                atol=self.atol,
                **self.extra_options)
     return jnp.transpose(y)
 def batch_warmup(self, x0: Array, total_steps: int) -> Array:
   """
   Integrates the model and returns just the last step.
   
   This function is used to spin-up the model to statistically stataionary
   regime.
   """
   t = jnp.asarray([0, total_steps * self.dt])
   traj = odeint(self.equation_of_motion, x0, t)
   return traj[-1, ...]
Exemple #18
0
 def evally(policy_params, rng, x0, gamma):
     # policy_params if first since that's what we want gradients wrt.
     t = random.exponential(rng, (num_keypoints, )) / -jp.log(gamma)
     t = jp.concatenate((jp.zeros((1, )), jp.sort(t)))
     print(f"t = {t}")
     x_t = ode.odeint(ofunc, x0, t, policy_params, rtol=1e-3, mxstep=1e6)
     print(f"x_t = {x_t}")
     costs = vmap(lambda x: cost_fn(x, policy(policy_params, x)))(x_t)
     print(f"costs = {costs}")
     return jp.mean(costs)
Exemple #19
0
def _forward_dormandprince(state, ts, params, diffusivity, stimuli, dt, dx):
    return ode.odeint(
        step,
        state,
        ts,
        params,
        diffusivity,
        stimuli,
        dx,
    )
Exemple #20
0
  def check_against_scipy(self, fun, y0, tspace, *args, tol=1e-1):
    y0, tspace = np.array(y0), np.array(tspace)
    np_fun = partial(fun, np)
    scipy_result = jnp.asarray(osp_integrate.odeint(np_fun, y0, tspace, args))

    y0, tspace = jnp.array(y0), jnp.array(tspace)
    jax_fun = partial(fun, jnp)
    jax_result = odeint(jax_fun, y0, tspace, *args)

    self.assertAllClose(jax_result, scipy_result, check_dtypes=False, atol=tol, rtol=tol)
 def make_sim(key):
     x0 = random.normal(key, (n, total_dim))  # (n, dim * 2 + params)
     #x0 = index_update(x0, s_[..., -1], np.exp(x0[..., -1])) #all masses
     x0 = index_update(
         x0, s_[..., -1],
         1)  # mass set to 1, [ x, y, x', y', random_value, mass]
     x_times = odeint(odefunc,
                      x0.reshape(packed_shape),
                      times,
                      mxstep=2000).reshape(-1, *unpacked_shape)
     return x_times
Exemple #22
0
def ode_forward(rng: random.PRNGKey, net_params: List[jnp.ndarray],
                num_samples: int, num_dims: int) -> Tuple[jnp.ndarray]:
    x = random.normal(rng, [num_samples, num_dims])
    log_prob_prior = jspst.norm.logpdf(x).sum(axis=-1)
    vector_field = lambda x, t: net(net_params, stacked(x, t))
    divfunc = primal_to_augmented(vector_field)
    init = (x, jnp.zeros(len(x)))
    time = jnp.array([0.0, 1.0])
    xfwd, trace = tuple(_[-1] for _ in odeint(divfunc, init, time))
    log_prob = log_prob_prior + trace
    xsph = project(xfwd)
    return xfwd, xsph, log_prob
Exemple #23
0
def ode_reverse(net_params: List[jnp.ndarray],
                xrev: jnp.ndarray) -> Tuple[jnp.ndarray]:
    num_dims = xrev.shape[-1]
    vector_field = lambda x, t: net(net_params, stacked(x, t))
    divfunc = primal_to_augmented(vector_field)
    revfunc = lambda x, t: tuple(-_ for _ in divfunc(x, 1.0 - t))
    revinit = (xrev, jnp.zeros(len(xrev)))
    time = jnp.array([0.0, 1.0])
    yrev, revtrace = tuple(_[-1] for _ in odeint(revfunc, revinit, time))
    log_prob_prior = jspst.norm.logpdf(yrev).sum(axis=-1)
    rev_log_prob = log_prob_prior - revtrace
    return rev_log_prob
Exemple #24
0
def integrate_fieldline(startpoint, objective_dict):
    closed_l = r(objective_dict['p'], theta)
    ll = closed_l[:, :-1, :]
    dl = closed_l[:, :-1, :] - closed_l[:, 1:, :]
    jitfield = jit(
        lambda r_eval, t: biot_savart(r_eval, dl, ll, objective_dict['I_arr']))
    startpos = r_surf[0, 0, :]
    times = np.linspace(0, 10, 10000)
    fieldline = odeint(jitfield, startpos, times)

    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')
    ax.plot(fieldline[:, 0], fieldline[:, 1], fieldline[:, 2])
Exemple #25
0
    def test_swoop(self):
        def swoop(y, t, arg1, arg2):
            return np.array(y - np.sin(t) - np.cos(t) * arg1 + arg2)

        ts = np.array([0.1, 0.2])
        tol = 1e-1 if num_float_bits(onp.float64) == 32 else 1e-3

        y0 = np.linspace(0.1, 0.9, 10)
        integrate = lambda y0, ts: odeint(swoop, y0, ts, 0.1, 0.2)
        jtu.check_grads(integrate, (y0, ts),
                        modes=["rev"],
                        order=2,
                        rtol=tol,
                        atol=tol)

        big_y0 = np.linspace(1.1, 10.9, 10)
        integrate = lambda y0, ts: odeint(swoop, big_y0, ts, 0.1, 0.3)
        jtu.check_grads(integrate, (y0, ts),
                        modes=["rev"],
                        order=2,
                        rtol=tol,
                        atol=tol)
Exemple #26
0
def loss(t1,flat_p,psi_init,psi0):
    '''
    define the loss function, which is a pure function
    '''
    t_set = jnp.linspace(0., t1, 5)

    def func(y, t, *args):
        t1, flat_p, = args
        return -1.0j*Hmat(t, flat_p, t1)@y

    res = odeint(func, psi_init, t_set, t1, flat_p, rtol=1.4e-10, atol=1.4e-10)
    psi_final = res[-1, :]

    return (1 - jnp.abs(jnp.dot(jnp.conjugate(psi_final), psi0))**2)
def state_and_costate_trajectories(initial_costate_and_final_time):
    """Propagates the ODE that defines the shooting method.

    Args:
        initial_costate_and_final_time: An array of shape (4,) containing (p_x(0), p_y(0), p_θ(0), t_f).

    Returns:
        A tuple of arrays (times, (states, costates)) where
            times: An array of shape (N,) containing a sequence of time points spanning [0, t_f].
            states: An array of shape (N, 3) containing the states at `times`.
            controls: An array of shape (N, 3) containing the controls at `times`.
    """
    initial_costate = initial_costate_and_final_time[:-1]
    final_time = initial_costate_and_final_time[-1]
    times = jnp.linspace(0, final_time, 20)
    return times, odeint(shooting_ode, (initial_state, initial_costate), times)
Exemple #28
0
    def model(self, t, X):        

        noise = sample('noise', dist.LogNormal(0.0, 1.0), sample_shape=(self.D,))
        hyp = sample('hyp', dist.Gamma(1.0, 0.5), sample_shape=(self.D,))
        W = sample('W', dist.LogNormal(0.0, 1.0), sample_shape=(self.D,))
        
        J0 = sample('J0', dist.Uniform(1.0, 10.0)) # 2.5
        k1 = sample('k1', dist.Uniform(80., 120.0)) # 100.
        k2 = sample('k2', dist.Uniform(1., 10.0)) # 6.
        k3 = sample('k3', dist.Uniform(2., 20.0)) # 16.
        k4 = sample('k4', dist.Uniform(80., 120.0)) # 100.
        k5 = sample('k5', dist.Uniform(0.1, 2.0)) # 1.28
        k6 = sample('k6', dist.Uniform(2., 20.0)) # 12.
        k = sample('k', dist.Uniform(0.1, 2.0)) # 1.8
        ka = sample('ka', dist.Uniform(2., 20.0)) # 13.
        q = sample('q', dist.Uniform(1., 10.0)) # 4.
        KI = sample('KI', dist.Uniform(0.1, 2.0)) # 0.52
        phi = sample('phi', dist.Uniform(0.05, 1.0)) # 0.1
        Np = sample('Np', dist.Uniform(0.1, 2.0)) # 1.
        A = sample('A', dist.Uniform(1., 10.0)) #4.
        
        IC = sample('IC', dist.Uniform(0, 1))
        
        # compute kernel
        K_11 = W[0]*self.RBF(self.t_t[0], self.t_t[0], hyp[0]) + np.eye(self.N[0])*(noise[0] + self.jitter)
        K_22 = W[1]*self.RBF(self.t_t[1], self.t_t[1], hyp[1]) + np.eye(self.N[1])*(noise[1] + self.jitter)
        K_33 = W[2]*self.RBF(self.t_t[2], self.t_t[2], hyp[2]) + np.eye(self.N[2])*(noise[2] + self.jitter)
        K = np.concatenate([np.concatenate([K_11, np.zeros((self.N[0], self.N[1])), np.zeros((self.N[0], self.N[2]))], axis = 1),
                            np.concatenate([np.zeros((self.N[1], self.N[0])), K_22, np.zeros((self.N[1], self.N[2]))], axis = 1),
                            np.concatenate([np.zeros((self.N[2], self.N[0])), np.zeros((self.N[2], self.N[1])), K_33], axis = 1)], axis = 0)
        
        # compute mean
        x0 = np.array([0.5, 1.9, 0.18, 0.15, IC, 0.1, 0.064])
        mut = odeint(self.dxdt, x0, self.t.flatten(), J0, k1, k2, k3, k4, k5, k6, k, ka, q, KI, phi, Np, A)
        mu1 = mut[self.i_t[0],ind[0]] / self.max_X[0]
        mu2 = mut[self.i_t[1],ind[1]] / self.max_X[1]
        mu3 = mut[self.i_t[2],ind[2]] / self.max_X[2]
        mu = np.concatenate((mu1,mu2,mu3),axis=0)
        
        # Concat data
        mu = mu.flatten('F')
        X = np.concatenate((self.X[0],self.X[1],self.X[2]),axis=0)
        X = X.flatten('F')

        # sample X according to the standard gaussian process formula
        sample("X", dist.MultivariateNormal(loc=mu, covariance_matrix=K), obs=X)
Exemple #29
0
    def bwd_spline_segment(ta, tb, args, Q, y_old, aug_tb):
        """Run the backwards RK on just one segment of the spline."""
        def adj_dynamics(aug, t, args, Q, y_old):
            _, y_bar, _ = aug
            y = unravel(eval_spline(ta, tb, Q, y_old, -t))
            _, vjpfun = vjp(fun, -t, y, args)
            return vjpfun(y_bar)

        adj_path = ode.odeint(adj_dynamics,
                              aug_tb,
                              jnp.array([-tb, -ta]),
                              args,
                              Q,
                              y_old,
                              rtol=1e-3,
                              atol=1e-3)
        return tree_map(itemgetter(-1), adj_path)
Exemple #30
0
        def make_sim(key):
            if sim in ['string', 'string_ball']:
                x0 = random.normal(key, (n, total_dim))
                x0 = index_update(x0, s_[..., -1], 1); #const mass
                x0 = index_update(x0, s_[..., 0], np.arange(n)+x0[...,0]*0.5)
                x0 = index_update(x0, s_[..., 2:3], 0.0)
            else:
                x0 = random.normal(key, (n, total_dim))
                x0 = index_update(x0, s_[..., -1], np.exp(x0[..., -1])); #all masses set to positive
                if sim in ['charge', 'superposition']:
                    x0 = index_update(x0, s_[..., -2], np.sign(x0[..., -2])); #charge is 1 or -1

            x_times = odeint(
                odefunc,
                x0.reshape(packed_shape),
                times, mxstep=2000).reshape(-1, *unpacked_shape)

            return x_times