Exemplo n.º 1
0
 def __call__(self, definition_ir) -> Dict[str, Dict[str, str]]:
     gtir = GtirPipeline(DefIRToGTIR.apply(definition_ir)).full()
     base_oir = gtir_to_oir.GTIRToOIR().visit(gtir)
     oir_pipeline = self.backend.builder.options.backend_opts.get(
         "oir_pipeline", DefaultPipeline(skip=[NoFieldAccessPruning]))
     oir = oir_pipeline.run(base_oir)
     oir = FillFlushToLocalKCaches().visit(oir)
     cuir = oir_to_cuir.OIRToCUIR().visit(oir)
     cuir = kernel_fusion.FuseKernels().visit(cuir)
     cuir = extent_analysis.CacheExtents().visit(cuir)
     format_source = self.backend.builder.options.format_source
     implementation = cuir_codegen.CUIRCodegen.apply(
         cuir, format_source=format_source)
     bindings = GTCCudaBindingsCodegen.apply(cuir,
                                             module_name=self.module_name,
                                             backend=self.backend,
                                             format_source=format_source)
     return {
         "computation": {
             "computation.hpp": implementation
         },
         "bindings": {
             "bindings.cu": bindings
         },
     }
Exemplo n.º 2
0
def test_fill_to_local_k_caches_section_splitting_forward():
    testee = StencilFactory(vertical_loops=[
        VerticalLoopFactory(
            loop_order=LoopOrder.FORWARD,
            sections=[
                VerticalLoopSectionFactory(
                    interval=IntervalFactory(
                        end=AxisBound(level=LevelMarker.END, offset=-1)),
                    horizontal_executions=[
                        HorizontalExecutionFactory(
                            body=[
                                AssignStmtFactory(left__name="foo",
                                                  right__name="foo",
                                                  right__offset__k=0),
                                AssignStmtFactory(left__name="foo",
                                                  right__name="foo",
                                                  right__offset__k=1),
                            ],
                            declarations=[LocalScalarFactory()],
                        )
                    ],
                ),
                VerticalLoopSectionFactory(
                    interval=IntervalFactory(
                        start=AxisBound(level=LevelMarker.END, offset=-1)),
                    horizontal_executions__0__body__0=AssignStmtFactory(
                        left__name="foo", right__name="foo"),
                ),
            ],
            caches=[KCacheFactory(name="foo", fill=True, flush=False)],
        )
    ])
    transformed = FillFlushToLocalKCaches().visit(testee)
    vertical_loop = transformed.vertical_loops[0]
    assert len(
        vertical_loop.sections) == 3, "wrong number of vertical sections"
    assert (vertical_loop.sections[0].interval.start.level ==
            vertical_loop.sections[0].interval.end.level ==
            vertical_loop.sections[1].interval.start.level == LevelMarker.START
            and vertical_loop.sections[1].interval.end.level ==
            vertical_loop.sections[2].interval.start.level ==
            vertical_loop.sections[2].interval.end.level == LevelMarker.END
            ), "wrong interval levels in split sections"
    assert (vertical_loop.sections[0].interval.start.offset == 0
            and vertical_loop.sections[0].interval.end.offset ==
            vertical_loop.sections[1].interval.start.offset == 1
            and vertical_loop.sections[1].interval.end.offset ==
            vertical_loop.sections[2].interval.start.offset == -1
            and vertical_loop.sections[2].interval.end.offset
            == 0), "wrong interval offsets in split sections"
    assert (len(vertical_loop.sections[0].horizontal_executions[0].body) == 4
            ), "wrong number of fill stmts"
    assert (len(vertical_loop.sections[1].horizontal_executions[0].body) == 3
            ), "wrong number of fill stmts"
    assert (len(vertical_loop.sections[2].horizontal_executions[0].body) == 1
            ), "wrong number of fill stmts"
Exemplo n.º 3
0
 def _optimize_oir(self, oir):
     oir = GreedyMerging().visit(oir)
     oir = AdjacentLoopMerging().visit(oir)
     oir = LocalTemporariesToScalars().visit(oir)
     oir = WriteBeforeReadTemporariesToScalars().visit(oir)
     oir = OnTheFlyMerging().visit(oir)
     oir = IJCacheDetection().visit(oir)
     oir = KCacheDetection().visit(oir)
     oir = PruneKCacheFills().visit(oir)
     oir = PruneKCacheFlushes().visit(oir)
     oir = FillFlushToLocalKCaches().visit(oir)
     return oir
Exemplo n.º 4
0
 def _optimize_oir(self, oir):
     oir = optimize_horizontal_executions(oir, GraphMerging)
     oir = AdjacentLoopMerging().visit(oir)
     oir = LocalTemporariesToScalars().visit(oir)
     oir = WriteBeforeReadTemporariesToScalars().visit(oir)
     oir = OnTheFlyMerging().visit(oir)
     oir = MaskStmtMerging().visit(oir)
     oir = IJCacheDetection().visit(oir)
     oir = KCacheDetection().visit(oir)
     oir = PruneKCacheFills().visit(oir)
     oir = PruneKCacheFlushes().visit(oir)
     oir = FillFlushToLocalKCaches().visit(oir)
     return oir
