示例#1
0
def test_children_setter():
    ''' Test that the children setter sets-up accepts lists or raises
    the appropriate issue. '''
    testnode = Schedule()

    # children is initialised as a ChildrenList
    assert isinstance(testnode.children, ChildrenList)

    # When is set up with a list, this becomes a ChildrenList
    statement1 = Statement()
    statement2 = Statement()
    testnode.children = [statement1, statement2]
    assert isinstance(testnode.children, ChildrenList)
    assert statement1.parent is testnode
    assert statement2.parent is testnode

    # Other types are not accepted
    with pytest.raises(TypeError) as error:
        testnode.children = Node()
    assert "The 'my_children' parameter of the node.children setter must be" \
           " a list." in str(error.value)

    # If a children list is overwritten, it properly disconnects the previous
    # children
    testnode.children = []
    assert statement1.parent is None
    assert statement2.parent is None
示例#2
0
def test_node_abs_position_error():
    ''' Check that the abs_position method produces and internal error when
    a node can be found as one of the children of its parent (this just
    happens with inconsistent parent-child connections). '''

    parent = Schedule()
    node1 = Statement()
    # Manually connect the _parent attribute which won't make a consistent
    # two-way relationship
    node1._parent = parent

    with pytest.raises(InternalError) as err:
        _ = node1.abs_position
    assert "Error in search for Node position in the tree" in str(err.value)
示例#3
0
def test_node_constructor_with_parent():
    ''' Check that the node constructor parent parameter works as expected. '''
    parent = Schedule()
    wrong_parent = Schedule()

    # By default no parent reference is given
    node = Statement()
    assert node.parent is None
    assert node.has_constructor_parent is False

    # The parent argument can predefine the parent reference
    node = Return(parent=parent)
    assert node.parent is parent
    assert node.has_constructor_parent is True

    # Then only an addition to this predefined parent is accepted
    with pytest.raises(GenerationError) as err:
        wrong_parent.addchild(node)
    assert ("'Schedule' cannot be set as parent of 'Return' because its "
            "constructor predefined the parent reference to a different "
            "'Schedule' node." in str(err.value))

    # Once given the proper parent, it can act as a regular node
    parent.addchild(node)
    assert node.parent is parent
    assert node.has_constructor_parent is False
    wrong_parent.addchild(node.detach())
    assert node.parent is wrong_parent
示例#4
0
def test_parent_references_coherency():
    ''' Check that the parent references keep updated with the children
    node operations. '''
    parent = Schedule()

    # Children addition methods
    node1 = Statement()
    parent.addchild(node1)
    assert node1.parent is parent

    node2 = Statement()
    parent.children.append(node2)
    assert node2.parent is parent

    node3 = Statement()
    parent.children.extend([node3])
    assert node3.parent is parent

    node4 = Statement()
    parent.children.insert(0, node4)
    assert node4.parent is parent

    # Node deletion
    node = parent.children.pop()
    assert node.parent is None
    assert node is node3

    del parent.children[0]
    assert node4.parent is None

    parent.children = []
    assert node2.parent is None

    # The insertion has deletions and additions
    parent.addchild(node1)
    parent.children[0] = node2
    assert node1.parent is None
    assert node2.parent is parent

    # The assignment also deletes and adds nodes
    parent.addchild(node1)
    parent.children = parent.children + [node3]
    assert node1.parent is parent
    assert node2.parent is parent
    assert node3.parent is parent
示例#5
0
def test_accloop():
    ''' Generic tests for the ACCLoopTrans transformation class '''
    trans = ACCLoopTrans()
    assert trans.name == "ACCLoopTrans"
    assert str(trans) == "Adds an 'OpenACC loop' directive to a loop"

    cnode = Statement()
    tdir = trans._directive([cnode])
    assert isinstance(tdir, ACCLoopDirective)
