def test_datasymbol_can_be_printed(): '''Test that a DataSymbol instance can always be printed. (i.e. is initialised fully.)''' symbol = DataSymbol("sname", REAL_SINGLE_TYPE) assert "sname: <Scalar<REAL, SINGLE>, Local>" in str(symbol) sym1 = DataSymbol("s1", INTEGER_SINGLE_TYPE, interface=UnresolvedInterface()) assert "s1: <Scalar<INTEGER, SINGLE>, Unresolved>" in str(sym1) array_type = ArrayType(REAL_SINGLE_TYPE, [ArrayType.Extent.ATTRIBUTE, 2, Reference(sym1)]) sym2 = DataSymbol("s2", array_type) assert ("s2: <Array<Scalar<REAL, SINGLE>, shape=['ATTRIBUTE', " "Literal[value:'2', Scalar<INTEGER, UNDEFINED>], " "Reference[name:'s1']]>, Local>" in str(sym2)) my_mod = ContainerSymbol("my_mod") sym3 = DataSymbol("s3", REAL_SINGLE_TYPE, interface=GlobalInterface(my_mod)) assert ("s3: <Scalar<REAL, SINGLE>, Global(container='my_mod')>" in str(sym3)) sym3 = DataSymbol("s3", INTEGER_SINGLE_TYPE, constant_value=12) assert ("s3: <Scalar<INTEGER, SINGLE>, Local, " "constant_value=Literal" "[value:'12', Scalar<INTEGER, SINGLE>]>" in str(sym3)) sym4 = DataSymbol("s4", INTEGER_SINGLE_TYPE, interface=UnresolvedInterface()) assert "s4: <Scalar<INTEGER, SINGLE>, Unresolved>" in str(sym4)
def test_symbol_interface_setter(): '''Test that the Symbol interface setter behaves as expected, including raising an exception if the input is of the wrong type. Also use this to test the is_local, is_global and is_argument and is_unresolved properties. ''' symbol = Symbol('sym1') assert symbol.is_local assert not symbol.is_global assert not symbol.is_argument assert not symbol.is_unresolved symbol.interface = GlobalInterface(ContainerSymbol("my_mod")) assert not symbol.is_local assert symbol.is_global assert not symbol.is_argument assert not symbol.is_unresolved symbol.interface = ArgumentInterface() assert not symbol.is_local assert not symbol.is_global assert symbol.is_argument assert not symbol.is_unresolved symbol.interface = UnresolvedInterface() assert not symbol.is_local assert not symbol.is_global assert not symbol.is_argument assert symbol.is_unresolved with pytest.raises(TypeError) as info: symbol.interface = "hello" assert ("The interface to a Symbol must be a SymbolInterface but got " "'str'" in str(info.value))
def test_symbol_initialisation(): '''Test that a Symbol instance can be created when valid arguments are given, otherwise raise relevant exceptions. Also tests the internal Visibility class, the name, visibility and interface properties. ''' sym = Symbol("sym1") assert isinstance(sym, Symbol) assert sym.name == "sym1" assert sym.visibility == Symbol.DEFAULT_VISIBILITY assert isinstance(sym.interface, LocalInterface) # Check that the default visibility is public assert Symbol.DEFAULT_VISIBILITY == Symbol.Visibility.PUBLIC sym = Symbol("sym2", Symbol.Visibility.PRIVATE) assert sym.visibility == Symbol.Visibility.PRIVATE sym = Symbol("sym3", interface=UnresolvedInterface()) assert isinstance(sym.interface, UnresolvedInterface) with pytest.raises(TypeError) as error: sym = Symbol(None) assert ("Symbol 'name' attribute should be of type 'str'" in str(error.value)) with pytest.raises(TypeError) as error: Symbol('sym1', visibility="hello") assert ("Symbol 'visibility' attribute should be of type " "psyir.symbols.Symbol.Visibility but" in str(error.value)) with pytest.raises(TypeError) as error: Symbol('sym1', interface="hello") assert ("The interface to a Symbol must be a SymbolInterface but got " "'str'" in str(error.value))
def test_unresolvedinterface(): '''Test we can create an UnresolvedInterface instance and check its __str__ value ''' interface = UnresolvedInterface() assert str(interface) == "Unresolved"
def test_routinesymbol_init(): '''Test that a RoutineSymbol instance can be created.''' assert isinstance(RoutineSymbol('jo'), RoutineSymbol) assert isinstance( RoutineSymbol('ellie', visibility=Symbol.Visibility.PRIVATE), RoutineSymbol) assert isinstance( RoutineSymbol('isaac', interface=UnresolvedInterface()), RoutineSymbol)
def test_datatypesymbol_copy(): ''' Check that a DataTypeSymbol can be copied. ''' symbol = DataTypeSymbol("my_type", DeferredType(), visibility=Symbol.Visibility.PRIVATE, interface=UnresolvedInterface()) new_symbol = symbol.copy() assert new_symbol is not symbol assert new_symbol.name == "my_type" assert isinstance(new_symbol.datatype, DeferredType) assert new_symbol.visibility == Symbol.Visibility.PRIVATE assert isinstance(new_symbol.interface, UnresolvedInterface)
def test_datasymbol_initialisation(): '''Test that a DataSymbol instance can be created when valid arguments are given, otherwise raise relevant exceptions.''' # Test with valid arguments assert isinstance(DataSymbol('a', REAL_SINGLE_TYPE), DataSymbol) assert isinstance(DataSymbol('a', REAL_DOUBLE_TYPE), DataSymbol) assert isinstance(DataSymbol('a', REAL4_TYPE), DataSymbol) kind = DataSymbol('r_def', INTEGER_SINGLE_TYPE) real_kind_type = ScalarType(ScalarType.Intrinsic.REAL, kind) assert isinstance(DataSymbol('a', real_kind_type), DataSymbol) # real constants are not currently supported assert isinstance(DataSymbol('a', INTEGER_SINGLE_TYPE), DataSymbol) assert isinstance(DataSymbol('a', INTEGER_DOUBLE_TYPE, constant_value=0), DataSymbol) assert isinstance(DataSymbol('a', INTEGER4_TYPE), DataSymbol) assert isinstance(DataSymbol('a', CHARACTER_TYPE), DataSymbol) assert isinstance(DataSymbol('a', CHARACTER_TYPE, constant_value="hello"), DataSymbol) assert isinstance(DataSymbol('a', BOOLEAN_TYPE), DataSymbol) assert isinstance(DataSymbol('a', BOOLEAN_TYPE, constant_value=False), DataSymbol) array_type = ArrayType(REAL_SINGLE_TYPE, [ArrayType.Extent.ATTRIBUTE]) assert isinstance(DataSymbol('a', array_type), DataSymbol) array_type = ArrayType(REAL_SINGLE_TYPE, [3]) assert isinstance(DataSymbol('a', array_type), DataSymbol) array_type = ArrayType(REAL_SINGLE_TYPE, [3, ArrayType.Extent.ATTRIBUTE]) assert isinstance(DataSymbol('a', array_type), DataSymbol) assert isinstance(DataSymbol('a', REAL_SINGLE_TYPE), DataSymbol) assert isinstance(DataSymbol('a', REAL8_TYPE), DataSymbol) dim = DataSymbol('dim', INTEGER_SINGLE_TYPE, interface=UnresolvedInterface()) array_type = ArrayType(REAL_SINGLE_TYPE, [Reference(dim)]) assert isinstance(DataSymbol('a', array_type), DataSymbol) array_type = ArrayType(REAL_SINGLE_TYPE, [3, Reference(dim), ArrayType.Extent.ATTRIBUTE]) assert isinstance(DataSymbol('a', array_type), DataSymbol) assert isinstance( DataSymbol('a', REAL_SINGLE_TYPE, interface=ArgumentInterface( ArgumentInterface.Access.READWRITE)), DataSymbol) assert isinstance( DataSymbol('a', REAL_SINGLE_TYPE, visibility=Symbol.Visibility.PRIVATE), DataSymbol) assert isinstance(DataSymbol('field', DataTypeSymbol("field_type", DeferredType())), DataSymbol)
def test_datasymbol_scalar_array(): '''Test that the DataSymbol property is_scalar returns True if the DataSymbol is a scalar and False if not and that the DataSymbol property is_array returns True if the DataSymbol is an array and False if not. ''' sym1 = DataSymbol("s1", INTEGER_SINGLE_TYPE, interface=UnresolvedInterface()) array_type = ArrayType(REAL_SINGLE_TYPE, [ArrayType.Extent.ATTRIBUTE, 2, Reference(sym1)]) sym2 = DataSymbol("s2", array_type) assert sym1.is_scalar assert not sym1.is_array assert not sym2.is_scalar assert sym2.is_array
def test_oclw_kernelschedule(): '''Check the OpenCLWriter class kernelschedule_node visitor produces the expected OpenCL code. ''' # The kernelschedule OpenCL Backend relies on abstract methods that # need to be implemented by the APIs. A generic kernelschedule will # produce a NotImplementedError. oclwriter = OpenCLWriter() kschedule = KernelSchedule("kname") with pytest.raises(NotImplementedError) as error: _ = oclwriter(kschedule) assert "Abstract property. Which symbols are data arguments is " \ "API-specific." in str(error.value) # Mock abstract properties. (pytest monkeypatch does not work # with properties, used sub-class instead) class MockSymbolTable(SymbolTable): ''' Mock needed abstract methods of the Symbol Table ''' @property def iteration_indices(self): return self.argument_list[:2] @property def data_arguments(self): return self.argument_list[2:] kschedule.symbol_table.__class__ = MockSymbolTable # Create a sample symbol table and kernel schedule interface = ArgumentInterface(ArgumentInterface.Access.UNKNOWN) i = DataSymbol('i', INTEGER_TYPE, interface=interface) j = DataSymbol('j', INTEGER_TYPE, interface=interface) array_type = ArrayType(REAL_TYPE, [10, 10]) data1 = DataSymbol('data1', array_type, interface=interface) data2 = DataSymbol('data2', array_type, interface=interface) kschedule.symbol_table.add(i) kschedule.symbol_table.add(j) kschedule.symbol_table.add(data1) kschedule.symbol_table.add(data2) kschedule.symbol_table.specify_argument_list([i, j, data1, data2]) kschedule.addchild(Return(parent=kschedule)) result = oclwriter(kschedule) assert result == "" \ "__kernel void kname(\n" \ " __global double * restrict data1,\n" \ " __global double * restrict data2\n" \ " ){\n" \ " int data1LEN1 = get_global_size(0);\n" \ " int data1LEN2 = get_global_size(1);\n" \ " int data2LEN1 = get_global_size(0);\n" \ " int data2LEN2 = get_global_size(1);\n" \ " int i = get_global_id(0);\n" \ " int j = get_global_id(1);\n" \ " return;\n" \ "}\n\n" # Set a local_size value different to 1 into the KernelSchedule oclwriter = OpenCLWriter(kernels_local_size=4) result = oclwriter(kschedule) assert result == "" \ "__attribute__((reqd_work_group_size(4, 1, 1)))\n" \ "__kernel void kname(\n" \ " __global double * restrict data1,\n" \ " __global double * restrict data2\n" \ " ){\n" \ " int data1LEN1 = get_global_size(0);\n" \ " int data1LEN2 = get_global_size(1);\n" \ " int data2LEN1 = get_global_size(0);\n" \ " int data2LEN2 = get_global_size(1);\n" \ " int i = get_global_id(0);\n" \ " int j = get_global_id(1);\n" \ " return;\n" \ "}\n\n" # Add a symbol with a deferred interface and check that this raises the # expected error array_type = ArrayType(REAL_TYPE, [10, 10]) kschedule.symbol_table.add( DataSymbol('broken', array_type, interface=UnresolvedInterface())) with pytest.raises(VisitorError) as err: _ = oclwriter(kschedule) assert ("symbol table contains unresolved data entries (i.e. that have no " "defined Interface) which are not used purely to define the " "precision of other symbols: 'broken'" in str(err.value))