def forward(ctx, ffcn, log_pfcn, x0, xsamples, wsamples, fwd_options, bck_options, nfparams, nf_objparams, npparams, *all_fpparams): # set up the default options config = set_default_option({ "method": "mh", }, fwd_options) ctx.bck_config = set_default_option(config, bck_options) # split the parameters fparams = all_fpparams[:nfparams] fobjparams = all_fpparams[nfparams:nfparams + nf_objparams] pparams = all_fpparams[nfparams + nf_objparams:nfparams + nf_objparams + npparams] pobjparams = all_fpparams[nfparams + nf_objparams + npparams:] # select the method for the sampling if xsamples is None: method = config["method"].lower() method_fcn = { "mh": mh, "_dummy1d": dummy1d, "mhcustom": mhcustom, } if method not in method_fcn: raise RuntimeError("Unknown mcquad method: %s" % config["method"]) xsamples, wsamples = method_fcn[method](log_pfcn, x0, pparams, **config) epf = _integrate(ffcn, xsamples, wsamples, fparams) # save parameters for backward calculations ctx.xsamples = xsamples ctx.wsamples = wsamples ctx.ffcn = ffcn ctx.log_pfcn = log_pfcn ctx.fparam_sep = TensorNonTensorSeparator((*fparams, *fobjparams)) ctx.pparam_sep = TensorNonTensorSeparator((*pparams, *pobjparams)) ctx.nfparams = len(fparams) ctx.npparams = len(pparams) # save for backward ftensor_params = ctx.fparam_sep.get_tensor_params() ptensor_params = ctx.pparam_sep.get_tensor_params() ctx.nftensorparams = len(ftensor_params) ctx.nptensorparams = len(ptensor_params) ctx.save_for_backward(epf, *ftensor_params, *ptensor_params) return epf
def forward(ctx, pfcn, ts, fwd_options, bck_options, nparams, y0, *allparams): config = fwd_options ctx.bck_config = set_default_option(config, bck_options) params = allparams[:nparams] objparams = allparams[nparams:] method = config.pop("method") methods = { "rk4": rk4_ivp, "rk38": rk38_ivp, "rk23": rk23_adaptive, "rk45": rk45_adaptive, } solver = get_method("solve_ivp", methods, method) yt = solver(pfcn, ts, y0, params, **config) # save the parameters for backward ctx.param_sep = TensorNonTensorSeparator(allparams, varonly=True) tensor_params = ctx.param_sep.get_tensor_params() ctx.save_for_backward(ts, y0, *tensor_params) ctx.pfcn = pfcn ctx.nparams = nparams ctx.yt = yt ctx.ts_requires_grad = ts.requires_grad return yt
def forward(ctx, fcn, y0, options, bck_options, nparams, *allparams): # set default options config = options ctx.bck_options = bck_options params = allparams[:nparams] objparams = allparams[nparams:] with fcn.useobjparams(objparams): method = config.pop("method") methods = { "broyden1": broyden1, "broyden2": broyden2, "linearmixing": linearmixing, } method_fcn = get_method("rootfinder", methods, method) y = method_fcn(fcn, y0, params, **config) ctx.fcn = fcn # split tensors and non-tensors params ctx.nparams = nparams ctx.param_sep = TensorNonTensorSeparator(allparams) tensor_params = ctx.param_sep.get_tensor_params() ctx.save_for_backward(y, *tensor_params) return y
def forward(ctx, fcn: Callable[..., torch.Tensor], y0: torch.Tensor, options: Mapping[str, Any], bck_options: Mapping[str, Any], nparams: int, *allparams) -> torch.Tensor: # set default options config = set_default_option({ "method": "broyden1", }, options) ctx.bck_options = set_default_option({"method": "exactsolve"}, bck_options) params = allparams[:nparams] objparams = allparams[nparams:] with fcn.useobjparams(objparams): orig_method = config.pop("method") method = orig_method.lower() if method == "broyden1": y = broyden1(fcn, y0, params, **config) else: raise RuntimeError("Unknown rootfinder method: %s" % orig_method) ctx.fcn = fcn # split tensors and non-tensors params ctx.nparams = nparams ctx.param_sep = TensorNonTensorSeparator(allparams) tensor_params = ctx.param_sep.get_tensor_params() ctx.save_for_backward(y, *tensor_params) return y
def forward(ctx, pfcn, ts, fwd_options, bck_options, nparams, y0, *allparams): config = fwd_options ctx.bck_config = set_default_option(config, bck_options) params = allparams[:nparams] objparams = allparams[nparams:] orig_method = config.pop("method") method = orig_method.lower() try: solver = { "rk4": rk4_ivp, "rk38": rk38_ivp, "rk23": rk23_adaptive, "rk45": rk45_adaptive, }[method] except KeyError: raise RuntimeError("Unknown solve_ivp method: %s" % config["method"]) yt = solver(pfcn, ts, y0, params, **config) # save the parameters for backward ctx.param_sep = TensorNonTensorSeparator(allparams, varonly=True) tensor_params = ctx.param_sep.get_tensor_params() ctx.save_for_backward(ts, y0, *tensor_params) ctx.pfcn = pfcn ctx.nparams = nparams ctx.yt = yt ctx.ts_requires_grad = ts.requires_grad return yt
def forward(ctx, fcn, y0, fwd_fcn, is_opt_method, options, bck_options, nparams, *allparams): # fcn: a function that returns what has to be 0 (will be used in the # backward, not used in the forward). For minimization, it is # the gradient # fwd_fcn: a function that will be executed in the forward method # (unused in the backward) # This class is also used for minimization, where fcn and fwd_fcn might # be slightly different # set default options config = options ctx.bck_options = bck_options params = allparams[:nparams] objparams = allparams[nparams:] with fwd_fcn.useobjparams(objparams): method = config.pop("method") methods = _RF_METHODS if not is_opt_method else _OPT_METHODS name = "rootfinder" if not is_opt_method else "minimizer" method_fcn = get_method(name, methods, method) y = method_fcn(fwd_fcn, y0, params, **config) ctx.fcn = fcn ctx.is_opt_method = is_opt_method # split tensors and non-tensors params ctx.nparams = nparams ctx.param_sep = TensorNonTensorSeparator(allparams) tensor_params = ctx.param_sep.get_tensor_params() ctx.save_for_backward(y, *tensor_params) return y
def __init__(self, fcn: Callable[..., torch.Tensor], params: Sequence[Any], idx: int, is_hermitian=False) -> None: # TODO: check if fcn has kwargs # run once to get the shapes and numels yparam = params[idx] with torch.enable_grad(): yout = fcn(*params) # (*nout) v = torch.ones_like(yout).to( yout.device).requires_grad_() # (*nout) dfdy, = torch.autograd.grad(yout, (yparam, ), grad_outputs=v, create_graph=True) # (*nin) inshape = yparam.shape outshape = yout.shape nin = torch.numel(yparam) nout = torch.numel(yout) super(_Jac, self).__init__(shape=(nout, nin), is_hermitian=is_hermitian, dtype=yparam.dtype, device=yparam.device) self.fcn = fcn self.yparam = yparam self.params = list(params) self.objparams = fcn.objparams() self.yout = yout self.v = v self.idx = idx self.dfdy = dfdy self.inshape = inshape self.outshape = outshape self.nin = nin self.nout = nout # params tensor is the LinearOperator's parameters self.param_sep = TensorNonTensorSeparator(params) self.params_tensor = self.param_sep.get_tensor_params() self.id_params_tensor = [id(param) for param in self.params_tensor] self.id_objparams_tensor = [id(param) for param in self.objparams]
def forward(ctx, ffcn, log_pfcn, x0, xsamples, wsamples, method, fwd_options, bck_options, nfparams, nf_objparams, npparams, *all_fpparams): # set up the default options config = fwd_options ctx.bck_config = set_default_option(config, bck_options) # split the parameters fparams = all_fpparams[:nfparams] fobjparams = all_fpparams[nfparams:nfparams + nf_objparams] pparams = all_fpparams[nfparams + nf_objparams:nfparams + nf_objparams + npparams] pobjparams = all_fpparams[nfparams + nf_objparams + npparams:] # select the method for the sampling if xsamples is None: methods = { "mh": mh, "_dummy1d": dummy1d, "mhcustom": mhcustom, } method_fcn = get_method("mcquad", methods, method) xsamples, wsamples = method_fcn(log_pfcn, x0, pparams, **config) epf = _integrate(ffcn, xsamples, wsamples, fparams) # save parameters for backward calculations ctx.xsamples = xsamples ctx.wsamples = wsamples ctx.ffcn = ffcn ctx.log_pfcn = log_pfcn ctx.fparam_sep = TensorNonTensorSeparator((*fparams, *fobjparams)) ctx.pparam_sep = TensorNonTensorSeparator((*pparams, *pobjparams)) ctx.nfparams = len(fparams) ctx.npparams = len(pparams) ctx.method = method # save for backward ftensor_params = ctx.fparam_sep.get_tensor_params() ptensor_params = ctx.pparam_sep.get_tensor_params() ctx.nftensorparams = len(ftensor_params) ctx.nptensorparams = len(ptensor_params) ctx.save_for_backward(epf, *ftensor_params, *ptensor_params) return epf
def forward(ctx, fcn, xl, xu, fwd_options, bck_options, nparams, dtype, device, *all_params): with fcn.disable_state_change(): config = set_default_option({ "method": "leggauss", "n": 100, }, fwd_options) ctx.bck_config = set_default_option(config, bck_options) params = all_params[:nparams] objparams = all_params[nparams:] # convert to tensor xl = torch.as_tensor(xl, dtype=dtype, device=device) xu = torch.as_tensor(xu, dtype=dtype, device=device) # apply transformation if the boundaries contain inf if _isinf(xl) or _isinf(xu): tfm = _TanInfTransform() @make_sibling(fcn) def fcn2(t, *params): ys = fcn(tfm.forward(t), *params) dxdt = tfm.dxdt(t) return ys * dxdt tl = tfm.x2t(xl) tu = tfm.x2t(xu) else: fcn2 = fcn tl = xl tu = xu method = config["method"].lower() if method == "leggauss": y = leggaussquad(fcn2, tl, tu, params, **config) else: raise RuntimeError("Unknown quad method: %s" % config["method"]) # save the parameters for backward ctx.param_sep = TensorNonTensorSeparator(all_params) tensor_params = ctx.param_sep.get_tensor_params() ctx.xltensor = isinstance(xl, torch.Tensor) ctx.xutensor = isinstance(xu, torch.Tensor) xlxu_tensor = ([xl] if ctx.xltensor else []) + \ ([xu] if ctx.xutensor else []) ctx.xlxu_nontensor = ([xl] if not ctx.xltensor else []) + \ ([xu] if not ctx.xutensor else []) ctx.save_for_backward(*xlxu_tensor, *tensor_params) ctx.fcn = fcn ctx.nparams = nparams return y
class _Jac(LinearOperator): def __init__(self, fcn: Callable[..., torch.Tensor], params: Sequence[Any], idx: int, is_hermitian=False) -> None: # TODO: check if fcn has kwargs # run once to get the shapes and numels yparam = params[idx] with torch.enable_grad(): yout = fcn(*params) # (*nout) v = torch.ones_like(yout).to( yout.device).requires_grad_() # (*nout) dfdy, = torch.autograd.grad(yout, (yparam, ), grad_outputs=v, create_graph=True) # (*nin) inshape = yparam.shape outshape = yout.shape nin = torch.numel(yparam) nout = torch.numel(yout) super(_Jac, self).__init__(shape=(nout, nin), is_hermitian=is_hermitian, dtype=yparam.dtype, device=yparam.device) self.fcn = fcn self.yparam = yparam self.params = list(params) self.objparams = fcn.objparams() self.yout = yout self.v = v self.idx = idx self.dfdy = dfdy self.inshape = inshape self.outshape = outshape self.nin = nin self.nout = nout # params tensor is the LinearOperator's parameters self.param_sep = TensorNonTensorSeparator(params) self.params_tensor = self.param_sep.get_tensor_params() self.id_params_tensor = [id(param) for param in self.params_tensor] self.id_objparams_tensor = [id(param) for param in self.objparams] def _getparamnames(self, prefix: str = "") -> Sequence[str]: return [prefix+"yparam"] + \ [prefix+("params_tensor[%d]"%i) for i in range(len(self.params_tensor))] + \ [prefix+("objparams[%d]"%i) for i in range(len(self.objparams))] def _mv(self, gy: torch.Tensor) -> torch.Tensor: # gy: (..., nin) # returns: (..., nout) # if the object parameter is still the same, then use the pre-calculated values if self.__param_tensors_unchanged(): v = self.v dfdy = self.dfdy # otherwise, reevaluate by replacing the parameters with the new tensor params else: with torch.enable_grad(), self.fcn.useobjparams(self.objparams): self.__update_params() yparam = self.params[self.idx] yout = self.fcn(*self.params) # (*nout) v = torch.ones_like(yout).to( yout.device).requires_grad_() # (*nout) dfdy, = torch.autograd.grad(yout, (yparam, ), grad_outputs=v, create_graph=True) # (*nin) gy1 = gy.reshape(-1, self.nin) # (nbatch, nin) nbatch = gy1.shape[0] dfdyfs = [] for i in range(nbatch): dfdyf, = torch.autograd.grad( dfdy, (v, ), grad_outputs=gy1[i].reshape(self.inshape), retain_graph=True, create_graph=torch.is_grad_enabled()) # (*nout) dfdyfs.append(dfdyf.unsqueeze(0)) dfdyfs = torch.cat(dfdyfs, dim=0) # (nbatch, *nout) res = dfdyfs.reshape(*gy.shape[:-1], self.nout) # (..., nout) res = connect_graph(res, self.params_tensor) return res def _rmv(self, gout: torch.Tensor) -> torch.Tensor: # gout: (..., nout) # self.yfcn: (*nin) if self.__param_tensors_unchanged(): yout = self.yout yparam = self.yparam else: with torch.enable_grad(), self.fcn.useobjparams(self.objparams): self.__update_params() yparam = self.params[self.idx] yout = self.fcn(*self.params) # (*nout) gout1 = gout.reshape(-1, self.nout) # (nbatch, nout) nbatch = gout1.shape[0] dfdy = [] for i in range(nbatch): one_dfdy, = torch.autograd.grad( yout, (yparam, ), grad_outputs=gout1[i].reshape(self.outshape), retain_graph=True, create_graph=torch.is_grad_enabled()) # (*nin) dfdy.append(one_dfdy.unsqueeze(0)) dfdy = torch.cat(dfdy, dim=0) # (nbatch, *nin) res = dfdy.reshape(*gout.shape[:-1], self.nin) # (..., nin) res = connect_graph(res, self.params_tensor) return res # (..., nin) def __param_tensors_unchanged(self): return [id(param) for param in self.params_tensor] == self.id_params_tensor and \ [id(param) for param in self.objparams] == self.id_objparams_tensor def __update_params(self): self.params = self.param_sep.reconstruct_params(self.params_tensor)