def test_container_create_invalid(): '''Test that the create method in a Container class raises the expected exception if the provided input is invalid. ''' symbol_table = SymbolTable() symbol_table.add(DataSymbol("x", REAL_SINGLE_TYPE)) children = [KernelSchedule.create("mod_1", SymbolTable(), [])] # name is not a string. with pytest.raises(GenerationError) as excinfo: _ = Container.create(1, symbol_table, children) assert ("name argument in create method of Container class " "should be a string but found 'int'.") in str(excinfo.value) # symbol_table not a SymbolTable. with pytest.raises(GenerationError) as excinfo: _ = Container.create("container", "invalid", children) assert ("symbol_table argument in create method of Container class " "should be a SymbolTable but found 'str'.") in str(excinfo.value) # children not a list. with pytest.raises(GenerationError) as excinfo: _ = Container.create("mod_name", symbol_table, "invalid") assert ("children argument in create method of Container class should " "be a list but found 'str'." in str(excinfo.value)) # contents of children list are not Container or KernelSchedule. with pytest.raises(GenerationError) as excinfo: _ = Container.create("mod_name", symbol_table, ["invalid"]) assert ("Item 'str' can't be child 0 of 'Container'. The valid format is:" " '[Container | KernelSchedule | InvokeSchedule]*'." in str(excinfo.value))
def test_scoping_node_symbol_table(): '''Test that ScopingNodes have a symbol_table property that returns its associated symbol table.''' # Since ScopingNode is abstract we will try this with a Container container = Container("test") assert container.symbol_table is container._symbol_table assert isinstance(container.symbol_table, SymbolTable) # An existing symbol table can be given to the constructor symtab = SymbolTable() container = Container("test", symbol_table=symtab) assert container.symbol_table is symtab
def test_find_or_create_change_symbol_type(): ''' Check that the _find_or_create_imported_symbol routine correctly updates the class of the located symbol if it is not an instance of the requested symbol type. ''' # pylint: disable=unidiomatic-typecheck # Create some suitable PSyIR from scratch symbol_table = SymbolTable() tmp_sym = symbol_table.new_symbol("tmp") sub_sym = symbol_table.new_symbol("my_sub") kernel1 = KernelSchedule.create("mod_1", SymbolTable(), []) _ = Container.create("container_name", symbol_table, [kernel1]) assign = Assignment.create(Reference(tmp_sym), Literal("1.0", REAL_TYPE)) kernel1.addchild(assign) # Search for the 'tmp' symbol sym = _find_or_create_imported_symbol(assign, "tmp") assert sym is tmp_sym assert type(sym) == Symbol # Repeat but this time specify that we're expecting a DataSymbol sym = _find_or_create_imported_symbol(assign, "tmp", symbol_type=DataSymbol, datatype=REAL_TYPE) assert sym is tmp_sym assert type(sym) == DataSymbol assert sym.datatype == REAL_TYPE # Search for 'my_sub' and specify that it should be a RoutineSymbol sym2 = _find_or_create_imported_symbol(assign, "my_sub", symbol_type=RoutineSymbol) assert sym2 is sub_sym assert type(sym2) == RoutineSymbol
def test_find_or_create_imported_symbol_2(): ''' Check that the _find_or_create_imported_symbol() method creates new symbols when appropriate. ''' # Create some suitable PSyIR from scratch symbol_table = SymbolTable() symbol_table.add(DataSymbol("tmp", REAL_TYPE)) kernel1 = KernelSchedule.create("mod_1", SymbolTable(), []) container = Container.create("container_name", symbol_table, [kernel1]) xvar = DataSymbol("x", REAL_TYPE) xref = Reference(xvar) assign = Assignment.create(xref, Literal("1.0", REAL_TYPE)) kernel1.addchild(assign) # We have no wildcard imports so there can be no symbol named 'undefined' with pytest.raises(SymbolError) as err: _ = _find_or_create_imported_symbol(assign, "undefined") assert "No Symbol found for name 'undefined'" in str(err.value) # We should be able to find the 'tmp' symbol in the parent Container sym = _find_or_create_imported_symbol(assign, "tmp") assert sym.datatype.intrinsic == ScalarType.Intrinsic.REAL # Add a wildcard import to the SymbolTable of the KernelSchedule new_container = ContainerSymbol("some_mod") new_container.wildcard_import = True kernel1.symbol_table.add(new_container) # Symbol not in any container but we do have wildcard imports so we # get a new symbol back new_symbol = _find_or_create_imported_symbol(assign, "undefined") assert new_symbol.name == "undefined" assert isinstance(new_symbol.interface, UnresolvedInterface) # pylint: disable=unidiomatic-typecheck assert type(new_symbol) == Symbol assert "undefined" not in container.symbol_table assert kernel1.symbol_table.lookup("undefined") is new_symbol
def test_scope(): '''Test that the scope method in a Node instance returns the closest ancestor Schedule or Container Node (including itself) or raises an exception if one does not exist. ''' kernel_symbol_table = SymbolTable() symbol = DataSymbol("tmp", REAL_TYPE) kernel_symbol_table.add(symbol) ref = Reference(symbol) assign = Assignment.create(ref, Literal("0.0", REAL_TYPE)) kernel_schedule = KernelSchedule.create("my_kernel", kernel_symbol_table, [assign]) container = Container.create("my_container", SymbolTable(), [kernel_schedule]) assert ref.scope is kernel_schedule assert assign.scope is kernel_schedule assert kernel_schedule.scope is kernel_schedule assert container.scope is container anode = Literal("x", INTEGER_TYPE) with pytest.raises(SymbolError) as excinfo: _ = anode.scope assert ("Unable to find the scope of node " "'Literal[value:'x', Scalar<INTEGER, UNDEFINED>]' as " "none of its ancestors are Container or Schedule nodes." in str(excinfo.value))
def test_container_name(): '''Test that the container name can be set and changed as expected.''' container = Container("test") assert container.name == "test" container.name = "new_test" assert container.name == "new_test"
def test_replace_with_error2(): '''Check that the replace_with method raises the expected exceptions if either node is invalid. ''' parent = Schedule() node1 = Statement() node2 = Statement() with pytest.raises(TypeError) as info: node1.replace_with("hello") assert ("The argument node in method replace_with in the Node class " "should be a Node but found 'str'." in str(info.value)) with pytest.raises(GenerationError) as info: node1.replace_with(node2) assert ("This node should have a parent if its replace_with method " "is called." in str(info.value)) node1.parent = parent node2.parent = parent parent.children = [node1, node2] with pytest.raises(GenerationError) as info: node1.replace_with(node2) assert ("The parent of argument node in method replace_with in the Node " "class should be None but found 'Schedule'." in str(info.value)) node3 = Container("hello") with pytest.raises(GenerationError) as info: node1.replace_with(node3) assert ("Generation Error: Item 'Container' can't be child 0 of " "'Schedule'. The valid format is: '[Statement]*'." in str(info.value))
def test_node_position(): ''' Test that the Node class position and abs_position methods return the correct value for a Node in a tree. The start position is set to 0. Relative position starts from 0 and absolute from 1. ''' _, invoke_info = parse(os.path.join(BASE_PATH, "4.7_multikernel_invokes.f90"), api="dynamo0.3") psy = PSyFactory("dynamo0.3", distributed_memory=True).create(invoke_info) invoke = psy.invokes.invoke_list[0] schedule = invoke.schedule.detach() child = schedule.children[6] # Assert that position of a Schedule (no parent Node) is 0 assert schedule.position == 0 assert schedule.abs_position == 0 # Assert that start_position of any Node is 0 assert child.START_POSITION == 0 # Assert that relative and absolute positions return correct values assert child.position == 6 assert child.abs_position == 7 # Insert two more levels of nodes in top of the root previous_root = child.root container1 = Container("test1") container2 = Container("test2") container2.addchild(previous_root) container1.addchild(container2) # The relative position should still be the same but the absolute position # should increase by 2. assert child.position == 6 assert child.abs_position == 9 # Check that the _find_position returns the correct distance between itself # and the provided ancestor. _, position = child._find_position(child.ancestor(Routine), 0) assert position == 7 # If no starting position is provided it starts with START_POSITION=0 _, same_position = child._find_position(child.ancestor(Routine)) assert same_position == position # Test InternalError for _find_position with an incorrect position with pytest.raises(InternalError) as excinfo: _, _ = child._find_position(child.root.children, -2) assert "started from -2 instead of 0" in str(excinfo.value)
def test_containersymbol_str(): '''Test that a ContainerSymbol instance can be stringified''' sym = ContainerSymbol("my_mod") assert str(sym) == "my_mod: <not linked>" sym._reference = Container("my_mod") assert str(sym) == "my_mod: <linked>"
def test_get_external_symbol(monkeypatch): ''' Test the get_external_symbol() method. ''' asym = Symbol("a") with pytest.raises(NotImplementedError) as err: asym.get_external_symbol() assert ("trying to resolve symbol 'a' properties, the lazy evaluation " "of 'Local' interfaces is not supported" in str(err.value)) other_container = ContainerSymbol("some_mod") ctable = SymbolTable() ctable.add(other_container) # Create a Symbol that is imported from the "some_mod" Container bsym = Symbol("b", interface=GlobalInterface(other_container)) ctable.add(bsym) _ = Container.create("test", ctable, [KernelSchedule("dummy")]) # Monkeypatch the container's FortranModuleInterface so that it always # appears to be unable to find the "some_mod" module def fake_import(name): raise SymbolError("Oh dear") monkeypatch.setattr(other_container._interface, "import_container", fake_import) with pytest.raises(SymbolError) as err: bsym.get_external_symbol() assert ("trying to resolve the properties of symbol 'b' in module " "'some_mod': PSyclone SymbolTable error: Oh dear" in str(err.value)) # Now create a Container for the 'some_mod' module and attach this to # the ContainerSymbol ctable2 = SymbolTable() some_mod = Container.create("some_mod", ctable2, [KernelSchedule("dummy2")]) other_container._reference = some_mod # Currently the Container does not contain an entry for 'b' with pytest.raises(SymbolError) as err: bsym.get_external_symbol() assert ("trying to resolve the properties of symbol 'b'. The interface " "points to module 'some_mod' but could not find the definition" in str(err.value)) # Add an entry for 'b' to the Container's symbol table ctable2.add(DataSymbol("b", INTEGER_SINGLE_TYPE)) new_sym = bsym.resolve_deferred() assert isinstance(new_sym, DataSymbol) assert new_sym.datatype == INTEGER_SINGLE_TYPE
def test_container_children_validation(): '''Test that children added to Container are validated. Container accepts just Container and kernelSchedule as children. ''' container = Container.create("container", SymbolTable(), []) # Valid children container2 = Container.create("container2", SymbolTable(), []) container.addchild(container2) # Invalid children (e.g. Return Statement) ret = Return() with pytest.raises(GenerationError) as excinfo: container.addchild(ret) assert ("Item 'Return' can't be child 1 of 'Container'. The valid format" " is: '[Container | KernelSchedule | InvokeSchedule]*'." "" in str(excinfo.value))
def test_arraytype_shape_dim_from_parent_scope(): ''' Check that the shape checking in the ArrayType class permits the use of a reference to a symbol in a parent scope. ''' cont = Container("test_mod") dim_sym = cont.symbol_table.new_symbol("dim1", symbol_type=DataSymbol, datatype=INTEGER_TYPE) kernel1 = KernelSchedule.create("mod_1", SymbolTable(), []) cont.addchild(kernel1) asym = kernel1.symbol_table.new_symbol( "array1", symbol_type=DataSymbol, datatype=ArrayType(INTEGER_TYPE, [Reference(dim_sym)])) assert isinstance(asym, DataSymbol)
def test_dyninvsched_parent(): ''' Check the setting of the parent of a DynInvokeSchedule. ''' _, invoke_info = parse(os.path.join(BASE_PATH, "1.0.1_single_named_invoke.f90"), api=TEST_API) kcalls = invoke_info.calls[0].kcalls # With no parent specified dsched = DynInvokeSchedule("my_sched", kcalls) assert dsched.parent is None # With a parent fake_parent = Container("my_mod") dsched2 = DynInvokeSchedule("my_sched", kcalls, parent=fake_parent) assert dsched2.parent is fake_parent
def test_gosched_parent(): ''' Check that the GOInvokeSchedule constructor allows the parent node to be supplied or omitted. ''' _, invoke_info = parse(os.path.join(BASE_PATH, "single_invoke_two_kernels.f90"), api=API) kcalls = invoke_info.calls[0].kcalls # With no parent specified gsched = GOInvokeSchedule("my_sched", kcalls) assert gsched.parent is None # With a parent cont = Container("my_mod") gsched = GOInvokeSchedule("my_sched", kcalls, parent=cont) assert gsched.parent is cont
def test_kernelschedule_constructor(): ''' Check that we can construct a KernelSchedule and that it has the expected properties. ''' ksched = KernelSchedule("timetable") assert ksched.name == "timetable" # A KernelSchedule does not represent a program assert not ksched.is_program # A KernelSchedule does not return anything assert ksched.return_type is None assert ksched.parent is None # Now create a KernelSchedule with a parent cnode = Container("BigBox") ksched2 = KernelSchedule("plan", parent=cnode) assert ksched2.parent is cnode
def test_file_container_create(): '''Test that the create method in the Container class correctly creates a FileContainer instance. ''' symbol_table = SymbolTable() symbol_table.add(DataSymbol("tmp", REAL_SINGLE_TYPE)) module = Container.create("mod_1", symbol_table, []) program = Routine.create("prog_1", SymbolTable(), [], is_program=True) file_container = FileContainer.create("container_name", SymbolTable(), [module, program]) assert isinstance(file_container, FileContainer) result = FortranWriter().filecontainer_node(file_container) assert result == ("module mod_1\n" " implicit none\n" " real :: tmp\n\n" " contains\n\n" "end module mod_1\n" "program prog_1\n\n\n" "end program prog_1\n")
def test_container_create(): '''Test that the create method in the Container class correctly creates a Container instance. ''' symbol_table = SymbolTable() symbol_table.add(DataSymbol("tmp", REAL_SINGLE_TYPE)) kernel1 = KernelSchedule.create("mod_1", SymbolTable(), []) kernel2 = KernelSchedule.create("mod_2", SymbolTable(), []) container = Container.create("container_name", symbol_table, [kernel1, kernel2]) check_links(container, [kernel1, kernel2]) assert container.symbol_table is symbol_table result = FortranWriter().container_node(container) assert result == ("module container_name\n" " real :: tmp\n\n" " contains\n" " subroutine mod_1()\n\n\n" " end subroutine mod_1\n" " subroutine mod_2()\n\n\n" " end subroutine mod_2\n\n" "end module container_name\n")
def test_find_symbol_table(): ''' Test the find_symbol_table() method. ''' sym = Symbol("a_var") with pytest.raises(TypeError) as err: sym.find_symbol_table("3") assert ("expected to be passed an instance of psyir.nodes.Node but got " "'str'" in str(err.value)) # Search for a SymbolTable with only one level of hierarchy sched = KernelSchedule("dummy") table = sched.symbol_table table.add(sym) assert sym.find_symbol_table(sched) is table # Create a Container so that we have two levels of hierarchy ctable = SymbolTable() sym2 = Symbol("b_var") ctable.add(sym2) _ = Container.create("test", ctable, [sched]) assert sym2.find_symbol_table(sched) is ctable # A Symbol that isn't in any table sym3 = Symbol("missing") assert sym3.find_symbol_table(sched) is None # When there is no SymbolTable associated with the PSyIR node orphan = Literal("1", INTEGER_SINGLE_TYPE) assert sym3.find_symbol_table(orphan) is None
def test_container_can_be_printed(): '''Test that a Container instance can always be printed (i.e. is initialised fully)''' cont_stmt = Container("box") assert "Container[box]\n" in str(cont_stmt)
def test_container_node_str(): '''Check the node_str method of the Container class.''' from psyclone.psyir.nodes.node import colored, SCHEDULE_COLOUR_MAP cont_stmt = Container("bin") coloredtext = colored("Container", SCHEDULE_COLOUR_MAP["Container"]) assert coloredtext + "[bin]" in cont_stmt.node_str()
def test_container_symbol_table(): '''Test that the container symbol_table method returns the expected content.''' container = Container("test") assert isinstance(container._symbol_table, SymbolTable) assert container.symbol_table is container._symbol_table
def test_container_init_parent(): '''Test that a container parent argument is stored as expected.''' container = Container("test", parent="hello") assert container.parent == "hello"
def test_container_init(): '''Test that a container is initialised as expected.''' container = Container("test") assert container._name == "test" assert container._parent is None assert isinstance(container._symbol_table, SymbolTable)
def create_psyir_tree(): ''' Create an example PSyIR Tree. :returns: an example PSyIR tree. :rtype: :py:class:`psyclone.psyir.nodes.Container` ''' # Symbol table, symbols and scalar datatypes symbol_table = SymbolTable() arg1 = symbol_table.new_symbol(symbol_type=DataSymbol, datatype=REAL_TYPE, interface=ArgumentInterface( ArgumentInterface.Access.READWRITE)) symbol_table.specify_argument_list([arg1]) tmp_symbol = symbol_table.new_symbol(symbol_type=DataSymbol, datatype=REAL_DOUBLE_TYPE) index_symbol = symbol_table.new_symbol(root_name="i", symbol_type=DataSymbol, datatype=INTEGER4_TYPE) real_kind = symbol_table.new_symbol(root_name="RKIND", symbol_type=DataSymbol, datatype=INTEGER_TYPE, constant_value=8) routine_symbol = RoutineSymbol("my_sub") # Array using precision defined by another symbol scalar_type = ScalarType(ScalarType.Intrinsic.REAL, real_kind) array = symbol_table.new_symbol(root_name="a", symbol_type=DataSymbol, datatype=ArrayType(scalar_type, [10])) # Nodes which do not have Nodes as children and (some) predefined # scalar datatypes # TODO: Issue #1136 looks at how to avoid all of the _x versions zero_1 = Literal("0.0", REAL_TYPE) zero_2 = Literal("0.0", REAL_TYPE) zero_3 = Literal("0.0", REAL_TYPE) one_1 = Literal("1.0", REAL4_TYPE) one_2 = Literal("1.0", REAL4_TYPE) one_3 = Literal("1.0", REAL4_TYPE) two = Literal("2.0", scalar_type) int_zero = Literal("0", INTEGER_SINGLE_TYPE) int_one_1 = Literal("1", INTEGER8_TYPE) int_one_2 = Literal("1", INTEGER8_TYPE) int_one_3 = Literal("1", INTEGER8_TYPE) int_one_4 = Literal("1", INTEGER8_TYPE) tmp1_1 = Reference(arg1) tmp1_2 = Reference(arg1) tmp1_3 = Reference(arg1) tmp1_4 = Reference(arg1) tmp1_5 = Reference(arg1) tmp1_6 = Reference(arg1) tmp2_1 = Reference(tmp_symbol) tmp2_2 = Reference(tmp_symbol) tmp2_3 = Reference(tmp_symbol) tmp2_4 = Reference(tmp_symbol) tmp2_5 = Reference(tmp_symbol) tmp2_6 = Reference(tmp_symbol) # Unary Operation oper = UnaryOperation.Operator.SIN unaryoperation_1 = UnaryOperation.create(oper, tmp2_1) unaryoperation_2 = UnaryOperation.create(oper, tmp2_2) # Binary Operation oper = BinaryOperation.Operator.ADD binaryoperation_1 = BinaryOperation.create(oper, one_1, unaryoperation_1) binaryoperation_2 = BinaryOperation.create(oper, one_2, unaryoperation_2) # Nary Operation oper = NaryOperation.Operator.MAX naryoperation = NaryOperation.create(oper, [tmp1_1, tmp2_3, one_3]) # Array reference using a range lbound = BinaryOperation.create(BinaryOperation.Operator.LBOUND, Reference(array), int_one_1) ubound = BinaryOperation.create(BinaryOperation.Operator.UBOUND, Reference(array), int_one_2) my_range = Range.create(lbound, ubound) tmparray = ArrayReference.create(array, [my_range]) # Assignments assign1 = Assignment.create(tmp1_2, zero_1) assign2 = Assignment.create(tmp2_4, zero_2) assign3 = Assignment.create(tmp2_5, binaryoperation_1) assign4 = Assignment.create(tmp1_3, tmp2_6) assign5 = Assignment.create(tmp1_4, naryoperation) assign6 = Assignment.create(tmparray, two) # Call call = Call.create(routine_symbol, [tmp1_5, binaryoperation_2]) # If statement if_condition = BinaryOperation.create(BinaryOperation.Operator.GT, tmp1_6, zero_3) ifblock = IfBlock.create(if_condition, [assign3, assign4]) # Loop loop = Loop.create(index_symbol, int_zero, int_one_3, int_one_4, [ifblock]) # KernelSchedule kernel_schedule = KernelSchedule.create( "work", symbol_table, [assign1, call, assign2, loop, assign5, assign6]) # Container container_symbol_table = SymbolTable() container = Container.create("CONTAINER", container_symbol_table, [kernel_schedule]) # Import data from another container external_container = ContainerSymbol("some_mod") container_symbol_table.add(external_container) external_var = DataSymbol("some_var", INTEGER_TYPE, interface=GlobalInterface(external_container)) container_symbol_table.add(external_var) routine_symbol.interface = GlobalInterface(external_container) container_symbol_table.add(routine_symbol) return container
interface=READ_ARG) for symbol in [NQP_XY, NQP_Z, WEIGHTS_XY, WEIGHTS_Z, BASIS_W3, DIFF_BASIS_W3]: SYMBOL_TABLE.add(symbol) SYMBOL_TABLE.specify_argument_list([ NDF_W3, UNDF_W3, NCELL_3D, FIELD2, OPERATOR, NQP_XY, NQP_Z, WEIGHTS_XY, WEIGHTS_Z, BASIS_W3, DIFF_BASIS_W3 ]) # Routine symbol ROUTINE_SYMBOL = RoutineSymbol("my_sub") # Call CALL = Call.create(ROUTINE_SYMBOL, [Reference(FIELD1), Reference(FIELD2), Reference(OPERATOR)]) # KernelSchedule KERNEL_SCHEDULE = KernelSchedule.create("work", SYMBOL_TABLE, [CALL]) # Container CONTAINER_SYMBOL_TABLE = SymbolTable() CONTAINER = Container.create("CONTAINER", CONTAINER_SYMBOL_TABLE, [KERNEL_SCHEDULE]) # Write out the code as Fortran WRITER = FortranWriter() RESULT = WRITER(CONTAINER) print(RESULT)
def test_container_node_str(): '''Check the node_str method of the Container class.''' cont_stmt = Container("bin") coloredtext = colored("Container", Container._colour) assert coloredtext + "[bin]" in cont_stmt.node_str()
def create_psyir_tree(): ''' Create an example PSyIR Tree. :returns: an example PSyIR tree. :rtype: :py:class:`psyclone.psyir.nodes.Container` ''' # Symbol table, symbols and scalar datatypes symbol_table = SymbolTable() arg1 = symbol_table.new_symbol(symbol_type=DataSymbol, datatype=REAL_TYPE, interface=ArgumentInterface( ArgumentInterface.Access.READWRITE)) symbol_table.specify_argument_list([arg1]) tmp_symbol = symbol_table.new_symbol(symbol_type=DataSymbol, datatype=REAL_DOUBLE_TYPE) index_symbol = symbol_table.new_symbol(root_name="i", symbol_type=DataSymbol, datatype=INTEGER4_TYPE) real_kind = symbol_table.new_symbol(root_name="RKIND", symbol_type=DataSymbol, datatype=INTEGER_TYPE, constant_value=8) routine_symbol = RoutineSymbol("my_sub") # Array using precision defined by another symbol scalar_type = ScalarType(ScalarType.Intrinsic.REAL, real_kind) array = symbol_table.new_symbol(root_name="a", symbol_type=DataSymbol, datatype=ArrayType(scalar_type, [10])) # Make generators for nodes which do not have other Nodes as children, # with some predefined scalar datatypes def zero(): return Literal("0.0", REAL_TYPE) def one(): return Literal("1.0", REAL4_TYPE) def two(): return Literal("2.0", scalar_type) def int_zero(): return Literal("0", INTEGER_SINGLE_TYPE) def int_one(): return Literal("1", INTEGER8_TYPE) def tmp1(): return Reference(arg1) def tmp2(): return Reference(tmp_symbol) # Unary Operation oper = UnaryOperation.Operator.SIN unaryoperation = UnaryOperation.create(oper, tmp2()) # Binary Operation oper = BinaryOperation.Operator.ADD binaryoperation = BinaryOperation.create(oper, one(), unaryoperation) # Nary Operation oper = NaryOperation.Operator.MAX naryoperation = NaryOperation.create(oper, [tmp1(), tmp2(), one()]) # Array reference using a range lbound = BinaryOperation.create(BinaryOperation.Operator.LBOUND, Reference(array), int_one()) ubound = BinaryOperation.create(BinaryOperation.Operator.UBOUND, Reference(array), int_one()) my_range = Range.create(lbound, ubound) tmparray = ArrayReference.create(array, [my_range]) # Assignments assign1 = Assignment.create(tmp1(), zero()) assign2 = Assignment.create(tmp2(), zero()) assign3 = Assignment.create(tmp2(), binaryoperation) assign4 = Assignment.create(tmp1(), tmp2()) assign5 = Assignment.create(tmp1(), naryoperation) assign6 = Assignment.create(tmparray, two()) # Call call = Call.create(routine_symbol, [tmp1(), binaryoperation.copy()]) # If statement if_condition = BinaryOperation.create(BinaryOperation.Operator.GT, tmp1(), zero()) ifblock = IfBlock.create(if_condition, [assign3, assign4]) # Loop loop = Loop.create(index_symbol, int_zero(), int_one(), int_one(), [ifblock]) # KernelSchedule kernel_schedule = KernelSchedule.create( "work", symbol_table, [assign1, call, assign2, loop, assign5, assign6]) # Container container_symbol_table = SymbolTable() container = Container.create("CONTAINER", container_symbol_table, [kernel_schedule]) # Import data from another container external_container = ContainerSymbol("some_mod") container_symbol_table.add(external_container) external_var = DataSymbol("some_var", INTEGER_TYPE, interface=GlobalInterface(external_container)) container_symbol_table.add(external_var) routine_symbol.interface = GlobalInterface(external_container) container_symbol_table.add(routine_symbol) return container
def test_scoping_node_copy_hierarchy(): ''' Test that the ScopingNode copy() method creates a new symbol table with copied symbols and updates the children references. This test has 2 ScopingNodes, and the copy will only be applied to the inner one. This means that the References to the symbols on the outer scope should not be duplicated. Also it contains argument symbols and a reference inside another reference to make sure all these are copied appropriately. ''' parent_node = Container("module") symbol_b = parent_node.symbol_table.new_symbol("b", symbol_type=DataSymbol, datatype=ArrayType( INTEGER_TYPE, [5])) schedule = Routine("routine") parent_node.addchild(schedule) symbol_a = schedule.symbol_table.new_symbol( "a", symbol_type=DataSymbol, datatype=INTEGER_TYPE, interface=ArgumentInterface(ArgumentInterface.Access.READWRITE)) schedule.symbol_table.specify_argument_list([symbol_a]) symbol_i = schedule.symbol_table.new_symbol("i", symbol_type=DataSymbol, datatype=INTEGER_TYPE) schedule.addchild( Assignment.create( Reference(symbol_a), ArrayReference.create(symbol_b, [Reference(symbol_i)]))) new_schedule = schedule.copy() # Check that the symbol_table has been deep copied assert new_schedule.symbol_table is not schedule.symbol_table assert new_schedule.symbol_table.lookup("i") is not \ schedule.symbol_table.lookup("i") assert new_schedule.symbol_table.lookup("a") is not \ schedule.symbol_table.lookup("a") # Check that 'a' and 'i' have been copied to the new symbol table. assert new_schedule[0].lhs.symbol not in schedule.symbol_table.symbols assert new_schedule[0].lhs.symbol in new_schedule.symbol_table.symbols assert new_schedule[0].rhs.children[0].symbol not in \ schedule.symbol_table.symbols assert new_schedule[0].rhs.children[0].symbol in \ new_schedule.symbol_table.symbols # Add the "_new" suffix to all symbol in the copied schedule for symbol in new_schedule.symbol_table.symbols: new_schedule.symbol_table.rename_symbol(symbol, symbol.name + "_new") # An update to a symbol in the outer scope must affect both copies of the # inner schedule. parent_node.symbol_table.rename_symbol(symbol_b, symbol_b.name + "_global") # Insert the schedule back to the original container parent_node.addchild(new_schedule) # Check that the expected code is generated # TODO #1200: the new 'routine' RoutineSymbol also needs to change. expected = '''\ module module implicit none integer, dimension(5) :: b_global contains subroutine routine(a) integer, intent(inout) :: a integer :: i a = b_global(i) end subroutine routine subroutine routine(a_new) integer, intent(inout) :: a_new integer :: i_new a_new = b_global(i_new) end subroutine routine end module module ''' writer = FortranWriter() output = writer(parent_node) assert expected == output