def test_distributed_matrix_global_index(na, nev, nblk): import numpy as np from pyelpa import ProcessorLayout, DistributedMatrix from mpi4py import MPI comm = MPI.COMM_WORLD layout_p = ProcessorLayout(comm) for dtype in [np.float64, np.complex128]: a = DistributedMatrix(layout_p, na, nev, nblk, dtype=dtype) for local_row in range(a.na_rows): for local_col in range(a.na_cols): global_row, global_col = a.get_global_index( local_row, local_col) l_row, l_col = a.get_local_index(global_row, global_col) assert (global_row >= 0 and global_row < a.na) assert (global_col >= 0 and global_col < a.na) assert (local_row == l_row and local_col == l_col)
def test_setting_global_matrix(na, nev, nblk): import numpy as np from pyelpa import ProcessorLayout, DistributedMatrix from mpi4py import MPI comm = MPI.COMM_WORLD layout_p = ProcessorLayout(comm) for dtype in [np.float64, np.complex128]: a = DistributedMatrix(layout_p, na, nev, nblk, dtype=dtype) # get global matrix that is equal on all cores matrix = get_random_vector(na * na).reshape(na, na).astype(dtype) a.set_data_from_global_matrix(matrix) # check data for global_row in range(a.na): for global_col in range(a.na): if not a.is_local_index(global_row, global_col): continue local_row, local_col = a.get_local_index( global_row, global_col) assert (a.data[local_row, local_col] == matrix[global_row, global_col])
def test_distributed_matrix_indexing_loop(na, nev, nblk): import numpy as np from pyelpa import ProcessorLayout, DistributedMatrix from mpi4py import MPI comm = MPI.COMM_WORLD layout_p = ProcessorLayout(comm) for dtype in [np.float64, np.complex128]: a = DistributedMatrix(layout_p, na, nev, nblk, dtype=dtype) for local_row in range(a.na_rows): for local_col in range(a.na_cols): global_row, global_col = a.get_global_index( local_row, local_col) a.data[local_row, local_col] = global_row * 10 + global_col for global_row in range(a.na): for global_col in range(a.na): if not a.is_local_index(global_row, global_col): continue local_row, local_col = a.get_local_index( global_row, global_col) assert (a.data[local_row, local_col] == global_row * 10 + global_col)