Beispiel #1
0
    def forward(ctx, A, neig, mode, M, fwd_options, bck_options, na,
                *amparams):
        # A: LinearOperator (*BA, q, q)
        # M: LinearOperator (*BM, q, q) or None

        # separate the sets of parameters
        params = amparams[:na]
        mparams = amparams[na:]

        config = set_default_option({}, fwd_options)
        ctx.bck_config = set_default_option(
            {
                # "method": ???
            },
            bck_options)

        method = config.pop("method")
        with A.uselinopparams(*params), M.uselinopparams(
                *mparams) if M is not None else dummy_context_manager():
            methods = {
                "davidson": davidson,
                "custom_exacteig": custom_exacteig,
            }
            method_fcn = get_method("symeig", methods, method)
            evals, evecs = method_fcn(A, neig, mode, M, **config)

        # save for the backward
        ctx.evals = evals  # (*BAM, neig)
        ctx.evecs = evecs  # (*BAM, na, neig)
        ctx.params = params
        ctx.A = A
        ctx.M = M
        ctx.mparams = mparams
        return evals, evecs
Beispiel #2
0
    def forward(ctx, fcn: Callable[..., torch.Tensor], y0: torch.Tensor,
                options: Mapping[str, Any], bck_options: Mapping[str, Any],
                nparams: int, *allparams) -> torch.Tensor:

        # set default options
        config = set_default_option({
            "method": "broyden1",
        }, options)
        ctx.bck_options = set_default_option({"method": "exactsolve"},
                                             bck_options)

        params = allparams[:nparams]
        objparams = allparams[nparams:]

        with fcn.useobjparams(objparams):

            orig_method = config.pop("method")
            method = orig_method.lower()
            if method == "broyden1":
                y = broyden1(fcn, y0, params, **config)
            else:
                raise RuntimeError("Unknown rootfinder method: %s" %
                                   orig_method)

        ctx.fcn = fcn

        # split tensors and non-tensors params
        ctx.nparams = nparams
        ctx.param_sep = TensorNonTensorSeparator(allparams)
        tensor_params = ctx.param_sep.get_tensor_params()
        ctx.save_for_backward(y, *tensor_params)

        return y
Beispiel #3
0
    def forward(ctx, A, neig, mode, M, fwd_options, bck_options, na,
                *amparams):
        # A: LinearOperator (*BA, q, q)
        # M: LinearOperator (*BM, q, q) or None

        # separate the sets of parameters
        params = amparams[:na]
        mparams = amparams[na:]

        config = set_default_option({}, fwd_options)
        ctx.bck_config = set_default_option(
            {
                # "method": ???
            },
            bck_options)

        method = config["method"].lower()
        if method == "davidson":
            evals, evecs = davidson(A, params, neig, mode, M, mparams,
                                    **config)
        elif method == "custom_exacteig":
            evals, evecs = custom_exacteig(A, params, neig, mode, M, mparams,
                                           **config)
        else:
            raise RuntimeError("Unknown eigen decomposition method: %s" %
                               config["method"])

        # save for the backward
        ctx.evals = evals  # (*BAM, neig)
        ctx.evecs = evecs  # (*BAM, na, neig)
        ctx.params = params
        ctx.A = A
        ctx.M = M
        ctx.mparams = mparams
        return evals, evecs
Beispiel #4
0
    def forward(ctx, A, B, E, M, posdef, fwd_options, bck_options, na,
                *all_params):
        # A: (*BA, nr, nr)
        # B: (*BB, nr, ncols)
        # E: (*BE, ncols) or None
        # M: (*BM, nr, nr) or None
        # all_params: list of tensor of any shape
        # returns: (*BABEM, nr, ncols)

        # separate the parameters for A and for M
        params = all_params[:na]
        mparams = all_params[na:]

        config = set_default_option({}, fwd_options)
        ctx.bck_config = set_default_option({}, bck_options)

        method = config["method"].lower()

        if torch.all(B == 0):  # special case
            dims = (*_get_batchdims(A, B, E, M), *B.shape[-2:])
            x = torch.zeros(dims, dtype=B.dtype, device=B.device)
        elif method == "custom_exactsolve":
            x = custom_exactsolve(A,
                                  params,
                                  B,
                                  E=E,
                                  M=M,
                                  mparams=mparams,
                                  **config)
        elif method == "gmres":
            x = wrap_gmres(A,
                           params,
                           B,
                           E=E,
                           M=M,
                           mparams=mparams,
                           posdef=posdef,
                           **config)
        elif method in ["lbfgs", "broyden"]:
            x = rootfinder_solve(method,
                                 A,
                                 params,
                                 B,
                                 E=E,
                                 M=M,
                                 mparams=mparams,
                                 posdef=posdef,
                                 **config)
        else:
            raise RuntimeError("Unknown solve method: %s" % config["method"])

        ctx.A = A
        ctx.M = M
        ctx.E = E
        ctx.x = x
        ctx.posdef = posdef
        ctx.params = params
        ctx.mparams = mparams
        ctx.na = na
        return x
