Exemplo n.º 1
0
 def integrate(self, t):
     _assert_increasing(t)
     solution = [cast_double(self.y0)]
     t = move_to_device(tf.cast(t, tf.float64), self.y0[0].device)
     self.before_integrate(t)
     for i in range(1, t.shape[0]):
         y = self.advance(t[i])
         y = cast_double(y)
         solution.append(y)
     return tuple(map(tf.stack, tuple(zip(*solution))))
Exemplo n.º 2
0
def _interp_fit_dopri5(y0,
                       y1,
                       k,
                       dt,
                       tableau=_DORMAND_PRINCE_SHAMPINE_TABLEAU):
    """Fit an interpolating polynomial to the results of a Runge-Kutta step."""
    dt = cast_double(dt)
    y0 = cast_double(y0)

    y_mid = tuple(y0_ + _scaled_dot_product(dt, DPS_C_MID, k_)
                  for y0_, k_ in zip(y0, k))
    f0 = tuple(k_[0] for k_ in k)
    f1 = tuple(k_[-1] for k_ in k)
    return _interp_fit(y0, y1, y_mid, f0, f1, dt)
Exemplo n.º 3
0
        def augmented_dynamics(t, y_aug):
            # Dynamics of the original system augmented with
            # the adjoint wrt y, and an integrator wrt t and args.

            y, adj_y = y_aug[:n_tensors], y_aug[
                n_tensors:2 * n_tensors]  # Ignore adj_time and adj_params.

            with tf.GradientTape() as tape:
                tape.watch(t)
                tape.watch(y)
                func_eval = func(t, y)
                func_eval = cast_double(func_eval)

            gradys = tf.stack(list(-adj_y_ for adj_y_ in adj_y))
            if len(gradys.shape) < len(func_eval.shape):
                gradys = tf.expand_dims(gradys, axis=0)
            vjp_t, *vjp_y_and_params = tape.gradient(func_eval,
                                                     (t, ) + y + f_params,
                                                     output_gradients=gradys)

            vjp_y = vjp_y_and_params[:n_tensors]
            vjp_params = vjp_y_and_params[n_tensors:]

            # autograd.grad returns None if no gradient, set to zero.
            vjp_t = tf.zeros_like(t, dtype=t.dtype) if vjp_t is None else vjp_t
            vjp_y = tuple(
                tf.zeros_like(y_, dtype=y_.dtype) if vjp_y_ is None else vjp_y_
                for vjp_y_, y_ in zip(vjp_y, y))
            vjp_params = _flatten_convert_none_to_zeros(vjp_params, f_params)

            if _check_len(f_params) == 0:
                vjp_params = tf.convert_to_tensor(0., dtype=vjp_y[0].dype)
                vjp_params = move_to_device(vjp_params, vjp_y[0].device)

            return (*func_eval, *vjp_y, vjp_t, vjp_params)
Exemplo n.º 4
0
def _runge_kutta_step(func, y0, f0, t0, dt, tableau):
    """Take an arbitrary Runge-Kutta step and estimate error.

    Args:
        func: Function to evaluate like `func(t, y)` to compute the time derivative
            of `y`.
        y0: Tensor initial value for the state.
        f0: Tensor initial value for the derivative, computed from `func(t0, y0)`.
        t0: float64 scalar Tensor giving the initial time.
        dt: float64 scalar Tensor giving the size of the desired time step.
        tableau: optional _ButcherTableau describing how to take the Runge-Kutta
            step.
        name: optional name for the operation.

    Returns:
        Tuple `(y1, f1, y1_error, k)` giving the estimated function value after
        the Runge-Kutta step at `t1 = t0 + dt`, the derivative of the state at `t1`,
        estimated error at `t1`, and a list of Runge-Kutta coefficients `k` used for
        calculating these terms.
    """
    y0 = cast_double(y0)
    f0 = cast_double(f0)

    dtype = y0[0].dtype
    device = y0[0].device

    t0 = _convert_to_tensor(t0, dtype=dtype, device=device)
    dt = _convert_to_tensor(dt, dtype=dtype, device=device)

    k = tuple(map(lambda x: [x], f0))
    for alpha_i, beta_i in zip(tableau.alpha, tableau.beta):
        ti = t0 + alpha_i * dt
        yi = tuple(y0_ + _scaled_dot_product(dt, cast_double(beta_i), k_)
                   for y0_, k_ in zip(y0, k))
        tuple(k_.append(cast_double(f_)) for k_, f_ in zip(k, func(ti, yi)))

    if not (tableau.c_sol[-1] == 0 and tableau.c_sol[:-1] == tableau.beta[-1]):
        # This property (true for Dormand-Prince) lets us save a few FLOPs.
        yi = tuple(y0_ + _scaled_dot_product(dt, tableau.c_sol, k_)
                   for y0_, k_ in zip(y0, k))

    y1 = yi
    f1 = tuple(k_[-1] for k_ in k)
    y1_error = tuple(_scaled_dot_product(dt, tableau.c_error, k_) for k_ in k)
    return (y1, f1, y1_error, k)
