示例#1
0
    def _split_entry_level(
        loop_order: common.LoopOrder,
        section: oir.VerticalLoopSection,
        new_symbol_name: Callable[[str], str],
    ) -> Tuple[oir.VerticalLoopSection, oir.VerticalLoopSection]:
        """Split the entry level of a loop section.

        Args:
            loop_order: forward or backward order.
            section: loop section to split.

        Returns:
            Two loop sections.
        """
        assert loop_order in (common.LoopOrder.FORWARD,
                              common.LoopOrder.BACKWARD)
        if loop_order == common.LoopOrder.FORWARD:
            bound = common.AxisBound(level=section.interval.start.level,
                                     offset=section.interval.start.offset + 1)
            entry_interval = oir.Interval(start=section.interval.start,
                                          end=bound)
            rest_interval = oir.Interval(start=bound, end=section.interval.end)
        else:
            bound = common.AxisBound(level=section.interval.end.level,
                                     offset=section.interval.end.offset - 1)
            entry_interval = oir.Interval(start=bound,
                                          end=section.interval.end)
            rest_interval = oir.Interval(start=section.interval.start,
                                         end=bound)
        decls = list(section.iter_tree().if_isinstance(oir.Decl))
        decls_map = {decl.name: new_symbol_name(decl.name) for decl in decls}

        class FixSymbolNameClashes(NodeTranslator):
            def visit_ScalarAccess(self,
                                   node: oir.ScalarAccess) -> oir.ScalarAccess:
                if node.name not in decls_map:
                    return node
                return oir.ScalarAccess(name=decls_map[node.name],
                                        dtype=node.dtype)

            def visit_LocalScalar(self,
                                  node: oir.LocalScalar) -> oir.LocalScalar:
                return oir.LocalScalar(name=decls_map[node.name],
                                       dtype=node.dtype)

        return (
            oir.VerticalLoopSection(
                interval=entry_interval,
                horizontal_executions=FixSymbolNameClashes().visit(
                    section.horizontal_executions),
                loc=section.loc,
            ),
            oir.VerticalLoopSection(
                interval=rest_interval,
                horizontal_executions=section.horizontal_executions,
                loc=section.loc,
            ),
        )
示例#2
0
 def visit_VerticalLoopSection(self, node: oir.VerticalLoopSection) -> Any:
     horizontal_executions = self.visit(node.horizontal_executions)
     if not horizontal_executions:
         return NOTHING
     return oir.VerticalLoopSection(
         interval=node.interval,
         horizontal_executions=horizontal_executions)
示例#3
0
 def visit_VerticalLoopSection(self, node: oir.VerticalLoopSection,
                               **kwargs: Any) -> oir.VerticalLoopSection:
     return oir.VerticalLoopSection(
         interval=node.interval,
         horizontal_executions=self._merge(node.horizontal_executions,
                                           **kwargs),
     )
示例#4
0
    def visit_VerticalLoop(self, node: gtir.VerticalLoop, *,
                           ctx: Context) -> oir.VerticalLoop:
        horiz_execs: List[oir.HorizontalExecution] = []
        for stmt in node.body:
            ctx.reset_local_scalars()
            ret = self.visit(stmt, ctx=ctx)
            stmts = utils.flatten_list(
                [ret] if isinstance(ret, oir.Stmt) else ret)
            horiz_execs.append(
                oir.HorizontalExecution(body=stmts,
                                        declarations=ctx.local_scalars))

        ctx.temp_fields += [
            oir.Temporary(name=temp.name,
                          dtype=temp.dtype,
                          dimensions=temp.dimensions)
            for temp in node.temporaries
        ]

        return oir.VerticalLoop(
            loop_order=node.loop_order,
            sections=[
                oir.VerticalLoopSection(
                    interval=self.visit(node.interval),
                    horizontal_executions=horiz_execs,
                    loc=node.loc,
                )
            ],
        )
示例#5
0
 def visit_VerticalLoopSection(
     self,
     node: oir.VerticalLoop,
     local_tmps: Set[str],
     **kwargs: Any,
 ) -> oir.VerticalLoopSection:
     return oir.VerticalLoopSection(
         interval=node.interval,
         horizontal_executions=self.visit(node.horizontal_executions,
                                          local_tmps=local_tmps,
                                          **kwargs),
     )
    def visit_VerticalLoopSection(self, node: oir.VerticalLoopSection,
                                  **kwargs: Any) -> oir.VerticalLoopSection:

        last_vls = None
        next_vls = node
        applied = True
        while applied:
            last_vls = next_vls
            next_vls = oir.VerticalLoopSection(
                interval=last_vls.interval,
                horizontal_executions=self._merge(
                    last_vls.horizontal_executions, **kwargs),
            )
            applied = len(next_vls.horizontal_executions) < len(
                last_vls.horizontal_executions)

        return next_vls
