def __init__(self, func, y0, rtol, atol, first_step=None, safety=0.9, ifactor=10.0, dfactor=0.2, max_num_steps=2**31 - 1, **unused_kwargs): _handle_unused_kwargs(self, unused_kwargs) del unused_kwargs self.func = func self.y0 = y0 self.rtol = rtol if _is_iterable(rtol) else [rtol] * len(y0) self.atol = atol if _is_iterable(atol) else [atol] * len(y0) self.first_step = first_step self.safety = _convert_to_tensor(safety, dtype=torch.float64, device=y0[0].device) self.ifactor = _convert_to_tensor(ifactor, dtype=torch.float64, device=y0[0].device) self.dfactor = _convert_to_tensor(dfactor, dtype=torch.float64, device=y0[0].device) self.max_num_steps = _convert_to_tensor(max_num_steps, dtype=torch.int32, device=y0[0].device)
def _interp_evaluate(coefficients, t0, t1, t): """Evaluate polynomial interpolation at the given time point. Args: coefficients: list of Tensor coefficients as created by `interp_fit`. t0: scalar float64 Tensor giving the start of the interval. t1: scalar float64 Tensor giving the end of the interval. t: scalar float64 Tensor giving the desired interpolation point. Returns: Polynomial interpolation of the coefficients at time `t`. """ dtype = coefficients[0][0].dtype device = coefficients[0][0].device t0 = _convert_to_tensor(t0, dtype=dtype, device=device) t1 = _convert_to_tensor(t1, dtype=dtype, device=device) t = _convert_to_tensor(t, dtype=dtype, device=device) assert (t0 <= t) & (t <= t1), 'invalid interpolation, fails `t0 <= t <= t1`: {}, {}, {}'.format(t0, t, t1) x = ((t - t0) / (t1 - t0)).type(dtype).to(device) xs = [torch.tensor(1).type(dtype).to(device), x] for _ in range(2, len(coefficients)): xs.append(xs[-1] * x) return tuple(_dot_product(coefficients_, reversed(xs)) for coefficients_ in zip(*coefficients))
def __init__(self, func, y0, rtol, atol, implicit=True, max_order=_MAX_ORDER, safety=0.9, ifactor=10.0, dfactor=0.2, **unused_kwargs): _handle_unused_kwargs(self, unused_kwargs) del unused_kwargs self.func = func self.y0 = y0 self.rtol = rtol if _is_iterable(rtol) else [rtol] * len(y0) self.atol = atol if _is_iterable(atol) else [atol] * len(y0) self.implicit = implicit self.max_order = int(max(_MIN_ORDER, min(max_order, _MAX_ORDER))) self.safety = _convert_to_tensor(safety, dtype=torch.float64, device=y0[0].device) self.ifactor = _convert_to_tensor(ifactor, dtype=torch.float64, device=y0[0].device) self.dfactor = _convert_to_tensor(dfactor, dtype=torch.float64, device=y0[0].device)
def advance(self, final_t): final_t = _convert_to_tensor(final_t).to(self.vcabm_state.prev_t[0]) while final_t > self.vcabm_state.prev_t[0]: self.vcabm_state = self._adaptive_adams_step( self.vcabm_state, final_t) assert final_t == self.vcabm_state.prev_t[0] return self.vcabm_state.y_n
def before_integrate(self, t): if self.first_step is None: first_step = _select_initial_step(self.func, t[0], self.y0, 4, self.rtol, self.atol).to(t) else: first_step = _convert_to_tensor(0.01, dtype=t.dtype, device=t.device) self.rk_state = _RungeKuttaState( self.y0, self.func(t[0].type_as(self.y0[0]), self.y0), t[0], t[0], first_step, tuple(map(lambda x: [x] * 7, self.y0)))
def _runge_kutta_step(func, y0, f0, t0, dt, tableau): """Take an arbitrary Runge-Kutta step and estimate error. Args: func: Function to evaluate like `func(t, y)` to compute the time derivative of `y`. y0: Tensor initial value for the state. f0: Tensor initial value for the derivative, computed from `func(t0, y0)`. t0: float64 scalar Tensor giving the initial time. dt: float64 scalar Tensor giving the size of the desired time step. tableau: optional _ButcherTableau describing how to take the Runge-Kutta step. name: optional name for the operation. Returns: Tuple `(y1, f1, y1_error, k)` giving the estimated function value after the Runge-Kutta step at `t1 = t0 + dt`, the derivative of the state at `t1`, estimated error at `t1`, and a list of Runge-Kutta coefficients `k` used for calculating these terms. """ dtype = y0[0].dtype device = y0[0].device t0 = _convert_to_tensor(t0, dtype=dtype, device=device) dt = _convert_to_tensor(dt, dtype=dtype, device=device) k = tuple(map(lambda x: [x], f0)) for alpha_i, beta_i in zip(tableau.alpha, tableau.beta): ti = t0 + alpha_i * dt yi = tuple(y0_ + _scaled_dot_product(dt, beta_i, k_) for y0_, k_ in zip(y0, k)) tuple(k_.append(f_) for k_, f_ in zip(k, func(ti, yi))) if not (tableau.c_sol[-1] == 0 and tableau.c_sol[:-1] == tableau.beta[-1]): # This property (true for Dormand-Prince) lets us save a few FLOPs. yi = tuple(y0_ + _scaled_dot_product(dt, tableau.c_sol, k_) for y0_, k_ in zip(y0, k)) y1 = yi f1 = tuple(k_[-1] for k_ in k) y1_error = tuple(_scaled_dot_product(dt, tableau.c_error, k_) for k_ in k) return (y1, f1, y1_error, k)
def _optimal_step_size(last_step, mean_error_ratio, safety=0.9, ifactor=10.0, dfactor=0.2, order=5): """Calculate the optimal size for the next Runge-Kutta step.""" if mean_error_ratio == 0: return last_step * ifactor if mean_error_ratio < 1: dfactor = _convert_to_tensor(1, dtype=torch.float64, device=mean_error_ratio.device) error_ratio = torch.sqrt(mean_error_ratio).type_as(last_step) exponent = torch.tensor(1 / order).type_as(last_step) factor = torch.max(1 / ifactor, torch.min(error_ratio**exponent / safety, 1 / dfactor)) return last_step / factor
def before_integrate(self, t): f0 = self.func(t[0].type_as(self.y0[0]), self.y0) if self.first_step is None: first_step = _select_initial_step(self.func, t[0], self.y0, 4, self.rtol[0], self.atol[0], f0=f0).to(t) else: first_step = _convert_to_tensor(0.01, dtype=t.dtype, device=t.device) self.rk_state = _RungeKuttaState(self.y0, f0, t[0], t[0], first_step, interp_coeff=[self.y0] * 5)