Exemplo n.º 1
0
    def forward(ctx, ffcn, log_pfcn, x0, xsamples, wsamples, fwd_options,
                bck_options, nfparams, nf_objparams, npparams, *all_fpparams):
        # set up the default options
        config = set_default_option({
            "method": "mh",
        }, fwd_options)
        ctx.bck_config = set_default_option(config, bck_options)

        # split the parameters
        fparams = all_fpparams[:nfparams]
        fobjparams = all_fpparams[nfparams:nfparams + nf_objparams]
        pparams = all_fpparams[nfparams + nf_objparams:nfparams +
                               nf_objparams + npparams]
        pobjparams = all_fpparams[nfparams + nf_objparams + npparams:]

        # select the method for the sampling
        if xsamples is None:
            method = config["method"].lower()
            method_fcn = {
                "mh": mh,
                "_dummy1d": dummy1d,
                "mhcustom": mhcustom,
            }
            if method not in method_fcn:
                raise RuntimeError("Unknown mcquad method: %s" %
                                   config["method"])
            xsamples, wsamples = method_fcn[method](log_pfcn, x0, pparams,
                                                    **config)
        epf = _integrate(ffcn, xsamples, wsamples, fparams)

        # save parameters for backward calculations
        ctx.xsamples = xsamples
        ctx.wsamples = wsamples
        ctx.ffcn = ffcn
        ctx.log_pfcn = log_pfcn
        ctx.fparam_sep = TensorNonTensorSeparator((*fparams, *fobjparams))
        ctx.pparam_sep = TensorNonTensorSeparator((*pparams, *pobjparams))
        ctx.nfparams = len(fparams)
        ctx.npparams = len(pparams)

        # save for backward
        ftensor_params = ctx.fparam_sep.get_tensor_params()
        ptensor_params = ctx.pparam_sep.get_tensor_params()
        ctx.nftensorparams = len(ftensor_params)
        ctx.nptensorparams = len(ptensor_params)
        ctx.save_for_backward(epf, *ftensor_params, *ptensor_params)

        return epf
Exemplo n.º 2
0
    def forward(ctx, pfcn, ts, fwd_options, bck_options, nparams, y0,
                *allparams):
        config = fwd_options
        ctx.bck_config = set_default_option(config, bck_options)

        params = allparams[:nparams]
        objparams = allparams[nparams:]

        method = config.pop("method")
        methods = {
            "rk4": rk4_ivp,
            "rk38": rk38_ivp,
            "rk23": rk23_adaptive,
            "rk45": rk45_adaptive,
        }
        solver = get_method("solve_ivp", methods, method)
        yt = solver(pfcn, ts, y0, params, **config)

        # save the parameters for backward
        ctx.param_sep = TensorNonTensorSeparator(allparams, varonly=True)
        tensor_params = ctx.param_sep.get_tensor_params()
        ctx.save_for_backward(ts, y0, *tensor_params)
        ctx.pfcn = pfcn
        ctx.nparams = nparams
        ctx.yt = yt
        ctx.ts_requires_grad = ts.requires_grad

        return yt
Exemplo n.º 3
0
    def forward(ctx, fcn, y0, options, bck_options, nparams, *allparams):

        # set default options
        config = options
        ctx.bck_options = bck_options

        params = allparams[:nparams]
        objparams = allparams[nparams:]

        with fcn.useobjparams(objparams):

            method = config.pop("method")
            methods = {
                "broyden1": broyden1,
                "broyden2": broyden2,
                "linearmixing": linearmixing,
            }
            method_fcn = get_method("rootfinder", methods, method)
            y = method_fcn(fcn, y0, params, **config)

        ctx.fcn = fcn

        # split tensors and non-tensors params
        ctx.nparams = nparams
        ctx.param_sep = TensorNonTensorSeparator(allparams)
        tensor_params = ctx.param_sep.get_tensor_params()
        ctx.save_for_backward(y, *tensor_params)

        return y
