Esempio n. 1
0
def schur(a,output='real',lwork=None,overwrite_a=0):
    """Compute Schur decomposition of matrix a. 

    Description:

      Return T, Z such that a = Z * T * (Z**H) where Z is a
      unitary matrix and T is either upper-triangular or quasi-upper
      triangular for output='real'
    """
    if not output in ['real','complex','r','c']:
        raise ValueError, "argument must be 'real', or 'complex'"
    a1 = asarray_chkfinite(a)
    if len(a1.shape) != 2 or (a1.shape[0] != a1.shape[1]):
        raise ValueError, 'expected square matrix'
    N = a1.shape[0]
    typ = a1.typecode()
    if output in ['complex','c'] and typ not in ['F','D']:
        if typ in _double_precision:
            a1 = a1.astype('D')
            typ = 'D'
        else:
            a1 = a1.astype('F')
            typ = 'F'
    overwrite_a = overwrite_a or (a1 is not a and not hasattr(a,'__array__'))
    gees, = get_lapack_funcs(('gees',),(a1,))
    if lwork is None or lwork == -1:
        # get optimal work array
        result = gees(lambda x: None,a,lwork=-1)
        lwork = result[-2][0]
    result = gees(lambda x: None,a,lwork=result[-2][0],overwrite_a=overwrite_a)
    info = result[-1]
    if info<0: raise ValueError,\
       'illegal value in %-th argument of internal gees'%(-info)
    elif info>0: raise LinAlgError, "Schur form not found.  Possibly ill-conditioned."
    return result[0], result[-3]
Esempio n. 2
0
def pinv(a, cond=None):
    """ pinv(a, cond=None) -> a_pinv

    Compute generalized inverse of A using least-squares solver.
    """
    a = asarray_chkfinite(a)
    t = a.typecode()
    b = scipy_base.identity(a.shape[0],t)
    return lstsq(a, b, cond=cond)[0]
Esempio n. 3
0
def cho_solve(clow, b):
    """Solve a previously factored symmetric system of equations.
    First input is a tuple (LorU, lower) which is the output to cho_factor.
    Second input is the right-hand side.
    """
    c, lower = clow
    c = asarray_chkfinite(c)
    _assert_squareness(c)
    b = asarray_chkfinite(b)
    potrs, = get_lapack_funcs(('potrs',),(c,))
    b, info = potrs(c,b,lower)
    if info < 0:
        msg = "Argument %d to lapack's ?potrs() has an illegal value." % info
        raise TypeError, msg
    if info > 0:
        msg = "Unknown error occured int ?potrs(): error code = %d" % info
        raise TypeError, msg
    return b
Esempio n. 4
0
def norm(x, ord=2):
    """ norm(x, ord=2) -> n

    Matrix and vector norm.

    Inputs:

      x -- a rank-1 (vector) or rank-2 (matrix) array
      ord -- the order of norm.

     Comments:

       For vectors ord can be any real number including Inf or -Inf.
         ord = Inf, computes the maximum of the magnitudes
         ord = -Inf, computes minimum of the magnitudes
         ord is finite, computes sum(abs(x)**ord)**(1.0/ord)

       For matrices ord can only be + or - 1, 2, Inf.
         ord = 2 computes the largest singular value
         ord = -2 computes the smallest singular value
         ord = 1 computes the largest column sum of absolute values
         ord = -1 computes the smallest column sum of absolute values
         ord = Inf computes the largest row sum of absolute values
         ord = -Inf computes the smallest row sum of absolute values
         ord = 'fro' computes the frobenius norm sqrt(sum(diag(X.H * X)))
    """
    x = asarray_chkfinite(x)
    nd = len(x.shape)
    Inf = scipy_base.Inf
    if nd == 1:
        if ord == Inf:
            return scipy_base.amax(abs(x))
        elif ord == -Inf:
            return scipy_base.amin(abs(x))
        else:
            return scipy_base.sum(abs(x)**ord)**(1.0/ord)
    elif nd == 2:
        if ord == 2:
            return scipy_base.amax(decomp.svd(x,compute_uv=0))
        elif ord == -2:
            return scipy_base.amin(decomp.svd(x,compute_uv=0))
        elif ord == 1:
            return scipy_base.amax(scipy_base.sum(abs(x)))
        elif ord == Inf:
            return scipy_base.amax(scipy_base.sum(abs(x),axis=1))
        elif ord == -1:
            return scipy_base.amin(scipy_base.sum(abs(x)))
        elif ord == -Inf:
            return scipy_base.amin(scipy_base.sum(abs(x),axis=1))
        elif ord in ['fro','f']:
            val = real((conjugate(x)*x).flat)
            return sqrt(add.reduce(val))
        else:
            raise ValueError, "Invalid norm order for matrices."
    else:
        raise ValueError, "Improper number of dimensions to norm."
