def test_polynomial_fit(self, dtype):
        """Asserts that interpolation of 4th order polynomial is exact."""
        coefficients = [1 + 2j, 0.3 - 1j, 3.5 - 3.7j, 0.5 - 0.1j, 0.1 + 0.1j]
        coefficients = [tf.cast(c, dtype) for c in coefficients]

        def f(x):
            components = []
            for power, c in enumerate(reversed(coefficients)):
                components.append(c * x**power)
            return tf.add_n(components)

        def f_prime(x):
            components = []
            for power, c in enumerate(reversed(coefficients[:-1])):
                components.append(c * x**(power) * (power + 1))
            return tf.add_n(components)

        coeffs = rk_util._fourth_order_interpolation_coefficients(
            f(0.0), f(10.0), f(5.0), f_prime(0.0), f_prime(10.0), 10.0)
        times = np.linspace(0, 10, dtype=np.float32)
        y_fit = tf.stack([
            rk_util.evaluate_interpolation(coeffs, 0.0, 10.0, t) for t in times
        ])
        y_expected = f(times)
        self.assertAllClose(y_fit, y_expected)
예제 #2
0
def _interpolate_solution_at(target_time, solver_state):
  """Computes the solution at `target_time` using 4th order interpolation.

  Args:
    target_time: Floating `Tensor` specifying the time at which to obtain the
      solution. Must be within the interval of the last time step of the
      `solver_state`: `solver_state.last_step_start` <= `target_time` <=
      `solver_state.current_time`.
    solver_state: `_DopriSolverInternalState` - solver state.
  Returns:
    solution: Solution at `target_time` obtained by interpolation.
    coefficients: Interpolating coefficients used to construct the solution.
  """
  coefficients = solver_state.interpolating_coefficients
  t0 = solver_state.last_step_start
  t1 = solver_state.current_time
  solution = rk_util.evaluate_interpolation(coefficients, t0, t1, target_time)
  return solution, coefficients