def makeInnerMatrixLoopDivStmt (dataLayout):
    rowStrideX, colStrideX = makeRowAndColStrides ('X', dataLayout)
    rowStrideY, colStrideY = makeRowAndColStrides ('Y', dataLayout)

    X_r_c_expr = makeMultiVectorAref ('X_r', dataLayout, rowStrideX, colStrideX, '', 'c')
    Y_r_c_expr = makeMultiVectorAref ('Y_r', dataLayout, rowStrideY, colStrideY, '', 'c')

    substDict = {'X_r_c': X_r_c_expr, 'Y_r_c': Y_r_c_expr}
    return Template('${X_r_c} = (${X_r_c} + ${Y_r_c}) / A_rr').substitute (substDict)
def makeInitLoopBody (dataLayout, unitDiag):
    rowStrideX, colStrideX = makeRowAndColStrides ('X', dataLayout)
    rowStrideY, colStrideY = makeRowAndColStrides ('Y', dataLayout)
    X_r_c_expr = makeMultiVectorAref ('X_r', dataLayout, rowStrideX, colStrideX, '', 'c')
    Y_r_c_expr = makeMultiVectorAref ('Y_r', dataLayout, rowStrideY, colStrideY, '', 'c')
    if unitDiag:
        return X_r_c_expr + ' = ' + Y_r_c_expr
        return X_r_c_expr + ' = STS::zero ()'
def makeFunctionBody (defDict):
    '''Make the body of the implementation of the sparse triangular solve routine.

    The input dictionary defDict must have at least the following fields:

    upLo: 'lower' for lower triangular solve, or 'upper' for upper
      triangular solve.

    dataLayout: This describes how the multivectors' data are arranged
      in memory.  Currently we only accept 'column major' or 'row

    unitDiag: True if the routine is for a sparse matrix with unit
      diagonal (which is not stored explicitly in the sparse matrix),
      else False for a sparse matrix with explicitly stored diagonal

    conjugateMatrixElements: Whether to use the conjugate of each
      matrix element.'''

    upLo = defDict['upLo']
    dataLayout = defDict['dataLayout']
    unitDiag = defDict['unitDiag']
    conjugateMatrixElements = defDict['conjugateMatrixElements']            

    t = Template ('''{
  typedef Teuchos::ScalarTraits<RangeScalar> STS;

  for (${row_loop_expr}) {
    // Following line: Unit diag only
    const RangeScalar* const Y_r = &${Y_r_expr};
    RangeScalar* const X_r = &${X_r_expr};
    for (Ordinal c = 0; c < numColsX; ++c) {
    for (Ordinal k = ptr[r]; k < ptr[r+1]; ++k) {

    rowLoopExpr = makeCsrTriSolveRowLoopBounds (upLo, 'startRow', 'endRowPlusOne')
    rowStrideX, colStrideX = makeRowAndColStrides ('X', dataLayout)
    rowStrideY, colStrideY = makeRowAndColStrides ('Y', dataLayout)
    X_r_expr = makeMultiVectorAref ("X", dataLayout, rowStrideX, colStrideX, 'r', '')
    Y_r_expr = makeMultiVectorAref ("Y", dataLayout, rowStrideY, colStrideY, 'r', '')
    initLoopBody = makeInitLoopBody (dataLayout, unitDiag)
    if unitDiag:
        setDiagMatrixElt = ''
        setDiagMatrixElt = 'MatrixScalar A_rr = STS::zero ();'
    innerMatrixLoopBody = makeInnerMatrixLoopBody (dataLayout, unitDiag,

    if unitDiag:
        nonUnitDiagDivLoop = ''
        nonUnitDiagDivLoop = Template('''for (Ordinal c = 0; c < numColsX; ++c) {
      // This assumes the following:
      // 1. operator+(RangeScalar, DomainScalar) exists,
      // 2. it returns a result of a type T1 such that 
      //    operator/(T1, MatrixScalar) exists, and
      // 3. that in turn returns a result of a type T2 such that 
      //    operator=(RangeScalar, T2) exists.
    }''').substitute (inner_loop_div_stmt=makeInnerMatrixLoopDivStmt(dataLayout))

    return t.substitute (row_loop_expr=rowLoopExpr,
def makeInnerMatrixLoopBody (dataLayout, unitDiag, conjugateMatrixElements):
    '''Make the inner loop body in CSR sparse triangular solve.

    dataLayout: This describes how the multivectors' data are arranged
      in memory.  Currently we only accept 'column major' or 'row

    unitDiag: True if the routine is for a sparse matrix with unit
      diagonal (which is not stored explicitly in the sparse matrix),
      else False for a sparse matrix with explicitly stored diagonal

    conjugateMatrixElements: Whether to use the conjugate of each
      matrix element.'''

    if not unitDiag:
        t = Template ('''const MatrixScalar A_rj = ${matElt};
      const Ordinal j = ind[k];
      if (j == r) {
        // We merge repeated diagonal elements additively.
        A_rr += A_rj;
      else {
        const DomainScalar* const Y_j = &${Y_j_expr};
        for (Ordinal c = 0; c < numColsX; ++c) {
          // This assumes the following:
          // 1. operator*(MatrixScalar, DomainScalar) exists,
          // 2. it returns a result of a type T1 such that 
          //    operator*(RangeScalar, T1) exists,
          // 3. that in turn returns a result of a type T2 such that 
          //    operator-(RangeScalar, T2) exists, and 
          // 4. that in turn returns a result of a type T3 such that
          //    operator=(RangeScalar, T3) exists.
          // For example, this relies on the usual C++ type promotion rules 
          // if MatrixScalar = float, DomainScalar = float, and RangeScalar 
          // = double (the typical iterative refinement case).
        t = Template ('''const MatrixScalar A_rj = val[k];
      const Ordinal j = ind[k];
      const DomainScalar* const Y_j = &${Y_j_expr};
      for (Ordinal c = 0; c < numColsX; ++c) {
        // This assumes the following:
        // 1. operator*(MatrixScalar, DomainScalar) exists,
        // 2. it returns a result of a type T1 such that 
        //    operator*(RangeScalar, T1) exists,
        // 3. that in turn returns a result of a type T2 such that 
        //    operator-(RangeScalar, T2) exists, and 
        // 4. that in turn returns a result of a type T3 such that
        //    operator=(RangeScalar, T3) exists.
        // For example, this relies on the usual C++ type promotion rules 
        // if MatrixScalar = float, DomainScalar = float, and RangeScalar 
        // = double (the typical iterative refinement case).

    if conjugateMatrixElements:
        matElt = 'STS::conjugate (val[k])'
        matElt = 'val[k]'

    rowStrideY, colStrideY = makeRowAndColStrides ('Y', dataLayout)
    Y_j_expr = makeMultiVectorAref ('Y', dataLayout, rowStrideY, colStrideY, 'j', '')
    innerLoopMultStmt = makeInnerMatrixLoopMultStmt(dataLayout)
    innerLoopDivStmt = makeInnerMatrixLoopDivStmt(dataLayout)
    return t.substitute (matElt=matElt,
