Example #1
0
 def __init__(self, func:nn.Module,
                    order=1,
                    sensitivity='autograd',
                    s_span=torch.linspace(0, 1, 2),
                    solver='rk4',
                    atol=1e-4,
                    rtol=1e-4,
                    intloss=None):
     super().__init__(func=DEFunc(func, order), order=order, sensitivity=sensitivity, s_span=s_span, solver=solver,
                                    atol=atol, rtol=rtol)
     self.nfe = self.defunc.nfe
     self.intloss = intloss
     self.u, self.controlled = None, False # data-control
     if sensitivity=='adjoint': self.adjoint = Adjoint(self.defunc, intloss);
Example #2
0
    def __init__(self, func:nn.Module,
                       order=1,
                       sensitivity='autograd',
                       s_span=torch.linspace(0, 1, 2),
                       solver='rk4',
                       atol=1e-4,
                       rtol=1e-4,
                       intloss=None):
        super().__init__()
        #compat_check(defaults)
        # TO DO: remove controlled from input args
        self.defunc, self.order = DEFunc(func, order), order
        self.sensitivity, self.s_span, self.solver = sensitivity, s_span, solver
        self.nfe = self.defunc.nfe
        self.rtol, self.atol = rtol, atol
        self.intloss = intloss
        self.u, self.controlled = None, False # data-control

        if sensitivity=='adjoint': self.adjoint = Adjoint(self.intloss);
Example #3
0
class NeuralODE(NeuralDETemplate):
    """General Neural ODE class

    :param func: function parametrizing the vector field.
    :type func: nn.Module
    :param settings: specifies parameters of the Neural DE.
    :type settings: dict
    """
    def __init__(self,
                 func: nn.Module,
                 order=1,
                 sensitivity='autograd',
                 s_span=torch.linspace(0, 1, 2),
                 solver='rk4',
                 atol=1e-4,
                 rtol=1e-4,
                 intloss=None):
        super().__init__(func=DEFunc(func, order),
                         order=order,
                         sensitivity=sensitivity,
                         s_span=s_span,
                         solver=solver,
                         atol=atol,
                         rtol=rtol)
        self.nfe = self.defunc.nfe
        self.intloss = intloss
        self.u, self.controlled = None, False  # data-control
        if sensitivity == 'adjoint':
            self.adjoint = Adjoint(self.defunc, intloss)

        self._solver_checks(solver, sensitivity)

    def _solver_checks(self, solver, sensitivity):

        self.solver = {'method': solver}

        if solver[:5] == "scipy" and solver not in SCIPY_SOLVERS:
            available_scipy_solvers = ", ".join(SCIPY_SOLVERS.keys())
            raise KeyError("Invalid Scipy Solver specified." +
                           " Supported Scipy Solvers are: " +
                           available_scipy_solvers)

        elif solver in SCIPY_SOLVERS:
            warnings.warn(
                UserWarning("CUDA is not available with SciPy solvers."))

            if sensitivity == 'autograd':
                raise ValueError(
                    "SciPy Solvers do not work with autograd." +
                    " Use adjoint sensitivity with SciPy Solvers.")

            self.solver = SCIPY_SOLVERS[solver]

    def _prep_odeint(self, x: torch.Tensor):
        self.s_span = self.s_span.to(x.device)

        # loss dimension detection routine; for CNF div propagation and integral losses w/ autograd
        excess_dims = 0
        if (not self.intloss is None) and self.sensitivity == 'autograd':
            excess_dims += 1

        # handle aux. operations required for some jacobian trace CNF estimators e.g Hutchinson's
        # as well as data-control set to DataControl module
        for name, module in self.defunc.named_modules():
            if hasattr(module, 'trace_estimator'):
                if module.noise_dist is not None:
                    module.noise = module.noise_dist.sample((x.shape[0], ))
                excess_dims += 1

        # data-control set routine. Is performed once at the beginning of odeint since the control is fixed to IC
        # TO DO: merge the named_modules loop for perf
        for name, module in self.defunc.named_modules():
            if hasattr(module, 'u'):
                self.controlled = True
                module.u = x[:, excess_dims:].detach()

        return x

    def forward(self, x: torch.Tensor):
        x = self._prep_odeint(x)
        switcher = {
            'autograd': self._autograd,
            'adjoint': self._adjoint,
            'torchdiffeq_adjoint': self._torchdiffeq_adjoint
        }
        odeint = switcher.get(self.sensitivity)
        out = odeint(x)
        return out

    def trajectory(self, x: torch.Tensor, s_span: torch.Tensor):
        """Returns a data-flow trajectory at `s_span` points

        :param x: input data
        :type x: torch.Tensor
        :param s_span: collections of points to evaluate the function at e.g torch.linspace(0, 1, 100) for a 100 point trajectory
                       between 0 and 1
        :type s_span: torch.Tensor
        """
        x = self._prep_odeint(x)
        sol = torchdiffeq.odeint(self.defunc,
                                 x,
                                 s_span,
                                 rtol=self.rtol,
                                 atol=self.atol,
                                 **self.solver)
        return sol

    def sensitivity_trajectory(self, x: torch.Tensor,
                               grad_output: torch.Tensor,
                               s_span: torch.Tensor):
        assert self.sensitivity == 'adjoint', 'Sensitivity trajectory only available for `adjoint`'
        x = torch.autograd.Variable(x, requires_grad=True)
        sol = self(x)
        adj0 = self.adjoint._init_adjoint_state(sol, grad_output)
        self.adjoint.flat_params = flatten(self.defunc.parameters())
        self.adjoint.func = self.defunc
        self.adjoint.f_params = tuple(self.defunc.parameters())
        adj_sol = torchdiffeq.odeint(self.adjoint.adjoint_dynamics,
                                     adj0,
                                     s_span,
                                     rtol=self.rtol,
                                     atol=self.atol,
                                     method=self.solver)
        return adj_sol

    def _autograd(self, x):
        self.defunc.intloss, self.defunc.sensitivity = self.intloss, self.sensitivity
        return torchdiffeq.odeint(self.defunc,
                                  x,
                                  self.s_span,
                                  rtol=self.rtol,
                                  atol=self.atol,
                                  **self.solver)[-1]

    def _adjoint(self, x):
        return self.adjoint(self.defunc,
                            x,
                            self.s_span,
                            rtol=self.rtol,
                            atol=self.atol,
                            **self.solver)

    def _torchdiffeq_adjoint(self, x):
        return torchdiffeq.odeint_adjoint(
            self.defunc,
            x,
            self.s_span,
            rtol=self.rtol,
            atol=self.atol,
            **self.solver,
            adjoint_options=dict(norm=make_norm(x)))[-1]