Exemple #1
0
def test_fill_to_local_k_caches_section_splitting_forward():
    testee = StencilFactory(vertical_loops=[
        VerticalLoopFactory(
            loop_order=LoopOrder.FORWARD,
            sections=[
                VerticalLoopSectionFactory(
                    interval=IntervalFactory(
                        end=AxisBound(level=LevelMarker.END, offset=-1)),
                    horizontal_executions=[
                        HorizontalExecutionFactory(
                            body=[
                                AssignStmtFactory(left__name="foo",
                                                  right__name="foo",
                                                  right__offset__k=0),
                                AssignStmtFactory(left__name="foo",
                                                  right__name="foo",
                                                  right__offset__k=1),
                            ],
                            declarations=[LocalScalarFactory()],
                        )
                    ],
                ),
                VerticalLoopSectionFactory(
                    interval=IntervalFactory(
                        start=AxisBound(level=LevelMarker.END, offset=-1)),
                    horizontal_executions__0__body__0=AssignStmtFactory(
                        left__name="foo", right__name="foo"),
                ),
            ],
            caches=[KCacheFactory(name="foo", fill=True, flush=False)],
        )
    ])
    transformed = FillFlushToLocalKCaches().visit(testee)
    vertical_loop = transformed.vertical_loops[0]
    assert len(
        vertical_loop.sections) == 3, "wrong number of vertical sections"
    assert (vertical_loop.sections[0].interval.start.level ==
            vertical_loop.sections[0].interval.end.level ==
            vertical_loop.sections[1].interval.start.level == LevelMarker.START
            and vertical_loop.sections[1].interval.end.level ==
            vertical_loop.sections[2].interval.start.level ==
            vertical_loop.sections[2].interval.end.level == LevelMarker.END
            ), "wrong interval levels in split sections"
    assert (vertical_loop.sections[0].interval.start.offset == 0
            and vertical_loop.sections[0].interval.end.offset ==
            vertical_loop.sections[1].interval.start.offset == 1
            and vertical_loop.sections[1].interval.end.offset ==
            vertical_loop.sections[2].interval.start.offset == -1
            and vertical_loop.sections[2].interval.end.offset
            == 0), "wrong interval offsets in split sections"
    assert (len(vertical_loop.sections[0].horizontal_executions[0].body) == 4
            ), "wrong number of fill stmts"
    assert (len(vertical_loop.sections[1].horizontal_executions[0].body) == 3
            ), "wrong number of fill stmts"
    assert (len(vertical_loop.sections[2].horizontal_executions[0].body) == 1
            ), "wrong number of fill stmts"
Exemple #2
0
 def shifted(self, offset: Optional[int]) -> "Interval":
     if offset is None:
         return UnboundedInterval()
     start = (
         None
         if self.start is None
         else AxisBound(level=self.start.level, offset=self.start.offset + offset)
     )
     end = (
         None
         if self.end is None
         else AxisBound(level=self.end.level, offset=self.end.offset + offset)
     )
     return UnboundedInterval(start=start, end=end)
Exemple #3
0
def test_same_node_read_write_not_overlap():

    oir = StencilFactory(vertical_loops=[
        VerticalLoopFactory(sections__0=VerticalLoopSectionFactory(
            interval=Interval(start=AxisBound.start(),
                              end=AxisBound.from_start(1)),
            horizontal_executions__0__body__0=AssignStmtFactory(
                left__name="field", right__name="other"),
        )),
        VerticalLoopFactory(sections__0=VerticalLoopSectionFactory(
            interval=Interval(start=AxisBound.from_start(1),
                              end=AxisBound.from_start(2)),
            horizontal_executions__0__body__0=AssignStmtFactory(
                left__name="field", right__name="field", right__offset__k=-1),
        )),
    ])
    sdfg = OirSDFGBuilder().visit(oir)
    convert(sdfg, oir.loc)
