def test_adjoint(self):
        """
        Test against dopri5
        """
        f, y0, t_points, _ = construct_problem(TEST_DEVICE)

        func = lambda y0, t_points: torchdiffeq.odeint(
            f, y0, t_points, method='dopri5')
        ys = func(y0, t_points)
        torch.manual_seed(0)
        gradys = torch.rand_like(ys)
        ys.backward(gradys)

        # reg_y0_grad = y0.grad
        reg_t_grad = t_points.grad
        reg_a_grad = f.a.grad
        reg_b_grad = f.b.grad

        f, y0, t_points, _ = construct_problem(TEST_DEVICE)

        func = lambda y0, t_points: torchdiffeq.odeint_adjoint(
            f, y0, t_points, method='dopri5')
        ys = func(y0, t_points)
        ys.backward(gradys)

        # adj_y0_grad = y0.grad
        adj_t_grad = t_points.grad
        adj_a_grad = f.a.grad
        adj_b_grad = f.b.grad

        # self.assertLess(max_abs(reg_y0_grad - adj_y0_grad), eps)
        self.assertLess(max_abs(reg_t_grad - adj_t_grad), eps)
        self.assertLess(max_abs(reg_a_grad - adj_a_grad), eps)
        self.assertLess(max_abs(reg_b_grad - adj_b_grad), eps)
    def test_rk4(self):

        f, y0, t_points, _ = construct_problem(TEST_DEVICE)

        func = lambda y0, t_points: torchdiffeq.odeint(
            f, y0, t_points, method='rk4')
        self.assertTrue(torch.autograd.gradcheck(func, (y0, t_points)))
Exemple #3
0
 def test_adaptive_heun(self):
     for ode in problems.PROBLEMS.keys():
         f, y0, t_points, sol = problems.construct_problem(TEST_DEVICE,
                                                           ode=ode)
         y = torchdiffeq.odeint(f, y0, t_points, method='adaptive_heun')
         with self.subTest(ode=ode):
             self.assertLess(rel_error(sol, y), error_tol)
Exemple #4
0
    def test_odeint(self):
        for reverse in (False, True):
            for dtype in DTYPES:
                for device in DEVICES:
                    for solver in [
                            'RK45', 'RK23', 'DOP853', 'Radau', 'BDF', 'LSODA'
                    ]:
                        for ode in PROBLEMS:
                            eps = 1e-3

                            with self.subTest(reverse=reverse,
                                              dtype=dtype,
                                              device=device,
                                              ode=ode,
                                              solver=solver):
                                f, y0, t_points, sol = construct_problem(
                                    dtype=dtype,
                                    device=device,
                                    ode=ode,
                                    reverse=reverse)
                                y = torchdiffeq.odeint(
                                    f,
                                    y0,
                                    t_points,
                                    method='scipy_solver',
                                    options={"solver": solver})
                                self.assertTrue(sol.shape == y.shape)
                                self.assertLess(rel_error(sol, y), eps)
Exemple #5
0
    def test_adjoint(self):
        for ode in problems.PROBLEMS.keys():
            f, y0, t_points, sol = problems.construct_problem(TEST_DEVICE, reverse=True)

            y = torchdiffeq.odeint_adjoint(f, y0, t_points, method='dopri5')
            with self.subTest(ode=ode):
                self.assertLess(rel_error(sol, y), error_tol)
Exemple #6
0
    def test_adaptive_heun_gradient(self):
        f, y0, t_points, sol = construct_problem(TEST_DEVICE)

        tuple_f = lambda t, y: (f(t, y[0]), f(t, y[1]))

        for i in range(2):
            func = lambda y0, t_points: torchdiffeq.odeint(tuple_f, (y0, y0), t_points, method='adaptive_heun')[i]
            self.assertTrue(torch.autograd.gradcheck(func, (y0, t_points)))
Exemple #7
0
 def test_bosh3(self):
     for ode in problems.PROBLEMS.keys():
         f, y0, t_points, sol = problems.construct_problem(TEST_DEVICE,
                                                           ode=ode)
         y = torchdiffeq.odeint(f, y0, t_points, method='bosh3')
         with self.subTest(ode=ode):
             # Seems less accurate so we increase the tolerance
             self.assertLess(rel_error(sol, y), large_error_tol)
    def test_odeint(self):
        for device in DEVICES:
            for method in METHODS:

                with self.subTest(device=device, method=method):
                    f, y0, t_points, _ = construct_problem(device=device)
                    func = lambda y0, t_points: torchdiffeq.odeint(
                        f, y0, t_points, method=method)
                    self.assertTrue(
                        torch.autograd.gradcheck(func, (y0, t_points)))
Exemple #9
0
    def test_gradient(self):
        for device in DEVICES:
            f, y0, t_points, sol = construct_problem(device=device)
            tuple_f = lambda t, y: (f(t, y[0]), f(t, y[1]))
            for method in ADAPTIVE_METHODS:

                with self.subTest(device=device, method=method):
                    for i in range(2):
                        func = lambda y0, t_points: torchdiffeq.odeint(tuple_f, (y0, y0), t_points, method=method)[i]
                        self.assertTrue(torch.autograd.gradcheck(func, (y0, t_points)))