Beispiel #5
0
    def forward(ctx, pfcn, ts, fwd_options, bck_options, nparams, y0,
                *allparams):
        config = set_default_option({
            "method": "rk45",
        }, fwd_options)
        ctx.bck_config = set_default_option(config, bck_options)

        params = allparams[:nparams]
        objparams = allparams[nparams:]

        orig_method = config.pop("method")
        method = orig_method.lower()
        try:
            solver = {
                "rk4": rk4_ivp,
                "rk38": rk38_ivp,
                "rk23": rk23_adaptive,
                "rk45": rk45_adaptive,
            }[method]
        except KeyError:
            raise RuntimeError("Unknown solve_ivp method: %s" %
                               config["method"])
        yt = solver(pfcn, ts, y0, params, **config)

        # save the parameters for backward
        ctx.param_sep = TensorNonTensorSeparator(allparams, varonly=True)
        tensor_params = ctx.param_sep.get_tensor_params()
        ctx.save_for_backward(ts, y0, *tensor_params)
        ctx.pfcn = pfcn
        ctx.nparams = nparams
        ctx.yt = yt
        ctx.ts_requires_grad = ts.requires_grad

        return yt
Beispiel #6
0
    def forward(ctx, fcn, xl, xu, fwd_options, bck_options, nparams, dtype,
                device, *all_params):

        with fcn.disable_state_change():

            config = set_default_option({
                "method": "leggauss",
                "n": 100,
            }, fwd_options)
            ctx.bck_config = set_default_option(config, bck_options)

            params = all_params[:nparams]
            objparams = all_params[nparams:]

            # convert to tensor
            xl = torch.as_tensor(xl, dtype=dtype, device=device)
            xu = torch.as_tensor(xu, dtype=dtype, device=device)

            # apply transformation if the boundaries contain inf
            if _isinf(xl) or _isinf(xu):
                tfm = _TanInfTransform()

                @make_sibling(fcn)
                def fcn2(t, *params):
                    ys = fcn(tfm.forward(t), *params)
                    dxdt = tfm.dxdt(t)
                    return ys * dxdt

                tl = tfm.x2t(xl)
                tu = tfm.x2t(xu)
            else:
                fcn2 = fcn
                tl = xl
                tu = xu

            method = config["method"].lower()
            if method == "leggauss":
                y = leggaussquad(fcn2, tl, tu, params, **config)
            else:
                raise RuntimeError("Unknown quad method: %s" %
                                   config["method"])

            # save the parameters for backward
            ctx.param_sep = TensorNonTensorSeparator(all_params)
            tensor_params = ctx.param_sep.get_tensor_params()
            ctx.xltensor = isinstance(xl, torch.Tensor)
            ctx.xutensor = isinstance(xu, torch.Tensor)
            xlxu_tensor = ([xl] if ctx.xltensor else []) + \
                          ([xu] if ctx.xutensor else [])
            ctx.xlxu_nontensor = ([xl] if not ctx.xltensor else []) + \
                                 ([xu] if not ctx.xutensor else [])
            ctx.save_for_backward(*xlxu_tensor, *tensor_params)
            ctx.fcn = fcn
            ctx.nparams = nparams
            return y
Beispiel #7
0
    def forward(ctx, ffcn, log_pfcn, x0, xsamples, wsamples, fwd_options,
                bck_options, nfparams, nf_objparams, npparams, *all_fpparams):
        # set up the default options
        config = set_default_option({
            "method": "mh",
        }, fwd_options)
        ctx.bck_config = set_default_option(config, bck_options)

        # split the parameters
        fparams = all_fpparams[:nfparams]
        fobjparams = all_fpparams[nfparams:nfparams + nf_objparams]
        pparams = all_fpparams[nfparams + nf_objparams:nfparams +
                               nf_objparams + npparams]
        pobjparams = all_fpparams[nfparams + nf_objparams + npparams:]

        # select the method for the sampling
        if xsamples is None:
            method = config["method"].lower()
            method_fcn = {
                "mh": mh,
                "_dummy1d": dummy1d,
                "mhcustom": mhcustom,
            }
            if method not in method_fcn:
                raise RuntimeError("Unknown mcquad method: %s" %
                                   config["method"])
            xsamples, wsamples = method_fcn[method](log_pfcn, x0, pparams,
                                                    **config)
        epf = _integrate(ffcn, xsamples, wsamples, fparams)

        # save parameters for backward calculations
        ctx.xsamples = xsamples
        ctx.wsamples = wsamples
        ctx.ffcn = ffcn
        ctx.log_pfcn = log_pfcn
        ctx.fparam_sep = TensorNonTensorSeparator((*fparams, *fobjparams))
        ctx.pparam_sep = TensorNonTensorSeparator((*pparams, *pobjparams))
        ctx.nfparams = len(fparams)
        ctx.npparams = len(pparams)

        # save for backward
        ftensor_params = ctx.fparam_sep.get_tensor_params()
        ptensor_params = ctx.pparam_sep.get_tensor_params()
        ctx.nftensorparams = len(ftensor_params)
        ctx.nptensorparams = len(ptensor_params)
        ctx.save_for_backward(epf, *ftensor_params, *ptensor_params)

        return epf