Exemplo n.º 4
0
    def forward(ctx, fcn: Callable[..., torch.Tensor], y0: torch.Tensor,
                options: Mapping[str, Any], bck_options: Mapping[str, Any],
                nparams: int, *allparams) -> torch.Tensor:

        # set default options
        config = set_default_option({
            "method": "broyden1",
        }, options)
        ctx.bck_options = set_default_option({"method": "exactsolve"},
                                             bck_options)

        params = allparams[:nparams]
        objparams = allparams[nparams:]

        with fcn.useobjparams(objparams):

            orig_method = config.pop("method")
            method = orig_method.lower()
            if method == "broyden1":
                y = broyden1(fcn, y0, params, **config)
            else:
                raise RuntimeError("Unknown rootfinder method: %s" %
                                   orig_method)

        ctx.fcn = fcn

        # split tensors and non-tensors params
        ctx.nparams = nparams
        ctx.param_sep = TensorNonTensorSeparator(allparams)
        tensor_params = ctx.param_sep.get_tensor_params()
        ctx.save_for_backward(y, *tensor_params)

        return y
Exemplo n.º 5
0
    def forward(ctx, pfcn, ts, fwd_options, bck_options, nparams, y0, *allparams):
        config = fwd_options
        ctx.bck_config = set_default_option(config, bck_options)

        params = allparams[:nparams]
        objparams = allparams[nparams:]

        orig_method = config.pop("method")
        method = orig_method.lower()
        try:
            solver = {
                "rk4": rk4_ivp,
                "rk38": rk38_ivp,
                "rk23": rk23_adaptive,
                "rk45": rk45_adaptive,
            }[method]
        except KeyError:
            raise RuntimeError("Unknown solve_ivp method: %s" % config["method"])
        yt = solver(pfcn, ts, y0, params, **config)

        # save the parameters for backward
        ctx.param_sep = TensorNonTensorSeparator(allparams, varonly=True)
        tensor_params = ctx.param_sep.get_tensor_params()
        ctx.save_for_backward(ts, y0, *tensor_params)
        ctx.pfcn = pfcn
        ctx.nparams = nparams
        ctx.yt = yt
        ctx.ts_requires_grad = ts.requires_grad

        return yt
Exemplo n.º 6
0
    def forward(ctx, fcn, y0, fwd_fcn, is_opt_method, options, bck_options,
                nparams, *allparams):
        # fcn: a function that returns what has to be 0 (will be used in the
        #      backward, not used in the forward). For minimization, it is
        #      the gradient
        # fwd_fcn: a function that will be executed in the forward method
        #          (unused in the backward)
        # This class is also used for minimization, where fcn and fwd_fcn might
        # be slightly different

        # set default options
        config = options
        ctx.bck_options = bck_options

        params = allparams[:nparams]
        objparams = allparams[nparams:]

        with fwd_fcn.useobjparams(objparams):

            method = config.pop("method")
            methods = _RF_METHODS if not is_opt_method else _OPT_METHODS
            name = "rootfinder" if not is_opt_method else "minimizer"
            method_fcn = get_method(name, methods, method)
            y = method_fcn(fwd_fcn, y0, params, **config)

        ctx.fcn = fcn
        ctx.is_opt_method = is_opt_method

        # split tensors and non-tensors params
        ctx.nparams = nparams
        ctx.param_sep = TensorNonTensorSeparator(allparams)
        tensor_params = ctx.param_sep.get_tensor_params()
        ctx.save_for_backward(y, *tensor_params)

        return y
Exemplo n.º 7
0
    def __init__(self,
                 fcn: Callable[..., torch.Tensor],
                 params: Sequence[Any],
                 idx: int,
                 is_hermitian=False) -> None:

        # TODO: check if fcn has kwargs

        # run once to get the shapes and numels
        yparam = params[idx]
        with torch.enable_grad():
            yout = fcn(*params)  # (*nout)
            v = torch.ones_like(yout).to(
                yout.device).requires_grad_()  # (*nout)
            dfdy, = torch.autograd.grad(yout, (yparam, ),
                                        grad_outputs=v,
                                        create_graph=True)  # (*nin)

        inshape = yparam.shape
        outshape = yout.shape
        nin = torch.numel(yparam)
        nout = torch.numel(yout)

        super(_Jac, self).__init__(shape=(nout, nin),
                                   is_hermitian=is_hermitian,
                                   dtype=yparam.dtype,
                                   device=yparam.device)

        self.fcn = fcn
        self.yparam = yparam
        self.params = list(params)
        self.objparams = fcn.objparams()
        self.yout = yout
        self.v = v
        self.idx = idx
        self.dfdy = dfdy
        self.inshape = inshape
        self.outshape = outshape
        self.nin = nin
        self.nout = nout

        # params tensor is the LinearOperator's parameters
        self.param_sep = TensorNonTensorSeparator(params)
        self.params_tensor = self.param_sep.get_tensor_params()
        self.id_params_tensor = [id(param) for param in self.params_tensor]
        self.id_objparams_tensor = [id(param) for param in self.objparams]
