Exemplo n.º 1
0
def test_datasymbol_can_be_printed():
    '''Test that a DataSymbol instance can always be printed. (i.e. is
    initialised fully.)'''
    symbol = DataSymbol("sname", REAL_SINGLE_TYPE)
    assert "sname: <Scalar<REAL, SINGLE>, Local>" in str(symbol)

    sym1 = DataSymbol("s1",
                      INTEGER_SINGLE_TYPE,
                      interface=UnresolvedInterface())
    assert "s1: <Scalar<INTEGER, SINGLE>, Unresolved>" in str(sym1)

    array_type = ArrayType(REAL_SINGLE_TYPE,
                           [ArrayType.Extent.ATTRIBUTE, 2,
                            Reference(sym1)])
    sym2 = DataSymbol("s2", array_type)
    assert ("s2: <Array<Scalar<REAL, SINGLE>, shape=['ATTRIBUTE', "
            "Literal[value:'2', Scalar<INTEGER, UNDEFINED>], "
            "Reference[name:'s1']]>, Local>" in str(sym2))

    my_mod = ContainerSymbol("my_mod")
    sym3 = DataSymbol("s3",
                      REAL_SINGLE_TYPE,
                      interface=GlobalInterface(my_mod))
    assert ("s3: <Scalar<REAL, SINGLE>, Global(container='my_mod')>"
            in str(sym3))

    sym3 = DataSymbol("s3", INTEGER_SINGLE_TYPE, constant_value=12)
    assert ("s3: <Scalar<INTEGER, SINGLE>, Local, "
            "constant_value=Literal"
            "[value:'12', Scalar<INTEGER, SINGLE>]>" in str(sym3))

    sym4 = DataSymbol("s4",
                      INTEGER_SINGLE_TYPE,
                      interface=UnresolvedInterface())
    assert "s4: <Scalar<INTEGER, SINGLE>, Unresolved>" in str(sym4)
Exemplo n.º 2
0
def test_symbol_interface_setter():
    '''Test that the Symbol interface setter behaves as expected,
    including raising an exception if the input is of the wrong
    type. Also use this to test the is_local, is_global and
    is_argument and is_unresolved properties.

    '''
    symbol = Symbol('sym1')
    assert symbol.is_local
    assert not symbol.is_global
    assert not symbol.is_argument
    assert not symbol.is_unresolved

    symbol.interface = GlobalInterface(ContainerSymbol("my_mod"))
    assert not symbol.is_local
    assert symbol.is_global
    assert not symbol.is_argument
    assert not symbol.is_unresolved

    symbol.interface = ArgumentInterface()
    assert not symbol.is_local
    assert not symbol.is_global
    assert symbol.is_argument
    assert not symbol.is_unresolved

    symbol.interface = UnresolvedInterface()
    assert not symbol.is_local
    assert not symbol.is_global
    assert not symbol.is_argument
    assert symbol.is_unresolved

    with pytest.raises(TypeError) as info:
        symbol.interface = "hello"
    assert ("The interface to a Symbol must be a SymbolInterface but got "
            "'str'" in str(info.value))
Exemplo n.º 3
0
def test_symbol_initialisation():
    '''Test that a Symbol instance can be created when valid arguments are
    given, otherwise raise relevant exceptions. Also tests the
    internal Visibility class, the name, visibility and interface properties.

    '''
    sym = Symbol("sym1")
    assert isinstance(sym, Symbol)
    assert sym.name == "sym1"
    assert sym.visibility == Symbol.DEFAULT_VISIBILITY
    assert isinstance(sym.interface, LocalInterface)
    # Check that the default visibility is public
    assert Symbol.DEFAULT_VISIBILITY == Symbol.Visibility.PUBLIC

    sym = Symbol("sym2", Symbol.Visibility.PRIVATE)
    assert sym.visibility == Symbol.Visibility.PRIVATE

    sym = Symbol("sym3", interface=UnresolvedInterface())
    assert isinstance(sym.interface, UnresolvedInterface)

    with pytest.raises(TypeError) as error:
        sym = Symbol(None)
    assert ("Symbol 'name' attribute should be of type 'str'"
            in str(error.value))

    with pytest.raises(TypeError) as error:
        Symbol('sym1', visibility="hello")
    assert ("Symbol 'visibility' attribute should be of type "
            "psyir.symbols.Symbol.Visibility but" in str(error.value))

    with pytest.raises(TypeError) as error:
        Symbol('sym1', interface="hello")
    assert ("The interface to a Symbol must be a SymbolInterface but got "
            "'str'" in str(error.value))
