def integrate(self, t): _assert_increasing(t) solution = [self.y0] t = move_to_device(tf.cast(t, tf.float64), self.y0[0].device) self.before_integrate(t) for i in range(1, t.shape[0]): y = self.advance(t[i]) solution.append(y) return tuple(map(tf.stack, tuple(zip(*solution))))
def integrate(self, t): _assert_increasing(t) t = tf.cast(t, self.y0[0].dtype) time_grid = self.grid_constructor(self.func, self.y0, t) assert tf.equal(time_grid[0], t[0]) and tf.equal(time_grid[-1], t[-1]) time_grid = move_to_device(time_grid, self.y0[0].device) solution = [cast_double(self.y0)] j = 1 y0 = cast_double(self.y0) for t0, t1 in zip(time_grid[:-1], time_grid[1:]): dy = self.step_func(self.func, t0, t1 - t0, y0) y1 = tuple(y0_ + dy_ for y0_, dy_ in zip(y0, dy)) y0 = y1 while j < t.shape[0] and t1 >= t[j]: y = self._linear_interp(t0, t1, y0, y1, t[j]) solution.append(y) j += 1 return tuple(map(tf.stack, tuple(zip(*solution))))