Ejemplo n.º 1
0
def symGivens2(a, b):
    """
    Stable Symmetric Givens rotation plus reflection

    Parameters

        a: (theano scalar) first element of a two-vector  [a; b]
        b: (theano scalar) second element of a two-vector [a; b]
    Returns

        c  cosine(theta), where theta is the implicit angle of
           rotation (counter-clockwise) in a plane-rotation
        s  sine(theta)
        d  two-norm of [a; b]

    Description:
        This method gives c and s such that
            [ c  s ][a] = [d],
            [ s -c ][b]   [0]
      where d = two norm of vector [a, b],
            c = a / sqrt(a^2 + b^2) = a / d,
            s = b / sqrt(a^2 + b^2) = b / d.
      The implementation guards against overflow in computing
         sqrt(a^2 + b^2).

      SEE ALSO:
         (1) Algorithm 4.9, stable *unsymmetric* Givens
         rotations in Golub and van Loan's book Matrix
         Computations, 3rd edition.
         (2) MATLAB's function PLANEROT.

      Observations:
          Implementing this function as a single op in C might improve speed
          considerably ..
    """
    c_branch1 = T.switch(T.eq(a, constantX(0)), constantX(1), T.sgn(a))
    c_branch21 = (a / b) * T.sgn(b) / T.sqrt(constantX(1) + (a / b) ** 2)
    c_branch22 = T.sgn(a) / T.sqrt(constantX(1) + (b / a) ** 2)

    c_branch2 = T.switch(T.eq(a, constantX(0)), constantX(0), T.switch(T.gt(abs(b), abs(a)), c_branch21, c_branch22))
    c = T.switch(T.eq(b, constantX(0)), c_branch1, c_branch2)

    s_branch1 = T.sgn(b) / T.sqrt(constantX(1) + (a / b) ** 2)
    s_branch2 = (b / a) * T.sgn(a) / T.sqrt(constantX(1) + (b / a) ** 2)
    s = T.switch(
        T.eq(b, constantX(0)),
        constantX(0),
        T.switch(T.eq(a, constantX(0)), T.sgn(b), T.switch(T.gt(abs(b), abs(a)), s_branch1, s_branch2)),
    )

    d_branch1 = b / (T.sgn(b) / T.sqrt(constantX(1) + (a / b) ** 2))
    d_branch2 = a / (T.sgn(a) / T.sqrt(constantX(1) + (b / a) ** 2))
    d = T.switch(
        T.eq(b, constantX(0)),
        abs(a),
        T.switch(T.eq(a, constantX(0)), abs(b), T.switch(T.gt(abs(b), abs(a)), d_branch1, d_branch2)),
    )
    return c, s, d
