def test_odeint_adjoint(sensitivity, solver, interpolator, stiffness): f = VanDerPol(stiffness) x = torch.randn(1024, 2, requires_grad=True) t0 = time.time() prob = ODEProblem(f, sensitivity=sensitivity, interpolator=interpolator, solver=solver, atol=1e-4, rtol=1e-4, atol_adjoint=1e-4, rtol_adjoint=1e-4) t_eval, sol_torchdyn = prob.odeint(x, t_span) t_end1 = time.time() - t0 t0 = time.time() sol_torchdiffeq = torchdiffeq.odeint_adjoint(f, x, t_span, method='dopri5', atol=1e-4, rtol=1e-4) t_end2 = time.time() - t0 true_sol = torchdiffeq.odeint_adjoint(f, x, t_span, method='dopri5', atol=1e-9, rtol=1e-9) t0 = time.time() grad1 = torch.autograd.grad(sol_torchdyn[-1].sum(), x)[0] t_end1 = time.time() - t0 t0 = time.time() grad2 = torch.autograd.grad(sol_torchdiffeq[-1].sum(), x)[0] t_end2 = time.time() - t0 grad_true = torch.autograd.grad(true_sol[-1].sum(), x)[0] err1 = (grad1-grad_true).abs().sum(1) err2 = (grad2-grad_true).abs().sum(1) assert (err1 <= 1e-3).all() and (err1.mean() <= err2.mean())
def test_seminorm(self): torch.manual_seed(3456786) # test can be flaky for dtype in DTYPES: for device in DEVICES: for method in ADAPTIVE_METHODS: if method == 'adaptive_heun': # Adaptive heun is consistently an awful choice with seminorms, it seems. My guess is that it's # consistently overconfident with its step sizes, and that having seminorms turned off means # that it actually gets it right more often. continue if dtype == torch.float32 and method == 'dopri8': continue with self.subTest(dtype=dtype, device=device, method=method): x0 = torch.tensor([1.0, 2.0], device=device, dtype=dtype) t = torch.tensor([0., 1.0], device=device, dtype=dtype) norm_f = _NeuralF(width=256, oscillate=True).to(device, dtype) out = torchdiffeq.odeint_adjoint(norm_f, x0, t, atol=3e-7, method=method) norm_f.nfe = 0 out.sum().backward() seminorm_f = _NeuralF(width=256, oscillate=True).to( device, dtype) with torch.no_grad(): for norm_param, seminorm_param in zip( norm_f.parameters(), seminorm_f.parameters()): seminorm_param.copy_(norm_param) out = torchdiffeq.odeint_adjoint( seminorm_f, x0, t, atol=1e-6, method=method, adjoint_options=dict(norm='seminorm')) seminorm_f.nfe = 0 out.sum().backward() self.assertLessEqual(seminorm_f.nfe, norm_f.nfe)
def forward(self, x): self.integration_time = self.integration_time.type_as(x) if self.adjoint: out = odeint_adjoint(self.odefunc, x, self.integration_time, rtol=1e-4, atol=1e-4) else: out = odeint(self.odefunc, x, self.integration_time, rtol=1e-4, atol=1e-4) return out[1]
def forward(self, first_point, time_steps_to_predict, backwards=False): """ # Decode the trajectory through ODE Solver """ n_traj_samples, n_traj = first_point.size()[0], first_point.size()[1] n_dims = first_point.size()[-1] if not self.ode_adjoint: pred_y = odeint(self.ode_func, first_point, time_steps_to_predict, rtol=self.odeint_rtol, atol=self.odeint_atol, method=self.ode_method) else: pred_y = odeint_adjoint(self.ode_func, first_point, time_steps_to_predict, rtol=self.odeint_rtol, atol=self.odeint_atol, method=self.ode_method) pred_y = pred_y.permute(1, 2, 0, 3) assert (torch.mean(pred_y[:, :, 0, :] - first_point) < 0.001) assert (pred_y.size()[0] == n_traj_samples) assert (pred_y.size()[1] == n_traj) return pred_y
def forward(self, x, eval_times=None): """Solves ODE starting from x. Parameters ---------- x : torch.Tensor Shape (batch_size, self.odefunc.data_dim) eval_times : None or torch.Tensor If None, returns solution of ODE at final time t=1. If torch.Tensor then returns full ODE trajectory evaluated at points in eval_times. """ # Forward pass corresponds to solving ODE, so reset number of function # evaluations counter if eval_times is None: integration_time = torch.tensor([0, 1]).float().type_as(x) else: integration_time = eval_times.type_as(x) if self.odefunc.augment_dim > 0: if self.is_conv: # Add augmentation batch_size, channels, height, width = x.shape aug = torch.zeros(batch_size, self.odefunc.augment_dim, height, width, device=x.device) # Shape (batch_size, channels + augment_dim, height, width) x_aug = torch.cat([x, aug], 1) else: # Add augmentation aug = torch.zeros(x.shape[0], self.odefunc.augment_dim, device=x.device) # Shape (batch_size, data_dim + augment_dim) x_aug = torch.cat([x, aug], 1) else: x_aug = x if self.adjoint: out = odeint_adjoint(self.odefunc, x_aug, integration_time, rtol=self.tol, atol=self.tol, method=self.method, options={'max_num_steps': self.max_num_steps}) else: out = odeint(self.odefunc, x_aug, integration_time, rtol=self.tol, atol=self.tol, method=self.method, options={'max_num_steps': self.max_num_steps}) if eval_times is None: return out[1] # Return only final time else: return out
def test_adjoint(self): for ode in problems.PROBLEMS.keys(): f, y0, t_points, sol = problems.construct_problem(TEST_DEVICE, reverse=True) y = torchdiffeq.odeint_adjoint(f, y0, t_points, method='dopri5') with self.subTest(ode=ode): self.assertLess(rel_error(sol, y), error_tol)
def _adjoint(self, x): return torchdiffeq.odeint_adjoint(self.defunc, x, self.s_span, rtol=self.st['rtol'], atol=self.st['atol'], method=self.st['solver'])
def reverse(self, z, c): _logpz = torch.zeros(z.shape[0], 1).to(z) states = (z, _logpz, c) if self.train_T: integration_times = torch.stack([ torch.tensor(0.0).to(z), self.sqrt_end_time * self.sqrt_end_time ]).to(z) else: integration_times = torch.tensor([0., self.T], requires_grad=False).to(z) integration_times = self._flip(integration_times, 0) # Refresh the odefunc statistics. self.odefunc.before_odeint() state_t = odeint_adjoint( self.odefunc, states, integration_times.to(z), atol=self.test_atol, rtol=self.test_rtol, method=self.test_solver, ) if len(integration_times) == 2: state_t = tuple(s[1] for s in state_t) x_t, _, c = state_t[:3] return x_t, c
def forward(self, x, eval_times=None): dt = self.T / self.time_steps if eval_times is None: integration_time = torch.tensor([0, self.T]).float().type_as(x) else: integration_time = eval_times.type_as(x) if self.dynamics.augment_dim > 0: x = x.view(x.size(0), -1) aug = torch.zeros(x.shape[0], self.dynamics.augment_dim).to(self.device) x_aug = torch.cat([x, aug], 1) else: x_aug = x if self.adjoint: out = odeint_adjoint(self.dynamics, x_aug, integration_time, method='euler', options={'step_size': dt}) else: out = odeint(self.dynamics, x_aug, integration_time, method='euler', options={'step_size': dt}) if eval_times is None: return out[1] else: return out
def forward(self, x, eval_times=None, only_last=True): if eval_times is None: integration_time = torch.tensor([0, 1]).float().type_as(x) else: integration_time = eval_times.type_as(x) if self.odefunc.augment_dim > 0: batch_size, channels, height, width = x.shape aug = torch.zeros(batch_size, self.odefunc.augment_dim, height, width) x_aug = torch.cat([x, aug], 1) else: x_aug = x out = odeint_adjoint(self.odefunc, x_aug, integration_time, rtol=self.tol, atol=self.tol, method='dopri5', options={'max_num_steps': 1000}) if only_last: return out[1] # Return only final time else: return out
def forward(self, x, c, log_det): _logpx = torch.zeros(x.shape[0], 1).to(x) states = (x, _logpx, c) if self.train_T: integration_times = torch.tensor( [0.0, self.sqrt_end_time * self.sqrt_end_time]).to(x) else: integration_times = torch.tensor([0.0, self.T]).to(x) # Refresh the odefunc statistics. self.odefunc.before_odeint() state_t = odeint_adjoint( self.odefunc, states, integration_times.to(x), atol=[self.atol, self.atol, self.atol], rtol=[self.rtol, self.rtol, self.atol], method=self.solver, options=self.solver_options, ) if len(integration_times) == 2: state_t = tuple(s[1] for s in state_t) z_t, logpz_t, c = state_t[:3] log_det = log_det - logpz_t return z_t, c, log_det
def forward(self, x, eval_times=None): # Forward pass corresponds to solving ODE, so reset number of function # evaluations counter self.odefunc.nfe = 0 if eval_times is None: integration_time = torch.tensor([0, 1]).float().type_as(x) else: integration_time = eval_times.type_as(x) if self.adjoint: out = odeint_adjoint(self.odefunc, x, integration_time, rtol=self.tol, atol=self.tol, method='dopri5', options={'max_num_steps': MAX_NUM_STEPS}) else: out = odeint(self.odefunc, x, integration_time, rtol=self.tol, atol=self.tol, method='dopri5', options={'max_num_steps': MAX_NUM_STEPS}) if eval_times is None: return out[1] # Return only final time else: return out
def forward(self, x, eval_times=None, only_last=True): if eval_times is None: #integration_time = torch.tensor([0, 1]).float().type_as(x) integration_time = torch.arange(0, 1, 0.1).float().type_as(x) else: integration_time = eval_times.type_as(x) if self.odefun.augment_dim > 0: aug = torch.zeros(x.shape[0], self.odefun.augment_dim) x_aug = torch.cat([x, aug], 1) else: x_aug = x out = odeint_adjoint(self.odefun, x_aug, integration_time, rtol=self.tol, atol=self.tol, method='dopri5', options={'max_num_steps': 1000}) if only_last: return out[-1] else: return out
def _run_ode(self, *xs, dynamics, **kwargs): # TODO: kwargs should be parsed to avoid conflicts! assert(all(x.shape[0] == xs[0].shape[0] for x in xs[1:])) n_batch = xs[0].shape[0] logp_init = torch.zeros(n_batch, 1).to(xs[0]) state = [*xs, logp_init] ts = torch.linspace(0.0, self._t_max, self._n_time_steps).to(xs[0]) kwargs = {**self._kwargs, **kwargs} if not self._use_checkpoints: from torchdiffeq import odeint_adjoint *ys, dlogp = odeint_adjoint( dynamics, state, t=ts, method=self._integrator_method, rtol=self._integrator_rtol, atol=self._integrator_atol, options=kwargs ) else: # raise NotImplementedError() from anode.adjoint import odesolver_adjoint state = odesolver_adjoint(dynamics, state, options=kwargs) ys = [y[-1] for y in ys] dlogp = dlogp[-1] return (*ys, dlogp)
def test_adjoint(self): """ Test against dopri5 """ f, y0, t_points, _ = construct_problem(TEST_DEVICE) func = lambda y0, t_points: torchdiffeq.odeint( f, y0, t_points, method='dopri5') ys = func(y0, t_points) torch.manual_seed(0) gradys = torch.rand_like(ys) ys.backward(gradys) # reg_y0_grad = y0.grad reg_t_grad = t_points.grad reg_a_grad = f.a.grad reg_b_grad = f.b.grad f, y0, t_points, _ = construct_problem(TEST_DEVICE) func = lambda y0, t_points: torchdiffeq.odeint_adjoint( f, y0, t_points, method='dopri5') ys = func(y0, t_points) ys.backward(gradys) # adj_y0_grad = y0.grad adj_t_grad = t_points.grad adj_a_grad = f.a.grad adj_b_grad = f.b.grad # self.assertLess(max_abs(reg_y0_grad - adj_y0_grad), eps) self.assertLess(max_abs(reg_t_grad - adj_t_grad), eps) self.assertLess(max_abs(reg_a_grad - adj_a_grad), eps) self.assertLess(max_abs(reg_b_grad - adj_b_grad), eps)
def test_seminorm(self): for dtype in DTYPES: for device in DEVICES: for method in ADAPTIVE_METHODS: with self.subTest(dtype=dtype, device=device, method=method): if dtype == torch.float32: tol = 1e-6 else: tol = 1e-8 x0 = torch.tensor([1.0, 2.0], device=device, dtype=dtype) t = torch.tensor([0., 1.0], device=device, dtype=dtype) ode_f = _NeuralF(width=1024, oscillate=True).to(device, dtype) out = torchdiffeq.odeint_adjoint(ode_f, x0, t, atol=tol, rtol=tol, method=method) ode_f.nfe = 0 out.sum().backward() default_nfe = ode_f.nfe out = torchdiffeq.odeint_adjoint( ode_f, x0, t, atol=tol, rtol=tol, method=method, adjoint_options=dict(norm='seminorm')) ode_f.nfe = 0 out.sum().backward() seminorm_nfe = ode_f.nfe self.assertLessEqual(seminorm_nfe, default_nfe)
def forward(self, y0, t=None, **kwargs): if self.t is not None and t is None: t = self.t elif self.t is None and t is not None: pass else: raise ValueError( 'you should add `t` when define NeuralODE or call a NeuralODE object' ) if 'event_fn' not in kwargs or ('event_fn' in kwargs and kwargs['event_fn'] is None): solution = odeint_adjoint(self.func, y0=y0, t=t, **kwargs) if self.last: solution = solution[-1] return solution else: event_t, solution = odeint_adjoint(self.func, y0=y0, t=t, **kwargs) return event_t, solution
def evolve(self, h, time_diff): t = torch.tensor([0, time_diff.item()], dtype=time_diff.dtype, device=time_diff.device) out = torchdiffeq.odeint_adjoint(func=self.func, y0=h, t=t, method='rk4') return out[1]
def _torchdiffeq_adjoint(self, x): return torchdiffeq.odeint_adjoint( self.defunc, x, self.s_span, rtol=self.rtol, atol=self.atol, **self.solver, adjoint_options=dict(norm=make_norm(x)))[-1]
def test_adjoint(self): for device in DEVICES: for method in METHODS: with self.subTest(device=device, method=method): f, y0, t_points, _ = construct_problem(device=device) func = lambda y0, t_points: torchdiffeq.odeint_adjoint( f, y0, t_points, method=method) self.assertTrue( torch.autograd.gradcheck(func, (y0, t_points)))
def forward(self, x): if self.norm: adjoint_options = dict(norm=common.make_norm(x)) else: adjoint_options = None z = torchdiffeq.odeint_adjoint(self.func, x[:, 0], self.times, rtol=self.rtol, atol=self.atol, adjoint_options=adjoint_options) loss = torch.nn.functional.mse_loss(x, z.transpose(0, 1)) return loss
def forward(self, x): self.integration_time = self.integration_time.type_as(x) # Actual ODE solver used in forward propogation out = odeint_adjoint(self.odefunc, x, self.integration_time, rtol=tol, atol=tol) return out[1]
def forward_batched(self, x:torch.Tensor, nn:int, indices:list, timestamps:set): """ Modified forward for ODE batches with different integration times """ timestamps = torch.Tensor(list(timestamps)) if self.adjoint_flag: out = torchdiffeq.odeint_adjoint(self.odefunc, x, timestamps, rtol=self.rtol, atol=self.atol, method=self.method) else: out = torchdiffeq.odeint(self.odefunc, x, timestamps, rtol=self.rtol, atol=self.atol, method=self.method) out = self._build_batch(out, nn, indices).reshape(x.shape) return out
def forward(self, x:torch.Tensor, T:int=1): self.integration_time = torch.tensor([0, T]).float() self.integration_time = self.integration_time.type_as(x) if self.adjoint_flag: out = torchdiffeq.odeint_adjoint(self.odefunc, x, self.integration_time, rtol=self.rtol, atol=self.atol, method=self.method) else: out = torchdiffeq.odeint(self.odefunc, x, self.integration_time, rtol=self.rtol, atol=self.atol, method=self.method) return out[-1]
def integrate(self, t0, t1, x, logpx, tol=None, method=None, norm=None, intermediate_states=0): """ Args: t0: (N,) t1: (N,) x: (N, ...) logpx: (N,) """ self.nfe = 0 tol = tol or self.tol method = method or self.method e = torch.randn_like(x)[:, :self.dim] energy = torch.zeros(1).to(x) jacnorm = torch.zeros(1).to(x) initial_state = (t0, t1, e, x, logpx, energy, jacnorm) if intermediate_states > 1: tt = torch.linspace(self.start_time, self.end_time, intermediate_states).to(t0) else: tt = torch.tensor([self.start_time, self.end_time]).to(t0) solution = odeint_adjoint( self, initial_state, tt, rtol=tol, atol=tol, method=method, ) if intermediate_states > 1: y = solution[3] _, _, _, _, logpy, energy, jacnorm = tuple(s[-1] for s in solution) else: _, _, _, y, logpy, energy, jacnorm = tuple(s[-1] for s in solution) regularization = (self.energy_regularization * (energy - energy.detach()) + self.jacnorm_regularization * (jacnorm - jacnorm.detach())) return y, logpy + regularization # hacky method to introduce regularization.
def test_adjoint(self): for reverse in (False, True): for dtype in DTYPES: for device in DEVICES: for ode in PROBLEMS: if ode == 'linear': eps = 2e-3 else: eps = 1e-4 with self.subTest(reverse=reverse, dtype=dtype, device=device, ode=ode): f, y0, t_points, sol = construct_problem(dtype=dtype, device=device, ode=ode, reverse=reverse) y = torchdiffeq.odeint_adjoint(f, y0, t_points) self.assertLess(rel_error(sol, y), eps)
def trajectory(self, x: torch.Tensor, s_span: torch.Tensor, method='odeint', **kwargs): if method == 'adjoint': solution = odeint_adjoint(self.func, x, s_span, **kwargs) elif method == 'odeint': solution = odeint(self.func, x, s_span, **kwargs) else: raise ValueError( 'Please check parameters `method`, it should be `adjoint` or `odeint`' ) return solution
def forward(self, X, i=-1): ''' Input: X - torch Tensor of size (N, C, W, H). i - int index corresponding to time point of ODE estimation of the solution. Output: X_ti - torch Tensor of size (N, C', W', H'). Estimation of X at time point t[i]. ''' return torchdiffeq.odeint_adjoint(self.f, X, self.t, rtol=TOL, atol=TOL)[i]
def forward(self, vt, x): integration_time_vector = vt.type_as(x) if self.adjoint: out = ode.odeint_adjoint(self.odefunc, x, integration_time_vector, rtol=self.rtol, atol=self.atol, method=self.method) else: out = ode.odeint(self.odefunc, x, integration_time_vector, rtol=self.rtol, atol=self.atol, method=self.method) # return out[-1] return out[-1] if self.terminal else out # 100 * 400 * 10
def test_adjoint(self): f, y0, t_points, sol = construct_problem(device="cpu", ode="constant") def event_fn(t, y): return torch.sum(y - sol[-1]) t, y = torchdiffeq.odeint_adjoint(f, y0, t_points[0:2], event_fn=event_fn, method="dopri5") y = y[-1] self.assertLess(rel_error(sol[-1], y), 1e-4) self.assertLess(rel_error(t_points[-1], t), 1e-4) # Make sure adjoint mode backward code can still be run. t.backward(retain_graph=True) y.sum().backward()