Esempio n. 1
0
    def solve_forward(self, t0, tvals, y0, y_out):
        CVodeReInit = lib.CVodeReInit
        CVodeAdjReInit = lib.CVodeAdjReInit
        CVodeF = lib.CVodeF
        ode = self._ode
        TOO_MUCH_WORK = lib.CV_TOO_MUCH_WORK

        state_data = self._state_buffer.data
        state_c_ptr = self._state_buffer.c_ptr

        state_data[:] = y0

        time_p = ffi.new('double*')
        time_p[0] = t0

        n_check = ffi.new('int*')
        n_check[0] = 0

        check(CVodeReInit(ode, t0, state_c_ptr))
        check(CVodeAdjReInit(ode))

        for i, t in enumerate(tvals):
            if t == t0:
                y_out[0, :] = y0
                continue

            retval = TOO_MUCH_WORK
            while retval == TOO_MUCH_WORK:
                retval = CVodeF(ode, t, state_c_ptr, time_p, lib.CV_NORMAL,
                                n_check)
                if retval != TOO_MUCH_WORK and retval != 0:
                    raise SolverError(
                        "Bad sundials return code while solving ode: %s (%s)" %
                        (ERRORS[retval], retval))
            y_out[i, :] = state_data
Esempio n. 2
0
    def current_stats(self) -> Dict[str, Any]:
        order = ffi.new('int*', 1)
        check_code(lib.CVodeGetCurrentOrder(self.c_ptr, order))

        return {
            "order": order[0],
        }
Esempio n. 3
0
    def solve(self,
              t0: float,
              tvals: np.ndarray,
              y0: np.ndarray,
              y_out: np.ndarray,
              forward_sens: Optional[np.ndarray] = None,
              checkpointing: bool = False) -> None:
        CVodeReInit = lib.CVodeReInit
        CVodeAdjReInit = lib.CVodeAdjReInit
        CVodeF = lib.CVodeF
        ode = self.c_ptr
        TOO_MUCH_WORK = lib.CV_TOO_MUCH_WORK

        state_data = self._state_buffer.data
        state_c_ptr = self._state_buffer.c_ptr

        if y0.dtype == self._problem.state_dtype:
            y0 = y0[None].view(np.float64)
        state_data[:] = y0

        time_p = ffi.new('double*')
        time_p[0] = t0

        n_check = ffi.new('int*')
        n_check[0] = 0

        check(CVodeReInit(ode, t0, state_c_ptr))
        check(CVodeAdjReInit(ode))

        for i, t in enumerate(tvals):
            if t == t0:
                y_out[0, :] = y0
                continue

            retval = TOO_MUCH_WORK
            while retval == TOO_MUCH_WORK:
                retval = CVodeF(ode, t, state_c_ptr, time_p, lib.CV_NORMAL,
                                n_check)
                if retval != TOO_MUCH_WORK and retval != 0:
                    raise SolverError(
                        "Bad sundials return code while solving ode: %s (%s)" %
                        (ERRORS[retval], retval))
            y_out[i, :] = state_data

        self.mark_changed(False)
Esempio n. 4
0
    def solve_forward(self, t0, tvals, y0, y_out, *, max_retries=5):
        CVodeReInit = lib.CVodeReInit
        CVodeAdjReInit = lib.CVodeAdjReInit
        CVodeF = lib.CVodeF
        ode = self._ode
        TOO_MUCH_WORK = lib.CV_TOO_MUCH_WORK

        state_data = self._state_buffer.data
        state_c_ptr = self._state_buffer.c_ptr

        if y0.dtype == self._problem.state_dtype:
            y0 = y0[None].view(np.float64)
        state_data[:] = y0

        time_p = ffi.new('double*')
        time_p[0] = t0

        n_check = ffi.new('int*')
        n_check[0] = 0

        check(CVodeReInit(ode, t0, state_c_ptr))
        check(CVodeAdjReInit(ode))

        for i, t in enumerate(tvals):
            if t == t0:
                y_out[0, :] = y0
                continue

            for retry in range(max_retries):
                retval = CVodeF(ode, t, state_c_ptr, time_p, lib.CV_NORMAL,
                                n_check)
                if retval == 0:
                    assert time_p[0] == t
                    break
                if retval != TOO_MUCH_WORK:
                    error = ERRORS[retval]
                    raise SolverError(
                        f"Solving ode failed before time={t}: {error} ({retval})"
                    )
            else:
                raise SolverError(f"Too many solver retries before time={t}.")

            y_out[i, :] = state_data
Esempio n. 5
0
    def _init_backward(self, checkpoint_n, interpolation):
        check(lib.CVodeAdjInit(self._ode, checkpoint_n, interpolation))

        # Initialized by CVodeCreateB
        backward_ode = ffi.new('int*')
        check(
            lib.CVodeCreateB(self._ode, self._adjoint_solver_type,
                             backward_ode))
        self._odeB = backward_ode[0]

        self._state_bufferB = sunode.empty_vector(self._problem.n_states)
        check(
            lib.CVodeInitB(self._ode, self._odeB, self._adj_rhs.cffi, 0.,
                           self._state_bufferB.c_ptr))

        # TODO
        check(lib.CVodeSStolerancesB(self._ode, self._odeB, 1e-10, 1e-10))

        linsolver = check(
            lib.SUNLinSol_Dense(self._state_bufferB.c_ptr, self._jacB))
        check(
            lib.CVodeSetLinearSolverB(self._ode, self._odeB, linsolver,
                                      self._jacB))

        self._jac_funcB = self._problem.make_sundials_adjoint_jac_dense()
        check(lib.CVodeSetJacFnB(self._ode, self._odeB, self._jac_funcB.cffi))

        user_data_p = ffi.cast(
            'void *', ffi.addressof(ffi.from_buffer(self._user_data.data)))
        check(lib.CVodeSetUserDataB(self._ode, self._odeB, user_data_p))

        self._quad_buffer = sunode.empty_vector(self._problem.n_params)
        self._quad_buffer_out = sunode.empty_vector(self._problem.n_params)
        check(
            lib.CVodeQuadInitB(self._ode, self._odeB, self._quad_rhs.cffi,
                               self._quad_buffer.c_ptr))

        check(lib.CVodeQuadSStolerancesB(self._ode, self._odeB, 1e-10, 1e-10))
        check(lib.CVodeSetQuadErrConB(self._ode, self._odeB, 1))