Exemple #10
0
    def test_adaptive_heun(self):
        f, y0, t_points, sol = construct_problem(TEST_DEVICE)

        tuple_f = lambda t, y: (f(t, y[0]), f(t, y[1]))
        tuple_y0 = (y0, y0)

        tuple_y = torchdiffeq.odeint(tuple_f, tuple_y0, t_points, method='adaptive_heun')
        max_error0 = (sol - tuple_y[0]).max()
        max_error1 = (sol - tuple_y[1]).max()
        self.assertLess(max_error0, eps)
        self.assertLess(max_error1, eps)
Exemple #11
0
    def test_odeint(self):
        for reverse in (False, True):
            for dtype in DTYPES:
                for device in DEVICES:
                    for method in METHODS:
                        for ode in PROBLEMS:

                            with self.subTest(reverse=reverse, dtype=dtype, device=device, ode=ode, method=method):
                                f, y0, t_points, sol = construct_problem(dtype=dtype, device=device, ode=ode,
                                                                         reverse=reverse)
                                y = torchdiffeq.odeint(f, y0, t_points[0:1], method=method)
                                self.assertLess((sol[0] - y).abs().max(), 1e-12)
Exemple #12
0
 def test_dopri8(self):
     for ode in problems.PROBLEMS.keys():
         f, y0, t_points, sol = problems.construct_problem(TEST_DEVICE,
                                                           ode=ode)
         y = torchdiffeq.odeint(f,
                                y0,
                                t_points,
                                method='dopri8',
                                rtol=1e-12,
                                atol=1e-14)
         with self.subTest(ode=ode):
             self.assertLess(rel_error(sol, y), error_tol)
Exemple #13
0
    def test_odeint(self):
        for reverse in (False, True):
            for dtype in DTYPES:
                for device in DEVICES:
                    for method in METHODS:

                        # TODO: remove after event handling gets enabled.
                        if method == 'scipy_solver':
                            continue

                        for ode in ("constant", "sine"):
                            with self.subTest(reverse=reverse,
                                              dtype=dtype,
                                              device=device,
                                              ode=ode,
                                              method=method):
                                if method == "explicit_adams":
                                    tol = 7e-2
                                elif method == "euler":
                                    tol = 5e-3
                                else:
                                    tol = 1e-4

                                f, y0, t_points, sol = construct_problem(
                                    dtype=dtype,
                                    device=device,
                                    ode=ode,
                                    reverse=reverse)

                                def event_fn(t, y):
                                    return torch.sum(y - sol[2])

                                if method in FIXED_METHODS:
                                    options = {
                                        "step_size": 0.01,
                                        "interp": "cubic"
                                    }
                                else:
                                    options = {}

                                t, y = torchdiffeq.odeint(f,
                                                          y0,
                                                          t_points[0:2],
                                                          event_fn=event_fn,
                                                          method=method,
                                                          options=options)
                                y = y[-1]
                                self.assertLess(rel_error(sol[2], y), tol)
                                self.assertLess(rel_error(t_points[2], t), tol)
Exemple #14
0
    def test_adjoint(self):
        for reverse in (False, True):
            for dtype in DTYPES:
                for device in DEVICES:
                    for ode in PROBLEMS:
                        if ode == 'linear':
                            eps = 2e-3
                        else:
                            eps = 1e-4

                        with self.subTest(reverse=reverse, dtype=dtype, device=device, ode=ode):
                            f, y0, t_points, sol = construct_problem(dtype=dtype, device=device, ode=ode,
                                                                     reverse=reverse)
                            y = torchdiffeq.odeint_adjoint(f, y0, t_points)
                            self.assertLess(rel_error(sol, y), eps)
Exemple #15
0
    def test_forward(self):
        for dtype in DTYPES:
            eps = EPS[dtype]
            for device in DEVICES:
                f, y0, t_points, sol = construct_problem(dtype=dtype, device=device)
                tuple_f = lambda t, y: (f(t, y[0]), f(t, y[1]))
                tuple_y0 = (y0, y0)
                for method in ADAPTIVE_METHODS:

                    with self.subTest(dtype=dtype, device=device, method=method):
                        tuple_y = torchdiffeq.odeint(tuple_f, tuple_y0, t_points, method=method)
                        max_error0 = (sol - tuple_y[0]).max()
                        max_error1 = (sol - tuple_y[1]).max()
                        self.assertLess(max_error0, eps)
                        self.assertLess(max_error1, eps)
