Exemplo n.º 1
0
 def solvefcn(amat, bmat):
     # is_hermitian=hermit is required to force the hermitian status in numerical gradient
     alinop = LinearOperator.m(amat, is_hermitian=hermit)
     x = solve(A=alinop,
               B=bmat,
               fwd_options=fwd_options,
               bck_options=bck_options)
     return x
Exemplo n.º 2
0
 def solvefcn(amat, bmat, emat):
     alinop = LinearOperator.m(amat)
     x = solve(A=alinop,
               B=bmat,
               E=emat,
               **fwd_options,
               bck_options=bck_options)
     return x
Exemplo n.º 3
0
def test_solve_nonsquare_err(dtype, device):
    torch.manual_seed(seed)
    mat = torch.rand((3, 2), dtype=dtype, device=device)
    mat2 = torch.rand((3, 3), dtype=dtype, device=device)
    linop = LinearOperator.m(mat)
    linop2 = LinearOperator.m(mat2)
    B = torch.rand((3, 1), dtype=dtype, device=device)

    try:
        res = solve(linop, B)
        assert False, "A RuntimeError must be raised if the A linear operator in solve not square"
    except RuntimeError:
        pass

    try:
        res = solve(linop2, B, M=linop)
        assert False, "A RuntimeError must be raised if the M linear operator in solve is not square"
    except RuntimeError:
        pass
Exemplo n.º 4
0
 def solvefcn(amat, mmat, bmat, emat):
     mmat = (mmat + mmat.transpose(-2, -1)) * 0.5
     alinop = LinearOperator.m(amat)
     mlinop = LinearOperator.m(mmat)
     x = solve(A=alinop,
               B=bmat,
               E=emat,
               M=mlinop,
               **fwd_options,
               bck_options=bck_options)
     return x
Exemplo n.º 5
0
    def backward(ctx, grad_yout):
        param_sep = ctx.param_sep
        yout = ctx.saved_tensors[0]
        nparams = ctx.nparams
        fcn = ctx.fcn

        # merge the tensor and nontensor parameters
        tensor_params = ctx.saved_tensors[1:]
        allparams = param_sep.reconstruct_params(tensor_params)
        params = allparams[:nparams]
        objparams = allparams[nparams:]

        # dL/df
        with ctx.fcn.useobjparams(objparams):

            jac_dfdy = jac(fcn, params=(yout, *params), idxs=[0])[0]
            gyfcn = solve(A=jac_dfdy.H,
                          B=-grad_yout.reshape(-1, 1),
                          bck_options=ctx.bck_options,
                          **ctx.bck_options)
            gyfcn = gyfcn.reshape(grad_yout.shape)

            # get the grad for the params
            with torch.enable_grad():
                tensor_params_copy = [
                    p.clone().requires_grad_() for p in tensor_params
                ]
                allparams_copy = param_sep.reconstruct_params(
                    tensor_params_copy)
                params_copy = allparams_copy[:nparams]
                objparams_copy = allparams_copy[nparams:]
                with ctx.fcn.useobjparams(objparams_copy):
                    yfcn = fcn(yout, *params_copy)

            grad_tensor_params = torch.autograd.grad(
                yfcn,
                tensor_params_copy,
                grad_outputs=gyfcn,
                create_graph=torch.is_grad_enabled(),
                allow_unused=True)
            grad_nontensor_params = [
                None for _ in range(param_sep.nnontensors())
            ]
            grad_params = param_sep.reconstruct_params(grad_tensor_params,
                                                       grad_nontensor_params)

        return (None, None, None, None, None, None, None, *grad_params)
Exemplo n.º 6
0
def test_solve_mismatch_err(dtype, device):
    torch.manual_seed(seed)
    shapes = [
        #   A      B      M
        ([(3, 3), (2, 1), (3, 3)], "the B shape does not match with A"),
        ([(3, 3), (3, 2), (2, 2)], "the M shape does not match with A"),
    ]
    for (ashape, bshape, mshape), msg in shapes:
        amat = torch.rand(ashape, dtype=dtype, device=device)
        bmat = torch.rand(bshape, dtype=dtype, device=device)
        mmat = torch.rand(mshape, dtype=dtype, device=device) + \
            torch.eye(mshape[-1], dtype=dtype, device=device)
        amat = amat + amat.transpose(-2, -1)
        mmat = mmat + mmat.transpose(-2, -1)

        alinop = LinearOperator.m(amat)
        mlinop = LinearOperator.m(mmat)
        try:
            res = solve(alinop, B=bmat, M=mlinop)
            assert False, "A RuntimeError must be raised if %s" % msg
        except RuntimeError:
            pass
