def test_create_basis_errors(monkeypatch): '''Check that the appropriate exceptions are raised when a) an evaluator shape is provided, as they are not yet supported, and b) an unrecognised quadrature or evaluator shape is found. ''' _, invoke_info = parse(os.path.join(BASE_PATH, "6.1_eval_invoke.f90"), api="dynamo0.3") psy = PSyFactory("dynamo0.3", distributed_memory=False).create(invoke_info) schedule = psy.invokes.invoke_list[0].schedule kernel = schedule[0].loop_body[0] kernel_interface = KernelInterface(kernel) # "w1" requires a basis function and is the first entry in the # unique function spaces list w1_fs = kernel.arguments.unique_fss[0] # Evaluator shapes are not yet supported. with pytest.raises(NotImplementedError) as info: kernel_interface.basis(w1_fs) assert ("Evaluator shapes not implemented in kernel_interface class." in str(info.value)) # Force an unsupported shape monkeypatch.setattr(kernel, "_eval_shapes", ["invalid_shape"]) with pytest.raises(InternalError) as info: kernel_interface.basis(w1_fs) assert ( "Unrecognised quadrature or evaluator shape 'invalid_shape'. " "Expected one of: ['gh_quadrature_xyoz', 'gh_quadrature_face', " "'gh_quadrature_edge', 'gh_evaluator']." in str(info.value))
def test_basis_face(): '''Test that the KernelInterface class basis method adds the expected classes to the symbol table and the _arglist list for face quadrature. ''' _, invoke_info = parse(os.path.join(BASE_PATH, "1.1.6_face_qr.f90"), api="dynamo0.3") psy = PSyFactory("dynamo0.3", distributed_memory=False).create(invoke_info) schedule = psy.invokes.invoke_list[0].schedule kernel = schedule[0].loop_body[0] # "w1" requires a basis function and is the first entry in the # unique function spaces list w1_fs = kernel.arguments.unique_fss[0] fs_name = w1_fs.orig_name kernel_interface = KernelInterface(kernel) kernel_interface.basis(w1_fs) # ndf declared ndf_symbol = kernel_interface._symbol_table.lookup( "ndf_{0}".format(fs_name)) assert isinstance(ndf_symbol, lfric_psyir.NumberOfDofsDataSymbol) assert isinstance(ndf_symbol.interface, ArgumentInterface) assert ( ndf_symbol.interface.access == kernel_interface._read_access.access) # nfaces declared nfaces_symbol = kernel_interface._symbol_table.lookup("nfaces") assert isinstance(nfaces_symbol, lfric_psyir.NumberOfFacesDataSymbol) assert isinstance(nfaces_symbol.interface, ArgumentInterface) assert ( nfaces_symbol.interface.access == kernel_interface._read_access.access) # nqp declared nqp_symbol = kernel_interface._symbol_table.lookup("nqp_faces") assert isinstance(nqp_symbol, lfric_psyir.NumberOfQrPointsInFacesDataSymbol) assert isinstance(nqp_symbol.interface, ArgumentInterface) assert ( nqp_symbol.interface.access == kernel_interface._read_access.access) # basis declared and added to argument list basis_symbol = kernel_interface._symbol_table.lookup("basis_w1_qr_face") assert isinstance(basis_symbol, lfric_psyir.BasisFunctionQrFaceDataSymbol) assert isinstance(basis_symbol.interface, ArgumentInterface) assert ( basis_symbol.interface.access == kernel_interface._read_access.access) assert kernel_interface._arglist[-1] is basis_symbol assert len(basis_symbol.shape) == 4 assert isinstance(basis_symbol.shape[0], Literal) assert basis_symbol.shape[0].value == "3" assert isinstance(basis_symbol.shape[1], Reference) assert basis_symbol.shape[1].symbol is ndf_symbol assert isinstance(basis_symbol.shape[2], Reference) assert basis_symbol.shape[2].symbol is nqp_symbol assert isinstance(basis_symbol.shape[3], Reference) assert basis_symbol.shape[3].symbol is nfaces_symbol