Esempio n. 5
0
def lu_solve(a_lu_pivots,b):
    """Solve a previously factored system.  First input is a tuple (lu, pivots)
    which is the output to lu_factor.  Second input is the right hand side.
    """
    a_lu, pivots = a_lu_pivots
    a_lu = asarray_chkfinite(a_lu)
    pivots = asarray_chkfinite(pivots)
    b = asarray_chkfinite(b)
    _assert_squareness(a_lu)
    
    getrs, = get_lapack_funcs(('getrs',),(a,))
    b, info = getrs(a_lu,pivots,b)    
    if info < 0:
        msg = "Argument %d to lapack's ?getrs() has an illegal value." % info
        raise TypeError, msg
    if info > 0:
        msg = "Unknown error occured int ?getrs(): error code = %d" % info
        raise TypeError, msg
    return b
Esempio n. 6
0
def det(a, overwrite_a=0):
    """ det(a, overwrite_a=0) -> d

    Return determinant of a square matrix.
    """
    a1 = asarray_chkfinite(a)
    if len(a1.shape) != 2 or a1.shape[0] != a1.shape[1]:
        raise ValueError, 'expected square matrix'
    overwrite_a = overwrite_a or (a1 is not a and not hasattr(a,'__array__'))
    fdet, = get_flinalg_funcs(('det',),(a1,))
    a_det,info = fdet(a1,overwrite_a=overwrite_a)
    if info<0: raise ValueError,\
       'illegal value in %-th argument of internal det.getrf'%(-info)
    return a_det
Esempio n. 7
0
def cho_factor(a, lower=0, overwrite_a=0):
    """ Compute Cholesky decomposition of matrix and return an object
    to be used for solving a linear system using cho_solve.
    """
    a1 = asarray_chkfinite(a)
    if len(a1.shape) != 2 or a1.shape[0] != a1.shape[1]:
        raise ValueError, 'expected square matrix'
    overwrite_a = overwrite_a or (a1 is not a and not hasattr(a,'__array__'))
    potrf, = get_lapack_funcs(('potrf',),(a1,))
    c,info = potrf(a1,lower=lower,overwrite_a=overwrite_a,clean=0)
    if info>0: raise LinAlgError, "matrix not positive definite"
    if info<0: raise ValueError,\
       'illegal value in %-th argument of internal potrf'%(-info)
    return c, lower
Esempio n. 8
0
def inv(a, overwrite_a=0):
    """ inv(a, overwrite_a=0) -> a_inv

    Return inverse of square matrix a.
    """
    a1 = asarray_chkfinite(a)
    if len(a1.shape) != 2 or a1.shape[0] != a1.shape[1]:
        raise ValueError, 'expected square matrix'
    overwrite_a = overwrite_a or (a1 is not a and not hasattr(a,'__array__'))
    #XXX: I found no advantage or disadvantage of using finv.