Beispiel #8
0
    def forward(ctx, pfcn, ts, fwd_options, bck_options, nparams, y0,
                *allparams):
        config = fwd_options
        ctx.bck_config = set_default_option(config, bck_options)

        params = allparams[:nparams]
        objparams = allparams[nparams:]

        method = config.pop("method")
        methods = {
            "rk4": rk4_ivp,
            "rk38": rk38_ivp,
            "rk23": rk23_adaptive,
            "rk45": rk45_adaptive,
        }
        solver = get_method("solve_ivp", methods, method)
        yt = solver(pfcn, ts, y0, params, **config)

        # save the parameters for backward
        ctx.param_sep = TensorNonTensorSeparator(allparams, varonly=True)
        tensor_params = ctx.param_sep.get_tensor_params()
        ctx.save_for_backward(ts, y0, *tensor_params)
        ctx.pfcn = pfcn
        ctx.nparams = nparams
        ctx.yt = yt
        ctx.ts_requires_grad = ts.requires_grad

        return yt
Beispiel #9
0
def wrap_gmres(A,
               params,
               B,
               E=None,
               M=None,
               mparams=[],
               posdef=False,
               **options):
    # A: (*BA, nr, nr)
    # B: (*BB, nr, ncols)
    # E: (*BE, ncols) or None
    # M: (*BM, nr, nr) or None

    # NOTE: currently only works for batched B (1 batch dim), but unbatched A
    assert len(A.shape) == 2 and len(
        B.shape
    ) == 3, "Currently only works for batched B (1 batch dim), but unbatched A"

    # check the parameters
    msg = "GMRES can only do AX=B"
    assert A.shape[-2] == A.shape[
        -1], "GMRES can only work for square operator for now"
    assert E is None, msg
    assert M is None, msg

    # set the default config options
    nbatch, na, ncols = B.shape
    config = set_default_option({
        "min_eps": 1e-9,
        "max_niter": 2 * na,
    }, options)
    min_eps = config["min_eps"]
    max_niter = config["max_niter"]

    B = B.transpose(-1, -2)  # (nbatch, ncols, na)

    # convert the numpy/scipy
    with A.uselinopparams(*params):
        op = A.scipy_linalg_op()
        B_np = B.detach().numpy()
        res_np = np.empty(B.shape, dtype=np.float64)
        for i in range(nbatch):
            for j in range(ncols):
                x, info = gmres(op,
                                B_np[i, j, :],
                                tol=min_eps,
                                atol=1e-12,
                                maxiter=max_niter)
                if info > 0:
                    msg = "The GMRES iteration does not converge to the desired value "\
                          "(%.3e) after %d iterations" % \
                          (config["min_eps"], info)
                    warnings.warn(msg)
                res_np[i, j, :] = x

        res = torch.tensor(res_np, dtype=B.dtype, device=B.device)
        res = res.transpose(-1, -2)  # (nbatch, na, ncols)
        return res
Beispiel #10
0
def gradrca(f, x0, jinv0=1.0, **options):
    # set up the default options
    config = set_default_option(
        {
            "max_niter": 20,
            "norders": 2,
            "min_eps": 1e-6,
            "verbose": False,
        }, options)

    # pull out the options for fast access
    min_eps = config["min_eps"]
    verbose = config["verbose"]
    norders = config["norders"]

    # pull out the parameters of x0
    nbatch, nfeat = x0.shape
    device = x0.device
    dtype = x0.dtype

    # set up the initial jinv
    jinv = _set_jinv0(jinv0, x0)

    x = x0
    onesvec = torch.ones_like(x0).unsqueeze(-1).to(x0.device) / np.sqrt(
        nfeat)  # (nbatch, nfeat, 1)
    for i in range(config["max_niter"]):
        xg = x.detach().requires_grad_()
        with torch.enable_grad():
            dx = f(xg)
            vunit = (dx /
                     dx.norm(dim=-1, keepdim=True)).detach()  # (nbatch, nfeat)
            loss = (dx * dx).sum(dim=-1)
            derivs = [loss]
            for j in range(norders):
                dldx = torch.autograd.grad(derivs[-1].sum(), (xg, ),
                                           create_graph=(j < norders - 1))[0]
                dldlmbda = (dldx * vunit).sum(dim=-1)
                derivs.append(dldlmbda)

        if norders == 2:
            dstep = (-derivs[1] / derivs[2])  # (nbatch,)
        elif norders == 3:
            dstep = (-derivs[2] +
                     torch.sqrt(derivs[2] * derivs[2] -
                                2 * derivs[1] * derivs[3])) / (derivs[3])
        else:
            raise RuntimeError("Order 4 or higher is not defined.")
        dstep = (jinv * dstep.unsqueeze(-1) * vunit)

        if verbose:
            print("Iter %d: %.3e" % (i + 1, dx.detach().abs().max()))

        x = x + dstep
    return x