Exemple #4
0
def test_two_vertical_loops_no_read():
    oir_pre = StencilFactory(vertical_loops=[
        VerticalLoopFactory(sections__0=VerticalLoopSectionFactory(
            horizontal_executions=[
                HorizontalExecutionFactory(body__0=AssignStmtFactory(
                    left__name="field",
                    right=Literal(value="42.0", dtype=DataType.FLOAT32),
                ))
            ],
            interval__end=AxisBound.from_start(3),
        ), ),
        VerticalLoopFactory(sections__0=VerticalLoopSectionFactory(
            horizontal_executions=[
                HorizontalExecutionFactory(body__0=AssignStmtFactory(
                    left__name="field",
                    right=Literal(value="43.0", dtype=DataType.FLOAT32),
                ))
            ],
            interval__start=AxisBound.from_start(3),
        ), ),
    ])
    sdfg = OirSDFGBuilder().visit(oir_pre)
    convert(sdfg, oir_pre.loc)
Exemple #5
0
def test_flush_to_local_k_caches_basic():
    testee = StencilFactory(vertical_loops=[
        VerticalLoopFactory(
            loop_order=LoopOrder.FORWARD,
            sections=[
                VerticalLoopSectionFactory(
                    interval=IntervalFactory(
                        end=AxisBound(level=LevelMarker.START, offset=1)),
                    horizontal_executions__0__body__0=AssignStmtFactory(
                        left__name="foo", right__name="foo"),
                ),
                VerticalLoopSectionFactory(
                    interval=IntervalFactory(
                        start=AxisBound(level=LevelMarker.START, offset=1)),
                    horizontal_executions__0__body__0=AssignStmtFactory(
                        left__name="foo",
                        right__name="foo",
                        right__offset__k=-1),
                ),
            ],
            caches=[KCacheFactory(name="foo", fill=False, flush=True)],
        )
    ])
    transformed = FillFlushToLocalKCaches().visit(testee)
    vertical_loop = transformed.vertical_loops[0]

    assert len(vertical_loop.caches) == 1, "wrong number of caches"
    assert not vertical_loop.caches[0].fill, "cache suddenly fills"
    assert not vertical_loop.caches[0].flush, "flushing cache was not removed"

    cache_name = vertical_loop.caches[0].name
    assert cache_name != "foo", "cache name must not be the same as flushing field"
    assert transformed.declarations[
        0].name == cache_name, "cache field not found in temporaries"

    assert len(
        vertical_loop.sections) == 2, "number of vertical sections has changed"

    assert (len(vertical_loop.sections[0].horizontal_executions[0].body) == 2
            ), "no or too many flush stmts introduced?"
    assert (vertical_loop.sections[0].horizontal_executions[0].body[0].left.
            name == cache_name), "wrong field name in cache access"
    assert (vertical_loop.sections[0].horizontal_executions[0].body[0].right.
            name == cache_name), "wrong field name in cache access"
    assert (vertical_loop.sections[0].horizontal_executions[0].body[0].right.
            offset.k == 0), "wrong offset in cache access"
    assert (vertical_loop.sections[0].horizontal_executions[0].body[1].left.
            name == "foo"), "wrong flush source"
    assert (vertical_loop.sections[0].horizontal_executions[0].body[1].right.
            name == cache_name), "wrong flush destination"
    assert (vertical_loop.sections[0].horizontal_executions[0].body[1].left.
            offset.k == vertical_loop.sections[0].horizontal_executions[0].
            body[1].right.offset.k == 0), "wrong flush offset"
    assert (len(vertical_loop.sections[1].horizontal_executions[0].body) == 2
            ), "no or too many flush stmts introduced?"
    assert (vertical_loop.sections[1].horizontal_executions[0].body[0].left.
            name == cache_name), "wrong field name in cache access"
    assert (vertical_loop.sections[1].horizontal_executions[0].body[0].right.
            name == cache_name), "wrong field name in cache access"
    assert (vertical_loop.sections[1].horizontal_executions[0].body[0].right.
            offset.k == -1), "wrong offset in cache access"
    assert (vertical_loop.sections[1].horizontal_executions[0].body[1].left.
            name == "foo"), "wrong flush source"
    assert (vertical_loop.sections[1].horizontal_executions[0].body[1].right.
            name == cache_name), "wrong flush destination"
    assert (vertical_loop.sections[1].horizontal_executions[0].body[1].right.
            offset.k == 0), "wrong flush offset"
Exemple #6
0
 def full(cls):
     return cls(start=AxisBound.start(), end=AxisBound.end())
Exemple #7
0
 def shifted(self, offset: int) -> "Interval":
     start = AxisBound(level=self.start.level,
                       offset=self.start.offset + offset)
     end = AxisBound(level=self.end.level, offset=self.end.offset + offset)
     return Interval(start=start, end=end)