Exemplo n.º 5
0
def _interp_coeff_tsit5(t0, dt, eval_t):
    t = cast_double((eval_t - t0) / dt)
    b1 = -1.0530884977290216 * t * (t - 1.3299890189751412) * (t ** 2 - 1.4364028541716351 * t + 0.7139816917074209)
    b2 = 0.1017 * t ** 2 * (t ** 2 - 2.1966568338249754 * t + 1.2949852507374631)
    b3 = 2.490627285651252793 * t ** 2 * (t ** 2 - 2.38535645472061657 * t + 1.57803468208092486)
    b4 = -16.54810288924490272 * (t - 1.21712927295533244) * (t - 0.61620406037800089) * t ** 2
    b5 = 47.37952196281928122 * (t - 1.203071208372362603) * (t - 0.658047292653547382) * t ** 2
    b6 = -34.87065786149660974 * (t - 1.2) * (t - 0.666666666666666667) * t ** 2
    b7 = 2.5 * (t - 1) * (t - 0.6) * t ** 2
    return [b1, b2, b3, b4, b5, b6, b7]
Exemplo n.º 6
0
def rk4_step_func(func, t, dt, y, k1=None):
    if k1 is None:
        k1 = func(t, y)
    k1 = cast_double(k1)

    k2 = func(t + dt / 2, tuple(y_ + dt * k1_ / 2 for y_, k1_ in zip(y, k1)))
    k3 = func(t + dt / 2, tuple(y_ + dt * k2_ / 2 for y_, k2_ in zip(y, k2)))
    k4 = func(t + dt, tuple(y_ + dt * k3_ for y_, k3_ in zip(y, k3)))
    return tuple((k1_ + 2 * k2_ + 2 * k3_ + k4_) * (dt / 6)
                 for k1_, k2_, k3_, k4_ in zip(k1, k2, k3, k4))
Exemplo n.º 7
0
 def before_integrate(self, t):
     if self.first_step is None:
         first_step = _convert_to_tensor(_select_initial_step(self.func, t[0], self.y0, 4, self.rtol, self.atol),
                                         device=t.device)
     else:
         first_step = _convert_to_tensor(0.01, dtype=t.dtype, device=t.device)
     self.rk_state = _RungeKuttaState(
         self.y0,
         cast_double(self.func(t[0], self.y0)), t[0], t[0], first_step,
         tuple(map(lambda x: [x] * 7, self.y0))
     )
Exemplo n.º 8
0
def rk4_alt_step_func(func, t, dt, y, k1=None):
    """Smaller error with slightly more compute."""
    if k1 is None:
        k1 = func(t, y)
    k1 = cast_double(k1)

    k2 = func(t + dt / 3, tuple(y_ + dt * k1_ / 3 for y_, k1_ in zip(y, k1)))
    k2 = cast_double(k2)

    k3 = func(
        t + dt * 2 / 3,
        tuple(y_ + dt * (k1_ / -3 + k2_) for y_, k1_, k2_ in zip(y, k1, k2)))
    k3 = cast_double(k3)

    k4 = func(
        t + dt,
        tuple(y_ + dt * (k1_ - k2_ + k3_)
              for y_, k1_, k2_, k3_ in zip(y, k1, k2, k3)))
    k4 = cast_double(k4)
    return tuple((k1_ + 3 * k2_ + 3 * k3_ + k4_) * (dt / 8)
                 for k1_, k2_, k3_, k4_ in zip(k1, k2, k3, k4))