Ejemplo n.º 2
0
def minres(compute_Av,
           bs,
           rtol=constantX(1e-6),
           maxit=20,
           Ms=None,
           shift=constantX(0.),
           maxxnorm=constantX(1e15),
           Acondlim=constantX(1e16),
           profile=0):
    """
    Attempts to find the minimum-length and minimum-residual-norm
    solution :math:`x` to the system of linear equations :math:`A*x = b`
    or least squares problem :math:`\\min||Ax-b||`. 

    The n-by-n coefficient matrix A must be symmetric (but need not be
    positive definite or invertible). The right-hand-side column vector
    b must have length n.

    .. note:: This code is inspired from
        http://www.stanford.edu/group/SOL/software/minres.html .

    Parameters
    ----------
    compute_Av : callable
        Callable returing the symbolic expression for
        `Av` (the product of matrix A with some vector v).
        `v` should be a list of tensors, where the vector v means
        the vector obtain by concatenating and flattening all tensors in v
    bs : list
        List of Theano expressions. We are looking to compute `A^-1\dot bs`.
    rtol : float, optional
        Specifies the tolerance of the method.  Default is 1e-6.
    maxit : int, positive, optional
        Specifies the maximum number of iterations. Default is 20.
    Ms : list
        List of theano expression of same shape as `bs`. The method uses
        these to precondition with diag(Ms)
    shift : float, optional
        Default is 0.  Effectively solve the system (A - shift I) * x = b.
    maxxnorm : float, positive, optional
        Maximum bound on NORM(x). Default is 1e14.
    Acondlim : float, positive, optional
        Maximum bound on COND(A). Default is 1e15.
    show : bool
        If True, show iterations, otherwise suppress outputs. Default is
        False.

    Returns
    -------
    x : list
        List of Theano tensor representing the solution
    flag : tensor_like
        Theano int scalar - convergence flag

            0. beta1 = 0.  The exact solution is  x = 0.
            1. A solution to (poss. singular) Ax = b found, given rtol.
            2. Pseudoinverse solution for singular LS problem, given rtol.
            3. A solution to (poss. singular) Ax = b found, given eps.
            4. Pseudoinverse solution for singular LS problem, given eps.
            5. x has converged to an eigenvector.
            6. xnorm has exceeded maxxnorm.
            7. Acond has exceeded Acondlim.
            8. The iteration limit was reached.
            9. 10. It is a least squares problem but no converged
               solution yet.
    iter : int
        Iteration number at which x was computed: `0 <= iter <= maxit`.
    relres : float
        Real positive, the relative residual is defined as
        NORM(b-A*x)/(NORM(A) * NORM(x) + NORM(b)),
        computed recurrently here.  If flag is 1 or 3,  relres <= TOL.
    relAres : float
        Real positive, the relative-NORM(Ar) := NORM(Ar) / NORM(A)
        computed recurrently here. If flag is 2 or 4, relAres <= TOL.
    Anorm : float
        Real positive, estimate of matrix 2-norm of A.
    Acond : float
        Real positive, estimate of condition number of A with respect to
        2-norm.
    xnorm : float
        Non-negative positive, recurrently computed NORM(x)
    Axnorm : float
        Non-negative positive, recurrently computed NORM(A * x).

    References
    ----------
    .. [1] Choi, Sou-Cheng. Iterative Methods for Singular Linear 
           Equations and Least-Squares Problems, PhD Dissertation,
           Stanford University, 2006.
    """

    if not isinstance(bs, (tuple, list)):
        bs = [bs]
        return_as_list = False
    else:
        bs = list(bs)
        return_as_list = True

    eps = constantX(1e-23)

    # Initialise
    beta1 = sqrt_inner_product(bs)

    #------------------------------------------------------------------
    # Set up p and v for the first Lanczos vector v1.
    # p  =  beta1 P' v1,  where  P = C**(-1).
    # v is really P' v1.
    #------------------------------------------------------------------
    r3s = [b for b in bs]
    r2s = [b for b in bs]
    r1s = [b for b in bs]
    if Ms is not None:
        r3s = [b / m for b, m in zip(bs, Ms)]
        beta1 = sqrt_inner_product(r3s, bs)
    #------------------------------------------------------------------
    ## Initialize other quantities.
    # Note that Anorm has been initialized by IsOpSym6.
    # ------------------------------------------------------------------
    bnorm = beta1
    n_params = len(bs)

    def loop(niter,
             beta,
             betan,
             phi,
             Acond,
             cs,
             dbarn,
             eplnn,
             rnorm,
             sn,
             Tnorm,
             rnorml,
             xnorm,
             Dnorm,
             gamma,
             pnorm,
             gammal,
             Axnorm,
             relrnorm,
             relArnorml,
             Anorm,
             flag,
             *args):
        #-----------------------------------------------------------------
        ## Obtain quantities for the next Lanczos vector vk+1, k = 1, 2,...
        # The general iteration is similar to the case k = 1 with v0 = 0:
        #
        #   p1      = Operator * v1  -  beta1 * v0,
        #   alpha1  = v1'p1,
        #   q2      = p2  -  alpha1 * v1,
        #   beta2^2 = q2'q2,
        #   v2      = (1/beta2) q2.
        #
        # Again, p = betak P vk,  where  P = C**(-1).
        # .... more description needed.
        #-----------------------------------------------------------------
        xs = args[0 * n_params: 1 * n_params]
        r1s = args[1 * n_params: 2 * n_params]
        r2s = args[2 * n_params: 3 * n_params]
        r3s = args[3 * n_params: 4 * n_params]
        dls = args[4 * n_params: 5 * n_params]
        ds = args[5 * n_params: 6 * n_params]
        betal = beta
        beta = betan
        vs = [r3 / beta for r3 in r3s]
        r3s, upds = compute_Av(*vs)

        r3s = [r3 - shift * v for r3, v in zip(r3s, vs)]
        r3s = [TT.switch(TT.ge(niter, constantX(1.)),
                         r3 - (beta / betal) * r1,
                         r3) for r3, r1 in zip(r3s, r1s)]

        alpha = inner_product(r3s, vs)
        r3s = [r3 - (alpha / beta) * r2 for r3, r2 in zip(r3s, r2s)]
        r1s = [r2 for r2 in r2s]
        r2s = [r3 for r3 in r3s]
        if Ms is not None:
            r3s = [r3 / M for r3, M in zip(r3s, Ms)]
            betan = sqrt_inner_product(r2s, r3s)
        else:
            betan = sqrt_inner_product(r3s)
        pnorml = pnorm
        pnorm = TT.switch(TT.eq(niter, constantX(0.)),
                          TT.sqrt(TT.sqr(alpha) + TT.sqr(betan)),
                          TT.sqrt(TT.sqr(alpha) + TT.sqr(betan) +
                                  TT.sqr(beta)))

        #-----------------------------------------------------------------
        ## Apply previous rotation Qk-1 to get
        #   [dlta_k epln_{k+1}] = [cs  sn][dbar_k    0      ]
        #   [gbar_k  dbar_{k+1} ]   [sn -cs][alpha_k beta_{k+1}].
        #-----------------------------------------------------------------
        dbar = dbarn
        epln = eplnn
        dlta = cs * dbar + sn * alpha
        gbar = sn * dbar - cs * alpha

        eplnn = sn * betan
        dbarn = -cs * betan

        ## Compute the current plane rotation Qk
        gammal2 = gammal
        gammal = gamma
        cs, sn, gamma = symGivens2(gbar, betan)
        tau = cs * phi
        phi = sn * phi
        Axnorm = TT.sqrt(TT.sqr(Axnorm) + TT.sqr(tau))
        # Update d

        dl2s = [dl for dl in dls]
        dls = [d for d in ds]
        ds = [TT.switch(TT.neq(gamma, constantX(0.)),
                        (v - epln * dl2 - dlta * dl) / gamma,
                        v)
              for v, dl2, dl in zip(vs, dl2s, dls)]
        d_norm = TT.switch(TT.neq(gamma, constantX(0.)),
                           sqrt_inner_product(ds),
                           constantX(numpy.inf))

        # Update x except if it will become too big
        xnorml = xnorm
        dl2s = [x for x in xs]
        xs = [x + tau * d for x, d in zip(xs, ds)]

        xnorm = sqrt_inner_product(xs)
        xs = [TT.switch(TT.ge(xnorm, maxxnorm),
                        dl2, x)
              for dl2, x in zip(dl2s, xs)]

        flag = TT.switch(TT.ge(xnorm, maxxnorm),
                         constantX(6.), flag)
        # Estimate various norms
        rnorml = rnorm  # ||r_{k-1}||
        Anorml = Anorm
        Acondl = Acond
        relrnorml = relrnorm
        flag_no_6 = TT.neq(flag, constantX(6.))
        Dnorm = TT.switch(flag_no_6,
                          TT.sqrt(TT.sqr(Dnorm) + TT.sqr(d_norm)),
                          Dnorm)
        xnorm = TT.switch(flag_no_6, sqrt_inner_product(xs), xnorm)
        rnorm = TT.switch(flag_no_6, phi, rnorm)
        relrnorm = TT.switch(flag_no_6,
                             rnorm / (Anorm * xnorm + bnorm),
                             relrnorm)
        Tnorm = TT.switch(flag_no_6,
                          TT.switch(TT.eq(niter, constantX(0.)),
                                    TT.sqrt(TT.sqr(alpha) + TT.sqr(betan)),
                                    TT.sqrt(TT.sqr(Tnorm) +
                                            TT.sqr(beta) +
                                            TT.sqr(alpha) +
                                            TT.sqr(betan))),
                          Tnorm)
        Anorm = TT.maximum(Anorm, pnorm)
        Acond = Anorm * Dnorm
        rootl = TT.sqrt(TT.sqr(gbar) + TT.sqr(dbarn))
        Anorml = rnorml * rootl
        relArnorml = rootl / Anorm

        #---------------------------------------------------------------
        # See if any of the stopping criteria are satisfied.
        # In rare cases, flag is already -1 from above (Abar = const*I).
        #---------------------------------------------------------------
        epsx = Anorm * xnorm * eps
        epsr = Anorm * xnorm * rtol
        #Test for singular Hk (hence singular A)
        # or x is already an LS solution (so again A must be singular).
        t1 = constantX(1) + relrnorm
        t2 = constantX(1) + relArnorml

        flag = TT.switch(
            TT.bitwise_or(TT.eq(flag, constantX(0)),
                          TT.eq(flag, constantX(6))),
            multiple_switch(TT.le(t1, constantX(1)),
                            constantX(3),
                            TT.le(t2, constantX(1)),
                            constantX(4),
                            TT.le(relrnorm, rtol),
                            constantX(1),
                            TT.le(Anorm, constantX(1e-20)),
                            constantX(12),
                            TT.le(relArnorml, rtol),
                            constantX(10),
                            TT.ge(epsx, beta1),
                            constantX(5),
                            TT.ge(xnorm, maxxnorm),
                            constantX(6),
                            TT.ge(niter, TT.cast(maxit,
                                                 theano.config.floatX)),
                            constantX(8),
                            flag),
            flag)

        flag = TT.switch(TT.lt(Axnorm, rtol * Anorm * xnorm),
                         constantX(11.),
                         flag)
        return [niter + constantX(1.),
                beta,
                betan,
                phi,
                Acond,
                cs,
                dbarn,
                eplnn,
                rnorm,
                sn,
                Tnorm,
                rnorml,
                xnorm,
                Dnorm,
                gamma,
                pnorm,
                gammal,
                Axnorm,
                relrnorm,
                relArnorml,
                Anorm,
                flag] + xs + r1s + r2s + r3s + dls + ds, upds, \
                theano.scan_module.scan_utils.until(TT.neq(flag, 0))

    states = []
    # 0 niter
    states.append(constantX(0))
    # 1 beta
    states.append(constantX(0))
    # 2 betan
    states.append(beta1)
    # 3 phi
    states.append(beta1)
    # 4 Acond
    states.append(constantX(1))
    # 5 cs
    states.append(constantX(-1))
    # 6 dbarn
    states.append(constantX(0))
    # 7 eplnn
    states.append(constantX(0))
    # 8 rnorm
    states.append(beta1)
    # 9 sn
    states.append(constantX(0))
    # 10 Tnorm
    states.append(constantX(0))
    # 11 rnorml
    states.append(beta1)
    # 12 xnorm
    states.append(constantX(0))
    # 13 Dnorm
    states.append(constantX(0))
    # 14 gamma
    states.append(constantX(0))
    # 15 pnorm
    states.append(constantX(0))
    # 16 gammal
    states.append(constantX(0))
    # 17 Axnorm
    states.append(constantX(0))
    # 18 relrnorm
    states.append(constantX(1))
    # 19 relArnorml
    states.append(constantX(1))
    # 20 Anorm
    states.append(constantX(0))
    # 21 flag
    states.append(constantX(0))
    xs = [TT.zeros_like(b) for b in bs]
    ds = [TT.zeros_like(b) for b in bs]
    dls = [TT.zeros_like(b) for b in bs]

    rvals, loc_updates = scan(
        loop,
        outputs_info=(states + xs + r1s + r2s + r3s + dls + ds),
        n_steps=maxit + numpy.int32(1),
        name='minres',
        profile=profile,
        mode=theano.Mode(linker='cvm'))
    assert isinstance(loc_updates, dict) and 'Ordered' in str(type(loc_updates))

    niters = TT.cast(rvals[0][-1], 'int32')
    flag = TT.cast(rvals[21][-1], 'int32')
    relres = rvals[18][-1]
    relAres = rvals[19][-1]
    Anorm = rvals[20][-1]
    Acond = rvals[4][-1]
    xnorm = rvals[12][-1]
    Axnorm = rvals[17][-1]
    sol = [x[-1] for x in rvals[22: 22 + n_params]]
    return (sol,
            flag,
            niters,
            relres,
            relAres,
            Anorm,
            Acond,
            xnorm,
            Axnorm,
            loc_updates)