示例#7
0
    def visit_VerticalLoop(self, node: gtir.VerticalLoop, *, ctx: Context,
                           **kwargs: Any) -> oir.VerticalLoop:
        ctx.horizontal_executions.clear()
        self.visit(node.body, ctx=ctx)

        for temp in node.temporaries:
            ctx.add_decl(oir.Temporary(name=temp.name, dtype=temp.dtype))

        return oir.VerticalLoop(
            loop_order=node.loop_order,
            sections=[
                oir.VerticalLoopSection(
                    interval=self.visit(node.interval, **kwargs),
                    horizontal_executions=ctx.horizontal_executions,
                )
            ],
            caches=[],
        )
    def visit_VerticalLoopSection(
        self,
        node: oir.VerticalLoopSection,
        *,
        block_extents: Dict[int, Extent],
        new_symbol_name: Callable[[str], str],
        **kwargs: Any,
    ) -> oir.VerticalLoopSection:
        horizontal_executions = [node.horizontal_executions[0]]
        new_block_extents = [block_extents[id(horizontal_executions[-1])]]

        for this_hexec in node.horizontal_executions[1:]:
            last_extent = new_block_extents[-1]

            last_writes = AccessCollector.apply(
                horizontal_executions[-1]).write_fields()
            this_offset_reads = {
                name
                for name, offsets in AccessCollector.apply(
                    this_hexec).read_offsets().items()
                if any(off[0] != 0 or off[1] != 0 for off in offsets)
            }

            reads_with_offset_after_write = last_writes & this_offset_reads
            this_extent = block_extents[id(this_hexec)]

            if reads_with_offset_after_write or last_extent != this_extent:
                # Cannot merge: simply append to list
                horizontal_executions.append(this_hexec)
                new_block_extents.append(this_extent)
            else:
                # Merge
                duplicated_locals = {
                    decl.name
                    for decl in horizontal_executions[-1].declarations
                } & {decl.name
                     for decl in this_hexec.declarations}
                # Map from old to new scalar names applied to the second horizontal execution
                scalar_map = {
                    name: new_symbol_name(name)
                    for name in duplicated_locals
                }
                locals_symtable = {
                    decl.name: decl
                    for decl in this_hexec.declarations
                }

                new_body = self.visit(this_hexec.body,
                                      scalar_map=scalar_map,
                                      **kwargs)

                this_not_duplicated = [
                    decl for decl in this_hexec.declarations
                    if decl.name not in duplicated_locals
                ]
                this_mapped = [
                    oir.ScalarDecl(name=scalar_map[name],
                                   dtype=locals_symtable[name].dtype)
                    for name in duplicated_locals
                ]

                horizontal_executions[-1] = oir.HorizontalExecution(
                    body=horizontal_executions[-1].body + new_body,
                    declarations=(horizontal_executions[-1].declarations +
                                  this_not_duplicated + this_mapped),
                )

        return oir.VerticalLoopSection(
            interval=node.interval,
            horizontal_executions=horizontal_executions)
示例#9
0
    def visit_VerticalLoopSection(
        self,
        node: oir.VerticalLoopSection,
        *,
        block_extents: Dict[int, Extent],
        new_symbol_name: Callable[[str], str],
        **kwargs: Any,
    ) -> oir.VerticalLoopSection:
        @dataclass
        class UncheckedHorizontalExecution:
            # local replacement without type checking for type-checked oir node
            # required to reach reasonable run times for large node counts
            body: List[oir.Stmt]
            declarations: List[oir.LocalScalar]
            loc: Optional[SourceLocation]

            assert set(oir.HorizontalExecution.__fields__) == {
                "loc",
                "symtable_",
                "body",
                "declarations",
            }, ("Unexpected field in oir.HorizontalExecution, "
                "probably UncheckedHorizontalExecution needs an update")

            @classmethod
            def from_oir(cls, hexec: oir.HorizontalExecution):
                return cls(body=hexec.body,
                           declarations=hexec.declarations,
                           loc=hexec.loc)

            def to_oir(self) -> oir.HorizontalExecution:
                return oir.HorizontalExecution(body=self.body,
                                               declarations=self.declarations,
                                               loc=self.loc)

        horizontal_executions = [
            UncheckedHorizontalExecution.from_oir(
                node.horizontal_executions[0])
        ]
        new_block_extents = [block_extents[id(node.horizontal_executions[0])]]
        last_writes = AccessCollector.apply(
            node.horizontal_executions[0]).write_fields()

        for this_hexec in node.horizontal_executions[1:]:
            last_extent = new_block_extents[-1]

            this_offset_reads = {
                name
                for name, offsets in AccessCollector.apply(
                    this_hexec).read_offsets().items()
                if any(off[0] != 0 or off[1] != 0 for off in offsets)
            }

            reads_with_offset_after_write = last_writes & this_offset_reads
            this_extent = block_extents[id(this_hexec)]

            if reads_with_offset_after_write or last_extent != this_extent:
                # Cannot merge: simply append to list
                horizontal_executions.append(
                    UncheckedHorizontalExecution.from_oir(this_hexec))
                new_block_extents.append(this_extent)
                last_writes = AccessCollector.apply(this_hexec).write_fields()
            else:
                # Merge
                duplicated_locals = {
                    decl.name
                    for decl in horizontal_executions[-1].declarations
                } & {decl.name
                     for decl in this_hexec.declarations}
                # Map from old to new scalar names applied to the second horizontal execution
                scalar_map = {
                    name: new_symbol_name(name)
                    for name in duplicated_locals
                }
                locals_symtable = {
                    decl.name: decl
                    for decl in this_hexec.declarations
                }

                new_body = self.visit(this_hexec.body,
                                      scalar_map=scalar_map,
                                      **kwargs)

                this_not_duplicated = [
                    decl for decl in this_hexec.declarations
                    if decl.name not in duplicated_locals
                ]
                this_mapped = [
                    oir.ScalarDecl(name=scalar_map[name],
                                   dtype=locals_symtable[name].dtype)
                    for name in duplicated_locals
                ]

                horizontal_executions[-1] = UncheckedHorizontalExecution(
                    body=horizontal_executions[-1].body + new_body,
                    declarations=(horizontal_executions[-1].declarations +
                                  this_not_duplicated + this_mapped),
                    loc=horizontal_executions[-1].loc,
                )
                last_writes |= AccessCollector.apply(new_body).write_fields()

        return oir.VerticalLoopSection(
            interval=node.interval,
            horizontal_executions=[
                hexec.to_oir() for hexec in horizontal_executions
            ],
            loc=node.loc,
        )