Exemplo n.º 4
0
def test_unresolvedinterface():
    '''Test we can create an UnresolvedInterface instance and check its
    __str__ value

    '''
    interface = UnresolvedInterface()
    assert str(interface) == "Unresolved"
Exemplo n.º 5
0
def test_routinesymbol_init():
    '''Test that a RoutineSymbol instance can be created.'''

    assert isinstance(RoutineSymbol('jo'), RoutineSymbol)
    assert isinstance(
        RoutineSymbol('ellie', visibility=Symbol.Visibility.PRIVATE),
        RoutineSymbol)
    assert isinstance(
        RoutineSymbol('isaac', interface=UnresolvedInterface()),
        RoutineSymbol)
Exemplo n.º 6
0
def test_datatypesymbol_copy():
    ''' Check that a DataTypeSymbol can be copied. '''
    symbol = DataTypeSymbol("my_type", DeferredType(),
                            visibility=Symbol.Visibility.PRIVATE,
                            interface=UnresolvedInterface())
    new_symbol = symbol.copy()
    assert new_symbol is not symbol
    assert new_symbol.name == "my_type"
    assert isinstance(new_symbol.datatype, DeferredType)
    assert new_symbol.visibility == Symbol.Visibility.PRIVATE
    assert isinstance(new_symbol.interface, UnresolvedInterface)
Exemplo n.º 7
0
def test_datasymbol_initialisation():
    '''Test that a DataSymbol instance can be created when valid arguments are
    given, otherwise raise relevant exceptions.'''

    # Test with valid arguments
    assert isinstance(DataSymbol('a', REAL_SINGLE_TYPE), DataSymbol)
    assert isinstance(DataSymbol('a', REAL_DOUBLE_TYPE), DataSymbol)
    assert isinstance(DataSymbol('a', REAL4_TYPE), DataSymbol)
    kind = DataSymbol('r_def', INTEGER_SINGLE_TYPE)
    real_kind_type = ScalarType(ScalarType.Intrinsic.REAL, kind)
    assert isinstance(DataSymbol('a', real_kind_type),
                      DataSymbol)
    # real constants are not currently supported
    assert isinstance(DataSymbol('a', INTEGER_SINGLE_TYPE), DataSymbol)
    assert isinstance(DataSymbol('a', INTEGER_DOUBLE_TYPE, constant_value=0),
                      DataSymbol)
    assert isinstance(DataSymbol('a', INTEGER4_TYPE),
                      DataSymbol)

    assert isinstance(DataSymbol('a', CHARACTER_TYPE), DataSymbol)
    assert isinstance(DataSymbol('a', CHARACTER_TYPE,
                                 constant_value="hello"), DataSymbol)
    assert isinstance(DataSymbol('a', BOOLEAN_TYPE), DataSymbol)
    assert isinstance(DataSymbol('a', BOOLEAN_TYPE,
                                 constant_value=False),
                      DataSymbol)
    array_type = ArrayType(REAL_SINGLE_TYPE, [ArrayType.Extent.ATTRIBUTE])
    assert isinstance(DataSymbol('a', array_type), DataSymbol)

    array_type = ArrayType(REAL_SINGLE_TYPE, [3])
    assert isinstance(DataSymbol('a', array_type), DataSymbol)
    array_type = ArrayType(REAL_SINGLE_TYPE, [3, ArrayType.Extent.ATTRIBUTE])
    assert isinstance(DataSymbol('a', array_type), DataSymbol)
    assert isinstance(DataSymbol('a', REAL_SINGLE_TYPE), DataSymbol)
    assert isinstance(DataSymbol('a', REAL8_TYPE), DataSymbol)
    dim = DataSymbol('dim', INTEGER_SINGLE_TYPE,
                     interface=UnresolvedInterface())
    array_type = ArrayType(REAL_SINGLE_TYPE, [Reference(dim)])
    assert isinstance(DataSymbol('a', array_type), DataSymbol)
    array_type = ArrayType(REAL_SINGLE_TYPE,
                           [3, Reference(dim), ArrayType.Extent.ATTRIBUTE])
    assert isinstance(DataSymbol('a', array_type), DataSymbol)
    assert isinstance(
        DataSymbol('a', REAL_SINGLE_TYPE,
                   interface=ArgumentInterface(
                       ArgumentInterface.Access.READWRITE)), DataSymbol)
    assert isinstance(
        DataSymbol('a', REAL_SINGLE_TYPE,
                   visibility=Symbol.Visibility.PRIVATE), DataSymbol)
    assert isinstance(DataSymbol('field', DataTypeSymbol("field_type",
                                                         DeferredType())),
                      DataSymbol)