Ejemplo n.º 3
0
def symGivens2(a, b):
    """
    Stable Symmetric Givens rotation plus reflection

    Parameters

        a: (theano scalar) first element of a two-vector  [a; b]
        b: (theano scalar) second element of a two-vector [a; b]
    Returns

        c  cosine(theta), where theta is the implicit angle of
           rotation (counter-clockwise) in a plane-rotation
        s  sine(theta)
        d  two-norm of [a; b]

    Description:
        This method gives c and s such that
            [ c  s ][a] = [d],
            [ s -c ][b]   [0]
      where d = two norm of vector [a, b],
            c = a / sqrt(a^2 + b^2) = a / d,
            s = b / sqrt(a^2 + b^2) = b / d.
      The implementation guards against overflow in computing
         sqrt(a^2 + b^2).

      SEE ALSO:
         (1) Algorithm 4.9, stable *unsymmetric* Givens
         rotations in Golub and van Loan's book Matrix
         Computations, 3rd edition.
         (2) MATLAB's function PLANEROT.

      Observations:
          Implementing this function as a single op in C might improve speed
          considerably ..
    """
    c_branch1 = T.switch(T.eq(a, constantX(0)), constantX(1), T.sgn(a))
    c_branch21 = (a / b) * T.sgn(b) / \
            T.sqrt(constantX(1) + (a / b) ** 2)
    c_branch22 = T.sgn(a) / T.sqrt(constantX(1) + (b / a)**2)

    c_branch2 = T.switch(
        T.eq(a, constantX(0)), constantX(0),
        T.switch(T.gt(abs(b), abs(a)), c_branch21, c_branch22))
    c = T.switch(T.eq(b, constantX(0)), c_branch1, c_branch2)

    s_branch1 = T.sgn(b) / T.sqrt(constantX(1) + (a / b)**2)
    s_branch2 = (b / a) * T.sgn(a) / T.sqrt(constantX(1) + (b / a)**2)
    s = T.switch(
        T.eq(b, constantX(0)), constantX(0),
        T.switch(T.eq(a, constantX(0)), T.sgn(b),
                 T.switch(T.gt(abs(b), abs(a)), s_branch1, s_branch2)))

    d_branch1 = b / (T.sgn(b) / T.sqrt(constantX(1) + (a / b)**2))
    d_branch2 = a / (T.sgn(a) / T.sqrt(constantX(1) + (b / a)**2))
    d = T.switch(
        T.eq(b, constantX(0)), abs(a),
        T.switch(T.eq(a, constantX(0)), abs(b),
                 T.switch(T.gt(abs(b), abs(a)), d_branch1, d_branch2)))
    return c, s, d