Beispiel #11
0
    def forward(ctx, A, B, E, M, method, fwd_options, bck_options, na,
                *all_params):
        # A: (*BA, nr, nr)
        # B: (*BB, nr, ncols)
        # E: (*BE, ncols) or None
        # M: (*BM, nr, nr) or None
        # all_params: list of tensor of any shape
        # returns: (*BABEM, nr, ncols)

        # separate the parameters for A and for M
        params = all_params[:na]
        mparams = all_params[na:]

        config = set_default_option({}, fwd_options)
        ctx.bck_config = set_default_option({}, bck_options)

        if torch.all(B == 0):  # special case
            dims = (*_get_batchdims(A, B, E, M), *B.shape[-2:])
            x = torch.zeros(dims, dtype=B.dtype, device=B.device)
        else:
            with A.uselinopparams(*params), M.uselinopparams(
                    *mparams) if M is not None else dummy_context_manager():
                methods = {
                    "custom_exactsolve": custom_exactsolve,
                    "scipy_gmres": wrap_gmres,
                    "broyden1": broyden1_solve,
                    "cg": cg,
                    "bicgstab": bicgstab,
                }
                method_fcn = get_method("solve", methods, method)
                x = method_fcn(A, B, E, M, **config)

        ctx.e_is_none = E is None
        ctx.A = A
        ctx.M = M
        if ctx.e_is_none:
            ctx.save_for_backward(x, *all_params)
        else:
            ctx.save_for_backward(x, E, *all_params)
        ctx.na = na
        return x
Beispiel #12
0
    def forward(ctx, A, neig, mode, M, fwd_options, bck_options, na,
                *amparams):
        # A: LinearOperator (*BA, q, q)
        # M: LinearOperator (*BM, q, q) or None

        # separate the sets of parameters
        params = amparams[:na]
        mparams = amparams[na:]

        config = set_default_option({}, fwd_options)
        ctx.bck_config = set_default_option(
            {
                "degen_atol": None,
                "degen_rtol": None,
            }, bck_options)

        # options for calculating the backward (not for `solve`)
        alg_keys = ["degen_atol", "degen_rtol"]
        ctx.bck_alg_config = get_and_pop_keys(ctx.bck_config, alg_keys)

        method = config.pop("method")
        with A.uselinopparams(*params), M.uselinopparams(
                *mparams) if M is not None else dummy_context_manager():
            methods = {
                "davidson": davidson,
                "custom_exacteig": custom_exacteig,
            }
            method_fcn = get_method("symeig", methods, method)
            evals, evecs = method_fcn(A, neig, mode, M, **config)

        # save for the backward
        # evals: (*BAM, neig)
        # evecs: (*BAM, na, neig)
        ctx.save_for_backward(evals, evecs, *amparams)
        ctx.na = na
        ctx.A = A
        ctx.M = M
        return evals, evecs
Beispiel #13
0
    def forward(ctx, ffcn, log_pfcn, x0, xsamples, wsamples, method,
                fwd_options, bck_options, nfparams, nf_objparams, npparams,
                *all_fpparams):
        # set up the default options
        config = fwd_options
        ctx.bck_config = set_default_option(config, bck_options)

        # split the parameters
        fparams = all_fpparams[:nfparams]
        fobjparams = all_fpparams[nfparams:nfparams + nf_objparams]
        pparams = all_fpparams[nfparams + nf_objparams:nfparams +
                               nf_objparams + npparams]
        pobjparams = all_fpparams[nfparams + nf_objparams + npparams:]

        # select the method for the sampling
        if xsamples is None:
            methods = {
                "mh": mh,
                "_dummy1d": dummy1d,
                "mhcustom": mhcustom,
            }
            method_fcn = get_method("mcquad", methods, method)
            xsamples, wsamples = method_fcn(log_pfcn, x0, pparams, **config)
        epf = _integrate(ffcn, xsamples, wsamples, fparams)

        # save parameters for backward calculations
        ctx.xsamples = xsamples
        ctx.wsamples = wsamples
        ctx.ffcn = ffcn
        ctx.log_pfcn = log_pfcn
        ctx.fparam_sep = TensorNonTensorSeparator((*fparams, *fobjparams))
        ctx.pparam_sep = TensorNonTensorSeparator((*pparams, *pobjparams))
        ctx.nfparams = len(fparams)
        ctx.npparams = len(pparams)
        ctx.method = method

        # save for backward
        ftensor_params = ctx.fparam_sep.get_tensor_params()
        ptensor_params = ctx.pparam_sep.get_tensor_params()
        ctx.nftensorparams = len(ftensor_params)
        ctx.nptensorparams = len(ptensor_params)
        ctx.save_for_backward(epf, *ftensor_params, *ptensor_params)

        return epf
