コード例 #1
0
def test_stencil_to_computation() -> None:
    stencil = StencilFactory(
        name="stencil",
        params=[
            FieldDeclFactory(
                name="a",
                dtype=common.DataType.FLOAT64,
            ),
            oir.ScalarDecl(
                name="b",
                dtype=common.DataType.INT32,
            ),
        ],
        vertical_loops__0__sections__0__horizontal_executions__0__body=[
            AssignStmtFactory(
                left=FieldAccessFactory(name="a"), right=ScalarAccessFactory(name="b")
            )
        ],
    )
    computation = OirToNpir().visit(stencil)

    assert set(d.name for d in computation.api_field_decls) == {
        "a",
    }
    assert set(computation.arguments) == {"a", "b"}
    assert len(computation.vertical_passes) == 1
コード例 #2
0
def sdfg_arrays_to_oir_decls(
        sdfg: dace.SDFG) -> Tuple[List[oir.Decl], List[oir.Temporary]]:
    params = list()
    decls = list()

    array: dace.data.Data
    for name, array in sdfg.arrays.items():
        dtype = common.typestr_to_data_type(dace_dtype_to_typestr(array.dtype))
        if isinstance(array, dace.data.Array):
            dimensions = array_dimensions(array)
            if not array.transient:
                params.append(
                    oir.FieldDecl(
                        name=name,
                        dtype=dtype,
                        dimensions=dimensions,
                        data_dims=array.shape[sum(dimensions):],
                    ))
            else:
                decls.append(
                    oir.Temporary(
                        name=name,
                        dtype=dtype,
                        dimensions=dimensions,
                        data_dims=array.shape[sum(dimensions):],
                    ))
        else:
            assert isinstance(array, dace.data.Scalar)
            params.append(oir.ScalarDecl(name=name, dtype=dtype))

    reserved_symbols = internal_symbols(sdfg)
    for sym, stype in sdfg.symbols.items():
        if sym not in reserved_symbols:
            params.append(
                oir.ScalarDecl(name=sym,
                               dtype=common.typestr_to_data_type(
                                   stype.as_numpy_dtype().str)))
    return params, decls
コード例 #3
0
def test_stencil_to_computation():
    stencil = StencilFactory(
        name="stencil",
        params=[
            FieldDeclFactory(
                name="a",
                dtype=common.DataType.FLOAT64,
            ),
            oir.ScalarDecl(
                name="b",
                dtype=common.DataType.INT32,
            ),
        ],
        vertical_loops__0__sections__0__horizontal_executions__0__body=[
            AssignStmtFactory(left=FieldAccessFactory(name="a"),
                              right=ScalarAccessFactory(name="b"))
        ],
    )
    computation = OirToNpir().visit(stencil)

    assert computation.field_params == ["a"]
    assert computation.params == ["a", "b"]
    assert len(computation.vertical_passes) == 1
コード例 #4
0
ファイル: gtir_to_oir.py プロジェクト: fthaler/gt4py
 def visit_ScalarDecl(self, node: gtir.ScalarDecl) -> oir.ScalarDecl:
     return oir.ScalarDecl(name=node.name, dtype=node.dtype, loc=node.loc)
コード例 #5
0
 def visit_ScalarDecl(self, node: gtir.ScalarDecl, **kwargs: Any) -> oir.ScalarDecl:
     return oir.ScalarDecl(name=node.name, dtype=node.dtype)
コード例 #6
0
    def visit_VerticalLoopSection(
        self,
        node: oir.VerticalLoopSection,
        *,
        block_extents: Dict[int, Extent],
        new_symbol_name: Callable[[str], str],
        **kwargs: Any,
    ) -> oir.VerticalLoopSection:
        horizontal_executions = [node.horizontal_executions[0]]
        new_block_extents = [block_extents[id(horizontal_executions[-1])]]

        for this_hexec in node.horizontal_executions[1:]:
            last_extent = new_block_extents[-1]

            last_writes = AccessCollector.apply(
                horizontal_executions[-1]).write_fields()
            this_offset_reads = {
                name
                for name, offsets in AccessCollector.apply(
                    this_hexec).read_offsets().items()
                if any(off[0] != 0 or off[1] != 0 for off in offsets)
            }

            reads_with_offset_after_write = last_writes & this_offset_reads
            this_extent = block_extents[id(this_hexec)]

            if reads_with_offset_after_write or last_extent != this_extent:
                # Cannot merge: simply append to list
                horizontal_executions.append(this_hexec)
                new_block_extents.append(this_extent)
            else:
                # Merge
                duplicated_locals = {
                    decl.name
                    for decl in horizontal_executions[-1].declarations
                } & {decl.name
                     for decl in this_hexec.declarations}
                # Map from old to new scalar names applied to the second horizontal execution
                scalar_map = {
                    name: new_symbol_name(name)
                    for name in duplicated_locals
                }
                locals_symtable = {
                    decl.name: decl
                    for decl in this_hexec.declarations
                }

                new_body = self.visit(this_hexec.body,
                                      scalar_map=scalar_map,
                                      **kwargs)

                this_not_duplicated = [
                    decl for decl in this_hexec.declarations
                    if decl.name not in duplicated_locals
                ]
                this_mapped = [
                    oir.ScalarDecl(name=scalar_map[name],
                                   dtype=locals_symtable[name].dtype)
                    for name in duplicated_locals
                ]

                horizontal_executions[-1] = oir.HorizontalExecution(
                    body=horizontal_executions[-1].body + new_body,
                    declarations=(horizontal_executions[-1].declarations +
                                  this_not_duplicated + this_mapped),
                )

        return oir.VerticalLoopSection(
            interval=node.interval,
            horizontal_executions=horizontal_executions)