Ejemplo n.º 4
0
    def loop(niter,
             beta,
             betan,
             phi,
             Acond,
             cs,
             dbarn,
             eplnn,
             rnorm,
             sn,
             Tnorm,
             rnorml,
             xnorm,
             Dnorm,
             gamma,
             pnorm,
             gammal,
             Axnorm,
             relrnorm,
             relArnorml,
             Anorm,
             flag,
             *args):
        #-----------------------------------------------------------------
        ## Obtain quantities for the next Lanczos vector vk+1, k = 1, 2,...
        # The general iteration is similar to the case k = 1 with v0 = 0:
        #
        #   p1      = Operator * v1  -  beta1 * v0,
        #   alpha1  = v1'p1,
        #   q2      = p2  -  alpha1 * v1,
        #   beta2^2 = q2'q2,
        #   v2      = (1/beta2) q2.
        #
        # Again, p = betak P vk,  where  P = C**(-1).
        # .... more description needed.
        #-----------------------------------------------------------------
        xs = args[0 * n_params: 1 * n_params]
        r1s = args[1 * n_params: 2 * n_params]
        r2s = args[2 * n_params: 3 * n_params]
        r3s = args[3 * n_params: 4 * n_params]
        dls = args[4 * n_params: 5 * n_params]
        ds = args[5 * n_params: 6 * n_params]
        betal = beta
        beta = betan
        vs = [r3 / beta for r3 in r3s]
        r3s, upds = compute_Av(*vs)

        r3s = [r3 - shift * v for r3, v in zip(r3s, vs)]
        r3s = [TT.switch(TT.ge(niter, constantX(1.)),
                         r3 - (beta / betal) * r1,
                         r3) for r3, r1 in zip(r3s, r1s)]

        alpha = inner_product(r3s, vs)
        r3s = [r3 - (alpha / beta) * r2 for r3, r2 in zip(r3s, r2s)]
        r1s = [r2 for r2 in r2s]
        r2s = [r3 for r3 in r3s]
        if Ms is not None:
            r3s = [r3 / M for r3, M in zip(r3s, Ms)]
            betan = sqrt_inner_product(r2s, r3s)
        else:
            betan = sqrt_inner_product(r3s)
        pnorml = pnorm
        pnorm = TT.switch(TT.eq(niter, constantX(0.)),
                          TT.sqrt(TT.sqr(alpha) + TT.sqr(betan)),
                          TT.sqrt(TT.sqr(alpha) + TT.sqr(betan) +
                                  TT.sqr(beta)))

        #-----------------------------------------------------------------
        ## Apply previous rotation Qk-1 to get
        #   [dlta_k epln_{k+1}] = [cs  sn][dbar_k    0      ]
        #   [gbar_k  dbar_{k+1} ]   [sn -cs][alpha_k beta_{k+1}].
        #-----------------------------------------------------------------
        dbar = dbarn
        epln = eplnn
        dlta = cs * dbar + sn * alpha
        gbar = sn * dbar - cs * alpha

        eplnn = sn * betan
        dbarn = -cs * betan

        ## Compute the current plane rotation Qk
        gammal2 = gammal
        gammal = gamma
        cs, sn, gamma = symGivens2(gbar, betan)
        tau = cs * phi
        phi = sn * phi
        Axnorm = TT.sqrt(TT.sqr(Axnorm) + TT.sqr(tau))
        # Update d

        dl2s = [dl for dl in dls]
        dls = [d for d in ds]
        ds = [TT.switch(TT.neq(gamma, constantX(0.)),
                        (v - epln * dl2 - dlta * dl) / gamma,
                        v)
              for v, dl2, dl in zip(vs, dl2s, dls)]
        d_norm = TT.switch(TT.neq(gamma, constantX(0.)),
                           sqrt_inner_product(ds),
                           constantX(numpy.inf))

        # Update x except if it will become too big
        xnorml = xnorm
        dl2s = [x for x in xs]
        xs = [x + tau * d for x, d in zip(xs, ds)]

        xnorm = sqrt_inner_product(xs)
        xs = [TT.switch(TT.ge(xnorm, maxxnorm),
                        dl2, x)
              for dl2, x in zip(dl2s, xs)]

        flag = TT.switch(TT.ge(xnorm, maxxnorm),
                         constantX(6.), flag)
        # Estimate various norms
        rnorml = rnorm  # ||r_{k-1}||
        Anorml = Anorm
        Acondl = Acond
        relrnorml = relrnorm
        flag_no_6 = TT.neq(flag, constantX(6.))
        Dnorm = TT.switch(flag_no_6,
                          TT.sqrt(TT.sqr(Dnorm) + TT.sqr(d_norm)),
                          Dnorm)
        xnorm = TT.switch(flag_no_6, sqrt_inner_product(xs), xnorm)
        rnorm = TT.switch(flag_no_6, phi, rnorm)
        relrnorm = TT.switch(flag_no_6,
                             rnorm / (Anorm * xnorm + bnorm),
                             relrnorm)
        Tnorm = TT.switch(flag_no_6,
                          TT.switch(TT.eq(niter, constantX(0.)),
                                    TT.sqrt(TT.sqr(alpha) + TT.sqr(betan)),
                                    TT.sqrt(TT.sqr(Tnorm) +
                                            TT.sqr(beta) +
                                            TT.sqr(alpha) +
                                            TT.sqr(betan))),
                          Tnorm)
        Anorm = TT.maximum(Anorm, pnorm)
        Acond = Anorm * Dnorm
        rootl = TT.sqrt(TT.sqr(gbar) + TT.sqr(dbarn))
        Anorml = rnorml * rootl
        relArnorml = rootl / Anorm

        #---------------------------------------------------------------
        # See if any of the stopping criteria are satisfied.
        # In rare cases, flag is already -1 from above (Abar = const*I).
        #---------------------------------------------------------------
        epsx = Anorm * xnorm * eps
        epsr = Anorm * xnorm * rtol
        #Test for singular Hk (hence singular A)
        # or x is already an LS solution (so again A must be singular).
        t1 = constantX(1) + relrnorm
        t2 = constantX(1) + relArnorml

        flag = TT.switch(
            TT.bitwise_or(TT.eq(flag, constantX(0)),
                          TT.eq(flag, constantX(6))),
            multiple_switch(TT.le(t1, constantX(1)),
                            constantX(3),
                            TT.le(t2, constantX(1)),
                            constantX(4),
                            TT.le(relrnorm, rtol),
                            constantX(1),
                            TT.le(Anorm, constantX(1e-20)),
                            constantX(12),
                            TT.le(relArnorml, rtol),
                            constantX(10),
                            TT.ge(epsx, beta1),
                            constantX(5),
                            TT.ge(xnorm, maxxnorm),
                            constantX(6),
                            TT.ge(niter, TT.cast(maxit,
                                                 theano.config.floatX)),
                            constantX(8),
                            flag),
            flag)

        flag = TT.switch(TT.lt(Axnorm, rtol * Anorm * xnorm),
                         constantX(11.),
                         flag)
        return [niter + constantX(1.),
                beta,
                betan,
                phi,
                Acond,
                cs,
                dbarn,
                eplnn,
                rnorm,
                sn,
                Tnorm,
                rnorml,
                xnorm,
                Dnorm,
                gamma,
                pnorm,
                gammal,
                Axnorm,
                relrnorm,
                relArnorml,
                Anorm,
                flag] + xs + r1s + r2s + r3s + dls + ds, upds, \
                theano.scan_module.scan_utils.until(TT.neq(flag, 0))