Beispiel #14
0
def broyden(f, x0, jinv0=1.0, **options):
    """
    Solve the root finder problem with Broyden's method.

    Arguments
    ---------
    * f: callable
        Callable that takes params as the input and output nfeat-outputs.
    * x0: torch.tensor (nbatch, nfeat)
        Initial value of parameters to be put in the function, f.
    * jinv0: float or torch.tensor (nbatch, nfeat, nfeat)
        The initial inverse of the Jacobian. If float, it will be the diagonal.
    * options: dict or None
        Options of the function.

    Returns
    -------
    * x: torch.tensor (nbatch, nfeat)
        The x that approximate f(x) = 0.
    """
    raise RuntimeError("This method is unfinished. Please use other methods.")

    # set up the default options
    config = set_default_option(
        {
            "max_niter": 20,
            "min_eps": 1e-6,
            "verbose": False,
        }, options)

    # pull out the options for fast access
    min_eps = config["min_eps"]
    verbose = config["verbose"]

    # pull out the parameters of x0
    nbatch, nfeat = x0.shape
    device = x0.device
    dtype = x0.dtype

    # set up the initial jinv
    jinv = _set_jinv0(jinv0, x0)

    # perform the Broyden iterations
    x = x0
    fx = f(x0)  # (nbatch, nfeat)
    stop_reason = "max_niter"
    for i in range(config["max_niter"]):
        dxnew = -jinv * fx  # (nbatch, nfeat)
        xnew = x + dxnew  # (nbatch, nfeat)
        fxnew = f(xnew)
        dfnew = fxnew - fx

        # calculate the new jinv
        xtnew_jinv = torch.bmm(xnew.unsqueeze(1), jinv)  # (nbatch, 1, nfeat)
        jinv_dfnew = torch.bmm(jinv, dfnew.unsqueeze(-1))  # (nbatch, nfeat, 1)
        xtnew_jinv_dfnew = torch.bmm(xtnew_jinv,
                                     dfnew.unsqueeze(-1))  # (nbatch, 1, 1)
        jinvnew = jinv + torch.bmm(dxnew - jinv_dfnew,
                                   xtnew_jinv) / xtnew_jinv_dfnew

        # update variables for the next iteration
        fx = fxnew
        jinv = jinvnew
        x = xnew

        # check the stopping condition
        if verbose:
            print("Iter %3d: %.3e" % (i + 1, fx.abs().max()))
        if torch.allclose(fx, torch.zeros_like(fx), atol=min_eps):
            stop_reason = "min_eps"
            break

    if stop_reason != "min_eps":
        msg = "The Broyden iteration does not converge to the required accuracy."
        msg += "\nRequired: %.3e. Achieved: %.3e" % (min_eps, fx.abs().max())
        warnings.warn(msg)

    return x