Exemplo n.º 5
0
def test_fill_flush_to_local_k_caches_basic_forward():
    testee = StencilFactory(vertical_loops=[
        VerticalLoopFactory(
            loop_order=LoopOrder.FORWARD,
            sections__0__horizontal_executions__0__body=[
                AssignStmtFactory(
                    left__name="foo",
                    right__name="foo",
                ),
            ],
            caches=[KCacheFactory(name="foo", fill=True, flush=True)],
        )
    ])
    transformed = FillFlushToLocalKCaches().visit(testee)
    vertical_loop = transformed.vertical_loops[0]

    assert len(vertical_loop.caches) == 1, "wrong number of caches"
    assert not vertical_loop.caches[0].fill, "filling cache was not removed"
    assert not vertical_loop.caches[0].flush, "flushing cache was not removed"

    cache_name = vertical_loop.caches[0].name
    assert cache_name != "foo", "cache name must not be the same as filling field"
    assert transformed.declarations[
        0].name == cache_name, "cache field not found in temporaries"

    assert len(
        vertical_loop.sections) == 1, "number of vertical sections has changed"

    body = vertical_loop.sections[0].horizontal_executions[0].body
    assert len(body) == 3, "no or too many fill/flush stmts introduced?"
    assert body[0].left.name == cache_name, "wrong fill destination"
    assert body[0].right.name == "foo", "wrong fill source"
    assert body[0].left.offset.k == body[
        0].right.offset.k == 0, "wrong fill offset"
    assert body[1].left.name == cache_name, "wrong field name in cache access"
    assert body[1].right.name == cache_name, "wrong field name in cache access"
    assert body[1].left.offset.k == body[
        1].right.offset.k == 0, "wrong offset in cache access"
    assert body[2].left.name == "foo", "wrong flush destination"
    assert body[2].right.name == cache_name, "wrong flush source"
    assert body[2].left.offset.k == body[
        2].right.offset.k == 0, "wrong flush offset"
Exemplo n.º 6
0
def test_flush_to_local_k_caches_basic():
    testee = StencilFactory(vertical_loops=[
        VerticalLoopFactory(
            loop_order=LoopOrder.FORWARD,
            sections=[
                VerticalLoopSectionFactory(
                    interval=IntervalFactory(
                        end=AxisBound(level=LevelMarker.START, offset=1)),
                    horizontal_executions__0__body__0=AssignStmtFactory(
                        left__name="foo", right__name="foo"),
                ),
                VerticalLoopSectionFactory(
                    interval=IntervalFactory(
                        start=AxisBound(level=LevelMarker.START, offset=1)),
                    horizontal_executions__0__body__0=AssignStmtFactory(
                        left__name="foo",
                        right__name="foo",
                        right__offset__k=-1),
                ),
            ],
            caches=[KCacheFactory(name="foo", fill=False, flush=True)],
        )
    ])
    transformed = FillFlushToLocalKCaches().visit(testee)
    vertical_loop = transformed.vertical_loops[0]

    assert len(vertical_loop.caches) == 1, "wrong number of caches"
    assert not vertical_loop.caches[0].fill, "cache suddenly fills"
    assert not vertical_loop.caches[0].flush, "flushing cache was not removed"

    cache_name = vertical_loop.caches[0].name
    assert cache_name != "foo", "cache name must not be the same as flushing field"
    assert transformed.declarations[
        0].name == cache_name, "cache field not found in temporaries"

    assert len(
        vertical_loop.sections) == 2, "number of vertical sections has changed"

    assert (len(vertical_loop.sections[0].horizontal_executions[0].body) == 2
            ), "no or too many flush stmts introduced?"
    assert (vertical_loop.sections[0].horizontal_executions[0].body[0].left.
            name == cache_name), "wrong field name in cache access"
    assert (vertical_loop.sections[0].horizontal_executions[0].body[0].right.
            name == cache_name), "wrong field name in cache access"
    assert (vertical_loop.sections[0].horizontal_executions[0].body[0].right.
            offset.k == 0), "wrong offset in cache access"
    assert (vertical_loop.sections[0].horizontal_executions[0].body[1].left.
            name == "foo"), "wrong flush source"
    assert (vertical_loop.sections[0].horizontal_executions[0].body[1].right.
            name == cache_name), "wrong flush destination"
    assert (vertical_loop.sections[0].horizontal_executions[0].body[1].left.
            offset.k == vertical_loop.sections[0].horizontal_executions[0].
            body[1].right.offset.k == 0), "wrong flush offset"
    assert (len(vertical_loop.sections[1].horizontal_executions[0].body) == 2
            ), "no or too many flush stmts introduced?"
    assert (vertical_loop.sections[1].horizontal_executions[0].body[0].left.
            name == cache_name), "wrong field name in cache access"
    assert (vertical_loop.sections[1].horizontal_executions[0].body[0].right.
            name == cache_name), "wrong field name in cache access"
    assert (vertical_loop.sections[1].horizontal_executions[0].body[0].right.
            offset.k == -1), "wrong offset in cache access"
    assert (vertical_loop.sections[1].horizontal_executions[0].body[1].left.
            name == "foo"), "wrong flush source"
    assert (vertical_loop.sections[1].horizontal_executions[0].body[1].right.
            name == cache_name), "wrong flush destination"
    assert (vertical_loop.sections[1].horizontal_executions[0].body[1].right.
            offset.k == 0), "wrong flush offset"