Ejemplo n.º 5
0
def minres(
    compute_Av,
    bs,
    rtol=constantX(1e-6),
    maxit=20,
    Ms=None,
    shift=constantX(0.0),
    maxxnorm=constantX(1e15),
    Acondlim=constantX(1e16),
    profile=0,
):
    """
    Attempts to find the minimum-length and minimum-residual-norm
    solution :math:`x` to the system of linear equations :math:`A*x = b`
    or least squares problem :math:`\\min||Ax-b||`. 

    The n-by-n coefficient matrix A must be symmetric (but need not be
    positive definite or invertible). The right-hand-side column vector
    b must have length n.

    .. note:: This code is inspired from
        http://www.stanford.edu/group/SOL/software/minres.html .

    Parameters
    ----------
    compute_Av : callable
        Callable returing the symbolic expression for
        `Av` (the product of matrix A with some vector v).
        `v` should be a list of tensors, where the vector v means
        the vector obtain by concatenating and flattening all tensors in v
    bs : list
        List of Theano expressions. We are looking to compute `A^-1\dot bs`.
    rtol : float, optional
        Specifies the tolerance of the method.  Default is 1e-6.
    maxit : int, positive, optional
        Specifies the maximum number of iterations. Default is 20.
    Ms : list
        List of theano expression of same shape as `bs`. The method uses
        these to precondition with diag(Ms)
    shift : float, optional
        Default is 0.  Effectively solve the system (A - shift I) * x = b.
    maxxnorm : float, positive, optional
        Maximum bound on NORM(x). Default is 1e14.
    Acondlim : float, positive, optional
        Maximum bound on COND(A). Default is 1e15.
    show : bool
        If True, show iterations, otherwise suppress outputs. Default is
        False.

    Returns
    -------
    x : list
        List of Theano tensor representing the solution
    flag : tensor_like
        Theano int scalar - convergence flag

            0. beta1 = 0.  The exact solution is  x = 0.
            1. A solution to (poss. singular) Ax = b found, given rtol.
            2. Pseudoinverse solution for singular LS problem, given rtol.
            3. A solution to (poss. singular) Ax = b found, given eps.
            4. Pseudoinverse solution for singular LS problem, given eps.
            5. x has converged to an eigenvector.
            6. xnorm has exceeded maxxnorm.
            7. Acond has exceeded Acondlim.
            8. The iteration limit was reached.
            9. 10. It is a least squares problem but no converged
               solution yet.
    iter : int
        Iteration number at which x was computed: `0 <= iter <= maxit`.
    relres : float
        Real positive, the relative residual is defined as
        NORM(b-A*x)/(NORM(A) * NORM(x) + NORM(b)),
        computed recurrently here.  If flag is 1 or 3,  relres <= TOL.
    relAres : float
        Real positive, the relative-NORM(Ar) := NORM(Ar) / NORM(A)
        computed recurrently here. If flag is 2 or 4, relAres <= TOL.
    Anorm : float
        Real positive, estimate of matrix 2-norm of A.
    Acond : float
        Real positive, estimate of condition number of A with respect to
        2-norm.
    xnorm : float
        Non-negative positive, recurrently computed NORM(x)
    Axnorm : float
        Non-negative positive, recurrently computed NORM(A * x).

    References
    ----------
    .. [1] Choi, Sou-Cheng. Iterative Methods for Singular Linear 
           Equations and Least-Squares Problems, PhD Dissertation,
           Stanford University, 2006.
    """

    if not isinstance(bs, (tuple, list)):
        bs = [bs]
        return_as_list = False
    else:
        bs = list(bs)
        return_as_list = True

    eps = constantX(1e-23)

    # Initialise
    beta1 = sqrt_inner_product(bs)

    # ------------------------------------------------------------------
    # Set up p and v for the first Lanczos vector v1.
    # p  =  beta1 P' v1,  where  P = C**(-1).
    # v is really P' v1.
    # ------------------------------------------------------------------
    r3s = [b for b in bs]
    r2s = [b for b in bs]
    r1s = [b for b in bs]
    if Ms is not None:
        r3s = [b / m for b, m in zip(bs, Ms)]
        beta1 = sqrt_inner_product(r3s, bs)
    # ------------------------------------------------------------------
    ## Initialize other quantities.
    # Note that Anorm has been initialized by IsOpSym6.
    # ------------------------------------------------------------------
    bnorm = beta1
    n_params = len(bs)

    def loop(
        niter,
        beta,
        betan,
        phi,
        Acond,
        cs,
        dbarn,
        eplnn,
        rnorm,
        sn,
        Tnorm,
        rnorml,
        xnorm,
        Dnorm,
        gamma,
        pnorm,
        gammal,
        Axnorm,
        relrnorm,
        relArnorml,
        Anorm,
        flag,
        *args
    ):
        # -----------------------------------------------------------------
        ## Obtain quantities for the next Lanczos vector vk+1, k = 1, 2,...
        # The general iteration is similar to the case k = 1 with v0 = 0:
        #
        #   p1      = Operator * v1  -  beta1 * v0,
        #   alpha1  = v1'p1,
        #   q2      = p2  -  alpha1 * v1,
        #   beta2^2 = q2'q2,
        #   v2      = (1/beta2) q2.
        #
        # Again, p = betak P vk,  where  P = C**(-1).
        # .... more description needed.
        # -----------------------------------------------------------------
        xs = args[0 * n_params : 1 * n_params]
        r1s = args[1 * n_params : 2 * n_params]
        r2s = args[2 * n_params : 3 * n_params]
        r3s = args[3 * n_params : 4 * n_params]
        dls = args[4 * n_params : 5 * n_params]
        ds = args[5 * n_params : 6 * n_params]
        betal = beta
        beta = betan
        vs = [r3 / beta for r3 in r3s]
        r3s, upds = compute_Av(*vs)

        r3s = [r3 - shift * v for r3, v in zip(r3s, vs)]
        r3s = [TT.switch(TT.ge(niter, constantX(1.0)), r3 - (beta / betal) * r1, r3) for r3, r1 in zip(r3s, r1s)]

        alpha = inner_product(r3s, vs)
        r3s = [r3 - (alpha / beta) * r2 for r3, r2 in zip(r3s, r2s)]
        r1s = [r2 for r2 in r2s]
        r2s = [r3 for r3 in r3s]
        if Ms is not None:
            r3s = [r3 / M for r3, M in zip(r3s, Ms)]
            betan = sqrt_inner_product(r2s, r3s)
        else:
            betan = sqrt_inner_product(r3s)
        pnorml = pnorm
        pnorm = TT.switch(
            TT.eq(niter, constantX(0.0)),
            TT.sqrt(TT.sqr(alpha) + TT.sqr(betan)),
            TT.sqrt(TT.sqr(alpha) + TT.sqr(betan) + TT.sqr(beta)),
        )

        # -----------------------------------------------------------------
        ## Apply previous rotation Qk-1 to get
        #   [dlta_k epln_{k+1}] = [cs  sn][dbar_k    0      ]
        #   [gbar_k  dbar_{k+1} ]   [sn -cs][alpha_k beta_{k+1}].
        # -----------------------------------------------------------------
        dbar = dbarn
        epln = eplnn
        dlta = cs * dbar + sn * alpha
        gbar = sn * dbar - cs * alpha

        eplnn = sn * betan
        dbarn = -cs * betan

        ## Compute the current plane rotation Qk
        gammal2 = gammal
        gammal = gamma
        cs, sn, gamma = symGivens2(gbar, betan)
        tau = cs * phi
        phi = sn * phi
        Axnorm = TT.sqrt(TT.sqr(Axnorm) + TT.sqr(tau))
        # Update d

        dl2s = [dl for dl in dls]
        dls = [d for d in ds]
        ds = [
            TT.switch(TT.neq(gamma, constantX(0.0)), (v - epln * dl2 - dlta * dl) / gamma, v)
            for v, dl2, dl in zip(vs, dl2s, dls)
        ]
        d_norm = TT.switch(TT.neq(gamma, constantX(0.0)), sqrt_inner_product(ds), constantX(numpy.inf))

        # Update x except if it will become too big
        xnorml = xnorm
        dl2s = [x for x in xs]
        xs = [x + tau * d for x, d in zip(xs, ds)]

        xnorm = sqrt_inner_product(xs)
        xs = [TT.switch(TT.ge(xnorm, maxxnorm), dl2, x) for dl2, x in zip(dl2s, xs)]

        flag = TT.switch(TT.ge(xnorm, maxxnorm), constantX(6.0), flag)
        # Estimate various norms
        rnorml = rnorm  # ||r_{k-1}||
        Anorml = Anorm
        Acondl = Acond
        relrnorml = relrnorm
        flag_no_6 = TT.neq(flag, constantX(6.0))
        Dnorm = TT.switch(flag_no_6, TT.sqrt(TT.sqr(Dnorm) + TT.sqr(d_norm)), Dnorm)
        xnorm = TT.switch(flag_no_6, sqrt_inner_product(xs), xnorm)
        rnorm = TT.switch(flag_no_6, phi, rnorm)
        relrnorm = TT.switch(flag_no_6, rnorm / (Anorm * xnorm + bnorm), relrnorm)
        Tnorm = TT.switch(
            flag_no_6,
            TT.switch(
                TT.eq(niter, constantX(0.0)),
                TT.sqrt(TT.sqr(alpha) + TT.sqr(betan)),
                TT.sqrt(TT.sqr(Tnorm) + TT.sqr(beta) + TT.sqr(alpha) + TT.sqr(betan)),
            ),
            Tnorm,
        )
        Anorm = TT.maximum(Anorm, pnorm)
        Acond = Anorm * Dnorm
        rootl = TT.sqrt(TT.sqr(gbar) + TT.sqr(dbarn))
        Anorml = rnorml * rootl
        relArnorml = rootl / Anorm

        # ---------------------------------------------------------------
        # See if any of the stopping criteria are satisfied.
        # In rare cases, flag is already -1 from above (Abar = const*I).
        # ---------------------------------------------------------------
        epsx = Anorm * xnorm * eps
        epsr = Anorm * xnorm * rtol
        # Test for singular Hk (hence singular A)
        # or x is already an LS solution (so again A must be singular).
        t1 = constantX(1) + relrnorm
        t2 = constantX(1) + relArnorml

        flag = TT.switch(
            TT.bitwise_or(TT.eq(flag, constantX(0)), TT.eq(flag, constantX(6))),
            multiple_switch(
                TT.le(t1, constantX(1)),
                constantX(3),
                TT.le(t2, constantX(1)),
                constantX(4),
                TT.le(relrnorm, rtol),
                constantX(1),
                TT.le(Anorm, constantX(1e-20)),
                constantX(12),
                TT.le(relArnorml, rtol),
                constantX(10),
                TT.ge(epsx, beta1),
                constantX(5),
                TT.ge(xnorm, maxxnorm),
                constantX(6),
                TT.ge(niter, TT.cast(maxit, theano.config.floatX)),
                constantX(8),
                flag,
            ),
            flag,
        )

        flag = TT.switch(TT.lt(Axnorm, rtol * Anorm * xnorm), constantX(11.0), flag)
        return (
            [
                niter + constantX(1.0),
                beta,
                betan,
                phi,
                Acond,
                cs,
                dbarn,
                eplnn,
                rnorm,
                sn,
                Tnorm,
                rnorml,
                xnorm,
                Dnorm,
                gamma,
                pnorm,
                gammal,
                Axnorm,
                relrnorm,
                relArnorml,
                Anorm,
                flag,
            ]
            + xs
            + r1s
            + r2s
            + r3s
            + dls
            + ds,
            upds,
            theano.scan_module.scan_utils.until(TT.neq(flag, 0)),
        )

    states = []
    # 0 niter
    states.append(constantX(0))
    # 1 beta
    states.append(constantX(0))
    # 2 betan
    states.append(beta1)
    # 3 phi
    states.append(beta1)
    # 4 Acond
    states.append(constantX(1))
    # 5 cs
    states.append(constantX(-1))
    # 6 dbarn
    states.append(constantX(0))
    # 7 eplnn
    states.append(constantX(0))
    # 8 rnorm
    states.append(beta1)
    # 9 sn
    states.append(constantX(0))
    # 10 Tnorm
    states.append(constantX(0))
    # 11 rnorml
    states.append(beta1)
    # 12 xnorm
    states.append(constantX(0))
    # 13 Dnorm
    states.append(constantX(0))
    # 14 gamma
    states.append(constantX(0))
    # 15 pnorm
    states.append(constantX(0))
    # 16 gammal
    states.append(constantX(0))
    # 17 Axnorm
    states.append(constantX(0))
    # 18 relrnorm
    states.append(constantX(1))
    # 19 relArnorml
    states.append(constantX(1))
    # 20 Anorm
    states.append(constantX(0))
    # 21 flag
    states.append(constantX(0))
    xs = [TT.zeros_like(b) for b in bs]
    ds = [TT.zeros_like(b) for b in bs]
    dls = [TT.zeros_like(b) for b in bs]

    rvals, loc_updates = scan(
        loop,
        outputs_info=(states + xs + r1s + r2s + r3s + dls + ds),
        n_steps=maxit + numpy.int32(1),
        name="minres",
        profile=profile,
        mode=theano.Mode(linker="cvm"),
    )
    assert isinstance(loc_updates, dict) and "Ordered" in str(type(loc_updates))

    niters = TT.cast(rvals[0][-1], "int32")
    flag = TT.cast(rvals[21][-1], "int32")
    relres = rvals[18][-1]
    relAres = rvals[19][-1]
    Anorm = rvals[20][-1]
    Acond = rvals[4][-1]
    xnorm = rvals[12][-1]
    Axnorm = rvals[17][-1]
    sol = [x[-1] for x in rvals[22 : 22 + n_params]]
    return (sol, flag, niters, relres, relAres, Anorm, Acond, xnorm, Axnorm, loc_updates)