Exemplo n.º 9
0
    def integrate(self, t):
        _assert_increasing(t)
        t = tf.cast(t, self.y0[0].dtype)
        time_grid = self.grid_constructor(self.func, self.y0, t)
        assert tf.equal(time_grid[0], t[0]) and tf.equal(time_grid[-1], t[-1])
        time_grid = move_to_device(time_grid, self.y0[0].device)

        solution = [cast_double(self.y0)]

        j = 1
        y0 = cast_double(self.y0)
        for t0, t1 in zip(time_grid[:-1], time_grid[1:]):
            dy = self.step_func(self.func, t0, t1 - t0, y0)
            y1 = tuple(y0_ + dy_ for y0_, dy_ in zip(y0, dy))
            y0 = y1

            while j < t.shape[0] and t1 >= t[j]:
                y = self._linear_interp(t0, t1, y0, y1, t[j])
                solution.append(y)
                j += 1

        return tuple(map(tf.stack, tuple(zip(*solution))))
Exemplo n.º 10
0
        def augmented_dynamics(t, y_aug):
            # Dynamics of the original system augmented with
            # the adjoint wrt y, and an integrator wrt t and args.

            y, adj_y = y_aug[:n_tensors], y_aug[
                n_tensors:2 * n_tensors]  # Ignore adj_time and adj_params.

            # t = tf.get_variable('t', initializer=t)
            # y = tuple(tf.Variable(y_) for y_ in y)

            with tf.GradientTape() as tape:
                tape.watch(t)
                tape.watch(y)
                func_eval = func(t, y)
                func_eval = cast_double(func_eval)

            # print('y', [y_.numpy().shape for y_ in y])
            # print('adj y', [a.numpy().shape for a in adj_y])

            vjp_t, *vjp_y_and_params = tape.gradient(
                func_eval,
                (t, ) + y + f_params,
                # list(-adj_y_ for adj_y_ in adj_y),
            )

            vjp_y = vjp_y_and_params[:n_tensors]
            vjp_params = vjp_y_and_params[n_tensors:]
            # print('vjp_y', [v.numpy().shape if v is not None else None for v in vjp_y])
            # print()

            # autograd.grad returns None if no gradient, set to zero.
            vjp_t = tf.zeros_like(t, dtype=t.dtype) if vjp_t is None else vjp_t
            vjp_y = tuple(
                tf.zeros_like(y_, dtype=y_.dtype) if vjp_y_ is None else vjp_y_
                for vjp_y_, y_ in zip(vjp_y, y))
            vjp_params = _flatten_convert_none_to_zeros(vjp_params, f_params)

            if _check_len(f_params) == 0:
                vjp_params = tf.convert_to_tensor(0., dtype=vjp_y[0].dype)
                vjp_params = move_to_device(vjp_params, vjp_y[0].device)

            # print('vjp_t grad', vjp_t.numpy())
            # print('vjp_params', [v.numpy() for v in vjp_params])
            # print('vjp y grads', [v.numpy().shape for v in vjp_y])
            # print("LEN FUNC EVALS : ", len(func_eval))
            # print()

            return (*func_eval, *vjp_y, vjp_t, vjp_params)
Exemplo n.º 11
0
def _interp_eval_tsit5(t0, t1, k, eval_t):
    dt = cast_double(t1) - cast_double(t0)
    y0 = tuple(k_[0] for k_ in k)
    interp_coeff = _interp_coeff_tsit5(t0, dt, eval_t)
    y_t = tuple(y0_ + _scaled_dot_product(dt, interp_coeff, k_) for y0_, k_ in zip(y0, k))
    return y_t
Exemplo n.º 12
0
 def step_func(self, func, t, dt, y):
     return tuple(dt * f_ for f_ in cast_double(func(t, y)))
Exemplo n.º 13
0
 def step_func(self, func, t, dt, y):
     f_outs = cast_double(func(t, y))
     ft_1_hat = tuple(y_ + dt * f_ for y_, f_ in zip(y, f_outs))
     ft_1_outs = cast_double(func(t + dt, ft_1_hat))
     return tuple(dt / 2. * (ft_ + ft_1_hat_)
                  for ft_, ft_1_hat_ in zip(f_outs, ft_1_outs))
Exemplo n.º 14
0
 def step_func(self, func, t, dt, y):
     y_mid = tuple(y_ + f_ * dt / 2
                   for y_, f_ in zip(y, cast_double(func(t, y))))
     return tuple(dt * f_ for f_ in cast_double(func(t + dt / 2, y_mid)))