Beispiel #15
0
def davidson(A, params, neig, mode, M=None, mparams=[], **options):
    """
    Iterative methods to obtain the `neig` lowest eigenvalues and eigenvectors.
    This function is written so that the backpropagation can be done.
    It solves the eigendecomposition AV = VME where V are the matrix of eigenvectors,
    and E are the diagonal matrix consists of the eigenvalues.

    Arguments
    ---------
    * A: LinearOperator instance (*BA, na, na)
        The linear operator object on which the eigenpairs are constructed.
    * params: list of differentiable torch.tensor of any shapes
        List of differentiable torch.tensor to be put to A.
    * neig: int
        The number of eigenpairs to be retrieved.
    * mode: str
        Take the `neig` "lowest" or "uppest" of the eigenpairs.
    * M: LinearOperator instance (*BM, na, na) or None
        The transformation on the right hand side. If None, then M=I.
    * mparams: list of differentiable torch.tensor of any shapes
        List of differentiable torch.tensor to be put to M.
    * **options:
        Iterative algorithm options.

    Returns
    -------
    * eigvals: torch.tensor (*BAM, neig)
    * eigvecs: torch.tensor (*BAM, na, neig)
        The `neig` lowest eigenpairs
    """
    # TODO: optimize for large linear operator and strict min_eps
    # Ideas:
    # (1) use better strategy to get the estimate on eigenvalues
    # (2) use restart strategy
    config = set_default_option(
        {
            "max_niter": 1000,
            "nguess": neig,  # number of initial guess
            "min_eps": 1e-6,
            "verbose": False,
            "eps_cond": 1e-6,
            "v_init": "randn",
            "max_addition": neig,
        },
        options)

    # get some of the options
    nguess = config["nguess"]
    max_niter = config["max_niter"]
    min_eps = config["min_eps"]
    verbose = config["verbose"]
    eps_cond = config["eps_cond"]
    max_addition = config["max_addition"]

    # get the shape of the transformation
    na = A.shape[-1]
    if M is None:
        bcast_dims = A.shape[:-2]
    else:
        bcast_dims = get_bcasted_dims(A.shape[:-2], M.shape[:-2])
    dtype = A.dtype
    device = A.device

    # TODO: A to use params
    prev_eigvals = None
    prev_eigvalT = None
    stop_reason = "max_niter"
    shift_is_eigvalT = False
    idx = torch.arange(neig).unsqueeze(-1)  # (neig, 1)

    with A.uselinopparams(*params), M.uselinopparams(
            *mparams) if M is not None else dummy_context_manager():
        # set up the initial guess
        V = _set_initial_v(config["v_init"].lower(),
                           dtype,
                           device,
                           bcast_dims,
                           na,
                           nguess,
                           M=M,
                           mparams=mparams)  # (*BAM, na, nguess)
        # V = V.reshape(*bcast_dims, na, nguess) # (*BAM, na, nguess)

        # estimating the lowest eigenvalues
        eig_est, rms_eig = _estimate_eigvals(A,
                                             neig,
                                             mode,
                                             bcast_dims=bcast_dims,
                                             na=na,
                                             ntest=20,
                                             dtype=V.dtype,
                                             device=V.device)

        best_resid = float("inf")
        AV = A.mm(V)
        for i in range(max_niter):
            VT = V.transpose(-2, -1)  # (*BAM,nguess,na)
            # Can be optimized by saving AV from the previous iteration and only
            # operate AV for the new V. This works because the old V has already
            # been orthogonalized, so it will stay the same
            # AV = A.mm(V) # (*BAM,na,nguess)
            T = torch.matmul(VT, AV)  # (*BAM,nguess,nguess)

            # eigvals are sorted from the lowest
            # eval: (*BAM, nguess), evec: (*BAM, nguess, nguess)
            eigvalT, eigvecT = torch.symeig(T, eigenvectors=True)
            eigvalT, eigvecT = _take_eigpairs(
                eigvalT, eigvecT, neig,
                mode)  # (*BAM, neig) and (*BAM, nguess, neig)

            # calculate the eigenvectors of A
            eigvecA = torch.matmul(V, eigvecT)  # (*BAM, na, neig)

            # calculate the residual
            AVs = torch.matmul(AV, eigvecT)  # (*BAM, na, neig)
            LVs = eigvalT.unsqueeze(-2) * eigvecA  # (*BAM, na, neig)
            if M is not None:
                LVs = M.mm(LVs)
            resid = AVs - LVs  # (*BAM, na, neig)

            # print information and check convergence
            max_resid = resid.abs().max()
            if prev_eigvalT is not None:
                deigval = eigvalT - prev_eigvalT
                max_deigval = deigval.abs().max()
                if verbose:
                    print("Iter %3d (guess size: %d): resid: %.3e, devals: %.3e" % \
                          (i+1, nguess, max_resid, max_deigval))

            if max_resid < best_resid:
                best_resid = max_resid
                best_eigvals = eigvalT
                best_eigvecs = eigvecA
            if max_resid < min_eps:
                break
            if AV.shape[-1] == AV.shape[-2]:
                break
            prev_eigvalT = eigvalT

            # apply the preconditioner
            # initial guess of the eigenvalues are actually help really much
            if not shift_is_eigvalT:
                z = eig_est  # (*BAM,neig)
            else:
                z = eigvalT  # (*BAM,neig)
            # if A.is_precond_set():
            #     t = A.precond(-resid, *params, biases=z, M=M, mparams=mparams) # (nbatch, na, neig)
            # else:
            t = -resid  # (*BAM, na, neig)

            # set the estimate of the eigenvalues
            if not shift_is_eigvalT:
                eigvalT_pred = eigvalT + torch.einsum(
                    '...ae,...ae->...e', eigvecA, A.mm(t))  # (*BAM, neig)
                diff_eigvalT = (eigvalT - eigvalT_pred)  # (*BAM, neig)
                if diff_eigvalT.abs().max() < rms_eig * 1e-2:
                    shift_is_eigvalT = True
                else:
                    change_idx = eig_est > eigvalT
                    next_value = eigvalT - 2 * diff_eigvalT
                    eig_est[change_idx] = next_value[change_idx]

            # orthogonalize t with the rest of the V
            t = to_fortran_order(t)
            Vnew = torch.cat((V, t), dim=-1)
            if Vnew.shape[-1] > Vnew.shape[-2]:
                Vnew = Vnew[..., :Vnew.shape[-2]]
            nadd = Vnew.shape[-1] - V.shape[-1]
            nguess = nguess + nadd
            if M is not None:
                MV_ = M.mm(Vnew)
                V, R = tallqr(Vnew, MV=MV_)
            else:
                V, R = tallqr(Vnew)
            AVnew = A.mm(V[..., -nadd:])  # (*BAM,na,nadd)
            AVnew = to_fortran_order(AVnew)
            AV = torch.cat((AV, AVnew), dim=-1)

    eigvals = best_eigvals  # (*BAM, neig)
    eigvecs = best_eigvecs  # (*BAM, na, neig)
    return eigvals, eigvecs
