def forward(ctx, A, neig, mode, M, fwd_options, bck_options, na, *amparams): # A: LinearOperator (*BA, q, q) # M: LinearOperator (*BM, q, q) or None # separate the sets of parameters params = amparams[:na] mparams = amparams[na:] config = set_default_option({}, fwd_options) ctx.bck_config = set_default_option( { # "method": ??? }, bck_options) method = config.pop("method") with A.uselinopparams(*params), M.uselinopparams( *mparams) if M is not None else dummy_context_manager(): methods = { "davidson": davidson, "custom_exacteig": custom_exacteig, } method_fcn = get_method("symeig", methods, method) evals, evecs = method_fcn(A, neig, mode, M, **config) # save for the backward ctx.evals = evals # (*BAM, neig) ctx.evecs = evecs # (*BAM, na, neig) ctx.params = params ctx.A = A ctx.M = M ctx.mparams = mparams return evals, evecs
def forward(ctx, fcn: Callable[..., torch.Tensor], y0: torch.Tensor, options: Mapping[str, Any], bck_options: Mapping[str, Any], nparams: int, *allparams) -> torch.Tensor: # set default options config = set_default_option({ "method": "broyden1", }, options) ctx.bck_options = set_default_option({"method": "exactsolve"}, bck_options) params = allparams[:nparams] objparams = allparams[nparams:] with fcn.useobjparams(objparams): orig_method = config.pop("method") method = orig_method.lower() if method == "broyden1": y = broyden1(fcn, y0, params, **config) else: raise RuntimeError("Unknown rootfinder method: %s" % orig_method) ctx.fcn = fcn # split tensors and non-tensors params ctx.nparams = nparams ctx.param_sep = TensorNonTensorSeparator(allparams) tensor_params = ctx.param_sep.get_tensor_params() ctx.save_for_backward(y, *tensor_params) return y
def forward(ctx, A, neig, mode, M, fwd_options, bck_options, na, *amparams): # A: LinearOperator (*BA, q, q) # M: LinearOperator (*BM, q, q) or None # separate the sets of parameters params = amparams[:na] mparams = amparams[na:] config = set_default_option({}, fwd_options) ctx.bck_config = set_default_option( { # "method": ??? }, bck_options) method = config["method"].lower() if method == "davidson": evals, evecs = davidson(A, params, neig, mode, M, mparams, **config) elif method == "custom_exacteig": evals, evecs = custom_exacteig(A, params, neig, mode, M, mparams, **config) else: raise RuntimeError("Unknown eigen decomposition method: %s" % config["method"]) # save for the backward ctx.evals = evals # (*BAM, neig) ctx.evecs = evecs # (*BAM, na, neig) ctx.params = params ctx.A = A ctx.M = M ctx.mparams = mparams return evals, evecs
def forward(ctx, A, B, E, M, posdef, fwd_options, bck_options, na, *all_params): # A: (*BA, nr, nr) # B: (*BB, nr, ncols) # E: (*BE, ncols) or None # M: (*BM, nr, nr) or None # all_params: list of tensor of any shape # returns: (*BABEM, nr, ncols) # separate the parameters for A and for M params = all_params[:na] mparams = all_params[na:] config = set_default_option({}, fwd_options) ctx.bck_config = set_default_option({}, bck_options) method = config["method"].lower() if torch.all(B == 0): # special case dims = (*_get_batchdims(A, B, E, M), *B.shape[-2:]) x = torch.zeros(dims, dtype=B.dtype, device=B.device) elif method == "custom_exactsolve": x = custom_exactsolve(A, params, B, E=E, M=M, mparams=mparams, **config) elif method == "gmres": x = wrap_gmres(A, params, B, E=E, M=M, mparams=mparams, posdef=posdef, **config) elif method in ["lbfgs", "broyden"]: x = rootfinder_solve(method, A, params, B, E=E, M=M, mparams=mparams, posdef=posdef, **config) else: raise RuntimeError("Unknown solve method: %s" % config["method"]) ctx.A = A ctx.M = M ctx.E = E ctx.x = x ctx.posdef = posdef ctx.params = params ctx.mparams = mparams ctx.na = na return x
def forward(ctx, pfcn, ts, fwd_options, bck_options, nparams, y0, *allparams): config = set_default_option({ "method": "rk45", }, fwd_options) ctx.bck_config = set_default_option(config, bck_options) params = allparams[:nparams] objparams = allparams[nparams:] orig_method = config.pop("method") method = orig_method.lower() try: solver = { "rk4": rk4_ivp, "rk38": rk38_ivp, "rk23": rk23_adaptive, "rk45": rk45_adaptive, }[method] except KeyError: raise RuntimeError("Unknown solve_ivp method: %s" % config["method"]) yt = solver(pfcn, ts, y0, params, **config) # save the parameters for backward ctx.param_sep = TensorNonTensorSeparator(allparams, varonly=True) tensor_params = ctx.param_sep.get_tensor_params() ctx.save_for_backward(ts, y0, *tensor_params) ctx.pfcn = pfcn ctx.nparams = nparams ctx.yt = yt ctx.ts_requires_grad = ts.requires_grad return yt
def forward(ctx, fcn, xl, xu, fwd_options, bck_options, nparams, dtype, device, *all_params): with fcn.disable_state_change(): config = set_default_option({ "method": "leggauss", "n": 100, }, fwd_options) ctx.bck_config = set_default_option(config, bck_options) params = all_params[:nparams] objparams = all_params[nparams:] # convert to tensor xl = torch.as_tensor(xl, dtype=dtype, device=device) xu = torch.as_tensor(xu, dtype=dtype, device=device) # apply transformation if the boundaries contain inf if _isinf(xl) or _isinf(xu): tfm = _TanInfTransform() @make_sibling(fcn) def fcn2(t, *params): ys = fcn(tfm.forward(t), *params) dxdt = tfm.dxdt(t) return ys * dxdt tl = tfm.x2t(xl) tu = tfm.x2t(xu) else: fcn2 = fcn tl = xl tu = xu method = config["method"].lower() if method == "leggauss": y = leggaussquad(fcn2, tl, tu, params, **config) else: raise RuntimeError("Unknown quad method: %s" % config["method"]) # save the parameters for backward ctx.param_sep = TensorNonTensorSeparator(all_params) tensor_params = ctx.param_sep.get_tensor_params() ctx.xltensor = isinstance(xl, torch.Tensor) ctx.xutensor = isinstance(xu, torch.Tensor) xlxu_tensor = ([xl] if ctx.xltensor else []) + \ ([xu] if ctx.xutensor else []) ctx.xlxu_nontensor = ([xl] if not ctx.xltensor else []) + \ ([xu] if not ctx.xutensor else []) ctx.save_for_backward(*xlxu_tensor, *tensor_params) ctx.fcn = fcn ctx.nparams = nparams return y
def forward(ctx, ffcn, log_pfcn, x0, xsamples, wsamples, fwd_options, bck_options, nfparams, nf_objparams, npparams, *all_fpparams): # set up the default options config = set_default_option({ "method": "mh", }, fwd_options) ctx.bck_config = set_default_option(config, bck_options) # split the parameters fparams = all_fpparams[:nfparams] fobjparams = all_fpparams[nfparams:nfparams + nf_objparams] pparams = all_fpparams[nfparams + nf_objparams:nfparams + nf_objparams + npparams] pobjparams = all_fpparams[nfparams + nf_objparams + npparams:] # select the method for the sampling if xsamples is None: method = config["method"].lower() method_fcn = { "mh": mh, "_dummy1d": dummy1d, "mhcustom": mhcustom, } if method not in method_fcn: raise RuntimeError("Unknown mcquad method: %s" % config["method"]) xsamples, wsamples = method_fcn[method](log_pfcn, x0, pparams, **config) epf = _integrate(ffcn, xsamples, wsamples, fparams) # save parameters for backward calculations ctx.xsamples = xsamples ctx.wsamples = wsamples ctx.ffcn = ffcn ctx.log_pfcn = log_pfcn ctx.fparam_sep = TensorNonTensorSeparator((*fparams, *fobjparams)) ctx.pparam_sep = TensorNonTensorSeparator((*pparams, *pobjparams)) ctx.nfparams = len(fparams) ctx.npparams = len(pparams) # save for backward ftensor_params = ctx.fparam_sep.get_tensor_params() ptensor_params = ctx.pparam_sep.get_tensor_params() ctx.nftensorparams = len(ftensor_params) ctx.nptensorparams = len(ptensor_params) ctx.save_for_backward(epf, *ftensor_params, *ptensor_params) return epf
def forward(ctx, pfcn, ts, fwd_options, bck_options, nparams, y0, *allparams): config = fwd_options ctx.bck_config = set_default_option(config, bck_options) params = allparams[:nparams] objparams = allparams[nparams:] method = config.pop("method") methods = { "rk4": rk4_ivp, "rk38": rk38_ivp, "rk23": rk23_adaptive, "rk45": rk45_adaptive, } solver = get_method("solve_ivp", methods, method) yt = solver(pfcn, ts, y0, params, **config) # save the parameters for backward ctx.param_sep = TensorNonTensorSeparator(allparams, varonly=True) tensor_params = ctx.param_sep.get_tensor_params() ctx.save_for_backward(ts, y0, *tensor_params) ctx.pfcn = pfcn ctx.nparams = nparams ctx.yt = yt ctx.ts_requires_grad = ts.requires_grad return yt
def wrap_gmres(A, params, B, E=None, M=None, mparams=[], posdef=False, **options): # A: (*BA, nr, nr) # B: (*BB, nr, ncols) # E: (*BE, ncols) or None # M: (*BM, nr, nr) or None # NOTE: currently only works for batched B (1 batch dim), but unbatched A assert len(A.shape) == 2 and len( B.shape ) == 3, "Currently only works for batched B (1 batch dim), but unbatched A" # check the parameters msg = "GMRES can only do AX=B" assert A.shape[-2] == A.shape[ -1], "GMRES can only work for square operator for now" assert E is None, msg assert M is None, msg # set the default config options nbatch, na, ncols = B.shape config = set_default_option({ "min_eps": 1e-9, "max_niter": 2 * na, }, options) min_eps = config["min_eps"] max_niter = config["max_niter"] B = B.transpose(-1, -2) # (nbatch, ncols, na) # convert the numpy/scipy with A.uselinopparams(*params): op = A.scipy_linalg_op() B_np = B.detach().numpy() res_np = np.empty(B.shape, dtype=np.float64) for i in range(nbatch): for j in range(ncols): x, info = gmres(op, B_np[i, j, :], tol=min_eps, atol=1e-12, maxiter=max_niter) if info > 0: msg = "The GMRES iteration does not converge to the desired value "\ "(%.3e) after %d iterations" % \ (config["min_eps"], info) warnings.warn(msg) res_np[i, j, :] = x res = torch.tensor(res_np, dtype=B.dtype, device=B.device) res = res.transpose(-1, -2) # (nbatch, na, ncols) return res
def gradrca(f, x0, jinv0=1.0, **options): # set up the default options config = set_default_option( { "max_niter": 20, "norders": 2, "min_eps": 1e-6, "verbose": False, }, options) # pull out the options for fast access min_eps = config["min_eps"] verbose = config["verbose"] norders = config["norders"] # pull out the parameters of x0 nbatch, nfeat = x0.shape device = x0.device dtype = x0.dtype # set up the initial jinv jinv = _set_jinv0(jinv0, x0) x = x0 onesvec = torch.ones_like(x0).unsqueeze(-1).to(x0.device) / np.sqrt( nfeat) # (nbatch, nfeat, 1) for i in range(config["max_niter"]): xg = x.detach().requires_grad_() with torch.enable_grad(): dx = f(xg) vunit = (dx / dx.norm(dim=-1, keepdim=True)).detach() # (nbatch, nfeat) loss = (dx * dx).sum(dim=-1) derivs = [loss] for j in range(norders): dldx = torch.autograd.grad(derivs[-1].sum(), (xg, ), create_graph=(j < norders - 1))[0] dldlmbda = (dldx * vunit).sum(dim=-1) derivs.append(dldlmbda) if norders == 2: dstep = (-derivs[1] / derivs[2]) # (nbatch,) elif norders == 3: dstep = (-derivs[2] + torch.sqrt(derivs[2] * derivs[2] - 2 * derivs[1] * derivs[3])) / (derivs[3]) else: raise RuntimeError("Order 4 or higher is not defined.") dstep = (jinv * dstep.unsqueeze(-1) * vunit) if verbose: print("Iter %d: %.3e" % (i + 1, dx.detach().abs().max())) x = x + dstep return x
def forward(ctx, A, B, E, M, method, fwd_options, bck_options, na, *all_params): # A: (*BA, nr, nr) # B: (*BB, nr, ncols) # E: (*BE, ncols) or None # M: (*BM, nr, nr) or None # all_params: list of tensor of any shape # returns: (*BABEM, nr, ncols) # separate the parameters for A and for M params = all_params[:na] mparams = all_params[na:] config = set_default_option({}, fwd_options) ctx.bck_config = set_default_option({}, bck_options) if torch.all(B == 0): # special case dims = (*_get_batchdims(A, B, E, M), *B.shape[-2:]) x = torch.zeros(dims, dtype=B.dtype, device=B.device) else: with A.uselinopparams(*params), M.uselinopparams( *mparams) if M is not None else dummy_context_manager(): methods = { "custom_exactsolve": custom_exactsolve, "scipy_gmres": wrap_gmres, "broyden1": broyden1_solve, "cg": cg, "bicgstab": bicgstab, } method_fcn = get_method("solve", methods, method) x = method_fcn(A, B, E, M, **config) ctx.e_is_none = E is None ctx.A = A ctx.M = M if ctx.e_is_none: ctx.save_for_backward(x, *all_params) else: ctx.save_for_backward(x, E, *all_params) ctx.na = na return x
def forward(ctx, A, neig, mode, M, fwd_options, bck_options, na, *amparams): # A: LinearOperator (*BA, q, q) # M: LinearOperator (*BM, q, q) or None # separate the sets of parameters params = amparams[:na] mparams = amparams[na:] config = set_default_option({}, fwd_options) ctx.bck_config = set_default_option( { "degen_atol": None, "degen_rtol": None, }, bck_options) # options for calculating the backward (not for `solve`) alg_keys = ["degen_atol", "degen_rtol"] ctx.bck_alg_config = get_and_pop_keys(ctx.bck_config, alg_keys) method = config.pop("method") with A.uselinopparams(*params), M.uselinopparams( *mparams) if M is not None else dummy_context_manager(): methods = { "davidson": davidson, "custom_exacteig": custom_exacteig, } method_fcn = get_method("symeig", methods, method) evals, evecs = method_fcn(A, neig, mode, M, **config) # save for the backward # evals: (*BAM, neig) # evecs: (*BAM, na, neig) ctx.save_for_backward(evals, evecs, *amparams) ctx.na = na ctx.A = A ctx.M = M return evals, evecs
def forward(ctx, ffcn, log_pfcn, x0, xsamples, wsamples, method, fwd_options, bck_options, nfparams, nf_objparams, npparams, *all_fpparams): # set up the default options config = fwd_options ctx.bck_config = set_default_option(config, bck_options) # split the parameters fparams = all_fpparams[:nfparams] fobjparams = all_fpparams[nfparams:nfparams + nf_objparams] pparams = all_fpparams[nfparams + nf_objparams:nfparams + nf_objparams + npparams] pobjparams = all_fpparams[nfparams + nf_objparams + npparams:] # select the method for the sampling if xsamples is None: methods = { "mh": mh, "_dummy1d": dummy1d, "mhcustom": mhcustom, } method_fcn = get_method("mcquad", methods, method) xsamples, wsamples = method_fcn(log_pfcn, x0, pparams, **config) epf = _integrate(ffcn, xsamples, wsamples, fparams) # save parameters for backward calculations ctx.xsamples = xsamples ctx.wsamples = wsamples ctx.ffcn = ffcn ctx.log_pfcn = log_pfcn ctx.fparam_sep = TensorNonTensorSeparator((*fparams, *fobjparams)) ctx.pparam_sep = TensorNonTensorSeparator((*pparams, *pobjparams)) ctx.nfparams = len(fparams) ctx.npparams = len(pparams) ctx.method = method # save for backward ftensor_params = ctx.fparam_sep.get_tensor_params() ptensor_params = ctx.pparam_sep.get_tensor_params() ctx.nftensorparams = len(ftensor_params) ctx.nptensorparams = len(ptensor_params) ctx.save_for_backward(epf, *ftensor_params, *ptensor_params) return epf
def broyden(f, x0, jinv0=1.0, **options): """ Solve the root finder problem with Broyden's method. Arguments --------- * f: callable Callable that takes params as the input and output nfeat-outputs. * x0: torch.tensor (nbatch, nfeat) Initial value of parameters to be put in the function, f. * jinv0: float or torch.tensor (nbatch, nfeat, nfeat) The initial inverse of the Jacobian. If float, it will be the diagonal. * options: dict or None Options of the function. Returns ------- * x: torch.tensor (nbatch, nfeat) The x that approximate f(x) = 0. """ raise RuntimeError("This method is unfinished. Please use other methods.") # set up the default options config = set_default_option( { "max_niter": 20, "min_eps": 1e-6, "verbose": False, }, options) # pull out the options for fast access min_eps = config["min_eps"] verbose = config["verbose"] # pull out the parameters of x0 nbatch, nfeat = x0.shape device = x0.device dtype = x0.dtype # set up the initial jinv jinv = _set_jinv0(jinv0, x0) # perform the Broyden iterations x = x0 fx = f(x0) # (nbatch, nfeat) stop_reason = "max_niter" for i in range(config["max_niter"]): dxnew = -jinv * fx # (nbatch, nfeat) xnew = x + dxnew # (nbatch, nfeat) fxnew = f(xnew) dfnew = fxnew - fx # calculate the new jinv xtnew_jinv = torch.bmm(xnew.unsqueeze(1), jinv) # (nbatch, 1, nfeat) jinv_dfnew = torch.bmm(jinv, dfnew.unsqueeze(-1)) # (nbatch, nfeat, 1) xtnew_jinv_dfnew = torch.bmm(xtnew_jinv, dfnew.unsqueeze(-1)) # (nbatch, 1, 1) jinvnew = jinv + torch.bmm(dxnew - jinv_dfnew, xtnew_jinv) / xtnew_jinv_dfnew # update variables for the next iteration fx = fxnew jinv = jinvnew x = xnew # check the stopping condition if verbose: print("Iter %3d: %.3e" % (i + 1, fx.abs().max())) if torch.allclose(fx, torch.zeros_like(fx), atol=min_eps): stop_reason = "min_eps" break if stop_reason != "min_eps": msg = "The Broyden iteration does not converge to the required accuracy." msg += "\nRequired: %.3e. Achieved: %.3e" % (min_eps, fx.abs().max()) warnings.warn(msg) return x
def davidson(A, params, neig, mode, M=None, mparams=[], **options): """ Iterative methods to obtain the `neig` lowest eigenvalues and eigenvectors. This function is written so that the backpropagation can be done. It solves the eigendecomposition AV = VME where V are the matrix of eigenvectors, and E are the diagonal matrix consists of the eigenvalues. Arguments --------- * A: LinearOperator instance (*BA, na, na) The linear operator object on which the eigenpairs are constructed. * params: list of differentiable torch.tensor of any shapes List of differentiable torch.tensor to be put to A. * neig: int The number of eigenpairs to be retrieved. * mode: str Take the `neig` "lowest" or "uppest" of the eigenpairs. * M: LinearOperator instance (*BM, na, na) or None The transformation on the right hand side. If None, then M=I. * mparams: list of differentiable torch.tensor of any shapes List of differentiable torch.tensor to be put to M. * **options: Iterative algorithm options. Returns ------- * eigvals: torch.tensor (*BAM, neig) * eigvecs: torch.tensor (*BAM, na, neig) The `neig` lowest eigenpairs """ # TODO: optimize for large linear operator and strict min_eps # Ideas: # (1) use better strategy to get the estimate on eigenvalues # (2) use restart strategy config = set_default_option( { "max_niter": 1000, "nguess": neig, # number of initial guess "min_eps": 1e-6, "verbose": False, "eps_cond": 1e-6, "v_init": "randn", "max_addition": neig, }, options) # get some of the options nguess = config["nguess"] max_niter = config["max_niter"] min_eps = config["min_eps"] verbose = config["verbose"] eps_cond = config["eps_cond"] max_addition = config["max_addition"] # get the shape of the transformation na = A.shape[-1] if M is None: bcast_dims = A.shape[:-2] else: bcast_dims = get_bcasted_dims(A.shape[:-2], M.shape[:-2]) dtype = A.dtype device = A.device # TODO: A to use params prev_eigvals = None prev_eigvalT = None stop_reason = "max_niter" shift_is_eigvalT = False idx = torch.arange(neig).unsqueeze(-1) # (neig, 1) with A.uselinopparams(*params), M.uselinopparams( *mparams) if M is not None else dummy_context_manager(): # set up the initial guess V = _set_initial_v(config["v_init"].lower(), dtype, device, bcast_dims, na, nguess, M=M, mparams=mparams) # (*BAM, na, nguess) # V = V.reshape(*bcast_dims, na, nguess) # (*BAM, na, nguess) # estimating the lowest eigenvalues eig_est, rms_eig = _estimate_eigvals(A, neig, mode, bcast_dims=bcast_dims, na=na, ntest=20, dtype=V.dtype, device=V.device) best_resid = float("inf") AV = A.mm(V) for i in range(max_niter): VT = V.transpose(-2, -1) # (*BAM,nguess,na) # Can be optimized by saving AV from the previous iteration and only # operate AV for the new V. This works because the old V has already # been orthogonalized, so it will stay the same # AV = A.mm(V) # (*BAM,na,nguess) T = torch.matmul(VT, AV) # (*BAM,nguess,nguess) # eigvals are sorted from the lowest # eval: (*BAM, nguess), evec: (*BAM, nguess, nguess) eigvalT, eigvecT = torch.symeig(T, eigenvectors=True) eigvalT, eigvecT = _take_eigpairs( eigvalT, eigvecT, neig, mode) # (*BAM, neig) and (*BAM, nguess, neig) # calculate the eigenvectors of A eigvecA = torch.matmul(V, eigvecT) # (*BAM, na, neig) # calculate the residual AVs = torch.matmul(AV, eigvecT) # (*BAM, na, neig) LVs = eigvalT.unsqueeze(-2) * eigvecA # (*BAM, na, neig) if M is not None: LVs = M.mm(LVs) resid = AVs - LVs # (*BAM, na, neig) # print information and check convergence max_resid = resid.abs().max() if prev_eigvalT is not None: deigval = eigvalT - prev_eigvalT max_deigval = deigval.abs().max() if verbose: print("Iter %3d (guess size: %d): resid: %.3e, devals: %.3e" % \ (i+1, nguess, max_resid, max_deigval)) if max_resid < best_resid: best_resid = max_resid best_eigvals = eigvalT best_eigvecs = eigvecA if max_resid < min_eps: break if AV.shape[-1] == AV.shape[-2]: break prev_eigvalT = eigvalT # apply the preconditioner # initial guess of the eigenvalues are actually help really much if not shift_is_eigvalT: z = eig_est # (*BAM,neig) else: z = eigvalT # (*BAM,neig) # if A.is_precond_set(): # t = A.precond(-resid, *params, biases=z, M=M, mparams=mparams) # (nbatch, na, neig) # else: t = -resid # (*BAM, na, neig) # set the estimate of the eigenvalues if not shift_is_eigvalT: eigvalT_pred = eigvalT + torch.einsum( '...ae,...ae->...e', eigvecA, A.mm(t)) # (*BAM, neig) diff_eigvalT = (eigvalT - eigvalT_pred) # (*BAM, neig) if diff_eigvalT.abs().max() < rms_eig * 1e-2: shift_is_eigvalT = True else: change_idx = eig_est > eigvalT next_value = eigvalT - 2 * diff_eigvalT eig_est[change_idx] = next_value[change_idx] # orthogonalize t with the rest of the V t = to_fortran_order(t) Vnew = torch.cat((V, t), dim=-1) if Vnew.shape[-1] > Vnew.shape[-2]: Vnew = Vnew[..., :Vnew.shape[-2]] nadd = Vnew.shape[-1] - V.shape[-1] nguess = nguess + nadd if M is not None: MV_ = M.mm(Vnew) V, R = tallqr(Vnew, MV=MV_) else: V, R = tallqr(Vnew) AVnew = A.mm(V[..., -nadd:]) # (*BAM,na,nadd) AVnew = to_fortran_order(AVnew) AV = torch.cat((AV, AVnew), dim=-1) eigvals = best_eigvals # (*BAM, neig) eigvecs = best_eigvecs # (*BAM, na, neig) return eigvals, eigvecs
def diis(f, x0, jinv0=1.0, **options): """ Solve the root finder problem with DIIS method. Arguments --------- * f: callable Callable that takes params as the input and output nfeat-outputs. * x0: torch.tensor (nbatch, nfeat) Initial value of parameters to be put in the function, f. * jinv0: float or torch.tensor (nbatch, nfeat, nfeat) The initial inverse of the Jacobian. If float, it will be the diagonal. * options: dict or None Options of the function. Returns ------- * x: torch.tensor (nbatch, nfeat) The x that approximate f(x) = 0. """ # set up the default options config = set_default_option( { "max_niter": 20, "min_eps": 1e-6, "max_memory": 20, "minit": 10, "verbose": False, }, options) # pull out the options for fast access min_eps = config["min_eps"] verbose = config["verbose"] max_memory = config["max_memory"] minit = config["minit"] # pull out the parameters of x0 nbatch, nfeat = x0.shape device = x0.device dtype = x0.dtype # set up the initial jinv jinv = _set_jinv0(jinv0, x0) # perform the iterations x = x0 fx = f(x0) # (nbatch, nfeat) stop_reason = "max_niter" bestcrit = float("inf") nbatch, nfeat = fx.shape x_history = torch.empty((nbatch, max_memory, nfeat), dtype=x.dtype, device=x.device) e_history = torch.empty((nbatch, max_memory, nfeat), dtype=x.dtype, device=x.device) mfill = 0 midx = 0 for i in range(config["max_niter"]): if mfill < 2: dxnew = -jinv * fx # (nbatch, nfeat) xnew = x + dxnew # (nbatch, nfeat) else: # construct the matrix B fx_tensor = e_history[:, :mfill, :] # (nbatch, m, nfeat) Bul = torch.matmul(fx_tensor, fx_tensor.transpose(-2, -1)) # (nbatch, m, m) Bu = torch.cat((Bul, -torch.ones(Bul.shape[0], Bul.shape[1], 1, dtype=Bul.dtype, device=Bul.device)), dim=-1) # (nbatch, m, m+1) B = torch.cat((Bu, -torch.ones( Bu.shape[0], 1, Bu.shape[-1], dtype=Bu.dtype, device=Bu.device)), dim=-2) # (nbatch, m+1, m+1) B[:, -1, -1] = 0.0 # solve the linear equation to get the coefficients a = torch.zeros(B.shape[0], B.shape[1], 1, dtype=B.dtype, device=B.device) # (nbatch, m+1, 1) a[:, -1] = -1.0 c = torch.solve(a, B)[0].squeeze(-1)[:, :mfill] # (nbatch, m) x_tensor = x_history[:, :mfill, :] # (nbatch, m, nfeat) xnew = (x_tensor * c.unsqueeze(-1)).sum(dim=1) # (nbatch, m) fxnew = f(xnew) # (nbatch, nfeat) dx = xnew - x # update variables for the next iteration fx = fxnew x = xnew # add the history if i >= minit: x_history[:, midx, :] = x # e_history[:,midx,:] = dx e_history[:, midx, :] = fx midx = (midx + 1) % max_memory mfill = (mfill + 1) if mfill < max_memory else mfill # get the best results crit = fx.abs().max() if crit < bestcrit: bestcrit = crit bestx = x # check the stopping condition if verbose: print("Iter %3d: %.3e" % (i + 1, crit)) if torch.allclose(fx, torch.zeros_like(fx), atol=min_eps): stop_reason = "min_eps" break if stop_reason != "min_eps": msg = "The DIIS iteration does not converge to the required accuracy." msg += "\nRequired: %.3e. Achieved: %.3e" % (min_eps, bestcrit) warnings.warn(msg) return bestx
def selfconsistent(f, x0, jinv0=1.0, **options): """ Solve the root finder problem with Broyden's method. Arguments --------- * f: callable Callable that takes params as the input and output nfeat-outputs. * x0: torch.tensor (nbatch, nfeat) Initial value of parameters to be put in the function, f. * jinv0: float or torch.tensor (nbatch, nfeat, nfeat) The initial inverse of the Jacobian. If float, it will be the diagonal. * options: dict or None Options of the function. Returns ------- * x: torch.tensor (nbatch, nfeat) The x that approximate f(x) = 0. """ # set up the default options config = set_default_option( { "max_niter": 20, "min_eps": 1e-6, "beta": 0.9, # contribution of the new delta_n to the total delta_n "jinvdecay": 1.0, "decayevery": 100, "verbose": False, }, options) # pull out the options for fast access min_eps = config["min_eps"] verbose = config["verbose"] beta = config["beta"] jinvdecay = config["jinvdecay"] decayevery = config["decayevery"] # pull out the parameters of x0 nbatch, nfeat = x0.shape device = x0.device dtype = x0.dtype # set up the initial jinv jinv = _set_jinv0(jinv0, x0) # perform the Broyden iterations x = x0 fx = f(x0) # (nbatch, nfeat) stop_reason = "max_niter" dx = torch.zeros_like(x).to(x.device) bestcrit = float("inf") for i in range(config["max_niter"]): dxnew = -jinv0 * fx # (nbatch, nfeat) dx = (1 - beta) * dx + beta * dxnew xnew = x + dx # (nbatch, nfeat) fxnew = f(xnew) dfnew = fxnew - fx # update variables for the next iteration fx = fxnew x = xnew if (i + 1) % decayevery == 0: jinv = jinv * jinvdecay # get the best results crit = fx.abs().max() if crit < bestcrit: bestcrit = crit bestx = x # check the stopping condition if verbose: print("Iter %3d: %.3e" % (i + 1, crit)) if torch.allclose(fx, torch.zeros_like(fx), atol=min_eps): stop_reason = "min_eps" break if stop_reason != "min_eps": msg = "The selfconsistent iteration does not converge to the required accuracy." msg += "\nRequired: %.3e. Achieved: %.3e" % (min_eps, bestcrit) warnings.warn(msg) return bestx
def lbfgs(f, x0, jinv0=1.0, **options): """ Solve the root finder problem with L-BFGS method. Arguments --------- * f: callable Callable that takes params as the input and output nfeat-outputs. * x0: torch.tensor (*, nfeat) Initial value of parameters to be put in the function, f. * jinv0: float or torch.tensor (nbatch, nfeat, nfeat) The initial inverse of the Jacobian. If float, it will be the diagonal. * options: dict or None Options of the function. Returns ------- * x: torch.tensor (nbatch, nfeat) The x that approximate f(x) = 0. """ config = set_default_option( { "max_niter": 20, "min_eps": 1e-6, "max_memory": 10, "alpha0": 1.0, "linesearch": False, "verbose": False, }, options) # pull out the options for fast access min_eps = config["min_eps"] max_memory = config["max_memory"] verbose = config["verbose"] linesearch = config["linesearch"] alpha = config["alpha0"] # set up the initial jinv and the memories H0 = _set_jinv0_diag(jinv0, x0) # (*, nfeat) sk_history = [] yk_history = [] rk_history = [] def _apply_Vk(rk, sk, yk, grad): # sk: (*, nfeat) # yk: (*, nfeat) # rk: (*, 1) return grad - (sk * grad).sum(dim=-1, keepdim=True) * rk * yk def _apply_VkT(rk, sk, yk, grad): # sk: (*, nfeat) # yk: (*, nfeat) # rk: (*, 1) return grad - (yk * grad).sum(dim=-1, keepdim=True) * rk * sk def _apply_Hk(H0, sk_hist, yk_hist, rk_hist, gk): # H0: (*, nfeat) # sk: (*, nfeat) # yk: (*, nfeat) # rk: (*, 1) # gk: (*, nfeat) nhist = len(sk_hist) if nhist == 0: return H0 * gk k = nhist - 1 rk = rk_hist[k] sk = sk_hist[k] yk = yk_hist[k] # get the last term (rk * sk * sk.T) rksksk = (sk * gk).sum(dim=-1, keepdim=True) * rk * sk # calculate the V_(k-1) grad = gk grad = _apply_Vk(rk_hist[k], sk_hist[k], yk_hist[k], grad) grad = _apply_Hk(H0, sk_hist[:k], yk_hist[:k], rk_hist[:k], grad) grad = _apply_VkT(rk_hist[k], sk_hist[k], yk_hist[k], grad) return grad + rksksk def _line_search(xk, gk, dk, g): if linesearch: dx, dg, nit = line_search(dk, xk, gk, g) return xk + dx, gk + dg else: return xk + alpha * dk, g(xk + alpha * dk) # perform the main iteration xk = x0 gk = f(xk) bestgk = gk.abs().max() bestx = x0 stop_reason = "max_niter" for k in range(config["max_niter"]): dk = -_apply_Hk(H0, sk_history, yk_history, rk_history, gk) xknew, gknew = _line_search(xk, gk, dk, f) # store the history sk = xknew - xk # (*, nfeat) yk = gknew - gk inv_rhok = 1.0 / (sk * yk).sum(dim=-1, keepdim=True) # (*, 1) sk_history.append(sk) yk_history.append(yk) rk_history.append(inv_rhok) if len(sk_history) > max_memory: sk_history = sk_history[-max_memory:] yk_history = yk_history[-max_memory:] rk_history = rk_history[-max_memory:] # update for the next iteration xk = xknew # alphakold = alphak gk = gknew # save the best point maxgk = gk.abs().max() if maxgk < bestgk: bestx = xk bestgk = maxgk # check the stopping condition if verbose: print("Iter %3d: %.3e" % (k + 1, gk.abs().max())) if maxgk < min_eps: stop_reason = "min_eps" break if stop_reason != "min_eps": msg = "The L-BFGS iteration does not converge to the required accuracy." msg += "\nRequired: %.3e. Achieved: %.3e" % (min_eps, bestgk) warnings.warn(msg) return bestx