Exemplo n.º 15
0
    def grad(*grad_output, variables=None):
        global _arguments
        # t, flat_params, *ans = ctx.saved_tensors
        # ans = tuple(ans)
        # func, rtol, atol, method, options = ctx.func, ctx.rtol, ctx.atol, ctx.method, ctx.options
        func = _arguments.func
        method = _arguments.method
        options = _arguments.options
        rtol = _arguments.rtol
        atol = _arguments.atol

        print("Gradient Output : ", grad_output)
        print("Variables : ", variables)

        n_tensors = len(ans)
        f_params = tuple(variables)

        # TODO: use a tf.keras.Model and call odeint_adjoint to implement higher order derivatives.
        def augmented_dynamics(t, y_aug):
            # Dynamics of the original system augmented with
            # the adjoint wrt y, and an integrator wrt t and args.

            y, adj_y = y_aug[:n_tensors], y_aug[
                n_tensors:2 * n_tensors]  # Ignore adj_time and adj_params.

            # t = tf.get_variable('t', initializer=t)
            # y = tuple(tf.Variable(y_) for y_ in y)

            with tf.GradientTape() as tape:
                tape.watch(t)
                tape.watch(y)
                func_eval = func(t, y)
                func_eval = cast_double(func_eval)

            # print('y', [y_.numpy().shape for y_ in y])
            # print('adj y', [a.numpy().shape for a in adj_y])

            vjp_t, *vjp_y_and_params = tape.gradient(
                func_eval,
                (t, ) + y + f_params,
                # list(-adj_y_ for adj_y_ in adj_y),
            )

            vjp_y = vjp_y_and_params[:n_tensors]
            vjp_params = vjp_y_and_params[n_tensors:]
            # print('vjp_y', [v.numpy().shape if v is not None else None for v in vjp_y])
            # print()

            # autograd.grad returns None if no gradient, set to zero.
            vjp_t = tf.zeros_like(t, dtype=t.dtype) if vjp_t is None else vjp_t
            vjp_y = tuple(
                tf.zeros_like(y_, dtype=y_.dtype) if vjp_y_ is None else vjp_y_
                for vjp_y_, y_ in zip(vjp_y, y))
            vjp_params = _flatten_convert_none_to_zeros(vjp_params, f_params)

            if _check_len(f_params) == 0:
                vjp_params = tf.convert_to_tensor(0., dtype=vjp_y[0].dype)
                vjp_params = move_to_device(vjp_params, vjp_y[0].device)

            # print('vjp_t grad', vjp_t.numpy())
            # print('vjp_params', [v.numpy() for v in vjp_params])
            # print('vjp y grads', [v.numpy().shape for v in vjp_y])
            # print("LEN FUNC EVALS : ", len(func_eval))
            # print()

            return (*func_eval, *vjp_y, vjp_t, vjp_params)

        T = ans[0].shape[0]
        if isinstance(grad_output, tf.Tensor) or isinstance(
                grad_output, tf.Variable):
            adj_y = [grad_output[-1]]
        else:
            adj_y = tuple(grad_output_[-1] for grad_output_ in grad_output)
        # adj_y = tuple(grad_output_[-1] for grad_output_ in grad_output)
        adj_params = tf.zeros_like(flat_params, dtype=flat_params.dtype)
        adj_time = move_to_device(tf.convert_to_tensor(0., dtype=t.dtype),
                                  t.device)
        time_vjps = []
        for i in range(T - 1, 0, -1):

            ans_i = tuple(ans_[i] for ans_ in ans)

            if isinstance(grad_output, tf.Tensor) or isinstance(
                    grad_output, tf.Variable):
                grad_output_i = [grad_output[i]]
            else:
                grad_output_i = tuple(grad_output_[i]
                                      for grad_output_ in grad_output)

            func_i = func(t[i], ans_i)
            func_i = cast_double(func_i)

            if not isinstance(func_i, Iterable):
                func_i = [func_i]

            # Compute the effect of moving the current time measurement point.
            dLd_cur_t = sum(
                tf.reshape(
                    tf.matmul(tf.reshape(func_i_, [1, -1]),
                              tf.reshape(grad_output_i_, [-1, 1])), [1])
                for func_i_, grad_output_i_ in zip(func_i, grad_output_i))
            adj_time = cast_double(adj_time)
            adj_time = adj_time - dLd_cur_t
            time_vjps.append(dLd_cur_t)

            # Run the augmented system backwards in time.
            if isinstance(adj_params, Iterable):
                count = _numel(adj_params)

                if count == 0:
                    adj_params = move_to_device(
                        tf.convert_to_tensor(0., dtype=adj_y[0].dtype),
                        adj_y[0].device)

            aug_y0 = (*ans_i, *adj_y, adj_time, adj_params)

            # print('ans i', [a.numpy().shape for a in ans_i])
            # print('adj y', [a.numpy().shape for a in adj_y])
            # print('adj time', adj_time.numpy().shape)
            # print('adj params', adj_params.numpy().shape)

            # print()

            aug_ans = odeint(augmented_dynamics,
                             aug_y0,
                             tf.convert_to_tensor([t[i], t[i - 1]]),
                             rtol=rtol,
                             atol=atol,
                             method=method,
                             options=options)

            # Unpack aug_ans.
            adj_y = aug_ans[n_tensors:2 * n_tensors]
            adj_time = aug_ans[2 * n_tensors]
            adj_params = aug_ans[2 * n_tensors + 1]

            adj_y = tuple(adj_y_[1] if _check_len(adj_y_) > 0 else adj_y_
                          for adj_y_ in adj_y)
            if _check_len(adj_time) > 0: adj_time = adj_time[1]
            if _check_len(adj_params) > 0: adj_params = adj_params[1]

            adj_y = tuple(adj_y_ + grad_output_[i - 1]
                          for adj_y_, grad_output_ in zip(adj_y, grad_output))

            del aug_y0, aug_ans

        time_vjps.append(adj_time)
        time_vjps = tf.concat(time_vjps[::-1], 0)

        print()
        print('adj y', len(adj_y))
        print('time vjps', time_vjps.shape)
        print('adj params', adj_params.shape)
        print()

        # reshape the parameters back into the correct variable shapes
        var_flat_lens = [_numel(v, dtype=tf.int32).numpy() for v in variables]
        var_shapes = [v.shape for v in variables]

        adj_params_splits = tf.split(adj_params, var_flat_lens)
        adj_params_list = [
            tf.reshape(p, v_shape)
            for p, v_shape in zip(adj_params_splits, var_shapes)
        ]

        # add the time gradient (always the first tensor in list of variables)
        # adj_params.insert(0, time_vjps)

        # adj_y_grad_vars = list(zip(adj_y, grad_output))
        # time_grad_vars = list((time_vjps, t))
        model_vars = list(adj_params_list)  # list(zip(adj_params, variables))

        grad_vars = model_vars  # adj_y_grad_vars + time_grad_vars + model_vars
        # print('adj y grad', len(adj_y_grad_vars))
        # print('time grad', len(time_grad_vars))
        print('model grad', len(model_vars))
        print('model grad values', [v for v in grad_vars])
        print()

        # if len(adj_y) == 1:
        #     adj_y = adj_y[0]

        return (adj_y, model_vars)