Beispiel #16
0
def diis(f, x0, jinv0=1.0, **options):
    """
    Solve the root finder problem with DIIS method.

    Arguments
    ---------
    * f: callable
        Callable that takes params as the input and output nfeat-outputs.
    * x0: torch.tensor (nbatch, nfeat)
        Initial value of parameters to be put in the function, f.
    * jinv0: float or torch.tensor (nbatch, nfeat, nfeat)
        The initial inverse of the Jacobian. If float, it will be the diagonal.
    * options: dict or None
        Options of the function.

    Returns
    -------
    * x: torch.tensor (nbatch, nfeat)
        The x that approximate f(x) = 0.
    """
    # set up the default options
    config = set_default_option(
        {
            "max_niter": 20,
            "min_eps": 1e-6,
            "max_memory": 20,
            "minit": 10,
            "verbose": False,
        }, options)

    # pull out the options for fast access
    min_eps = config["min_eps"]
    verbose = config["verbose"]
    max_memory = config["max_memory"]
    minit = config["minit"]

    # pull out the parameters of x0
    nbatch, nfeat = x0.shape
    device = x0.device
    dtype = x0.dtype

    # set up the initial jinv
    jinv = _set_jinv0(jinv0, x0)

    # perform the iterations
    x = x0
    fx = f(x0)  # (nbatch, nfeat)
    stop_reason = "max_niter"
    bestcrit = float("inf")
    nbatch, nfeat = fx.shape
    x_history = torch.empty((nbatch, max_memory, nfeat),
                            dtype=x.dtype,
                            device=x.device)
    e_history = torch.empty((nbatch, max_memory, nfeat),
                            dtype=x.dtype,
                            device=x.device)
    mfill = 0
    midx = 0
    for i in range(config["max_niter"]):
        if mfill < 2:
            dxnew = -jinv * fx  # (nbatch, nfeat)
            xnew = x + dxnew  # (nbatch, nfeat)
        else:
            # construct the matrix B
            fx_tensor = e_history[:, :mfill, :]  # (nbatch, m, nfeat)
            Bul = torch.matmul(fx_tensor,
                               fx_tensor.transpose(-2, -1))  # (nbatch, m, m)
            Bu = torch.cat((Bul, -torch.ones(Bul.shape[0],
                                             Bul.shape[1],
                                             1,
                                             dtype=Bul.dtype,
                                             device=Bul.device)),
                           dim=-1)  # (nbatch, m, m+1)
            B = torch.cat((Bu, -torch.ones(
                Bu.shape[0], 1, Bu.shape[-1], dtype=Bu.dtype,
                device=Bu.device)),
                          dim=-2)  # (nbatch, m+1, m+1)
            B[:, -1, -1] = 0.0

            # solve the linear equation to get the coefficients
            a = torch.zeros(B.shape[0],
                            B.shape[1],
                            1,
                            dtype=B.dtype,
                            device=B.device)  # (nbatch, m+1, 1)
            a[:, -1] = -1.0
            c = torch.solve(a, B)[0].squeeze(-1)[:, :mfill]  # (nbatch, m)

            x_tensor = x_history[:, :mfill, :]  # (nbatch, m, nfeat)
            xnew = (x_tensor * c.unsqueeze(-1)).sum(dim=1)  # (nbatch, m)

        fxnew = f(xnew)  # (nbatch, nfeat)
        dx = xnew - x

        # update variables for the next iteration
        fx = fxnew
        x = xnew

        # add the history
        if i >= minit:
            x_history[:, midx, :] = x
            # e_history[:,midx,:] = dx
            e_history[:, midx, :] = fx
            midx = (midx + 1) % max_memory
            mfill = (mfill + 1) if mfill < max_memory else mfill

        # get the best results
        crit = fx.abs().max()
        if crit < bestcrit:
            bestcrit = crit
            bestx = x

        # check the stopping condition
        if verbose:
            print("Iter %3d: %.3e" % (i + 1, crit))
        if torch.allclose(fx, torch.zeros_like(fx), atol=min_eps):
            stop_reason = "min_eps"
            break

    if stop_reason != "min_eps":
        msg = "The DIIS iteration does not converge to the required accuracy."
        msg += "\nRequired: %.3e. Achieved: %.3e" % (min_eps, bestcrit)
        warnings.warn(msg)

    return bestx
Beispiel #17
0
def selfconsistent(f, x0, jinv0=1.0, **options):
    """
    Solve the root finder problem with Broyden's method.

    Arguments
    ---------
    * f: callable
        Callable that takes params as the input and output nfeat-outputs.
    * x0: torch.tensor (nbatch, nfeat)
        Initial value of parameters to be put in the function, f.
    * jinv0: float or torch.tensor (nbatch, nfeat, nfeat)
        The initial inverse of the Jacobian. If float, it will be the diagonal.
    * options: dict or None
        Options of the function.

    Returns
    -------
    * x: torch.tensor (nbatch, nfeat)
        The x that approximate f(x) = 0.
    """
    # set up the default options
    config = set_default_option(
        {
            "max_niter": 20,
            "min_eps": 1e-6,
            "beta":
            0.9,  # contribution of the new delta_n to the total delta_n
            "jinvdecay": 1.0,
            "decayevery": 100,
            "verbose": False,
        },
        options)

    # pull out the options for fast access
    min_eps = config["min_eps"]
    verbose = config["verbose"]
    beta = config["beta"]
    jinvdecay = config["jinvdecay"]
    decayevery = config["decayevery"]

    # pull out the parameters of x0
    nbatch, nfeat = x0.shape
    device = x0.device
    dtype = x0.dtype

    # set up the initial jinv
    jinv = _set_jinv0(jinv0, x0)

    # perform the Broyden iterations
    x = x0
    fx = f(x0)  # (nbatch, nfeat)
    stop_reason = "max_niter"
    dx = torch.zeros_like(x).to(x.device)
    bestcrit = float("inf")
    for i in range(config["max_niter"]):
        dxnew = -jinv0 * fx  # (nbatch, nfeat)
        dx = (1 - beta) * dx + beta * dxnew
        xnew = x + dx  # (nbatch, nfeat)
        fxnew = f(xnew)
        dfnew = fxnew - fx

        # update variables for the next iteration
        fx = fxnew
        x = xnew
        if (i + 1) % decayevery == 0:
            jinv = jinv * jinvdecay

        # get the best results
        crit = fx.abs().max()
        if crit < bestcrit:
            bestcrit = crit
            bestx = x

        # check the stopping condition
        if verbose:
            print("Iter %3d: %.3e" % (i + 1, crit))
        if torch.allclose(fx, torch.zeros_like(fx), atol=min_eps):
            stop_reason = "min_eps"
            break

    if stop_reason != "min_eps":
        msg = "The selfconsistent iteration does not converge to the required accuracy."
        msg += "\nRequired: %.3e. Achieved: %.3e" % (min_eps, bestcrit)
        warnings.warn(msg)

    return bestx