Exemple #16
0
    def test_adjoint(self):
        f, y0, t_points, sol = construct_problem(device="cpu", ode="constant")

        def event_fn(t, y):
            return torch.sum(y - sol[-1])

        t, y = torchdiffeq.odeint_adjoint(f,
                                          y0,
                                          t_points[0:2],
                                          event_fn=event_fn,
                                          method="dopri5")
        y = y[-1]
        self.assertLess(rel_error(sol[-1], y), 1e-4)
        self.assertLess(rel_error(t_points[-1], t), 1e-4)

        # Make sure adjoint mode backward code can still be run.
        t.backward(retain_graph=True)
        y.sum().backward()
Exemple #17
0
    def test_odeint(self):
        for reverse in (False, True):
            for dtype in DTYPES:
                for device in DEVICES:
                    for method in METHODS:

                        kwargs = dict()
                        # Have to increase tolerance for dopri8.
                        if method == 'dopri8' and dtype == torch.float64:
                            kwargs = dict(rtol=1e-12, atol=1e-14)
                        if method == 'dopri8' and dtype == torch.float32:
                            kwargs = dict(rtol=1e-7, atol=1e-7)

                        problems = PROBLEMS if method in ADAPTIVE_METHODS else (
                            'constant', )
                        for ode in problems:
                            if method in ['adaptive_heun', 'bosh3']:
                                eps = 4e-3
                            elif ode == 'linear':
                                eps = 2e-3
                            else:
                                eps = 1e-4

                            with self.subTest(reverse=reverse,
                                              dtype=dtype,
                                              device=device,
                                              ode=ode,
                                              method=method):
                                f, y0, t_points, sol = construct_problem(
                                    dtype=dtype,
                                    device=device,
                                    ode=ode,
                                    reverse=reverse)
                                y = torchdiffeq.odeint(f,
                                                       y0,
                                                       t_points,
                                                       method=method,
                                                       **kwargs)
                                self.assertLess(rel_error(sol, y), eps)
Exemple #18
0
    def test_euler(self):
        f, y0, t_points, sol = problems.construct_problem(TEST_DEVICE,
                                                          reverse=True)

        y = torchdiffeq.odeint(f, y0, t_points, method='euler')
        self.assertLess(rel_error(sol, y), error_tol)
Exemple #19
0
    def test_rk4_classic(self):
        f, y0, t_points, sol = problems.construct_problem(TEST_DEVICE)

        y = torchdiffeq.odeint(f, y0, t_points, method='rk4_classic')
        self.assertLess(rel_error(sol, y), error_tol)
Exemple #20
0
    def test_explicit_adams(self):
        f, y0, t_points, sol = problems.construct_problem(TEST_DEVICE)

        y = torchdiffeq.odeint(f, y0, t_points, method='explicit_adams')
        self.assertLess(rel_error(sol, y), error_tol)
    def test_adjoint(self):
        """
        Test against dopri5
        """
        for device in DEVICES:
            for ode in PROBLEMS:
                for t_grad in (True, False):
                    if ode == 'constant':
                        eps = 1e-12
                    elif ode == 'linear':
                        eps = 1e-5
                    elif ode == 'sine':
                        eps = 5e-3
                    else:
                        raise RuntimeError

                    with self.subTest(device=device, ode=ode, t_grad=t_grad):
                        f, y0, t_points, _ = construct_problem(device=device,
                                                               ode=ode)
                        t_points.requires_grad_(t_grad)

                        ys = torchdiffeq.odeint(f,
                                                y0,
                                                t_points,
                                                rtol=1e-9,
                                                atol=1e-12)
                        torch.manual_seed(0)
                        gradys = torch.rand_like(ys)
                        ys.backward(gradys)

                        reg_y0_grad = y0.grad.clone()
                        reg_t_grad = t_points.grad.clone() if t_grad else None
                        reg_params_grads = []
                        for param in f.parameters():
                            reg_params_grads.append(param.grad.clone())

                        y0.grad.zero_()
                        if t_grad:
                            t_points.grad.zero_()
                        for param in f.parameters():
                            param.grad.zero_()

                        ys = torchdiffeq.odeint_adjoint(f,
                                                        y0,
                                                        t_points,
                                                        rtol=1e-9,
                                                        atol=1e-12)
                        ys.backward(gradys)

                        adj_y0_grad = y0.grad
                        adj_t_grad = t_points.grad if t_grad else None
                        adj_params_grads = []
                        for param in f.parameters():
                            adj_params_grads.append(param.grad)

                        self.assertLess(max_abs(reg_y0_grad - adj_y0_grad),
                                        eps)
                        if t_grad:
                            self.assertLess(max_abs(reg_t_grad - adj_t_grad),
                                            eps)
                        for reg_grad, adj_grad in zip(reg_params_grads,
                                                      adj_params_grads):
                            self.assertLess(max_abs(reg_grad - adj_grad), eps)
Exemple #22
0
    def test_dopri8(self):
        f, y0, t_points, sol = problems.construct_problem(TEST_DEVICE,
                                                          reverse=True)

        y = torchdiffeq.odeint(f, y0, t_points[0:1], method='dopri8')
        self.assertLess(max_abs(sol[0] - y), error_tol)