def test_fusetrans_error_not_same_parent(): ''' Check that we reject attempts to fuse loops which don't share the same parent ''' from psyclone.psyGen import Loop, Schedule, Literal from psyclone.transformations import LoopFuseTrans, TransformationError sch1 = Schedule() sch2 = Schedule() loop1 = Loop(variable_name="i", parent=sch1) loop2 = Loop(variable_name="j", parent=sch2) sch1.addchild(loop1) sch2.addchild(loop2) loop1.addchild(Literal("1", parent=loop1)) # start loop1.addchild(Literal("10", parent=loop1)) # stop loop1.addchild(Literal("1", parent=loop1)) # step loop1.addchild(Schedule(parent=loop1)) # loop body loop2.addchild(Literal("1", parent=loop2)) # start loop2.addchild(Literal("10", parent=loop2)) # stop loop2.addchild(Literal("1", parent=loop2)) # step loop2.addchild(Schedule(parent=loop2)) # loop body fuse = LoopFuseTrans() # Try to fuse loops with different parents with pytest.raises(TransformationError) as err: fuse._validate(loop1, loop2) assert "Error in LoopFuse transformation. Loops do not have the " \ "same parent" in str(err.value)
def test_fusetrans_error_incomplete(): ''' Check that we reject attempts to fuse loops which are incomplete. ''' from psyclone.psyGen import Loop, Schedule, Literal, Return from psyclone.transformations import LoopFuseTrans, TransformationError sch = Schedule() loop1 = Loop(variable_name="i", parent=sch) loop2 = Loop(variable_name="j", parent=sch) sch.addchild(loop1) sch.addchild(loop2) fuse = LoopFuseTrans() # Check first loop with pytest.raises(TransformationError) as err: fuse._validate(loop1, loop2) assert "Error in LoopFuse transformation. The first loop does not have " \ "4 children." in str(err.value) loop1.addchild(Literal("start", parent=loop1)) loop1.addchild(Literal("stop", parent=loop1)) loop1.addchild(Literal("step", parent=loop1)) loop1.addchild(Schedule(parent=loop1)) loop1.loop_body.addchild(Return(parent=loop1.loop_body)) # Check second loop with pytest.raises(TransformationError) as err: fuse._validate(loop1, loop2) assert "Error in LoopFuse transformation. The second loop does not have " \ "4 children." in str(err.value) loop2.addchild(Literal("start", parent=loop2)) loop2.addchild(Literal("stop", parent=loop2)) loop2.addchild(Literal("step", parent=loop2)) loop2.addchild(Schedule(parent=loop2)) loop2.loop_body.addchild(Return(parent=loop2.loop_body)) # Validation should now pass fuse._validate(loop1, loop2)