Ejemplo n.º 6
0
    def loop(
        niter,
        beta,
        betan,
        phi,
        Acond,
        cs,
        dbarn,
        eplnn,
        rnorm,
        sn,
        Tnorm,
        rnorml,
        xnorm,
        Dnorm,
        gamma,
        pnorm,
        gammal,
        Axnorm,
        relrnorm,
        relArnorml,
        Anorm,
        flag,
        *args
    ):
        # -----------------------------------------------------------------
        ## Obtain quantities for the next Lanczos vector vk+1, k = 1, 2,...
        # The general iteration is similar to the case k = 1 with v0 = 0:
        #
        #   p1      = Operator * v1  -  beta1 * v0,
        #   alpha1  = v1'p1,
        #   q2      = p2  -  alpha1 * v1,
        #   beta2^2 = q2'q2,
        #   v2      = (1/beta2) q2.
        #
        # Again, p = betak P vk,  where  P = C**(-1).
        # .... more description needed.
        # -----------------------------------------------------------------
        xs = args[0 * n_params : 1 * n_params]
        r1s = args[1 * n_params : 2 * n_params]
        r2s = args[2 * n_params : 3 * n_params]
        r3s = args[3 * n_params : 4 * n_params]
        dls = args[4 * n_params : 5 * n_params]
        ds = args[5 * n_params : 6 * n_params]
        betal = beta
        beta = betan
        vs = [r3 / beta for r3 in r3s]
        r3s, upds = compute_Av(*vs)

        r3s = [r3 - shift * v for r3, v in zip(r3s, vs)]
        r3s = [TT.switch(TT.ge(niter, constantX(1.0)), r3 - (beta / betal) * r1, r3) for r3, r1 in zip(r3s, r1s)]

        alpha = inner_product(r3s, vs)
        r3s = [r3 - (alpha / beta) * r2 for r3, r2 in zip(r3s, r2s)]
        r1s = [r2 for r2 in r2s]
        r2s = [r3 for r3 in r3s]
        if Ms is not None:
            r3s = [r3 / M for r3, M in zip(r3s, Ms)]
            betan = sqrt_inner_product(r2s, r3s)
        else:
            betan = sqrt_inner_product(r3s)
        pnorml = pnorm
        pnorm = TT.switch(
            TT.eq(niter, constantX(0.0)),
            TT.sqrt(TT.sqr(alpha) + TT.sqr(betan)),
            TT.sqrt(TT.sqr(alpha) + TT.sqr(betan) + TT.sqr(beta)),
        )

        # -----------------------------------------------------------------
        ## Apply previous rotation Qk-1 to get
        #   [dlta_k epln_{k+1}] = [cs  sn][dbar_k    0      ]
        #   [gbar_k  dbar_{k+1} ]   [sn -cs][alpha_k beta_{k+1}].
        # -----------------------------------------------------------------
        dbar = dbarn
        epln = eplnn
        dlta = cs * dbar + sn * alpha
        gbar = sn * dbar - cs * alpha

        eplnn = sn * betan
        dbarn = -cs * betan

        ## Compute the current plane rotation Qk
        gammal2 = gammal
        gammal = gamma
        cs, sn, gamma = symGivens2(gbar, betan)
        tau = cs * phi
        phi = sn * phi
        Axnorm = TT.sqrt(TT.sqr(Axnorm) + TT.sqr(tau))
        # Update d

        dl2s = [dl for dl in dls]
        dls = [d for d in ds]
        ds = [
            TT.switch(TT.neq(gamma, constantX(0.0)), (v - epln * dl2 - dlta * dl) / gamma, v)
            for v, dl2, dl in zip(vs, dl2s, dls)
        ]
        d_norm = TT.switch(TT.neq(gamma, constantX(0.0)), sqrt_inner_product(ds), constantX(numpy.inf))

        # Update x except if it will become too big
        xnorml = xnorm
        dl2s = [x for x in xs]
        xs = [x + tau * d for x, d in zip(xs, ds)]

        xnorm = sqrt_inner_product(xs)
        xs = [TT.switch(TT.ge(xnorm, maxxnorm), dl2, x) for dl2, x in zip(dl2s, xs)]

        flag = TT.switch(TT.ge(xnorm, maxxnorm), constantX(6.0), flag)
        # Estimate various norms
        rnorml = rnorm  # ||r_{k-1}||
        Anorml = Anorm
        Acondl = Acond
        relrnorml = relrnorm
        flag_no_6 = TT.neq(flag, constantX(6.0))
        Dnorm = TT.switch(flag_no_6, TT.sqrt(TT.sqr(Dnorm) + TT.sqr(d_norm)), Dnorm)
        xnorm = TT.switch(flag_no_6, sqrt_inner_product(xs), xnorm)
        rnorm = TT.switch(flag_no_6, phi, rnorm)
        relrnorm = TT.switch(flag_no_6, rnorm / (Anorm * xnorm + bnorm), relrnorm)
        Tnorm = TT.switch(
            flag_no_6,
            TT.switch(
                TT.eq(niter, constantX(0.0)),
                TT.sqrt(TT.sqr(alpha) + TT.sqr(betan)),
                TT.sqrt(TT.sqr(Tnorm) + TT.sqr(beta) + TT.sqr(alpha) + TT.sqr(betan)),
            ),
            Tnorm,
        )
        Anorm = TT.maximum(Anorm, pnorm)
        Acond = Anorm * Dnorm
        rootl = TT.sqrt(TT.sqr(gbar) + TT.sqr(dbarn))
        Anorml = rnorml * rootl
        relArnorml = rootl / Anorm

        # ---------------------------------------------------------------
        # See if any of the stopping criteria are satisfied.
        # In rare cases, flag is already -1 from above (Abar = const*I).
        # ---------------------------------------------------------------
        epsx = Anorm * xnorm * eps
        epsr = Anorm * xnorm * rtol
        # Test for singular Hk (hence singular A)
        # or x is already an LS solution (so again A must be singular).
        t1 = constantX(1) + relrnorm
        t2 = constantX(1) + relArnorml

        flag = TT.switch(
            TT.bitwise_or(TT.eq(flag, constantX(0)), TT.eq(flag, constantX(6))),
            multiple_switch(
                TT.le(t1, constantX(1)),
                constantX(3),
                TT.le(t2, constantX(1)),
                constantX(4),
                TT.le(relrnorm, rtol),
                constantX(1),
                TT.le(Anorm, constantX(1e-20)),
                constantX(12),
                TT.le(relArnorml, rtol),
                constantX(10),
                TT.ge(epsx, beta1),
                constantX(5),
                TT.ge(xnorm, maxxnorm),
                constantX(6),
                TT.ge(niter, TT.cast(maxit, theano.config.floatX)),
                constantX(8),
                flag,
            ),
            flag,
        )

        flag = TT.switch(TT.lt(Axnorm, rtol * Anorm * xnorm), constantX(11.0), flag)
        return (
            [
                niter + constantX(1.0),
                beta,
                betan,
                phi,
                Acond,
                cs,
                dbarn,
                eplnn,
                rnorm,
                sn,
                Tnorm,
                rnorml,
                xnorm,
                Dnorm,
                gamma,
                pnorm,
                gammal,
                Axnorm,
                relrnorm,
                relArnorml,
                Anorm,
                flag,
            ]
            + xs
            + r1s
            + r2s
            + r3s
            + dls
            + ds,
            upds,
            theano.scan_module.scan_utils.until(TT.neq(flag, 0)),
        )
