コード例 #1
0
def pyamgx_solve(A, b, config = None, x0 = None):
    '''
    Uses the (experimental) pyamgx Python bindings to the Nvidia AMGX library to solve the system Ax=b on the GPU using multigrid.

    A: CSR format sparse matrix
    b: numpy array
    config: AMGX config. See AMGX github for details.
    x0: numpy array. Initial guess for the mgrid algorithm.

    Outputs a numpy array containing the solution to the equation system.
    '''
    pyamgx.initialize()
    #pyamgx.register_print_callback(lambda msg: print(''))
    try:#try-except block to call pyamgx.finalize() in the case an error occurs - subsequent calls with good inputs will fail otherwise
        if config is None:
            #default config copied directly from https://github.com/NVIDIA/AMGX/blob/master/core/configs/AMG_CLASSICAL_CG.json
            config = get_default_pyamgx_config()
            config = pyamgx.Config().create_from_dict(config)
        elif isinstance(config, dict):
            config = pyamgx.Config().create_from_dict(config)

        resources = pyamgx.Resources()
        resources.create_simple(config)
        #Allocate memory for variables on GPU
        A_pyamgx = pyamgx.Matrix()
        A_pyamgx.create(resources, mode='dDDI')
        A_pyamgx.upload_CSR(A)
    
        if not isinstance(b,np.ndarray):
            b = np.array(b)
        b = b.astype(np.float64)
        b_pyamgx = pyamgx.Vector()
        b_pyamgx.create(resources, mode='dDDI')
        b_pyamgx.upload(b)

        x = pyamgx.Vector().create(resources)
        x0 = x0 if x0 is not None else np.zeros(b.shape,dtype=b.dtype)
        x.upload(x0)
        #Solve system
        solver = pyamgx.Solver()
        solver.create(resources, config)
        solver.setup(A_pyamgx)
        solver.solve(b_pyamgx, x)
        rval = x.download()
        print(solver.get_residual())
        #Cleanup to prevent GPU memory leak
        solver.destroy()
        A_pyamgx.destroy()
        b_pyamgx.destroy()
        x.destroy()
        resources.destroy()
        config.destroy()
        pyamgx.finalize()
    
        return rval
    
    except:
        pyamgx.finalize()
        raise(RuntimeError('pyamgx variable creation or solver error. See stack trace.'))
コード例 #2
0
 def __del__(self):
     import pyamgx
     self._solver.destroy()
     self._rhs.destroy()
     self._phi_vec.destroy()
     self._matrix.destroy()
     self.resources.destroy()
     self.cfg.destroy()
     pyamgx.finalize()
コード例 #3
0
 def teardown_class(self):
     self.rsrc.destroy()
     self.cfg.destroy()
     pyamgx.finalize()
コード例 #4
0
ファイル: test_system.py プロジェクト: wd15/pyamgx
 def teardown(self):
     pyamgx.finalize()
コード例 #5
0
ファイル: test_config.py プロジェクト: wd15/pyamgx
 def teardown(self):
     self.cfg.destroy()
     pyamgx.finalize()
コード例 #6
0
# Create solver:
solver = pyamgx.Solver().create(rsc, cfg)

# Upload system:
M = sparse.csr_matrix(np.random.rand(5, 5))
rhs = np.random.rand(5)
sol = np.zeros(5, dtype=np.float64)

A.upload_CSR(M)
b.upload(rhs)
x.upload(sol)

# Setup and solve system:
solver.setup(A)
solver.solve(b, x)

# Download solution
x.download(sol)
print("pyamgx solution: ", sol)
print("scipy solution: ", splinalg.spsolve(M, rhs))

# Clean up:
A.destroy()
x.destroy()
b.destroy()
solver.destroy()
rsc.destroy()
cfg.destroy()

pyamgx.finalize()
コード例 #7
0
 def teardown_class(cls):
     pyamgx.finalize()