def __init__(self, func:nn.Module, order=1, sensitivity='autograd', s_span=torch.linspace(0, 1, 2), solver='rk4', atol=1e-4, rtol=1e-4, intloss=None): super().__init__(func=DEFunc(func, order), order=order, sensitivity=sensitivity, s_span=s_span, solver=solver, atol=atol, rtol=rtol) self.nfe = self.defunc.nfe self.intloss = intloss self.u, self.controlled = None, False # data-control if sensitivity=='adjoint': self.adjoint = Adjoint(self.defunc, intloss);
def __init__(self, func:nn.Module, order=1, sensitivity='autograd', s_span=torch.linspace(0, 1, 2), solver='rk4', atol=1e-4, rtol=1e-4, intloss=None): super().__init__() #compat_check(defaults) # TO DO: remove controlled from input args self.defunc, self.order = DEFunc(func, order), order self.sensitivity, self.s_span, self.solver = sensitivity, s_span, solver self.nfe = self.defunc.nfe self.rtol, self.atol = rtol, atol self.intloss = intloss self.u, self.controlled = None, False # data-control if sensitivity=='adjoint': self.adjoint = Adjoint(self.intloss);
class NeuralODE(NeuralDETemplate): """General Neural ODE class :param func: function parametrizing the vector field. :type func: nn.Module :param settings: specifies parameters of the Neural DE. :type settings: dict """ def __init__(self, func: nn.Module, order=1, sensitivity='autograd', s_span=torch.linspace(0, 1, 2), solver='rk4', atol=1e-4, rtol=1e-4, intloss=None): super().__init__(func=DEFunc(func, order), order=order, sensitivity=sensitivity, s_span=s_span, solver=solver, atol=atol, rtol=rtol) self.nfe = self.defunc.nfe self.intloss = intloss self.u, self.controlled = None, False # data-control if sensitivity == 'adjoint': self.adjoint = Adjoint(self.defunc, intloss) self._solver_checks(solver, sensitivity) def _solver_checks(self, solver, sensitivity): self.solver = {'method': solver} if solver[:5] == "scipy" and solver not in SCIPY_SOLVERS: available_scipy_solvers = ", ".join(SCIPY_SOLVERS.keys()) raise KeyError("Invalid Scipy Solver specified." + " Supported Scipy Solvers are: " + available_scipy_solvers) elif solver in SCIPY_SOLVERS: warnings.warn( UserWarning("CUDA is not available with SciPy solvers.")) if sensitivity == 'autograd': raise ValueError( "SciPy Solvers do not work with autograd." + " Use adjoint sensitivity with SciPy Solvers.") self.solver = SCIPY_SOLVERS[solver] def _prep_odeint(self, x: torch.Tensor): self.s_span = self.s_span.to(x.device) # loss dimension detection routine; for CNF div propagation and integral losses w/ autograd excess_dims = 0 if (not self.intloss is None) and self.sensitivity == 'autograd': excess_dims += 1 # handle aux. operations required for some jacobian trace CNF estimators e.g Hutchinson's # as well as data-control set to DataControl module for name, module in self.defunc.named_modules(): if hasattr(module, 'trace_estimator'): if module.noise_dist is not None: module.noise = module.noise_dist.sample((x.shape[0], )) excess_dims += 1 # data-control set routine. Is performed once at the beginning of odeint since the control is fixed to IC # TO DO: merge the named_modules loop for perf for name, module in self.defunc.named_modules(): if hasattr(module, 'u'): self.controlled = True module.u = x[:, excess_dims:].detach() return x def forward(self, x: torch.Tensor): x = self._prep_odeint(x) switcher = { 'autograd': self._autograd, 'adjoint': self._adjoint, 'torchdiffeq_adjoint': self._torchdiffeq_adjoint } odeint = switcher.get(self.sensitivity) out = odeint(x) return out def trajectory(self, x: torch.Tensor, s_span: torch.Tensor): """Returns a data-flow trajectory at `s_span` points :param x: input data :type x: torch.Tensor :param s_span: collections of points to evaluate the function at e.g torch.linspace(0, 1, 100) for a 100 point trajectory between 0 and 1 :type s_span: torch.Tensor """ x = self._prep_odeint(x) sol = torchdiffeq.odeint(self.defunc, x, s_span, rtol=self.rtol, atol=self.atol, **self.solver) return sol def sensitivity_trajectory(self, x: torch.Tensor, grad_output: torch.Tensor, s_span: torch.Tensor): assert self.sensitivity == 'adjoint', 'Sensitivity trajectory only available for `adjoint`' x = torch.autograd.Variable(x, requires_grad=True) sol = self(x) adj0 = self.adjoint._init_adjoint_state(sol, grad_output) self.adjoint.flat_params = flatten(self.defunc.parameters()) self.adjoint.func = self.defunc self.adjoint.f_params = tuple(self.defunc.parameters()) adj_sol = torchdiffeq.odeint(self.adjoint.adjoint_dynamics, adj0, s_span, rtol=self.rtol, atol=self.atol, method=self.solver) return adj_sol def _autograd(self, x): self.defunc.intloss, self.defunc.sensitivity = self.intloss, self.sensitivity return torchdiffeq.odeint(self.defunc, x, self.s_span, rtol=self.rtol, atol=self.atol, **self.solver)[-1] def _adjoint(self, x): return self.adjoint(self.defunc, x, self.s_span, rtol=self.rtol, atol=self.atol, **self.solver) def _torchdiffeq_adjoint(self, x): return torchdiffeq.odeint_adjoint( self.defunc, x, self.s_span, rtol=self.rtol, atol=self.atol, **self.solver, adjoint_options=dict(norm=make_norm(x)))[-1]