def test_polynomial_fit(self, dtype): """Asserts that interpolation of 4th order polynomial is exact.""" coefficients = [1 + 2j, 0.3 - 1j, 3.5 - 3.7j, 0.5 - 0.1j, 0.1 + 0.1j] coefficients = [tf.cast(c, dtype) for c in coefficients] def f(x): components = [] for power, c in enumerate(reversed(coefficients)): components.append(c * x**power) return tf.add_n(components) def f_prime(x): components = [] for power, c in enumerate(reversed(coefficients[:-1])): components.append(c * x**(power) * (power + 1)) return tf.add_n(components) coeffs = rk_util._fourth_order_interpolation_coefficients( f(0.0), f(10.0), f(5.0), f_prime(0.0), f_prime(10.0), 10.0) times = np.linspace(0, 10, dtype=np.float32) y_fit = tf.stack([ rk_util.evaluate_interpolation(coeffs, 0.0, 10.0, t) for t in times ]) y_expected = f(times) self.assertAllClose(y_fit, y_expected)
def _interpolate_solution_at(target_time, solver_state): """Computes the solution at `target_time` using 4th order interpolation. Args: target_time: Floating `Tensor` specifying the time at which to obtain the solution. Must be within the interval of the last time step of the `solver_state`: `solver_state.last_step_start` <= `target_time` <= `solver_state.current_time`. solver_state: `_DopriSolverInternalState` - solver state. Returns: solution: Solution at `target_time` obtained by interpolation. coefficients: Interpolating coefficients used to construct the solution. """ coefficients = solver_state.interpolating_coefficients t0 = solver_state.last_step_start t1 = solver_state.current_time solution = rk_util.evaluate_interpolation(coefficients, t0, t1, target_time) return solution, coefficients