##     finv, = get_flinalg_funcs(('inv',),(a1,))
##     if finv is not None:
##         a_inv,info = finv(a1,overwrite_a=overwrite_a)
##         if info==0:
##             return a_inv
##         if info>0: raise LinAlgError, "singular matrix"
##         if info<0: raise ValueError,\
##            'illegal value in %-th argument of internal inv.getrf|getri'%(-info)
    getrf,getri = get_lapack_funcs(('getrf','getri'),(a1,))
    #XXX: C ATLAS versions of getrf/i have rowmajor=1, this could be
    #     exploited for further optimization. But it will be probably
    #     a mess. So, a good testing site is required before trying
    #     to do that.
    if getrf.module_name[:7]=='clapack'!=getri.module_name[:7]:
        # ATLAS 3.2.1 has getrf but not getri.
        lu,piv,info = getrf(transpose(a1),
                            rowmajor=0,overwrite_a=overwrite_a)
        lu = transpose(lu)
    else:
        lu,piv,info = getrf(a1,overwrite_a=overwrite_a)
    if info==0:
        if getri.module_name[:7] == 'flapack':
            lwork = calc_lwork.getri(getri.prefix,a1.shape[0])
            lwork = lwork[1]
            # XXX: the following line fixes curious SEGFAULT when
            # benchmarking 500x500 matrix inverse. This seems to
            # be a bug in LAPACK ?getri routine because if lwork is
            # minimal (when using lwork[0] instead of lwork[1]) then
            # all tests pass. Further investigation is required if
            # more such SEGFAULTs occur.
            lwork = int(1.01*lwork)
            inv_a,info = getri(lu,piv,
                               lwork=lwork,overwrite_lu=1)
        else: # clapack
            inv_a,info = getri(lu,piv,overwrite_lu=1)
    if info>0: raise LinAlgError, "singular matrix"
    if info<0: raise ValueError,\
       'illegal value in %-th argument of internal getrf|getri'%(-info)
    return inv_a
Esempio n. 9
0
def qr(a,overwrite_a=0,lwork=None):
    """QR decomposition of an M x N matrix a.

    Description:

      Find a unitary matrix, q, and an upper-trapezoidal matrix r
      such that q * r = a

    Inputs:

      a -- the matrix
      overwrite_a=0 -- if non-zero then discard the contents of a,
                     i.e. a is used as a work array if possible.

      lwork=None -- >= shape(a)[1]. If None (or -1) compute optimal
                    work array size. 

    Outputs:
    
      q, r -- matrices such that q * r = a

    """
    a1 = asarray_chkfinite(a)
    if len(a1.shape) != 2:
        raise ValueError, 'expected matrix'
    M,N = a1.shape
    overwrite_a = overwrite_a or (a1 is not a and not hasattr(a,'__array__'))
    geqrf, = get_lapack_funcs(('geqrf',),(a1,))
    if lwork is None or lwork == -1:
        # get optimal work array
        qr,tau,work,info = geqrf(a1,lwork=-1,overwrite_a=1)
        lwork = work[0]
    qr,tau,work,info = geqrf(a1,lwork=lwork,overwrite_a=overwrite_a)    
    if info<0: raise ValueError,\
       'illegal value in %-th argument of internal geqrf'%(-info)
    gemm, = get_blas_funcs(('gemm',),(qr,))
    t = qr.typecode()
    R = basic.triu(qr)
    Q = scipy_base.identity(M,typecode=t)
    ident = scipy_base.identity(M,typecode=t)
    zeros = scipy_base.zeros
    for i in range(min(M,N)):
        v = zeros((M,),t)
        v[i] = 1
        v[i+1:M] = qr[i+1:M,i]
        H = gemm(-tau[i],v,v,1+0j,ident,trans_b=2)
        Q = gemm(1,Q,H)
    return Q, R
Esempio n. 10
0
def pinv2(a, cond=None):
    """ pinv2(a, cond=None) -> a_pinv

    Compute the generalized inverse of A using svd.
    """
    a = asarray_chkfinite(a)
    u, s, vh = decomp.svd(a)
    t = u.typecode()
    if cond in [None,-1]:
        cond = {0: feps*1e3, 1: eps*1e6}[_array_precision[t]]
    m,n = a.shape
    cutoff = cond*scipy_base.maximum.reduce(s)
    psigma = zeros((m,n),t)
    for i in range(len(s)):
        if s[i] > cutoff:
            psigma[i,i] = 1.0/conjugate(s[i])
    #XXX: use lapack/blas routines for dot
    return transpose(conjugate(dot(dot(u,psigma),vh)))
