コード例 #1
0
ファイル: solvers.py プロジェクト: tongtongliuliu/tfdiffeq
 def integrate(self, t):
     _assert_increasing(t)
     solution = [self.y0]
     t = move_to_device(tf.cast(t, tf.float64), self.y0[0].device)
     self.before_integrate(t)
     for i in range(1, t.shape[0]):
         y = self.advance(t[i])
         solution.append(y)
     return tuple(map(tf.stack, tuple(zip(*solution))))
コード例 #2
0
ファイル: solvers.py プロジェクト: dbxmcf/node
    def integrate(self, t):
        _assert_increasing(t)
        t = tf.cast(t, self.y0[0].dtype)
        time_grid = self.grid_constructor(self.func, self.y0, t)
        assert tf.equal(time_grid[0], t[0]) and tf.equal(time_grid[-1], t[-1])
        time_grid = move_to_device(time_grid, self.y0[0].device)

        solution = [cast_double(self.y0)]

        j = 1
        y0 = cast_double(self.y0)
        for t0, t1 in zip(time_grid[:-1], time_grid[1:]):
            dy = self.step_func(self.func, t0, t1 - t0, y0)
            y1 = tuple(y0_ + dy_ for y0_, dy_ in zip(y0, dy))
            y0 = y1

            while j < t.shape[0] and t1 >= t[j]:
                y = self._linear_interp(t0, t1, y0, y1, t[j])
                solution.append(y)
                j += 1

        return tuple(map(tf.stack, tuple(zip(*solution))))