Esempio n. 6
0
    def solve_backward(self,
                       t0,
                       tend,
                       tvals,
                       grads,
                       grad_out,
                       lamda_out,
                       lamda_all_out=None,
                       quad_all_out=None,
                       max_retries=50):
        CVodeReInitB = lib.CVodeReInitB
        CVodeQuadReInitB = lib.CVodeQuadReInitB
        CVodeGetQuadB = lib.CVodeGetQuadB
        CVodeB = lib.CVodeB
        CVodeGetB = lib.CVodeGetB
        ode = self._ode
        odeB = self._odeB
        TOO_MUCH_WORK = lib.CV_TOO_MUCH_WORK

        state_data = self._state_bufferB.data
        state_c_ptr = self._state_bufferB.c_ptr

        quad_data = self._quad_buffer.data
        quad_c_ptr = self._quad_buffer.c_ptr

        quad_out_data = self._quad_buffer_out.data
        quad_out_c_ptr = self._quad_buffer_out.c_ptr

        state_data[:] = 0
        quad_data[:] = 0
        quad_out_data[:] = 0

        time_p = ffi.new('double*')
        time_p[0] = t0

        ts = [t0] + list(tvals[::-1]) + [tend]
        t_intervals = zip(ts[1:], ts[:-1])
        grads = [None] + list(grads)

        for i, ((t_lower, t_upper),
                grad) in enumerate(zip(t_intervals, reversed(grads))):
            if t_lower < t_upper:
                check(CVodeReInitB(ode, odeB, t_upper, state_c_ptr))
                check(CVodeQuadReInitB(ode, odeB, quad_c_ptr))

                for retry in range(max_retries):
                    retval = CVodeB(ode, t_lower, lib.CV_NORMAL)
                    if retval == 0:
                        break
                    if retval != TOO_MUCH_WORK:
                        error = ERRORS[retval]
                        raise SolverError(
                            f"Solving ode failed between time {t_upper} and "
                            f"{t_lower}: {error} ({retval})")
                else:
                    raise SolverError(
                        f"Too many solver retries between time {t_upper} and {t_lower}."
                    )

                check(CVodeGetB(ode, odeB, time_p, state_c_ptr))
                check(CVodeGetQuadB(ode, odeB, time_p, quad_out_c_ptr))
                quad_data[:] = quad_out_data[:]
                assert time_p[0] == t_lower, (time_p[0], t_lower)

            if grad is not None:
                state_data[:] -= grad

                if lamda_all_out is not None:
                    lamda_all_out[-i, :] = state_data
                if quad_all_out is not None:
                    quad_all_out[-i, :] = quad_data

        grad_out[:] = quad_out_data
        lamda_out[:] = state_data
Esempio n. 7
0
    def solve(self,
              t0,
              tvals,
              y0,
              y_out,
              *,
              sens0=None,
              sens_out=None,
              max_retries=5):
        if self._compute_sens and (sens0 is None or sens_out is None):
            raise ValueError(
                '"sens_out" and "sens0" are required when computin sensitivities.'
            )
        CVodeReInit = lib.CVodeReInit
        CVodeSensReInit = lib.CVodeSensReInit
        CVode = lib.CVode
        CVodeGetSens = lib.CVodeGetSens
        ode = self._ode
        TOO_MUCH_WORK = lib.CV_TOO_MUCH_WORK

        n_params = self._problem.n_params

        state_data = self._state_buffer.data
        state_c_ptr = self._state_buffer.c_ptr

        if self._compute_sens:
            sens_buffer_array = self._sens_buffer_array
            sens_data = tuple(buffer.data for buffer in self._sens_buffers)
            for i in range(n_params):
                sens_data[i][:] = sens0[i, :]

        if y0.dtype == self._problem.state_dtype:
            y0 = y0[None].view(np.float64)
        state_data[:] = y0

        time_p = ffi.new('double*')
        time_p[0] = t0

        check(CVodeReInit(ode, t0, state_c_ptr))
        if self._compute_sens:
            check(
                CVodeSensReInit(ode, self._sens_mode, self._sens_buffer_array))

        for i, t in enumerate(tvals):
            if t == t0:
                y_out[0, :] = y0
                if self._compute_sens:
                    sens_out[0, :, :] = sens0
                continue

            for retry in range(max_retries):
                retval = CVode(ode, t, state_c_ptr, time_p, lib.CV_NORMAL)
                if retval == 0:
                    assert time_p[0] == t
                    break
                if retval != TOO_MUCH_WORK:
                    error = ERRORS[retval]
                    raise SolverError(
                        f"Solving ode failed before time={t}: {error} ({retval})"
                    )
            else:
                raise SolverError(f"Too many solver retries before time={t}.")

            y_out[i, :] = state_data

            if self._compute_sens:
                retval = CVodeGetSens(ode, time_p, sens_buffer_array)
                if retval == 0:
                    for j in range(n_params):
                        sens_out[i, j, :] = sens_data[j]