def make_solver(self, A):
        offset = np.array(A.RowOffsets().ToList(), dtype=int)
        rows = A.NumRowBlocks()
        cols = A.NumColBlocks()
        
        local_size = np.diff(offset)
        x = allgather_vector(local_size)
        global_size = np.sum(x.reshape(num_proc,-1), 0)
        nicePrint(local_size)

        global_offset = np.hstack(([0], np.cumsum(global_size)))
        global_roffset = global_offset + offset
        print global_offset

        new_offset = np.hstack(([0], np.cumsum(x)))[:-1]
#                                np.cumsum(x.reshape(2,-1).transpose().flatten())))
        new_size =   x.reshape(num_proc, -1)
        new_offset = new_offset.reshape(num_proc, -1)
        print new_offset
        
        #index_mapping
        def blk_stm_idx_map(i):
            stm_idx = [new_offset[kk, i]+
                       np.arange(new_size[kk, i], dtype=int)
                       for kk in range(num_proc)]
            return np.hstack(stm_idx)
        
        map = [blk_stm_idx_map(i) for i in range(rows)]
            

        newi = []
        newj = []
        newd = []
        nrows = np.sum(local_size)
        ncols = np.sum(global_size)
        
        for i in range(rows):
            for j in range(cols):
                 m = self.get_block(A, i, j)
                 if m is None: continue
#                      num_rows, ilower, iupper, jlower, jupper, irn, jcn, data = 0, 0, 0, 0, 0, np.array([0,0]), np.array([0,0]), np.array([0,0])
#                 else:
                 num_rows, ilower, iupper, jlower, jupper, irn, jcn, data = m.GetCooDataArray()

                 irn = irn         #+ global_roffset[i]
                 jcn = jcn         #+ global_offset[j]

                 nicePrint(i, j, map[i].shape, map[i])
                 nicePrint(irn)
                 irn2 = map[i][irn]
                 jcn2 = map[j][jcn]

                 newi.append(irn2)
                 newj.append(jcn2)
                 newd.append(data)

        newi = np.hstack(newi)
        newj = np.hstack(newj)
        newd = np.hstack(newd)

        from scipy.sparse import coo_matrix

        nicePrint(new_offset)
        nicePrint((nrows, ncols),)
        nicePrint('newJ', np.min(newj), np.max(newj))
        nicePrint('newI', np.min(newi)-new_offset[myid, 0],
                          np.max(newi)-new_offset[myid, 0])
        mat = coo_matrix((newd,(newi-new_offset[myid, 0], newj)),
                          shape=(nrows, ncols),
                          dtype=newd.dtype).tocsr()
        
        AA = ToHypreParCSR(mat)

        import mfem.par.strumpack as strmpk
        Arow = strmpk.STRUMPACKRowLocMatrix(AA)

        args = []
        if self.hss:
            args.extend(["--sp_enable_hss", 
                         "--hss_verbose", 
                         "--sp_hss_min_sep_size",
                         str(int(self.hss_front_size)),
                         "--hss_rel_tol",
                         str(0.01),
                         "--hss_abs_tol",                         
                         str(1e-4),])
        print self.maxiter
        args.extend(["--sp_maxit", str(int(self.maxiter))])
        args.extend(["--sp_rel_tol", str(self.rctol)])
        args.extend(["--sp_abs_tol", str(self.actol)])        
        args.extend(["--sp_gmres_restart", str(int(self.gmres_restart))])

        strumpack = strmpk.STRUMPACKSolver(args, MPI.COMM_WORLD)
        
        if self.gui.log_level == 0:
            strumpack.SetPrintFactorStatistics(False)
            strumpack.SetPrintSolveStatistics(False)
        elif self.gui.log_level == 1:
            strumpack.SetPrintFactorStatistics(True)
            strumpack.SetPrintSolveStatistics(False)
        else:
            strumpack.SetPrintFactorStatistics(True)
            strumpack.SetPrintSolveStatistics(True)

        strumpack.SetKrylovSolver(strmpk.KrylovSolver_DIRECT);
        strumpack.SetReorderingStrategy(strmpk.ReorderingStrategy_METIS)
        strumpack.SetMC64Job(strmpk.MC64Job_NONE)
        # strumpack.SetSymmetricPattern(True)
        strumpack.SetOperator(Arow)
        strumpack.SetFromCommandLine()

        strumpack._mapper = map
        return strumpack        
