Esempio n. 1
0
def test_create_basis_errors(monkeypatch):
    '''Check that the appropriate exceptions are raised when a) an
    evaluator shape is provided, as they are not yet supported, and b)
    an unrecognised quadrature or evaluator shape is found.

    '''
    _, invoke_info = parse(os.path.join(BASE_PATH, "6.1_eval_invoke.f90"),
                           api="dynamo0.3")
    psy = PSyFactory("dynamo0.3", distributed_memory=False).create(invoke_info)
    schedule = psy.invokes.invoke_list[0].schedule
    kernel = schedule[0].loop_body[0]
    kernel_interface = KernelInterface(kernel)

    # "w1" requires a basis function and is the first entry in the
    # unique function spaces list
    w1_fs = kernel.arguments.unique_fss[0]
    # Evaluator shapes are not yet supported.
    with pytest.raises(NotImplementedError) as info:
        kernel_interface.basis(w1_fs)
    assert ("Evaluator shapes not implemented in kernel_interface class."
            in str(info.value))
    # Force an unsupported shape
    monkeypatch.setattr(kernel, "_eval_shapes", ["invalid_shape"])
    with pytest.raises(InternalError) as info:
        kernel_interface.basis(w1_fs)
        assert (
            "Unrecognised quadrature or evaluator shape 'invalid_shape'. "
            "Expected one of: ['gh_quadrature_xyoz', 'gh_quadrature_face', "
            "'gh_quadrature_edge', 'gh_evaluator']." in str(info.value))
Esempio n. 2
0
def test_basis_face():
    '''Test that the KernelInterface class basis method adds the expected
    classes to the symbol table and the _arglist list for face
    quadrature.

    '''
    _, invoke_info = parse(os.path.join(BASE_PATH, "1.1.6_face_qr.f90"),
                           api="dynamo0.3")
    psy = PSyFactory("dynamo0.3", distributed_memory=False).create(invoke_info)
    schedule = psy.invokes.invoke_list[0].schedule
    kernel = schedule[0].loop_body[0]

    # "w1" requires a basis function and is the first entry in the
    # unique function spaces list
    w1_fs = kernel.arguments.unique_fss[0]
    fs_name = w1_fs.orig_name

    kernel_interface = KernelInterface(kernel)
    kernel_interface.basis(w1_fs)

    # ndf declared
    ndf_symbol = kernel_interface._symbol_table.lookup(
        "ndf_{0}".format(fs_name))
    assert isinstance(ndf_symbol, lfric_psyir.NumberOfDofsDataSymbol)
    assert isinstance(ndf_symbol.interface, ArgumentInterface)
    assert (
        ndf_symbol.interface.access == kernel_interface._read_access.access)
    # nfaces declared
    nfaces_symbol = kernel_interface._symbol_table.lookup("nfaces")
    assert isinstance(nfaces_symbol, lfric_psyir.NumberOfFacesDataSymbol)
    assert isinstance(nfaces_symbol.interface, ArgumentInterface)
    assert (
        nfaces_symbol.interface.access == kernel_interface._read_access.access)
    # nqp declared
    nqp_symbol = kernel_interface._symbol_table.lookup("nqp_faces")
    assert isinstance(nqp_symbol,
                      lfric_psyir.NumberOfQrPointsInFacesDataSymbol)
    assert isinstance(nqp_symbol.interface, ArgumentInterface)
    assert (
        nqp_symbol.interface.access == kernel_interface._read_access.access)
    # basis declared and added to argument list
    basis_symbol = kernel_interface._symbol_table.lookup("basis_w1_qr_face")
    assert isinstance(basis_symbol, lfric_psyir.BasisFunctionQrFaceDataSymbol)
    assert isinstance(basis_symbol.interface, ArgumentInterface)
    assert (
        basis_symbol.interface.access == kernel_interface._read_access.access)
    assert kernel_interface._arglist[-1] is basis_symbol
    assert len(basis_symbol.shape) == 4
    assert isinstance(basis_symbol.shape[0], Literal)
    assert basis_symbol.shape[0].value == "3"
    assert isinstance(basis_symbol.shape[1], Reference)
    assert basis_symbol.shape[1].symbol is ndf_symbol
    assert isinstance(basis_symbol.shape[2], Reference)
    assert basis_symbol.shape[2].symbol is nqp_symbol
    assert isinstance(basis_symbol.shape[3], Reference)
    assert basis_symbol.shape[3].symbol is nfaces_symbol