def test_batching_auto_broadcast(
     self, interpolation_times, reference_times, reference_yields, results):
   dtype = tf.float64
   result = self.evaluate(
       constant_fwd_interpolation.interpolate(
           interpolation_times, reference_times,
           reference_yields, dtype=dtype))
   expected_result = np.array(results)
   np.testing.assert_allclose(result, expected_result, atol=1e-6)
  def test_extrapolation(self):
    interpolation_times = [0.5, 35.0]

    reference_times = [1.0, 2.0, 6.0, 8.0, 18.0, 30.0]
    reference_yields = [0.01, 0.02, 0.015, 0.014, 0.02, 0.025]
    result = self.evaluate(
        constant_fwd_interpolation.interpolate(interpolation_times,
                                               reference_times,
                                               reference_yields))
    expected_result = np.array([0.01, 0.025])
    np.testing.assert_allclose(result, expected_result, atol=1e-6)
  def test_correctness(self):
    interpolation_times = [1., 3., 6., 7., 8., 15., 18., 25., 30.]

    reference_times = [0.0, 2.0, 6.0, 8.0, 18.0, 30.0]
    reference_yields = [0.01, 0.02, 0.015, 0.014, 0.02, 0.025]
    result = self.evaluate(
        constant_fwd_interpolation.interpolate(interpolation_times,
                                               reference_times,
                                               reference_yields))
    expected_result = np.array(
        [0.02, 0.0175, 0.015, 0.01442857, 0.014, 0.01904, 0.02, 0.0235, 0.025])
    np.testing.assert_allclose(result, expected_result, atol=1e-6)
  def test_batching(self):
    interpolation_times = [[1., 3., 6., 7.], [8., 15., 18., 25.]]

    reference_times = [[0.0, 2.0, 6.0, 8.0, 18.0, 30.0],
                       [0.0, 2.0, 6.0, 8.0, 18.0, 30.0]]
    reference_yields = [[0.01, 0.02, 0.015, 0.014, 0.02, 0.025],
                        [0.01, 0.02, 0.015, 0.014, 0.02, 0.025]]
    result = self.evaluate(
        constant_fwd_interpolation.interpolate(interpolation_times,
                                               reference_times,
                                               reference_yields))
    expected_result = np.array(
        [[0.02, 0.0175, 0.015, 0.01442857], [0.014, 0.01904, 0.02, 0.0235]])
    np.testing.assert_allclose(result, expected_result, atol=1e-6)