Exemplo n.º 7
0
    def backward(ctx, grad_evals, grad_evecs):
        # grad_evals: (*BAM, neig)
        # grad_evecs: (*BAM, na, neig)

        # get the variables from ctx
        evals = ctx.evals
        evecs = ctx.evecs
        M = ctx.M
        A = ctx.A

        # the loss function where the gradient will be retrieved
        # warnings: if not all params have the connection to the output of A,
        # it could cause an infinite loop because pytorch will keep looking
        # for the *params node and propagate further backward via the `evecs`
        # path. So make sure all the *params are all connected in the graph.
        with torch.enable_grad():
            params = [p.clone().requires_grad_() for p in ctx.params]
            with A.uselinopparams(*params):
                loss = A.mm(evecs)  # (*BAM, na, neig)

        # calculate the contributions from the eigenvalues
        gevalsA = grad_evals.unsqueeze(-2) * evecs  # (*BAM, na, neig)

        # calculate the contributions from the eigenvectors
        with M.uselinopparams(
                *ctx.mparams) if M is not None else dummy_context_manager():
            # orthogonalize the grad_evecs with evecs
            B = ortho(grad_evecs, evecs, dim=-2, M=M, mright=False)
            with A.uselinopparams(*ctx.params):
                gevecs = solve(A,
                               -B,
                               evals,
                               M,
                               fwd_options=ctx.bck_config,
                               bck_options=ctx.bck_config)
            # orthogonalize gevecs w.r.t. evecs
            gevecsA = ortho(gevecs, evecs, dim=-2, M=M, mright=True)

        # accummulate the gradient contributions
        gaccumA = gevalsA + gevecsA
        grad_params = torch.autograd.grad(
            outputs=(loss, ),
            inputs=params,
            grad_outputs=(gaccumA, ),
            create_graph=torch.is_grad_enabled(),
        )

        grad_mparams = []
        if ctx.M is not None:
            with torch.enable_grad():
                mparams = [p.clone().requires_grad_() for p in ctx.mparams]
                with M.uselinopparams(*mparams):
                    mloss = M.mm(evecs)  # (*BAM, na, neig)
            gevalsM = -gevalsA * evals.unsqueeze(-2)
            gevecsM = -gevecsA * evals.unsqueeze(-2)

            # the contribution from the parallel elements
            gevecsM_par = (-0.5 *
                           torch.einsum("...ae,...ae->...e", grad_evecs, evecs)
                           ).unsqueeze(-2) * evecs  # (*BAM, na, neig)

            gaccumM = gevalsM + gevecsM + gevecsM_par
            grad_mparams = torch.autograd.grad(
                outputs=(mloss, ),
                inputs=mparams,
                grad_outputs=(gaccumM, ),
                create_graph=torch.is_grad_enabled(),
            )

        return (None, None, None, None, None, None, None, *grad_params,
                *grad_mparams)
Exemplo n.º 8
0
 def solvefcn(amat, bmat, emat, mmat):
     alinop = LinearOperator.m(amat)
     mlinop = LinearOperator.m(mmat)
     x = solve(A=alinop, B=bmat, E=emat, M=mlinop, **fwd_options)
     return x
Exemplo n.º 9
0
 def solvefcn(amat, bmat):
     alinop = LinearOperator.m(amat)
     x = solve(A=alinop, B=bmat, fwd_options=fwd_options)
     return x
