Ejemplo n.º 1
0
    def __init__(self,
                 func,
                 y0,
                 rtol,
                 atol,
                 first_step=None,
                 safety=0.9,
                 ifactor=10.0,
                 dfactor=0.2,
                 max_num_steps=2**31 - 1,
                 **unused_kwargs):
        _handle_unused_kwargs(self, unused_kwargs)
        del unused_kwargs

        self.func = func
        self.y0 = y0
        self.rtol = rtol if _is_iterable(rtol) else [rtol] * len(y0)
        self.atol = atol if _is_iterable(atol) else [atol] * len(y0)
        self.first_step = first_step
        self.safety = _convert_to_tensor(safety,
                                         dtype=torch.float64,
                                         device=y0[0].device)
        self.ifactor = _convert_to_tensor(ifactor,
                                          dtype=torch.float64,
                                          device=y0[0].device)
        self.dfactor = _convert_to_tensor(dfactor,
                                          dtype=torch.float64,
                                          device=y0[0].device)
        self.max_num_steps = _convert_to_tensor(max_num_steps,
                                                dtype=torch.int32,
                                                device=y0[0].device)
def _interp_evaluate(coefficients, t0, t1, t):
    """Evaluate polynomial interpolation at the given time point.

    Args:
        coefficients: list of Tensor coefficients as created by `interp_fit`.
        t0: scalar float64 Tensor giving the start of the interval.
        t1: scalar float64 Tensor giving the end of the interval.
        t: scalar float64 Tensor giving the desired interpolation point.

    Returns:
        Polynomial interpolation of the coefficients at time `t`.
    """

    dtype = coefficients[0][0].dtype
    device = coefficients[0][0].device

    t0 = _convert_to_tensor(t0, dtype=dtype, device=device)
    t1 = _convert_to_tensor(t1, dtype=dtype, device=device)
    t = _convert_to_tensor(t, dtype=dtype, device=device)

    assert (t0 <= t) & (t <= t1), 'invalid interpolation, fails `t0 <= t <= t1`: {}, {}, {}'.format(t0, t, t1)
    x = ((t - t0) / (t1 - t0)).type(dtype).to(device)

    xs = [torch.tensor(1).type(dtype).to(device), x]
    for _ in range(2, len(coefficients)):
        xs.append(xs[-1] * x)

    return tuple(_dot_product(coefficients_, reversed(xs)) for coefficients_ in zip(*coefficients))
Ejemplo n.º 3
0
    def __init__(self,
                 func,
                 y0,
                 rtol,
                 atol,
                 implicit=True,
                 max_order=_MAX_ORDER,
                 safety=0.9,
                 ifactor=10.0,
                 dfactor=0.2,
                 **unused_kwargs):
        _handle_unused_kwargs(self, unused_kwargs)
        del unused_kwargs

        self.func = func
        self.y0 = y0
        self.rtol = rtol if _is_iterable(rtol) else [rtol] * len(y0)
        self.atol = atol if _is_iterable(atol) else [atol] * len(y0)
        self.implicit = implicit
        self.max_order = int(max(_MIN_ORDER, min(max_order, _MAX_ORDER)))
        self.safety = _convert_to_tensor(safety,
                                         dtype=torch.float64,
                                         device=y0[0].device)
        self.ifactor = _convert_to_tensor(ifactor,
                                          dtype=torch.float64,
                                          device=y0[0].device)
        self.dfactor = _convert_to_tensor(dfactor,
                                          dtype=torch.float64,
                                          device=y0[0].device)
Ejemplo n.º 4
0
 def advance(self, final_t):
     final_t = _convert_to_tensor(final_t).to(self.vcabm_state.prev_t[0])
     while final_t > self.vcabm_state.prev_t[0]:
         self.vcabm_state = self._adaptive_adams_step(
             self.vcabm_state, final_t)
     assert final_t == self.vcabm_state.prev_t[0]
     return self.vcabm_state.y_n
