コード例 #1
0
    def test_odeint_required_dtypes(self):
        with self.assertRaisesRegexp(TypeError,
                                     '`y0` must have a floating point'):
            odes.odeint(self.func, math_ops.cast(self.y0, dtypes.int32),
                        [0, 1])

        with self.assertRaisesRegexp(TypeError,
                                     '`t` must have a floating point'):
            odes.odeint(self.func, self.y0, math_ops.cast([0, 1],
                                                          dtypes.int32))
コード例 #2
0
    def test_odeint_riccati(self):
        """The Ricatti equation is: dy / dt = (y - t) ** 2 + 1.0,  y(0) = 0.5.
        Its analytical solution is y = 1.0 / (2.0 - t) + t."""
        func = lambda t, y: (y - t)**2 + 1.0
        t = np.linspace(0.0, 1.0, 11)
        y_solved = odes.odeint(func, np.float64(0.5), t)

        y_true = 1.0 / (2.0 - t) + t
        self.assertAllClose(y_true, y_solved)
コード例 #3
0
    def test_odeint_complex(self):
        """Test a complex, linear ODE: dy / dt = k * y,  y(0) = 1.0. Its analytical solution is y = exp(k * t)."""
        k = 1j - 0.1
        func = lambda y, t: k * y
        t = np.linspace(0.0, 1.0, 11)
        y_solved = odes.odeint(func, 1.0 + 0.0j, t)

        y_true = np.exp(k * t)
        self.assertAllClose(y_true, y_solved)
コード例 #4
0
    def test_odeint_different_times(self):
        times0 = np.linspace(0, 10, num=11, dtype=float)
        times1 = np.linspace(0, 10, num=101, dtype=float)

        y_solved_0, info_0 = odes.odeint(self.func,
                                         self.y0,
                                         times0,
                                         full_output=True)
        y_solved_1, info_1 = odes.odeint(self.func,
                                         self.y0,
                                         times1,
                                         full_output=True)

        self.assertAllClose(y_solved_0, y_solved_1[::10])
        self.assertAllEqual(info_0['num_func_evals'], info_1['num_func_evals'])
        self.assertAllEqual(info_0['integrate_points'],
                            info_1['integrate_points'])
        self.assertAllEqual(info_0['error_ratio'], info_1['error_ratio'])
コード例 #5
0
 def test_odeint_higher_rank(self):
     func = lambda y, t: y
     y0 = constant_op.constant(1.0, dtype=dtypes.float64)
     t = np.linspace(0.0, 1.0, 11)
     for shape in [(), (1, ), (1, 1)]:
         expected_shape = (len(t), ) + shape
         y_solved = odes.odeint(func, array_ops.reshape(y0, shape), t)
         self.assertEqual(y_solved.get_shape(),
                          tensor_shape.TensorShape(expected_shape))
         self.assertEqual(y_solved.shape, expected_shape)
コード例 #6
0
    def test_odeint_exp(self):
        """Test odeint by an exponential function: dy / dt = y,  y(0) = 1.0. Its analytical solution is y = exp(t)."""
        func = lambda y, t: y
        y0 = constant_op.constant(1.0, dtype=dtypes.float64)
        t = np.linspace(0.0, 1.0, 11)
        y_solved = odes.odeint(func, y0, t)

        self.assertEqual(y_solved.get_shape(), tensor_shape.TensorShape([11]))
        y_true = np.exp(t)
        self.assertAllClose(y_true, y_solved)
コード例 #7
0
 def test_odeint_all_dtypes(self):
     func = lambda y, t: y
     t = np.linspace(0.0, 1.0, 11)
     for y0_dtype in [
             dtypes.float32, dtypes.float64, dtypes.complex64,
             dtypes.complex128
     ]:
         for t_dtype in [dtypes.float32, dtypes.float64]:
             y0 = math_ops.cast(1.0, y0_dtype)
             y_solved = odes.odeint(func, y0, math_ops.cast(t, t_dtype))
             expected = np.asarray(np.exp(t))
             self.assertAllClose(y_solved, expected, rtol=1e-5)
             self.assertEqual(dtypes.as_dtype(y_solved.dtype), y0_dtype)
コード例 #8
0
    def test_odeint_5th_order_accuracy(self):
        t = [0, 20]
        kwargs = dict(full_output=True,
                      method='dopri5',
                      options=dict(max_num_steps=2000))
        _, info_0 = odes.odeint(self.func,
                                self.y0,
                                t,
                                rtol=0,
                                atol=1e-6,
                                **kwargs)
        _, info_1 = odes.odeint(self.func,
                                self.y0,
                                t,
                                rtol=0,
                                atol=1e-9,
                                **kwargs)

        self.assertAllClose(array_ops.size(
            info_0['integrate_points'], out_type=dtypes.float32) * 1000**0.2,
                            array_ops.size(info_1['integrate_points'],
                                           out_type=dtypes.float32),
                            rtol=0.01)
コード例 #9
0
    def test_odeint_2d_linear(self):
        """Solve the 2D linear differential equation:
        dy1 / dt = 3.0 * y1 + 4.0 * y2,
        dy2 / dt = -4.0 * y1 + 3.0 * y2,
        y1(0) = 0.0,
        y2(0) = 1.0.
        Its analytical solution is
        y1 = sin(4.0 * t) * exp(3.0 * t),
        y2 = cos(4.0 * t) * exp(3.0 * t).
        """
        matrix = constant_op.constant([[3.0, 4.0], [-4.0, 3.0]],
                                      dtype=dtypes.float64)
        func = lambda y, t: math_ops.matmul(matrix, y)
        y0 = constant_op.constant([[0.0], [1.0]], dtype=dtypes.float64)
        t = np.linspace(0.0, 1.0, 11)
        y_solved = odes.odeint(func, y0, t)

        y_true = np.zeros((len(t), 2, 1))
        y_true[:, 0, 0] = np.sin(4.0 * t) * np.exp(3.0 * t)
        y_true[:, 1, 0] = np.cos(4.0 * t) * np.exp(3.0 * t)
        self.assertAllClose(y_true, y_solved, atol=1e-5)
コード例 #10
0
    def test_odeint_runtime_errors(self):
        with self.assertRaisesRegexp(ValueError,
                                     'cannot supply `options` without'):
            odes.odeint(self.func,
                        self.y0, [0, 1],
                        options={'first_step': 1.0})

        with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
                                     'max_num_steps'):
            odes.odeint(self.func,
                        self.y0, [0, 1],
                        method='dopri5',
                        options={'max_num_steps': 0})

        with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
                                     'monotonic increasing'):
            odes.odeint(self.func, self.y0, [1, 0])