Beispiel #2
0
def run(order = 1, static_cond = False,
        meshfile = def_meshfile, visualization = False,
        use_strumpack = False):

   mesh = mfem.Mesh(meshfile, 1,1)
   dim = mesh.Dimension()

   ref_levels = int(np.floor(np.log(10000./mesh.GetNE())/np.log(2.)/dim))
   for x in range(ref_levels):
      mesh.UniformRefinement();
   mesh.ReorientTetMesh();
   pmesh = mfem.ParMesh(MPI.COMM_WORLD, mesh)
   del mesh

   par_ref_levels = 2
   for l in range(par_ref_levels):
       pmesh.UniformRefinement();

   if order > 0:
       fec = mfem.H1_FECollection(order, dim)
   elif mesh.GetNodes():
       fec = mesh.GetNodes().OwnFEC()
       print( "Using isoparametric FEs: " + str(fec.Name()));
   else:
       order = 1
       fec = mfem.H1_FECollection(order, dim)

   fespace =mfem.ParFiniteElementSpace(pmesh, fec)
   fe_size = fespace.GlobalTrueVSize()

   if (myid == 0):
      print('Number of finite element unknowns: '+  str(fe_size))

   ess_tdof_list = mfem.intArray()
   if pmesh.bdr_attributes.Size()>0:
       ess_bdr = mfem.intArray(pmesh.bdr_attributes.Max())
       ess_bdr.Assign(1)
       fespace.GetEssentialTrueDofs(ess_bdr, ess_tdof_list)

   #   the basis functions in the finite element fespace.
   b = mfem.ParLinearForm(fespace)
   one = mfem.ConstantCoefficient(1.0)
   b.AddDomainIntegrator(mfem.DomainLFIntegrator(one))
   b.Assemble();

   x = mfem.ParGridFunction(fespace);
   x.Assign(0.0)

   a = mfem.ParBilinearForm(fespace);
   a.AddDomainIntegrator(mfem.DiffusionIntegrator(one))

   if static_cond: a.EnableStaticCondensation()
   a.Assemble();

   A = mfem.HypreParMatrix()
   B = mfem.Vector()
   X = mfem.Vector()
   a.FormLinearSystem(ess_tdof_list, x, b, A, X, B)

   if (myid == 0):
      print("Size of linear system: " + str(x.Size()))
      print("Size of linear system: " + str(A.GetGlobalNumRows()))

   if use_strumpack:
       import mfem.par.strumpack as strmpk
       Arow = strmpk.STRUMPACKRowLocMatrix(A)
       args = ["--sp_hss_min_sep_size", "128", "--sp_enable_hss"]
       strumpack = strmpk.STRUMPACKSolver(args, MPI.COMM_WORLD)
       strumpack.SetPrintFactorStatistics(True)
       strumpack.SetPrintSolveStatistics(False)
       strumpack.SetKrylovSolver(strmpk.KrylovSolver_DIRECT);
       strumpack.SetReorderingStrategy(strmpk.ReorderingStrategy_METIS)
       strumpack.SetMC64Job(strmpk.MC64Job_NONE)
       # strumpack.SetSymmetricPattern(True)
       strumpack.SetOperator(Arow)
       strumpack.SetFromCommandLine()
       strumpack.Mult(B, X);

   else:
       amg = mfem.HypreBoomerAMG(A)
       cg = mfem.CGSolver(MPI.COMM_WORLD)
       cg.SetRelTol(1e-12)
       cg.SetMaxIter(200)
       cg.SetPrintLevel(1)
       cg.SetPreconditioner(amg)
       cg.SetOperator(A)
       cg.Mult(B, X);


   a.RecoverFEMSolution(X, b, x)

   smyid = '{:0>6d}'.format(myid)
   mesh_name  =  "mesh."+smyid
   sol_name   =  "sol."+smyid

   pmesh.Print(mesh_name, 8)
   x.Save(sol_name, 8)
Beispiel #3
0
a.Finalize()

m = mfem.ParBilinearForm(fespace)
m.AddDomainIntegrator(mfem.MassIntegrator(one))
m.Assemble()

# shift the eigenvalue corresponding to eliminated dofs to a large value
m.EliminateEssentialBCDiag(ess_bdr, 3.0e-300)
m.Finalize()

A = a.ParallelAssemble()
M = m.ParallelAssemble()

if use_strumpack:
    import mfem.par.strumpack as strmpk
    Arow = strmpk.STRUMPACKRowLocMatrix(A)

# 8. Define and configure the LOBPCG eigensolver and the BoomerAMG
#    preconditioner for A to be used within the solver. Set the matrices
#    which define the generalized eigenproblem A x = lambda M x.
#    We don't support SuperLU

if use_strumpack:
    args = ["--sp_hss_min_sep_size", "128", "--sp_enable_hss"]
    strumpack = strmpk.STRUMPACKSolver(args, MPI.COMM_WORLD)
    strumpack.SetPrintFactorStatistics(True)
    strumpack.SetPrintSolveStatistics(False)
    strumpack.SetKrylovSolver(strmpk.KrylovSolver_DIRECT)
    strumpack.SetReorderingStrategy(strmpk.ReorderingStrategy_METIS)
    strumpack.SetMC64Job(strmpk.MC64Job_NONE)
    # strumpack.SetSymmetricPattern(True)