def test_fill_to_local_k_caches_section_splitting_forward(): testee = StencilFactory(vertical_loops=[ VerticalLoopFactory( loop_order=LoopOrder.FORWARD, sections=[ VerticalLoopSectionFactory( interval=IntervalFactory( end=AxisBound(level=LevelMarker.END, offset=-1)), horizontal_executions=[ HorizontalExecutionFactory( body=[ AssignStmtFactory(left__name="foo", right__name="foo", right__offset__k=0), AssignStmtFactory(left__name="foo", right__name="foo", right__offset__k=1), ], declarations=[LocalScalarFactory()], ) ], ), VerticalLoopSectionFactory( interval=IntervalFactory( start=AxisBound(level=LevelMarker.END, offset=-1)), horizontal_executions__0__body__0=AssignStmtFactory( left__name="foo", right__name="foo"), ), ], caches=[KCacheFactory(name="foo", fill=True, flush=False)], ) ]) transformed = FillFlushToLocalKCaches().visit(testee) vertical_loop = transformed.vertical_loops[0] assert len( vertical_loop.sections) == 3, "wrong number of vertical sections" assert (vertical_loop.sections[0].interval.start.level == vertical_loop.sections[0].interval.end.level == vertical_loop.sections[1].interval.start.level == LevelMarker.START and vertical_loop.sections[1].interval.end.level == vertical_loop.sections[2].interval.start.level == vertical_loop.sections[2].interval.end.level == LevelMarker.END ), "wrong interval levels in split sections" assert (vertical_loop.sections[0].interval.start.offset == 0 and vertical_loop.sections[0].interval.end.offset == vertical_loop.sections[1].interval.start.offset == 1 and vertical_loop.sections[1].interval.end.offset == vertical_loop.sections[2].interval.start.offset == -1 and vertical_loop.sections[2].interval.end.offset == 0), "wrong interval offsets in split sections" assert (len(vertical_loop.sections[0].horizontal_executions[0].body) == 4 ), "wrong number of fill stmts" assert (len(vertical_loop.sections[1].horizontal_executions[0].body) == 3 ), "wrong number of fill stmts" assert (len(vertical_loop.sections[2].horizontal_executions[0].body) == 1 ), "wrong number of fill stmts"
def shifted(self, offset: Optional[int]) -> "Interval": if offset is None: return UnboundedInterval() start = ( None if self.start is None else AxisBound(level=self.start.level, offset=self.start.offset + offset) ) end = ( None if self.end is None else AxisBound(level=self.end.level, offset=self.end.offset + offset) ) return UnboundedInterval(start=start, end=end)
def test_same_node_read_write_not_overlap(): oir = StencilFactory(vertical_loops=[ VerticalLoopFactory(sections__0=VerticalLoopSectionFactory( interval=Interval(start=AxisBound.start(), end=AxisBound.from_start(1)), horizontal_executions__0__body__0=AssignStmtFactory( left__name="field", right__name="other"), )), VerticalLoopFactory(sections__0=VerticalLoopSectionFactory( interval=Interval(start=AxisBound.from_start(1), end=AxisBound.from_start(2)), horizontal_executions__0__body__0=AssignStmtFactory( left__name="field", right__name="field", right__offset__k=-1), )), ]) sdfg = OirSDFGBuilder().visit(oir) convert(sdfg, oir.loc)
def test_two_vertical_loops_no_read(): oir_pre = StencilFactory(vertical_loops=[ VerticalLoopFactory(sections__0=VerticalLoopSectionFactory( horizontal_executions=[ HorizontalExecutionFactory(body__0=AssignStmtFactory( left__name="field", right=Literal(value="42.0", dtype=DataType.FLOAT32), )) ], interval__end=AxisBound.from_start(3), ), ), VerticalLoopFactory(sections__0=VerticalLoopSectionFactory( horizontal_executions=[ HorizontalExecutionFactory(body__0=AssignStmtFactory( left__name="field", right=Literal(value="43.0", dtype=DataType.FLOAT32), )) ], interval__start=AxisBound.from_start(3), ), ), ]) sdfg = OirSDFGBuilder().visit(oir_pre) convert(sdfg, oir_pre.loc)
def test_flush_to_local_k_caches_basic(): testee = StencilFactory(vertical_loops=[ VerticalLoopFactory( loop_order=LoopOrder.FORWARD, sections=[ VerticalLoopSectionFactory( interval=IntervalFactory( end=AxisBound(level=LevelMarker.START, offset=1)), horizontal_executions__0__body__0=AssignStmtFactory( left__name="foo", right__name="foo"), ), VerticalLoopSectionFactory( interval=IntervalFactory( start=AxisBound(level=LevelMarker.START, offset=1)), horizontal_executions__0__body__0=AssignStmtFactory( left__name="foo", right__name="foo", right__offset__k=-1), ), ], caches=[KCacheFactory(name="foo", fill=False, flush=True)], ) ]) transformed = FillFlushToLocalKCaches().visit(testee) vertical_loop = transformed.vertical_loops[0] assert len(vertical_loop.caches) == 1, "wrong number of caches" assert not vertical_loop.caches[0].fill, "cache suddenly fills" assert not vertical_loop.caches[0].flush, "flushing cache was not removed" cache_name = vertical_loop.caches[0].name assert cache_name != "foo", "cache name must not be the same as flushing field" assert transformed.declarations[ 0].name == cache_name, "cache field not found in temporaries" assert len( vertical_loop.sections) == 2, "number of vertical sections has changed" assert (len(vertical_loop.sections[0].horizontal_executions[0].body) == 2 ), "no or too many flush stmts introduced?" assert (vertical_loop.sections[0].horizontal_executions[0].body[0].left. name == cache_name), "wrong field name in cache access" assert (vertical_loop.sections[0].horizontal_executions[0].body[0].right. name == cache_name), "wrong field name in cache access" assert (vertical_loop.sections[0].horizontal_executions[0].body[0].right. offset.k == 0), "wrong offset in cache access" assert (vertical_loop.sections[0].horizontal_executions[0].body[1].left. name == "foo"), "wrong flush source" assert (vertical_loop.sections[0].horizontal_executions[0].body[1].right. name == cache_name), "wrong flush destination" assert (vertical_loop.sections[0].horizontal_executions[0].body[1].left. offset.k == vertical_loop.sections[0].horizontal_executions[0]. body[1].right.offset.k == 0), "wrong flush offset" assert (len(vertical_loop.sections[1].horizontal_executions[0].body) == 2 ), "no or too many flush stmts introduced?" assert (vertical_loop.sections[1].horizontal_executions[0].body[0].left. name == cache_name), "wrong field name in cache access" assert (vertical_loop.sections[1].horizontal_executions[0].body[0].right. name == cache_name), "wrong field name in cache access" assert (vertical_loop.sections[1].horizontal_executions[0].body[0].right. offset.k == -1), "wrong offset in cache access" assert (vertical_loop.sections[1].horizontal_executions[0].body[1].left. name == "foo"), "wrong flush source" assert (vertical_loop.sections[1].horizontal_executions[0].body[1].right. name == cache_name), "wrong flush destination" assert (vertical_loop.sections[1].horizontal_executions[0].body[1].right. offset.k == 0), "wrong flush offset"
def full(cls): return cls(start=AxisBound.start(), end=AxisBound.end())
def shifted(self, offset: int) -> "Interval": start = AxisBound(level=self.start.level, offset=self.start.offset + offset) end = AxisBound(level=self.end.level, offset=self.end.offset + offset) return Interval(start=start, end=end)