Exemplo n.º 8
0
    def forward(ctx, ffcn, log_pfcn, x0, xsamples, wsamples, method,
                fwd_options, bck_options, nfparams, nf_objparams, npparams,
                *all_fpparams):
        # set up the default options
        config = fwd_options
        ctx.bck_config = set_default_option(config, bck_options)

        # split the parameters
        fparams = all_fpparams[:nfparams]
        fobjparams = all_fpparams[nfparams:nfparams + nf_objparams]
        pparams = all_fpparams[nfparams + nf_objparams:nfparams +
                               nf_objparams + npparams]
        pobjparams = all_fpparams[nfparams + nf_objparams + npparams:]

        # select the method for the sampling
        if xsamples is None:
            methods = {
                "mh": mh,
                "_dummy1d": dummy1d,
                "mhcustom": mhcustom,
            }
            method_fcn = get_method("mcquad", methods, method)
            xsamples, wsamples = method_fcn(log_pfcn, x0, pparams, **config)
        epf = _integrate(ffcn, xsamples, wsamples, fparams)

        # save parameters for backward calculations
        ctx.xsamples = xsamples
        ctx.wsamples = wsamples
        ctx.ffcn = ffcn
        ctx.log_pfcn = log_pfcn
        ctx.fparam_sep = TensorNonTensorSeparator((*fparams, *fobjparams))
        ctx.pparam_sep = TensorNonTensorSeparator((*pparams, *pobjparams))
        ctx.nfparams = len(fparams)
        ctx.npparams = len(pparams)
        ctx.method = method

        # save for backward
        ftensor_params = ctx.fparam_sep.get_tensor_params()
        ptensor_params = ctx.pparam_sep.get_tensor_params()
        ctx.nftensorparams = len(ftensor_params)
        ctx.nptensorparams = len(ptensor_params)
        ctx.save_for_backward(epf, *ftensor_params, *ptensor_params)

        return epf
Exemplo n.º 9
0
    def forward(ctx, fcn, xl, xu, fwd_options, bck_options, nparams, dtype,
                device, *all_params):

        with fcn.disable_state_change():

            config = set_default_option({
                "method": "leggauss",
                "n": 100,
            }, fwd_options)
            ctx.bck_config = set_default_option(config, bck_options)

            params = all_params[:nparams]
            objparams = all_params[nparams:]

            # convert to tensor
            xl = torch.as_tensor(xl, dtype=dtype, device=device)
            xu = torch.as_tensor(xu, dtype=dtype, device=device)

            # apply transformation if the boundaries contain inf
            if _isinf(xl) or _isinf(xu):
                tfm = _TanInfTransform()

                @make_sibling(fcn)
                def fcn2(t, *params):
                    ys = fcn(tfm.forward(t), *params)
                    dxdt = tfm.dxdt(t)
                    return ys * dxdt

                tl = tfm.x2t(xl)
                tu = tfm.x2t(xu)
            else:
                fcn2 = fcn
                tl = xl
                tu = xu

            method = config["method"].lower()
            if method == "leggauss":
                y = leggaussquad(fcn2, tl, tu, params, **config)
            else:
                raise RuntimeError("Unknown quad method: %s" %
                                   config["method"])

            # save the parameters for backward
            ctx.param_sep = TensorNonTensorSeparator(all_params)
            tensor_params = ctx.param_sep.get_tensor_params()
            ctx.xltensor = isinstance(xl, torch.Tensor)
            ctx.xutensor = isinstance(xu, torch.Tensor)
            xlxu_tensor = ([xl] if ctx.xltensor else []) + \
                          ([xu] if ctx.xutensor else [])
            ctx.xlxu_nontensor = ([xl] if not ctx.xltensor else []) + \
                                 ([xu] if not ctx.xutensor else [])
            ctx.save_for_backward(*xlxu_tensor, *tensor_params)
            ctx.fcn = fcn
            ctx.nparams = nparams
            return y