Exemplo n.º 10
0
    def backward(ctx, grad_evals, grad_evecs):
        # grad_evals: (*BAM, neig)
        # grad_evecs: (*BAM, na, neig)

        # get the variables from ctx
        evals, evecs = ctx.saved_tensors[:2]
        na = ctx.na
        amparams = ctx.saved_tensors[2:]
        params = amparams[:na]
        mparams = amparams[na:]

        M = ctx.M
        A = ctx.A
        degen_atol: Optional[float] = ctx.bck_alg_config["degen_atol"]
        degen_rtol: Optional[float] = ctx.bck_alg_config["degen_rtol"]

        # set the default values of degen_*tol
        dtype = evals.dtype
        if degen_atol is None:
            degen_atol = torch.finfo(dtype).eps**0.6
        if degen_rtol is None:
            degen_rtol = torch.finfo(dtype).eps**0.4

        # check the degeneracy
        if degen_atol > 0 or degen_rtol > 0:
            # idx_degen: (*BAM, neig, neig)
            idx_degen, isdegenerate = _check_degen(evals, degen_atol,
                                                   degen_rtol)
        else:
            isdegenerate = False
        if not isdegenerate:
            idx_degen = None

        # the loss function where the gradient will be retrieved
        # warnings: if not all params have the connection to the output of A,
        # it could cause an infinite loop because pytorch will keep looking
        # for the *params node and propagate further backward via the `evecs`
        # path. So make sure all the *params are all connected in the graph.
        with torch.enable_grad():
            params = [p.clone().requires_grad_() for p in params]
            with A.uselinopparams(*params):
                loss = A.mm(evecs)  # (*BAM, na, neig)

        # if degenerate, check the conditions for finite derivative
        if is_debug_enabled() and isdegenerate:
            xtg = torch.matmul(evecs.transpose(-2, -1), grad_evecs)
            req1 = idx_degen * (xtg - xtg.transpose(-2, -1))
            reqtol = xtg.abs().max() * grad_evecs.shape[-2] * torch.finfo(
                grad_evecs.dtype).eps

            if not torch.all(torch.abs(req1) <= reqtol):
                # if the requirements are not satisfied, raises a warning
                msg = (
                    "Degeneracy appears but the loss function seem to depend "
                    "strongly on the eigenvector. The gradient might be incorrect.\n"
                )
                msg += "Eigenvalues:\n%s\n" % str(evals)
                msg += "Degenerate map:\n%s\n" % str(idx_degen)
                msg += "Requirements (should be all 0s):\n%s" % str(req1)
                warnings.warn(MathWarning(msg))

        # calculate the contributions from the eigenvalues
        gevalsA = grad_evals.unsqueeze(-2) * evecs  # (*BAM, na, neig)

        # calculate the contributions from the eigenvectors
        with M.uselinopparams(
                *mparams) if M is not None else dummy_context_manager():
            # orthogonalize the grad_evecs with evecs
            B = _ortho(grad_evecs, evecs, D=idx_degen, M=M, mright=False)

            with A.uselinopparams(*params):
                gevecs = solve(A,
                               -B,
                               evals,
                               M,
                               bck_options=ctx.bck_config,
                               **ctx.bck_config)  # (*BAM, na, neig)

            # orthogonalize gevecs w.r.t. evecs
            gevecsA = _ortho(gevecs, evecs, D=None, M=M, mright=True)

        # accummulate the gradient contributions
        gaccumA = gevalsA + gevecsA
        grad_params = torch.autograd.grad(
            outputs=(loss, ),
            inputs=params,
            grad_outputs=(gaccumA, ),
            create_graph=torch.is_grad_enabled(),
        )

        grad_mparams = []
        if ctx.M is not None:
            with torch.enable_grad():
                mparams = [p.clone().requires_grad_() for p in mparams]
                with M.uselinopparams(*mparams):
                    mloss = M.mm(evecs)  # (*BAM, na, neig)
            gevalsM = -gevalsA * evals.unsqueeze(-2)
            gevecsM = -gevecsA * evals.unsqueeze(-2)

            # the contribution from the parallel elements
            gevecsM_par = (-0.5 *
                           torch.einsum("...ae,...ae->...e", grad_evecs, evecs)
                           ).unsqueeze(-2) * evecs  # (*BAM, na, neig)

            gaccumM = gevalsM + gevecsM + gevecsM_par
            grad_mparams = torch.autograd.grad(
                outputs=(mloss, ),
                inputs=mparams,
                grad_outputs=(gaccumM, ),
                create_graph=torch.is_grad_enabled(),
            )

        return (None, None, None, None, None, None, None, *grad_params,
                *grad_mparams)