コード例 #7
0
    def visit_VerticalLoopSection(
        self,
        node: oir.VerticalLoopSection,
        *,
        block_extents: Dict[int, Extent],
        new_symbol_name: Callable[[str], str],
        **kwargs: Any,
    ) -> oir.VerticalLoopSection:
        @dataclass
        class UncheckedHorizontalExecution:
            # local replacement without type checking for type-checked oir node
            # required to reach reasonable run times for large node counts
            body: List[oir.Stmt]
            declarations: List[oir.LocalScalar]
            loc: Optional[SourceLocation]

            assert set(oir.HorizontalExecution.__fields__) == {
                "loc",
                "symtable_",
                "body",
                "declarations",
            }, ("Unexpected field in oir.HorizontalExecution, "
                "probably UncheckedHorizontalExecution needs an update")

            @classmethod
            def from_oir(cls, hexec: oir.HorizontalExecution):
                return cls(body=hexec.body,
                           declarations=hexec.declarations,
                           loc=hexec.loc)

            def to_oir(self) -> oir.HorizontalExecution:
                return oir.HorizontalExecution(body=self.body,
                                               declarations=self.declarations,
                                               loc=self.loc)

        horizontal_executions = [
            UncheckedHorizontalExecution.from_oir(
                node.horizontal_executions[0])
        ]
        new_block_extents = [block_extents[id(node.horizontal_executions[0])]]
        last_writes = AccessCollector.apply(
            node.horizontal_executions[0]).write_fields()

        for this_hexec in node.horizontal_executions[1:]:
            last_extent = new_block_extents[-1]

            this_offset_reads = {
                name
                for name, offsets in AccessCollector.apply(
                    this_hexec).read_offsets().items()
                if any(off[0] != 0 or off[1] != 0 for off in offsets)
            }

            reads_with_offset_after_write = last_writes & this_offset_reads
            this_extent = block_extents[id(this_hexec)]

            if reads_with_offset_after_write or last_extent != this_extent:
                # Cannot merge: simply append to list
                horizontal_executions.append(
                    UncheckedHorizontalExecution.from_oir(this_hexec))
                new_block_extents.append(this_extent)
                last_writes = AccessCollector.apply(this_hexec).write_fields()
            else:
                # Merge
                duplicated_locals = {
                    decl.name
                    for decl in horizontal_executions[-1].declarations
                } & {decl.name
                     for decl in this_hexec.declarations}
                # Map from old to new scalar names applied to the second horizontal execution
                scalar_map = {
                    name: new_symbol_name(name)
                    for name in duplicated_locals
                }
                locals_symtable = {
                    decl.name: decl
                    for decl in this_hexec.declarations
                }

                new_body = self.visit(this_hexec.body,
                                      scalar_map=scalar_map,
                                      **kwargs)

                this_not_duplicated = [
                    decl for decl in this_hexec.declarations
                    if decl.name not in duplicated_locals
                ]
                this_mapped = [
                    oir.ScalarDecl(name=scalar_map[name],
                                   dtype=locals_symtable[name].dtype)
                    for name in duplicated_locals
                ]

                horizontal_executions[-1] = UncheckedHorizontalExecution(
                    body=horizontal_executions[-1].body + new_body,
                    declarations=(horizontal_executions[-1].declarations +
                                  this_not_duplicated + this_mapped),
                    loc=horizontal_executions[-1].loc,
                )
                last_writes |= AccessCollector.apply(new_body).write_fields()

        return oir.VerticalLoopSection(
            interval=node.interval,
            horizontal_executions=[
                hexec.to_oir() for hexec in horizontal_executions
            ],
            loc=node.loc,
        )