def test_local_temporaries_to_scalars_multiexec(): testee = StencilFactory( vertical_loops__0__sections__0__horizontal_executions=[ HorizontalExecutionFactory(body=[ AssignStmtFactory(left__name="tmp"), AssignStmtFactory(right__name="tmp"), ]), HorizontalExecutionFactory( body=[AssignStmtFactory(right__name="tmp")]), ], declarations=[TemporaryFactory(name="tmp")], ) transformed = LocalTemporariesToScalars().visit(testee) assert "tmp" in {d.name for d in transformed.declarations} assert not transformed.iter_tree().if_isinstance( oir.ScalarAccess).to_list()
def _optimize_oir(self, oir): oir = GreedyMerging().visit(oir) oir = AdjacentLoopMerging().visit(oir) oir = LocalTemporariesToScalars().visit(oir) oir = WriteBeforeReadTemporariesToScalars().visit(oir) oir = OnTheFlyMerging().visit(oir) oir = NoFieldAccessPruning().visit(oir) oir = IJCacheDetection().visit(oir) oir = KCacheDetection().visit(oir) oir = PruneKCacheFills().visit(oir) oir = PruneKCacheFlushes().visit(oir) return oir
def _optimize_oir(self, oir): oir = optimize_horizontal_executions(oir, GraphMerging) oir = AdjacentLoopMerging().visit(oir) oir = LocalTemporariesToScalars().visit(oir) oir = WriteBeforeReadTemporariesToScalars().visit(oir) oir = OnTheFlyMerging().visit(oir) oir = MaskStmtMerging().visit(oir) oir = NoFieldAccessPruning().visit(oir) oir = IJCacheDetection().visit(oir) oir = KCacheDetection().visit(oir) oir = PruneKCacheFills().visit(oir) oir = PruneKCacheFlushes().visit(oir) return oir
def test_local_temporaries_to_scalars_basic(): testee = StencilFactory( vertical_loops__0__sections__0__horizontal_executions__0__body=[ AssignStmtFactory(left__name="tmp"), AssignStmtFactory(right__name="tmp"), ], declarations=[TemporaryFactory(name="tmp")], ) transformed = LocalTemporariesToScalars().visit(testee) hexec = transformed.vertical_loops[0].sections[0].horizontal_executions[0] assert isinstance(hexec.body[0].left, oir.ScalarAccess) assert isinstance(hexec.body[1].right, oir.ScalarAccess) assert not transformed.declarations assert len(hexec.declarations) == 1