Ejemplo n.º 5
0
 def before_integrate(self, t):
     if self.first_step is None:
         first_step = _select_initial_step(self.func, t[0], self.y0, 4,
                                           self.rtol, self.atol).to(t)
     else:
         first_step = _convert_to_tensor(0.01,
                                         dtype=t.dtype,
                                         device=t.device)
     self.rk_state = _RungeKuttaState(
         self.y0, self.func(t[0].type_as(self.y0[0]), self.y0), t[0], t[0],
         first_step, tuple(map(lambda x: [x] * 7, self.y0)))
def _runge_kutta_step(func, y0, f0, t0, dt, tableau):
    """Take an arbitrary Runge-Kutta step and estimate error.

    Args:
        func: Function to evaluate like `func(t, y)` to compute the time derivative
            of `y`.
        y0: Tensor initial value for the state.
        f0: Tensor initial value for the derivative, computed from `func(t0, y0)`.
        t0: float64 scalar Tensor giving the initial time.
        dt: float64 scalar Tensor giving the size of the desired time step.
        tableau: optional _ButcherTableau describing how to take the Runge-Kutta
            step.
        name: optional name for the operation.

    Returns:
        Tuple `(y1, f1, y1_error, k)` giving the estimated function value after
        the Runge-Kutta step at `t1 = t0 + dt`, the derivative of the state at `t1`,
        estimated error at `t1`, and a list of Runge-Kutta coefficients `k` used for
        calculating these terms.
    """
    dtype = y0[0].dtype
    device = y0[0].device

    t0 = _convert_to_tensor(t0, dtype=dtype, device=device)
    dt = _convert_to_tensor(dt, dtype=dtype, device=device)

    k = tuple(map(lambda x: [x], f0))
    for alpha_i, beta_i in zip(tableau.alpha, tableau.beta):
        ti = t0 + alpha_i * dt
        yi = tuple(y0_ + _scaled_dot_product(dt, beta_i, k_) for y0_, k_ in zip(y0, k))
        tuple(k_.append(f_) for k_, f_ in zip(k, func(ti, yi)))

    if not (tableau.c_sol[-1] == 0 and tableau.c_sol[:-1] == tableau.beta[-1]):
        # This property (true for Dormand-Prince) lets us save a few FLOPs.
        yi = tuple(y0_ + _scaled_dot_product(dt, tableau.c_sol, k_) for y0_, k_ in zip(y0, k))

    y1 = yi
    f1 = tuple(k_[-1] for k_ in k)
    y1_error = tuple(_scaled_dot_product(dt, tableau.c_error, k_) for k_ in k)
    return (y1, f1, y1_error, k)
Ejemplo n.º 7
0
def _optimal_step_size(last_step,
                       mean_error_ratio,
                       safety=0.9,
                       ifactor=10.0,
                       dfactor=0.2,
                       order=5):
    """Calculate the optimal size for the next Runge-Kutta step."""
    if mean_error_ratio == 0:
        return last_step * ifactor
    if mean_error_ratio < 1:
        dfactor = _convert_to_tensor(1,
                                     dtype=torch.float64,
                                     device=mean_error_ratio.device)
    error_ratio = torch.sqrt(mean_error_ratio).type_as(last_step)
    exponent = torch.tensor(1 / order).type_as(last_step)
    factor = torch.max(1 / ifactor,
                       torch.min(error_ratio**exponent / safety, 1 / dfactor))
    return last_step / factor
Ejemplo n.º 8
0
 def before_integrate(self, t):
     f0 = self.func(t[0].type_as(self.y0[0]), self.y0)
     if self.first_step is None:
         first_step = _select_initial_step(self.func,
                                           t[0],
                                           self.y0,
                                           4,
                                           self.rtol[0],
                                           self.atol[0],
                                           f0=f0).to(t)
     else:
         first_step = _convert_to_tensor(0.01,
                                         dtype=t.dtype,
                                         device=t.device)
     self.rk_state = _RungeKuttaState(self.y0,
                                      f0,
                                      t[0],
                                      t[0],
                                      first_step,
                                      interp_coeff=[self.y0] * 5)