示例#6
0
def test_sched_init():
    ''' Check the Schedule class is initialised as expected.'''

    # By default Schedule sets parent to None, children to an empty list and
    # initialises its own symbol table.
    sched = Schedule()
    assert isinstance(sched, Schedule)
    assert not sched.parent
    assert not sched.children
    assert isinstance(sched.symbol_table, SymbolTable)

    # A custom symbol table and parent and children nodes can be given as
    # arguments of Schedule.
    symtab = SymbolTable()
    sched2 = Schedule(parent=sched, children=[Statement(), Statement()],
                      symbol_table=symtab)
    assert isinstance(sched2, Schedule)
    assert sched2.parent is sched
    assert len(sched2.children) == 2
    assert sched2.symbol_table is symtab
示例#7
0
def test_detach():
    ''' Check that the detach method removes a node from its parent node. '''

    # Create a PSyIR tree
    parent = Schedule()
    node1 = Statement()
    parent.addchild(node1)
    node2 = Statement()
    parent.addchild(node2)

    # Execute the detach method on node 1, it should return itself
    assert node1.detach() is node1

    # Check that the resulting nodes and connections are correct
    assert node1.parent is None
    assert len(parent.children) == 1
    assert parent.children[0] is node2

    # Executing it again still succeeds
    assert node1.detach() is node1
示例#8
0
def test_pop_all_children():
    ''' Check that the pop_all_children method removes the children nodes
    from the children list and return them all in a list. '''

    # Create a PSyIR tree
    parent = Schedule()
    node1 = Statement()
    parent.addchild(node1)
    node2 = Statement()
    parent.addchild(node2)

    # Execute pop_all_children method
    result = parent.pop_all_children()

    # Check the resulting nodes and connections are as expected
    assert isinstance(result, list)
    assert len(parent.children) == 0
    assert node1.parent is None
    assert node2.parent is None
    assert result[0] is node1 and result[1] is node2
示例#9
0
def test_accloop():
    ''' Generic tests for the ACCLoopTrans transformation class '''
    from psyclone.transformations import ACCLoopTrans
    from psyclone.psyGen import ACCLoopDirective
    trans = ACCLoopTrans()
    assert trans.name == "ACCLoopTrans"
    assert str(trans) == "Adds an 'OpenACC loop' directive to a loop"

    pnode = Node()
    cnode = Statement()
    tdir = trans._directive(pnode, [cnode])
    assert isinstance(tdir, ACCLoopDirective)
示例#10
0
def test_children_setter():
    ''' Test that the children setter sets-up accepts lists or None or raises
    the appropriate issue. '''
    testnode = Schedule()

    # children is initialised as a ChildrenList
    assert isinstance(testnode.children, ChildrenList)

    # When is set up with a list, this becomes a ChildrenList
    testnode.children = [Statement(), Statement()]
    assert isinstance(testnode.children, ChildrenList)

    # It also accepts None
    testnode.children = None
    assert testnode.children is None

    # Other types are not accepted
    with pytest.raises(TypeError) as error:
        testnode.children = Node()
    assert "The 'my_children' parameter of the node.children setter must be" \
           " a list or None." in str(error.value)
示例#11
0
def test_profile_trans_invalid_name(value):
    '''Invalid name supplied to options argument.'''
    profile_trans = ProfileTrans()

    # We need to have a schedule as parent, otherwise the node
    # (with no parent) will not be allowed.
    sched = Schedule()
    node = Statement(parent=sched)
    sched.addchild(node)
    with pytest.raises(TransformationError) as excinfo:
        _ = profile_trans.apply(node, options={"region_name": value})
    assert ("User-supplied region name must be a tuple containing "
            "two non-empty strings." in str(excinfo.value))
示例#12
0
def test_lower_to_language_level(monkeypatch):
    ''' Test that Node has a lower_to_language_level() method that \
    recurses to the same method of its children. '''

    # Monkeypatch the lower_to_language_level to just mark a flag
    def visited(self):
        self._visited_flag = True

    monkeypatch.setattr(Statement, "lower_to_language_level", visited)

    testnode = Schedule()
    node1 = Statement()
    node2 = Statement()
    testnode.children = [node1, node2]

    # Execute method
    testnode.lower_to_language_level()

    # Check all children have been visited
    for child in testnode.children:
        # This member only exists in the monkeypatched version
        # pylint:disable=no-member
        assert child._visited_flag
