def test_adjoint(self): """ Test against dopri5 """ f, y0, t_points, _ = construct_problem(TEST_DEVICE) func = lambda y0, t_points: torchdiffeq.odeint( f, y0, t_points, method='dopri5') ys = func(y0, t_points) torch.manual_seed(0) gradys = torch.rand_like(ys) ys.backward(gradys) # reg_y0_grad = y0.grad reg_t_grad = t_points.grad reg_a_grad = f.a.grad reg_b_grad = f.b.grad f, y0, t_points, _ = construct_problem(TEST_DEVICE) func = lambda y0, t_points: torchdiffeq.odeint_adjoint( f, y0, t_points, method='dopri5') ys = func(y0, t_points) ys.backward(gradys) # adj_y0_grad = y0.grad adj_t_grad = t_points.grad adj_a_grad = f.a.grad adj_b_grad = f.b.grad # self.assertLess(max_abs(reg_y0_grad - adj_y0_grad), eps) self.assertLess(max_abs(reg_t_grad - adj_t_grad), eps) self.assertLess(max_abs(reg_a_grad - adj_a_grad), eps) self.assertLess(max_abs(reg_b_grad - adj_b_grad), eps)
def test_rk4(self): f, y0, t_points, _ = construct_problem(TEST_DEVICE) func = lambda y0, t_points: torchdiffeq.odeint( f, y0, t_points, method='rk4') self.assertTrue(torch.autograd.gradcheck(func, (y0, t_points)))
def test_adaptive_heun(self): for ode in problems.PROBLEMS.keys(): f, y0, t_points, sol = problems.construct_problem(TEST_DEVICE, ode=ode) y = torchdiffeq.odeint(f, y0, t_points, method='adaptive_heun') with self.subTest(ode=ode): self.assertLess(rel_error(sol, y), error_tol)
def test_odeint(self): for reverse in (False, True): for dtype in DTYPES: for device in DEVICES: for solver in [ 'RK45', 'RK23', 'DOP853', 'Radau', 'BDF', 'LSODA' ]: for ode in PROBLEMS: eps = 1e-3 with self.subTest(reverse=reverse, dtype=dtype, device=device, ode=ode, solver=solver): f, y0, t_points, sol = construct_problem( dtype=dtype, device=device, ode=ode, reverse=reverse) y = torchdiffeq.odeint( f, y0, t_points, method='scipy_solver', options={"solver": solver}) self.assertTrue(sol.shape == y.shape) self.assertLess(rel_error(sol, y), eps)
def test_adjoint(self): for ode in problems.PROBLEMS.keys(): f, y0, t_points, sol = problems.construct_problem(TEST_DEVICE, reverse=True) y = torchdiffeq.odeint_adjoint(f, y0, t_points, method='dopri5') with self.subTest(ode=ode): self.assertLess(rel_error(sol, y), error_tol)
def test_adaptive_heun_gradient(self): f, y0, t_points, sol = construct_problem(TEST_DEVICE) tuple_f = lambda t, y: (f(t, y[0]), f(t, y[1])) for i in range(2): func = lambda y0, t_points: torchdiffeq.odeint(tuple_f, (y0, y0), t_points, method='adaptive_heun')[i] self.assertTrue(torch.autograd.gradcheck(func, (y0, t_points)))
def test_bosh3(self): for ode in problems.PROBLEMS.keys(): f, y0, t_points, sol = problems.construct_problem(TEST_DEVICE, ode=ode) y = torchdiffeq.odeint(f, y0, t_points, method='bosh3') with self.subTest(ode=ode): # Seems less accurate so we increase the tolerance self.assertLess(rel_error(sol, y), large_error_tol)
def test_odeint(self): for device in DEVICES: for method in METHODS: with self.subTest(device=device, method=method): f, y0, t_points, _ = construct_problem(device=device) func = lambda y0, t_points: torchdiffeq.odeint( f, y0, t_points, method=method) self.assertTrue( torch.autograd.gradcheck(func, (y0, t_points)))
def test_gradient(self): for device in DEVICES: f, y0, t_points, sol = construct_problem(device=device) tuple_f = lambda t, y: (f(t, y[0]), f(t, y[1])) for method in ADAPTIVE_METHODS: with self.subTest(device=device, method=method): for i in range(2): func = lambda y0, t_points: torchdiffeq.odeint(tuple_f, (y0, y0), t_points, method=method)[i] self.assertTrue(torch.autograd.gradcheck(func, (y0, t_points)))
def test_adaptive_heun(self): f, y0, t_points, sol = construct_problem(TEST_DEVICE) tuple_f = lambda t, y: (f(t, y[0]), f(t, y[1])) tuple_y0 = (y0, y0) tuple_y = torchdiffeq.odeint(tuple_f, tuple_y0, t_points, method='adaptive_heun') max_error0 = (sol - tuple_y[0]).max() max_error1 = (sol - tuple_y[1]).max() self.assertLess(max_error0, eps) self.assertLess(max_error1, eps)
def test_odeint(self): for reverse in (False, True): for dtype in DTYPES: for device in DEVICES: for method in METHODS: for ode in PROBLEMS: with self.subTest(reverse=reverse, dtype=dtype, device=device, ode=ode, method=method): f, y0, t_points, sol = construct_problem(dtype=dtype, device=device, ode=ode, reverse=reverse) y = torchdiffeq.odeint(f, y0, t_points[0:1], method=method) self.assertLess((sol[0] - y).abs().max(), 1e-12)
def test_dopri8(self): for ode in problems.PROBLEMS.keys(): f, y0, t_points, sol = problems.construct_problem(TEST_DEVICE, ode=ode) y = torchdiffeq.odeint(f, y0, t_points, method='dopri8', rtol=1e-12, atol=1e-14) with self.subTest(ode=ode): self.assertLess(rel_error(sol, y), error_tol)
def test_odeint(self): for reverse in (False, True): for dtype in DTYPES: for device in DEVICES: for method in METHODS: # TODO: remove after event handling gets enabled. if method == 'scipy_solver': continue for ode in ("constant", "sine"): with self.subTest(reverse=reverse, dtype=dtype, device=device, ode=ode, method=method): if method == "explicit_adams": tol = 7e-2 elif method == "euler": tol = 5e-3 else: tol = 1e-4 f, y0, t_points, sol = construct_problem( dtype=dtype, device=device, ode=ode, reverse=reverse) def event_fn(t, y): return torch.sum(y - sol[2]) if method in FIXED_METHODS: options = { "step_size": 0.01, "interp": "cubic" } else: options = {} t, y = torchdiffeq.odeint(f, y0, t_points[0:2], event_fn=event_fn, method=method, options=options) y = y[-1] self.assertLess(rel_error(sol[2], y), tol) self.assertLess(rel_error(t_points[2], t), tol)
def test_adjoint(self): for reverse in (False, True): for dtype in DTYPES: for device in DEVICES: for ode in PROBLEMS: if ode == 'linear': eps = 2e-3 else: eps = 1e-4 with self.subTest(reverse=reverse, dtype=dtype, device=device, ode=ode): f, y0, t_points, sol = construct_problem(dtype=dtype, device=device, ode=ode, reverse=reverse) y = torchdiffeq.odeint_adjoint(f, y0, t_points) self.assertLess(rel_error(sol, y), eps)
def test_forward(self): for dtype in DTYPES: eps = EPS[dtype] for device in DEVICES: f, y0, t_points, sol = construct_problem(dtype=dtype, device=device) tuple_f = lambda t, y: (f(t, y[0]), f(t, y[1])) tuple_y0 = (y0, y0) for method in ADAPTIVE_METHODS: with self.subTest(dtype=dtype, device=device, method=method): tuple_y = torchdiffeq.odeint(tuple_f, tuple_y0, t_points, method=method) max_error0 = (sol - tuple_y[0]).max() max_error1 = (sol - tuple_y[1]).max() self.assertLess(max_error0, eps) self.assertLess(max_error1, eps)
def test_adjoint(self): f, y0, t_points, sol = construct_problem(device="cpu", ode="constant") def event_fn(t, y): return torch.sum(y - sol[-1]) t, y = torchdiffeq.odeint_adjoint(f, y0, t_points[0:2], event_fn=event_fn, method="dopri5") y = y[-1] self.assertLess(rel_error(sol[-1], y), 1e-4) self.assertLess(rel_error(t_points[-1], t), 1e-4) # Make sure adjoint mode backward code can still be run. t.backward(retain_graph=True) y.sum().backward()
def test_odeint(self): for reverse in (False, True): for dtype in DTYPES: for device in DEVICES: for method in METHODS: kwargs = dict() # Have to increase tolerance for dopri8. if method == 'dopri8' and dtype == torch.float64: kwargs = dict(rtol=1e-12, atol=1e-14) if method == 'dopri8' and dtype == torch.float32: kwargs = dict(rtol=1e-7, atol=1e-7) problems = PROBLEMS if method in ADAPTIVE_METHODS else ( 'constant', ) for ode in problems: if method in ['adaptive_heun', 'bosh3']: eps = 4e-3 elif ode == 'linear': eps = 2e-3 else: eps = 1e-4 with self.subTest(reverse=reverse, dtype=dtype, device=device, ode=ode, method=method): f, y0, t_points, sol = construct_problem( dtype=dtype, device=device, ode=ode, reverse=reverse) y = torchdiffeq.odeint(f, y0, t_points, method=method, **kwargs) self.assertLess(rel_error(sol, y), eps)
def test_euler(self): f, y0, t_points, sol = problems.construct_problem(TEST_DEVICE, reverse=True) y = torchdiffeq.odeint(f, y0, t_points, method='euler') self.assertLess(rel_error(sol, y), error_tol)
def test_rk4_classic(self): f, y0, t_points, sol = problems.construct_problem(TEST_DEVICE) y = torchdiffeq.odeint(f, y0, t_points, method='rk4_classic') self.assertLess(rel_error(sol, y), error_tol)
def test_explicit_adams(self): f, y0, t_points, sol = problems.construct_problem(TEST_DEVICE) y = torchdiffeq.odeint(f, y0, t_points, method='explicit_adams') self.assertLess(rel_error(sol, y), error_tol)
def test_adjoint(self): """ Test against dopri5 """ for device in DEVICES: for ode in PROBLEMS: for t_grad in (True, False): if ode == 'constant': eps = 1e-12 elif ode == 'linear': eps = 1e-5 elif ode == 'sine': eps = 5e-3 else: raise RuntimeError with self.subTest(device=device, ode=ode, t_grad=t_grad): f, y0, t_points, _ = construct_problem(device=device, ode=ode) t_points.requires_grad_(t_grad) ys = torchdiffeq.odeint(f, y0, t_points, rtol=1e-9, atol=1e-12) torch.manual_seed(0) gradys = torch.rand_like(ys) ys.backward(gradys) reg_y0_grad = y0.grad.clone() reg_t_grad = t_points.grad.clone() if t_grad else None reg_params_grads = [] for param in f.parameters(): reg_params_grads.append(param.grad.clone()) y0.grad.zero_() if t_grad: t_points.grad.zero_() for param in f.parameters(): param.grad.zero_() ys = torchdiffeq.odeint_adjoint(f, y0, t_points, rtol=1e-9, atol=1e-12) ys.backward(gradys) adj_y0_grad = y0.grad adj_t_grad = t_points.grad if t_grad else None adj_params_grads = [] for param in f.parameters(): adj_params_grads.append(param.grad) self.assertLess(max_abs(reg_y0_grad - adj_y0_grad), eps) if t_grad: self.assertLess(max_abs(reg_t_grad - adj_t_grad), eps) for reg_grad, adj_grad in zip(reg_params_grads, adj_params_grads): self.assertLess(max_abs(reg_grad - adj_grad), eps)
def test_dopri8(self): f, y0, t_points, sol = problems.construct_problem(TEST_DEVICE, reverse=True) y = torchdiffeq.odeint(f, y0, t_points[0:1], method='dopri8') self.assertLess(max_abs(sol[0] - y), error_tol)