Esempio n. 11
0
def svd(a,compute_uv=1,overwrite_a=0):
    """Compute singular value decomposition (SVD) of matrix a.

    Description:

      Singular value decomposition of a matrix a is
        a = u * sigma * v^H,
      where v^H denotes conjugate(transpose(v)), u,v are unitary
      matrices, sigma is zero matrix with a main diagonal containing
      real non-negative singular values of the matrix a.

    Inputs:

      a -- An M x N matrix.
      compute_uv -- If zero, then only the vector of singular values
                    is returned.

    Outputs:

      u -- An M x M unitary matrix [compute_uv=1].
      s -- An min(M,N) vector of singular values in descending order,
           sigma = diagsvd(s).
      vh -- An N x N unitary matrix [compute_uv=1], vh = v^H.

    """
    a1 = asarray_chkfinite(a)
    if len(a1.shape) != 2:
        raise ValueError, 'expected matrix'
    m,n = a1.shape
    overwrite_a = overwrite_a or (a1 is not a and not hasattr(a,'__array__'))
    gesdd, = get_lapack_funcs(('gesdd',),(a1,))
    if gesdd.module_name[:7] == 'flapack':
        lwork = calc_lwork.gesdd(gesdd.prefix,m,n,compute_uv)[1]
        u,s,v,info = gesdd(a1,compute_uv = compute_uv, lwork = lwork,
                      overwrite_a = overwrite_a)
    else: # 'clapack'
        raise NotImplementedError,'calling gesdd from %s' % (gesdd.module_name)
    if info>0: raise LinAlgError, "SVD did not converge"
    if info<0: raise ValueError,\
       'illegal value in %-th argument of internal gesdd'%(-info)
    if compute_uv:
        return u,s,v
    else:
        return s
Esempio n. 12
0
def cholesky(a,lower=0,overwrite_a=0):
    """Compute Cholesky decomposition of matrix.

    Description:

      For a hermitian positive-definite matrix a return the
      upper-triangular (or lower-triangular if lower==1) matrix,
      u such that u^H * u = a (or l * l^H = a).

    """
    a1 = asarray_chkfinite(a)
    if len(a1.shape) != 2 or a1.shape[0] != a1.shape[1]:
        raise ValueError, 'expected square matrix'
    overwrite_a = overwrite_a or (a1 is not a and not hasattr(a,'__array__'))
    potrf, = get_lapack_funcs(('potrf',),(a1,))
    c,info = potrf(a1,lower=lower,overwrite_a=overwrite_a,clean=1)
    if info>0: raise LinAlgError, "matrix not positive definite"
    if info<0: raise ValueError,\
       'illegal value in %-th argument of internal potrf'%(-info)
    return c
Esempio n. 13
0
def lu(a,permute_l=0,overwrite_a=0):
    """Return LU decompostion of a matrix.

    Inputs:

      a     -- An M x N matrix.
      permute_l  -- Perform matrix multiplication p * l [disabled].

    Outputs:

      p,l,u  -- LU decomposition matrices of a [permute_l=0]
      pl,u   -- LU decomposition matrices of a [permute_l=1]

    Definitions:
      
      a = p * l * u    [permute_l=0]
      a = pl * u       [permute_l=1]

      p   -  An M x M permutation matrix
      l   -  An M x K lower triangular or trapezoidal matrix
             with unit-diagonal
      u   -  An K x N upper triangular or trapezoidal matrix
             K = min(M,N)
    """
    a1 = asarray_chkfinite(a)
    if len(a1.shape) != 2:
        raise ValueError, 'expected matrix'
    m,n = a1.shape
    overwrite_a = overwrite_a or (a1 is not a and not hasattr(a,'__array__'))
    flu, = get_flinalg_funcs(('lu',),(a1,))
    p,l,u,info = flu(a1,permute_l=permute_l,overwrite_a = overwrite_a)
    if info<0: raise ValueError,\
       'illegal value in %-th argument of internal lu.getrf'%(-info)
    if permute_l:
        return l,u
    return p,l,u