示例#13
0
def test_omp_do_children_err():
    ''' Tests that we raise the expected error when an OpenMP parallel do
    directive has more than one child. '''
    from psyclone.transformations import OMPParallelLoopTrans
    from psyclone.psyGen import OMPParallelDoDirective
    otrans = OMPParallelLoopTrans()
    psy, invoke_info = get_invoke("imperfect_nest.f90", api=API, idx=0)
    schedule = invoke_info.schedule
    otrans.apply(schedule[0].loop_body[2])
    directive = schedule[0].loop_body[2]
    assert isinstance(directive, OMPParallelDoDirective)
    # Make the schedule invalid by adding a second child to the
    # OMPParallelDoDirective
    directive.dir_body.children.append(Statement())
    with pytest.raises(GenerationError) as err:
        _ = psy.gen
    assert ("An OpenMP PARALLEL DO can only be applied to a single loop but "
            "this Node has 2 children:" in str(err.value))
示例#14
0
def test_replace_with():
    '''Check that the replace_with method behaves as expected.'''

    parent_node = Schedule()
    node1 = Statement()
    node2 = Statement()
    node3 = Statement()
    parent_node.children = [node1, node2, node3]
    new_node = Assignment()

    node2.replace_with(new_node)

    assert parent_node.children[1] is new_node
    assert new_node.parent is parent_node
    assert node2.parent is None
示例#15
0
def test_omp_do_update():
    '''Check the OMPDoDirective update function.'''
    psy, invoke = get_invoke("imperfect_nest.f90", api=API, idx=0)
    schedule = invoke.schedule
    par_trans = OMPParallelTrans()
    loop_trans = OMPLoopTrans()
    par_trans.apply(schedule[0].loop_body[1].else_body[0].else_body[0])
    loop_trans.apply(
        schedule[0].loop_body[1].else_body[0].else_body[0].dir_body[0])
    gen_code = str(psy.gen).lower()
    correct = '''      !$omp parallel default(shared), private(ji,jj)
      !$omp do schedule(static)
      do jj = 1, jpj, 1
        do ji = 1, jpi, 1
          zdkt(ji, jj) = (ptb(ji, jj, jk - 1, jn) - ptb(ji, jj, jk, jn)) * \
wmask(ji, jj, jk)
        end do
      end do
      !$omp end do
      !$omp end parallel'''
    assert correct in gen_code
    directive = schedule[0].loop_body[1].else_body[0].else_body[0].dir_body[0]
    assert isinstance(directive, OMPDoDirective)

    # Call update a second time and make sure that this does not
    # trigger the whole update process again, and we get the same ast
    old_ast = directive.ast
    directive.update()
    assert directive.ast is old_ast

    # Remove the existing AST, so we can do more tests:
    directive.ast = None
    # Make the schedule invalid by adding a second child to the
    # OMPParallelDoDirective
    directive.dir_body.children.append(Statement())

    with pytest.raises(GenerationError) as err:
        _ = directive.update()
    assert ("An OpenMP DO can only be applied to a single loop but "
            "this Node has 2 children:" in str(err.value))
示例#16
0
def test_replace_with_error2():
    '''Check that the replace_with method raises the expected exceptions
    if either node is invalid.

    '''
    parent = Schedule()
    node1 = Statement()
    node2 = Statement()

    with pytest.raises(TypeError) as info:
        node1.replace_with("hello")
    assert ("The argument node in method replace_with in the Node class "
            "should be a Node but found 'str'." in str(info.value))

    with pytest.raises(GenerationError) as info:
        node1.replace_with(node2)
    assert ("This node should have a parent if its replace_with method "
            "is called." in str(info.value))

    node1.parent = parent
    node2.parent = parent
    parent.children = [node1, node2]
    with pytest.raises(GenerationError) as info:
        node1.replace_with(node2)
    assert ("The parent of argument node in method replace_with in the Node "
            "class should be None but found 'Schedule'." in str(info.value))

    node3 = Container("hello")
    with pytest.raises(GenerationError) as info:
        node1.replace_with(node3)
        assert ("Generation Error: Item 'Container' can't be child 0 of "
                "'Schedule'. The valid format is: '[Statement]*'."
                in str(info.value))