Exemplo n.º 8
0
def test_datasymbol_scalar_array():
    '''Test that the DataSymbol property is_scalar returns True if the
    DataSymbol is a scalar and False if not and that the DataSymbol property
    is_array returns True if the DataSymbol is an array and False if not.

    '''
    sym1 = DataSymbol("s1", INTEGER_SINGLE_TYPE,
                      interface=UnresolvedInterface())
    array_type = ArrayType(REAL_SINGLE_TYPE,
                           [ArrayType.Extent.ATTRIBUTE, 2, Reference(sym1)])
    sym2 = DataSymbol("s2", array_type)
    assert sym1.is_scalar
    assert not sym1.is_array
    assert not sym2.is_scalar
    assert sym2.is_array
Exemplo n.º 9
0
def test_oclw_kernelschedule():
    '''Check the OpenCLWriter class kernelschedule_node visitor produces
    the expected OpenCL code.

    '''

    # The kernelschedule OpenCL Backend relies on abstract methods that
    # need to be implemented by the APIs. A generic kernelschedule will
    # produce a NotImplementedError.
    oclwriter = OpenCLWriter()
    kschedule = KernelSchedule("kname")
    with pytest.raises(NotImplementedError) as error:
        _ = oclwriter(kschedule)
    assert "Abstract property. Which symbols are data arguments is " \
        "API-specific." in str(error.value)

    # Mock abstract properties. (pytest monkeypatch does not work
    # with properties, used sub-class instead)
    class MockSymbolTable(SymbolTable):
        ''' Mock needed abstract methods of the Symbol Table '''
        @property
        def iteration_indices(self):
            return self.argument_list[:2]

        @property
        def data_arguments(self):
            return self.argument_list[2:]

    kschedule.symbol_table.__class__ = MockSymbolTable

    # Create a sample symbol table and kernel schedule
    interface = ArgumentInterface(ArgumentInterface.Access.UNKNOWN)
    i = DataSymbol('i', INTEGER_TYPE, interface=interface)
    j = DataSymbol('j', INTEGER_TYPE, interface=interface)
    array_type = ArrayType(REAL_TYPE, [10, 10])
    data1 = DataSymbol('data1', array_type, interface=interface)
    data2 = DataSymbol('data2', array_type, interface=interface)
    kschedule.symbol_table.add(i)
    kschedule.symbol_table.add(j)
    kschedule.symbol_table.add(data1)
    kschedule.symbol_table.add(data2)
    kschedule.symbol_table.specify_argument_list([i, j, data1, data2])
    kschedule.addchild(Return(parent=kschedule))

    result = oclwriter(kschedule)
    assert result == "" \
        "__kernel void kname(\n" \
        "  __global double * restrict data1,\n" \
        "  __global double * restrict data2\n" \
        "  ){\n" \
        "  int data1LEN1 = get_global_size(0);\n" \
        "  int data1LEN2 = get_global_size(1);\n" \
        "  int data2LEN1 = get_global_size(0);\n" \
        "  int data2LEN2 = get_global_size(1);\n" \
        "  int i = get_global_id(0);\n" \
        "  int j = get_global_id(1);\n" \
        "  return;\n" \
        "}\n\n"

    # Set a local_size value different to 1 into the KernelSchedule
    oclwriter = OpenCLWriter(kernels_local_size=4)
    result = oclwriter(kschedule)

    assert result == "" \
        "__attribute__((reqd_work_group_size(4, 1, 1)))\n" \
        "__kernel void kname(\n" \
        "  __global double * restrict data1,\n" \
        "  __global double * restrict data2\n" \
        "  ){\n" \
        "  int data1LEN1 = get_global_size(0);\n" \
        "  int data1LEN2 = get_global_size(1);\n" \
        "  int data2LEN1 = get_global_size(0);\n" \
        "  int data2LEN2 = get_global_size(1);\n" \
        "  int i = get_global_id(0);\n" \
        "  int j = get_global_id(1);\n" \
        "  return;\n" \
        "}\n\n"

    # Add a symbol with a deferred interface and check that this raises the
    # expected error
    array_type = ArrayType(REAL_TYPE, [10, 10])
    kschedule.symbol_table.add(
        DataSymbol('broken', array_type, interface=UnresolvedInterface()))
    with pytest.raises(VisitorError) as err:
        _ = oclwriter(kschedule)
    assert ("symbol table contains unresolved data entries (i.e. that have no "
            "defined Interface) which are not used purely to define the "
            "precision of other symbols: 'broken'" in str(err.value))