Ejemplo n.º 7
0
def symGivens2(a, b):
    """
    Stable Symmetric Givens rotation plus reflection

    Parameters
    ----------
    a : theano scalar
        first element of a two-vector  [a; b]
    b : theano scalar
        second element of a two-vector [a; b]

    Returns
    -------
    c : WRITEME
        cosine(theta), where theta is the implicit angle of
        rotation (counter-clockwise) in a plane-rotation
    s : WRITEME
        sine(theta)
    d : WRITEME
        two-norm of [a; b]

    Notes
    -----
    * See also:

      - Algorithm 4.9, stable *unsymmetric* Givens rotations in Golub
        and van Loan's book Matrix Computations, 3rd edition.

      - MATLAB's function PLANEROT.

    * This method gives c and s such that

      .. math::

          \\begin{pmatrix} c & s \\\ s & -c \\end{pmatrix}
          \\begin{pmatrix} a \\\ b \\end{pmatrix} =
          \\begin{pmatrix} d \\\ 0 \\end{pmatrix}

      where

      :math:`d = \\left\Vert \\begin{pmatrix} a \\\ b \\end{pmatrix}
      \\right\Vert _{2}`,
      :math:`c = a / \sqrt{a^2 + b^2} = a / d`,
      :math:`s = b / \sqrt{a^2 + b^2} = b / d`.

      The implementation guards against overflow in computing
      :math:`\sqrt{a^2 + b^2}`.

    * Observation: Implementing this function as a single op in C might
      improve speed considerably .
    """
    c_branch1 = T.switch(T.eq(a, constantX(0)),
                          constantX(1),
                          T.sgn(a))
    c_branch21 = (a / b) * T.sgn(b) / \
            T.sqrt(constantX(1) + (a / b) ** 2)
    c_branch22 = T.sgn(a) / T.sqrt(constantX(1) + (b / a) ** 2)

    c_branch2 = T.switch(T.eq(a, constantX(0)),
                          constantX(0),
                          T.switch(T.gt(abs(b), abs(a)),
                                    c_branch21,
                                    c_branch22))
    c = T.switch(T.eq(b, constantX(0)),
                  c_branch1,
                  c_branch2)

    s_branch1 = T.sgn(b) / T.sqrt(constantX(1) + (a / b) ** 2)
    s_branch2 = (b / a) * T.sgn(a) / T.sqrt(constantX(1) + (b / a) ** 2)
    s = T.switch(T.eq(b, constantX(0)),
                  constantX(0),
                  T.switch(T.eq(a, constantX(0)),
                            T.sgn(b),
                            T.switch(T.gt(abs(b), abs(a)),
                                      s_branch1,
                                      s_branch2)))

    d_branch1 = b / (T.sgn(b) / T.sqrt(constantX(1) + (a / b) ** 2))
    d_branch2 = a / (T.sgn(a) / T.sqrt(constantX(1) + (b / a) ** 2))
    d = T.switch(T.eq(b, constantX(0)),
                  abs(a),
                  T.switch(T.eq(a, constantX(0)),
                            abs(b),
                            T.switch(T.gt(abs(b), abs(a)),
                                      d_branch1,
                                      d_branch2)))
    return c, s, d