Esempio n. 14
0
    Solve a system of equations given a previously factored matrix

    Inputs:

      (lu,piv) -- The factored matrix, a (the output of lu_factor)
      b        -- a set of right-hand sides
      trans    -- type of system to solve:
                  0 : a   * x = b   (no transpose)
                  1 : a^T * x = b   (transpose) 
                  2   a^H * x = b   (conjugate transpose)

    Outputs:

       x -- the solution to the system
    """
    b1 = asarray_chkfinite(b)
    overwrite_b = overwrite_b or (b1 is not b and not hasattr(b,'__array__'))
    if lu.shape[0] != b1.shape[0]:
        raise ValuError, "incompatible dimensions."
    getrs, = get_lapack_funcs(('getrs',),(lu,b1))
    x,info = getrs(lu,piv,b1,trans=trans,overwrite_b=overwrite_b)
    if info==0:
        return x
    raise ValueError,\
          'illegal value in %-th argument of internal gesv|posv'%(-info)

def cho_solve((c, lower), b, overwrite_b=0):
    """ cho_solve((c, lower), b, overwrite_b=0) -> x

    Solve a system of equations given a previously cholesky factored matrix
Esempio n. 15
0
def eig(a,b=None,left=0,right=1,overwrite_a=0,overwrite_b=0):
    """ Solve ordinary and generalized eigenvalue problem
    of a square matrix.

    Inputs:

      a     -- An N x N matrix.
      b     -- An N x N matrix [default is identity(N)].
      left  -- Return left eigenvectors [disabled].
      right -- Return right eigenvectors [enabled].

    Outputs:

      w      -- eigenvalues [left==right==0].
      w,vr   -- w and right eigenvectors [left==0,right=1].
      w,vl   -- w and left eigenvectors [left==1,right==0].
      w,vl,vr  -- [left==right==1].

    Definitions:
      
      a * vr[:,i] = w[i] * b * vr[:,i]

      a^H * vl[:,i] = conjugate(w[i]) * b^H * vl[:,i]

    where a^H denotes transpose(conjugate(a)).
    """
    a1 = asarray_chkfinite(a)
    if len(a1.shape) != 2 or a1.shape[0] != a1.shape[1]:
        raise ValueError, 'expected square matrix'
    overwrite_a = overwrite_a or (a1 is not a and not hasattr(a,'__array__'))
    if b is not None:
        b = asarray_chkfinite(b)
        return _geneig(a1,b,left,right,overwrite_a,overwrite_b)
    geev, = get_lapack_funcs(('geev',),(a1,))
    compute_vl,compute_vr=left,right
    if geev.module_name[:7] == 'flapack':
        lwork = calc_lwork.geev(geev.prefix,a1.shape[0],
                                compute_vl,compute_vr)[1]
        if geev.prefix in 'cz':
            w,vl,vr,info = geev(a1,lwork = lwork,
                                compute_vl=compute_vl,
                                compute_vr=compute_vr,
                                overwrite_a=overwrite_a)
        else:
            wr,wi,vl,vr,info = geev(a1,lwork = lwork,
                                    compute_vl=compute_vl,
                                    compute_vr=compute_vr,
                                    overwrite_a=overwrite_a)
            t = {'f':'F','d':'D'}[wr.typecode()]
            w = wr+_I*wi
    else: # 'clapack'
        if geev.prefix in 'cz':
            w,vl,vr,info = geev(a1,
                                compute_vl=compute_vl,
                                compute_vr=compute_vr,
                                overwrite_a=overwrite_a)
        else:
            wr,wi,vl,vr,info = geev(a1,
                                    compute_vl=compute_vl,
                                    compute_vr=compute_vr,
                                    overwrite_a=overwrite_a)
            t = {'f':'F','d':'D'}[wr.typecode()]
            w = wr+_I*wi
    if info<0: raise ValueError,\
       'illegal value in %-th argument of internal geev'%(-info)
    if info>0: raise LinAlgError,"eig algorithm did not converge"

    only_real = scipy_base.logical_and.reduce(scipy_base.equal(w.imag,0.0))
    if not (geev.prefix in 'cz' or only_real):
        t = w.typecode()
        if left:
            vl = _make_complex_eigvecs(w, vl, t)
        if right:
            vr = _make_complex_eigvecs(w, vr, t)
    if not (left or right):
        return w
    if left:
        if right:
            return w, vl, vr
        return w, vl
    return w, vr