Exemplo n.º 10
0
class _Jac(LinearOperator):
    def __init__(self,
                 fcn: Callable[..., torch.Tensor],
                 params: Sequence[Any],
                 idx: int,
                 is_hermitian=False) -> None:

        # TODO: check if fcn has kwargs

        # run once to get the shapes and numels
        yparam = params[idx]
        with torch.enable_grad():
            yout = fcn(*params)  # (*nout)
            v = torch.ones_like(yout).to(
                yout.device).requires_grad_()  # (*nout)
            dfdy, = torch.autograd.grad(yout, (yparam, ),
                                        grad_outputs=v,
                                        create_graph=True)  # (*nin)

        inshape = yparam.shape
        outshape = yout.shape
        nin = torch.numel(yparam)
        nout = torch.numel(yout)

        super(_Jac, self).__init__(shape=(nout, nin),
                                   is_hermitian=is_hermitian,
                                   dtype=yparam.dtype,
                                   device=yparam.device)

        self.fcn = fcn
        self.yparam = yparam
        self.params = list(params)
        self.objparams = fcn.objparams()
        self.yout = yout
        self.v = v
        self.idx = idx
        self.dfdy = dfdy
        self.inshape = inshape
        self.outshape = outshape
        self.nin = nin
        self.nout = nout

        # params tensor is the LinearOperator's parameters
        self.param_sep = TensorNonTensorSeparator(params)
        self.params_tensor = self.param_sep.get_tensor_params()
        self.id_params_tensor = [id(param) for param in self.params_tensor]
        self.id_objparams_tensor = [id(param) for param in self.objparams]

    def _getparamnames(self, prefix: str = "") -> Sequence[str]:
        return [prefix+"yparam"] + \
               [prefix+("params_tensor[%d]"%i) for i in range(len(self.params_tensor))] + \
               [prefix+("objparams[%d]"%i) for i in range(len(self.objparams))]

    def _mv(self, gy: torch.Tensor) -> torch.Tensor:
        # gy: (..., nin)
        # returns: (..., nout)

        # if the object parameter is still the same, then use the pre-calculated values
        if self.__param_tensors_unchanged():
            v = self.v
            dfdy = self.dfdy
        # otherwise, reevaluate by replacing the parameters with the new tensor params
        else:
            with torch.enable_grad(), self.fcn.useobjparams(self.objparams):
                self.__update_params()
                yparam = self.params[self.idx]
                yout = self.fcn(*self.params)  # (*nout)
                v = torch.ones_like(yout).to(
                    yout.device).requires_grad_()  # (*nout)
                dfdy, = torch.autograd.grad(yout, (yparam, ),
                                            grad_outputs=v,
                                            create_graph=True)  # (*nin)

        gy1 = gy.reshape(-1, self.nin)  # (nbatch, nin)
        nbatch = gy1.shape[0]
        dfdyfs = []
        for i in range(nbatch):
            dfdyf, = torch.autograd.grad(
                dfdy, (v, ),
                grad_outputs=gy1[i].reshape(self.inshape),
                retain_graph=True,
                create_graph=torch.is_grad_enabled())  # (*nout)
            dfdyfs.append(dfdyf.unsqueeze(0))
        dfdyfs = torch.cat(dfdyfs, dim=0)  # (nbatch, *nout)

        res = dfdyfs.reshape(*gy.shape[:-1], self.nout)  # (..., nout)
        res = connect_graph(res, self.params_tensor)
        return res

    def _rmv(self, gout: torch.Tensor) -> torch.Tensor:
        # gout: (..., nout)
        # self.yfcn: (*nin)
        if self.__param_tensors_unchanged():
            yout = self.yout
            yparam = self.yparam
        else:
            with torch.enable_grad(), self.fcn.useobjparams(self.objparams):
                self.__update_params()
                yparam = self.params[self.idx]
                yout = self.fcn(*self.params)  # (*nout)

        gout1 = gout.reshape(-1, self.nout)  # (nbatch, nout)
        nbatch = gout1.shape[0]
        dfdy = []
        for i in range(nbatch):
            one_dfdy, = torch.autograd.grad(
                yout, (yparam, ),
                grad_outputs=gout1[i].reshape(self.outshape),
                retain_graph=True,
                create_graph=torch.is_grad_enabled())  # (*nin)
            dfdy.append(one_dfdy.unsqueeze(0))
        dfdy = torch.cat(dfdy, dim=0)  # (nbatch, *nin)

        res = dfdy.reshape(*gout.shape[:-1], self.nin)  # (..., nin)
        res = connect_graph(res, self.params_tensor)
        return res  # (..., nin)

    def __param_tensors_unchanged(self):
        return [id(param) for param in self.params_tensor] == self.id_params_tensor and \
               [id(param) for param in self.objparams] == self.id_objparams_tensor

    def __update_params(self):
        self.params = self.param_sep.reconstruct_params(self.params_tensor)