Ejemplo n.º 8
0
def symGivens2(a, b):
    """
    Stable Symmetric Givens rotation plus reflection

    Parameters
    ----------
    a : theano scalar
        first element of a two-vector  [a; b]
    b : theano scalar
        second element of a two-vector [a; b]

    Returns
    -------
    c : WRITEME
        cosine(theta), where theta is the implicit angle of
        rotation (counter-clockwise) in a plane-rotation
    s : WRITEME
        sine(theta)
    d : WRITEME
        two-norm of [a; b]

    Notes
    -----
    * See also:

      - Algorithm 4.9, stable *unsymmetric* Givens rotations in Golub
        and van Loan's book Matrix Computations, 3rd edition.

      - MATLAB's function PLANEROT.

    * This method gives c and s such that

      .. math::

          \\begin{pmatrix} c & s \\\ s & -c \\end{pmatrix}
          \\begin{pmatrix} a \\\ b \\end{pmatrix} =
          \\begin{pmatrix} d \\\ 0 \\end{pmatrix}

      where

      :math:`d = \\left\Vert \\begin{pmatrix} a \\\ b \\end{pmatrix}
      \\right\Vert _{2}`,
      :math:`c = a / \sqrt{a^2 + b^2} = a / d`,
      :math:`s = b / \sqrt{a^2 + b^2} = b / d`.

      The implementation guards against overflow in computing
      :math:`\sqrt{a^2 + b^2}`.

    * Observation: Implementing this function as a single op in C might
      improve speed considerably .
    """
    c_branch1 = T.switch(T.eq(a, constantX(0)),
                          constantX(1),
                          T.sgn(a))
    c_branch21 = (a / b) * T.sgn(b) / \
            T.sqrt(constantX(1) + (a / b) ** 2)
    c_branch22 = T.sgn(a) / T.sqrt(constantX(1) + (b / a) ** 2)

    c_branch2 = T.switch(T.eq(a, constantX(0)),
                          constantX(0),
                          T.switch(T.gt(abs(b), abs(a)),
                                    c_branch21,
                                    c_branch22))
    c = T.switch(T.eq(b, constantX(0)),
                  c_branch1,
                  c_branch2)

    s_branch1 = T.sgn(b) / T.sqrt(constantX(1) + (a / b) ** 2)
    s_branch2 = (b / a) * T.sgn(a) / T.sqrt(constantX(1) + (b / a) ** 2)
    s = T.switch(T.eq(b, constantX(0)),
                  constantX(0),
                  T.switch(T.eq(a, constantX(0)),
                            T.sgn(b),
                            T.switch(T.gt(abs(b), abs(a)),
                                      s_branch1,
                                      s_branch2)))

    d_branch1 = b / (T.sgn(b) / T.sqrt(constantX(1) + (a / b) ** 2))
    d_branch2 = a / (T.sgn(a) / T.sqrt(constantX(1) + (b / a) ** 2))
    d = T.switch(T.eq(b, constantX(0)),
                  abs(a),
                  T.switch(T.eq(a, constantX(0)),
                            abs(b),
                            T.switch(T.gt(abs(b), abs(a)),
                                      d_branch1,
                                      d_branch2)))
    return c, s, d