Exemplo n.º 16
0
    def grad(*grad_output, variables=None):
        global _arguments
        flat_params = _flatten(variables)

        func = _arguments.func
        method = _arguments.method
        options = _arguments.options
        rtol = _arguments.rtol
        atol = _arguments.atol

        n_tensors = len(ans)
        f_params = tuple(variables)

        # TODO: use a tf.keras.Model and call odeint_adjoint to implement higher order derivatives.
        def augmented_dynamics(t, y_aug):
            # Dynamics of the original system augmented with
            # the adjoint wrt y, and an integrator wrt t and args.

            y, adj_y = y_aug[:n_tensors], y_aug[
                n_tensors:2 * n_tensors]  # Ignore adj_time and adj_params.

            with tf.GradientTape() as tape:
                tape.watch(t)
                tape.watch(y)
                func_eval = func(t, y)
                func_eval = cast_double(func_eval)

            gradys = tf.stack(list(-adj_y_ for adj_y_ in adj_y))
            if len(gradys.shape) < len(func_eval.shape):
                gradys = tf.expand_dims(gradys, axis=0)
            vjp_t, *vjp_y_and_params = tape.gradient(func_eval,
                                                     (t, ) + y + f_params,
                                                     output_gradients=gradys)

            vjp_y = vjp_y_and_params[:n_tensors]
            vjp_params = vjp_y_and_params[n_tensors:]

            # autograd.grad returns None if no gradient, set to zero.
            vjp_t = tf.zeros_like(t, dtype=t.dtype) if vjp_t is None else vjp_t
            vjp_y = tuple(
                tf.zeros_like(y_, dtype=y_.dtype) if vjp_y_ is None else vjp_y_
                for vjp_y_, y_ in zip(vjp_y, y))
            vjp_params = _flatten_convert_none_to_zeros(vjp_params, f_params)

            if _check_len(f_params) == 0:
                vjp_params = tf.convert_to_tensor(0., dtype=vjp_y[0].dype)
                vjp_params = move_to_device(vjp_params, vjp_y[0].device)

            return (*func_eval, *vjp_y, vjp_t, vjp_params)

        T = ans[0].shape[0]
        if isinstance(grad_output, tf.Tensor) or isinstance(
                grad_output, tf.Variable):
            adj_y = [grad_output[-1]]
        else:
            adj_y = tuple(grad_output_[-1] for grad_output_ in grad_output)
        # adj_y = tuple(grad_output_[-1] for grad_output_ in grad_output)
        adj_params = tf.zeros_like(flat_params, dtype=flat_params.dtype)
        adj_time = move_to_device(tf.convert_to_tensor(0., dtype=t.dtype),
                                  t.device)
        time_vjps = []
        for i in range(T - 1, 0, -1):

            ans_i = tuple(ans_[i] for ans_ in ans)

            if isinstance(grad_output, tf.Tensor) or isinstance(
                    grad_output, tf.Variable):
                grad_output_i = [grad_output[i]]
            else:
                grad_output_i = tuple(grad_output_[i]
                                      for grad_output_ in grad_output)

            func_i = func(t[i], ans_i)
            func_i = cast_double(func_i)

            if not isinstance(func_i, Iterable):
                func_i = [func_i]

            # Compute the effect of moving the current time measurement point.
            dLd_cur_t = sum(
                tf.reshape(
                    tf.matmul(tf.reshape(func_i_, [1, -1]),
                              tf.reshape(grad_output_i_, [-1, 1])), [1])
                for func_i_, grad_output_i_ in zip(func_i, grad_output_i))
            adj_time = cast_double(adj_time)
            adj_time = adj_time - dLd_cur_t
            time_vjps.append(dLd_cur_t)

            # Run the augmented system backwards in time.
            if isinstance(adj_params, Iterable):
                count = _numel(adj_params)

                if count == 0:
                    adj_params = move_to_device(
                        tf.convert_to_tensor(0., dtype=adj_y[0].dtype),
                        adj_y[0].device)

            aug_y0 = (*ans_i, *adj_y, adj_time, adj_params)

            aug_ans = odeint(augmented_dynamics,
                             aug_y0,
                             tf.convert_to_tensor([t[i], t[i - 1]]),
                             rtol=rtol,
                             atol=atol,
                             method=method,
                             options=options)

            # Unpack aug_ans.
            adj_y = aug_ans[n_tensors:2 * n_tensors]
            adj_time = aug_ans[2 * n_tensors]
            adj_params = aug_ans[2 * n_tensors + 1]

            adj_y = tuple(adj_y_[1] if _check_len(adj_y_) > 0 else adj_y_
                          for adj_y_ in adj_y)
            if _check_len(adj_time) > 0: adj_time = adj_time[1]
            if _check_len(adj_params) > 0: adj_params = adj_params[1]

            adj_y = tuple(adj_y_ + grad_output_[i - 1]
                          for adj_y_, grad_output_ in zip(adj_y, grad_output))

            del aug_y0, aug_ans

        time_vjps.append(adj_time)
        time_vjps = tf.concat(time_vjps[::-1], 0)

        # reshape the parameters back into the correct variable shapes
        var_flat_lens = [_numel(v, dtype=tf.int32).numpy() for v in variables]
        var_shapes = [v.shape for v in variables]

        adj_params_splits = tf.split(adj_params, var_flat_lens)
        adj_params_list = [
            tf.reshape(p, v_shape)
            for p, v_shape in zip(adj_params_splits, var_shapes)
        ]

        return (*adj_y, time_vjps), adj_params_list