Beispiel #18
0
def lbfgs(f, x0, jinv0=1.0, **options):
    """
    Solve the root finder problem with L-BFGS method.

    Arguments
    ---------
    * f: callable
        Callable that takes params as the input and output nfeat-outputs.
    * x0: torch.tensor (*, nfeat)
        Initial value of parameters to be put in the function, f.
    * jinv0: float or torch.tensor (nbatch, nfeat, nfeat)
        The initial inverse of the Jacobian. If float, it will be the diagonal.
    * options: dict or None
        Options of the function.

    Returns
    -------
    * x: torch.tensor (nbatch, nfeat)
        The x that approximate f(x) = 0.
    """
    config = set_default_option(
        {
            "max_niter": 20,
            "min_eps": 1e-6,
            "max_memory": 10,
            "alpha0": 1.0,
            "linesearch": False,
            "verbose": False,
        }, options)

    # pull out the options for fast access
    min_eps = config["min_eps"]
    max_memory = config["max_memory"]
    verbose = config["verbose"]
    linesearch = config["linesearch"]
    alpha = config["alpha0"]

    # set up the initial jinv and the memories
    H0 = _set_jinv0_diag(jinv0, x0)  # (*, nfeat)
    sk_history = []
    yk_history = []
    rk_history = []

    def _apply_Vk(rk, sk, yk, grad):
        # sk: (*, nfeat)
        # yk: (*, nfeat)
        # rk: (*, 1)
        return grad - (sk * grad).sum(dim=-1, keepdim=True) * rk * yk

    def _apply_VkT(rk, sk, yk, grad):
        # sk: (*, nfeat)
        # yk: (*, nfeat)
        # rk: (*, 1)
        return grad - (yk * grad).sum(dim=-1, keepdim=True) * rk * sk

    def _apply_Hk(H0, sk_hist, yk_hist, rk_hist, gk):
        # H0: (*, nfeat)
        # sk: (*, nfeat)
        # yk: (*, nfeat)
        # rk: (*, 1)
        # gk: (*, nfeat)
        nhist = len(sk_hist)
        if nhist == 0:
            return H0 * gk

        k = nhist - 1
        rk = rk_hist[k]
        sk = sk_hist[k]
        yk = yk_hist[k]

        # get the last term (rk * sk * sk.T)
        rksksk = (sk * gk).sum(dim=-1, keepdim=True) * rk * sk

        # calculate the V_(k-1)
        grad = gk
        grad = _apply_Vk(rk_hist[k], sk_hist[k], yk_hist[k], grad)
        grad = _apply_Hk(H0, sk_hist[:k], yk_hist[:k], rk_hist[:k], grad)
        grad = _apply_VkT(rk_hist[k], sk_hist[k], yk_hist[k], grad)
        return grad + rksksk

    def _line_search(xk, gk, dk, g):
        if linesearch:
            dx, dg, nit = line_search(dk, xk, gk, g)
            return xk + dx, gk + dg
        else:
            return xk + alpha * dk, g(xk + alpha * dk)

    # perform the main iteration
    xk = x0
    gk = f(xk)
    bestgk = gk.abs().max()
    bestx = x0
    stop_reason = "max_niter"
    for k in range(config["max_niter"]):
        dk = -_apply_Hk(H0, sk_history, yk_history, rk_history, gk)
        xknew, gknew = _line_search(xk, gk, dk, f)

        # store the history
        sk = xknew - xk  # (*, nfeat)
        yk = gknew - gk
        inv_rhok = 1.0 / (sk * yk).sum(dim=-1, keepdim=True)  # (*, 1)
        sk_history.append(sk)
        yk_history.append(yk)
        rk_history.append(inv_rhok)
        if len(sk_history) > max_memory:
            sk_history = sk_history[-max_memory:]
            yk_history = yk_history[-max_memory:]
            rk_history = rk_history[-max_memory:]

        # update for the next iteration
        xk = xknew
        # alphakold = alphak
        gk = gknew

        # save the best point
        maxgk = gk.abs().max()
        if maxgk < bestgk:
            bestx = xk
            bestgk = maxgk

        # check the stopping condition
        if verbose:
            print("Iter %3d: %.3e" % (k + 1, gk.abs().max()))
        if maxgk < min_eps:
            stop_reason = "min_eps"
            break

    if stop_reason != "min_eps":
        msg = "The L-BFGS iteration does not converge to the required accuracy."
        msg += "\nRequired: %.3e. Achieved: %.3e" % (min_eps, bestgk)
        warnings.warn(msg)

    return bestx